知识蒸馏实战代码教学二(代码实战部分)

一、上章原理回顾

具体过程:

        (1)首先我们要先训练出较大模型既teacher模型。(在图中没有出现)

        (2)再对teacher模型进行蒸馏,此时我们已经有一个训练好的teacher模型,所以我们能很容易知道teacher模型输入特征x之后,预测出来的结果teacher_preds标签。

        (3)此时,求到老师预测结果之后,我们需要求解学生在训练过程中的每一次结果student_preds标签。

        (4)先求hard_loss,也就是学生模型的预测student_preds与真实标签targets之间的损失。

        (5)再求soft_loss,也就是学生模型的预测student_preds与教师模型teacher_preds的预测之间的损失。

        (6)求出hard_loss与soft_loss之后,求和总loss=a*hard_loss + (1-a)soft_loss,a是一个自己设置的权重参数,我在代码中设置为a=0.3。

        (7)最后反向传播继续迭代。

二、代码实现

1、数据集

        数据集采用的是手写数字的数据集mnist数据集,如果没有下载,代码部分中会进行下载,只需要把download改成True,然后就会保存在当前目录中。该数据集将其分成80%的训练集,20%的测试集,最后返回train_dataset和test_datatset。

class MyDataset(Dataset):def __init__(self,opt):self.opt = optdef MyData(self):## mnist数据集下载0mnist = datasets.MNIST(root='../datasets/', train=True, download=False, transform=transforms.Compose([transforms.Resize(self.opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),)dataset_size = len(mnist)train_size = int(0.8 * dataset_size)test_size = dataset_size - train_sizetrain_dataset, test_dataset = random_split(mnist, [train_size, test_size])train_dataloader = DataLoader(train_dataset,batch_size=self.opt.batch_size,shuffle=True,)test_dataloader = DataLoader(test_dataset,batch_size=self.opt.batch_size,shuffle=False,  # 在测试集上不需要打乱顺序)return train_dataloader,test_dataloader

2、teacher模型和训练实现

       (1) 首先是teacher模型构造,经过三次线性层。

import torch.nn as nn
import torchimg_area = 784class TeacherModel(nn.Module):def __init__(self,in_channel=1,num_classes=10):super(TeacherModel,self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(img_area,1200)self.fc2 = nn.Linear(1200, 1200)self.fc3 = nn.Linear(1200, num_classes)self.dropout = nn.Dropout(p=0.5)def forward(self, x):x = x.view(-1, img_area)x = self.fc1(x)x = self.dropout(x)x = self.relu(x)x = self.fc2(x)x = self.dropout(x)x = self.relu(x)x = self.fc3(x)return x

        (2)训练teacher模型

        老师模型训练完成后其权重参数会保存在teacher.pth当中,为以后调用。

import torch.nn as nn
import torch## 创建文件夹
from tqdm import tqdmfrom dist.TeacherModel import TeacherModelweight_path = 'C:/Users/26394/PycharmProjects/untitled1/dist/params/teacher.pth'
## 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')torch.backends.cudnn.benchmark = True #使用卷积cuDNN加速class TeacherTrainer():def __init__(self,opt,train_dataloader,test_dataloader):self.opt = optself.train_dataloader = train_dataloaderself.test_dataloader = test_dataloaderdef trainer(self):# 老师模型opt = self.opttrain_dataloader = self.train_dataloadertest_dataloader = self.test_dataloaderteacher_model = TeacherModel()teacher_model = teacher_model.to(device)criterion = nn.CrossEntropyLoss()optimizer_teacher = torch.optim.Adam(teacher_model.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))for epoch in range(opt.n_epochs):  ## epoch:50teacher_model.train()for data, targets in tqdm(train_dataloader):data = data.to(device)targets = targets.to(device)preds = teacher_model(data)loss = criterion(preds, targets)optimizer_teacher.zero_grad()loss = criterion(preds, targets)loss.backward()optimizer_teacher.step()teacher_model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_dataloader:x = x.to(device)y = y.to(device)preds = teacher_model(x)predictions = preds.max(1).indicesnum_correct += (predictions == y).sum()num_samples += predictions.size(0)acc = (num_correct / num_samples).item()torch.save(teacher_model.state_dict(), weight_path)teacher_model.train()print('teacher: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

        (3)训练teacher模型

        模型参数都在paras()当中设置好了,直接调用teacher_model就行,然后将其权重参数会保存在teacher.pth当中。

import argparseimport torchfrom dist.DistillationTrainer import DistillationTrainer
from dist.MyDateLoader import MyDataset
from dist.TeacherTrainer import TeacherTrainerdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def paras():## 超参数配置parser = argparse.ArgumentParser()parser.add_argument("--n_epochs", type=int, default=5, help="number of epochs of training")parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")parser.add_argument("--channels", type=int, default=1, help="number of image channels")parser.add_argument("--sample_interval", type=int, default=500, help="interval betwen image samples")opt = parser.parse_args()## opt = parser.parse_args(args=[])                 ## 在colab中运行时,换为此行print(opt)return optif __name__ == '__main__':opt = paras()data = MyDataset(opt)train_dataloader, test_dataloader = data.MyData()# 训练Teacher模型teacher_trainer = TeacherTrainer(opt,train_dataloader,test_dataloader)teacher_trainer.trainer()

 3、学生模型的构建

        学生模型也是经过了三次线性层,但是神经元没有teacher当中多。所以student模型会比teacher模型小很多。

import torch.nn as nn
import torchimg_area = 784class StudentModel(nn.Module):def __init__(self,in_channel=1,num_classes=10):super(StudentModel,self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(img_area,20)self.fc2 = nn.Linear(20, 20)self.fc3 = nn.Linear(20, num_classes)def forward(self, x):x = x.view(-1, img_area)x = self.fc1(x)# x = self.dropout(x)x = self.relu(x)x = self.fc2(x)# x = self.dropout(x)x = self.relu(x)x = self.fc3(x)return x

4、知识蒸馏训练

(1)首先读取teacher模型。

        将teacher模型中的权重参数teacher.pth放入模型当中。

 #拿取训练好的模型teacher_model = TeacherModel()if os.path.exists(weights):teacher_model.load_state_dict(torch.load(weights))print('successfully')else:print('not loading')teacher_model = teacher_model.to(device)

(2)设置损失求解的函数

        hard_loss用的就是普通的交叉熵损失函数,而soft_loss就是用的KL散度。

        # hard_losshard_loss = nn.CrossEntropyLoss()# hard_loss权重alpha = 0.3# soft_losssoft_loss = nn.KLDivLoss(reduction="batchmean")

(3)之后再进行蒸馏训练,温度为7

  •         先求得hard_loss就是用学生模型预测的标签和真实标签进行求得损失。
  •         再求soft_loss就是用学生模型预测的标签和老师模型预测的标签进行求得损失。使用softmax时候还需要进行除以温度temp。
  •         最后反向传播,求解模型
       for epoch in range(opt.n_epochs):  ## epoch:5for data, targets in tqdm(train_dataloader):data = data.to(device)targets = targets.to(device)# 老师模型预测with torch.no_grad():teacher_preds = teacher_model(data)# 学生模型预测student_preds = model(data)# 计算hard_lossstudent_loss = hard_loss(student_preds, targets)# 计算蒸馏后的预测损失ditillation_loss = soft_loss(F.softmax(student_preds / temp, dim=1),F.softmax(teacher_preds / temp, dim=1))loss = alpha * student_loss + (1 - alpha) * ditillation_lossoptimizer.zero_grad()loss.backward()optimizer.step()model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_dataloader:x = x.to(device)y = y.to(device)preds = model(x)predictions = preds.max(1).indicesnum_correct += (predictions == y).sum()num_samples += predictions.size(0)acc = (num_correct / num_samples).item()model.train()print('distillation: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

(4)整个蒸馏训练代码

import torch.nn as nn
import torch
import torch.nn.functional as F
import os
from tqdm import tqdmfrom dist.StudentModel import StudentModel
from dist.TeacherModel import TeacherModelweights = 'C:/Users/26394/PycharmProjects/untitled1//dist/params/teacher.pth'# D_weight_path = 'C:/Users/26394/PycharmProjects/untitled1/dist/params/distillation.pth'
## 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')torch.backends.cudnn.benchmark = True #使用卷积cuDNN加速class DistillationTrainer():def __init__(self,opt,train_dataloader,test_dataloader):self.opt = optself.train_dataloader = train_dataloaderself.test_dataloader = test_dataloaderdef trainer(self):opt = self.opttrain_dataloader = self.train_dataloadertest_dataloader = self.test_dataloader#拿取训练好的模型teacher_model = TeacherModel()if os.path.exists(weights):teacher_model.load_state_dict(torch.load(weights))print('successfully')else:print('not loading')teacher_model = teacher_model.to(device)teacher_model.eval()model = StudentModel()model = model.to(device)temp = 7# hard_losshard_loss = nn.CrossEntropyLoss()# hard_loss权重alpha = 0.3# soft_losssoft_loss = nn.KLDivLoss(reduction="batchmean")optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))for epoch in range(opt.n_epochs):  ## epoch:5for data, targets in tqdm(train_dataloader):data = data.to(device)targets = targets.to(device)# 老师模型预测with torch.no_grad():teacher_preds = teacher_model(data)# 学生模型预测student_preds = model(data)# 计算hard_lossstudent_loss = hard_loss(student_preds, targets)# 计算蒸馏后的预测损失ditillation_loss = soft_loss(F.softmax(student_preds / temp, dim=1),F.softmax(teacher_preds / temp, dim=1))loss = alpha * student_loss + (1 - alpha) * ditillation_lossoptimizer.zero_grad()loss.backward()optimizer.step()model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_dataloader:x = x.to(device)y = y.to(device)preds = model(x)predictions = preds.max(1).indicesnum_correct += (predictions == y).sum()num_samples += predictions.size(0)acc = (num_correct / num_samples).item()model.train()print('distillation: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

(5)蒸馏训练的主函数

        该部分大致与teacher模型训练类似,只是调用不同。

import argparseimport torchfrom dist.DistillationTrainer import DistillationTrainer
from dist.MyDateLoader import MyDataset
from dist.TeacherTrainer import TeacherTrainerdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def paras():## 超参数配置parser = argparse.ArgumentParser()parser.add_argument("--n_epochs", type=int, default=5, help="number of epochs of training")parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")parser.add_argument("--channels", type=int, default=1, help="number of image channels")parser.add_argument("--sample_interval", type=int, default=500, help="interval betwen image samples")opt = parser.parse_args()## opt = parser.parse_args(args=[])                 ## 在colab中运行时,换为此行print(opt)return optif __name__ == '__main__':opt = paras()data = MyDataset(opt)train_dataloader, test_dataloader = data.MyData()# 训练Teacher模型# teacher_trainer = TeacherTrainer(opt,train_dataloader,test_dataloader)# teacher_trainer.trainer()distillation_trainer = DistillationTrainer(opt,train_dataloader,test_dataloader)distillation_trainer.trainer()

三、总结

        总的来说,知识蒸馏是一种有效的模型压缩技术,可以通过在模型训练过程中引入额外的监督信号来训练简化的模型,从而获得与大型复杂模型相近的性能,但具有更小的模型尺寸和计算开销。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/478898.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

人工智能_普通服务器CPU_安装清华开源人工智能AI大模型ChatGlm-6B_001---人工智能工作笔记0096

使用centos安装,注意安装之前,保证系统可以联网,然后执行yum update 先去更新一下系统,可以省掉很多麻烦 20240219_150031 这里我们使用centos系统吧,使用习惯了. ChatGlm首先需要一台个人计算机,或者服务器, 要的算力,训练最多,微调次之,推理需要算力最少 其实很多都支持C…

openEuler学习——mysql(第一次总结)

1、openEuler 二进制方式安装MySQL 8.0.x。 思路是先从官网获取安装包链接如下https://mirrors.aliyun.com/mysql/MySQL-8.0/mysql-8.0.28-linux-glibc2.12-x86_64.tar.xz 然后解压安装修改权限,可以参考mysql官方网站步骤 [rootopenEuler-node1 ~]# wget -c https:…

面试redis篇-03缓存击穿

原理 缓存击穿:给某一个key设置了过期时间,当key过期的时候,恰好这时间点对这个key有大量的并发请求过来,这些并发的请求可能会瞬间把DB压垮 解决方案一:互斥锁 解决方案二:逻辑过期 提问与回答 面试官 :什么是缓存击穿 ? 怎么解决 ? 回答: 缓存击穿的意思…

【Linux 内核源码分析】虚拟内存地址空间

在现代操作系统中,每个进程被分配了独享的虚拟内存地址空间。这个地址空间可以视为一维线性空间,由多个连续的内存页组成。初始时,操作系统会将整个虚拟地址空间分成几个不同的区域,每个区域用于特定的目的。以下是一个常见的布局…

motplotlib图例案例1:通过多个legend完全控制图例显示顺序(指定按行排序 or 按列排序)

这个方法的核心,是手动的获得图中的handlers和labels,然后对它们进行切分和提取,最后分为几个legend进行显示。代码如下: 后来对下面的代码进行修改,通过handlers, labels get_legend_handles_labels(axs[axis])自动的…

从入门到精通全面指南助你掌握数据分析的技术与实践【第89篇—数据分析入门到精通】

从入门到精通全面指南助你掌握数据分析的技术与实践 在当今数字化时代,数据成为决策制定和问题解决的关键。数据分析作为一门强大的技能,不仅在商业领域中备受推崇,而且在科学研究、社会分析等多个领域也发挥着至关重要的作用。本文将带你深…

Allegro172版本如何用自带功能改变过孔网络属性操作指导

Allegro172版本如何用自带功能改变过孔网络属性操作指导 在用Allegro做PCB设计的时候,时常会需要将过孔的网络进行变更,可以将原来的过孔删除,再重新打一个,这种方法难免会繁琐一些。 当然我们可以借助skill工具来完成更换过孔网络的更改,除此之外,Allegro自带的功能完成…

Linux 常用命令100+

Linux 运维/开发/测试 常用命令100 线上查询及帮助命令(2个) 命令功能说明示例man 命令查看普通命令帮助,命令的词典,更复杂的还有info,但不常用。rootbrLinux ~]#man lshelp 命令查看Linux内置命令的帮助,比如cd命令。[rootbrL…

GPT在机器学习,深度学习等领域应用

详情点击公众号:技术科研吧 链接:GPT在机器学习,深度学习等领域应用 第一:2024年AI领域最新技术 1.OpenAI新模型-GPT-5 2.谷歌新模型-Gemini Ultra 3.Meta新模型-LLama3 4.科大讯飞-星火认知 5.百度-文心一言 6.MoonshotAI…

【计算机网络】网络基础

初识网络 一、网络发展二、认识协议三、认识网络协议1. 协议分层2. OSI 七层模型3. TCP/IP五层模型4. OS和网络协议栈 四、网络传输基本流程1. TCP/IP 协议通讯过程2. 以太网通信(1)以太网通信原理(2)数据碰撞 3. 数据跨网络传输 …

高效办公待办工具:打造高效待办事项管理桌面

在这个信息爆炸的时代,高效办公已成为职场人士的追求。高效办公的好处不言而喻:它能提升我们的工作效率,减少不必要的浪费,使我们在有限的时间内创造更多的价值。而在高效办公的道路上,管理好待办事项无疑是关键的一环…

力扣题目训练(16)

2024年2月9日力扣题目训练 2024年2月9日力扣题目训练530. 二叉搜索树的最小绝对差541. 反转字符串 II543. 二叉树的直径238. 除自身以外数组的乘积240. 搜索二维矩阵 II124. 二叉树中的最大路径和 2024年2月9日力扣题目训练 2024年2月9日第十六天编程训练,今天主要…