(动手学习深度学习)第13章 实战kaggle竞赛:狗的品种识别

文章目录

      • 1. 导入相关库
      • 2. 加载数据集
      • 3. 整理数据集
      • 4. 图像增广
      • 5. 读取数据
      • 6. 微调预训练模型
      • 7. 定义损失函数和评价损失函数
      • 9. 训练模型

1. 导入相关库

import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

2. 加载数据集

- 该数据集是完整数据集的小规模样本
# 下载数据集
d2l.DATA_HUB['dog_tiny'] = (d2l.DATA_URL + 'kaggle_dog_tiny.zip','0cb91d09b814ecdc07b50f31f8dcad3e81d6a86d')# 如果使用Kaggle比赛的完整数据集,请将下面的变量更改为False
demo = True
if demo:data_dir = d2l.download_extract('dog_tiny')
else:data_dir = os.path.join('..', 'data', 'dog-breed-identification')

3. 整理数据集

def reorg_dog_data(data_dir, valid_ratio):labels = d2l.read_csv_labels(os.path.join(data_dir, 'labels.csv'))d2l.reorg_train_valid(data_dir, labels, valid_ratio)d2l.reorg_test(data_dir)batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_dog_data(data_dir, valid_ratio)

4. 图像增广

transform_train = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0)),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
transform_test = torchvision.transforms.Compose([torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

5. 读取数据

train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder),transform=transform_train) for folder in ['train', 'train_valid']
]
valid_ds, test_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder),transform=transform_test) for folder in ['valid', 'test']
]
train_iter, train_valid_iter = [torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, drop_last=True) for dataset in (train_ds, train_valid_ds)
]
valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False, drop_last=True
)
test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False, drop_last=True
)

6. 微调预训练模型

def get_net(devices):finetune_net = nn.Sequential()finetune_net.features = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)# 定义一个新的输出网络,共有120个输出类别finetune_net.output_new = nn.Sequential(nn.Linear(1000, 256),nn.ReLU(),nn.Linear(256, 120))finetune_net = finetune_net.to(devices[0])# 冻结参数for param in finetune_net.features.parameters():param.requires_grad = Falsereturn finetune_net
# 查看网络模型
get_net(devices=d2l.try_all_gpus())

在这里插入图片描述

7. 定义损失函数和评价损失函数

# 定义损失函数
loss = nn.CrossEntropyLoss(reduction='none')def evaluate_loss(data_iter, net, device):l_sum, n = 0.0, 0for features, labels in data_iter:features, labels = features.to(device[0]), labels.to(device[0])outputs = net(features)l = loss(outputs, labels)l_sum += l.sum()n += labels.numel()return (l_sum / n).to('cpu')
  1. 定义训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):# 只训练小型定义输出网络net = nn.DataParallel(net, device_ids=devices).to(devices[0])trainer = torch.optim.SGD((param for param in net.parameters() if param.requires_grad),lr=lr, momentum=0.9, weight_decay=wd)scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)num_batches, timer = len(train_iter), d2l.Timer()legend = ['train loss']if valid_iter is not None:legend.append('valid loss')animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend)for epoch in range(num_epochs):metric = d2l.Accumulator(2)for i, (features, labels) in enumerate(train_iter):timer.start()features, labels = features.to(devices[0]), labels.to(devices[0])trainer.zero_grad()output = net(features)l = loss(output, labels).sum()l.backward()trainer.step()metric.add(l, labels.shape[0])timer.stop()if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches, (metric[0] / metric[1], None))measures = f'train loss {metric[0] / metric[1]:.3f}'if valid_iter is not None :valid_loss = evaluate_loss(valid_iter, net, devices)animator.add(epoch + 1, (None, valid_loss.detach().cpu()))scheduler.step()if valid_iter is not None:measures += f', valid loss {valid_loss:.3f}'print(measures + f'\n{metric[1] * num_epochs / timer.sum():.1f}'f'examples/sec on {str(devices)}')

9. 训练模型

devices, num_epochs, lr, wd = d2l.try_all_gpus(), 10, 1e-4, 1e-4
lr_period, lr_decay, net, = 2, 0.9, get_net(devices)
import time# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f}')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/209642.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

关于 IBM Spectrum LSF

关于 IBM Spectrum LSF IBM Spectrum LSF 允许您使用 IBM Spectrum LSF 作为 HPC 调度软件来部署高性能计算 (HPC) 集群。 此产品使用基于开放式源代码 Terraform 的自动化来供应和配置 IBM Cloud 资源。 通过简单的步骤来定义配置属性并使用自动化部署,您可以在几…

如何判断交流回馈老化测试负载是否合格?

交流回馈老化测试负载是用于模拟实际工作环境中设备运行状态的测试工具,主要用于检测设备的耐久性和稳定性。 负载性能:需要检查负载的性能是否符合设计要求,这包括负载的功率、电流、电压等参数是否在规定的范围内,以及负载的工作…

TikTok与精神健康:社交媒体在压力时代的作用

在当今数字化和社交化的时代,社交媒体已成为人们生活中不可或缺的一部分。其中,TikTok作为一款备受欢迎的短视频应用,不仅改变了人们的娱乐方式,也对精神健康产生了深远的影响。 本文将深入探讨TikTok在压力时代对精神健康的作用…

opencv-Hough 直线变换

Hough 直线变换是一种在图像中检测直线的技术。它通过在极坐标空间中表示图像中的直线,将直线检测问题转换为参数空间的累加问题。OpenCV 提供了 cv2.HoughLines() 和 cv2.HoughLinesP() 函数来执行 Hough 直线变换。 cv2.HoughLines() lines cv2.HoughLines(ima…

【AGC】云存储服务端使用方法

【集成准备】 1、Python环境配置 下载Python和PyCharm并安装。 ​ 使用安装的python本身作为解释器。 ​ 安装AGC Python SDK。 ​云存储包安装完成。 ​ 2、AGC环境配置 在AGC创建项目和应用 ​ 开通云存储服务。 返回项目设置界面,选择Server SDK 页签…

Redis高可用之主从复制及哨兵模式

一、Redis的主从复制 1.1 Redis主从复制定义 主从复制是redis实现高可用的基础,哨兵模式和集群都是在主从复制的基础之上实现高可用; 主从复制实现数据的多级备份,以及读写分离(主服务器负责写,从服务器只能读) 1.2 主从复制流…

系统中正在运行的进程数量和等待运行时间的可运行进程数量计算猜测

vmstat的r输出是包含正在运行和等待运行时间的可运行进程的数量。 ps r的输出中是只包含正在运行的进程,于是通过“ ps r |awk ‘{if (NR >1) print $0}’ |wc -l ”可以统计出正在运行中的进程的数量。 那么根据上面的结果,等待运行时间的进程的…

创建vue项目体验

文章目录 使用vue-cli创建vue项目创建出的项目目录结构配置router 运行问题router未找到eslint报错 首页显示单页面内容替换 使用vue-cli创建vue项目 安装vue-cli,创建基本项目 选择步骤 一般创建成功后,提示使用下面的指令运行demo npm run serve创建…

docker compose搭建渗透测试vulstudy靶场示例

前言 渗透测试(Penetration test)即网络安全工程师/安全测试工程师/渗透测试工程师通过模拟黑客,在合法授权范围内,通过信息搜集、漏洞挖掘、权限提升等行为,对目标对象进行安全测试(或攻击)&am…

人物血条的制作_unity基础开发教程

人物血条的制作 场景创建导入素材血条制作血量控制代码部分 场景创建 随便创建一个地板、一个胶囊体,搭建一个简易的场景,我这里就继续使用前面文章创建的场景 导入素材 在unity编辑器中选择Window,点击Asset Store 点击Search online 在搜…

CVE-2022-21661

简介 CVE-2022-21661是一个与WordPress相关的漏洞,涉及到SQL注入问题。该漏洞主要源于WordPress的WQ_Tax_Query类中的clean_query函数,可能允许攻击者通过控制传递给该函数的数据来控制生成的SQL查询,从而执行任意的SQL代码。 当WordPress的…

C语言运算符详解

详细介绍了C语言表达式、算术运算符、赋值运算符、关系运算符、条件结构、逻辑运算符、位运算符的语法和使用方法,并讨论了运算符的优先级。 1、表达式与算术运算符 在C语言中,表达式是一个类似数学中的算式,表达式由变量、字面值、常量、运…