学习笔记13:微调模型

news/2025/3/25 18:17:17/文章来源:https://www.cnblogs.com/gongzb/p/18230156

转自:https://www.cnblogs.com/miraclepbc/p/14360807.html

resnet预训练模型

resnet模型与之前笔记中的vgg模型不同,需要我们直接覆盖掉最后的全连接层
先看一下resnet模型的结构:

我们需要先将所有的参数都设置成requires_grad = False
然后再重新定义fc层,并覆盖掉原来的。
重新定义的fc层的requires_grad默认为True

 
for p in model.parameters():p.requries_grad = Falsein_f = model.fc.in_features
model.fc = nn.Linear(in_f, 4)

当定义optimizer的时候,需要注意,传进去的参数是fc层的参数,而不是所有层的参数

optimizer = torch.optim.Adam(model.fc.parameters(), lr = 0.001)

微调

微调的一般步骤是:

  • 重新定义全连接层
  • 训练重新定义的全连接层
  • 解冻部分其他层
  • 训练整个模型
    注意:微调是在训练完新的全连接层后,才能进行的。也就相当于整个模型训练了两次。
    optimizer这时的参数就是整个模型的参数了。
    代码:
for param in model.parameters():param.requires_grad = Trueextend_epoch = 30
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)

全部代码

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms, models
import os
import shutil
%matplotlib inlinetrain_transform = transforms.Compose([transforms.Resize(224),transforms.RandomCrop(192),transforms.RandomHorizontalFlip(),transforms.RandomRotation(0.2),transforms.ColorJitter(brightness = 0.5),transforms.ColorJitter(contrast = 0.5),transforms.ToTensor(),transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
])
test_transform = transforms.Compose([transforms.Resize((192, 192)),transforms.ToTensor(),transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
])
train_ds = datasets.ImageFolder("E:/datasets2/29-42/29-42/dataset2/4weather/train",transform = train_transform
)
test_ds = datasets.ImageFolder("E:/datasets2/29-42/29-42/dataset2/4weather/test",transform = test_transform
)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size = 8, shuffle = True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size = 8)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = models.resnet101(pretrained = True)
for p in model.parameters():p.requries_grad = False
in_f = model.fc.in_features
model.fc = nn.Linear(in_f, 4)loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr = 0.001)
epochs = 30
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 7, gamma = 0.1)def fit(epoch, model, trainloader, testloader):correct = 0total = 0running_loss = 0model.train()for x, y in trainloader:x, y = x.to(device), y.to(device)y_pred = model(x)loss = loss_func(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()with torch.no_grad():y_pred = torch.argmax(y_pred, dim = 1)correct += (y_pred == y).sum().item()total += y.size(0)running_loss += loss.item()exp_lr_scheduler.step()epoch_acc = correct / totalepoch_loss = running_loss / len(trainloader.dataset)test_correct = 0test_total = 0test_running_loss = 0model.eval()with torch.no_grad():for x, y in testloader:x, y = x.to(device), y.to(device)y_pred = model(x)loss = loss_func(y_pred, y)y_pred = torch.argmax(y_pred, dim = 1)test_correct += (y_pred == y).sum().item()test_total += y.size(0)test_running_loss += loss.item()epoch_test_acc = test_correct / test_totalepoch_test_loss = test_running_loss / len(testloader.dataset)print('epoch: ', epoch, 'loss: ', round(epoch_loss, 3),'accuracy: ', round(epoch_acc, 3),'test_loss: ', round(epoch_test_loss, 3),'test_accuracy: ', round(epoch_test_acc, 3))return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acctrain_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch, model, train_dl, test_dl)train_loss.append(epoch_loss)train_acc.append(epoch_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)for param in model.parameters():param.requires_grad = True
extend_epoch = 30
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(extend_epoch):epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch, model, train_dl, test_dl)train_loss.append(epoch_loss)train_acc.append(epoch_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/719707.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

成熟的双向同步方案,能够解决哪些同步问题?

在企业的数据流转管控过程中,经常会遇到频繁的数据备份、同步,人工重复这样的工作程序,既繁琐又容易出错。因此对于企业而言,选择一款高效且安全的同步软件成为了企业运营中的关键一环,不仅能够提高工作效率,还能确保数据的安全性。在选择双向同步方案时,首先要明确自己…

从数据库设计到性能调优,全面掌握openGemini应用开发最佳实践

据库设计和性能调优最重要的干货都在这里了!本文分享自华为云社区《DTSE Tech Talk openGemini :从数据库设计到性能调优,全面掌握openGemini应用开发最佳实践》,作者:华为云开源。 在本期《从数据库设计到性能调优,全面掌握openGemini应用开发最佳实践》的主题直播中,…

MBD闲谈 第03期:MBD的“禁区”——底层驱动

转载自:autoMBD, 版权归autoMBD所有,转载请注明作者和来源 原文链接:http://www.360doc.com/content/22/0820/17/15913066_1044626106.shtml全文约3562字,你将看到以下内容:底层驱动的那些事底层驱动为啥是MBD“禁区” 底层驱动与模型集成下期预告1 底层驱动的那些事 先…

allure的suites(测试套)中未显示返回值参数,显示No information about test execution is available.(转自大佬,亲测有用)

转自大佬:https://blog.csdn.net/sbdxmnz/article/details/137016423ExecutionNo information about test execution is available.解决方法: 添加代码,因为pytest输出文本形式测试报告时未存储响应内容 # 将接口响应的文本内容附加到Allure报告中 allure.attach(接口响应.…

学习笔记9:卷积神经网络实现MNIST分类(GPU加速)

转自:https://www.cnblogs.com/miraclepbc/p/14345342.html 相关包导入 import torch import pandas as pd import numpy as np import matplotlib.pyplot as plt from torch import nn import torch.nn.functional as F from torch.utils.data import TensorDataset from tor…

笔记2:张量简介

张量生成方法 转自:https://www.cnblogs.com/miraclepbc/p/14329476.html张量的形状及类型张量的计算张量的梯度手写线性回归张量生成方法 张量的形状及类型 张量的计算 张量的梯度 手写线性回归

笔记3:逻辑回归(分批次训练)

转自:https://www.cnblogs.com/miraclepbc/p/14332084.html 相关库导入 import torch import pandas as pd import numpy as np import matplotlib.pyplot as plt from torch import nn %matplotlib inline数据读入及预处理 data = pd.read_csv(E:/datasets/dataset/credit-a.…

【深度好文】到底什么是质量意识?如何衡量,如何提升?

大家好,我是狂师! 在软件测试中,质量意识是一个核心且至关重要的概念。相信大家,经常会听到:"这个家伙质量意识很强,某某某要提升质量意识“之类的话语。 在企业中,“质量意识”不仅关乎产品和服务的优劣,更是企业竞争力和可持续发展的关键因素。那么,到底什么是…

域名

顶级域名、二级域名与三级域名互联网名称与数字地址分配机构(ICANN)负责管理和协调国际互联网络域名系统。根据ICANN的定义,一个完整的域名至少有两个部分,各部分之间用“.”来分隔,最后一个“.”的右边部分称为顶级域名,也称为一级域名;最后一个“.”的左边部分称为二级…

异构数据源同步之数据同步 → DataX 使用细节

开心一刻 中午我妈微信给我消息 妈:儿子啊,妈电话欠费了,能帮妈充个话费吗 我:妈,我知道了,我帮你充 当我帮我妈把话费充好,正准备回微信的时候,我妈微信给我发消息了 妈:等会儿子,不用充了,刚刚有个二臂帮妈充上了 我输入框中的(妈,充好了)是发还是不发?简单使…

js日期格式化代码

js 日期格式化代码 分享一个前端实用的 js 日期格式化代码,相当给力。1 export function getFillDate(key) {2 if(key < 10) {3 return `0${key}`;4 }else{5 return `${key}`;6 }7 }8 /**9 * 时间戳转化为年月日 10 * @param times 时间戳 11 * @param ym…

一周万星的文本转语音开源项目「GitHub 热点速览」

上周的热门开源项目让我想起了「图灵测试」,测试者在不知道对面是机器还是人类的前提下随意提问,最后根据对方回复的内容,判断与他们交谈的是人还是计算机。如果无法分辨出回答者是机器还是人类,则说明机器已通过测试,具有人类的智力水平。 ​虽然现在大模型的回答还充满 …