知识蒸馏,需要合适的教师模型,学生模型,蒸馏数据,损失函数,训练策略,让小模型有大模型的知识

  • 知识蒸馏使用的是Teacher—Student模型,其中teacher是“知识”的输出者,student是“知识”的接受者。知识蒸馏的过程分为2个阶段:

    • 原始模型训练: 训练"Teacher模型", 它的特点是模型相对复杂,也可以由多个分别训练的模型集成而成。我们对"Teacher模型"不作任何关于模型架构、参数量、是否集成方面的限制,唯一的要求就是,对于输入X, 其都能输出Y,其中Y经过softmax的映射,输出值对应相应类别的概率值
    • 精简模型训练: 训练"Student模型", 它是参数量较小、模型结构相对简单的单模型。同样的,对于输入X,其都能输出Y,Y经过softmax映射后同样能输出对应相应类别的概率值
    • Teacher学习能力强,可以将它学到的知识迁移给学习能力相对弱的Student模型,以此来增强Student模型的泛化能力。复杂笨重但是效果好的Teacher模型不上线,就单纯是个导师角色,真正部署上线进行预测任务的是灵活轻巧的Student小模型。
    • 在这里插入图片描述
  • 知识蒸馏可以将大型的复杂模型(如深度神经网络)转换成小型的简化模型,从而减少了模型的存储空间和计算资源需求,使得模型更适合在资源受限的设备上部署和运行,如移动设备或嵌入式系统。通过知识蒸馏,可以将一个在大规模数据集上训练的模型的知识迁移到一个相似但规模较小的模型上,这有助于在资源受限的情况下进行迁移学习或模型微调知识蒸馏还可以帮助提高模型的泛化性能,因为学生模型在训练过程中利用了教师模型的“软标签”,这些软标签包含了教师模型对数据分布的更丰富的信息,有助于减少过拟合

  • 我们可以通过将 OpenCLIP 视觉语言模型中的知识转移到 ResNet18 模型中来探索知识蒸馏,以便对 STL10 数据集进行分类。将探讨用于蒸馏的数据、用于蒸馏的方法和模型架构对最终准确性的影响。可以帮助学习如何利用现有的大型模型,并将知识转移到更适合边缘部署的架构上使用,因为这些架构具有更低的内存消耗、更高的吞吐量或更好的架构支持(即:在 Jetson AGX Orin 的深度学习加速器上运行)。

  • 以下从概念上介绍知识蒸馏,包含了重现结果的代码。如果熟悉深度学习和模型训练,并且正在寻找将大型模型带入边缘的方法,可能会对您有所帮助!查看Nvidia的其他项目 clip-distillation(NVIDIA-AI-IOT/clip-distillation: Zero-label image classification via OpenCLIP knowledge distillation (github.com)),了解如何在没有任何标记数据的情况下使用知识提炼创建自己的自定义图像分类器!还将讨论如何在 NVIDIA Jetson Orin Nano 上对模型进行剖析和优化,以便最终实现实时部署。

什么是知识蒸馏

  • 知识蒸馏是一种将知识从一个神经网络(教师)转移到另一个神经网络(学生)的技术,如需更深入地了解知识蒸馏,建议阅读[2006.05525] Knowledge Distillation: A Survey (arxiv.org)。这一过程有多种形式,可分为以下几类

    • 响应知识蒸馏: 使用 divergence 损失(即使用 KL 散度)训练输出类概率分布,使其与教师概率分布相匹配。
    • 特征知识蒸馏: 训练学生模型的内部特征,使其与教师模型的内部特征直接匹配(即:使用均方误差)。
    • 关系知识蒸馏: 训练教师模型中特征的相对分布,使之与学生模型中特征的相对分布相匹配。
  • 本文将探讨(1、2),因为与关系知识提炼相比,(1、2)非常简单。我们特别感兴趣的是探索如何使用知识蒸馏来获取基于 transformer 的大型教师模型(OpenCLIP),并训练更快、内存更小的模型(ResNet18),使其更适合边缘部署。将通过将 OpenCLIP 调整为针对 STL10 分类数据集的图像分类器来探索这一概念,用于蒸馏的数据和技术如何影响最终模型的准确性。

  • 探讨的教师模型是 OpenCLIP。OpenCLIP 是 OpenAI 的 CLIP(对比语言-图像预训练)的开源实现。CLIP 模型的训练目的是将图像与文本进行匹配。具体来说,该模型由以下部分组成

    • 图像编码器,它能获取图像并生成代表图像的嵌入信息
    • 文本编码器,可接收文本提示并生成代表文本提示的嵌入内容
  • 通过对模型进行训练,配对图像和文本的图像和文本嵌入相似,而非配对图像和文本的图像和文本嵌入差异很大。这个模型的一个有趣之处在于,它是在大量非结构化数据的基础上训练出来的,所学习到的特征可以应用到各种各样的任务中。事实上,只需提供类别描述,它就能在分类任务中达到很好的零误差准确率。不过,这种模型的缺点是,与 ResNet 等基于 CNN 的架构相比,它的运行时间和内存消耗相对较高。这就引出了一个问题:我们能否利用 OpenCLIP 的功能,同时获得更低的内存消耗和延迟?本文将探讨如何使用 OpenCLIP 模型作为教师模型,训练 CNN 模型完成分类任务。我们将探讨用于训练的数据、针对分类任务调整模型的方法以及提炼模型的方法对最终准确性的影响。

  • 为了探索如何将 OpenCLIP 用于分类和知识蒸馏,使用斯坦福大学的 STL10 数据集。只探讨了 STL10 数据集,某些结果可能取决于特定的数据分布和任务。虽然我们不能保证这些结果会转换到其他任务和数据源中,这样您就可以用自己的数据来探索知识蒸馏。STL10 数据集是一个包含 10 个类别的分类数据集:

    • 与 MNIST 相比,它包含自然图像,适合使用 OpenCLIP。
    • 与 CIFAR10 相比,图像的分辨率为 96x96,而不是 32x32。这更接近 OpenCLIP 的训练分辨率。
    • 它包含大量未标记的图像(100,000 张),这使我们能够探索在训练过程中使用未标记数据的好处
  • 在开始使用教师模型训练学生之前,最好先对教师模型在当前任务中的表现有一个初步的预期。这大致有助于设定我们希望学生模型达到的最佳性能预期。由于 OpenCLIP 并未直接在 STL10 数据集上进行训练,因此我们有几种方法可以调整该模型以执行分类

    • 使用文本提示进行分类:在 STL10 数据集上使用 OpenCLIP 进行分类的第一种也是最简单的方法是使用文本提示定义类别,通过文本编码器运行提示,并将每个编码后的文本提示与 GT 标签进行比较。对于 STL10 数据集,我们可以按以下方式生成文本嵌入

    • import open_clip
      model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="laion2b_s34b_b79k"
      )
      tokenizer = open_clip.get_tokenizer("ViT-B-32")
      labels = ["an airplane","a bird","a car","a cat","a deer","a dog","a horse","a monkey","a ship","a truck"
      ]
      text = tokenizer(labels)
      text_embeddings = model.encode_text(text)
      
    • 现在,每个文本提示的嵌入都包含一个长度为 512 的向量。该向量与视觉编码器输出的大小相同。该向量与视觉特征的点积表示相似度,因此我们可以确定数据集的类别概率如下

    • import torch.nn.functional as F
      def embeddings_to_class_probs(vision_embeddings, text_embeddings)vision_embeddings = vision_embeddings / vision_embeddings.norm(dim=-1, keepdim=True)text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)logits = vision_embeddings @ text_embeddings.Tclass_probs = F.softmax(100. * logits, dim=-1)return class_probs
      
    • 现在,我们已经有了目标任务的文本嵌入以及将文本嵌入与图像嵌入进行比较的方法,剩下要做的就是通过 OpenCLIP 视觉编码器运行 STL10 数据集,计算输出类概率,并将结果与 GT 标签进行比较。

    • import tqdm
      from torchvision.datasets import STL10
      dataset = STL10(root=dataset_path,download=True,split="test"
      )
      num_correct = 0
      for image, label in tqdm.tqdm(dataset):input_tensor = preprocess(image).unsqueeze(0)vision_embeddings = model.encode_image(input_tensor)output_class_probs = embeddings_to_class_probs(vision_embeddings, text_embeddings)output_label = torch.argmax(dim=-1)num_correct += int(torch.count_nonzero(output_label == label))
      accuracy = 100. * num_correct / len(dataset)
      
    • 之后,开箱即用的 OpenCLIP 编码器在没有任何额外训练的情况下,在 STL10 测试数据集上获得了 96.68% 的准确率!在没有任何技巧的情况下,我们在 STL10 数据集上获得了相当有竞争力的准确率。

    • 使用线性 head 进行分类:使用文本提示作为类标签,我们能够在 STL10 数据集上实现相当高的准确率,而无需任何训练或基本真实标签。但是,如果我们有 GT 标签呢?我们能用它来提高准确率吗?有了这个选项,我们将探索如何使用一些 GT 数据,在 OpenCLIP 模型的末尾训练一个微小的逻辑回归层(线性层,然后是 softmax),看看这是否能提高准确率。为此,我们对线性层的定义如下

    • import torch.nn as nn
      linear_probe = nn.Linear(512, len(labels))
      
    • 然后,我们需要训练模型。这包括:从数据集中读取一批数据;运行 OpenCLIP 视觉编码器(无梯度);在 OpenCLIP 的输出上运行线性层;计算线性层输出与 GT 标签之间的交叉熵;更新线性层。

    • optimizer = torch.optim.Adam(linear_probe.parameters(), lr=3e-4)
      for epoch in range(num_epochs):for image, label in iter(train_loader):# ... run open-clip to get vision embeddingsoptimizer.zero_grad()output_logits = linear_probe(vision_embeddings)output_logprob = F.log_softmax(output_logits, dim=-1)loss = F.nll_loss(output_logprob, label)loss.backward()optimizer.step()
      
    • 训练完线性探针后,我们在 STL10 数据集上对其进行了评估,结果与之前类似,准确率达到了 98.57!通过使用一些标注数据,我们训练出了一个小型逻辑回归层,使 OpenCLIP 在 STL10 数据集上的准确率提高了近 +2%!这种改进可能是因为我们的文本提示(如 “an airplane”)可能与 STL10 数据集中的标签不完全匹配。但是,通过查看每个标签的一些示例,我们可以学习到更准确地代表类标签的参考嵌入。

训练学生模型以模仿 OpenCLIP

  • 我们现在已经看到,使用大型 OpenCLIP 模型,我们可以在 STL10 图像分类数据集上不费吹灰之力就取得具有竞争力的结果。但是,OpenCLIP 体积庞大,与其他模型架构相比,内存消耗和延迟可能较高。此外,作为视觉 Transformer 模型,OpenCLIP 在利用 Jetson AGX Orin 上的深度学习加速器(DLA)方面能力较弱,因为注意层中存在矩阵乘法。另一方面,像 resnet18 这样的 CNN 模型则通过 Jetson 上的 GPU 和 DLA 进行了高度优化,使我们能够以更高的吞吐量和更少的内存运行模型。然而,知识提炼会影响模型的准确性,因此我们希望更好地了解哪些因素最为重要。为此,我们进行了一些实验,试图回答几个问题:
    • 蒸馏法与使用 GT 标签进行的训练相比效果如何?可以使用 GT 标签从头开始训练 resnet18,并将其与使用 OpenCLIP 输出概率(使用文本提示标签和线性回归方法)训练 resnet18 进行比较。在每个实验中,我们只使用 STL10 训练数据集(5000 张图像),以便与基线训练进行公平比较。
      • 在使用相同数据进行训练的情况下,使用线性回归头的提炼模型比直接使用 GT 标签的准确度更高,即使每种方法都使用相同的可用数据和标签集。但是,文本提示方法无法达到从头开始训练 resnet18 的准确度。
      • 这表明,即使在数据分布相同的情况下,知识蒸馏也有能力提高模型的准确性。不过,在下一节中,我们将看到如何在提炼过程中利用无标签数据,从而更进一步。事实上,使用无标签数据,文本提示方法(在此过程中不需要任何地面实况标签)能够大大超过从头开始训练模型的效果!
    • 用于训练的数据分布对模型准确性有何影响?
      • 现在我们已经看到,知识提炼能够提高模型的准确性。但是,我们的最佳学生模型(准确率为 60.65%)仍然远远低于教师模型(准确率为 98.57%)。为什么会这样?学生模型 resnet18 是否缺乏模仿教师模型的能力?还是其他原因,也许是我们用来提炼的数据?
      • 为了帮助回答这个问题,我们进行了一系列实验,不仅使用了 STL10 训练数据集中的 5000 张图像,还使用了补充 STL10 数据集中提供的 100,000 张未标记图像进行知识提炼。只需在蒸馏过程中从 STL10 未标注数据集拆分出大量未标注图像,模型的准确性就会大幅提高!现在,带有文本提示标签的 resnet18 模型的准确率远远超过了使用 GT 标签训练的 resnet18,仅比带有文本提示标签的原始 OpenCLIP 模型低 2%。带有线性分类头的蒸馏模型则更进一步,达到了 96.88%,超过了带有文本提示的原始 OpenCLIP 模型,比带有线性分类头的最佳 OpenCLIP 变体低不到 2%。
      • 所以,总而言之、在没有标签,只有大量无标签数据的情况下,我们通过提炼文本提示 OpenCLIP 分类器,利用 resnet18 模型实现了 94.32% 的准确率;在有一些标签和大量无标签数据的情况下,我们可以通过提炼具有线性分类头的 OpenCLIP 分类器,使 resnet18 模型的准确率达到 96.88%。我们得到的经验是,蒸馏所使用的数据对于获得良好的精度非常重要。但我们的学生模型架构如何?使用 resnet50 这样的高容量模型能否取得更好的效果?
    • 学生模型架构对模型准确性有何影响?resnet50 会比 resnet18 获得更高的准确率吗?
      • 为了探索学生模型架构对最终模型准确性的影响,我们进行了一系列实验,使用了我们的最佳蒸馏配置,以及三种不同的模型架构:resnet18、resnet34 和 resnet50。在切换学生模型架构时,我们看到的差异相对较小,可以忽略不计。这意味着,至少对于这项任务(STL10 分类)来说,用于提炼的数据比学生模型架构重要得多。对于其他任务来说,情况很可能并非如此,但我们希望将这些结果包括进来,至少在这种情况下与大家分享这一发现,以便大家了解并优先考虑哪些因素需要首先探索。
    • 蒸馏方法对模型准确性有何影响?基于类概率还是内部特征进行训练更好?
      • 到目前为止,我们已经看到用于训练的数据对模型的准确性有很大影响,但用于蒸馏的方法又如何呢?如前所述,知识提炼可以通过以下几种方式进行:响应蒸馏: 拟合模型以学习输出类别概率;特征提炼: 拟合模型以学习内部特征
      • 为了探索这些决定的影响,我们进行了几次实验,训练 resnet18 学习 OpenCLIP 输出的视觉特征嵌入(512 维),而不是类别概率。我们将其输入文本提示或线性回归头,就像使用原始 OpenCLIP 模型一样。
      • 通过对特征进行训练(使用均方误差损失)获得的准确率略高于通过对输出类别概率进行训练(使用 KL 发散)获得的准确率。概述如下:通过特征训练,我们的文本提示学生准确率提高了 0.25;通过特征训练,我们的线性分类头学生准确率提高了 0.032
      • 虽然这些变化并不显著,但有趣的是,在 embeddings 上进行训练并不会对模型的准确性产生不利影响。之所以说这很有趣,是因为这些 embeddings 并不是明确针对 STL10 任务的,它们有可能像最初的 OpenCLIP 模型一样被重新利用,只需改变用于分类的文本提示或重新训练线性回归头即可
      • 不过,在本教程中,我们尚未以通用方式提炼 OpenCLIP。下一步,我们将对这种可能性进行有趣的探索。但现在,我们已经为目标任务建立了一个相当不错的分类模型。让我们讨论一下如何优化我们的学生模型以便部署,并看看延迟和内存消耗与原始 OpenCLIP 模型相比如何。

用 TensorRT 优化模型并比较性能

  • 以上我们展示了如何训练一个 resnet CNN 模型来模仿大型 OpenCLIP 模型。现在,让我们看看为什么这样的努力是值得的。使用我们的学生模型能带来哪些性能提升?为了了解每个模型的预期性能,我们将使用英伟达 TensorRT 对模型进行优化,并测量英伟达 Jetson 上的吞吐量和内存消耗。为此,我们用 ONNX 导出了模型,并用英伟达 TensorRT 进行了优化。下面我们将展示 OpenCLIP 和 resnet18 在 Jetson Orin Nano 上以 224x224 分辨率运行的性能,批量大小为 8。在使用 TensorRT 对每个模型进行优化后,resnet18 模型的速度是原始 open_clip 模型的 4.2 倍,而内存使用量则减少了 3.45 倍。

  • 为了测量内存,我们使用了 tegrastats。我们用 trtexec 记录了模型执行前和模型执行后的系统内存。上表中的内存是模型运行时系统内存的变化。对于这个特定的图像分类任务,Jetson Orin Nano 有足够能力运行原始 OpenCLIP 模型。但是,如果要运行更高分辨率的模型,这些吞吐量和内存消耗方面的差异可能会变得至关重要。

  • 我们探索了在 STL10 分类数据集上使用知识蒸馏来训练 resnet18 分类器。在这项任务中,我们取得了与原始 OpenCLIP 模型相当的准确率,同时大幅减少了运行时间和内存消耗。希望本教程能让您了解如何利用知识蒸馏技术将大型模型提升到边缘水平。除本介绍外,我们还创建了一个配套项目 clip-distillation,让您可以轻松创建零标签图像分类器,完成自己的自定义任务!它包括下载经过剪辑过滤的相关图片的脚本,以用于提炼;提取高效 CNN 模型以模仿 OpenCLIP 变换器模型的脚本,包括量化感知训练和结构稀疏性训练选项。

通过对 OpenCLIP 模型进行知识提炼

  • 可以使用零标记数据创建自己的定制图像分类模型。即使您不直接需要图像分类器,您也可能会发现这个项目很有帮助,它启发您如何使用知识提炼来优化用于推理的模型,或者作为一个示例,说明如何在 NVIDIA Jetson 上使用量化感知训练和结构化稀疏性来训练用于推理的模型。该项目NVIDIA-AI-IOT/clip-distillation: Zero-label image classification via OpenCLIP knowledge distillation (github.com)包括:
    • 从 LAION 数据库中搜索和下载相关数据的脚本,以用于提炼;
    • 将任何 OpenCLIP 模型提炼为任何 Pytorch 图像模型(timm)CNN 模型的脚本。支持面向下游 INT8 推理的量化感知训练 (QAT);支持使用 ASP 库执行 2:4 结构稀疏性训练;
    • 使用英伟达™(NVIDIA®)TensorRT 运行推理的脚本;支持 INT8 模式;支持在某些 NVIDIA Jetson 平台(如 NVIDIA Jetson Orin Nano)上加速 2:4 结构稀疏模型。

通过 CLIP 过滤搜索和下载图像

  • 在提炼模型时,我们需要做的第一件事就是获取用于提炼的数据。在这项任务中,我们将通过搜索 LAION 数据库来查找相关图片。我们提供了一个脚本来简化这项工作。要搜索相关图像,首先要创建一个 data/text_prompts.txt 文件,其中包含要查询的文本提示。每个提示都应独立成行。接下来,调用脚本查询与文本提示相匹配的图片。

    • python3 search_clip_images.py \"data/text_prompts.txt" \"data/image_urls.txt" \-n 5000 \-m 10000 \--max_workers 2 \--append
      
    • 这将输出一个文件 data/image_urls.txt,其中包含与我们的文本提示查询相匹配的图像 URL。现在我们已经找到了用于蒸馏的相关图片,我们需要下载它们。为此,我们调用以下脚本将图像下载到输出文件夹。

    • python3 download_images.py \"data/image_urls.txt" \"data/images" \--max_workers 32 \--timeout 2
      
    • 该脚本将把图片下载到 data/images 文件夹。每张图片都将根据其 URL 获得一个唯一的文件名。

计算 OpenCLIP 嵌入

  • 我们在上面下载的图像将在蒸馏过程中作为教师和学生模型的输入。遗憾的是,在训练过程中执行教师模型可能会比较慢。为了加快这一过程,我们将预先计算教师模型的输出,这样就不需要在训练过程中执行教师模型了。为此,请调用 compute_openclip_embeddings.py 脚本,如下所示、

    • python3 compute_openclip_embeddings.py \data/images \data/embeddings \--batch_size 16 \--num_workers 8 \--model_name ViT-B-32 \--pretrained laion2b_s34b_b79k
      
    • 这将把输出的嵌入文件写入 data/embeddings 文件夹,文件名与图像文件名一致,但文件扩展名除外。注:有关可用模型名称和预训练权重标识符,请参考 OpenCLIP Repoopen_clip/src/open_clip/pretrained.py at fb72f4db1b17133befd6c67c9cf32a533b85a321 · mlfoundations/open_clip (github.com)。

训练学生 CNN 模型以模仿 OpenCLIP 模型

  • 现在,我们已经有了用于知识提炼的数据,可以通过调用 distil_model_embeddings.py 脚本来执行提炼(学生模型训练),如下所示。

    • python3 distil_model_embeddings.py \resnet18 \data/images \data/embeddings \data/models/resnet18 \--output_dim 512 \--pretrained
      
    • 这将向 data/models/resnet18 输出模型检查点和信息。我们在本例中使用的提炼模型是 resnet18。该模型经过 TensorRT 高度优化,我们可以在训练过程中随时应用其他优化,如降低精度和结构稀疏性。

使用提炼的模型进行推理

  • 在提炼过程中,我们对学生模型进行了训练,使其与 open-clip 模型的特征相匹配。不过,我们有兴趣创建一个分类模型。要创建zero-shot分类模型,我们需要从描述类别标签的文本提示中生成文本嵌入。为此,我们使用了预先训练好的 OpenCLIP 文本编码器。我们调用 compute_openclip_text_embeddings.py 脚本来创建文本嵌入。

    • python3 compute_openclip_text_embeddings.py \data/text_prompts.txt \data/text_embeddings.npy \--model_name ViT-B-32
      
    • 在这种情况下,我们使用与图像搜索相同的文本提示作为分类的文本提示。现在,我们已经计算出了图像类别的文本提示,我们可以使用 PyTorch 模型进行图像分类,具体如下

    • python3 predict_pytorch.py \resnet18 \data/models/resnet18/checkpoint.pth \data/text_embeddings.npy \assets/cat.jpg \--text_prompts data/text_prompts.txt
      
  • 同样,我们也可以对实时摄像机画面进行如下推理:

    • python3 demo_pytorch.py \resnet18 \data/models/resnet18/checkpoint.pth \data/text_embeddings.npy \--text_prompts data/text_prompts.txt \--camera_device 0
      

利用结构化稀疏性训练学生模型

  • 训练脚本提供结构稀疏性训练功能。这可以在使用 TensorRT 的英伟达 Jetson 平台上部署模型时提供额外的加速。用结构化稀疏性训练模型

    • python3 distil_model_embeddings.py \resnet18 \data/images \data/embeddings \data/models/resnet18_sparse \--output_dim 512 \--pretrained \--init_checkpoint data/models/resnet18/checkpoint.pth \--use_asp \--num_epochs 25
      
    • 用 PyTorch 进行预测

    • python3 predict_pytorch.py \resnet18 \data/models/resnet18_sparse/checkpoint.pth \data/text_embeddings.npy \assets/cat.jpg \--text_prompts data/text_prompts.txt \--use_asp
      
    • PyTorch 演示

    • python3 demo_pytorch.py \resnet18 \data/models/resnet18_sparse/checkpoint.pth \data/text_embeddings.npy \--text_prompts data/text_prompts.txt \--camera_device 0 \--use_asp
      
    • 导出到 ONNX

    • python3 export_onnx.py \resnet18 \data/models/resnet18_sparse/checkpoint.pth \data/onnx/resnet18_sparse.onnx \--use_asp
      

使用量化感知训练和 INT8 精确度训练学生

  • 除了结构化稀疏性,我们还可以利用降低 INT8 精度来提高性能。量化感知训练是一种将使用 INT8 精度时引入的量化误差最小化的技术。它通过在训练过程中的模型前向传递过程中应用量化来实现这一目的。这样,模型就能在训练过程中适应量化误差。在使用训练后量化时,它还能让我们避免校准的需要。要使用量化感知训练来提炼模型,请按照以下步骤操作

    • 使用量化感知训练 (QAT) 训练模型

    • python3 distil_model_embeddings.py \resnet18 \data/images \data/embeddings \data/models/resnet18_qat \--output_dim 512 \--pretrained \--init_checkpoint data/models/resnet18/checkpoint.pth \--use_qat \--num_epochs 25
      
    • 用 PyTorch 进行预测

    • python3 predict_pytorch.py \resnet18 \data/models/resnet18_sparse/checkpoint.pth \data/text_embeddings.npy \assets/cat.jpg \--text_prompts data/text_prompts.txt \--use_qat
      
    • PyTorch 演示

    • python3 demo_pytorch.py \resnet18 \data/models/resnet18_sparse/checkpoint.pth \data/text_embeddings.npy \--text_prompts data/text_prompts.txt \--camera_device 0 \--use_qat
      
    • 导出到 ONNX

    • python3 export_onnx.py \resnet18 \data/models/resnet18_qat/checkpoint.pth \data/onnx/resnet18_qat.onnx \--use_qat
      
  • 希望您能在不使用任何标记数据的情况下训练出自己的图像分类模型。

知识蒸馏实战

  • 数据使用我以前在图像分类任务中的数据集——植物幼苗数据集,先将数据集转为训练集和验证集。执行代码:

    • import glob
      import os
      import shutil
      image_list=glob.glob('data1/*/*.png')
      print(image_list)
      file_dir='data'
      if os.path.exists(file_dir):print('true')#os.rmdir(file_dir)shutil.rmtree(file_dir)#删除再建立os.makedirs(file_dir)
      else:os.makedirs(file_dir)from sklearn.model_selection import train_test_split
      trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
      train_dir='train'
      val_dir='val'
      train_root=os.path.join(file_dir,train_dir)
      val_root=os.path.join(file_dir,val_dir)
      for file in trainval_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(train_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)
      for file in val_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(val_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)
      
  • 教师网络选用coatnet_2,是一个比较大一点的网络了,模型的大小有200M。训练50个epoch,最好的模型在92%左右。导入需要的库

    • import torch.optim as optim
      import torch
      import torch.nn as nn
      import torch.nn.parallel
      import torch.utils.data
      import torch.utils.data.distributed
      import torchvision.transforms as transforms
      from torchvision import datasets
      from torch.autograd import Variable
      from model.coatnet import coatnet_2
      import json
      import os
      
    • 定义训练和验证函数

    • def train(model, device, train_loader, optimizer, epoch):model.train()sum_loss = 0total_num = len(train_loader.dataset)print(total_num, len(train_loader))for batch_idx, (data, target) in enumerate(train_loader):data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()print_loss = loss.data.item()sum_loss += print_lossif (batch_idx + 1) % 10 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),100. * (batch_idx + 1) / len(train_loader), loss.item()))ave_loss = sum_loss / len(train_loader)print('epoch:{},loss:{}'.format(epoch, ave_loss))
      Best_ACC=0
      # 验证过程
      @torch.no_grad()
      def val(model, device, test_loader):global Best_ACCmodel.eval()test_loss = 0correct = 0total_num = len(test_loader.dataset)print(total_num, len(test_loader))with torch.no_grad():for data, target in test_loader:data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)_, pred = torch.max(output.data, 1)correct += torch.sum(pred == target)print_loss = loss.data.item()test_loss += print_losscorrect = correct.data.item()acc = correct / total_numavgloss = test_loss / len(test_loader)if acc > Best_ACC:torch.save(model, file_dir + '/' + 'best.pth')Best_ACC = accprint('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(avgloss, correct, len(test_loader.dataset), 100 * acc))return acc
      
    • 定义全局参数

    • if __name__ == '__main__':# 创建保存模型的文件夹file_dir = 'CoatNet'if os.path.exists(file_dir):print('true')os.makedirs(file_dir, exist_ok=True)else:os.makedirs(file_dir)# 设置全局参数modellr = 1e-4BATCH_SIZE = 16EPOCHS = 50DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      
    • 图像预处理与增强

    •     transform = transforms.Compose([transforms.RandomRotation(10),transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])])transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])])
      
    • 使用pytorch默认读取数据的方式。

    • 	# 读取数据dataset_train = datasets.ImageFolder('data/train', transform=transform)dataset_test = datasets.ImageFolder("data/val", transform=transform_test)with open('class.txt', 'w') as file:file.write(str(dataset_train.class_to_idx))with open('class.json', 'w', encoding='utf-8') as file:file.write(json.dumps(dataset_train.class_to_idx))# 导入数据train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
      
    • 设置模型和Loss

    •  	# 实例化模型并且移动到GPUcriterion = nn.CrossEntropyLoss()model_ft = coatnet_2()num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, 12) # 维度对齐model_ft.to(DEVICE)# 选择简单暴力的Adam优化器,学习率调低optimizer = optim.Adam(model_ft.parameters(), lr=modellr)cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)# 训练val_acc_list= {}for epoch in range(1, EPOCHS + 1):train(model_ft, DEVICE, train_loader, optimizer, epoch)cosine_schedule.step()acc=val(model_ft, DEVICE, test_loader)val_acc_list[epoch]=accwith open('result.json', 'w', encoding='utf-8') as file:file.write(json.dumps(val_acc_list))torch.save(model_ft, 'CoatNet/model_final.pth')
      
    • 完成上面的代码就可以开始训练Teacher网络了。

  • 学生网络选用ResNet18,是一个比较小一点的网络了,模型的大小有40M。训练50个epoch,最好的模型在86%左右。导入需要的库

    • import torch.optim as optim
      import torch
      import torch.nn as nn
      import torch.nn.parallel
      import torch.utils.data
      import torch.utils.data.distributed
      import torchvision.transforms as transforms
      from torchvision import datasets
      from torch.autograd import Variable
      from torchvision.models.resnet import resnet18
      import json
      import os
      
    • 定义训练和验证函数

    • def train(model, device, train_loader, optimizer, epoch):model.train()sum_loss = 0total_num = len(train_loader.dataset)print(total_num, len(train_loader))for batch_idx, (data, target) in enumerate(train_loader):data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()print_loss = loss.data.item()sum_loss += print_lossif (batch_idx + 1) % 10 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),100. * (batch_idx + 1) / len(train_loader), loss.item()))ave_loss = sum_loss / len(train_loader)print('epoch:{},loss:{}'.format(epoch, ave_loss))
      Best_ACC=0
      # 验证过程
      @torch.no_grad()
      def val(model, device, test_loader):global Best_ACCmodel.eval()test_loss = 0correct = 0total_num = len(test_loader.dataset)print(total_num, len(test_loader))with torch.no_grad():for data, target in test_loader:data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)_, pred = torch.max(output.data, 1)correct += torch.sum(pred == target)print_loss = loss.data.item()test_loss += print_losscorrect = correct.data.item()acc = correct / total_numavgloss = test_loss / len(test_loader)if acc > Best_ACC:torch.save(model, file_dir + '/' + 'best.pth')Best_ACC = accprint('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(avgloss, correct, len(test_loader.dataset), 100 * acc))return acc
      
    • 定义全局参数

    • if __name__ == '__main__':# 创建保存模型的文件夹file_dir = 'resnet'if os.path.exists(file_dir):print('true')os.makedirs(file_dir, exist_ok=True)else:os.makedirs(file_dir)# 设置全局参数modellr = 1e-4BATCH_SIZE = 16EPOCHS = 50DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      
    • 图像预处理与增强

    •     transform = transforms.Compose([transforms.RandomRotation(10),transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])])transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])])
      
    • 使用pytorch默认读取数据的方式

    •     # 读取数据dataset_train = datasets.ImageFolder('data/train', transform=transform)dataset_test = datasets.ImageFolder("data/val", transform=transform_test)with open('class.txt', 'w') as file:file.write(str(dataset_train.class_to_idx))with open('class.json', 'w', encoding='utf-8') as file:file.write(json.dumps(dataset_train.class_to_idx))# 导入数据train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
      
    • 设置模型和Loss

    • 	# 实例化模型并且移动到GPUcriterion = nn.CrossEntropyLoss()model_ft = resnet18()print(model_ft)num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, 12)model_ft.to(DEVICE)# 选择简单暴力的Adam优化器,学习率调低optimizer = optim.Adam(model_ft.parameters(), lr=modellr)cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)# 训练val_acc_list= {}for epoch in range(1, EPOCHS + 1):train(model_ft, DEVICE, train_loader, optimizer, epoch)cosine_schedule.step()acc=val(model_ft, DEVICE, test_loader)val_acc_list[epoch]=accwith open('result_student.json', 'w', encoding='utf-8') as file:file.write(json.dumps(val_acc_list))torch.save(model_ft, 'resnet/model_final.pth')
      
    • 完成上面的代码就可以开始训练Student网络了。

  • 蒸馏学生网络,学生网络继续选用ResNet18,使用Teacher网络蒸馏学生网络,训练50个epoch,最终ACC是89%。新建student_kd_train.py,导入需要的库

    • import torch.optim as optim
      import torch
      import torch.nn as nn
      import torch.nn.parallel
      import torch.utils.data
      import torch.utils.data.distributed
      import torchvision.transforms as transforms
      from torchvision import datasets
      from torch.autograd import Variable
      from torchvision.models.resnet import resnet18
      import json
      import os
      
    • 定义蒸馏函数

    • def distillation(y, labels, teacher_scores, temp, alpha):return nn.KLDivLoss()(F.log_softmax(y / temp, dim=1), F.softmax(teacher_scores / temp, dim=1)) * (temp * temp * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)
      
    • 定义训练和验证函数

    • # 定义训练过程
      def train(model, device, train_loader, optimizer, epoch):model.train()sum_loss = 0total_num = len(train_loader.dataset)print(total_num, len(train_loader))for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)teacher_output = teacher_model(data)  # 训练出教师的 teacher_outputteacher_output = teacher_output.detach()  # 切断老师网络的反向传播loss = distillation(output, target, teacher_output, temp=7.0, alpha=0.7)  # 通过老师的 teacher_output训练学生的outputloss.backward()optimizer.step()print_loss = loss.data.item()sum_loss += print_lossif (batch_idx + 1) % 10 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),100. * (batch_idx + 1) / len(train_loader), loss.item()))ave_loss = sum_loss / len(train_loader)print('epoch:{},loss:{}'.format(epoch, ave_loss))
      Best_ACC=0
      # 验证过程
      @torch.no_grad()
      def val(model, device, test_loader):global Best_ACCmodel.eval()test_loss = 0correct = 0total_num = len(test_loader.dataset)print(total_num, len(test_loader))with torch.no_grad():for data, target in test_loader:data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)_, pred = torch.max(output.data, 1)correct += torch.sum(pred == target)print_loss = loss.data.item()test_loss += print_losscorrect = correct.data.item()acc = correct / total_numavgloss = test_loss / len(test_loader)if acc > Best_ACC:torch.save(model, file_dir + '/' + 'best.pth')Best_ACC = accprint('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(avgloss, correct, len(test_loader.dataset), 100 * acc))return acc
      
    • 定义全局参数

    • if __name__ == '__main__':# 创建保存模型的文件夹file_dir = 'resnet_kd'if os.path.exists(file_dir):print('true')os.makedirs(file_dir, exist_ok=True)else:os.makedirs(file_dir)# 设置全局参数modellr = 1e-4BATCH_SIZE = 16EPOCHS = 50DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      
    • 图像预处理与增强

    • 	transform = transforms.Compose([transforms.RandomRotation(10),transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])])transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])])
      
    • 使用pytorch默认读取数据的方式

    •     # 读取数据dataset_train = datasets.ImageFolder('data/train', transform=transform)dataset_test = datasets.ImageFolder("data/val", transform=transform_test)with open('class.txt', 'w') as file:file.write(str(dataset_train.class_to_idx))with open('class.json', 'w', encoding='utf-8') as file:file.write(json.dumps(dataset_train.class_to_idx))# 导入数据train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
      
    • 设置模型和Loss

    •     criterion = nn.CrossEntropyLoss()model_ft = resnet18()print(model_ft)num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, 12)model_ft.to(DEVICE)# 选择简单暴力的Adam优化器,学习率调低optimizer = optim.Adam(model_ft.parameters(), lr=modellr)cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)# 训练val_acc_list= {}for epoch in range(1, EPOCHS + 1):train(model_ft, DEVICE, train_loader, optimizer, epoch)cosine_schedule.step()acc=val(model_ft, DEVICE, test_loader)val_acc_list[epoch]=accwith open('result_student.json', 'w', encoding='utf-8') as file:file.write(json.dumps(val_acc_list))torch.save(model_ft, 'resnet_kd/model_final.pth')
      
  • 加载保存的结果,然后绘制acc曲线。

    • import numpy as np
      from matplotlib import pyplot as plt
      import json
      teacher_file='result.json'
      student_file='result_student.json'
      student_kd_file='result_kd.json'
      def read_json(file):with open(file, 'r', encoding='utf8') as fp:json_data = json.load(fp)print(json_data)return json_data
      teacher_data=read_json(teacher_file)
      student_data=read_json(student_file)
      student_kd_data=read_json(student_kd_file)
      x =[int(x) for x in  list(dict(teacher_data).keys())]
      print(x)
      plt.plot(x, list(teacher_data.values()), label='teacher')
      plt.plot(x,list(student_data.values()), label='student without KD')
      plt.plot(x, list(student_kd_data.values()), label='student with KD')
      plt.title('Test accuracy')
      plt.legend()
      plt.show()
      

知识蒸馏的另一个例子

  • 知识蒸馏的核心思想是利用教师模型的输出作为附加的监督信号来训练学生模型。在传统的监督学习中,目标是最小化模型预测与真实标签之间的差距(损失函数)。而在知识蒸馏中,除了最小化模型预测与真实标签之间的差距外,还引入了一个额外的损失项,该项衡量了学生模型预测与教师模型预测之间的距离。 具体而言,损失函数通常由两部分组成:一部分是传统的交叉熵损失,用于衡量学生模型的预测与真实标签之间的差距;另一部分是知识蒸馏损失,用于衡量学生模型的预测与教师模型的预测之间的差距。知识蒸馏损失通常使用一些形式的距离度量来计算,例如平方误差损失或者交叉熵损失。在蒸馏求loss时候,需要采用蒸馏函数,这个函数就是把softmax函数在计算时候,预测出来的结果Z进行除以温度T,进行求解后验概率。下面是修改后的softmax函数,也就是蒸馏函数。

    • q i = e x p ( z i / T ) ∑ j e x p ( z j / T ) q_i=\frac{exp(z_i/T)}{\sum_jexp(z_j/T)} qi=jexp(zj/T)exp(zi/T)

    • 在这里插入图片描述

  • 首先我们要先训练出较大模型既teacher模型。再对teacher模型进行蒸馏,此时已经有一个训练好的teacher模型,所以我们能很容易知道teacher模型输入特征 x 之后,预测出来的结果teacher_preds标签。求到老师预测结果之后,我们需要求解学生在训练过程中的每一次结果student_preds标签。

    • 先求hard_loss,也就是学生模型的预测student_preds与真实标签targets之间的损失。
    • 再求soft_loss,也就是学生模型的预测student_preds与教师模型teacher_preds的预测之间的损失。
    • 求出hard_loss与soft_loss之后,求和总loss=a*hard_loss + (1-a)soft_loss,a是一个自己设置的权重参数,最后反向传播继续迭代。
  • 数据集采用的是手写数字的数据集mnist数据集,如果没有下载,代码部分中会进行下载,只需要把download改成True,然后就会保存在当前目录中。该数据集将其分成80%的训练集,20%的测试集,最后返回train_dataset和test_datatset。

    • class MyDataset(Dataset):def __init__(self,opt):self.opt = optdef MyData(self):## mnist数据集下载0mnist = datasets.MNIST(root='../datasets/', train=True, download=True, transform=transforms.Compose([transforms.Resize(self.opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),)dataset_size = len(mnist)train_size = int(0.8 * dataset_size)test_size = dataset_size - train_sizetrain_dataset, test_dataset = random_split(mnist, [train_size, test_size])train_dataloader = DataLoader(train_dataset,batch_size=self.opt.batch_size,shuffle=True,)test_dataloader = DataLoader(test_dataset,batch_size=self.opt.batch_size,shuffle=False,  # 在测试集上不需要打乱顺序)return train_dataloader,test_dataloader
      
    • 首先是teacher模型构造,经过三次线性层。

    • import torch.nn as nn
      import torch
      img_area = 784
      class TeacherModel(nn.Module):def __init__(self,in_channel=1,num_classes=10):super(TeacherModel,self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(img_area,1200)self.fc2 = nn.Linear(1200, 1200)self.fc3 = nn.Linear(1200, num_classes)self.dropout = nn.Dropout(p=0.5)def forward(self, x):x = x.view(-1, img_area)x = self.fc1(x)x = self.dropout(x)x = self.relu(x)x = self.fc2(x)x = self.dropout(x)x = self.relu(x)x = self.fc3(x)return x
      
    • 训练teacher模型,老师模型训练完成后其权重参数会保存在teacher.pth当中,为以后调用。

    • import torch.nn as nn
      import torch
      from tqdm import tqdm
      from dist.TeacherModel import TeacherModel
      weight_path = './teacher.pth'
      cuda = True if torch.cuda.is_available() else False
      device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      torch.backends.cudnn.benchmark = True #使用卷积cuDNN加速
      class TeacherTrainer():def __init__(self,opt,train_dataloader,test_dataloader):self.opt = optself.train_dataloader = train_dataloaderself.test_dataloader = test_dataloaderdef trainer(self):# 老师模型opt = self.opttrain_dataloader = self.train_dataloadertest_dataloader = self.test_dataloaderteacher_model = TeacherModel()teacher_model = teacher_model.to(device)criterion = nn.CrossEntropyLoss()optimizer_teacher = torch.optim.Adam(teacher_model.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))for epoch in range(opt.n_epochs):  ## epoch:50teacher_model.train()for data, targets in tqdm(train_dataloader):data = data.to(device)targets = targets.to(device)preds = teacher_model(data)loss = criterion(preds, targets)optimizer_teacher.zero_grad()loss = criterion(preds, targets)loss.backward()optimizer_teacher.step()teacher_model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_dataloader:x = x.to(device)y = y.to(device)preds = teacher_model(x)predictions = preds.max(1).indicesnum_correct += (predictions == y).sum()num_samples += predictions.size(0)acc = (num_correct / num_samples).item()torch.save(teacher_model.state_dict(), weight_path)teacher_model.train()print('teacher: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))
      
    • 设置参数以及主函数

    • import argparse
      import torch
      from dist.DistillationTrainer import DistillationTrainer
      from dist.MyDateLoader import MyDataset
      from dist.TeacherTrainer import TeacherTrainer
      device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      def paras():## 超参数配置parser = argparse.ArgumentParser()parser.add_argument("--n_epochs", type=int, default=5, help="number of epochs of training")parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")parser.add_argument("--channels", type=int, default=1, help="number of image channels")parser.add_argument("--sample_interval", type=int, default=500, help="interval betwen image samples")opt = parser.parse_args()## opt = parser.parse_args(args=[])                 ## 在colab中运行时,换为此行print(opt)return opt
      if __name__ == '__main__':opt = paras()data = MyDataset(opt)train_dataloader, test_dataloader = data.MyData()# 训练Teacher模型teacher_trainer = TeacherTrainer(opt,train_dataloader,test_dataloader)teacher_trainer.trainer()
      
    • 学生模型的构建,学生模型也是经过了三次线性层,但是神经元没有teacher当中多。所以student模型会比teacher模型小很多。

    • import torch.nn as nn
      import torch
      img_area = 784
      class StudentModel(nn.Module):def __init__(self,in_channel=1,num_classes=10):super(StudentModel,self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(img_area,20)self.fc2 = nn.Linear(20, 20)self.fc3 = nn.Linear(20, num_classes)def forward(self, x):x = x.view(-1, img_area)x = self.fc1(x)# x = self.dropout(x)x = self.relu(x)x = self.fc2(x)# x = self.dropout(x)x = self.relu(x)x = self.fc3(x)return x
      
    • 知识蒸馏训练, 首先将teacher模型中的权重参数teacher.pth放入模型当中。

    • #加载训练好的模型
      teacher_model = TeacherModel()
      if os.path.exists(weights):teacher_model.load_state_dict(torch.load(weights))print('successfully')
      else:print('not loading')
      teacher_model = teacher_model.to(device)
      
    • 设置损失求解的函数, hard_loss用的就是普通的交叉熵损失函数,而soft_loss就是用的KL散度。

    • hard_loss = nn.CrossEntropyLoss()
      alpha = 0.3# hard_loss权重
      # soft_loss
      soft_loss = nn.KLDivLoss(reduction="batchmean")
      
    • 之后再进行蒸馏训练,温度为7。 先求得hard_loss就是用学生模型预测的标签和真实标签进行求得损失。再求soft_loss就是用学生模型预测的标签和老师模型预测的标签进行求得损失。使用softmax时候还需要进行除以温度temp。最后反向传播,求解模型

    •        for epoch in range(opt.n_epochs):  ## epoch:5for data, targets in tqdm(train_dataloader):data = data.to(device)targets = targets.to(device)# 老师模型预测with torch.no_grad():teacher_preds = teacher_model(data)# 学生模型预测student_preds = model(data)# 计算hard_lossstudent_loss = hard_loss(student_preds, targets)# 计算蒸馏后的预测损失ditillation_loss = soft_loss(F.softmax(student_preds / temp, dim=1),F.softmax(teacher_preds / temp, dim=1))loss = alpha * student_loss + (1 - alpha) * ditillation_lossoptimizer.zero_grad()loss.backward()optimizer.step()model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_dataloader:x = x.to(device)y = y.to(device)preds = model(x)predictions = preds.max(1).indicesnum_correct += (predictions == y).sum()num_samples += predictions.size(0)acc = (num_correct / num_samples).item()model.train()print('distillation: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))
      
    • 整个蒸馏训练代码

    • import torch.nn as nn
      import torch
      import torch.nn.functional as F
      import os
      from tqdm import tqdm
      from dist.StudentModel import StudentModel
      from dist.TeacherModel import TeacherModel
      weights = './teacher.pth'
      ## 设置cuda:(cuda:0)
      cuda = True if torch.cuda.is_available() else False
      device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      torch.backends.cudnn.benchmark = True #使用卷积cuDNN加速
      class DistillationTrainer():def __init__(self,opt,train_dataloader,test_dataloader):self.opt = optself.train_dataloader = train_dataloaderself.test_dataloader = test_dataloaderdef trainer(self):opt = self.opttrain_dataloader = self.train_dataloadertest_dataloader = self.test_dataloaderteacher_model = TeacherModel()if os.path.exists(weights):teacher_model.load_state_dict(torch.load(weights))print('successfully')else:print('not loading')teacher_model = teacher_model.to(device)teacher_model.eval()model = StudentModel()model = model.to(device)temp = 7# hard_losshard_loss = nn.CrossEntropyLoss()# hard_loss权重alpha = 0.3# soft_losssoft_loss = nn.KLDivLoss(reduction="batchmean")optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))for epoch in range(opt.n_epochs):  ## epoch:5for data, targets in tqdm(train_dataloader):data = data.to(device)targets = targets.to(device)# 老师模型预测with torch.no_grad():teacher_preds = teacher_model(data)# 学生模型预测student_preds = model(data)# 计算hard_lossstudent_loss = hard_loss(student_preds, targets)# 计算蒸馏后的预测损失ditillation_loss = soft_loss(F.softmax(student_preds / temp, dim=1),F.softmax(teacher_preds / temp, dim=1))loss = alpha * student_loss + (1 - alpha) * ditillation_lossoptimizer.zero_grad()loss.backward()optimizer.step()model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_dataloader:x = x.to(device)y = y.to(device)preds = model(x)predictions = preds.max(1).indicesnum_correct += (predictions == y).sum()num_samples += predictions.size(0)acc = (num_correct / num_samples).item()model.train()print('distillation: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))
      
    • 蒸馏训练的主函数,该部分大致与teacher模型训练类似,只是调用不同。

    • import argparse
      import torch
      from dist.DistillationTrainer import DistillationTrainer
      from dist.MyDateLoader import MyDataset
      from dist.TeacherTrainer import TeacherTrainer
      device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      def paras():## 超参数配置parser = argparse.ArgumentParser()parser.add_argument("--n_epochs", type=int, default=5, help="number of epochs of training")parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")parser.add_argument("--channels", type=int, default=1, help="number of image channels")parser.add_argument("--sample_interval", type=int, default=500, help="interval betwen image samples")opt = parser.parse_args()## opt = parser.parse_args(args=[])                 ## 在colab中运行时,换为此行print(opt)return opt
      if __name__ == '__main__':opt = paras()data = MyDataset(opt)train_dataloader, test_dataloader = data.MyData()# 训练Teacher模型# teacher_trainer = TeacherTrainer(opt,train_dataloader,test_dataloader)# teacher_trainer.trainer()distillation_trainer = DistillationTrainer(opt,train_dataloader,test_dataloader)distillation_trainer.trainer()
      
  • 示例

  • import torch
    from torch import nn
    import torch.nn.functional as F
    from torch.utils.data import DataLoader
    from tqdm import tqdm
    import torchvision
    from torchvision import transforms
    class TeacherModel(nn.Module):def __init__(self, in_channels=1, num_classes=10):super(TeacherModel, self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(784, 1200)self.fc2 = nn.Linear(1200, 1200)self.fc3 = nn.Linear(1200, num_classes)self.dropout = nn.Dropout(p=0.5)def forward(self, x):x = x.view(-1, 784) #输入的图像是一个28x28像素的灰度图像,因此输入层有784个神经元。x = self.relu(self.dropout(self.fc1(x)))x = self.relu(self.dropout(self.fc2(x)))x = self.fc3(x)return x
    class StudentModel(nn.Module):def __init__(self, in_channels=1, num_classes=10):super(StudentModel, self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(784, 20)self.fc2 = nn.Linear(20, 20)self.fc3 = nn.Linear(20, num_classes)self.dropout = nn.Dropout(p=0.5)def forward(self, x):x = x.view(-1, 784)x = self.relu(self.dropout(self.fc1(x)))x = self.relu(self.dropout(self.fc2(x)))x = self.fc3(x)return x
    def teacher(device, train_loader, test_loader):print('--------------teachermodel start--------------')model = TeacherModel()model = model.to(device)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)epochs = 6for epoch in range(epochs):model.train()for data, target in tqdm(train_loader):data = data.to(device) #(32,1,28,28)target = target.to(device) #(32,)preds = model(data)loss = criterion(preds, target)optimizer.zero_grad()loss.backward()optimizer.step()model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_loader:x = x.to(device)y = y.to(device)preds = model(x)predictions = preds.max(1).indicesnum_correct += (predictions.eq(y)).sum().item()num_samples += predictions.size(0)acc = num_correct / num_samplesmodel.train()print('Epoch:{}\t Acc:{:.4f}'.format(epoch + 1, acc))torch.save(model, 'teacher.pkl')print('--------------teachermodel end--------------')
    def student(device, train_loader, test_loader):print('--------------studentmodel start--------------')model = StudentModel()model = model.to(device)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)epochs = 3for epoch in range(epochs):model.train()for data, target in tqdm(train_loader):data = data.to(device)target = target.to(device)preds = model(data)loss = criterion(preds, target)optimizer.zero_grad()loss.backward()optimizer.step()model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_loader:x = x.to(device)y = y.to(device)# print(y)preds = model(x)predictions = preds.max(1).indices# print(predictions)num_correct += (predictions.eq(y)).sum().item()num_samples += predictions.size(0)acc = num_correct / num_samplesmodel.train()print('Epoch:{}\t Acc:{:.4f}'.format(epoch + 1, acc))print('--------------studentmodel prediction end--------------')
    def kd(teachermodel, device, train_loader, test_loader):print('--------------kdmodel start--------------')teachermodel.eval() #将模型设置为评估模式,Dropout和Batch Normalization 等层会被固定住;模型参数不会被更新,即不会进行反向传播和梯度更新,只会进行前向传播计算studentmodel = StudentModel()studentmodel = studentmodel.to(device)studentmodel.train()temp = 7    #蒸馏温度alpha = 0.3hard_loss = nn.CrossEntropyLoss()soft_loss = nn.KLDivLoss(reduction='batchmean')#计算两个概率分布之间的KL散度,用于度量两个分布之间的差异或相似性,需要提供两个输入张量:input预测的概率分布和 target目标概率分布optimizer = torch.optim.Adam(studentmodel.parameters(), lr=1e-4)epochs = 20for epoch in range(epochs):for data, target in tqdm(train_loader):data = data.to(device)target = target.to(device)with torch.no_grad(): #当你使用 torch.no_grad() 包裹代码块时,该代码块内的张量操作将不会被追踪,也不会计算梯度。这可以减少内存消耗并提高代码的执行效率。teacher_preds = teachermodel(data)student_preds = studentmodel(data)student_loss = hard_loss(student_preds, target) #hard_lossdistillation_loss = soft_loss(F.log_softmax(student_preds / temp, dim=1),F.softmax(teacher_preds / temp, dim=1))   #soft_lossloss = alpha * student_loss + (1 - alpha) * distillation_lossoptimizer.zero_grad()loss.backward()optimizer.step()studentmodel.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_loader:x = x.to(device)y = y.to(device)preds = studentmodel(x)predictions = preds.max(1).indicesnum_correct += (predictions.eq(y)).sum().item()num_samples += predictions.size(0)acc = num_correct / num_samplesstudentmodel.train()print('Epoch:{}\t Acc:{:.4f}'.format(epoch + 1, acc))print('--------------kdmodel end--------------')
    if __name__ == '__main__':torch.manual_seed(0)device = torch.device("cpu")torch.backends.cudnn.benchmark = True#加载数据集X_train = torchvision.datasets.MNIST(root="dataset/",train=True,transform=transforms.ToTensor(),download=True)X_test = torchvision.datasets.MNIST(root="dataset/",train=False,transform=transforms.ToTensor(),download=True)train_loader = DataLoader(dataset=X_train, batch_size=32, shuffle=True)test_loader = DataLoader(dataset=X_test, batch_size=32, shuffle=False)#从头训练教师模型,并预测teacher(device, train_loader, test_loader)#从头训练学生模型,并预测student(device, train_loader, test_loader)#知识蒸馏训练学生模型model = torch.load('teacher.pkl')kd(model, device, train_loader, test_loader)
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/685485.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

3D Gaussian Splatting for Real-Time Radiance Field Rendering 论文阅读

如此热门的项目,网络上有很多大牛分析了这篇文章的做法,在这里简单记录一下个人粗浅的理解。 关于各种数学表达式的推导,论文和参考资料中都提供了较为详细的解读,本人能力有限,这一部分理解不够深刻,先不做…

分布式光伏管理平台功能介绍

一、项目管理系统 1、关键信息:板块化展现项目关键信息,包含所在区域、屋面类型、未来25年发电量、累计收益等信息。 (1) 可迅速获取项目核心要点 (2) 及时跟进修改,凸显项目信息 (3) 项目信息清晰展现,了解整体项目流程 2、项…

音视频开发4 FFmpeg windows 环境搭建,QT 安装,动态库的搜索路径

FFmpeg 为了让所有平台的开发者都能够学习到音视频开发的通用技术,本教程主要讲解跨平台的音视频开发库FFmpeg。其实只要你掌握了FFmpeg,也可以很快上手其他音视频开发库,因为底层原理都是一样的,你最终操作的都是一样的数据&…

【问题分析】锁屏界面调起google语音助手后壁纸不可见【Android 14】

1 问题描述 为系统和锁屏分别设置两张不同的壁纸,然后在锁屏界面长按Power调起google语音助手后,有时候会出现壁纸不可见的情况,如以下截图所示: 有的时候又是正常的,但显示的也是系统壁纸,并非是锁屏壁纸…

【gin框架入门】

1.介绍 Gin 是一个用 Golang编写的 高性能的web 框架, 由于http路由的优化,速度提高了近 40 倍。 Gin的特点就是封装优雅、API友好。 Gin的一些特性: 快速 基于 Radix 树的路由,小内存占用。没有反射。可预测的 API 性能。 支持中间件 传入…

聊聊测试团队管理

管理测试团队是一个复杂但至关重要的任务,它不仅关乎于保证软件产品的质量,还涉及到团队建设、流程优化、技能提升等多个方面。以下是一些关键策略,可以帮助您有效地管理测试团队,比如“持续培训和技术支持,明确目标&a…

AIGC 时代软件工程师:前景、需求与大模型提效探究

过去,在互联网浪潮汹涌的十年来,软件工程师的角色愈发凸显其不可或缺的价值。随着AIGC(人工智能生成内容)时代的到来,软件开发的每个环节都正在经历一场前所未有的革新。今天,我们深入研究了大型AI模型如何…

相交链表(给你两个链表,找出它们的第一个公共结点)的实现与讲解

一:题目 二:思路讲解 1:应该是先判断是否相交。 2:相交就计算出相交前的两条链表的长度差距 3:让长的那一条链表先走长度的差距,这样会距离交点的距离相等,然后再对两条链表的节点对应着去比…

KaiwuDB 参编的《分析型数据库技术要求》标准正式发布

近期,中国电子工业标准化技术协会正式发布团体标准《分析型数据库技术要求》(项目号:T-CESA 2023-006)。该标准由中国电子技术标准化研究院、KaiwuDB(上海沄熹科技有限公司) 等国内 16 家企业联合起草&…

C#开发的网络速度计 - 开源研究系列文章 - 个人小作品

上次发布了一个获取网络速度的例子( https://www.cnblogs.com/lzhdim/p/18167854 ),就是为了这次这个例子。用于在托盘里显示网络速度的图标,并且能够显示网络速度。下面就介绍一下这个小应用的源码。 1、 项目目录; 2、 源码介绍&#xff1b…

javac编译web项目中的src

对于单个文件的且不引用其他类文件的java源码用javac编译大家都很熟悉即 javac hello.java, 服务器未安装idea,现在在服务器里面直接编译src目录 1 idea项目结构如下 2 web目录为最终部署的代码 WEB-INF下面没有 classes 目录 3 使用javac 编译src javac -encod…

Python中的数据可视化:阶梯图matplotlib.pyplot.step()

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 Python中的数据可视化: 阶梯图 matplotlib.pyplot.step() [太阳]选择题 matplotlib.pyplot.step()的功能是? import matplotlib.pyplot as plt import numpy as…