源码:
import torch
import torchvision as tv
from torch.utils import data
import matplotlib.pyplot as plt
import timedef get_fashion_mnist_labels(labels):text_labels = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']return [text_labels[int(i)] for i in labels]def show_fashion_mnist(imgs, num_rows, num_cols, titles=None, scale=0.5):figsize = (num_cols*scale, num_rows*scale)_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i,(ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):ax.imshow(img.numpy())else:ax.imshow(img)ax.axis('off')if titles:ax.set_title(titles[i])plt.show()return axesdef get_dataloader_workers(): #@save"""使用4个进程来读取数据"""return 4def load_data_fashion_mnist(batch_size, resize=None):trans = [tv.transforms.ToTensor()] # 创建一个将图像转换为张量的变换if resize:trans.insert(0, tv.transforms.Resize(resize))trans = tv.transforms.Compose(trans)mnist_train = tv.datasets.FashionMNIST(root='./data', train=True, download=True, transform=trans) # 加载FashionMNIST训练数据集,并应用变换mnist_test = tv.datasets.FashionMNIST(root='./data', train=False, download=True, transform=trans) # 加载FashionMNIST测试数据集,并应用变换return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1,keepdim=True)return X_exp / partitiondef net(X):return softmax(torch.matmul(X.reshape(-1, W.shape[0]), W) +b)def cross_entropy(y_hat, y):return -torch.log(y_hat[range(len(y_hat)), y])def accuracy(y_hat, y):if len(y_hat.shape)>1 and y_hat.shape[1]>1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def evaluate_accuracy(net, data_iter):if isinstance(net, torch.nn.Module):net.eval() # 评估模式, 这会关闭dropoutmetric = Accumulator(2) # 正确预测数、预测总数with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel())return metric[0]/metric[1]def sgd(params, lr, batch_size): #@save"""小批量随机梯度下降"""with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()def updater(batch_size):return sgd([W, b], lr, batch_size)class Accumulator: #@save"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * n #self.data 是一个列表,初始化为 n 个 0.0,用于存储累加的值。def add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)] #一个列表推导式,它遍历每一对 (a, b),并将 a 和 b 相加的结果生成一个新的列表。def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]
def train_epoch_ch3(net, train_iter, loss, updater): #@save"""训练模型一个迭代周期(定义见第3章)"""# 将模型设置为训练模式if isinstance(net, torch.nn.Module):net.train()# 训练损失总和、训练准确度总和、样本数metric = Accumulator(3)for X, y in train_iter:# 计算梯度并更新参数y_hat = net(X)l = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):# 使用PyTorch内置的优化器和损失函数updater.zero_grad()l.mean().backward()updater.step()else:# 使用定制的优化器和损失函数l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater): #@save"""训练模型(定义见第3章)"""for epoch in range(num_epochs):train_metrics = train_epoch_ch3(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f' % ( epoch + 1, train_metrics[0], train_metrics[1], test_acc))def predict_ch3(net, test_iter, n=6): #@savefor X, y in test_iter:breaktrues = get_fashion_mnist_labels(y)preds = get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true +'\n' + pred for true, pred in zip(trues, preds)]show_fashion_mnist(X[:n].reshape(-1,28,28), 1, n, titles[:n])if __name__ == '__main__':batch_size = 256train_iter, test_iter = load_data_fashion_mnist(batch_size)num_inputs = 784num_outputs = 10W = torch.normal(0, 0.1, size=(num_inputs, num_outputs), requires_grad=True)b = torch.zeros(num_outputs, requires_grad=True)lr = 0.1num_epochs = 10loss = cross_entropy # updater = lambda batch_size: sgd([W, b], lr, batch_size)train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)predict_ch3(net, test_iter)
另外感慨一下MNIST数据集下载速度真是比CIFAR快太多了