-
知识蒸馏使用的是Teacher—Student模型,其中teacher是“知识”的输出者,student是“知识”的接受者。知识蒸馏的过程分为2个阶段:
- 原始模型训练: 训练"Teacher模型", 它的特点是模型相对复杂,也可以由多个分别训练的模型集成而成。我们对"Teacher模型"不作任何关于模型架构、参数量、是否集成方面的限制,唯一的要求就是,对于输入X, 其都能输出Y,其中Y经过softmax的映射,输出值对应相应类别的概率值。
- 精简模型训练: 训练"Student模型", 它是参数量较小、模型结构相对简单的单模型。同样的,对于输入X,其都能输出Y,Y经过softmax映射后同样能输出对应相应类别的概率值。
- Teacher学习能力强,可以将它学到的知识迁移给学习能力相对弱的Student模型,以此来增强Student模型的泛化能力。复杂笨重但是效果好的Teacher模型不上线,就单纯是个导师角色,真正部署上线进行预测任务的是灵活轻巧的Student小模型。
-
知识蒸馏可以将大型的复杂模型(如深度神经网络)转换成小型的简化模型,从而减少了模型的存储空间和计算资源需求,使得模型更适合在资源受限的设备上部署和运行,如移动设备或嵌入式系统。通过知识蒸馏,可以将一个在大规模数据集上训练的模型的知识迁移到一个相似但规模较小的模型上,这有助于在资源受限的情况下进行迁移学习或模型微调。知识蒸馏还可以帮助提高模型的泛化性能,因为学生模型在训练过程中利用了教师模型的“软标签”,这些软标签包含了教师模型对数据分布的更丰富的信息,有助于减少过拟合。
-
我们可以通过将 OpenCLIP 视觉语言模型中的知识转移到 ResNet18 模型中来探索知识蒸馏,以便对 STL10 数据集进行分类。将探讨用于蒸馏的数据、用于蒸馏的方法和模型架构对最终准确性的影响。可以帮助学习如何利用现有的大型模型,并将知识转移到更适合边缘部署的架构上使用,因为这些架构具有更低的内存消耗、更高的吞吐量或更好的架构支持(即:在 Jetson AGX Orin 的深度学习加速器上运行)。
-
以下
从概念上介绍知识蒸馏,包含了重现结果的代码
。如果熟悉深度学习和模型训练,并且正在寻找将大型模型带入边缘的方法,可能会对您有所帮助!查看Nvidia的其他项目 clip-distillation(NVIDIA-AI-IOT/clip-distillation: Zero-label image classification via OpenCLIP knowledge distillation (github.com)),了解如何在没有任何标记数据的情况下使用知识提炼创建自己的自定义图像分类器!还将讨论如何在 NVIDIA Jetson Orin Nano 上对模型进行剖析和优化,以便最终实现实时部署。
什么是知识蒸馏
-
知识蒸馏是一种将知识从一个神经网络(教师)转移到另一个神经网络(学生)的技术,如需更深入地了解知识蒸馏,建议阅读[2006.05525] Knowledge Distillation: A Survey (arxiv.org)。这一过程有多种形式,可分为以下几类
- 响应知识蒸馏: 使用 divergence 损失(即使用 KL 散度)训练输出类概率分布,使其与教师概率分布相匹配。
- 特征知识蒸馏: 训练学生模型的内部特征,使其与教师模型的内部特征直接匹配(即:使用均方误差)。
- 关系知识蒸馏: 训练教师模型中特征的相对分布,使之与学生模型中特征的相对分布相匹配。
-
本文将探讨(1、2),因为与关系知识提炼相比,(1、2)非常简单。我们特别感兴趣的是探索如何使用知识蒸馏来获取基于 transformer 的大型教师模型(OpenCLIP),并训练更快、内存更小的模型(ResNet18),使其更适合边缘部署。将通过将 OpenCLIP 调整为针对 STL10 分类数据集的图像分类器来探索这一概念,用于蒸馏的数据和技术如何影响最终模型的准确性。
-
探讨的教师模型是 OpenCLIP。OpenCLIP 是 OpenAI 的 CLIP(对比语言-图像预训练)的开源实现。CLIP 模型的训练目的是将图像与文本进行匹配。具体来说,该模型由以下部分组成
图像编码器,它能获取图像并生成代表图像的嵌入信息
文本编码器,可接收文本提示并生成代表文本提示的嵌入内容
-
通过对模型进行训练,配对图像和文本的图像和文本嵌入相似,而非配对图像和文本的图像和文本嵌入差异很大。这个模型的一个有趣之处在于,它是在大量非结构化数据的基础上训练出来的,所学习到的特征可以应用到各种各样的任务中。事实上,只需提供类别描述,它就能在分类任务中达到很好的零误差准确率。不过,这种模型的缺点是,与 ResNet 等基于 CNN 的架构相比,它的运行时间和内存消耗相对较高。这就引出了一个问题:我们能否利用 OpenCLIP 的功能,同时获得更低的内存消耗和延迟?本文将探讨如何使用 OpenCLIP 模型作为教师模型,训练 CNN 模型完成分类任务。我们将探讨用于训练的数据、针对分类任务调整模型的方法以及提炼模型的方法对最终准确性的影响。
-
为了探索如何将 OpenCLIP 用于分类和知识蒸馏,使用斯坦福大学的 STL10 数据集。只探讨了 STL10 数据集,某些结果可能取决于特定的数据分布和任务。虽然我们不能保证这些结果会转换到其他任务和数据源中,这样您就可以用自己的数据来探索知识蒸馏。STL10 数据集是一个包含 10 个类别的分类数据集:
- 与 MNIST 相比,它包含自然图像,适合使用 OpenCLIP。
- 与 CIFAR10 相比,图像的分辨率为 96x96,而不是 32x32。这更接近 OpenCLIP 的训练分辨率。
- 它包含大量未标记的图像(100,000 张),这使我们能够探索在训练过程中使用未标记数据的好处。
-
在开始使用教师模型训练学生之前,最好先对教师模型在当前任务中的表现有一个初步的预期。这大致有助于设定我们希望学生模型达到的最佳性能预期。由于 OpenCLIP 并未直接在 STL10 数据集上进行训练,因此我们有几种方法可以调整该模型以执行分类。
-
使用文本提示进行分类:在 STL10 数据集上使用 OpenCLIP 进行分类的第一种也是最简单的方法是使用文本提示定义类别,通过文本编码器运行提示,并将每个编码后的文本提示与 GT 标签进行比较。对于 STL10 数据集,我们可以按以下方式生成文本嵌入
-
import open_clip model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="laion2b_s34b_b79k" ) tokenizer = open_clip.get_tokenizer("ViT-B-32") labels = ["an airplane","a bird","a car","a cat","a deer","a dog","a horse","a monkey","a ship","a truck" ] text = tokenizer(labels) text_embeddings = model.encode_text(text)
-
现在,每个文本提示的嵌入都包含一个长度为 512 的向量。该向量与视觉编码器输出的大小相同。该向量与视觉特征的点积表示相似度,因此我们可以确定数据集的类别概率如下
-
import torch.nn.functional as F def embeddings_to_class_probs(vision_embeddings, text_embeddings)vision_embeddings = vision_embeddings / vision_embeddings.norm(dim=-1, keepdim=True)text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)logits = vision_embeddings @ text_embeddings.Tclass_probs = F.softmax(100. * logits, dim=-1)return class_probs
-
现在,我们已经有了目标任务的文本嵌入以及将文本嵌入与图像嵌入进行比较的方法,剩下要做的就是通过 OpenCLIP 视觉编码器运行 STL10 数据集,计算输出类概率,并将结果与 GT 标签进行比较。
-
import tqdm from torchvision.datasets import STL10 dataset = STL10(root=dataset_path,download=True,split="test" ) num_correct = 0 for image, label in tqdm.tqdm(dataset):input_tensor = preprocess(image).unsqueeze(0)vision_embeddings = model.encode_image(input_tensor)output_class_probs = embeddings_to_class_probs(vision_embeddings, text_embeddings)output_label = torch.argmax(dim=-1)num_correct += int(torch.count_nonzero(output_label == label)) accuracy = 100. * num_correct / len(dataset)
-
之后,开箱即用的 OpenCLIP 编码器在没有任何额外训练的情况下,在 STL10 测试数据集上获得了 96.68% 的准确率!在没有任何技巧的情况下,我们在 STL10 数据集上获得了相当有竞争力的准确率。
-
使用线性 head 进行分类:使用文本提示作为类标签,我们能够在 STL10 数据集上实现相当高的准确率,而无需任何训练或基本真实标签。但是,如果我们有 GT 标签呢?我们能用它来提高准确率吗?有了这个选项,我们将探索如何使用一些 GT 数据,在 OpenCLIP 模型的末尾训练一个微小的逻辑回归层(线性层,然后是 softmax),看看这是否能提高准确率。为此,我们对线性层的定义如下
-
import torch.nn as nn linear_probe = nn.Linear(512, len(labels))
-
然后,我们需要训练模型。这包括:从数据集中读取一批数据;运行 OpenCLIP 视觉编码器(无梯度);在 OpenCLIP 的输出上运行线性层;计算线性层输出与 GT 标签之间的交叉熵;更新线性层。
-
optimizer = torch.optim.Adam(linear_probe.parameters(), lr=3e-4) for epoch in range(num_epochs):for image, label in iter(train_loader):# ... run open-clip to get vision embeddingsoptimizer.zero_grad()output_logits = linear_probe(vision_embeddings)output_logprob = F.log_softmax(output_logits, dim=-1)loss = F.nll_loss(output_logprob, label)loss.backward()optimizer.step()
-
训练完线性探针后,我们在 STL10 数据集上对其进行了评估,结果与之前类似,准确率达到了 98.57!通过使用一些标注数据,我们训练出了一个小型逻辑回归层,使 OpenCLIP 在 STL10 数据集上的准确率提高了近 +2%!这种改进可能是因为我们的文本提示(如 “an airplane”)可能与 STL10 数据集中的标签不完全匹配。但是,通过查看每个标签的一些示例,我们可以学习到更准确地代表类标签的参考嵌入。
-
训练学生模型以模仿 OpenCLIP
- 我们现在已经看到,使用大型 OpenCLIP 模型,我们可以在 STL10 图像分类数据集上不费吹灰之力就取得具有竞争力的结果。但是,OpenCLIP 体积庞大,与其他模型架构相比,内存消耗和延迟可能较高。此外,作为视觉 Transformer 模型,OpenCLIP 在利用 Jetson AGX Orin 上的深度学习加速器(DLA)方面能力较弱,因为注意层中存在矩阵乘法。另一方面,像 resnet18 这样的 CNN 模型则通过 Jetson 上的 GPU 和 DLA 进行了高度优化,使我们能够以更高的吞吐量和更少的内存运行模型。然而,知识提炼会影响模型的准确性,因此我们希望更好地了解哪些因素最为重要。为此,我们进行了一些实验,试图回答几个问题:
- 蒸馏法与使用 GT 标签进行的训练相比效果如何?可以使用 GT 标签从头开始训练 resnet18,并将其与使用 OpenCLIP 输出概率(使用文本提示标签和线性回归方法)训练 resnet18 进行比较。在每个实验中,我们只使用 STL10 训练数据集(5000 张图像),以便与基线训练进行公平比较。
- 在使用相同数据进行训练的情况下,使用线性回归头的提炼模型比直接使用 GT 标签的准确度更高,即使每种方法都使用相同的可用数据和标签集。但是,文本提示方法无法达到从头开始训练 resnet18 的准确度。
- 这表明,即使在数据分布相同的情况下,知识蒸馏也有能力提高模型的准确性。不过,在下一节中,我们将看到如何在提炼过程中利用无标签数据,从而更进一步。事实上,使用无标签数据,文本提示方法(在此过程中不需要任何地面实况标签)能够大大超过从头开始训练模型的效果!
- 用于训练的数据分布对模型准确性有何影响?
- 现在我们已经看到,知识提炼能够提高模型的准确性。但是,我们的最佳学生模型(准确率为 60.65%)仍然远远低于教师模型(准确率为 98.57%)。为什么会这样?学生模型 resnet18 是否缺乏模仿教师模型的能力?还是其他原因,也许是我们用来提炼的数据?
- 为了帮助回答这个问题,我们进行了一系列实验,不仅使用了 STL10 训练数据集中的 5000 张图像,还使用了补充 STL10 数据集中提供的 100,000 张未标记图像进行知识提炼。只需在蒸馏过程中从 STL10 未标注数据集拆分出大量未标注图像,模型的准确性就会大幅提高!现在,带有文本提示标签的 resnet18 模型的准确率远远超过了使用 GT 标签训练的 resnet18,仅比带有文本提示标签的原始 OpenCLIP 模型低 2%。带有线性分类头的蒸馏模型则更进一步,达到了 96.88%,超过了带有文本提示的原始 OpenCLIP 模型,比带有线性分类头的最佳 OpenCLIP 变体低不到 2%。
- 所以,总而言之、在没有标签,只有大量无标签数据的情况下,我们通过提炼文本提示 OpenCLIP 分类器,利用 resnet18 模型实现了 94.32% 的准确率;在有一些标签和大量无标签数据的情况下,我们可以通过提炼具有线性分类头的 OpenCLIP 分类器,使 resnet18 模型的准确率达到 96.88%。我们得到的经验是,蒸馏所使用的数据对于获得良好的精度非常重要。但我们的学生模型架构如何?使用 resnet50 这样的高容量模型能否取得更好的效果?
- 学生模型架构对模型准确性有何影响?resnet50 会比 resnet18 获得更高的准确率吗?
- 为了探索学生模型架构对最终模型准确性的影响,我们进行了一系列实验,使用了我们的最佳蒸馏配置,以及三种不同的模型架构:resnet18、resnet34 和 resnet50。在切换学生模型架构时,我们看到的差异相对较小,可以忽略不计。这意味着,至少对于这项任务(STL10 分类)来说,用于提炼的数据比学生模型架构重要得多。对于其他任务来说,情况很可能并非如此,但我们希望将这些结果包括进来,至少在这种情况下与大家分享这一发现,以便大家了解并优先考虑哪些因素需要首先探索。
- 蒸馏方法对模型准确性有何影响?基于类概率还是内部特征进行训练更好?
- 到目前为止,我们已经看到用于训练的数据对模型的准确性有很大影响,但用于蒸馏的方法又如何呢?如前所述,知识提炼可以通过以下几种方式进行:响应蒸馏: 拟合模型以学习输出类别概率;特征提炼: 拟合模型以学习内部特征
- 为了探索这些决定的影响,我们进行了几次实验,训练 resnet18 学习 OpenCLIP 输出的视觉特征嵌入(512 维),而不是类别概率。我们将其输入文本提示或线性回归头,就像使用原始 OpenCLIP 模型一样。
- 通过对特征进行训练(使用均方误差损失)获得的准确率略高于通过对输出类别概率进行训练(使用 KL 发散)获得的准确率。概述如下:通过特征训练,我们的文本提示学生准确率提高了 0.25;通过特征训练,我们的线性分类头学生准确率提高了 0.032
- 虽然这些变化并不显著,但有趣的是,在 embeddings 上进行训练并不会对模型的准确性产生不利影响。之所以说这很有趣,是因为这些 embeddings 并不是明确针对 STL10 任务的,它们有可能像最初的 OpenCLIP 模型一样被重新利用,只需改变用于分类的文本提示或重新训练线性回归头即可。
- 不过,在本教程中,我们尚未以通用方式提炼 OpenCLIP。下一步,我们将对这种可能性进行有趣的探索。但现在,我们已经为目标任务建立了一个相当不错的分类模型。让我们讨论一下如何优化我们的学生模型以便部署,并看看延迟和内存消耗与原始 OpenCLIP 模型相比如何。
- 蒸馏法与使用 GT 标签进行的训练相比效果如何?可以使用 GT 标签从头开始训练 resnet18,并将其与使用 OpenCLIP 输出概率(使用文本提示标签和线性回归方法)训练 resnet18 进行比较。在每个实验中,我们只使用 STL10 训练数据集(5000 张图像),以便与基线训练进行公平比较。
用 TensorRT 优化模型并比较性能
-
以上我们展示了如何训练一个 resnet CNN 模型来模仿大型 OpenCLIP 模型。现在,让我们看看为什么这样的努力是值得的。使用我们的学生模型能带来哪些性能提升?为了了解每个模型的预期性能,我们将使用英伟达 TensorRT 对模型进行优化,并测量英伟达 Jetson 上的吞吐量和内存消耗。为此,我们用 ONNX 导出了模型,并用英伟达 TensorRT 进行了优化。下面我们将展示 OpenCLIP 和 resnet18 在 Jetson Orin Nano 上以 224x224 分辨率运行的性能,批量大小为 8。在使用 TensorRT 对每个模型进行优化后,resnet18 模型的速度是原始 open_clip 模型的 4.2 倍,而内存使用量则减少了 3.45 倍。
-
为了测量内存,我们使用了 tegrastats。我们用 trtexec 记录了模型执行前和模型执行后的系统内存。上表中的内存是模型运行时系统内存的变化。对于这个特定的图像分类任务,Jetson Orin Nano 有足够能力运行原始 OpenCLIP 模型。但是,如果要运行更高分辨率的模型,这些吞吐量和内存消耗方面的差异可能会变得至关重要。
-
我们探索了在 STL10 分类数据集上使用知识蒸馏来训练 resnet18 分类器。在这项任务中,我们取得了与原始 OpenCLIP 模型相当的准确率,同时大幅减少了运行时间和内存消耗。希望本教程能让您了解如何利用知识蒸馏技术将大型模型提升到边缘水平。除本介绍外,我们还创建了一个配套项目 clip-distillation,让您可以轻松创建零标签图像分类器,完成自己的自定义任务!它包括下载经过剪辑过滤的相关图片的脚本,以用于提炼;提取高效 CNN 模型以模仿 OpenCLIP 变换器模型的脚本,包括量化感知训练和结构稀疏性训练选项。
通过对 OpenCLIP 模型进行知识提炼
- 可以使用零标记数据创建自己的定制图像分类模型。即使您不直接需要图像分类器,您也可能会发现这个项目很有帮助,它启发您如何使用知识提炼来优化用于推理的模型,或者作为一个示例,说明如何在 NVIDIA Jetson 上使用量化感知训练和结构化稀疏性来训练用于推理的模型。该项目NVIDIA-AI-IOT/clip-distillation: Zero-label image classification via OpenCLIP knowledge distillation (github.com)包括:
- 从 LAION 数据库中搜索和下载相关数据的脚本,以用于提炼;
- 将任何 OpenCLIP 模型提炼为任何 Pytorch 图像模型(timm)CNN 模型的脚本。支持面向下游 INT8 推理的量化感知训练 (QAT);支持使用 ASP 库执行 2:4 结构稀疏性训练;
- 使用英伟达™(NVIDIA®)TensorRT 运行推理的脚本;支持 INT8 模式;支持在某些 NVIDIA Jetson 平台(如 NVIDIA Jetson Orin Nano)上加速 2:4 结构稀疏模型。
通过 CLIP 过滤搜索和下载图像
-
在提炼模型时,我们需要做的第一件事就是获取用于提炼的数据。在这项任务中,我们将通过搜索 LAION 数据库来查找相关图片。我们提供了一个脚本来简化这项工作。要搜索相关图像,首先要创建一个 data/text_prompts.txt 文件,其中包含要查询的文本提示。每个提示都应独立成行。接下来,调用脚本查询与文本提示相匹配的图片。
-
python3 search_clip_images.py \"data/text_prompts.txt" \"data/image_urls.txt" \-n 5000 \-m 10000 \--max_workers 2 \--append
-
这将输出一个文件 data/image_urls.txt,其中包含与我们的文本提示查询相匹配的图像 URL。现在我们已经找到了用于蒸馏的相关图片,我们需要下载它们。为此,我们调用以下脚本将图像下载到输出文件夹。
-
python3 download_images.py \"data/image_urls.txt" \"data/images" \--max_workers 32 \--timeout 2
-
该脚本将把图片下载到 data/images 文件夹。每张图片都将根据其 URL 获得一个唯一的文件名。
-
计算 OpenCLIP 嵌入
-
我们在上面下载的图像将在蒸馏过程中作为教师和学生模型的输入。遗憾的是,在训练过程中执行教师模型可能会比较慢。为了加快这一过程,我们将预先计算教师模型的输出,这样就不需要在训练过程中执行教师模型了。为此,请调用 compute_openclip_embeddings.py 脚本,如下所示、
-
python3 compute_openclip_embeddings.py \data/images \data/embeddings \--batch_size 16 \--num_workers 8 \--model_name ViT-B-32 \--pretrained laion2b_s34b_b79k
-
这将把输出的嵌入文件写入 data/embeddings 文件夹,文件名与图像文件名一致,但文件扩展名除外。注:有关可用模型名称和预训练权重标识符,请参考 OpenCLIP Repoopen_clip/src/open_clip/pretrained.py at fb72f4db1b17133befd6c67c9cf32a533b85a321 · mlfoundations/open_clip (github.com)。
-
训练学生 CNN 模型以模仿 OpenCLIP 模型
-
现在,我们已经有了用于知识提炼的数据,可以通过调用 distil_model_embeddings.py 脚本来执行提炼(学生模型训练),如下所示。
-
python3 distil_model_embeddings.py \resnet18 \data/images \data/embeddings \data/models/resnet18 \--output_dim 512 \--pretrained
-
这将向 data/models/resnet18 输出模型检查点和信息。我们在本例中使用的提炼模型是 resnet18。该模型经过 TensorRT 高度优化,我们可以在训练过程中随时应用其他优化,如降低精度和结构稀疏性。
-
使用提炼的模型进行推理
-
在提炼过程中,我们对学生模型进行了训练,使其与 open-clip 模型的特征相匹配。不过,我们有兴趣创建一个分类模型。要创建zero-shot分类模型,我们需要从描述类别标签的文本提示中生成文本嵌入。为此,我们使用了预先训练好的 OpenCLIP 文本编码器。我们调用 compute_openclip_text_embeddings.py 脚本来创建文本嵌入。
-
python3 compute_openclip_text_embeddings.py \data/text_prompts.txt \data/text_embeddings.npy \--model_name ViT-B-32
-
在这种情况下,我们使用与图像搜索相同的文本提示作为分类的文本提示。现在,我们已经计算出了图像类别的文本提示,我们可以使用 PyTorch 模型进行图像分类,具体如下
-
python3 predict_pytorch.py \resnet18 \data/models/resnet18/checkpoint.pth \data/text_embeddings.npy \assets/cat.jpg \--text_prompts data/text_prompts.txt
-
-
同样,我们也可以对实时摄像机画面进行如下推理:
-
python3 demo_pytorch.py \resnet18 \data/models/resnet18/checkpoint.pth \data/text_embeddings.npy \--text_prompts data/text_prompts.txt \--camera_device 0
-
利用结构化稀疏性训练学生模型
-
训练脚本提供结构稀疏性训练功能。这可以在使用 TensorRT 的英伟达 Jetson 平台上部署模型时提供额外的加速。用结构化稀疏性训练模型
-
python3 distil_model_embeddings.py \resnet18 \data/images \data/embeddings \data/models/resnet18_sparse \--output_dim 512 \--pretrained \--init_checkpoint data/models/resnet18/checkpoint.pth \--use_asp \--num_epochs 25
-
用 PyTorch 进行预测
-
python3 predict_pytorch.py \resnet18 \data/models/resnet18_sparse/checkpoint.pth \data/text_embeddings.npy \assets/cat.jpg \--text_prompts data/text_prompts.txt \--use_asp
-
PyTorch 演示
-
python3 demo_pytorch.py \resnet18 \data/models/resnet18_sparse/checkpoint.pth \data/text_embeddings.npy \--text_prompts data/text_prompts.txt \--camera_device 0 \--use_asp
-
导出到 ONNX
-
python3 export_onnx.py \resnet18 \data/models/resnet18_sparse/checkpoint.pth \data/onnx/resnet18_sparse.onnx \--use_asp
-
使用量化感知训练和 INT8 精确度训练学生
-
除了结构化稀疏性,我们还可以利用降低 INT8 精度来提高性能。量化感知训练是一种将使用 INT8 精度时引入的量化误差最小化的技术。它通过在训练过程中的模型前向传递过程中应用量化来实现这一目的。这样,模型就能在训练过程中适应量化误差。在使用训练后量化时,它还能让我们避免校准的需要。要使用量化感知训练来提炼模型,请按照以下步骤操作
-
使用量化感知训练 (QAT) 训练模型
-
python3 distil_model_embeddings.py \resnet18 \data/images \data/embeddings \data/models/resnet18_qat \--output_dim 512 \--pretrained \--init_checkpoint data/models/resnet18/checkpoint.pth \--use_qat \--num_epochs 25
-
用 PyTorch 进行预测
-
python3 predict_pytorch.py \resnet18 \data/models/resnet18_sparse/checkpoint.pth \data/text_embeddings.npy \assets/cat.jpg \--text_prompts data/text_prompts.txt \--use_qat
-
PyTorch 演示
-
python3 demo_pytorch.py \resnet18 \data/models/resnet18_sparse/checkpoint.pth \data/text_embeddings.npy \--text_prompts data/text_prompts.txt \--camera_device 0 \--use_qat
-
导出到 ONNX
-
python3 export_onnx.py \resnet18 \data/models/resnet18_qat/checkpoint.pth \data/onnx/resnet18_qat.onnx \--use_qat
-
-
希望您能在不使用任何标记数据的情况下训练出自己的图像分类模型。
知识蒸馏实战
-
数据使用我以前在图像分类任务中的数据集——植物幼苗数据集,先将数据集转为训练集和验证集。执行代码:
-
import glob import os import shutil image_list=glob.glob('data1/*/*.png') print(image_list) file_dir='data' if os.path.exists(file_dir):print('true')#os.rmdir(file_dir)shutil.rmtree(file_dir)#删除再建立os.makedirs(file_dir) else:os.makedirs(file_dir)from sklearn.model_selection import train_test_split trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42) train_dir='train' val_dir='val' train_root=os.path.join(file_dir,train_dir) val_root=os.path.join(file_dir,val_dir) for file in trainval_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(train_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name) for file in val_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(val_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)
-
-
教师网络选用coatnet_2,是一个比较大一点的网络了,模型的大小有200M。训练50个epoch,最好的模型在92%左右。导入需要的库
-
import torch.optim as optim import torch import torch.nn as nn import torch.nn.parallel import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms from torchvision import datasets from torch.autograd import Variable from model.coatnet import coatnet_2 import json import os
-
定义训练和验证函数
-
def train(model, device, train_loader, optimizer, epoch):model.train()sum_loss = 0total_num = len(train_loader.dataset)print(total_num, len(train_loader))for batch_idx, (data, target) in enumerate(train_loader):data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()print_loss = loss.data.item()sum_loss += print_lossif (batch_idx + 1) % 10 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),100. * (batch_idx + 1) / len(train_loader), loss.item()))ave_loss = sum_loss / len(train_loader)print('epoch:{},loss:{}'.format(epoch, ave_loss)) Best_ACC=0 # 验证过程 @torch.no_grad() def val(model, device, test_loader):global Best_ACCmodel.eval()test_loss = 0correct = 0total_num = len(test_loader.dataset)print(total_num, len(test_loader))with torch.no_grad():for data, target in test_loader:data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)_, pred = torch.max(output.data, 1)correct += torch.sum(pred == target)print_loss = loss.data.item()test_loss += print_losscorrect = correct.data.item()acc = correct / total_numavgloss = test_loss / len(test_loader)if acc > Best_ACC:torch.save(model, file_dir + '/' + 'best.pth')Best_ACC = accprint('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(avgloss, correct, len(test_loader.dataset), 100 * acc))return acc
-
定义全局参数
-
if __name__ == '__main__':# 创建保存模型的文件夹file_dir = 'CoatNet'if os.path.exists(file_dir):print('true')os.makedirs(file_dir, exist_ok=True)else:os.makedirs(file_dir)# 设置全局参数modellr = 1e-4BATCH_SIZE = 16EPOCHS = 50DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
图像预处理与增强
-
transform = transforms.Compose([transforms.RandomRotation(10),transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])])transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])])
-
使用pytorch默认读取数据的方式。
-
# 读取数据dataset_train = datasets.ImageFolder('data/train', transform=transform)dataset_test = datasets.ImageFolder("data/val", transform=transform_test)with open('class.txt', 'w') as file:file.write(str(dataset_train.class_to_idx))with open('class.json', 'w', encoding='utf-8') as file:file.write(json.dumps(dataset_train.class_to_idx))# 导入数据train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
-
设置模型和Loss
-
# 实例化模型并且移动到GPUcriterion = nn.CrossEntropyLoss()model_ft = coatnet_2()num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, 12) # 维度对齐model_ft.to(DEVICE)# 选择简单暴力的Adam优化器,学习率调低optimizer = optim.Adam(model_ft.parameters(), lr=modellr)cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)# 训练val_acc_list= {}for epoch in range(1, EPOCHS + 1):train(model_ft, DEVICE, train_loader, optimizer, epoch)cosine_schedule.step()acc=val(model_ft, DEVICE, test_loader)val_acc_list[epoch]=accwith open('result.json', 'w', encoding='utf-8') as file:file.write(json.dumps(val_acc_list))torch.save(model_ft, 'CoatNet/model_final.pth')
-
完成上面的代码就可以开始训练Teacher网络了。
-
-
学生网络选用ResNet18,是一个比较小一点的网络了,模型的大小有40M。训练50个epoch,最好的模型在86%左右。导入需要的库
-
import torch.optim as optim import torch import torch.nn as nn import torch.nn.parallel import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms from torchvision import datasets from torch.autograd import Variable from torchvision.models.resnet import resnet18 import json import os
-
定义训练和验证函数
-
def train(model, device, train_loader, optimizer, epoch):model.train()sum_loss = 0total_num = len(train_loader.dataset)print(total_num, len(train_loader))for batch_idx, (data, target) in enumerate(train_loader):data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()print_loss = loss.data.item()sum_loss += print_lossif (batch_idx + 1) % 10 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),100. * (batch_idx + 1) / len(train_loader), loss.item()))ave_loss = sum_loss / len(train_loader)print('epoch:{},loss:{}'.format(epoch, ave_loss)) Best_ACC=0 # 验证过程 @torch.no_grad() def val(model, device, test_loader):global Best_ACCmodel.eval()test_loss = 0correct = 0total_num = len(test_loader.dataset)print(total_num, len(test_loader))with torch.no_grad():for data, target in test_loader:data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)_, pred = torch.max(output.data, 1)correct += torch.sum(pred == target)print_loss = loss.data.item()test_loss += print_losscorrect = correct.data.item()acc = correct / total_numavgloss = test_loss / len(test_loader)if acc > Best_ACC:torch.save(model, file_dir + '/' + 'best.pth')Best_ACC = accprint('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(avgloss, correct, len(test_loader.dataset), 100 * acc))return acc
-
定义全局参数
-
if __name__ == '__main__':# 创建保存模型的文件夹file_dir = 'resnet'if os.path.exists(file_dir):print('true')os.makedirs(file_dir, exist_ok=True)else:os.makedirs(file_dir)# 设置全局参数modellr = 1e-4BATCH_SIZE = 16EPOCHS = 50DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
图像预处理与增强
-
transform = transforms.Compose([transforms.RandomRotation(10),transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])])transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])])
-
使用pytorch默认读取数据的方式
-
# 读取数据dataset_train = datasets.ImageFolder('data/train', transform=transform)dataset_test = datasets.ImageFolder("data/val", transform=transform_test)with open('class.txt', 'w') as file:file.write(str(dataset_train.class_to_idx))with open('class.json', 'w', encoding='utf-8') as file:file.write(json.dumps(dataset_train.class_to_idx))# 导入数据train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
-
设置模型和Loss
-
# 实例化模型并且移动到GPUcriterion = nn.CrossEntropyLoss()model_ft = resnet18()print(model_ft)num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, 12)model_ft.to(DEVICE)# 选择简单暴力的Adam优化器,学习率调低optimizer = optim.Adam(model_ft.parameters(), lr=modellr)cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)# 训练val_acc_list= {}for epoch in range(1, EPOCHS + 1):train(model_ft, DEVICE, train_loader, optimizer, epoch)cosine_schedule.step()acc=val(model_ft, DEVICE, test_loader)val_acc_list[epoch]=accwith open('result_student.json', 'w', encoding='utf-8') as file:file.write(json.dumps(val_acc_list))torch.save(model_ft, 'resnet/model_final.pth')
-
完成上面的代码就可以开始训练Student网络了。
-
-
蒸馏学生网络,学生网络继续选用ResNet18,使用Teacher网络蒸馏学生网络,训练50个epoch,最终ACC是89%。新建student_kd_train.py,导入需要的库
-
import torch.optim as optim import torch import torch.nn as nn import torch.nn.parallel import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms from torchvision import datasets from torch.autograd import Variable from torchvision.models.resnet import resnet18 import json import os
-
定义蒸馏函数
-
def distillation(y, labels, teacher_scores, temp, alpha):return nn.KLDivLoss()(F.log_softmax(y / temp, dim=1), F.softmax(teacher_scores / temp, dim=1)) * (temp * temp * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)
-
定义训练和验证函数
-
# 定义训练过程 def train(model, device, train_loader, optimizer, epoch):model.train()sum_loss = 0total_num = len(train_loader.dataset)print(total_num, len(train_loader))for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)teacher_output = teacher_model(data) # 训练出教师的 teacher_outputteacher_output = teacher_output.detach() # 切断老师网络的反向传播loss = distillation(output, target, teacher_output, temp=7.0, alpha=0.7) # 通过老师的 teacher_output训练学生的outputloss.backward()optimizer.step()print_loss = loss.data.item()sum_loss += print_lossif (batch_idx + 1) % 10 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),100. * (batch_idx + 1) / len(train_loader), loss.item()))ave_loss = sum_loss / len(train_loader)print('epoch:{},loss:{}'.format(epoch, ave_loss)) Best_ACC=0 # 验证过程 @torch.no_grad() def val(model, device, test_loader):global Best_ACCmodel.eval()test_loss = 0correct = 0total_num = len(test_loader.dataset)print(total_num, len(test_loader))with torch.no_grad():for data, target in test_loader:data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)_, pred = torch.max(output.data, 1)correct += torch.sum(pred == target)print_loss = loss.data.item()test_loss += print_losscorrect = correct.data.item()acc = correct / total_numavgloss = test_loss / len(test_loader)if acc > Best_ACC:torch.save(model, file_dir + '/' + 'best.pth')Best_ACC = accprint('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(avgloss, correct, len(test_loader.dataset), 100 * acc))return acc
-
定义全局参数
-
if __name__ == '__main__':# 创建保存模型的文件夹file_dir = 'resnet_kd'if os.path.exists(file_dir):print('true')os.makedirs(file_dir, exist_ok=True)else:os.makedirs(file_dir)# 设置全局参数modellr = 1e-4BATCH_SIZE = 16EPOCHS = 50DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
图像预处理与增强
-
transform = transforms.Compose([transforms.RandomRotation(10),transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])])transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])])
-
使用pytorch默认读取数据的方式
-
# 读取数据dataset_train = datasets.ImageFolder('data/train', transform=transform)dataset_test = datasets.ImageFolder("data/val", transform=transform_test)with open('class.txt', 'w') as file:file.write(str(dataset_train.class_to_idx))with open('class.json', 'w', encoding='utf-8') as file:file.write(json.dumps(dataset_train.class_to_idx))# 导入数据train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
-
设置模型和Loss
-
criterion = nn.CrossEntropyLoss()model_ft = resnet18()print(model_ft)num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, 12)model_ft.to(DEVICE)# 选择简单暴力的Adam优化器,学习率调低optimizer = optim.Adam(model_ft.parameters(), lr=modellr)cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)# 训练val_acc_list= {}for epoch in range(1, EPOCHS + 1):train(model_ft, DEVICE, train_loader, optimizer, epoch)cosine_schedule.step()acc=val(model_ft, DEVICE, test_loader)val_acc_list[epoch]=accwith open('result_student.json', 'w', encoding='utf-8') as file:file.write(json.dumps(val_acc_list))torch.save(model_ft, 'resnet_kd/model_final.pth')
-
-
加载保存的结果,然后绘制acc曲线。
-
import numpy as np from matplotlib import pyplot as plt import json teacher_file='result.json' student_file='result_student.json' student_kd_file='result_kd.json' def read_json(file):with open(file, 'r', encoding='utf8') as fp:json_data = json.load(fp)print(json_data)return json_data teacher_data=read_json(teacher_file) student_data=read_json(student_file) student_kd_data=read_json(student_kd_file) x =[int(x) for x in list(dict(teacher_data).keys())] print(x) plt.plot(x, list(teacher_data.values()), label='teacher') plt.plot(x,list(student_data.values()), label='student without KD') plt.plot(x, list(student_kd_data.values()), label='student with KD') plt.title('Test accuracy') plt.legend() plt.show()
-
知识蒸馏的另一个例子
-
知识蒸馏的核心思想是利用教师模型的输出作为附加的监督信号来训练学生模型。在传统的监督学习中,目标是最小化模型预测与真实标签之间的差距(损失函数)。而在知识蒸馏中,除了最小化模型预测与真实标签之间的差距外,还引入了一个额外的损失项,该项衡量了学生模型预测与教师模型预测之间的距离。 具体而言,损失函数通常由两部分组成:一部分是传统的交叉熵损失,用于衡量学生模型的预测与真实标签之间的差距;另一部分是知识蒸馏损失,用于衡量学生模型的预测与教师模型的预测之间的差距。知识蒸馏损失通常使用一些形式的距离度量来计算,例如平方误差损失或者交叉熵损失。在蒸馏求loss时候,需要采用蒸馏函数,这个函数就是把softmax函数在计算时候,预测出来的结果Z进行除以温度T,进行求解后验概率。下面是修改后的softmax函数,也就是蒸馏函数。
-
q i = e x p ( z i / T ) ∑ j e x p ( z j / T ) q_i=\frac{exp(z_i/T)}{\sum_jexp(z_j/T)} qi=∑jexp(zj/T)exp(zi/T)
-
-
首先我们要先训练出较大模型既teacher模型。再对teacher模型进行蒸馏,此时已经有一个训练好的teacher模型,所以我们能很容易知道teacher模型输入特征 x 之后,预测出来的结果teacher_preds标签。求到老师预测结果之后,我们需要求解学生在训练过程中的每一次结果student_preds标签。
- 先求hard_loss,也就是学生模型的预测student_preds与真实标签targets之间的损失。
- 再求soft_loss,也就是学生模型的预测student_preds与教师模型teacher_preds的预测之间的损失。
- 求出hard_loss与soft_loss之后,求和总loss=a*hard_loss + (1-a)soft_loss,a是一个自己设置的权重参数,最后反向传播继续迭代。
-
数据集采用的是手写数字的数据集mnist数据集,如果没有下载,代码部分中会进行下载,只需要把download改成True,然后就会保存在当前目录中。该数据集将其分成80%的训练集,20%的测试集,最后返回train_dataset和test_datatset。
-
class MyDataset(Dataset):def __init__(self,opt):self.opt = optdef MyData(self):## mnist数据集下载0mnist = datasets.MNIST(root='../datasets/', train=True, download=True, transform=transforms.Compose([transforms.Resize(self.opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),)dataset_size = len(mnist)train_size = int(0.8 * dataset_size)test_size = dataset_size - train_sizetrain_dataset, test_dataset = random_split(mnist, [train_size, test_size])train_dataloader = DataLoader(train_dataset,batch_size=self.opt.batch_size,shuffle=True,)test_dataloader = DataLoader(test_dataset,batch_size=self.opt.batch_size,shuffle=False, # 在测试集上不需要打乱顺序)return train_dataloader,test_dataloader
-
首先是teacher模型构造,经过三次线性层。
-
import torch.nn as nn import torch img_area = 784 class TeacherModel(nn.Module):def __init__(self,in_channel=1,num_classes=10):super(TeacherModel,self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(img_area,1200)self.fc2 = nn.Linear(1200, 1200)self.fc3 = nn.Linear(1200, num_classes)self.dropout = nn.Dropout(p=0.5)def forward(self, x):x = x.view(-1, img_area)x = self.fc1(x)x = self.dropout(x)x = self.relu(x)x = self.fc2(x)x = self.dropout(x)x = self.relu(x)x = self.fc3(x)return x
-
训练teacher模型,老师模型训练完成后其权重参数会保存在teacher.pth当中,为以后调用。
-
import torch.nn as nn import torch from tqdm import tqdm from dist.TeacherModel import TeacherModel weight_path = './teacher.pth' cuda = True if torch.cuda.is_available() else False device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') torch.backends.cudnn.benchmark = True #使用卷积cuDNN加速 class TeacherTrainer():def __init__(self,opt,train_dataloader,test_dataloader):self.opt = optself.train_dataloader = train_dataloaderself.test_dataloader = test_dataloaderdef trainer(self):# 老师模型opt = self.opttrain_dataloader = self.train_dataloadertest_dataloader = self.test_dataloaderteacher_model = TeacherModel()teacher_model = teacher_model.to(device)criterion = nn.CrossEntropyLoss()optimizer_teacher = torch.optim.Adam(teacher_model.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))for epoch in range(opt.n_epochs): ## epoch:50teacher_model.train()for data, targets in tqdm(train_dataloader):data = data.to(device)targets = targets.to(device)preds = teacher_model(data)loss = criterion(preds, targets)optimizer_teacher.zero_grad()loss = criterion(preds, targets)loss.backward()optimizer_teacher.step()teacher_model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_dataloader:x = x.to(device)y = y.to(device)preds = teacher_model(x)predictions = preds.max(1).indicesnum_correct += (predictions == y).sum()num_samples += predictions.size(0)acc = (num_correct / num_samples).item()torch.save(teacher_model.state_dict(), weight_path)teacher_model.train()print('teacher: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))
-
设置参数以及主函数
-
import argparse import torch from dist.DistillationTrainer import DistillationTrainer from dist.MyDateLoader import MyDataset from dist.TeacherTrainer import TeacherTrainer device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def paras():## 超参数配置parser = argparse.ArgumentParser()parser.add_argument("--n_epochs", type=int, default=5, help="number of epochs of training")parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")parser.add_argument("--channels", type=int, default=1, help="number of image channels")parser.add_argument("--sample_interval", type=int, default=500, help="interval betwen image samples")opt = parser.parse_args()## opt = parser.parse_args(args=[]) ## 在colab中运行时,换为此行print(opt)return opt if __name__ == '__main__':opt = paras()data = MyDataset(opt)train_dataloader, test_dataloader = data.MyData()# 训练Teacher模型teacher_trainer = TeacherTrainer(opt,train_dataloader,test_dataloader)teacher_trainer.trainer()
-
学生模型的构建,学生模型也是经过了三次线性层,但是神经元没有teacher当中多。所以student模型会比teacher模型小很多。
-
import torch.nn as nn import torch img_area = 784 class StudentModel(nn.Module):def __init__(self,in_channel=1,num_classes=10):super(StudentModel,self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(img_area,20)self.fc2 = nn.Linear(20, 20)self.fc3 = nn.Linear(20, num_classes)def forward(self, x):x = x.view(-1, img_area)x = self.fc1(x)# x = self.dropout(x)x = self.relu(x)x = self.fc2(x)# x = self.dropout(x)x = self.relu(x)x = self.fc3(x)return x
-
知识蒸馏训练, 首先将teacher模型中的权重参数teacher.pth放入模型当中。
-
#加载训练好的模型 teacher_model = TeacherModel() if os.path.exists(weights):teacher_model.load_state_dict(torch.load(weights))print('successfully') else:print('not loading') teacher_model = teacher_model.to(device)
-
设置损失求解的函数, hard_loss用的就是普通的交叉熵损失函数,而soft_loss就是用的KL散度。
-
hard_loss = nn.CrossEntropyLoss() alpha = 0.3# hard_loss权重 # soft_loss soft_loss = nn.KLDivLoss(reduction="batchmean")
-
之后再进行蒸馏训练,温度为7。 先求得hard_loss就是用学生模型预测的标签和真实标签进行求得损失。再求soft_loss就是用学生模型预测的标签和老师模型预测的标签进行求得损失。使用softmax时候还需要进行除以温度temp。最后反向传播,求解模型
-
for epoch in range(opt.n_epochs): ## epoch:5for data, targets in tqdm(train_dataloader):data = data.to(device)targets = targets.to(device)# 老师模型预测with torch.no_grad():teacher_preds = teacher_model(data)# 学生模型预测student_preds = model(data)# 计算hard_lossstudent_loss = hard_loss(student_preds, targets)# 计算蒸馏后的预测损失ditillation_loss = soft_loss(F.softmax(student_preds / temp, dim=1),F.softmax(teacher_preds / temp, dim=1))loss = alpha * student_loss + (1 - alpha) * ditillation_lossoptimizer.zero_grad()loss.backward()optimizer.step()model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_dataloader:x = x.to(device)y = y.to(device)preds = model(x)predictions = preds.max(1).indicesnum_correct += (predictions == y).sum()num_samples += predictions.size(0)acc = (num_correct / num_samples).item()model.train()print('distillation: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))
-
整个蒸馏训练代码
-
import torch.nn as nn import torch import torch.nn.functional as F import os from tqdm import tqdm from dist.StudentModel import StudentModel from dist.TeacherModel import TeacherModel weights = './teacher.pth' ## 设置cuda:(cuda:0) cuda = True if torch.cuda.is_available() else False device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') torch.backends.cudnn.benchmark = True #使用卷积cuDNN加速 class DistillationTrainer():def __init__(self,opt,train_dataloader,test_dataloader):self.opt = optself.train_dataloader = train_dataloaderself.test_dataloader = test_dataloaderdef trainer(self):opt = self.opttrain_dataloader = self.train_dataloadertest_dataloader = self.test_dataloaderteacher_model = TeacherModel()if os.path.exists(weights):teacher_model.load_state_dict(torch.load(weights))print('successfully')else:print('not loading')teacher_model = teacher_model.to(device)teacher_model.eval()model = StudentModel()model = model.to(device)temp = 7# hard_losshard_loss = nn.CrossEntropyLoss()# hard_loss权重alpha = 0.3# soft_losssoft_loss = nn.KLDivLoss(reduction="batchmean")optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))for epoch in range(opt.n_epochs): ## epoch:5for data, targets in tqdm(train_dataloader):data = data.to(device)targets = targets.to(device)# 老师模型预测with torch.no_grad():teacher_preds = teacher_model(data)# 学生模型预测student_preds = model(data)# 计算hard_lossstudent_loss = hard_loss(student_preds, targets)# 计算蒸馏后的预测损失ditillation_loss = soft_loss(F.softmax(student_preds / temp, dim=1),F.softmax(teacher_preds / temp, dim=1))loss = alpha * student_loss + (1 - alpha) * ditillation_lossoptimizer.zero_grad()loss.backward()optimizer.step()model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_dataloader:x = x.to(device)y = y.to(device)preds = model(x)predictions = preds.max(1).indicesnum_correct += (predictions == y).sum()num_samples += predictions.size(0)acc = (num_correct / num_samples).item()model.train()print('distillation: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))
-
蒸馏训练的主函数,该部分大致与teacher模型训练类似,只是调用不同。
-
import argparse import torch from dist.DistillationTrainer import DistillationTrainer from dist.MyDateLoader import MyDataset from dist.TeacherTrainer import TeacherTrainer device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def paras():## 超参数配置parser = argparse.ArgumentParser()parser.add_argument("--n_epochs", type=int, default=5, help="number of epochs of training")parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")parser.add_argument("--channels", type=int, default=1, help="number of image channels")parser.add_argument("--sample_interval", type=int, default=500, help="interval betwen image samples")opt = parser.parse_args()## opt = parser.parse_args(args=[]) ## 在colab中运行时,换为此行print(opt)return opt if __name__ == '__main__':opt = paras()data = MyDataset(opt)train_dataloader, test_dataloader = data.MyData()# 训练Teacher模型# teacher_trainer = TeacherTrainer(opt,train_dataloader,test_dataloader)# teacher_trainer.trainer()distillation_trainer = DistillationTrainer(opt,train_dataloader,test_dataloader)distillation_trainer.trainer()
-
-
示例
-
import torch from torch import nn import torch.nn.functional as F from torch.utils.data import DataLoader from tqdm import tqdm import torchvision from torchvision import transforms class TeacherModel(nn.Module):def __init__(self, in_channels=1, num_classes=10):super(TeacherModel, self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(784, 1200)self.fc2 = nn.Linear(1200, 1200)self.fc3 = nn.Linear(1200, num_classes)self.dropout = nn.Dropout(p=0.5)def forward(self, x):x = x.view(-1, 784) #输入的图像是一个28x28像素的灰度图像,因此输入层有784个神经元。x = self.relu(self.dropout(self.fc1(x)))x = self.relu(self.dropout(self.fc2(x)))x = self.fc3(x)return x class StudentModel(nn.Module):def __init__(self, in_channels=1, num_classes=10):super(StudentModel, self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(784, 20)self.fc2 = nn.Linear(20, 20)self.fc3 = nn.Linear(20, num_classes)self.dropout = nn.Dropout(p=0.5)def forward(self, x):x = x.view(-1, 784)x = self.relu(self.dropout(self.fc1(x)))x = self.relu(self.dropout(self.fc2(x)))x = self.fc3(x)return x def teacher(device, train_loader, test_loader):print('--------------teachermodel start--------------')model = TeacherModel()model = model.to(device)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)epochs = 6for epoch in range(epochs):model.train()for data, target in tqdm(train_loader):data = data.to(device) #(32,1,28,28)target = target.to(device) #(32,)preds = model(data)loss = criterion(preds, target)optimizer.zero_grad()loss.backward()optimizer.step()model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_loader:x = x.to(device)y = y.to(device)preds = model(x)predictions = preds.max(1).indicesnum_correct += (predictions.eq(y)).sum().item()num_samples += predictions.size(0)acc = num_correct / num_samplesmodel.train()print('Epoch:{}\t Acc:{:.4f}'.format(epoch + 1, acc))torch.save(model, 'teacher.pkl')print('--------------teachermodel end--------------') def student(device, train_loader, test_loader):print('--------------studentmodel start--------------')model = StudentModel()model = model.to(device)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)epochs = 3for epoch in range(epochs):model.train()for data, target in tqdm(train_loader):data = data.to(device)target = target.to(device)preds = model(data)loss = criterion(preds, target)optimizer.zero_grad()loss.backward()optimizer.step()model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_loader:x = x.to(device)y = y.to(device)# print(y)preds = model(x)predictions = preds.max(1).indices# print(predictions)num_correct += (predictions.eq(y)).sum().item()num_samples += predictions.size(0)acc = num_correct / num_samplesmodel.train()print('Epoch:{}\t Acc:{:.4f}'.format(epoch + 1, acc))print('--------------studentmodel prediction end--------------') def kd(teachermodel, device, train_loader, test_loader):print('--------------kdmodel start--------------')teachermodel.eval() #将模型设置为评估模式,Dropout和Batch Normalization 等层会被固定住;模型参数不会被更新,即不会进行反向传播和梯度更新,只会进行前向传播计算studentmodel = StudentModel()studentmodel = studentmodel.to(device)studentmodel.train()temp = 7 #蒸馏温度alpha = 0.3hard_loss = nn.CrossEntropyLoss()soft_loss = nn.KLDivLoss(reduction='batchmean')#计算两个概率分布之间的KL散度,用于度量两个分布之间的差异或相似性,需要提供两个输入张量:input预测的概率分布和 target目标概率分布optimizer = torch.optim.Adam(studentmodel.parameters(), lr=1e-4)epochs = 20for epoch in range(epochs):for data, target in tqdm(train_loader):data = data.to(device)target = target.to(device)with torch.no_grad(): #当你使用 torch.no_grad() 包裹代码块时,该代码块内的张量操作将不会被追踪,也不会计算梯度。这可以减少内存消耗并提高代码的执行效率。teacher_preds = teachermodel(data)student_preds = studentmodel(data)student_loss = hard_loss(student_preds, target) #hard_lossdistillation_loss = soft_loss(F.log_softmax(student_preds / temp, dim=1),F.softmax(teacher_preds / temp, dim=1)) #soft_lossloss = alpha * student_loss + (1 - alpha) * distillation_lossoptimizer.zero_grad()loss.backward()optimizer.step()studentmodel.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_loader:x = x.to(device)y = y.to(device)preds = studentmodel(x)predictions = preds.max(1).indicesnum_correct += (predictions.eq(y)).sum().item()num_samples += predictions.size(0)acc = num_correct / num_samplesstudentmodel.train()print('Epoch:{}\t Acc:{:.4f}'.format(epoch + 1, acc))print('--------------kdmodel end--------------') if __name__ == '__main__':torch.manual_seed(0)device = torch.device("cpu")torch.backends.cudnn.benchmark = True#加载数据集X_train = torchvision.datasets.MNIST(root="dataset/",train=True,transform=transforms.ToTensor(),download=True)X_test = torchvision.datasets.MNIST(root="dataset/",train=False,transform=transforms.ToTensor(),download=True)train_loader = DataLoader(dataset=X_train, batch_size=32, shuffle=True)test_loader = DataLoader(dataset=X_test, batch_size=32, shuffle=False)#从头训练教师模型,并预测teacher(device, train_loader, test_loader)#从头训练学生模型,并预测student(device, train_loader, test_loader)#知识蒸馏训练学生模型model = torch.load('teacher.pkl')kd(model, device, train_loader, test_loader)