IDR: Self-Supervised Image Denoising via Iterative Data Refinement

文章目录

  • IDR: Self-Supervised Image Denoising via Iterative Data Refinement
      • 1. noisy-clean pair 比较难获取
      • 2. noiser-noisy pair 比较容易获取,但是训练效果呢?
        • 2.1 noiser-noisy 训练的模型,能够对 noisy 图像一定程度的降噪
        • 2.2 noiser-noisy数据 越接近 noisy-clean 数据,训练的效果越好。
      • 3.通过训练 让noiser-noisy数据 更接近 noisy-clean 数据
      • 4.Fast Iterative Data Refinement
      • 5. sensenoise-500 dataset
        • 5.1 数据集 error
      • 6.训练

IDR: Self-Supervised Image Denoising via Iterative Data Refinement

IDR 是一个无监督降噪模型。

1. noisy-clean pair 比较难获取

noisy-clean pair:
x: noidy image, y: clean image
但是 y比较难获取

noisr-noisy pair
x + n , x

在这里插入图片描述

对噪声图像再添加噪声,得到 噪声更大的图像。

这里的n表示的是sensor的噪声模型(也可以是采样得到的,参考作者另一篇论文rethinking noise).

2. noiser-noisy pair 比较容易获取,但是训练效果呢?

作者的两个发现:

2.1 noiser-noisy 训练的模型,能够对 noisy 图像一定程度的降噪

如下图:

在这里插入图片描述

2.2 noiser-noisy数据 越接近 noisy-clean 数据,训练的效果越好。

在这里插入图片描述

3.通过训练 让noiser-noisy数据 更接近 noisy-clean 数据

1.训练F0,生成新的数据集
在这里插入图片描述

2.利用新的数据集训练F1.
由于 新的数据集 更接近 noisy-clean 数据,因此训练的结果对于noisy的表现会更好。
在这里插入图片描述

3.因此可以迭代训练,不断生成新的less biased数据集, 训练新的model
在这里插入图片描述

4.Fast Iterative Data Refinement

以上迭代训练需要生成多次数据集,训练多次model.

作者提出改进的方案:
a.每个epoch refine一次dataset, 不需要训练到完全收敛
b.利用上个epoch的model初始化下一个epoch的model
*
这样改进下来,和正常训练差别不大了,除了每个epoch要更新一次数据集。

实际的效果如下:
每次迭代,降噪效果都有改善。

请添加图片描述

5. sensenoise-500 dataset

IMX586, 3000x4000 pixels, low light conditions.

64 帧 = 4 帧 正常曝光noisy image + 60 帧 长曝光(1s-2s) use median value ad ground truth

正常曝光和长曝光的图像如何 保持亮度一致呢?需要设置 iso 和曝光时间:
在这里插入图片描述

图像示例:
在这里插入图片描述

最终图像数据是1010张(505pairs):
在这里插入图片描述

dng是噪声图, npy是groundtruth

5.1 数据集 error

部分ground truth 高亮区域偏红色。

6.训练

4种方案训练sensenoise 500

  1. pair
  2. add GP noise
  3. idr(本文)
  4. noise2noise: add gp noise

由于不知道数据集的实际噪声参数。因此add noise都是添加的一定范围

k = np.random.uniform(0.8, 3)
scale = np.random.uniform(1, 30)
# k = torch.FloatTensor(k)
# scale = torch.FloatTensor(scale)
in_img1 = add_noise_torch(gt_img, k, scale).to(device)
in_img2 = add_noise_torch(gt_img, k, scale).to(device) # 是否需要转化为int16类型,因为实际raw图数据都是整数
gt_img = gt_img.to(device)
# print(in_img.min(), in_img.max(), in_img.mean(), in_img.var())
# print(gt_img.min(), gt_img.max(), gt_img.mean(), gt_img.var()) # more and more little
gt_img = gt_img / 1023
in_img = in_img1 / 1023
in_img = torch.clamp(in_img, 0, 1)
in_img2 = in_img2 / 1023
in_img2 = torch.clamp(in_img2, 0, 1)

idr训练:

import glob
import os.pathimport cv2
import numpy as np
import rawpy
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdmfrom skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
from model import UNetSeeInDark
from sensenoise500 import add_noise_torch
from sid_dataset_sensenoise500 import sensenoise_dataset, apply_wb_ccm, sensenoise_dataset_2, \sensenoise_dataset_addnoise, sensenoise_dataset_addnoise_2, choose_k_sigma
import torchvisionif __name__ == "__main__":# 1.当前版本信息print(torch.__version__)print(torch.version.cuda)print(torch.backends.cudnn.version())print(torch.cuda.get_device_name(0))np.random.seed(0)torch.manual_seed(0)torch.cuda.manual_seed_all(0)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = False# 2. 设置device信息 和 创建modelmodel = UNetSeeInDark()model._initialize_weights()gpus = [1]#model = nn.DataParallel(model, device_ids=gpus)device = torch.device('cuda:1')model = model.cuda(device=gpus[0])# 6. 是否恢复模型resume = 0last_epoch = 0lr_epoch = 1if resume and last_epoch > 1:model.load_state_dict(torch.load(os.path.join(save_model_dir, f'checkpoint_{last_epoch:04d}.pth'), map_location=device))lr_epoch = 0.5**(last_epoch // 500)# 3. dataset 和 data loader, num_workers设置线程数目,pin_memory设置固定内存# train_dataset = sensenoise_dataset_addnoise_2(mode='train')# train_dataset_loader = DataLoader(train_dataset, batch_size=4*len(gpus), shuffle=True, num_workers=8, pin_memory=True)eval_dataset = sensenoise_dataset_2(mode='eval')eval_dataset_loader = DataLoader(eval_dataset, batch_size=1, num_workers=8, pin_memory=True)print('load dataset !')files = glob.glob(os.path.join('/home/wangzhansheng/dataset/sidd/SenseNoise500/final_datasetv3/', '*.dng'))files = sorted(files)[:400]datas = []for file in files:input_path = filetxt_path = input_path[:-4] + '.txt'para = np.loadtxt(txt_path)wb_gain = np.array(para[:3]).astype(np.float32)ccm = np.array(para[3:12]).astype(np.float32).reshape(3, 3)iso = para[-1]# gt_raw = np.load(gt_path).astype(np.int32)# gt_raw = np.dstack((gt_raw[0::2, 0::2], gt_raw[0::2, 1::2], gt_raw[1::2, 0::2], gt_raw[1::2, 1::2]))input_raw = rawpy.imread(input_path).raw_image_visible.astype(np.float32)input_raw = np.dstack((input_raw[0::2, 0::2], input_raw[0::2, 1::2], input_raw[1::2, 0::2], input_raw[1::2, 1::2]))datas.append([input_raw, wb_gain, ccm, input_path, iso])print(file, len(datas))# 4. 损失函数 和  优化器loss_fn = nn.L1Loss()learning_rate = 3*1e-4optimizer = optim.Adam(model.parameters(), lr=learning_rate)lr_step = 500scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_step, gamma=0.5)# 5. hyper para 设置epochs = 5000save_epoch = 100save_model_dir = 'saved_model_sensenoise500_addnoise_single_idr'eval_epoch = 100save_sample_dir = 'saved_sample_sensenoise500_addnoise_single_idr'if not os.path.exists(save_model_dir):os.makedirs(save_model_dir)# 7. 训练epochepoch_infos = []eval_infos = []patch_size = 512for epoch in range(last_epoch+1, epochs + 1):print('current epoch:', epoch, 'current lr:', optimizer.state_dict()['param_groups'][0]['lr'])if epoch < 101:save_epoch = 10eval_epoch = 10else:save_epoch = 100eval_epoch = 100# 8. train loopmodel_copy = UNetSeeInDark().to(device)model_copy.load_state_dict(model.state_dict())model_copy.eval()model.train()g_loss = []g_psnr = []kk = 0for idx in tqdm(np.random.permutation(len(datas))):data = datas[idx]#for data in np.random.shuffle(datas):# gt_path = file# txt_path = gt_path[:-4] + '.txt'# para = np.loadtxt(txt_path)# wb_gain = np.array(para[:3]).astype(np.float32)# ccm = np.array(para[3:12]).astype(np.float32).reshape(3, 3)# iso = para[-1]## gt_raw = np.load(gt_path).astype(np.int32)# iso, k, sigma = choose_k_sigma(iso/2)# k = k * np.random.uniform(0.8, 1.2)# sigma2 = np.sqrt(sigma) * np.random.uniform(0.8, 1.1)# short_raw = k * np.random.poisson(gt_raw / k) + np.random.normal(0., sigma2, gt_raw.shape)# gt_raw = gt_raw / 1023# short_raw = short_raw / 1023input_raw, wb_gain, ccm, gt_path, iso = data# croph, w, c = input_raw.shapeh1 = np.random.randint(0, h - patch_size)w1 = np.random.randint(0, w - patch_size)# short_raw = short_raw[h1:h1 + patch_size, w1:w1 + patch_size, :]short_raw = input_raw[h1:h1 + patch_size, w1:w1 + patch_size, :]# augmentif np.random.randint(2, size=1)[0] == 1:  # random flipshort_raw = np.flip(short_raw, axis=0)#gt_raw = np.flip(gt_raw, axis=0)if np.random.randint(2, size=1)[0] == 1:short_raw = np.flip(short_raw, axis=1)#gt_raw = np.flip(gt_raw, axis=1)if np.random.randint(2, size=1)[0] == 1:  # random transposeshort_raw = np.transpose(short_raw, (1, 0, 2))#gt_raw = np.transpose(gt_raw, (1, 0, 2))#in_img = torch.permute(input_patch, (0,3,1,2)).cuda(device=gpus[0])short_raw = np.ascontiguousarray(short_raw[np.newaxis, ...])gt_img = torch.from_numpy(short_raw).permute(0, 3, 1, 2)if epoch > last_epoch + 1:model_copy.eval()with torch.no_grad():gt_img_last = gt_img.to(device) / 1023gt_img = model_copy(gt_img_last).cpu()gt_img = torch.clamp(gt_img* 1023, 0, 1023)# print(gt_img_last.min(), gt_img_last.max(), gt_img_last.mean(), gt_img_last.var())# print(gt_img.min(), gt_img.max(), gt_img.mean(), gt_img.var()) # more and more littleif kk  > 50000:im1 = gt_img_last.cpu().float().numpy().squeeze().transpose(1, 2, 0)im2 = gt_img.float().numpy().squeeze().transpose(1, 2, 0) / 1023im1 = im1[..., [0, 1, 3]] ** (1 / 2.2)im2 = im2[..., [0, 1, 3]] ** (1 / 2.2)im1 = np.clip(im1 * 255 + 0.5, 0, 255).astype(np.uint8)im2 = np.clip(im2 * 255 + 0.5, 0, 255).astype(np.uint8)save_sample_dir3 = save_sample_dir + f'/{epoch:04}dd/'if not os.path.isdir(save_sample_dir3):os.makedirs(save_sample_dir3)filename_save = os.path.basename(gt_path)[:-4]cv2.imwrite(os.path.join(save_sample_dir3, '%s_dddd1.png' % (filename_save)), im1[..., ::-1])cv2.imwrite(os.path.join(save_sample_dir3, '%s_dddd2.png' % (filename_save)),im2[..., ::-1])iso, k, sigma = choose_k_sigma(iso/2)# k = np.random.uniform(0.8, 3)# scale = np.random.uniform(1, 30)# k = torch.FloatTensor(k)# scale = torch.FloatTensor(scale)scale = np.sqrt(sigma)in_img = add_noise_torch(gt_img, k, scale).to(device)gt_img = gt_img.to(device)# print(in_img.min(), in_img.max(), in_img.mean(), in_img.var())# print(gt_img.min(), gt_img.max(), gt_img.mean(), gt_img.var()) # more and more littlegt_img = gt_img / 1023in_img = in_img / 1023in_img = torch.clamp(in_img, 0, 1)# print(gt_img.shape, gt_img.min(), gt_img.max())# print(in_img.shape, in_img.min(), in_img.max())# print(wb_gain, ccm, iso, gt_path)out = model(in_img)loss = loss_fn(out, gt_img)optimizer.zero_grad()loss.backward()optimizer.step()# training resultg_loss.append(loss.data.detach().cpu())mse_value = np.mean((out.cpu().data.numpy() - gt_img.cpu().data.numpy()) ** 2)psnr = 10. * np.log10(1. / mse_value)g_psnr.append(psnr)mean_loss = np.mean(np.array(g_loss))mean_psnr = np.mean(np.array(g_psnr))print(f'epoch{epoch:04d} ,train loss: {mean_loss},train psnr: {mean_psnr}')epoch_infos.append([epoch, mean_loss, mean_psnr])# 9. save modelif epoch % save_epoch == 0:save_model_path = os.path.join(save_model_dir, f'checkpoint_{epoch:04d}.pth')torch.save(model.state_dict(), save_model_path)# 10. eval test and save some samples if neededif epoch % eval_epoch == 0:model.eval()k = 0with torch.no_grad():psnr_12800_0 = []psnr_12800_1 = []ssim_12800_0 = []ssim_12800_1 = []for data in tqdm(eval_dataset_loader):input_patch, gt_patch, wb_gain, ccm, gt_path, iso = datain_img = input_patch.permute(0, 3, 1, 2).cuda(device=gpus[0])gt_img = gt_patch.permute(0, 3, 1, 2).cuda(device=gpus[0])out = model(in_img)im1 = gt_img.detach().cpu().float().numpy().squeeze().transpose(1,2,0)im2 = out.detach().cpu().float().numpy().squeeze().transpose(1,2,0)im1 = np.clip(im1 * 255 + 0.5, 0, 255).astype(np.uint8)im2 = np.clip(im2 * 255 + 0.5, 0, 255).astype(np.uint8)temp_psnr = compare_psnr(im1, im2, data_range=255)temp_ssim = compare_ssim(im1, im2, data_range=255, channel_axis=-1)if iso <= 12800:psnr_12800_0.append(temp_psnr)ssim_12800_0.append(temp_ssim)else:psnr_12800_1.append(temp_psnr)ssim_12800_1.append(temp_ssim)# show training outsave_img = 1if save_img and k<10:k += 1im_input = in_img.detach().permute(0, 2, 3, 1).cpu().float().numpy()[0]im_gt = gt_img.detach().permute(0, 2, 3, 1).cpu().float().numpy()[0]im_out = out.detach().permute(0, 2, 3, 1).cpu().float().numpy()[0]wb_gain = wb_gain.data.cpu().numpy()[0]ccm = ccm.data.cpu().numpy()[0]gt_path = gt_path[0]pattern_sensenoise500 = 'RGGB'im_input_srgb = apply_wb_ccm(im_input[..., [0, 1, 3]], wb_gain, ccm, pattern_sensenoise500)im_gt_srgb = apply_wb_ccm(im_gt[..., [0, 1, 3]],  wb_gain, ccm, pattern_sensenoise500)im_out_srgb = apply_wb_ccm(im_out[..., [0, 1, 3]],  wb_gain, ccm, pattern_sensenoise500)im_input_srgb = np.clip(im_input_srgb * 255 + 0.5, 0, 255).astype(np.uint8)im_gt_srgb = np.clip(im_gt_srgb * 255 + 0.5, 0, 255).astype(np.uint8)im_out_srgb = np.clip(im_out_srgb * 255 + 0.5, 0, 255).astype(np.uint8)save_sample_dir2 = save_sample_dir + f'/{epoch:04}/'if not os.path.isdir(save_sample_dir2):os.makedirs(save_sample_dir2)# save_sample_path = os.path.join(save_sample_dir2, os.path.basename(gt_path)[:-4]+'.png')# cv2.imwrite(save_sample_path, np.hstack((im_gt_srgb,im_input_srgb, im_out_srgb))[..., ::-1])filename_save = os.path.basename(gt_path)[:-4]cv2.imwrite(os.path.join(save_sample_dir2, '%s_psnr_%.2f_out.png' % (filename_save, temp_psnr)), im_out_srgb[...,::-1])cv2.imwrite(os.path.join(save_sample_dir2, '%s_NOISY.png' % (filename_save)), im_input_srgb[...,::-1])cv2.imwrite(os.path.join(save_sample_dir2, '%s_GT.png' % (filename_save)), im_gt_srgb[...,::-1])print('eval dataset  psnr: ', np.array(psnr_12800_0).mean(), np.array(psnr_12800_1).mean())print('eval dataset  ssim: ', np.array(ssim_12800_0).mean(), np.array(ssim_12800_1).mean())eval_infos.append([epoch, np.array(psnr_12800_0).mean(), np.array(psnr_12800_1).mean(), np.array(ssim_12800_0).mean(), np.array(ssim_12800_1).mean()])scheduler.step() # 更新学习率np.savetxt('train_infos.txt',  epoch_infos, fmt='%.4f') # epoch loss psnrnp.savetxt('eval_infos.txt', eval_infos, fmt='%.4f')    # epoch psnr, ssim

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/27487.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

YOLOv7 yaml 文件简化

文章目录 修改方式common.pyyolo.pyYOLOv7-ELAN.yaml原始的 YOLOv7 yaml 文件的模块是拆开写的,比较乱, 改进起来也不太容易,这篇博文将 YOLOv7 yaml 文件换了一种写法, 参数量和计算量是完全和原来一致的,区别只是在于 yaml文件的写法不同, 封装后具体的结构可以参考…

希尔排序

希尔排序 排序步骤 1、分组&#xff0c;以任意长度进行分组&#xff08;这个长度我们称作增量gap&#xff09;&#xff1b;通常以总长度的一半这个数为依据进行分组&#xff0c;每间隔 gap 个数即为一组 2、组内排序&#xff1b;组内使用插入排序法进行排序 3、重新设置间隔…

微服务Gateway网关(自动定位/自定义过滤器/解决跨域)+nginx反向代理gateway集群

目录 Gateway网关 1.0.为什么需要网关&#xff1f; 1.1.如何使用gateway网关 1.2.网关从注册中心拉取服务 1.3.gateway自动定位 1.4.gateway常见的断言 1.5.gateway内置的过滤器 1.6.自定义过滤器-全局过滤器 1.7.解决跨域问题 2.nginx反向代理gateway集群 2.1.配置…

Upsource的下载安装使用

一&#xff0c;下载 下载地址&#xff1a; https://www.jetbrains.com/upsource/下载并解压到指定的文件夹 ├── api ├── apps ├── backups # 备份目录 ├── bin # 应用目录 ├── conf # 配置文件 ├── data ├── internal ├── launcher ├── lib ├─…

Java集合详解

1. 集合基础 1.1 集合概述 1.2 ArrayList构造方法和添加方法 1.3 ArrayList集合常用方法 1. 集合基础 1.1 集合概述 集合类的特点:提供一种存储空间可变的存储横型&#xff0c;存储的数据容量可以发生改变 ArrayList ArrayList< >: 可调整大小的数组实现 < >:是…

TextView 必填项pro版

优点 基本解决对齐方式,可以设置前缀隐藏和显示 /*** https://blog.csdn.net/u013982652/article/details/94404711* Android自定义TextView实现必填项前面的*号* 另一种实现方式(推荐使用这种,有非必填情况的话不会有对齐问题)* <p>* <cn.mvp.mlibs.weight.MiRequire…

【Fiddler】Fiddler实现mock测试(模拟接口数据)

软件接口测试过程中&#xff0c;经常会遇后端接口还没有开发完成&#xff0c;领导就让先介入测试&#xff0c;然后缩短项目时间&#xff0c;有的人肯定会懵&#xff0c;接口还没开发好&#xff0c;怎么介入测试&#xff0c;其实这就涉及到了我们要说的mock了。 一、mock原理 m…

CentOS 安装字体 微软雅黑

fc-list命令查看已经安装的字体 fc-list :langzh命令可以查看已安装的中文字体 找到windows系统里面的字体 上传到服务器 /usr/share/fonts/winFonts 下&#xff0c;winFonts目录是自己建立的&#xff0c;名称无要求 如果C:\Windows\Fonts下的字体没法直接传输将这个文件夹复…

Leetcode-每日一题【24.两两交换链表中的节点】

题目 给你一个链表&#xff0c;两两交换其中相邻的节点&#xff0c;并返回交换后链表的头节点。你必须在不修改节点内部的值的情况下完成本题&#xff08;即&#xff0c;只能进行节点交换&#xff09;。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4]输出&#xff1a;[…

rust 引用怎么用

本来好好的引用&#xff0c;被 rust 玩坏了&#xff0c;搞得自己都不会使用引用了&#xff0c;我们还是从简单的例子入手&#xff0c;来探索使用引用可能遇到额问题。 下面的示例代码编译不通过&#xff0c;在 s1 赋值给变量 s2 的过程中&#xff0c;字符串 neojos 值的所有权…

dede tag彩色随机大小的样式修改方法

dede tag彩色随机大小的样式修改方法&#xff0c;打开include/common.func.php 在最下面添加以下代码&#xff1a; //TAG彩色 jinmengqiang.cn function getTagStyle() { $minFontSize8; //最小字体大小,可根据需要自行更改 $maxFontSize18; //最大字体大小,可根据需要自行更改…

Unity 上传文件到阿里云 对象存储OSS服务器

首先登录阿里云 免费试用–对象存储OSS --点击立即试用&#xff0c;可以有三个月的免费试用 创建Buket 新建AccessKey ,新建完成后&#xff0c;会有一个CSV文件&#xff0c;下载下来&#xff0c;里面有Key &#xff0c;代码中需要用到 下载SDK 双击打开 sln文件&#xff0…