softmax回实战

1.数据集

MNIST数据集 (LeCun et al., 1998) 是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。 我们将使用类似但更复杂的Fashion-MNIST数据集 (Xiao et al., 2017)。

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2ld2l.use_svg_display()

读取数据集

我们可以通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

读取小批量

为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。 回顾一下,在每次迭代中,数据加载器每次都会读取一小批量数据,大小为batch_size。 通过内置数据迭代器,我们可以随机打乱了所有样本,从而无偏见地读取小批量。

batch_size = 256def get_dataloader_workers():  #@save"""使用4个进程来读取数据"""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())

整合所有组件

def load_data_fashion_mnist(batch_size, resize=None):  #@save"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

我们现在已经准备好使用Fashion-MNIST数据集,便于下面的章节调用来评估各种分类算法。

  • 数据迭代器是获得更高性能的关键组件。依靠实现良好的数据迭代器,利用高性能计算来避免减慢训练过程。

2.初始化模型参数

在softmax回归中,我们的输出与类别一样多。 因为我们的数据集有10个类别,所以网络输出维度为10。 因此,权重将构成一个784×10的矩阵, 偏置将构成一个1×10的行向量。

num_inputs = 784
num_outputs = 10W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

3.定义softmax操作

def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp / partition  # 这里应用了广播机制

正如上述代码,对于任何随机输入,我们将每个元素变成一个非负数。 此外,依据概率原理,每行总和为1。

X = torch.normal(0, 1, (2, 5))
X_prob = softmax(X)
X_prob, X_prob.sum(1)(tensor([[0.1686, 0.4055, 0.0849, 0.1064, 0.2347],[0.0217, 0.2652, 0.6354, 0.0457, 0.0321]]),tensor([1.0000, 1.0000]))

4.定义模型

定义softmax操作后,我们可以实现softmax回归模型。 下面的代码定义了输入如何通过网络映射到输出。 注意,将数据传递到模型之前,我们使用reshape函数将每张原始图像展平为向量。

def net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

5.定义损失函数

接下来,我们引入交叉熵损失函数。 这可能是深度学习中最常见的损失函数,因为目前分类问题的数量远远超过回归问题的数量。

def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])

小批量随机梯度下降来优化模型的损失函数,设置学习率为0.1

lr = 0.1def updater(batch_size):return d2l.sgd([W, b], lr, batch_size)

6.分类精度

为了计算精度,我们执行以下操作。 首先,如果y_hat是矩阵,那么假定第二个维度存储每个类的预测分数。 我们使用argmax获得每行中最大元素的索引来获得预测类别。 然后我们将预测类别与真实y元素进行比较。 由于等式运算符“==”对数据类型很敏感, 因此我们将y_hat的数据类型转换为与y的数据类型一致。 结果是一个包含0(错)和1(对)的张量。 最后,我们求和会得到正确预测的数量。

def accuracy(y_hat, y):  #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())

7.训练

在这里,我们重构训练过程的实现以使其可重复使用。 首先,我们定义一个函数来训练一个迭代周期。 请注意,updater是更新模型参数的常用函数,它接受批量大小作为参数。

def train_epoch_ch3(net, train_iter, loss, updater):  #@save"""训练模型一个迭代周期(定义见第3章)"""# 将模型设置为训练模式if isinstance(net, torch.nn.Module):net.train()# 训练损失总和、训练准确度总和、样本数metric = Accumulator(3)for X, y in train_iter:# 计算梯度并更新参数y_hat = net(X)l = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):# 使用PyTorch内置的优化器和损失函数updater.zero_grad()l.mean().backward()updater.step()else:# 使用定制的优化器和损失函数l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]

在展示训练函数的实现之前,我们定义一个在动画中绘制数据的实用程序类Animator, 它能够简化本书其余部分的代码。

class Animator:  #@save"""在动画中绘制数据"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):# 增量地绘制多条线if legend is None:legend = []d2l.use_svg_display()self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# 使用lambda函数捕获参数self.config_axes = lambda: d2l.set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef add(self, x, y):# 向图表中添加多个数据点if not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()display.display(self.fig)display.clear_output(wait=True)

接下来我们实现一个训练函数, 它会在train_iter访问到的训练数据集上训练一个模型net。 该训练函数将会运行多个迭代周期(由num_epochs指定)。 在每个迭代周期结束时,利用test_iter访问到的测试数据集对模型进行评估。 我们将利用Animator类来可视化训练进度。

def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):  #@save"""训练模型(定义见第3章)"""animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],legend=['train loss', 'train acc', 'test acc'])for epoch in range(num_epochs):train_metrics = train_epoch_ch3(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)animator.add(epoch + 1, train_metrics + (test_acc,))train_loss, train_acc = train_metricsassert train_loss < 0.5, train_lossassert train_acc <= 1 and train_acc > 0.7, train_accassert test_acc <= 1 and test_acc > 0.7, test_acc

训练:

num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

8.预测

现在训练已经完成,我们的模型已经准备好对图像进行分类预测。 给定一系列图像,我们将比较它们的实际标签(文本输出的第一行)和模型预测(文本输出的第二行)。

def predict_ch3(net, test_iter, n=6):  #@save"""预测标签(定义见第3章)"""for X, y in test_iter:breaktrues = d2l.get_fashion_mnist_labels(y)preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true +'\n' + pred for true, pred in zip(trues, preds)]d2l.show_images(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])predict_ch3(net, test_iter)

9.总结:

  • 借助softmax回归,我们可以训练多分类的模型。
  • 训练softmax回归循环模型与训练线性回归模型非常相似:先读取数据,再定义模型和损失函数,然后使用优化算法训练模型。大多数常见的深度学习模型都有类似的训练过程。

10.简洁实现:

10.1读取数据

import torch
from torch import nn
from d2l import torch as d2lbatch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

10.2定义模型,初始化参数

跟线性规划模型是一样的,只不过有十个输出

# PyTorch不会隐式地调整输入的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights);

10.3定义损失函数

loss = nn.CrossEntropyLoss(reduction='none')

10.4定义优化器

我们使用学习率为0.1的小批量随机梯度下降作为优化算法

updater = torch.optim.SGD(net.parameters(), lr=0.1)

10.5训练

调用上面写好的函数,只不过这次的模型,损失函数,优化器都是pytorch内置的

num_epochs = 10
train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

10.6预测:

predict_ch3(net, test_iter)

10.7小结

  • 使用深度学习框架的高级API,我们可以更简洁地实现softmax回归。
  • 从计算的角度来看,实现softmax回归比较复杂。在许多情况下,深度学习框架在这些著名的技巧之外采取了额外的预防措施,来确保数值的稳定性。这使我们避免了在实践中从零开始编写模型时可能遇到的陷阱。
    数,优化器**都是pytorch内置的
num_epochs = 10
train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

10.6预测:

predict_ch3(net, test_iter)

10.7小结

  • 使用深度学习框架的高级API,我们可以更简洁地实现softmax回归。
  • 从计算的角度来看,实现softmax回归比较复杂。在许多情况下,深度学习框架在这些著名的技巧之外采取了额外的预防措施,来确保数值的稳定性。这使我们避免了在实践中从零开始编写模型时可能遇到的陷阱。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/418419.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Web自动化测试 —— cookie复用

一、cookie简介 cookie是一些数据&#xff0c;存储于用户电脑的文本文件中 当web服务器想浏览器发送web页面时&#xff0c;在链接关闭后&#xff0c;服务端不会记录用户信息 二、为什么要使用Cookie自动化登录 复用浏览器仍然在每次用例开始都需要人为介入若用例需要经常执行&…

浅析性能测试策略及适用场景

前言 面对日益复杂的业务场景和不同的系统架构&#xff0c;前期的需求分析和准备工作&#xff0c;需要耗费很多的时间。而不同的测试策略&#xff0c;也对我们的测试结果是否符合预期目标至关重要。 这篇博客&#xff0c;聊聊我个人对常见的性能测试策略的理解&#xff0c;以…

Android Matrix绘制PaintDrawable设置BitmapShader,手指触点为圆心scale放大原图,Kotlin(二)

Android Matrix绘制PaintDrawable设置BitmapShader&#xff0c;手指触点为圆心scale放大原图&#xff0c;Kotlin&#xff08;二&#xff09; 在 Android Matrix绘制PaintDrawable设置BitmapShader&#xff0c;手指触点为圆心scale放大原图&#xff0c;Kotlin-CSDN博客 基础上&…

信息登记小程序怎么做_重塑用户互动,开启全新营销篇章

信息登记小程序&#xff1a;重塑用户互动&#xff0c;开启全新营销篇章 在数字化浪潮中&#xff0c;小程序以其便捷、高效的特点&#xff0c;逐渐成为企业与用户之间沟通的桥梁。其中&#xff0c;信息登记小程序更是凭借其独特的定位&#xff0c;在众多小程序中脱颖而出。本文…

如何用GPT进行论文润色与改写?

详情点击链接&#xff1a;如何用GPT进行论文润色与改写&#xff1f; 一OpenAI 1.最新大模型GPT-4 Turbo 2.最新发布的高级数据分析&#xff0c;AI画图&#xff0c;图像识别&#xff0c;文档API 3.GPT Store 4.从0到1创建自己的GPT应用 5. 模型Gemini以及大模型Claude2二定…

C#,入门教程(20)——列表(List)的基础知识

上一篇&#xff1a; C#&#xff0c;入门教程(19)——循环语句&#xff08;for&#xff0c;while&#xff0c;foreach&#xff09;的基础知识https://blog.csdn.net/beijinghorn/article/details/124060844 List顾名思义就是数据列表&#xff0c;区别于数据数组&#xff08;arr…

html5实现好看的年会邀请函源码模板

文章目录 1.设计来源1.1 邀请函主界面1.2 诚挚邀请界面1.3 关于我们界面1.4 董事长致词界面1.5 公司合作方界面1.6 活动流程界面1.7 加盟支持界面1.8 加盟流程界面1.9 加盟申请界面1.10 活动信息界面 2.效果和源码2.1 动态效果2.2 源码目录结构 源码下载 作者&#xff1a;xcLei…

最长公共前缀(Leetcode14)

例题&#xff1a; 分析&#xff1a; 我们可以先定义两个变量 i &#xff0c; j&#xff0c; j表示数组中的每一个字符串&#xff0c; i 表示每个字符串中的第几个字符。一列一列地进行比较&#xff0c;先比较第一列的字符&#xff0c;若都相同&#xff0c;则 i &#xff0c;继…

linux-部署Samba文件共享服务

linux-部署Samba文件共享服务 1、使用命令安装samba服务和samba客户端 dnf install samba samba-client # 或者 yum install samba samba-client2、配置文件的设置(可提前备份smb.conf) vim /etc/samba/smb.conf [global]workgroup SAMBAsecurity userpassdb backend tdbsam…

算法总结——单调栈

纵有疾风起&#xff0c;人生不言弃。本文篇幅较长&#xff0c;如有错误请不吝赐教&#xff0c;感谢支持。 文章目录 一、单调栈的定义二、单调栈的应用&#xff1a;寻找左边第一个比它小的数单调栈的思想&#xff08;重点&#xff09;&#xff1a;寻找左边第一个比它小的数的下…

组件通信方式

组件通信 父子组件通信 单向数据流 属性传递props&#xff08;还有插槽&#xff0c;$attrs非属性&#xff09;/$emit&#xff0c;发布订阅模式 方法也可以作为属性 父子组件渲染生命周期&#xff1a; 获取组件实例。$children、ref&$refs/$parent 祖先和后代 组件和后代通信…

css绘制下拉框头部三角(分实心/空心)

1:需求图: 手绘下拉框 带三角 2:网上查了一些例子,但都是实心的, 可参考,如图: (原链接: https://blog.csdn.net/qq_33463449/article/details/113375804) 3:简洁版的: a: 实心: <view class"angle"/>.angle{width:0;height:0;border-left: 10px solid t…