1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class NeuralNetwork(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28*28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10) )
def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits
model = NeuralNetwork().to(device) print(model)
def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) model.train() for batch, (X, y) in enumerate(dataloader): X, y = X.to(device), y.to(device)
pred = model(X) loss = loss_fn(pred, y)
optimizer.zero_grad() loss.backward() optimizer.step()
if batch % 100 == 0: loss, current = loss.item(), batch * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
评论系统未开启,无法评论!