数据增强,迁移学习,Resnet分类实战

目录

1. 数据增强(Data Augmentation)

2. 迁移学习

3. 模型保存    

4. 102种类花分类实战

1. 数据集

2.导入包

3. 数据读取与预处理操作 

4. Datasets制作输入数据

5.将标签的名字读出 

6.展示原始数据 

7.加载models中提供的模型 

8.初始化 

9.优化器设置 

10.训练模块


1. 数据增强(Data Augmentation)

        数据不够怎么办?采用翻转,镜像,增加数据

        如何更加高效利用数据?多利用几次

        在pytorch中有数据预处理部分:

            数据增强:torchvision中transforms模块自带功能,比较实用

            数据预处理:torchvision中transforms也帮我们实现好了,直接调用即可

            DataLoader模块直接读取batch数据

        pyorch官网:https://pytorch.org/vision/stable

2. 迁移学习

        在训练自己的模型时出现一些问题:

        1. 自己的数据不够好

        2. 训练参数花费时间多

        3. 训练模型太难

        解决方法:

        有前人已经训练好了模型,其实就是将训练的参数保留下来,而且目标都差不多。那么把别人的模型参数当成初始化参数,所有的结构和前人模型一样。

        网络模块设置:

    加载预训练模型,torchvision中有很多经典网络架构,调用起来十分方便,并且也可以用人家训练好的权重参数来继续训练,也就是所谓的迁移学习

    需要注意的是别人训练好的任务根咱们的可不是完全一样的,需要把最后的head层改一改,一般也就是最后的全连接层,改成咱们自己的任务

    训练时可以完全重头训练,也可以只训练最后咱们任务层,因为前几层都是做特征提取的,本质任务目标一致的。

        总结:迁移学习策略

                1. 将卷积层当成初始化权重参数

                2.将卷积层权重参数冻住不变,全连接层重新训练(一般是,数据量少,冻住的层数多)

3. 模型保存    

        网络模型保存与测试

            模型保存的时候可以带有选择性,例如在验证集中如果当前效果好则保存

            读取模型进行实际测试

4. 102种类花分类实战

      1. 数据集

        有训练集,测试集。一共102种花,每种花有25~100个图像

2.导入包

import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

3. 数据读取与预处理操作 

data_dir = './flower_data'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'

制作好数据源:

    data_transforms中指定了所有图像预处理操作

    ImageFolder假设所有文件按文件夹保存好,每个文件夹下面存储同一类别的图片,文件夹的名字为分类的名字

data_transforms = {'train' : transforms.Compose([transforms.RandomRotation(45),#随即旋转,-45度到45度之间随机选transforms.CenterCrop(224), #从中心点开始裁剪transforms.RandomHorizontalFlip(p=0.5),#随即水平翻转,选择一个概率transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1), #参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=Btransforms.ToTensor(),#转换成tensor格式transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])#均值,标准差]), 'valid':transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])                   ])
}

 4. Datasets制作输入数据

        采用batch,将数据分组输入。

batch_size  = 8image_datasets = {x : datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x]) for x in ['train','valid']}
dataloaders = {x : torch.utils.data.DataLoader(image_datasets[x],batch_size=batch_size,shuffle=True) for x in ['train','valid']}
dataset_szies = {x :len(image_datasets[x]) for x in ['train','valid']}
class_names = image_datasets['train'].classes
print(image_datasets)
print(dataloaders)

5.将标签的名字读出 

        用123....打标签好像不好,用花的名字作为标签

#读取标签对应的实际名字
with open('./flower_data/cat_to_name.json','r') as f:cat_to_name = json.load(f)
print(cat_to_name)

6.展示原始数据 

        展示下数据

            注意tensor的数据需要转换成numpy格式,而且还需要还原成标准化的结果

def im_convert(tensor):'''展示数据'''image = tensor.to('cpu').clone().detach()image = image.numpy().squeeze()image = image.transpose(1,2,0)image = image * np.array((0.229,0.224,0.225)) + np.array((0.485,0.456,0.406))image = image.clip(0,1)return imagefig = plt.figure(figsize=(20,12))
colunms = 4
rows = 2dataiter = iter(dataloaders['valid'])
inputs,classes =next(dataiter)for idx in range(colunms * rows):ax = fig.add_subplot(rows,colunms,idx+1,xticks = [],yticks = [])ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])plt.imshow(im_convert(inputs[idx]))
plt.show()

7.加载models中提供的模型 

        加载models中提供的模型,并且直接用训练好的权重当作初始化参数

            第一次执行需要下载,可能会比较慢

model_name = 'resnet' #可选的会比较多['resnet','alexnet','vgg','squeezenet','densenet','inception']
# 是否用人家训练好的特征来做
feature_extract = True#是否用GPU训练
train_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('CUDA is not available .  Training on CPU...')
else:print('CUDA is  available .  Training on GPU...')device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')def set_parameter_requires_grad(model,feature_extracting):if feature_extract:for param in model.parameters():param.requires_grad = False #冻不冻住model_ft = models.resnet152()
print(model_ft)

8.初始化 

        迁移学习,用前人的参数,改变全连接层。

def initalize_model(model_name,num_classes,feature_extract,use_pretrained=True):#选择合适的模型,不同模型的初始化方法稍微有点区别model_ft = Noneinput_size = 0if model_name == 'resnet':model_ft = models.resnet152(pretrained=use_pretrained)set_parameter_requires_grad(model_ft,feature_extract)num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Sequential(nn.Linear(num_ftrs,num_classes),nn.LogSoftmax(dim=1))input_size = 224return model_ft,input_sizefeature_extract = True
model_ft,input_size = initalize_model(model_name,102,feature_extract,use_pretrained=True)#GPU计算
model_ft = model_ft.to(device)#模型保存
filename = 'checkpoint.pth'#是否训练所有层params_to_updata = model_ft.parameters()
print('Params to learn')
if feature_extract:params_to_updata = []for name,param in model_ft.named_parameters():if param.requires_grad == True:params_to_updata.append(param)print("\t",name)
else:for name,param in model_ft.named_parameters():if param.requires_grad == True:print("\t",name)print(model_ft)

 

9.优化器设置 

#优化器设置
optimizer_ft = optim.Adam(params_to_updata,lr=1e-2)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1) #学习率每7个epoch衰减成原来的1/10
#最后一层已经LogSoftmax()了,所以不能nn.CrossEntropyLoss()来计算了,nn.CrossEntropyLoss()相当于logSoftmax()和nn.NLLoss()整合
criterion = nn.NLLLoss()

10.训练模块

def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False,filename=filename):since = time.time()best_acc = 0'''checkpoint = torch.laod(filename)best_acc = checkpoint['best_acc]model.load_state_dict(checkpoint['optimizer'])model.class_to_idx = checkpoint['mapping']'''model.to(device)val_acc_history = []train_acc_history = []train_losses = []vaild_losses = []LRs = [optimizer.param_groups[0]['lr']]best_model_wts = copy.deepcopy(model.state_dict())for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch,num_epochs-1))print('-'*10)#训练和验证for phase in ['train','valid']:if phase == 'train':model.train() #训练else:model.eval()  #验证running_loss = 0.0running_corrects = 0#把数据都取个遍for inputs,labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)#清零optimizer.zero_grad()#只有训练的时候计算和更新梯度with torch.set_grad_enabled(phase=='train'):if is_inception and phase == 'train':outputs,aux_outputs = model(inputs)loss1 = criterion(outputs,labels)loss2 = criterion(aux_outputs,labels)loss = loss1 + loss2else: #resnet执行的是这里outputs= model(inputs)loss = criterion(outputs,labels)_,preds = torch.max(outputs,1)#训练阶段更新权重if phase == 'train':loss.backward()optimizer.step()#计算损失running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(dataloaders[phase].dataset)epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)time_elapsed = time.time() - sinceprint('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase,epoch_loss,epoch_acc))#得到最好的那次的模型if phase == 'valid' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())state ={'state_dict' : model.state_dict(),'best_acc': best_acc,'optimizer': optimizer.state_dict(),}torch.save(state,filename)if phase == 'valid':val_acc_history.append(epoch_acc)vaild_losses.append(epoch_loss)scheduler.step(epoch_loss)if phase=='train':train_acc_history.append(epoch_acc)train_losses.append(epoch_loss)print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))LRs.append(optimizer.param_groups[0]['lr'])print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60,time_elapsed % 60))print('Best val Acc: {:.4f}'.format(best_acc))#训练完后用最好的一次当作模型的最终结果model.load_state_dict(best_model_wts)return model, val_acc_history,train_acc_history,vaild_losses,train_losses,LRs
#开始训练!!!
model_ft,val_acc_history,train_acc_history,vaild_losses,train_losses,LRs = train_model(model_ft,dataloaders,criterion,optimizer_ft,num_epochs=5)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/690848.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

德国Dürr杜尔机器人维修技巧分析

在工业生产中,杜尔工业机器人因其高效、精准和稳定性而备受青睐。然而,即便是最精良的设备,也难免会出现Drr机械手故障。 一、传感器故障 1. 视觉传感器故障:可能导致机器人无法正确识别目标物,影响工作效率。解决方法…

力扣每日一题124:二叉树中的最大路径和

题目 困难 二叉树中的 路径 被定义为一条节点序列,序列中每对相邻节点之间都存在一条边。同一个节点在一条路径序列中 至多出现一次 。该路径 至少包含一个 节点,且不一定经过根节点。 路径和 是路径中各节点值的总和。 给你一个二叉树的根节点 root…

Apache ECharts

Apache ECharts介绍: Apache ECharts 是一款基于 Javascript 的数据可视化图表库,提供直观,生动,可交互,可个性化定制的数据可视化图表。 官网地址:https://echarts.apache.org/zh/index.html Apache ECh…

基于STC12C5A60S2系列1T 8051单片机实现一主单片机与一从单片机进行双向串口通信功能

基于STC12C5A60S2系列1T 8051单片机实现一主单片机与一从单片机进行双向串口通信功能 STC12C5A60S2系列1T 8051单片机管脚图STC12C5A60S2系列1T 8051单片机串口通信介绍STC12C5A60S2系列1T 8051单片机串口通信的结构基于STC12C5A60S2系列1T 8051单片机串口通信的特殊功能寄存器…

Django项目运行报错:ModuleNotFoundError: No module named ‘MySQLdb‘

解决方法: 在__init__.py文件下,新增下面这段代码 import pymysql pymysql.install_as_MySQLdb() 注意:确保你的 python 有下载 pymysql 库,没有的话可以使用 pip install pymysql安装 原理:用pymysql来代替mysqlL…

深度学习:基于人工神经网络ANN的降雨预测

前言 系列专栏:【深度学习:算法项目实战】✨︎ 本专栏涉及创建深度学习模型、处理非结构化数据以及指导复杂的模型,如卷积神经网络(CNN)、递归神经网络 (RNN),包括长短期记忆 (LSTM) 、门控循环单元 (GRU)、自动编码器 (AE)、受限玻尔兹曼机(…

docker搭建redis6.0(docker rundocker compose演示)

文章讲了:docker下搭建redis6.0.20遇到一些问题,以及解决后的最佳实践方案 文章实现了: docker run搭建redisdocker compose搭建redis 搭建一个redis’的过程中遇到很多问题,先简单说一下搭建的顺序 找一个redis.conf文件&…

LangChain:大模型框架的深度解析与应用探索

在数字化的时代浪潮中,人工智能技术正以前所未有的速度蓬勃发展,而大模型作为其中的翘楚,以生成式对话技术逐渐成为推动行业乃至整个社会进步的核心力量。再往近一点来说,在公司,不少产品都戴上了人工智能的帽子&#…

maven mirrorOf的作用

在工作中遇到了一个问题导致依赖下载不了,最后发现是mirror的问题,决定好好去看一下mirror的配置,以及mirrorOf的作用,以前都是直接复制过来使用,看了之后才明白什么意思。 过程 如果你设置了镜像,镜像会匹…

Doris【部署 01】Linux部署MPP数据库Doris稳定版(下载+安装+连接+测试)

本次安装测试的为稳定版2.0.8官方文档 https://doris.apache.org/zh-CN/docs/2.0/get-starting/quick-start 这个简短的指南将告诉你如何下载 Doris 最新稳定版本,在单节点上安装并运行它,包括创建数据库、数据表、导入数据及查询等。 Linux部署稳定版Do…

蓝桥杯单片机之模块代码《多样点灯方式》

过往历程 历程1:秒表 历程2:按键显示时钟 历程3:列矩阵按键显示时钟 历程4:行矩阵按键显示时钟 历程5:新DS1302 历程6:小数点精确后两位ds18b20 历程7:35定时器测量频率 历程8&#xff…

基于fastapi sqladmin开发,实现可动态配置admin

1. 功能介绍: 1. 支持动态创建表、类,属性,唯一约束、外键,索引,关系,无需写代码,快速创建业务对象; 2. 支持配置admin显示参数,支持sqladmin原生参数设置,动…