PyTorch入门之【AlexNet】

参考文献:https://www.bilibili.com/video/BV1DP411C7Bw/?spm_id_from=333.999.0.0&vd_source=98d31d5c9db8c0021988f2c2c25a9620
AlexNet 是一个经典的卷积神经网络模型,用于图像分类任务。

目录

  • 大纲
  • dataloader
  • model
  • train
  • test

大纲

在这里插入图片描述
各个文件的作用:

  • data就是数据集
  • dataloader.py就是数据集的加载以及实例初始化
  • model.py就是AlexNet模块的定义
  • train.py就是模型的训练
  • test.py就是模型的测试

dataloader

import torch
import torchvision
import torchvision.transforms as transformsimport matplotlib.pyplot as plt
import numpy as np# define the dataloader
transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])batch_size = 16trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,shuffle=False)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')if __name__ == '__main__':# get some random training imagesdataiter = iter(train_loader)images, labels = next(dataiter)# print labelsprint(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))# show imagesimg_grid = torchvision.utils.make_grid(images)img_grid = img_grid / 2 + 0.5npimg = img_grid.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()

model

import torch.nn as nn
import torchclass AlexNet(nn.Module):def __init__(self, num_classes=10):super(AlexNet, self).__init__()self.conv_1 = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),nn.BatchNorm2d(96),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.conv_2 = nn.Sequential(nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.conv_3 = nn.Sequential(nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(384),nn.ReLU())self.conv_4 = nn.Sequential(nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(384),nn.ReLU())self.conv_5 = nn.Sequential(nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.fc_1 = nn.Sequential(nn.Dropout(0.5),nn.Linear(9216, 4096),nn.ReLU())self.fc_2 = nn.Sequential(nn.Dropout(0.5),nn.Linear(4096, 4096),nn.ReLU())self.fc_3= nn.Sequential(nn.Linear(4096, num_classes))def forward(self, x):out = self.conv_1(x)out = self.conv_2(out)out = self.conv_3(out)out = self.conv_4(out)out = self.conv_5(out)out = out.reshape(out.size(0), -1)out = self.fc_1(out)out = self.fc_2(out)out = self.fc_3(out)return outif __name__ == '__main__':model = AlexNet()print(model)x = torch.randn(1, 3, 224, 224)y = model(x)print(y.size())

train

import torch
import torch.nn as nnfrom dataloader import train_loader, test_loader
from model import AlexNet# define the hyperparameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 10
num_epochs = 20
learning_rate = 1e-3# load the model
model = AlexNet(num_classes).to(device)# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  # train the model
total_len = len(train_loader)for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):# move tensors to the configured deviceimages = images.to(device)labels = labels.to(device)# forward passoutputs = model(images)loss = criterion(outputs, labels)# backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_len, loss.item()))# Validationwith torch.no_grad():model.eval()correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()model.train()print('Accuracy of the network on the {} validation images: {} %'.format(10000, 100 * correct / total))# save the model checkpoint
torch.save(model.state_dict(), 'alexnet.pth')

test

import torchfrom dataloader import test_loader, classes
from model import AlexNet# load the pretrained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AlexNet().to(device)
model.load_state_dict(torch.load('alexnet.pth', map_location=device))# test the pretrained model on CIFAR-10 test data
with torch.no_grad():model.eval()correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the {} validation images: {} %'.format(10000, 100 * correct / total))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/126748.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【java爬虫】使用vue+element-plus编写一个简单的管理页面

前言 前面我们已经将某宝联盟的数据获取下来了,并且编写了一个接口将数据返回,现在我们需要使用vueelement-plus编写一个简单的管理页面进行数据展示,由于第一次使用vue编写前端项目,所以只是编写了一个非常简单的页面。 项目结…

【Spatial-Temporal Action Localization(五)】论文阅读2020年

文章目录 1. Actions as Moving Points摘要和结论引言:针对痛点和贡献模型框架实验 1. Actions as Moving Points Actions as Moving Points (ECCV 2020) 摘要和结论 MovingCenter Detector (MOCdetector) 通过将动作实例视为移动点的轨迹。通过三个分支生成 tub…

使用css制作3D盒子,目的是把盒子并列制作成3D货架

注意事项:这个正方体的其他面的角度很难调,因此如果想动态生成,需要很复杂的设置动态的角度,反正我是折腾了半天没继续搞下去, 1. 首先看效果(第一个五颜六色的是透明多个面,第2-3都是只有3个面…

专业图像处理软件DxO PhotoLab 7 mac中文特点和功能

DxO PhotoLab 7 mac是一款专业的图像处理软件,它为摄影师和摄影爱好者提供了强大而全面的照片处理和编辑功能。 DxO PhotoLab 7 mac软件特点和功能 强大的RAW和JPEG格式处理能力:DxO PhotoLab 7可以处理来自各种相机的RAW格式图像,包括佳能、…

SpringBatch适配不同数据库的两种方法

一、配置JobRepository Configuration EnableBatchProcessing public class TaskArrangeConfig extends DefaultBatchConfigurer {Autowiredprivate DataSource dataSource;Autowiredprivate JobLauncher jobLauncher;Autowiredprivate JobExplorer jobExplorer;Autowiredpriv…

计算机毕设 大数据工作岗位数据分析与可视化 - python flask

文章目录 0 前言1 课题背景2 实现效果3 项目实现3.1 概括 3.2 Flask实现3.3 HTML页面交互及Jinja2 4 **完整代码**5 最后 0 前言 🔥 这两年开始毕业设计和毕业答辩的要求和难度不断提升,传统的毕设题目缺少创新和亮点,往往达不到毕业答辩的要…

沈阳陪诊系统|沈阳陪诊系统开发|沈阳陪诊系统功能和优势

在现代医疗服务中,陪诊系统服务正变得越来越重要。这项创新的服务提供了一种全新的方式,帮助患者在医院就诊时获得更好的照顾和支持。无论是面对复杂的医学流程还是需要心理支持,陪诊系统服务都能够为患者提供方便、专业的帮助。陪诊系统服务…

python+pygame+opencv+gpt实现虚拟数字人直播(一)

AI技术突飞猛进,不断的改变着人们的工作和生活。数字人直播作为新兴形式,必将成为未来趋势,具有巨大的、广阔的、惊人的市场前景。它将不断融合创新技术和跨界合作,提供更具个性化和多样化的互动体验,成为未来的一种趋…

leetCode 53.最大子数和 图解 + 贪心算法/动态规划+优化

53. 最大子数组和 - 力扣(LeetCode) 给你一个整数数组 nums ,请你找出一个具有最大和的连续子数组(子数组最少包含一个元素),返回其最大和。 子数组 是数组中的一个连续部分。 示例 1: 输入…

Android Camera FW 里的requestId和frameId

安卓相机frameworks里面经常出现requestId和frameId,最近简单看了一下代码,发现相关流程还是很复杂的,总结来看requestId 就是上层(java)发送的repeating(capture)请求的id,是从0开始递增的。 这是CameraD…

Godot 官方2D游戏笔记(1):导入动画资源和添加节点

前言 Godot 官方给了我们2D游戏和3D游戏的案例,不过如果是独立开发者只用考虑2D游戏就可以了,因为2D游戏纯粹,我们只需要关注游戏的玩法即可。2D游戏的美术素材简单,交互逻辑简单,我们可以把更多的时间放在游戏的玩法…

Transformer学习

这里写目录标题 Seq2Seq语音翻译为何不直接用语音辨识机器翻译?语法分析文章归类问题目标检测 TransformerEncoder结构multi-head attention block为何batch-norm 不如 layer-norm? Decoder结构decoder流程decoder结构decoder比encoder多了一个masked se…