微调Fine tune

网络架构
一个神经网络一般可以分为两块

  • 特征抽取将原始像素变成容易线性分割的特征
  • 线性分类器来做分类
    在这里插入图片描述

微调:使用之前已经训练好的特征抽取模块来直接使用到现有模型上,而对于线性分类器由于标号可能发生改变而不能直接使用

训练
是一个目标数据集上的正常训练任务,但使用更强的正则化

  • 使用更小的学习率
  • 使用更小的数据迭代
    源数据集远复杂于目标数据,通常微调效果会更好

重用分类器权重

  • 源数据集可能也有目标数据中的部分标号
  • 可以使用预训练好模型分类器中对应标号对应的向量来做初始化

固定一些层
神经网络通常学习有层次的特征表示

  • 低层次的特征更加通用
  • 高层次的特征则更和数据集相关
    可以固定底部一些层的参数,不参与更新来减小模型的复杂度
  • 更强的正则

微调通过使用在大数据上得到的预训练好的模型来初始化模型权重来完成提升精度
预训练模型质量很重要
微调通常速度更快,精度更好

就是重用在大数据集上训练好的模型的特征提取模块,用来做自己模型的特征提取的初始化,用来使得相比于随机初始化有更好的效果

1. 实现

%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
# 热狗数据集来源于网络d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')data_dir = d2l.download_extract('hotdog')train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))
# 图像的大小和纵横比各有不同hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4)
# 数据增广normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], # 因为要使用imageNet上的特征提取模块,所以要对数据先进行归一化【方差,均值】[0.229, 0.224, 0.225])train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(), normalize])test_augs = torchvision.transforms.Compose([torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(), normalize])
# 定义和初始化模型pretrained_net = torchvision.models.resnet18(pretrained=True)pretrained_net.fc # 输出Linear(in_features=512, out_features=1000, bias=True)finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2) # 只对最后一层的类别改变
nn.init.xavier_uniform_(finetune_net.fc.weight)
# 微调模型def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'),transform=train_augs),batch_size=batch_size, shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'),transform=test_augs),batch_size=batch_size)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction="none")if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]trainer = torch.optim.SGD([{'params': params_1x}, {'params': net.fc.parameters(), # 最后一层分类器的学习率提高十倍,为了使其能够更快的学习'lr': learning_rate * 10}], lr=learning_rate,weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)
# 使用较小的学习率train_fine_tuning(finetune_net, 5e-5)# 为了进行比较, 所有模型参数初始化为随机值 -》结果没有之前微调的效果好
scratch_net = torchvision.models.resnet18()
scratch_net.fc = nn.Linear(scratch_net.fc.in_features, 2)
train_fine_tuning(scratch_net, 5e-4, param_group=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/236362.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot整合validation数据校验

1. 首先引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-validation</artifactId></dependency> 点标识进去可以发现是通过Hibernate Validator使用 Java Bean Validation 2. 属性上…

二级分类菜单及三级分类菜单的层级结构返回

前言 在开发投诉分类功能模块时&#xff0c;遇到过这样一个业务场景&#xff1a;后端需要按层级结构返回二级分类菜单所需数据&#xff0c;换言之&#xff0c;将具有父子关系的List结果集数据转为树状结构数据来返回 二级分类菜单 前期准备 这里简单复刻下真实场景中 出现的…

Ubuntu22.04 使用Docker部署Neo4j出错 Exited(70)

项目场景&#xff1a; 最近需要使用Neo4j图数据库&#xff0c;因此打算使用docker部署 环境使用WSL Ubuntu22.04 问题描述 拉下最新Neo4j镜像&#xff0c;执行命令部署 启动容器脚本 docker run -d -p 7474:7474 -p 7687:7687 \ --name neo4j \ --env "NEO4J_AUTHneo…

Python语言学习笔记之七(JOSN应用)

本课程对于有其它语言基础的开发人员可以参考和学习&#xff0c;同时也是记录下来&#xff0c;为个人学习使用&#xff0c;文档中有此不当之处&#xff0c;请谅解。 1、认识Json JSON (JavaScript Obiect Notation)是一种轻量级的数据交换格式&#xff0c;它是ECMAScript的一…

英国人工智能初创公司Stability AI面临卖身压力;深度学习中的检索增强生成简介

&#x1f989; AI新闻 &#x1f680; 英国人工智能初创公司Stability AI面临卖身压力 摘要&#xff1a;多位知情人士透露&#xff0c;英国人工智能初创公司Stability AI正寻求出售公司&#xff0c;因为投资者对其财务状况的压力越来越大。管理层最近几周一直将自己标榜为收购…

Day46力扣打卡

最近一直在做以前的题&#xff0c;刷题量都没有怎么增长&#xff0c;感觉自己算法一直不太行&#xff0c;但也只能菜就多练了。 打卡记录 由子序列构造的最长回文串的长度&#xff08;区间DP&#xff09; 链接 第二次刷这道题&#xff0c;相比上回思路来的很快&#xff0c;但…

码云配置遇到秘钥不正确

你这个就是秘钥没有和git绑定&#xff0c; 需要 git config --global user.name "你的用户名随便写" git config --global user.email "你的邮箱"

SpringBoot+Redis获取电脑信息

获取电脑信息 测试 System.getProperties(); System: 是Java中的一个内置类&#xff0c;用于提供与系统相关的功能和信息。这个类中包含了一些静态方法和常量&#xff0c;可以让您方便地访问和操作系统级别的资源。 getProperties(): 是一个静态方法&#xff0c;它返回一个表示…

Windows11如何让桌面图标的箭头消失(去掉快捷键箭头)

在Windows 11中&#xff0c;桌面图标的箭头是快捷方式图标的一个标志&#xff0c;用来表示该图标是一个指向文件、文件夹或程序的快捷方式。如果要隐藏这些箭头&#xff0c;你需要修改Windows注册表或使用第三方软件。 在此之前&#xff0c;我需要提醒你&#xff0c;修改注册表…

俄罗斯方块小游戏开发

代码图&#xff1a; import pygame, randompygame.init()# 游戏界面参数 width 300 height 600 surface pygame.display.set_mode((width, height))# 颜色定义 black (0, 0, 0) white (255, 255, 255) red (200, 0, 0) green (0, 200, 0) blue (0, 0, 200)# 俄罗斯方块…

我与开源的历程

我在2000年开始接触开源&#xff0c;当时在松下航空电子美国总部工作。我负责将 IFE 系统从 Win31 迁移到 Linux。作为一个完全不懂 Linux 的小白&#xff0c;我不得不找到一台笔记本电脑安装并自学 Redhat Linux 6.1。2003年回到新加坡后&#xff0c;我发现没有一个凝聚 Linux…

Echarts 柱状图添加标记 最大值 最小值 平均值

标记 最大值 最小值 series: [//图表配置项 如大小&#xff0c;图表类型{name: 图例,type: bar,//图表类型data: [{value: 500,time: 2012-11-12},{value: 454,time: 2020-5-17},{value: 544,time: 2022-1-22},{value: 877,time: 2013-1-30}, {value: 877,time: 2012-11-12}] …