学习pytorch19 pytorch使用GPU训练

pytorch使用GPU进行训练

  • 1. 数据 模型 损失函数调用cuda()
  • 2. 使用谷歌免费GPU gogle colab 需要创建谷歌账号登录使用, 网络能访问谷歌
  • 3. 执行
  • 4. 代码

B站土堆学习视频: https://www.bilibili.com/video/BV1hE411t7RN/?p=30&spm_id_from=pageDriver&vd_source=9607a6d9d829b667f8f0ccaaaa142fcb

1. 数据 模型 损失函数调用cuda()

if torch.cuda.is_available():net = net.cuda()
if torch.cuda.is_available():loss_fn = loss_fn.cuda()
if torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()

2. 使用谷歌免费GPU gogle colab 需要创建谷歌账号登录使用, 网络能访问谷歌

土堆YYDS, colab省事省力好多

3. 执行

运行报错
https://stackoverflow.com/questions/59013109/runtimeerror-input-type-torch-floattensor-and-weight-type-torch-cuda-floatte
You get this error because your model is on the GPU, but your data is on the CPU. So, you need to send your input tensors to the GPU.
意思是数据在cpu但是模型在gpu导致报错
查看代码测试数据做模型预测的时候,有个变量名写错了,改正后正常运行。
在这里插入图片描述
在这里插入图片描述

4. 代码

import torch
import torchvision
from torch import nn
from torch.utils.tensorboard import SummaryWriter
import time
# from p24_model import *# 1. 准备数据集
# 训练数据
from torch.utils.data import DataLoadertrain_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),download=True)
# 测试数据
test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),download=True)# 查看数据大小--size
print("训练数据集大小:", len(train_data))
print("测试数据集大小:", len(test_data))
# 利用DataLoader来加载数据集
train_loader = DataLoader(dataset=train_data, batch_size=64)
test_loader = DataLoader(dataset=test_data, batch_size=64)# 2. 导入模型结构 创建模型
class Cifar10Net(nn.Module):def __init__(self):super(Cifar10Net, self).__init__()self.net = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self, x):x = self.net(x)return xnet = Cifar10Net()
if torch.cuda.is_available():net = net.cuda()# 3. 创建损失函数  分类问题--交叉熵
loss_fn = nn.CrossEntropyLoss()
if torch.cuda.is_available():loss_fn = loss_fn.cuda()# 4. 创建优化器
# learing_rate = 0.01
# 1e-2 = 1 * 10^(-2) = 0.01
learing_rate = 1e-2
print(learing_rate)
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)# 设置训练网络的一些参数
epoch = 10   # 记录训练的轮数
total_train_step = 0  # 记录训练的次数
total_test_step = 0   # 记录测试的次数# 利用tensorboard显示训练loss趋势
writer = SummaryWriter('./train_logs')start_time = time.time()
print('start_time: ', start_time)
print(torch.cuda.is_available())
for i in range(epoch):# 训练步骤开始net.train()  # 可以加可以不加  只有当模型结构有 Dropout BatchNorml层才会起作用for data in train_loader:imgs, targets = data  # 获取数据if torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()output = net(imgs)    # 数据输入模型loss = loss_fn(output, targets)  # 损失函数计算损失 看计算的输出和真实的标签误差是多少# 优化器开始优化模型  1.梯度清零  2.反向传播  3.参数优化optimizer.zero_grad()  # 利用优化器把梯度清零 全部设置为0loss.backward()        # 设置计算的损失值,调用损失的反向传播,计算每个参数结点的参数optimizer.step()       # 调用优化器的step()方法 对其中的参数进行优化# 优化一次 认为训练了一次total_train_step += 1if total_train_step % 100 == 0:print('训练次数: {}   loss: {}'.format(total_train_step, loss))end_time = time.time()print('训练100次需要的时间:', end_time-start_time)# 直接打印loss是tensor数据类型,打印loss.item()是打印的int或float真实数值, 真实数值方便做数据可视化【损失可视化】# print('训练次数: {}   loss: {}'.format(total_train_step, loss.item()))writer.add_scalar('train-loss', loss.item(), global_step=total_train_step)# 利用现有模型做模型测试# 测试步骤开始total_test_loss = 0accuracy = 0net.eval()  # 可以加可以不加  只有当模型结构有 Dropout BatchNorml层才会起作用with torch.no_grad():for data in test_loader:imags, targets = dataif torch.cuda.is_available():imags = imags.cuda()targets = targets.cuda()output = net(imags)loss = loss_fn(output, targets)total_test_loss += loss.item()# 计算测试集的正确率preds = (output.argmax(1)==targets).sum()accuracy += preds# writer.add_scalar('test-loss', total_test_loss, global_step=i+1)writer.add_scalar('test-loss', total_test_loss, global_step=total_test_step)writer.add_scalar('test-accracy', accuracy/len(test_data), total_test_step)total_test_step += 1print("---------test loss: {}--------------".format(total_test_loss))print("---------test accuracy: {}--------------".format(accuracy/len(test_data)))# 保存每一个epoch训练得到的模型torch.save(net, './net_epoch{}.pth'.format(i))writer.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/258604.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

防水,也不怕水。Mate X5是如何做到让你湿手湿屏也不影响操作的?

相信不少人都碰到过当手机屏幕存在小水珠时,触控变得不灵敏,或者出现“幽灵触屏”,指东打西的情况。 尤其是在洗澡、做饭,或者在户外遇到下雨天气时,如果打湿的手机收到重要聊天消息或者电话,却因为湿屏导…

循环结构中 break、continue、return 和exit() 的区别

循环结构中 break、continue、return 和exit() 的区别 文章目录 循环结构中 break、continue、return 和exit() 的区别一、break语句二、continue语句三、return 语句四、exit() 函数 说明:本文内容参考牟海军 著《C语言进阶: 重点、难点与疑点解析》&a…

交换机与VLAN

交换机简介 交换机的作用 如今随着计算机速度不断提高,以及网络应用越来越多,局域网的负载变得越来越大了,交换机的使用也变得更有必要。交换机的作用主要有两个:一个是维护CAM(Context Address Memory)表…

WGCLOUD v3.5.0 新增支持监测交换机的接口状态UP DOWN

WGCLOUD v3.5.0开始 可以监测交换机或SNMP设备的接口状态了,直接上图

实时动作识别学习笔记

目录 yowo v2 yowof 判断是在干什么,不能获取细节信息 yowo v2 https://github.com/yjh0410/YOWOv2/blob/master/README_CN.md ModelClipmAPFPSweightYOWOv2-Nano1612.640ckptYOWOv2-Tiny

潮落云起:中国云桌面的产业变局

云桌面,又被称为桌面云、桌面虚拟化技术。这项技术的起源可以追溯到20世纪70年代,IBM通过一台计算机来实现多用户资源的桌面共享。在数十年的发展中,云桌面从技术实现方式到产品形态日趋丰富。 相比于传统PC数据单独存放、系统独立运维的分散…

医学图像数据处理流程以及遇到的问题

数据总目录: /home/bavon/datasets/wsi/hsil /home/bavon/datasets/wsi/lsil 1 规整文件命名以及xml拷贝 data_prepare.py 的 align_xml_svs 方法 if __name__ __main__: file_path "/home/bavon/datasets/wsi/lsil"# align_xml_svs(file_path) # b…

【android开发-22】android中音频和视频用法详解

1,播放音频 MediaPlayer是Android中用于播放音频和视频的类。它提供了许多方法来控制播放,例如播放、暂停、停止、释放等。下面是一个简单的MediaPlayer用法详解和参考代码例子。 首先,确保在布局文件中添加了一个MediaPlayer控件&#xff…

Python 小红书评论区采集 小红薯xhs精准用户获客

成品图 评论接口https://edith.xiaohongshu.com/api/sns/web/v2/comment/page?note_id笔记id&cursor光标 初次使用cursor为空,该接口为GET,需要x-s,x-t签名验证 子评论接口https://edith.xiaohongshu.com/api/sns/web/v2/comment/sub/page?note_id%s&r…

修改pip源

修改pip源 永久修改 PS C:\Users\Dell> pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/Writing to C:\Users\Dell\AppData\Roaming\pip\pip.ini临时修改 pip install -i(即--index-url简写) http://mirrors.aliyun.com/pypi/simple/ selenium…

Gerber文件使用详解

目录 概述 一、Gerber 格式 二、接线图示例 三、顶层丝印 四、顶级阻焊层 五、顶部助焊层 六、顶部(或顶部铜) 七、钻头 八、电路板概要 九、使用文本和字体进行 Gerber 导出 十、总结 概述 Gerber文件:它们是什么? PCB制造商如何使用它们? …

Swing程序设计(9)复选框,下拉框

文章目录 前言一、复选框二、下拉框总结 前言 该篇文章简单介绍了Java中Swing组件里的复选框组件、列表框组件、下拉框组件,这些在系统中都是常用的组件。 一、复选框 复选框(JCheckBox)在Swing组件中的使用也非常广泛,一个方形方…