Rethinking Data Augmentation for Image Super-resolution

文章目录

  • Rethinking Data Augmentation for Image Super-resolution:
    • 1.概述
    • 2.一些现有方法的分析
    • 3.cutblur
    • 4.MOA 各种策略的混合
    • 5.降噪
    • 6.cutblur 代码

Rethinking Data Augmentation for Image Super-resolution:

A Comprehensive Analysis and a New Strategy

1.概述

根据方法应用的位置将现有的增强技术分为两组:像素域就是针对图像 和 特征域就是中间的特征层。
作者提出cutblur方法,正则化模型使模型可以学到 在图像的什么区域区增强以及如何增强。
作者基于提出cutblur数据增强方法以及其他一些辅助方法 构建一个混合的数据增强策略,效果很好。

作者主要利用EDSR model 在 DIV2K和 RealSR 两个数据集上从头训练,进行分析。

一些数据增强方法 被提出,但是一般应用在high level任务上,对于超分,降噪这种low level任务不一定可行,甚至有害,也比较好理解。
许多增强方法的核心思想是对训练信号进行部分屏蔽或混淆,从而使模型获得更强的泛化能力。然而,与分类等高级任务不同,模型学习抽象图像。
low level任务像素之间的局部和全局关系在低层次视觉任务中尤为重要,如去噪和超分辨率。

2.一些现有方法的分析

下图是一些方法的展示:
这里简单说明一下:
blend就是 对通道 随机加减乘除 随机数,改变颜色
rgb permute 我理解使 各通道换一下位置
cutout 随机去掉一块
mixup 2张图像混合
cutmix 从其他图像取一块贴在新图像上
cutmixup 从其他图像取一块 与 新图像上的一块 混合
在这里插入图片描述

以上这些DA方法作者进行实验
cutout(%0.1%) 就是随机丢弃千分之一的像素,丢弃很少。
可以看出最简单方法 blend, rgb permute, 都是有效的。
作者提出的cut blur效果最好
在这里插入图片描述

关于cutmix介绍:参考https://meetonfriday.com/posts/b4202d1/
在这里插入图片描述

3.cutblur

cutblur也很简单:
从hr中取一块 paste lr中 改变原来的lr
从lr中取一块 paste hr中 改变原来的hr

使用和不使用cutblur的效果对比:
在这里插入图片描述

不使用cutblur训练的模型,当输入cutblur的图片是 会产生过度锐化
毕竟不使用cutblur训练,对全图都会进行超分,也在情理之中。不过这也说明了利用cutblur训练会使model学会在哪些区域做超分,哪些区域保持不变。
在这里插入图片描述

作者又和将hr作为输入,hr作为输出来锻炼model对hr的识别能力,发现效果不如cutblur。

4.MOA 各种策略的混合

MOA就是table1中的各个策略随机选择进行 DA

  1. 不同尺寸的超分模型
    模型越大对DA的包容性越大,否则小模型本身拟合能力不够,不能处理DA数据。
    在这里插入图片描述

  2. 不同的dataset size
    数据集比较小的时候,大模型很容易过拟合,这个时候DA方法有助于帮助改善,下图的c,d
    在这里插入图片描述

  3. 在div2k 和 realsr训练比较
    realsr上应用DA,效果普遍提升较大
    在这里插入图片描述

  4. 一个应用
    现在手机上会有背景虚化或者前后景分辨率不同的图像,这个时候的图像适合 本文提出DA数据的训练

5.降噪

在这里插入图片描述

另外model对不同噪声水平的图像降噪的泛化能力更高

6.cutblur 代码

在这里插入图片描述

看一下下面代码的实现,修改的都是im2,即始终修改的使 lr图像,hr图像始终不变。实际代码 和论文图示,论文解释,论文公式都有点出入。

im1是HR
im2是LR

def cutblur(im1, im2, prob=1.0, alpha=1.0):if im1.size() != im2.size():raise ValueError("im1 and im2 have to be the same resolution.")if alpha <= 0 or np.random.rand(1) >= prob:return im1, im2cut_ratio = np.random.randn() * 0.01 + alphah, w = im2.size(2), im2.size(3)ch, cw = np.int(h*cut_ratio), np.int(w*cut_ratio)cy = np.random.randint(0, h-ch+1)cx = np.random.randint(0, w-cw+1)# apply CutBlur to inside or outsideif np.random.random() > 0.5:im2[..., cy:cy+ch, cx:cx+cw] = im1[..., cy:cy+ch, cx:cx+cw]else:im2_aug = im1.clone()im2_aug[..., cy:cy+ch, cx:cx+cw] = im2[..., cy:cy+ch, cx:cx+cw]im2 = im2_augreturn im1, im2

参数设置

# default augmentation policiesif opt.use_moa:opt.augs = ["blend", "rgb", "mixup", "cutout", "cutmix", "cutmixup", "cutblur"]opt.prob = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]opt.alpha = [0.6, 1.0, 1.2, 0.001, 0.7, 0.7, 0.7]opt.aux_prob, opt.aux_alpha = 1.0, 1.2opt.mix_p = Noneif "RealSR" in opt.dataset:opt.mix_p = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.4]if "DN" in opt.dataset or "JPEG" in opt.dataset:opt.prob = [0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6]if "CARN" in opt.model and not "RealSR" in opt.dataset:opt.prob = [0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2]

训练代码:
LR, HR都是 N3HW shape
然后调用augments增强函数

            HR = inputs[0].to(self.dev)LR = inputs[1].to(self.dev)# match the resolution of (LR, HR) due to CutBlurif HR.size() != LR.size():scale = HR.size(2) // LR.size(2)LR = F.interpolate(LR, scale_factor=scale, mode="nearest")HR, LR, mask, aug = augments.apply_augment(HR, LR,opt.augs, opt.prob, opt.alpha,opt.aux_alpha, opt.aux_alpha, opt.mix_p)SR = self.net(LR)if aug == "cutout":SR, HR = SR*mask, HR*maskloss = self.loss_fn(SR, HR)self.optim.zero_grad()loss.backward()

完整的DA代码:

"""
CutBlur
Copyright 2020-present NAVER corp.
MIT license
"""
import numpy as np
import torch
import torch.nn.functional as Fdef apply_augment(im1, im2,augs, probs, alphas,aux_prob=None, aux_alpha=None,mix_p=None
):idx = np.random.choice(len(augs), p=mix_p)aug = augs[idx]prob = float(probs[idx])alpha = float(alphas[idx])mask = Noneif aug == "none":im1_aug, im2_aug = im1.clone(), im2.clone()elif aug == "blend":im1_aug, im2_aug = blend(im1.clone(), im2.clone(),prob=prob, alpha=alpha)elif aug == "mixup":im1_aug, im2_aug, = mixup(im1.clone(), im2.clone(),prob=prob, alpha=alpha,)elif aug == "cutout":im1_aug, im2_aug, mask, _ = cutout(im1.clone(), im2.clone(),prob=prob, alpha=alpha)elif aug == "cutmix":im1_aug, im2_aug = cutmix(im1.clone(), im2.clone(),prob=prob, alpha=alpha,)elif aug == "cutmixup":im1_aug, im2_aug = cutmixup(im1.clone(), im2.clone(),mixup_prob=aux_prob, mixup_alpha=aux_alpha,cutmix_prob=prob, cutmix_alpha=alpha,)elif aug == "cutblur":im1_aug, im2_aug = cutblur(im1.clone(), im2.clone(),prob=prob, alpha=alpha)elif aug == "rgb":im1_aug, im2_aug = rgb(im1.clone(), im2.clone(),prob=prob)else:raise ValueError("{} is not invalid.".format(aug))return im1_aug, im2_aug, mask, augdef blend(im1, im2, prob=1.0, alpha=0.6):if alpha <= 0 or np.random.rand(1) >= prob:return im1, im2c = torch.empty((im2.size(0), 3, 1, 1), device=im2.device).uniform_(0, 255)rim2 = c.repeat((1, 1, im2.size(2), im2.size(3)))rim1 = c.repeat((1, 1, im1.size(2), im1.size(3)))v = np.random.uniform(alpha, 1)im1 = v * im1 + (1-v) * rim1im2 = v * im2 + (1-v) * rim2return im1, im2def mixup(im1, im2, prob=1.0, alpha=1.2):if alpha <= 0 or np.random.rand(1) >= prob:return im1, im2v = np.random.beta(alpha, alpha)r_index = torch.randperm(im1.size(0)).to(im2.device)im1 = v * im1 + (1-v) * im1[r_index, :]im2 = v * im2 + (1-v) * im2[r_index, :]return im1, im2def _cutmix(im2, prob=1.0, alpha=1.0):if alpha <= 0 or np.random.rand(1) >= prob:return Nonecut_ratio = np.random.randn() * 0.01 + alphah, w = im2.size(2), im2.size(3)ch, cw = np.int(h*cut_ratio), np.int(w*cut_ratio)fcy = np.random.randint(0, h-ch+1)fcx = np.random.randint(0, w-cw+1)tcy, tcx = fcy, fcxrindex = torch.randperm(im2.size(0)).to(im2.device)return {"rindex": rindex, "ch": ch, "cw": cw,"tcy": tcy, "tcx": tcx, "fcy": fcy, "fcx": fcx,}def cutmix(im1, im2, prob=1.0, alpha=1.0):c = _cutmix(im2, prob, alpha)if c is None:return im1, im2scale = im1.size(2) // im2.size(2)rindex, ch, cw = c["rindex"], c["ch"], c["cw"]tcy, tcx, fcy, fcx = c["tcy"], c["tcx"], c["fcy"], c["fcx"]hch, hcw = ch*scale, cw*scalehfcy, hfcx, htcy, htcx = fcy*scale, fcx*scale, tcy*scale, tcx*scaleim2[..., tcy:tcy+ch, tcx:tcx+cw] = im2[rindex, :, fcy:fcy+ch, fcx:fcx+cw]im1[..., htcy:htcy+hch, htcx:htcx+hcw] = im1[rindex, :, hfcy:hfcy+hch, hfcx:hfcx+hcw]return im1, im2def cutmixup(im1, im2,mixup_prob=1.0, mixup_alpha=1.0,cutmix_prob=1.0, cutmix_alpha=1.0
):c = _cutmix(im2, cutmix_prob, cutmix_alpha)if c is None:return im1, im2scale = im1.size(2) // im2.size(2)rindex, ch, cw = c["rindex"], c["ch"], c["cw"]tcy, tcx, fcy, fcx = c["tcy"], c["tcx"], c["fcy"], c["fcx"]hch, hcw = ch*scale, cw*scalehfcy, hfcx, htcy, htcx = fcy*scale, fcx*scale, tcy*scale, tcx*scalev = np.random.beta(mixup_alpha, mixup_alpha)if mixup_alpha <= 0 or np.random.rand(1) >= mixup_prob:im2_aug = im2[rindex, :]im1_aug = im1[rindex, :]else:im2_aug = v * im2 + (1-v) * im2[rindex, :]im1_aug = v * im1 + (1-v) * im1[rindex, :]# apply mixup to inside or outsideif np.random.random() > 0.5:im2[..., tcy:tcy+ch, tcx:tcx+cw] = im2_aug[..., fcy:fcy+ch, fcx:fcx+cw]im1[..., htcy:htcy+hch, htcx:htcx+hcw] = im1_aug[..., hfcy:hfcy+hch, hfcx:hfcx+hcw]else:im2_aug[..., tcy:tcy+ch, tcx:tcx+cw] = im2[..., fcy:fcy+ch, fcx:fcx+cw]im1_aug[..., htcy:htcy+hch, htcx:htcx+hcw] = im1[..., hfcy:hfcy+hch, hfcx:hfcx+hcw]im2, im1 = im2_aug, im1_augreturn im1, im2def cutblur(im1, im2, prob=1.0, alpha=1.0):if im1.size() != im2.size():raise ValueError("im1 and im2 have to be the same resolution.")if alpha <= 0 or np.random.rand(1) >= prob:return im1, im2cut_ratio = np.random.randn() * 0.01 + alphah, w = im2.size(2), im2.size(3)ch, cw = np.int(h*cut_ratio), np.int(w*cut_ratio)cy = np.random.randint(0, h-ch+1)cx = np.random.randint(0, w-cw+1)# apply CutBlur to inside or outsideif np.random.random() > 0.5:im2[..., cy:cy+ch, cx:cx+cw] = im1[..., cy:cy+ch, cx:cx+cw]else:im2_aug = im1.clone()im2_aug[..., cy:cy+ch, cx:cx+cw] = im2[..., cy:cy+ch, cx:cx+cw]im2 = im2_augreturn im1, im2def cutout(im1, im2, prob=1.0, alpha=0.1):scale = im1.size(2) // im2.size(2)fsize = (im2.size(0), 1)+im2.size()[2:]if alpha <= 0 or np.random.rand(1) >= prob:fim2 = np.ones(fsize)fim2 = torch.tensor(fim2, dtype=torch.float, device=im2.device)fim1 = F.interpolate(fim2, scale_factor=scale, mode="nearest")return im1, im2, fim1, fim2fim2 = np.random.choice([0.0, 1.0], size=fsize, p=[alpha, 1-alpha])fim2 = torch.tensor(fim2, dtype=torch.float, device=im2.device)fim1 = F.interpolate(fim2, scale_factor=scale, mode="nearest")im2 *= fim2return im1, im2, fim1, fim2def rgb(im1, im2, prob=1.0):if np.random.rand(1) >= prob:return im1, im2perm = np.random.permutation(3)im1 = im1[:, perm]im2 = im2[:, perm]return im1, im2

官方代码:https://github.com/clovaai/cutblur
这里扩展一下,如果是降噪任务:

  1. rgb permute, blend 比较通用
  2. cutout 这里是随机去掉千分之一的pixel 也可以提高对图像内容的理解,之前有介绍过另一篇论文。
  3. mixup 函数是 一个batch里的图像打乱后与原先batch混合,lr,hr同样的打乱顺序,似乎还行。
  4. 但是这里我想到另一个mixup方法,就是lr, hr以一定程度混合得到lr, 会不会更好,提高处理不同noise level图像的泛化能力。
  5. cutmix, cutmixup类似mixup 在一个batch种的图像,随机打乱后, 对应位置的patch进行替换 或者 混合。
  6. cutblur, 将hr中的patch 替换到 lr中 本文提到的方法。 当然我感觉也可以cut and mixup, 而不只是cut and paste.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/521455.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

springcloud容错机制改进

文章目录 Spring Cloud的容错机制改进常见方案示例 Spring Cloud的容错机制改进常见方案 Spring Cloud的容错机制改进常见方案主要包括以下几个方面&#xff1a; 替换或升级熔断器组件&#xff1a; 从Hystrix迁移到其他高性能的熔断器&#xff1a;由于Hystrix已停止维护&#…

案例研究|辛格林电梯借助DataEase实现数据整合与智能展示

辛格林电梯&#xff08;SIGLEN&#xff09;于2012年创立&#xff0c;是电梯领域的领军品牌之一。该公司总部位于广东佛山&#xff0c;是全国首批获得A1级电梯制造资质的企业&#xff0c;拥有省级工程技术研究中心。辛格林电梯专注于研发和生产高品质电梯产品&#xff0c;涵盖别…

手持气象站的优势

【TH-SQ5】手持气象站作为一种便携式的气象监测设备&#xff0c;具有许多显著的优势&#xff0c;主要体现在以下几个方面&#xff1a; 便携性&#xff1a;手持气象站设计精巧&#xff0c;重量轻&#xff0c;易于携带&#xff0c;这使得用户能够轻松地在不同地点进行气象观测&a…

视觉和GPT再碰火花!CVPR`24 | RegionGPT:面向复杂区域理解的VLM(港大英伟达)

文章链接&#xff1a;https://arxiv.org/pdf/2403.02330 视觉语言模型&#xff08;VLMs&#xff09;通过将大语言模型&#xff08;LLMs&#xff09;与图像文本对集成&#xff0c;经历了快速的发展&#xff0c;但由于视觉编码器的空间意识有限以及使用缺乏详细的区域特定字幕的…

2.1 决策就是为了幸福

传统经济学中&#xff0c;决策是围绕着均衡展开的。 四个均衡 供求均衡 消费者均衡 生产者均衡 厂商均衡 解决的是如何确定产量&#xff0c;能够使厂商利润最大的问题 MR MC 边际收益等于边际成本&#xff0c;此时利润最大&#xff1b;使得边际收益等于边际成本的产量也就成…

Python实现时间序列分析Theta模型(ThetaModel算法)项目实战

说明&#xff1a;这是一个机器学习实战项目&#xff08;附带数据代码文档视频讲解&#xff09;&#xff0c;如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 时间序列分析中的Theta模型(Theta Model)是由Athanasios Theodoridis在2008年提出的一种统计预测方法&…

K8S之实现业务的金丝雀发布

如何实现金丝雀发布 金丝雀发布简介优缺点在k8s中实现金丝雀发布 金丝雀发布简介 金丝雀发布的由来&#xff1a;17 世纪&#xff0c;英国矿井工人发现&#xff0c;金丝雀对瓦斯这种气体十分敏感。空气中哪怕有极其微量的瓦斯&#xff0c;金丝雀也会停止歌唱&#xff1b;当瓦斯…

【Java】线程池

在编程中经常会使用线程来异步处理任务&#xff0c;但是每个线程的创建和销毁都需要一定的开销。如果每次执行一个任务都需要开一个新线程去执行&#xff0c;则这些线程的创建和销毁将消耗大量的资源&#xff1b;并且线程都是“各自为政”的&#xff0c;很难对其进行控制&#…

或运算的理解

||或运算的含义&#xff1a;至少有一个成立&#xff1b; 根据德.摩根定律&#xff0c; 非(P 且 Q) (非 P) 或 (非 Q) 非(P 或 Q) (非 P) 且 (非 Q) 与运算取非等价于 各自取非&#xff0c;然后再或 例如&#xff1a; !(A&&B) !A||!B 或运算&#xff1a; 否一推一…

AIGC绘画关键词 - 写实少女

Unity3D特效百例案例项目实战源码Android-Unity实战问题汇总游戏脚本-辅助自动化Android控件全解手册再战Android系列Scratch编程案例软考全系列Unity3D学习专栏蓝桥系列ChatGPT和AIGC &#x1f449;关于作者 专注于Android/Unity和各种游戏开发技巧&#xff0c;以及各种资源分…

学术神器ChatGPT在论文分析中的妙用!

话语分析是一个广泛的研究领域&#xff0c;它关注的是人们在实际社会交际场景中使用的语言单位。话语分析旨在揭示语言、社会和文化之间的互动关系&#xff0c;以及话诺在构建意义、传递信息、维护社会关系和表达权力等方面的作用。话语分析包括对话分析、批判性话语分析、语篇…

Find My产品越来越得到市场认可,伦茨科技ST17H6x芯片赋能厂家

苹果发布AirTag发布以来&#xff0c;大家都更加注重物品的防丢&#xff0c;苹果的 Find My 就可以查找 iPhone、Mac、AirPods、Apple Watch&#xff0c;如今的Find My已经不单单可以查找苹果的设备&#xff0c;随着第三方设备的加入&#xff0c;将丰富Find My Network的版图。产…