1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
| def get_dataloader_workers(): return 4 def load_data_fashion_mnist(batch_size, resize=None): trans = [transforms.ToTensor()] if resize: trans.insert(0, transforms.Resize(resize)) trans = transforms.Compose(trans) mnist_train = torchvision.datasets.FashionMNIST( root="./data", train=True, transform=trans, download=True) mnist_test = torchvision.datasets.FashionMNIST( root="./data", train=False, transform=trans, download=True) return ( data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()), data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()) )
class Accumulator: """在n个变量上累加""" def __init__(self, n): self.data = [0.0] * n
def add(self, *args): self.data = [a + float(b) for a, b in zip(self.data, args)]
def reset(self): self.data = [0.0] * len(self.data)
def __getitem__(self, idx): return self.data[idx] def accuracy(y_hat, y): """计算预测正确的数量""" if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: y_hat = y_hat.argmax(axis=1) cmp = y_hat.type(y.dtype) == y return float(cmp.type(y.dtype).sum()) def evaluate_accuracy(net, data_iter,device): """计算在指定数据集上模型的精度""" if isinstance(net, torch.nn.Module): net.eval() metric = Accumulator(2) with torch.no_grad(): for X, y in data_iter: X,y = X.to(device), y.to(device) metric.add(accuracy(net(X), y), y.numel()) return metric[0] / metric[1]
def train_ch6(net, trian_iter, test_iter, num_epochs, lr, device): def init_weights(m): if type(m)==nn.Linear or type(m)==nn.Conv2d: nn.init.xavier_uniform_(m.weight) net.apply(init_weights) print('training on' , device)
optimizer = torch.optim.SGD(net.parameters(), lr=lr) loss = nn.CrossEntropyLoss() for epoch in range(num_epochs): metric = Accumulator(3) net = net.to(device) net.train() for i, (X,y) in tqdm(enumerate(train_iter)): start = time.time() optimizer.zero_grad() X,y = X.to(device), y.to(device) y_hat = net(X) l = loss(y_hat,y) l.backward() optimizer.step() with torch.no_grad(): metric.add(l * X.shape[0], accuracy(y_hat,y), X.shape[0]) train_l = metric[0] / metric[2] train_acc = metric[1] / metric[2] test_acc = evaluate_accuracy(net, test_iter,device) print(f'loss {train_l:.3f}, train acc {train_acc:.3f},' f'test acc {test_acc:.3f}')
|