FasterViT实战:使用FasterViT实现图像分类任务(一)

文章目录

  • 摘要
  • 安装包
    • 安装timm
    • 安装 grad-cam
  • 数据增强Cutout和Mixup
  • EMA
  • 项目结构
  • 计算mean和std
  • 生成数据集

摘要

论文翻译:https://blog.csdn.net/m0_47867638/article/details/131542132
官方源码:https://github.com/NVlabs/FasterViT

这是一篇来自英伟达的论文。FasterViT结合了CNN的快速局部表示学习和ViT的全局建模特性的优点。新提出的分层注意力(HAT)方法将具有二次复杂度的全局自注意力分解为具有减少计算成本的多级注意力。受益于基于窗口的高效自注意力。每个窗口都可以访问参与局部和全局表示学习的专用载体Token。在高层次上,全局的自注意力使高效的跨窗口通信能够以较低的成本实现。FasterViT在精度与图像吞吐量方面达到了SOTA Pareto-front。广泛地验证了它在各种CV任务上的有效性,包括分类、目标检测和分割。HAT可以用作现有网络的即插即用模块并增强它们。

在这里插入图片描述

这篇文章使用FasterViT完成植物分类任务,模型采用faster_vit_0_224向大家展示如何使用FasterViT。由于没有预训练模型,faster_vit_0_224在这个数据集上实现了86+%的ACC,如下图:

请添加图片描述
在这里插入图片描述

通过这篇文章能让你学到:

  1. 如何使用数据增强,包括transforms的增强、CutOut、MixUp、CutMix等增强手段?
  2. 如何实现FasterViT模型实现训练?
  3. 如何使用pytorch自带混合精度?
  4. 如何使用梯度裁剪防止梯度爆炸?
  5. 如何使用DP多显卡训练?
  6. 如何绘制loss和acc曲线?
  7. 如何生成val的测评报告?
  8. 如何编写测试脚本测试测试集?
  9. 如何使用余弦退火策略调整学习率?
  10. 如何使用AverageMeter类统计ACC和loss等自定义变量?
  11. 如何理解和统计ACC1和ACC5?
  12. 如何使用EMA?
  13. 如果使用Grad-CAM 实现热力图可视化?

如果基础薄弱,对上面的这些功能难以理解可以看我的专栏:经典主干网络精讲与实战
这个专栏,从零开始时,一步一步的讲解这些,让大家更容易接受。

安装包

安装timm

使用pip就行,命令:

pip install timm

mixup增强和EMA用到了timm

安装 grad-cam

pip install grad-cam

数据增强Cutout和Mixup

为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:

pip install torchtoolbox

Cutout实现,在transforms中。

from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([transforms.Resize((224, 224)),Cutout(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

需要导入包:from timm.data.mixup import Mixup,

定义Mixup,和SoftTargetCrossEntropy

  mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,prob=0.1, switch_prob=0.5, mode='batch',label_smoothing=0.1, num_classes=12)criterion_train = SoftTargetCrossEntropy()

参数详解:

mixup_alpha (float): mixup alpha 值,如果 > 0,则 mixup 处于活动状态。

cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 处于活动状态。

cutmix_minmax (List[float]):cutmix 最小/最大图像比率,cutmix 处于活动状态,如果不是 None,则使用这个 vs alpha。

如果设置了 cutmix_minmax 则cutmix_alpha 默认为1.0

prob (float): 每批次或元素应用 mixup 或 cutmix 的概率。

switch_prob (float): 当两者都处于活动状态时切换cutmix 和mixup 的概率 。

mode (str): 如何应用 mixup/cutmix 参数(每个’batch’,‘pair’(元素对),‘elem’(元素)。

correct_lam (bool): 当 cutmix bbox 被图像边框剪裁时应用。 lambda 校正

label_smoothing (float):将标签平滑应用于混合目标张量。

num_classes (int): 目标的类数。

EMA

EMA(Exponential Moving Average)是指数移动平均值。在深度学习中的做法是保存历史的一份参数,在一定训练阶段后,拿历史的参数给目前学习的参数做一次平滑。具体实现如下:


import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn_logger = logging.getLogger(__name__)class ModelEma:def __init__(self, model, decay=0.9999, device='', resume=''):# make a copy of the model for accumulating moving average of weightsself.ema = deepcopy(model)self.ema.eval()self.decay = decayself.device = device  # perform ema on different device from model if setif device:self.ema.to(device=device)self.ema_has_module = hasattr(self.ema, 'module')if resume:self._load_checkpoint(resume)for p in self.ema.parameters():p.requires_grad_(False)def _load_checkpoint(self, checkpoint_path):checkpoint = torch.load(checkpoint_path, map_location='cpu')assert isinstance(checkpoint, dict)if 'state_dict_ema' in checkpoint:new_state_dict = OrderedDict()for k, v in checkpoint['state_dict_ema'].items():# ema model may have been wrapped by DataParallel, and need module prefixif self.ema_has_module:name = 'module.' + k if not k.startswith('module') else kelse:name = knew_state_dict[name] = vself.ema.load_state_dict(new_state_dict)_logger.info("Loaded state_dict_ema")else:_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")def update(self, model):# correct a mismatch in state dict keysneeds_module = hasattr(model, 'module') and not self.ema_has_modulewith torch.no_grad():msd = model.state_dict()for k, ema_v in self.ema.state_dict().items():if needs_module:k = 'module.' + kmodel_v = msd[k].detach()if self.device:model_v = model_v.to(device=self.device)ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

加入到模型中。

#初始化
if use_ema:model_ema = ModelEma(model_ft,decay=model_ema_decay,device='cpu',resume=resume)# 训练过程中,更新完参数后,同步update shadow weights
def train():optimizer.step()if model_ema is not None:model_ema.update(model)# 将model_ema传入验证函数中
val(model_ema.ema, DEVICE, test_loader)

针对没有预训练的模型,容易出现EMA不上分的情况,这点大家要注意啊!

项目结构

InceptionNext_Demo
├─data1
│  ├─Black-grass
│  ├─Charlock
│  ├─Cleavers
│  ├─Common Chickweed
│  ├─Common wheat
│  ├─Fat Hen
│  ├─Loose Silky-bent
│  ├─Maize
│  ├─Scentless Mayweed
│  ├─Shepherds Purse
│  ├─Small-flowered Cranesbill
│  └─Sugar beet
├─models
│  └─faster_vit.py
├─mean_std.py
├─makedata.py
├─train.py
├─cam_image.py
└─test.py

models:来源官方代码,对面的代码做了一些适应性修改。
mean_std.py:计算mean和std的值。
makedata.py:生成数据集。
ema.py:EMA脚本
train.py:训练InceptionNext模型
cam_image.py:热力图可视化

计算mean和std

为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:

from torchvision.datasets import ImageFolder
import torch
from torchvision import transformsdef get_mean_and_std(train_data):train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=False, num_workers=0,pin_memory=True)mean = torch.zeros(3)std = torch.zeros(3)for X, _ in train_loader:for d in range(3):mean[d] += X[:, d, :, :].mean()std[d] += X[:, d, :, :].std()mean.div_(len(train_data))std.div_(len(train_data))return list(mean.numpy()), list(std.numpy())if __name__ == '__main__':train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())print(get_mean_and_std(train_dataset))

数据集结构:

image-20220221153058619

运行结果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

把这个结果记录下来,后面要用!

生成数据集

我们整理还的图像分类的数据集结构是这样的

data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet

pytorch和keras默认加载方式是ImageNet数据集格式,格式是

├─data
│  ├─val
│  │   ├─Black-grass
│  │   ├─Charlock
│  │   ├─Cleavers
│  │   ├─Common Chickweed
│  │   ├─Common wheat
│  │   ├─Fat Hen
│  │   ├─Loose Silky-bent
│  │   ├─Maize
│  │   ├─Scentless Mayweed
│  │   ├─Shepherds Purse
│  │   ├─Small-flowered Cranesbill
│  │   └─Sugar beet
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet

新增格式转化脚本makedata.py,插入代码:

import glob
import os
import shutilimage_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):print('true')#os.rmdir(file_dir)shutil.rmtree(file_dir)#删除再建立os.makedirs(file_dir)
else:os.makedirs(file_dir)from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(train_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)for file in val_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(val_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)

完成上面的内容就可以开启训练和测试了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/26241.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

初学者怎么学习c++(合集)

学习c方法1 找一本好的书本教材,辅助看教学视频。好的教材,可以让你更快更好的进入C/C的世界。在校学生的话,你们的教材通常都是不错的。如果是自学,推荐使用谭浩强出的C/C经典入门教材。看视频是学习比较直观的方式。建议先看课本…

Smartbi 身份认证绕过漏洞

0x00 简介 Smartbi是广州思迈特软件有限公司旗下的商业智能BI和数据分析产品,致力于为企业客户提供一站式商业智能解决方案。 0x01 漏洞概述 Smartbi在安装时会内置三个用户(public、service、system),在使用特定接口时&#x…

高压放大器需要注意哪些指标

高压放大器是一种专门用于输出高电压信号的电子设备,主要应用于精密测量、医疗设备、电力电子等领域中。在选择高压放大器时,需要注意其性能指标,以确保设备的稳定性和可靠性。 以下是高压放大器需要注意的性能指标: 输出电压范围…

利用ArcGIS Pro制作三维效果图

1、新建工程 打开Arcgispro,新建工程,这里我们要用到的模板为全局场景。 2、添加数据 这里添加的数据需要有一个字段内容是数值的,这个字段也是接下来要进行拉伸的字段。 3、高度拉伸 数据添加进来后,如下图所示,这时图层处于2D图层里。 这时我们点中该图层,回到菜单栏…

Spring Cloud Gateway下的GC停顿排查之旅

01 背景 在微服务架构体系流行的当下,Spring Cloud全家桶已经是大多数团队的首选,我们也不例外,并且选择了Spring Cloud Gateway作为了业务网关,进行了一些通用能力的开发,如鉴权、路由等等。作为一个成熟的框架&#…

如何学习Java集合框架? - 易智编译EaseEditing

要学习Java集合框架相关的技术和知识,可以按照以下步骤进行: 掌握Java基础知识: 在学习集合框架之前,确保你已经具备良好的Java编程基础,包括语法、面向对象编程(OOP)原理和常用的核心类库等。…

win10电脑出现网络问题时,如何解决?

我们的Windows可能会出现各种网络连接问题: 尝试连接Wi-Fi网络时出现错误:Windows无法连接到此网络;可以通过Wifi访问互联网,但通过电缆访问以太网却无法正常工作;尝试通过电缆连接互联网时出现错误: Wind…

ios 启动页storyboard 使用记录

本文简单记录ios启动页storyboard 如何使用和注意事项。 xcode窗口简介 以xcode14为例,新建项目如下图,左边文件栏中的LaunchScreen.storyboard 为默认启动页布局。窗口中间部分是storyboard中的组件列表,右侧为预览,可以看到渲…

goland设置内置命令行为当前项目环境

goland设置内置的命令行为当前项目环境 修改 GoLand 中的 SSH 终端配置即可

Python正则表达式校验某个字符串是否是合格的email

Python正则表达式校验某个字符串是否是合格的email 可以借助正则表达式校验某个字符串是否是合规的电子邮箱。对于邮箱的正则表达式有严格的模式,如:^[a-zA-Z0-9_&*-](?:\\.[a-zA-Z0-9_&*-])*(?:[a-zA-Z0-9-]\\.)[a-zA-Z]{2,7}$ 对应的Python…

ESP32设备驱动-直流电机与L298N电机驱动器

直流电机与L298N电机驱动器 文章目录 直流电机与L298N电机驱动器1、L298N介绍2、硬件准备3、软件准备4、驱动实现在本文中,我们将介绍如何使用ESP32通过L298N电机驱动器驱动直流电机。 1、L298N介绍 L298N 电机驱动器模块非常易于与微控制器一起使用,而且相对便宜。 它被广泛…

tinymce编辑器导入docx、doc格式Word文档完整版

看此文章之前需要注意一点 在前端使用导入Word文档并自动解析成html再插入到tinymce编辑器中,在这里我使用的是mammoth.js识别Word内容,并set到编辑器中,使用mammoth只可解析.docx格式的Word,目前的mammoth不支持.doc格式&#x…