MNIST数据集:手搓softmax回归

news/2024/10/5 15:05:16/文章来源:https://www.cnblogs.com/bozhi233/p/18287897

源码:

import torch 
import torchvision as tv
from torch.utils import data
import matplotlib.pyplot as plt
import timedef get_fashion_mnist_labels(labels):text_labels = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']return [text_labels[int(i)] for i in labels]def show_fashion_mnist(imgs, num_rows, num_cols, titles=None, scale=0.5):figsize = (num_cols*scale, num_rows*scale)_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i,(ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):ax.imshow(img.numpy())else:ax.imshow(img)ax.axis('off')if titles:ax.set_title(titles[i])plt.show()return axesdef get_dataloader_workers():  #@save"""使用4个进程来读取数据"""return 4def load_data_fashion_mnist(batch_size, resize=None):trans = [tv.transforms.ToTensor()]  # 创建一个将图像转换为张量的变换if resize:trans.insert(0, tv.transforms.Resize(resize))trans = tv.transforms.Compose(trans)mnist_train = tv.datasets.FashionMNIST(root='./data', train=True, download=True, transform=trans)  # 加载FashionMNIST训练数据集,并应用变换mnist_test = tv.datasets.FashionMNIST(root='./data', train=False, download=True, transform=trans)  # 加载FashionMNIST测试数据集,并应用变换return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1,keepdim=True)return X_exp / partitiondef net(X):return softmax(torch.matmul(X.reshape(-1, W.shape[0]), W) +b)def cross_entropy(y_hat, y):return -torch.log(y_hat[range(len(y_hat)), y])def accuracy(y_hat, y):if len(y_hat.shape)>1 and y_hat.shape[1]>1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def evaluate_accuracy(net, data_iter):if isinstance(net, torch.nn.Module):net.eval() # 评估模式, 这会关闭dropoutmetric = Accumulator(2) # 正确预测数、预测总数with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel())return metric[0]/metric[1]def sgd(params, lr, batch_size):  #@save"""小批量随机梯度下降"""with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()def updater(batch_size):return sgd([W, b], lr, batch_size)class Accumulator:  #@save"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * n       #self.data 是一个列表,初始化为 n 个 0.0,用于存储累加的值。def add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)] #一个列表推导式,它遍历每一对 (a, b),并将 a 和 b 相加的结果生成一个新的列表。def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]
def train_epoch_ch3(net, train_iter, loss, updater):  #@save"""训练模型一个迭代周期(定义见第3章)"""# 将模型设置为训练模式if isinstance(net, torch.nn.Module):net.train()# 训练损失总和、训练准确度总和、样本数metric = Accumulator(3)for X, y in train_iter:# 计算梯度并更新参数y_hat = net(X)l = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):# 使用PyTorch内置的优化器和损失函数updater.zero_grad()l.mean().backward()updater.step()else:# 使用定制的优化器和损失函数l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):  #@save"""训练模型(定义见第3章)"""for epoch in range(num_epochs):train_metrics = train_epoch_ch3(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f' % ( epoch + 1, train_metrics[0], train_metrics[1], test_acc))def predict_ch3(net, test_iter, n=6):  #@savefor X, y in test_iter:breaktrues = get_fashion_mnist_labels(y)preds = get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true +'\n' + pred for true, pred in zip(trues, preds)]show_fashion_mnist(X[:n].reshape(-1,28,28), 1, n, titles[:n])if __name__ == '__main__':batch_size = 256train_iter, test_iter = load_data_fashion_mnist(batch_size)num_inputs = 784num_outputs = 10W = torch.normal(0, 0.1, size=(num_inputs, num_outputs), requires_grad=True)b = torch.zeros(num_outputs, requires_grad=True)lr = 0.1num_epochs = 10loss = cross_entropy    # updater = lambda batch_size: sgd([W, b], lr, batch_size)train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)predict_ch3(net, test_iter)

另外感慨一下MNIST数据集下载速度真是比CIFAR快太多了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/739520.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java反射与Fastjson的危险反序列化

Preface 在前文中,我们介绍了 Java 的基础语法和特性和 fastjson 的基础用法,本文我们将深入学习fastjson的危险反序列化以及预期相关的 Java 概念。 什么是Java反射? 在前文中,我们有一行代码 Computer macBookPro = JSON.parseObject(preReceive,Computer.class); 这行代…

Win10双屏设置 之 鼠标不能从中间划过 问题解决

Win10双屏设置 之 鼠标不能从中间划过解决-百度经验 (baidu.com)

比赛获奖的武林秘籍:03 好的创意选取-获得国奖的最必要前提

本文主要介绍了大学生电子计算机类比赛和创新创业类比赛创意选取的重要性,并列举了好的创意选取和坏的创意选取的例子,同时说明了好的创意选取具有哪些特点,同时对常见的创意选取途径与来源进行了基本介绍。比赛获奖的武林秘籍:03 好的创意选取-获得国奖的最必要前提 摘要 …

阶段测试

Sre网络班阶段测试 一:用sed 命令修改/etc/fstab文件,删除文件中的空行,注释行,并保留文件备份(7分) 答案写这里:二: 用 find 命令查找出 /var/ 目录中大于1M且以db结尾的文件(7分) 答案写这里:三: 先判断当前主机是否安装了nginx包,如果没安装,则执行命令安装,…

时间序列分析专题——利用SPSS专家建模器进行建模

SPSS的专家建模器可以自动识别数据,给出最适合的模型,本章通过三个例题介绍如何使用SPSS实现时间序列分析。由于本人对时间序列分析的理解尚浅,做出模型后在论文上的呈现形式需要取查阅资料,以便更好地在论文上呈现 在此之前,我们还需要了解时间序列分析的一些基础的名词 …

如何在ubuntu上设置清华源

如何在ubuntu上设置清华源 apt介绍 apt(Advanced Packaging Tool)是一个在 Debian 和 Ubuntu 中的 Shell 前端软件包管理器。 apt 命令提供了查找、安装、升级、删除某一个、一组甚至全部软件包的命令,而且命令简洁而又好记。 apt 命令执行需要超级管理员权限(root)。 操作 …

c++ u7-02-高精度乘法

本节课作业: 链接:https://pan.baidu.com/s/13-FC86jSHGziRDA8lqzimg?pwd=owv1 提取码:owv1高精度乘法    #include<iostream> #include<cstdio> #include<cstring> using namespace std; string x , y; int a[50010] , b[50010] , c[50010…

node-red的基本指令

1. inject->debug输入到输出,调试结果在右边如果选择时间戳的话,可以选择立即执行,或者周期性,持续执行inject除了时间戳还有一些其他输入项可以选择inject选择json文件输出写好json文件之后点击格式化json,可以校对文件格式payload.number可以让输出只输出number的内容…

lombokjunit

lombok&junit 1 lombok先去官网或者maven仓库下载jar包https://mvnrepository.com/导入第三方包到项目中右键lib文件夹,点击add as library默认jvm不解析第三方注解,需要手动开启使用//@Setter // 生成set方法 1 //@Getter // 生成get方法 2 //@ToString // 生…

日期类异常类

日期类&异常类 作业:千位数字相乘 public static void main(String[] args) {// 两个千位数字相乘int[] arr1 = {7,8,9,9,8,9};int[] arr2 = {7,9,8,9,6,8};// 定义结果的数组int[] result = new int[12];for (int i = 0; i < arr1.length; i++) {for (int j = 0; j <…

包装类数学类位运算

包装类&数学类&位运算 1 包装类 把基本数据类型包装成引用数据类型byte short int long float double char boolean voidByte Short Integer Long Float Double Character Boolean VoidVoidVoid类构造方法是私有的,所以不能创建对象。并且Void是一个最终类,没有子类。…