VMamba模型
- 摘要
- Abstract
- 1. VMamba模型
- 1.1 文献摘要
- 1.2 研究背景
- 1.3 状态空间模型(SSM)
- 1.4 VMamba架构
- 1.5 实验
- 1.5.1 ImageNet-1K 上的图像分类
- 1.5.2 COCO 上的物体检测
- 总结
- 2. pytorch练习
摘要
Abstract
1. VMamba模型
文献出处:VMamba: Visual State Space Model
1.1 文献摘要
CNN和VIT一直以来都是视觉领域的骨干网络,虽然 ViT 最近因其卓越的拟合能力而比 CNN 获得了突出地位,但其可扩展性在很大程度上受到注意力计算的二次复杂度的限制。
作者在本文提出了 VMamba,目的是为了将计算复杂度降低到线性,同时保留 ViT 的优势特征,同时也引入了交叉扫描模块(CSM),以实现具有全局感受野的 2D 图像空间中的 1D 选择性扫描。
实验结果证明了 VMamba 在各种视觉感知任务中的良好性能,凸显了与现有基准模型相比,其在输入缩放效率方面的显着优势。
1.2 研究背景
最近,状态空间模型(SSM)在自然语言处理(NLP)任务中展示了具有线性复杂性的长序列建模的巨大潜力。
作者提出了 VMamba,这是一种通用视觉主干,具有基于 SSM 的块,用于高效的视觉表示学习。 VMamba 在降低注意力计算复杂性方面的有效性很大程度上归功于 S6 模型中存在的选择性扫描机制,也称为选择性 SSM。与允许在上下文中进行密集信息路由的传统注意力计算方法不同,S6 要求一维数组(例如文本序列)中的每个元素仅通过压缩隐藏状态来获取上下文知识,从而将二次复杂度降低为线性复杂度。
然而,由于视觉数据的二维性质,单个扫描过程很难同时捕获不同方向上的依赖性信息,从而导致感受野受到限制。 我们将此问题称为“方向敏感”问题,并建议通过新引入的交叉扫描模块(CSM)来解决它。 CSM 不是以单向模式(列向或行向)遍历图像特征图的空间域,而是采用四向扫描策略,即从左上角和右下角开始遍历整个特征 映射到相反的位置(如下图)。 该策略确保特征图中的每个元素集成来自不同方向的所有其他位置的信息,从而在不增加计算复杂度的情况下实现全局感受野。
1.3 状态空间模型(SSM)
SSM 可以被视为线性时不变 (LTI) 系统,它通过隐藏状态 h(t) ε CN 将输入刺激 u(t) ε RL 映射到输出响应 y(t) ε RL。 它们通常被表述为线性常微分方程 (ODE)
离散化 状态空间模型(SSM)作为连续时间模型,在集成到深度学习算法中时面临着巨大的挑战。为了克服这个障碍,离散化过程势在必行。
作者首先使用 CSM(扫描扩展)扫描图像。然后通过 S6 块单独处理四个结果特征,并将四个输出特征合并(扫描合并)以构建最终的 2D 特征图。
通过 SS2D 模块传递数据涉及三个步骤:交叉扫描、使用 S6 块进行选择性扫描以及交叉合并。 给定输入数据,SS2D 首先沿着四个不同的遍历路径(即交叉扫描)将图像块展开为序列,使用单独的 S6 块并行处理每个块序列,然后重塑并合并结果序列以形成输出图 (即交叉合并)。 通过采用互补的遍历路径,SS2D使图像中的每个像素能够有效地整合来自不同方向的所有其他像素的信息,从而促进全局感受野的建立。
1.4 VMamba架构
VMamba-Tiny 的架构概述如下图所示。 VMamba 首先使用 Stem 模块将输入图像划分为图块,从而生成空间维度为 H 4 × W 4 \frac{H}{4} \times \frac{W}{4} 4H×4W 的 2D 特征图。
随后,多个网络阶段,每个阶段由 VSS 块组成,前面是下采样层(第一阶段除外),用于创建分辨率为 H 8 × W 8 \frac{H}{8} \times \frac{W}{8} 8H×8W 、 H 16 × W 16 \frac{H}{16} \times \frac{W}{16} 16H×16W 和 H 32 × W 32 \frac{H}{32} \times \frac{W}{32} 32H×32W。 下采样操作是通过补丁合并进行的,VSS块的详细结构如下图所示:
普通 VSS 块的结构如下图所示,这两个块都可以看作具有跳跃连接的残差网络。 残差网络包含两个分支:一个用于使用 3 × 3 深度卷积层进行特征提取,另一个由线性映射和激活层组成,激活层计算乘性门控信号。 Mamba 和普通 VSS 模块之间的主要区别在于用 SS2D 模块替换了 S6 模块,这使得选择性扫描能够适应 2D 视觉数据。
尽管在长序列建模方面效率很高,但基于 SSM 的架构 [14] 在处理较小规模的输入时经常会遇到计算速度降低的情况,这可能会限制 VMamba 的实际用途。
如下图所示,普通 VMamba-Tiny 模型实现了 426 个图像/秒的吞吐量,包含 22.9M 个参数和 5.6G FLOP(如果选择性扫描操作可以实现,FLOP 将降至 4.5G) 由单个 for 循环实现)。 低吞吐量和高内存开销给VMamba的实际部署带来了挑战。 因此,为了提高其推理速度,人们付出了巨大的努力,主要集中在实现细节和架构设计方面的进步。
从VMamba V0到V2,我们先后在torch.autograd.Function中实现了CSM,然后在Triton中重新实现了它。 这些修改有助于将吞吐量从 426 增加到 467。然后,在 V3 中,我们调整了与选择性扫描操作相关的 CUDA 实现,以适应 float16 输入张量并生成具有 float32 数据类型的输出张量。 与处理 float32 数据类型张量的实现相比,此调整提高了性能,特别是在训练期间,同时与对输入和输出张量使用 float16 相比,还实现了更高的数值稳定性。 此外,在 V4 和 V5 中,我们用线性变换(即 torch.nn.function.linear)替代了选择性扫描中相对较慢的 einsum 操作。 我们还采用了(B,C,H,W)的张量布局来消除不必要的数据排列。 这些变化导致吞吐量增加了 49.5%(从 426 增加到 637),并且不影响其他指标,例如参数数量、FLOP 和 ImageNet-1K 上的分类性能。
1.5 实验
1.5.1 ImageNet-1K 上的图像分类
我们使用 ImageNet-1K 数据集评估 VMamba 在图像分类方面的性能。 遵循[31]中概述的评估协议,VMamba-T/S/B模型从头开始训练300个epoch,前20个epoch专门用于预热,批量大小为1024。训练过程使用AdamW 优化器[34],贝塔设置为(0.9,0.999),动量为0.9,余弦衰减学习率调度器,初始学习率为1×10−3,权重衰减为0.05。 还应用了标签平滑 (0.1) 和指数移动平均 (EMA) 等其他技术。 除此之外,没有采用进一步的培训技术。
下表总结了 VMamba 与 ImageNet-1K 上基准骨干模型的比较结果。很明显,在相似的 FLOP 下,VMamba-T 的性能达到 82.5%,超过 RegNetY-4G 2.5%,超过 DeiT-S 2.5%。 2.7%,Swin-T 1.2%。 值得注意的是,VMamba 的这些性能优势在小型和基本规模模型中始终存在。 具体来说,VMamba-S 的 top-1 准确率达到 83.6%,比 RegNetY-8G 提高 1.9%,比 Swin-S 提高 0.6%。 同时,VMamba-B 的 top-1 准确率达到 83.9%,超过 RegNetY-16G 1.0%,超过 DeiT-B 0.6%。 在计算效率方面,虽然现有的基于 SSM 的视觉模型通常仅在大规模输入 [68](例如 1024 × 1024)下才表现出明显更好的吞吐量,但 VMamba-T 即使在输入分辨率为 224 × 224。这种性能更好,或者至少与最先进的方法相当,并且这种优势在 VMamba-S 和 VMamba-B 中仍然存在。 值得注意的是,随着输入大小从 224 × 224 扩展到 1024 × 1024,VMamba 相对于现有方法的优势变得更加明显,如表 4 所示。后续章节将对此主题进行进一步讨论。
1.5.2 COCO 上的物体检测
我们使用 MSCOCO 2017 数据集评估 VMamba 在对象检测方面的性能。 我们的训练框架是使用 MMDetection 库构建的,并且我们遵循 Swin中使用的超参数和 Mask-RCNN 检测器。 具体来说,我们采用 AdamW 优化器并对 12 和 36 epoch 的预训练分类模型(在 ImageNet-1K 上)进行微调。 VMamba-T/S/B 的丢弃路径率分别设置为 0.2%/0.3%/0.5%。 学习率初始化为 1×10−4,并在第 9 和 11 epoch 减少 10×。 我们实现了批量大小为 16 的多尺度训练和随机翻转,这与目标检测评估的既定实践一致。
VMamba 在 COCO 上的框/掩模平均精度 (AP) 方面保持优势,无论采用何种训练计划(12 或 36 epoch)。 具体来说,通过 12 epoch 的微调计划,VMamba-T/S/B 模型实现了 47.4%/48.7%/49.2% 的目标检测 mAP,超过了 Swin-T/S/B 4.7%/3.9%/2.3 % mAP 和 ConvNeXt-T/S/B 分别提高 3.2%/3.3%/2.2% mAP。 在相同配置下,VMambaT/S/B 的实例分割 mIoU 为 42.7%/43.7%/43.9%,比 Swin-T/S/B 高出 3.4%/2.8%/1.6% mIoU,而 ConvNeXt-T/S/ B 分别为 2.6%/1.9%/1.3% mIoU。 此外,VMamba 在多尺度训练的 36 epoch 微调方案下仍然具有优势,如表 2 所示。与 Swin [32]、ConvNeXt [33]、PVTv2 [55] 和 ViT 等同行相比 [12](使用适配器),VMamba-T/S 表现出卓越的性能,在对象检测上分别实现了 48.9%/49.9% mAP,在实例分割上分别实现了 43.7%/44.2% mIoU。 这些结果强调了 VMamba 在具有密集预测的下游任务中实现有希望的性能的潜力。
总结
本文介绍了 VMamba,这是一种多功能主干网络,专为使用状态空间模型 (SSM) 进行高效视觉表示学习而设计。 VMamba 的主要目标是将选择性 SSM 的优点(包括全局感受野、输入相关的加权参数和线性计算复杂性)融入视觉数据处理中。 具体来说,我们提出交叉扫描模块(CSM)来弥合一维选择性扫描和二维视觉数据之间的差距,并通过数学推导和定性可视化说明其与注意力机制的关系及其在实现全局感受野方面的有效性 。 此外,我们通过改进技术实现和架构设计,显着提高了 VMamba 的推理速度。 VMamba 系列(包括 VMamba-T/S/B 模型)的有效性已通过大量实验和消融研究得到证明,超越了流行的 CNN 和视觉 Transformer 的性能。 此外,VMamba 随着输入分辨率的提高而表现出卓越的可扩展性,在保持线性计算复杂性的同时表现出最小的性能下降。
下周我将具体通过pytorch实现这个网络架构,加油~
2. pytorch练习
数据集处理
import os
from shutil import copy, rmtree
import randomdef mk_file(file_path: str):if os.path.exists(file_path):# 如果文件夹存在,则先删除原文件夹在重新创建rmtree(file_path)os.makedirs(file_path)def main():# 保证随机可复现random.seed(0)# 将数据集中10%的数据划分到验证集中split_rate = 0.1# 指向你解压后的flower_photos文件夹cwd = os.getcwd()data_root = os.path.join(cwd, "CUB_200_2011")origin_CUB_path = os.path.join(data_root, "images")assert os.path.exists(origin_CUB_path), "path '{}' does not exist.".format(origin_CUB_path)CUB_class = [cla for cla in os.listdir(origin_CUB_path)if os.path.isdir(os.path.join(origin_CUB_path, cla))]# 建立保存训练集的文件夹train_root = os.path.join(data_root, "train")mk_file(train_root)for cla in CUB_class:# 建立每个类别对应的文件夹mk_file(os.path.join(train_root, cla))# 建立保存验证集的文件夹val_root = os.path.join(data_root, "val")mk_file(val_root)for cla in CUB_class:# 建立每个类别对应的文件夹mk_file(os.path.join(val_root, cla))for cla in CUB_class:cla_path = os.path.join(origin_CUB_path, cla)images = os.listdir(cla_path)num = len(images)# 随机采样验证集的索引eval_index = random.sample(images, k=int(num*split_rate))for index, image in enumerate(images):if image in eval_index:# 将分配至验证集中的文件复制到相应目录image_path = os.path.join(cla_path, image)new_path = os.path.join(val_root, cla)copy(image_path, new_path)else:# 将分配至训练集中的文件复制到相应目录image_path = os.path.join(cla_path, image)new_path = os.path.join(train_root, cla)copy(image_path, new_path)print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing barprint()print("processing done!")if __name__ == '__main__':main()
参数设置
import argparsedef str2bool(v):if v.lower() in ('yes', 'true', 't', 'y', '1'):return Trueelif v.lower() in ('no', 'false', 'f', 'n', '0'):return Falseelse:raise argparse.ArgumentTypeError('Boolean value expected.')def get_args():parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')parser.add_argument('data', metavar='DIR', nargs='?', default='imagenet',help='path to dataset (default: imagenet)')parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',help='models architecture: default: resnet18)') # arch是需要加载的预训练模型名parser.add_argument("--optimizer", default="SGD", type=str, help='["SGD", "Adam", "AdamW"]')parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',help='number of data loading workers (default: 4)')parser.add_argument('--epochs', default=120, type=int, metavar='N',help='number of total epochs to run')parser.add_argument('--start-epoch', default=0, type=int, metavar='N',help='manual epoch number (useful on restarts)')parser.add_argument('-b', '--batch-size', default=16, type=int,metavar='N',help='mini-batch size (default: 256), this is the total ''batch size of all GPUs on the current node when ''using Data Parallel or Distributed Data Parallel')# optimizerparser.add_argument('--lr', '--learning-rate', default=0.005, type=float,metavar='LR', help='initial learning rate', dest='lr')parser.add_argument('--momentum', default=0.9, type=float, metavar='M',help='momentum')parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,metavar='W', help='weight decay (default: 1e-4)',dest='weight_decay')# center lossparser.add_argument('--parts', default=32, type=int,metavar='N', help='number of parts (default: 32)')parser.add_argument('--alpha', default=0.95, type=float,metavar='N', help='weight for BAP loss')# schedulerparser.add_argument('--decay-step', default=20, type=int, metavar='N',help='learning rate decay step')parser.add_argument('--gamma', default=0.5, type=float, metavar='M',help='gamma')parser.add_argument('-p', '--print-freq', default=10, type=int,metavar='N', help='print frequency (default: 10)')parser.add_argument('--resume', default='', type=str, metavar='PATH',help='path to latest checkpoint (default: none)')parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',help='evaluate models on validation set')parser.add_argument('--pretrained', dest='pretrained', action='store_true',help='use pre-trained models')# parser.add_argument('--world-size', default=-1, type=int,# help='number of nodes for distributed training')# parser.add_argument('--rank', default=-1, type=int,# help='node rank for distributed training')# parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,# help='url used to set up distributed training')# parser.add_argument('--dist-backend', default='nccl', type=str,# help='distributed backend')parser.add_argument('--seed', default=1, type=int,help='seed for initializing training. ')parser.add_argument('--gpu', default=1, type=int,help='GPU id to use.')# parser.add_argument('--multiprocessing-distributed', action='store_true',# help='Use multi-processing distributed training to launch '# 'N processes per node, which has N GPUs. This is the '# 'fastest way to use PyTorch for either single node or '# 'multi node data parallel training')# parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")# trainingparser.add_argument('--dataset', type=str, default='CUB',choices=['CUB','Cars','Aircraft'],help='dataset for FGVC')parser.add_argument('--name', type=str, default='test_case')parser.add_argument('--lr_step', type=int, default=30) # lr_stepparser.add_argument('--resize-size', type=int, default=512, help='validation resize size')parser.add_argument('--crop-size', type=int, default=448, help='validation crop size')parser.add_argument('--VAL-CROP', type=str2bool, nargs='?', const=True, default=True,help='Evaluation method''If True, Evaluate on 256x256 resized and center cropped 224x224 map''If False, Evaluate on directly 224x224 resized map')# CAMparser.add_argument('--cam-thr', type=float, default=0.2, help='cam threshold value(default=0.15)')# Random Erasingparser.add_argument('--p', default=0.5, type=float, help='Random Erasing probability')parser.add_argument('--sh', default=0.4, type=float, help='max erasing area')parser.add_argument('--r1', default=0.3, type=float, help='aspect of erasing area')args = parser.parse_args()return args
Res2Net模型
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch
import torch.nn.functional as F
__all__ = ['Res2Net', 'res2net50']model_urls = {'res2net50_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_4s-06e79181.pth','res2net50_48w_2s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_48w_2s-afed724a.pth','res2net50_14w_8s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_14w_8s-6527dddc.pth','res2net50_26w_6s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_6s-19041792.pth','res2net50_26w_8s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_8s-2c7c9f12.pth','res2net101_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_26w_4s-02a759a1.pth',
}class Bottle2neck(nn.Module):expansion = 4def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype='normal'):""" 构造函数参数:inplanes: 输入通道维度planes: 输出通道维度stride: 卷积步长。替代池化层。downsample: 当stride = 1时为NonebaseWidth: conv3x3的基本宽度scale: 尺度数量。type: 'normal': 正常设置。 'stage': 新阶段的第一个块。"""super(Bottle2neck, self).__init__()# 计算卷积核的宽度width = int(math.floor(planes * (baseWidth / 64.0)))# 第一个1x1卷积层self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False)self.bn1 = nn.BatchNorm2d(width * scale)# 计算重复次数if scale == 1:self.nums = 1else:self.nums = scale - 1# 如果是新阶段的第一个块,则使用平均池化层进行下采样if stype == 'stage':self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)# 定义重复的卷积层和BN层convs = []bns = []for i in range(self.nums):convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, bias=False))bns.append(nn.BatchNorm2d(width))# 创建了两个 nn.ModuleList 对象 self.convs 和 self.bns,用于存储多个卷积层和批量归一化层。self.convs = nn.ModuleList(convs)self.bns = nn.ModuleList(bns)# 最后一个1x1卷积层self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(planes * self.expansion)# 激活函数self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stype = stypeself.scale = scaleself.width = widthdef forward(self, x):residual = x# 第一个1x1卷积层的计算out = self.conv1(x)out = self.bn1(out)out = self.relu(out)# 将输出按照宽度进行分割spx = torch.split(out, self.width, 1)for i in range(self.nums):# 如果是第一个块或者是新阶段的第一个块,则直接取分割后的部分if i == 0 or self.stype == 'stage':sp = spx[i]else:# 否则,累加之前的部分sp = sp + spx[i]# 对部分进行卷积、BN和ReLU操作sp = self.convs[i](sp)sp = self.relu(self.bns[i](sp))if i == 0:out = spelse:# 将处理后的部分拼接起来out = torch.cat((out, sp), 1)# 如果尺度不为1且为正常设置,将最后一个部分拼接到一起if self.scale != 1 and self.stype == 'normal':out = torch.cat((out, spx[self.nums]), 1)# 如果尺度不为1且为新阶段的第一个块,则对最后一个部分进行平均池化并拼接elif self.scale != 1 and self.stype == 'stage':out = torch.cat((out, self.pool(spx[self.nums])), 1)# 最后一个1x1卷积层的计算out = self.conv3(out)out = self.bn3(out)# 如果存在下采样,则对输入进行下采样if self.downsample is not None:residual = self.downsample(x)# 残差连接并进行ReLU激活out += residualout = self.relu(out)return outclass Res2Net(nn.Module):def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000):# 初始化Res2Net模型self.inplanes = 64 # 设置输入通道数为64self.baseWidth = baseWidthself.scale = scalesuper(Res2Net, self).__init__() # 调用父类的构造函数# 定义网络的第一层:7x7的卷积层,输入通道数为3,输出通道数为64,步长为2,填充为3self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)# Batch Normalization层,对每个channel的数据进行标准化self.bn1 = nn.BatchNorm2d(64)# 激活函数ReLUself.relu = nn.ReLU(inplace=True)# 最大池化层,窗口大小为3x3,步长为2,填充为1self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 定义4个Res2Net的阶段(stage)self.layer1 = self._make_layer(block, 64, layers[0]) # 第一个阶段,输出通道数为64self.layer2 = self._make_layer(block, 128, layers[1], stride=2) # 第二个阶段,输出通道数为128,步长为2self.layer3 = self._make_layer(block, 256, layers[2], stride=2) # 第三个阶段,输出通道数为256,步长为2self.layer4 = self._make_layer(block, 512, layers[3], stride=2) # 第四个阶段,输出通道数为512,步长为2# 全局平均池化层,将每个通道的特征图变成一个数self.avgpool = nn.AdaptiveAvgPool2d(1)# 全连接层,将512维的特征向量映射到num_classes维的向量,用于分类self.fc = nn.Linear(512 * block.expansion, num_classes)# 初始化网络参数for m in self.modules():if isinstance(m, nn.Conv2d):# 使用kaiming正态分布初始化卷积层参数nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, nn.BatchNorm2d):# 将Batch Normalization层的权重初始化为1,偏置初始化为0nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)def _make_layer(self, block, planes, blocks, stride=1):# 构建Res2Net的一个阶段(stage),包含多个blockdownsample = Noneif stride != 1 or self.inplanes != planes * block.expansion:# 如果输入输出通道数不一致,或者步长不为1,需要添加下采样层downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion,kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(planes * block.expansion),)# 构建阶段的每个blocklayers = []layers.append(block(self.inplanes, planes, stride, downsample=downsample,stype='stage', baseWidth=self.baseWidth, scale=self.scale))self.inplanes = planes * block.expansionfor i in range(1, blocks):layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale))return nn.Sequential(*layers)def forward(self, x):# 定义前向传播过程x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.view(x.size(0), -1)x = self.fc(x)return xdef res2net50(pretrained=False, **kwargs):"""Constructs a Res2Net-50 model.Res2Net-50 refers to the Res2Net-50_26w_4s.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)if pretrained:model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_4s']))return modeldef res2net50_26w_4s(pretrained=False, **kwargs):"""Constructs a Res2Net-50_26w_4s model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)if pretrained:model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_4s']))return modeldef res2net101_26w_4s(pretrained=False, **kwargs):"""Constructs a Res2Net-50_26w_4s model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs)if pretrained:model.load_state_dict(model_zoo.load_url(model_urls['res2net101_26w_4s']))return modeldef res2net50_26w_6s(pretrained=False, **kwargs):"""Constructs a Res2Net-50_26w_4s model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 6, **kwargs)if pretrained:model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_6s']))return modeldef res2net50_26w_8s(pretrained=False, **kwargs):"""Constructs a Res2Net-50_26w_4s model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 8, **kwargs)if pretrained:model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_8s']))return modeldef res2net50_48w_2s(pretrained=False, **kwargs):"""Constructs a Res2Net-50_48w_2s model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 48, scale = 2, **kwargs)if pretrained:model.load_state_dict(model_zoo.load_url(model_urls['res2net50_48w_2s']))return modeldef res2net50_14w_8s(pretrained=False, **kwargs):"""Constructs a Res2Net-50_14w_8s model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 14, scale = 8, **kwargs)if pretrained:model.load_state_dict(model_zoo.load_url(model_urls['res2net50_14w_8s']))return modelif __name__ == '__main__':images = torch.rand(1, 3, 224, 224).cuda(0)model = res2net50_48w_2s(pretrained=False)model = model.cuda(0)print(model(images).size())print(model)
训练代码
# coding:utf-8 允许中文注释
import numpy as np
import osimport torchvisionos.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from option import get_args
from model import resnet50
from util import AverageMeter, accuracy, save_checkpoint, load_model_checkpoint
from res2net import res2net50_48w_2sdef init_seeds(seed=0):torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)if seed == 0:torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falsebest_acc1 = 0.
def repeat_channels(x):# 这个函数将输入的 PIL 图像 x 复制到三个通道,模拟 RGB 图像return x.repeat(3, 1, 1)def main():print("Start...")global best_acc1args = get_args()DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')init_seeds(seed=0) # set random seedif args.gpu is not None:print("Use GPU: {} for training".format(args.gpu)) # 训练所用的GUP ID# directory for saveargs.log_folder = os.path.join('log', 'res2net50_48w_2s')if not os.path.exists(args.log_folder):os.makedirs(args.log_folder)if args.dataset == "CUB" and args.arch == "resnet50":channels = 2048num_classes = 200data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root pathimage_path = os.path.join(data_root, '/data/tgf/resnet/Data/')# train_dir = '/data/tgf/resnet/Data/trian'# valid_dir = '/data/tgf/resnet/Data/test'elif args.dataset == 'Cars' and args.arch == "resnet50":channels = 2048num_classes = 196data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root pathimage_path = os.path.join(data_root, '/tgf/resnet/CUB_200_2011/dataset')# train_dir = '/learn_pytorch/resnet/Data/trian'# valid_dir = '/learn_pytorch/resnet/Data/test'elif args.dataset == "Aircraft" and args.arch == "resnet50":channels = 2048num_classes = 100data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root pathimage_path = os.path.join(data_root, '/data/tgf/resnet/Data')# train_dir = '/learn_pytorch/resnet/Data/trian'# valid_dir = '/learn_pytorch/resnet/Data/test'else:raise Exception("No dataset named {}".format(args.dataset))# Modelprint("=> creating model '{}'".format(args.arch))print("num_classes ", num_classes)model = res2net50_48w_2s(pretrained=True)# model_weight_path = "./resnet50_pre.pth"# assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)# model.load_state_dict(torch.load(model_weight_path, map_location='cpu'))# change fc layer structurein_channel = model.fc.in_featuresmodel.fc = nn.Linear(in_channel, num_classes)model = model.cuda()cudnn.benchmark = True# Loading training/validation datasettrain_transform = transforms.Compose([transforms.Resize((512, 512)),transforms.RandomCrop((448, 448)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),# transforms.Lambda(repeat_channels),transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),])test_transform = transforms.Compose([transforms.Resize((512, 512)),transforms.CenterCrop((448, 448)), # RandomCrop for train and CenterCrop for testtransforms.ToTensor(),# transforms.Lambda(repeat_channels),transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),])train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), transform=train_transform)print("train_dataset为:",train_dataset)valid_dataset = datasets.ImageFolder(root=os.path.join(image_path, "test"), transform=test_transform)train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size,shuffle=True, num_workers=args.workers, pin_memory=True)# print("train_loader为:",train_loader)valid_loader = DataLoader(dataset=valid_dataset, batch_size=args.batch_size,shuffle=False, num_workers=args.workers, pin_memory=True)print("using {} images for training, {} images for validation.".format(len(train_dataset), len(valid_dataset)))# define loss function (criterion), optimizer, and learning rate schedulercriterion = nn.CrossEntropyLoss().cuda()optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, nesterov=True, momentum=args.momentum,weight_decay=args.weight_decay)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.decay_step, gamma=args.gamma)# optionally resume from a checkpointif args.resume:model, optimizer = load_model_checkpoint(model, optimizer, args)def train(train_loader, model, criterion, optimizer, epoch, args):# AverageMeter for Performancelosses = AverageMeter('Loss', ':.4e')top1 = AverageMeter('Acc@1', ':6.2f')top5 = AverageMeter('Acc@5', ':6.2f')DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# Switch to train modemodel.train()# lr = next(iter(optimizer.param_groups))['lr']train_bar = tqdm(train_loader) # 训练集进度条for batch_idx, (inputs, targets) in enumerate(train_bar):idx = batch_idxinputs, targets = Variable(inputs).cuda(), Variable(targets).cuda()# inputs, targets = Variable(inputs), Variable(targets)# compute outputoutputs = model(inputs) # 前向传播loss = criterion(outputs, targets)# # measure accuracy and record lossacc1, acc5 = accuracy(outputs, targets, topk=(1, 5))losses.update(loss.item(), inputs.size(0))top1.update(acc1[0], inputs.size(0))top5.update(acc5[0], inputs.size(0))# compute gradient and do SGD stepoptimizer.zero_grad()loss.backward() # !!optimizer.step()# print infodescription = "[Train:{0:3d}/{1:3d}] Top1-cls: {2:6.2f}, Top5-cls: {3:6.2f}, Loss: {4:7.4f},". \format(epoch + 1, args.epochs, top1.avg, top5.avg, losses.avg)train_bar.set_description(desc=description)return top1.avg, losses.avgbest_acc_epoch = 0for epoch in range(args.start_epoch, args.epochs):lr = next(iter(optimizer.param_groups))['lr']# ————————————————Train————————————————#train_acc1, train_losses = train(train_loader, model, criterion, optimizer, epoch, args)scheduler.step() # 放到每个epoch训练完之后# tensorboardwith SummaryWriter(log_dir=os.path.join(args.log_folder, 'no_seed/train'), comment='train') as writer:writer.add_scalar('Train/learning_rate', lr, epoch)writer.add_scalar('Train/train_acc1', train_acc1, epoch)writer.add_scalar('Train/train_loss', train_losses, epoch)writer.flush()writer.close()# ————————————————Test————————————————#val_acc1, val_losses = validate(valid_loader, model, criterion, epoch, args) # Test!!!# tensorboardwith SummaryWriter(log_dir=os.path.join(args.log_folder, 'no_seed/val'), comment='test') as writer:writer.add_scalar('Test/val_acc1', val_acc1, epoch)writer.add_scalar('Test/val_loss', val_losses, epoch)writer.flush()writer.close()is_best = val_acc1 > best_acc1 # True / Falsebest_acc1 = max(val_acc1, best_acc1)# save_checkpoint({# 'epoch': epoch + 1,# 'arch': args.arch,# 'state_dict': model.state_dict(),# 'best_acc1': best_acc1,# 'optimizer': optimizer.state_dict(),# # 'scheduler': scheduler.state_dict()# }, is_best, args.log_folder)if is_best:best_acc_epoch = epoch + 1savepath = "/data/tgf/resnet/log/resnet50_in_CUB/best.pth"torch.save(model, savepath)print("Until %d epochs, Best Acc@1 %.3f in the %d-th epoch" % (epoch + 1, best_acc1, best_acc_epoch))with open(os.path.join(args.log_folder, 'result.txt'), 'w') as file:file.write("best_acc1 {}".format(best_acc1))file.close()def validate(val_loader, model, criterion, epoch, args):# AverageMeter for Performancelosses = AverageMeter('Loss', ':.4e')top1 = AverageMeter('Acc@1', ':6.2f')top5 = AverageMeter('Acc@5', ':6.2f')# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# switch to evaluate modemodel.eval()with torch.no_grad():val_bar = tqdm(val_loader)for batch_idx, (inputs, targets) in enumerate(val_bar):idx = batch_idxinputs, targets = Variable(inputs).cuda(), Variable(targets).cuda()# Compute outputoutputs = model(inputs)loss = criterion(outputs, targets)# measure accuracy and record lossacc1, acc5 = accuracy(outputs, targets, topk=(1, 5))losses.update(loss.item(), inputs.size(0))top1.update(acc1[0], inputs.size(0))top5.update(acc5[0], inputs.size(0))# print infodescription = "[Valid:{0:3d}/{1:3d}] Top1-cls: {2:6.2f}, Top5-cls: {3:6.2f}, Loss: {4:7.4f}, ". \format(epoch + 1, args.epochs, top1.avg, top5.avg, losses.avg)val_bar.set_description(desc=description)return top1.avg, losses.avgif __name__ == '__main__':main()
实验结果