知识蒸馏代码实现(以MNIST手写数字体为例,自定义MLP网络做为教师和学生网络)

dataloader_tools.py

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoaderdef load_data():# 载入MNIST训练集train_dataset = torchvision.datasets.MNIST(root = "../datasets/",train=True,transform=transforms.ToTensor(),download=True)# 载入MNIST测试集test_dataset = torchvision.datasets.MNIST(root = "../datasets/",train=False,transform=transforms.ToTensor(),download=True)# 生成训练集和测试集的dataloadertrain_dataloader = DataLoader(dataset=train_dataset,batch_size=12,shuffle=True)test_dataloader = DataLoader(dataset=test_dataset,batch_size=12,shuffle=False)return train_dataloader,test_dataloader

models.py

import torch
from torch import nn
# 教师模型
class TeacherModel(nn.Module):def __init__(self,in_channels=1,num_classes=10):super(TeacherModel,self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(784,1200)self.fc2 = nn.Linear(1200,1200)self.fc3 = nn.Linear(1200,num_classes)self.dropout = nn.Dropout(p=0.5) #p=0.5是丢弃该层一半的神经元.def forward(self,x):x = x.view(-1,784)x = self.fc1(x)x = self.dropout(x)x = self.relu(x)x = self.fc2(x)x = self.dropout(x)x = self.relu(x)x = self.fc3(x)return xclass StudentModel(nn.Module):def __init__(self,in_channels=1,num_classes=10):super(StudentModel,self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(784,20)self.fc2 = nn.Linear(20,20)self.fc3 = nn.Linear(20,num_classes)def forward(self,x):x = x.view(-1,784)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.relu(x)x = self.fc3(x)return x

train_tools.py

from torch import nn
import time
import torch
import tqdm
import torch.nn.functional as Fdef train(epochs, model, model_name, lr,train_dataloader,test_dataloader,device):# ----------------------开始计时-----------------------------------start_time = time.time()# 设置参数开始训练best_acc, best_epoch = 0, 0criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=lr)for epoch in range(epochs):model.train()# 训练集上训练模型权重for data, targets in tqdm.tqdm(train_dataloader):# 把数据加载到GPU上data = data.to(device)targets = targets.to(device)# 前向传播preds = model(data)loss = criterion(preds, targets)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 测试集上评估模型性能model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_dataloader:x = x.to(device)y = y.to(device)preds = model(x)predictions = preds.max(1).indices  # 返回每一行的最大值和该最大值在该行的列索引num_correct += (predictions == y).sum()num_samples += predictions.size(0)acc = (num_correct / num_samples).item()if acc > best_acc:best_acc = accbest_epoch = epoch# 保存模型最优准确率的参数torch.save(model.state_dict(), f"../weights/{model_name}_best_acc_params.pth")model.train()print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc),f'loss={loss}')print(f'最优准确率的epoch为{best_epoch},值为:{best_acc},最优参数已经保存到:weights/{model_name}_best_acc_params.pth')# -------------------------结束计时------------------------------------end_time = time.time()run_time = end_time - start_time# 将输出的秒数保留两位小数if int(run_time) < 60:print(f'训练用时为:{round(run_time, 2)}s')else:print(f'训练用时为:{round(run_time / 60, 2)}minutes')def distill_train(epochs,teacher_model,student_model,model_name,train_dataloader,test_dataloader,alpha,lr,temp,device):# -------------------------------------开始计时--------------------------------start_time = time.time()# 定以损失函数hard_loss = nn.CrossEntropyLoss()soft_loss = nn.KLDivLoss(reduction="batchmean")# 定义优化器optimizer = torch.optim.Adam(student_model.parameters(), lr=lr)best_acc,best_epoch = 0,0for epoch in range(epochs):student_model.train()# 训练集上训练模型权重for data,targets in tqdm.tqdm(train_dataloader):# 把数据加载到GPU上data = data.to(device)targets = targets.to(device)# 教师模型预测with torch.no_grad():teacher_preds = teacher_model(data)# 学生模型预测student_preds = student_model(data)# 计算hard_lossstudent_hard_loss = hard_loss(student_preds,targets)# 计算蒸馏后的预测结果及soft_lossditillation_loss = soft_loss(F.softmax(student_preds/temp,dim=1),F.softmax(teacher_preds/temp,dim=1))# 将hard_loss和soft_loss加权求和loss = temp * temp * alpha * student_hard_loss + (1-alpha)*ditillation_loss# 反向传播,优化权重optimizer.zero_grad()loss.backward()optimizer.step()#测试集上评估模型性能student_model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x,y in test_dataloader:x = x.to(device)y = y.to(device)preds = student_model(x)predictions = preds.max(1).indices #返回每一行的最大值和该最大值在该行的列索引num_correct += (predictions ==y).sum()num_samples += predictions.size(0)acc = (num_correct/num_samples).item()if acc>best_acc:best_acc = accbest_epoch = epoch# 保存模型最优准确率的参数torch.save(student_model.state_dict(),f"../weights/{model_name}_best_acc_params.pth")student_model.train()print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1,acc))print(f'student_hard_loss={student_hard_loss},ditillation_loss={ditillation_loss},loss={loss}')print(f'最优准确率的epoch为{best_epoch},值为:{best_acc},')# --------------------------------结束计时----------------------------------end_time = time.time()run_time = end_time - start_time# 将输出的秒数保留两位小数if int(run_time) < 60:print(f'训练用时为:{round(run_time, 2)}s')else:print(f'训练用时为:{round(run_time / 60, 2)}minutes')

训练教师网络

import torch
from torchinfo import summary #用来可视化的
import models
import dataloader_tools
import train_tools# 设置随机数种子
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True# 载入MNIST训练集和测试集
train_dataloader,test_dataloader = dataloader_tools.load_data()# 定义教师模型
model = models.TeacherModel()
model = model.to(device)
# 打印模型的参数
summary(model)# 定义参数并开始训练
epochs = 10
lr = 1e-4
model_name = 'teacher'
train_tools.train(epochs,model,model_name,lr,train_dataloader,test_dataloader,device)
最优准确率的epoch为9,值为:0.9868999719619751

用非蒸馏的方法训练学生网络

import torch
from torchinfo import summary #用来可视化的
import dataloader_tools
import models
import train_tools# 设置随机数种子
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True# 生成训练集和测试集的dataloader
train_dataloader,test_dataloader = dataloader_tools.load_data()# 从头训练学生模型
model = models.StudentModel()
model = model.to(device)
# 查看模型参数
print(summary(model))# 定义参数并开始训练
epochs = 10
lr = 1e-4
model_name = 'student'
train_tools.train(epochs, model, model_name, lr,train_dataloader,test_dataloader,device)
最优准确率的epoch为9,准确率为:0.9382999539375305,最优参数已经保存到:weights/student_best_acc_params.pth
训练用时为:1.74minutes

用知识蒸馏的方法训练student model

import torch
import train_tools
import models
import dataloader_tools# 设置随机数种子
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True# 加载数据
train_dataloader,test_dataloader = dataloader_tools.load_data()# 加载训练好的teacher model
teacher_model = models.TeacherModel()
teacher_model = teacher_model.to(device)
teacher_model.load_state_dict(torch.load('../weights/teacher_best_acc_params.pth'))
teacher_model.eval()# 准备新的学生模型
student_model = models.StudentModel()
student_model = student_model.to(device)
student_model.train()# 开始训练
lr = 0.0001
epochs = 20
alpha = 0.3 # hard_loss权重
temp = 7 # 蒸馏温度
model_name = 'distill_student_loss'
# 调用train_tools中的
train_tools.distill_train(epochs,teacher_model,student_model,model_name,train_dataloader,test_dataloader,alpha,lr,temp,device)
最优准确率的epoch为9,值为:0.9204999804496765,
训练用时为:2.14minutes

在这里插入图片描述

loss改为:

# temp的平方乘在student_hard_loss
loss = temp * temp * alpha * student_hard_loss + (1 - alpha) * ditillation_loss
最优准确率的epoch为9,值为:0.9336999654769897,
训练用时为:2.12minutes

loss改为:

# temp的平方乘ditillation_loss
loss = alpha * student_hard_loss + temp * temp * (1 - alpha) * ditillation_loss
最优准确率的epoch为9,值为:0.9176999926567078,
训练用时为:2.09minutes

上面的几种loss,蒸馏损失都出现了负数的情况。不太对劲。
在这里插入图片描述

其它开源的知识蒸馏算法如下:

open-mmlab开源的工具箱包含知识蒸馏算法

mmrazor

github.com/open-mmlab/mmrazor

在这里插入图片描述

NAS:神经架构搜索
剪枝:Pruning
KD: 知识蒸馏
Quantization: 量化

自定义知识蒸馏算法:
在这里插入图片描述

mmdeploy

可以把算法部署到一些厂商支持的中间格式,如ONNX,tensorRT等。

在这里插入图片描述

HobbitLong的RepDistiller

github.com/HobbitLong/RepDistiller

在这里插入图片描述
在这里插入图片描述
里面有12种最新的知识蒸馏算法。

蒸馏网络可以应用于同一种模型,将大的学习的知识蒸馏到小的上面。
如下将resnet100做教师网络,resnet32做学生网络。

在这里插入图片描述

将一种模型迁移到另一种模型上。如vgg13做教师网络,mobilNetv2做学生网络:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/235421.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【散列函数的构造方法(直接定址法 ==除留余数法==),散列表的查找(1.开放地址法,2.链地址法(拉链法))】

文章目录 散列函数的构造方法直接定址法除留余数法 散列表的查找1.开放地址法线性探测法二次探测法伪随机探测法 2.链地址法&#xff08;拉链法&#xff09; 散列表的查找效率 散列函数的构造方法 散列存储 选取某个函数&#xff0c;依该函数按关键字计算元素的存储位置。 Loc…

单宁对葡萄酒可饮用性和陈酿潜力会有影响吗?

当在酿酒过程中葡萄酒中的单宁过量时&#xff0c;酿酒师可以使用白蛋白、酪蛋白和明胶等各种细化剂&#xff0c;这些药物可以与单宁分子结合&#xff0c;并将其作为沉淀物沉淀出来。随着葡萄酒的老化&#xff0c;单宁将形成长长的聚合链&#xff0c;氧气可以与单宁分子结合&…

正式版PS 2024 25新增功能 刚刚发布的虎标正式版

Adobe Photoshop 2024是一款业界领先的图像编辑软件&#xff0c;被广泛应用于设计、摄影、插图等领域。以下是这款软件的一些主要功能和特点&#xff1a; 丰富的工具和功能。Adobe Photoshop 2024提供了丰富的工具和功能&#xff0c;可以帮助用户对图像进行编辑、修饰和优化。…

回文链表,剑指offer 27,力扣 61

目录 题目&#xff1a; 我们直接看题解吧&#xff1a; 解题方法&#xff1a; 难度分析&#xff1a; 审题目事例提示&#xff1a; 解题分析&#xff1a; 解题思路&#xff08;数组列表双指针&#xff09;&#xff1a; 代码说明补充&#xff1a; 代码实现&#xff1a; 代码实现&a…

Course1-Week2-多输入变量的回归问题

Course1-Week2-多输入变量的回归问题 文章目录 Course1-Week2-多输入变量的回归问题1. 向量化和多元线性回归1.1 多维特征1.2 向量化1.3 用于多元线性回归的梯度下降法 2. 使梯度下降法更快收敛的技巧2.1 特征缩放2.2 判断梯度下降是否收敛2.3 如何设置学习率 3. 特征工程3.1 选…

Python中的Slice函数:灵活而强大的序列切片技术

更多资料获取 &#x1f4da; 个人网站&#xff1a;ipengtao.com Python中的Slice函数是一种强大且灵活的序列切片技术&#xff0c;用于从字符串、列表、元组等序列类型中提取子集。本文将深入研究Slice函数的功能和用法&#xff0c;提供详细的示例代码和解释&#xff0c;帮助读…

MAVEN冲突解决

MAVEN冲突解决 1.安装下面这个插件 2.安装成功点击pom文件 dependency analyzer标志&#xff0c;说明maven helper插件就安装成功 3.点击dependency analyzer之后就会进入到下面的页面 4.标记红色就是版本冲突&#xff0c;右击complie&#xff0c;排除不是使用的 5.POM 文件…

CISO在2024年应该优先考虑七项安全任务

专业安全媒体CyberTalk.org主编Shira Landau日前表示&#xff1a;现代企业的CISO们在2024年必须做出改变&#xff0c;要更多关注于企业整体安全路线图的推进与实现&#xff0c;让网络安全工作与业务发展目标保持更紧密的一致性。 首席信息安全官&#xff08;CISO&#xff09;是…

行行AI董事长李明顺:今天每个人都可以成为AI应用的创业者

“ AI创业的核心在于真正介入到应用层面&#xff0c;AI应该成为真正的应用支撑。 ” 整理 | 王娴 编辑 | 云舒 出品&#xff5c;极新 2023年11月28日&#xff0c;极新AIGC行业峰会在北京东升国际科学园顺利召开&#xff0c;行行AI董事长李明顺先生在会上做了题为《从大模型…

解决CentOS下PHP system命令unoconv转PDF提示“Unable to connect or start own listener“

centos系统下&#xff0c;用php的system命令unoconv把word转pdf时提示Unable to connect or start own listene的解决办法 unoconv -o /foo/bar/public_html/upload/ -f pdf /foo/bar/public_html/upload/test.docx 2>&1 上面这个命令在shell 终端能执行成功&#xff0c…

BootStrap完整页面尝试(感兴趣的同学可以做)

试采用BootStrap技术或者htmlcss&#xff0c;完成以下页面。 题目为选做&#xff0c;有兴趣的同学可以尝试。

C++:类和对象(中)

1.类的6个默认成员函数&#xff1a; 如果一个类中什么成员都没有&#xff0c;简称为空类。 空类中真的什么都没有吗&#xff1f;并不是&#xff0c;任何类在什么都不写时&#xff0c;编译器会自动生成以下6个默认成员函数。 默认成员函数&#xff1a;用户没有显式实现&#xff…