logo头像
Snippet 博客主题

Pytorch-使用cuda

速度起飞!


相关API

  1. torch.cuda.is_available() # 是否存在GPU
  2. torch.cuda.device_count() # 可用的GPU个数
  3. torch.cuda.get_device_name() # 取得设备名
  4. torch.device() # 使用选择的设备

样例demo

注意:所有的tensor和model需要手动调用to(device)设置为使用GPU资源,否则报错


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # 使用GPU

class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)

def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits

model = NeuralNetwork().to(device) # model指定GPU执行
print(model)

def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device) # tensor指定GPU执行

# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)

# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()

if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

相关资料

Ptyroch Cuda官网

评论系统未开启,无法评论!