生动理解深度学习精度提升利器——测试时增强(TTA)

测试时增强(Test-Time Augmentation,TTA)是一种在深度学习模型的测试阶段应用数据增强的技术手段。它是通过对测试样本进行多次随机变换或扰动,产生多个增强的样本,并使用这些样本进行预测的多数投票或平均来得出最终预测结果。

为了直观理解TTA执行的过程,这里我绘制了流程示意图如下所示:

TTA的过程如下:

  1. 数据增强:

    • 在测试时,对每个测试样本应用随机的变换或扰动操作,生成多个增强样本。
    • 常用的数据增强操作包括随机翻转、随机旋转、随机裁剪、随机缩放等。这些操作可以增加样本的多样性,模拟真实世界中的不确定性和变化。
  2. 多次预测:

    • 使用训练好的模型对生成的增强样本进行多次预测。
    • 对于每个增强样本,都会得到一个预测结果。
  3. 预测结果集成:

    • 对多次预测的结果进行集成,常用的集成方式有多数投票和平均。
    • 对于分类任务,多数投票即选择预测结果中出现次数最多的类别作为最终的预测类别。对于回归任务,平均即将多次预测结果进行平均。

接下来针对性地对比分析下使用TTA带来的优点和缺点:

优点:

  • 提高鲁棒性:通过应用数据增强,TTA可以增加样本的多样性和泛化能力,提高模型在面对未见过的输入分布和未知变化时的鲁棒性。
  • 提高准确性:通过多次预测和集成,TTA可以减少预测结果的随机性和偶然误差,提高最终预测结果的稳定性和准确性。
  • 模型评估和排名:TTA可以改变模型预测的不确定性,使得模型评估更可靠,能够更好地对不同模型进行性能排名。

缺点:

  • 计算开销:生成和预测多个增强样本会增加计算量。特别是在大型模型和复杂任务中,可能导致推理时间的显著增加,限制了TTA的实际应用。
  • 可能造成过拟合:对于已包含在训练数据中的变换或扰动,如果在测试时反复应用,可能会导致模型对这些特定样本的过拟合,从而影响模型的泛化能力。

TTA是一种常用的技术手段,通过应用数据增强和集成预测结果,可以提高深度学习模型在测试阶段的性能和鲁棒性。然而,TTA的应用需要平衡计算开销和预测准确性,并谨慎处理可能导致模型过拟合的问题。根据具体任务和需求,可以灵活选择合适的增强操作和集成策略来使用TTA。

下面是demo代码实现,如下所示:

import numpy as np
import torch
import torchvision.transforms as transformsdef test_time_augmentation(model, image, n_augmentations):# 定义数据增强的变换transform = transforms.Compose([transforms.ToTensor(),# 在此添加你需要的任何其他数据增强操作])# 存储多次预测结果的列表predictions = []# 对图像应用多次增强和预测for _ in range(n_augmentations):augmented_image = transform(image)augmented_image = augmented_image.unsqueeze(0)  # 增加一个维度作为批次with torch.no_grad():# 切换模型为评估模式,确保不执行梯度计算model.eval()# 使用增强的图像进行预测output = model(augmented_image)_, predicted = torch.max(output.data, 1)predictions.append(predicted.item())# 执行多数投票并返回最终预测结果final_prediction = np.bincount(predictions).argmax()return final_prediction

在前文鸟类细粒度识别项目实验中测试发现,应用TTA技术后,对应的评估指标上有明显的涨点,但是很明显地可以发现:在整个测试过程中资源消耗增加明显,且耗时显著增长,这也是TTA无法避免的劣势,在对精度要求较高的场景下可以有限考虑引入TTA,但是对于计算时耗要求较高的场景则不推荐使用TTA。

开源社区里面也有一些优秀的实现,这里推荐一个,地址在这里,如下所示:

目前有将近1k的star量,还是蛮不错的。

安装方法如下所示:

pip安装:
pip install ttach源码安装:
pip install git+https://github.com/qubvel/ttach
        Input|           # input batch of images / / /|\ \ \      # apply augmentations (flips, rotation, scale, etc.)| | | | | | |     # pass augmented batches through model| | | | | | |     # reverse transformations for each batch of masks/labels\ \ \ / / /      # merge predictions (mean, max, gmean, etc.)|           # output batch of masks/labelsOutput

目前支持分割、分类、关键点检测三种任务,实例使用如下所示:

Segmentation model wrapping [docstring]:
import ttach as tta
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')Classification model wrapping [docstring]:
tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())Keypoints model wrapping [docstring]:
tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True)
data transforms 实例实现如下所示:
# defined 2 * 2 * 3 * 3 = 36 augmentations !
transforms = tta.Compose([tta.HorizontalFlip(),tta.Rotate90(angles=[0, 180]),tta.Scale(scales=[1, 2, 4]),tta.Multiply(factors=[0.9, 1, 1.1]),        ]
)tta_model = tta.SegmentationTTAWrapper(model, transforms)

Custom model (multi-input / multi-output)实现如下所示:

# Example how to process ONE batch on images with TTA
# Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform() # augment imageaugmented_image = transformer.augment_image(image)# pass to modelmodel_output = model(augmented_image, another_input_data)# reverse augmentation for mask and labeldeaug_mask = transformer.deaugment_mask(model_output['mask'])deaug_label = transformer.deaugment_label(model_output['label'])# save resultslabels.append(deaug_mask)masks.append(deaug_label)# reduce results as you want, e.g mean/max/min
label = mean(labels)
mask = mean(masks)

Transforms详情如下所示:

TransformParametersValues
HorizontalFlip--
VerticalFlip--
Rotate90anglesList[0, 90, 180, 270]
Scalescales
interpolation
List[float]
"nearest"/"linear"
Resizesizes
original_size
interpolation
List[Tuple[int, int]]
Tuple[int,int]
"nearest"/"linear"
AddvaluesList[float]
MultiplyfactorsList[float]
FiveCropscrop_height
crop_width
int
int

支持的结果融合方法如下:

mean
gmean (geometric mean)
sum
max
min
tsharpen (temperature sharpen with t=0.5)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/103168.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

826. 安排工作以达到最大收益;2257. 统计网格图中没有被保卫的格子数;816. 模糊坐标

826. 安排工作以达到最大收益 核心思想:排序维护最大利润。首先我们需要对工人按照能力排序,前面工人满足的最大利润后面的工人肯定是满足的,所以我们只需要用一个tmp来维护小于等于当前工人的最大利润,然后如何得到tmp&#xff…

flutter开发实战-实现自定义bottomNavigationBar样式awesome_bottom_bar

flutter开发实战-实现自定义bottomNavigationBar样式awesome_bottom_bar 在开发过程中,需要自定义bottomNavigationBar样式,可以自定义实现,这里使用的是awesome_bottom_bar库 一、awesome_bottom_bar 在pubspec.yaml中引入awesome_bottom_…

es滚动查询分析和使用步骤

ES在进行普通的查询时,默认只会查询出来10条数据。我们通过设置es中的size可以将最终的查询结果从10增加到10000。如果需要查询数据量大于es的翻页限制或者需要将es的数据进行导出又当如何? Elasticsearch提供了一种称为"滚动查询"&#xff08…

探索Apache Hive:融合专业性、趣味性和吸引力的数据库操作奇幻之旅

文章目录 版权声明一 数据库操作二 Hive数据表操作2.1 表操作语法和数据类型2.2 Hive表分类2.3 内部表Vs外部表2.4 内部表操作2.4.1 创建内部表2.4.2 其他创建内部表的形式2.4.3 数据分隔符2.4.4 自定义分隔符2.4.5 删除内部表 2.5 外部表操作2.5.1 创建外部表2.5.2 操作演示2.…

Jmeter进阶使用指南-分布式测试

当你需要模拟大量并发用户并测试应用程序的性能时,JMeter的分布式测试功能非常有用。分布式测试允许你使用多个JMeter实例来模拟并发用户,从而提供更高的负载。 下面是一个详细的介绍和讲解分布式测试的步骤: 准备主机和从机: 首…

QT 插件化图像算法软件架构

为什么要做插件化软件架构? 通过 结构化、模块化、松耦合、高内聚、插件化,有助于提升软件开发效率。 1、通过结构化、模块化、插件化方式的软件设计与开发,减少重复开发、重复测试、重复BUG修复,从而提高开发效率、提升代码质量…

flask bootstrap页面json格式化

html <!DOCTYPE html> <html lang"en"> <head><!-- 新 Bootstrap5 核心 CSS 文件 --> <link rel"stylesheet" href"static/bootstrap-5.0.0-beta1-dist/css/bootstrap.min.css"><!-- 最新的 Bootstrap5 核心 …

算法通关村17关 | 透析跳跃游戏

1. 跳跃游戏 题目 LeetCode55 给定一个非负整数数组&#xff0c;最初位于数组的第一个位置&#xff0c;数组中的每个元素代表你再该位置可以跳跃的最大长度&#xff0c;判断你是否能够达到最后一个位置。 思路 如果当前位置元素如果是3&#xff0c;我们无需考虑是跳几步&#…

【HTTP爬虫ip实操】智能路由构建高效稳定爬虫系统

在当今信息时代&#xff0c;数据的价值越来越受到重视。对于许多企业和个人而言&#xff0c;网络爬取成为了获取大量有用数据的关键手段之一。然而&#xff0c;在面对反爬机制、封锁限制以及频繁变动的网站结构时&#xff0c;如何确保稳定地采集所需数据却是一个不容忽视且具挑…

【前端】CSS-Grid网格布局

目录 一、grid布局是什么二、grid布局的属性三、容器属性1、display①、语句②、属性值 2、grid-template-columns属性、grid-template-rows属性①、定义②、属性值1&#xff09;、固定的列宽和行高2&#xff09;、repeat()函数3&#xff09;、auto-fill关键字4&#xff09;、f…

Redis多机数据库实现

Redis多机数据库实现 为《Redis设计与实现》笔记 复制 客户端可以使用SLAVEOF命令将指定服务器设置为该服务器的主服务器 127.0.0.1:12345> SLAVEOF 127.0.0.1 6379127.0.0.1:6379将被设置为127.0.0.1:123456的主服务器 旧版复制功能的实现 Redis的复制功能分为同步&a…

OpenHarmony:如何使用HDF驱动控制LED灯

一、程序简介 该程序是基于OpenHarmony标准系统编写的基础外设类&#xff1a;RGB LED。 目前已在凌蒙派-RK3568开发板跑通。详细资料请参考官网&#xff1a;https://gitee.com/Lockzhiner-Electronics/lockzhiner-rk3568-openharmony/tree/master/samples/b02_hdf_rgb_led。 …