超分辨率(2)--基于EDSR网络实现图像超分辨率重建

目录

一.项目介绍

二.项目流程详解

2.1.构建网络模型

2.2.数据集处理

2.3.训练模块

2.4.测试模块

三.测试网络


一.项目介绍

EDSR全称Enhanced Deep Residual Networks,是SRResnet的升级版,其对网络结构进行了优化(去除了BN层),省下来的空间可以用于提升模型的size来增强表现力。

为什么要去除BN层:

Batch Norm是深度学习中非常重要的技术,不仅可以使训练更深的网络变容易,加速收敛,还有一定正则化的效果,可以防止模型过拟合。

但对于图像超分辨率来说,网络输出的图像在色彩、对比度、亮度上要求和输入一致,改变的仅仅是分辨率和一些细节,而Batch Norm,对图像来说类似于一种对比度的拉伸,任何图像经过Batch Norm后,其色彩的分布都会被归一化,也就是说,它破坏了图像原本的对比度信息,所以Batch Norm的加入反而影响了网络输出的质量。

网络结构及对比:

移除BN层后,模型更加轻量,BN层所消耗的存储空间等同于上一层CNN层所消耗的,作者指出相比于SRResNet,EDSR去掉BN层之后节约了40%的存储资源。

同时在BN腾出来的空间下插入更多的类似于残差块等CNN-based子网络来增加模型的表现力。

论文地址:

[1707.02921] Enhanced Deep Residual Networks for Single Image Super-Resolution (arxiv.org)icon-default.png?t=N7T8https://arxiv.org/abs/1707.02921源码地址:

developer0hye/EDAR: PyTorch implementation of Deep Convolution Networks based on EDSR for Compression(Jpeg) Artifacts Reduction (github.com)icon-default.png?t=N7T8https://github.com/developer0hye/EDAR

二.项目流程详解

2.1.构建网络模型

def default_conv(in_channels, out_channels, kernel_size, bias=True):return nn.Conv2d(in_channels, out_channels, kernel_size,padding=(kernel_size//2), bias=bias)class MeanShift(nn.Conv2d):def __init__(self, rgb_mean, rgb_std, sign=-1):super(MeanShift, self).__init__(3, 3, kernel_size=1)std = torch.Tensor(rgb_std)self.weight.data = torch.eye(3).view(3, 3, 1, 1)self.weight.data.div_(std.view(3, 1, 1, 1))self.bias.data = sign * torch.Tensor(rgb_mean)self.bias.data.div_(std)self.requires_grad = Falseclass ResBlock(nn.Module):def __init__(self, conv, n_feat, kernel_size,bias=True, act=nn.ReLU(True)):super(ResBlock, self).__init__()m = []for i in range(2):m.append(conv(n_feat, n_feat, kernel_size, bias=bias))if i == 0: m.append(act)# m是设置好的conv层# 设置网络内部层次结构为bodyself.body = nn.Sequential(*m)def forward(self, x):# 获取当前的结果res = self.body(x)# 当前得到的网络和最初的网络融合res += xreturn res

class EDAR(nn.Module):def __init__(self, conv=common.default_conv):super(EDAR, self).__init__()# 参数设置n_resblock = 8  # resnet长度n_feats = 64kernel_size = 3  # 卷积核大小#DIV 2K meanrgb_mean = (0.4488, 0.4371, 0.4040)rgb_std = (1.0, 1.0, 1.0)self.sub_mean = common.MeanShift(rgb_mean, rgb_std)# define head module# 经过卷积,特征图数由3->n_featsm_head = [conv(3, n_feats, kernel_size)]# define body module# Residual Block设置m_body = [common.ResBlock(conv, n_feats, kernel_size) for _ in range(n_resblock)]m_body.append(conv(n_feats, n_feats, kernel_size))# define tail module# 经过卷积,特征图数由n_feats->3m_tail = [conv(n_feats, 3, kernel_size)]self.add_mean = common.MeanShift(rgb_mean, rgb_std, 1)# 设置网络的三个层次self.head = nn.Sequential(*m_head)self.body = nn.Sequential(*m_body)self.tail = nn.Sequential(*m_tail)

前向传播过程:

    def forward(self, x):x = self.sub_mean(x)x = self.head(x)res = self.body(x)res += xx = self.tail(res)x = self.add_mean(x)# 将输入input张量每个元素的范围限制到区间 [min,max],返回结果到一个新张量。# 及输出一个新张量值x,并限制他的值在0~1之间return torch.clamp(x,0.0,1.0)

2.2.数据集处理

import os
import io
import random
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = Trueclass Dataset(object):def __init__(self, images_dir, patch_size=48, jpeg_quality=40, transforms=None):self.images = os.walk(images_dir).__next__()[2]self.images_path = []for img_file in self.images:if img_file.endswith((".ppm")):try:#print(os.path.join(images_dir, img_file))label = Image.open(os.path.join(images_dir, img_file))self.images_path.append(os.path.join(images_dir, img_file))except:print(f"Image {os.path.join(images_dir, img_file)} didn't get loaded")self.patch_size = patch_sizeself.jpeg_quality = jpeg_qualityself.transforms = transformsself.random_rotate = [0, 90, 180, 270]def __getitem__(self, idx):label = Image.open(self.images_path[idx]).convert('RGB')label = label.rotate(self.random_rotate[random.randrange(0,4)])# randomly crop patch from training setcrop_x = random.randint(0, label.width - self.patch_size)crop_y = random.randint(0, label.height - self.patch_size)# 使用crop函数对图片进行裁剪label = label.crop((crop_x, crop_y, crop_x + self.patch_size, crop_y + self.patch_size))# additive jpeg noisebuffer = io.BytesIO()label.save(buffer, format='jpeg', quality=random.randrange(self.jpeg_quality+1))input = Image.open(buffer).convert('RGB')if self.transforms is not None:input = self.transforms(input)label = self.transforms(label)#print("Image transformed")return input, labeldef __len__(self):return len(self.images_path)

2.3.训练模块

import argparse
import osfrom dataset import Dataset
from edar import EDARimport torch
from torch import nn
from torch.utils.data.dataloader import DataLoader
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torchvision import transforms
from torchvision.models.vgg import vgg16from utils import AverageMeter
from tqdm import tqdmif __name__ == '__main__':'''It enables benchmark mode in cudnn.benchmark mode is good whenever your input sizes for your network do not vary. This way, cudnn will look for the optimal set of algorithms for that particular configuration (which takes some time). This usually leads to faster runtime.But if your input sizes changes at each iteration, then cudnn will benchmark every time a new size appears, possibly leading to worse runtime performances.'''cudnn.benchmark = Truedevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 参数设置parser = argparse.ArgumentParser()# required为true的参数则是必须要设置的参数parser.add_argument('--images_dir', type=str, required=True)parser.add_argument('--outputs_dir', type=str, required=True)parser.add_argument('--jpeg_quality', type=int, default=40)parser.add_argument('--patch_size', type=int, default=48)parser.add_argument('--batch_size', type=int, default=16)parser.add_argument('--num_epochs', type=int, default=400)parser.add_argument('--lr', type=float, default=1e-4)parser.add_argument('--threads', type=int, default=1)parser.add_argument('--seed', type=int, default=123)parser.add_argument('--resume', default='', type=str, metavar='PATH',help='path to latest checkpoint (default: none)')opt = parser.parse_args()# 如果输出文件夹不存在,则自动创建一个文件夹if not os.path.exists(opt.outputs_dir):os.makedirs(opt.outputs_dir)torch.manual_seed(opt.seed)transforms_train = transforms.Compose([transforms.ToTensor()])# 模型设置model = EDAR().to(device)print("Model loaded")if opt.resume:if os.path.isfile(opt.resume):state_dict = model.state_dict()for n, p in torch.load(opt.resume, map_location=lambda storage, loc: storage).items():if n in state_dict.keys():state_dict[n].copy_(p)else:raise KeyError(n)# 损失函数设置criterion = nn.L1Loss()# 优化器设置optimizer = optim.Adam(model.parameters(), lr=opt.lr)print("Data processing started")# 数据集设置dataset = Dataset(opt.images_dir, opt.patch_size, opt.jpeg_quality,transforms=transforms_train)dataloader = DataLoader(dataset=dataset,batch_size=opt.batch_size,shuffle=True,num_workers=opt.threads,pin_memory=True,drop_last=True)print("Data loading completed")#vgg = vgg16(pretrained=True).cuda()#loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
#     for param in loss_network.parameters():
#         param.requires_grad = False# 开始训练for epoch in range(opt.num_epochs):epoch_losses = AverageMeter()print("Length of the dataset is", len(dataset))with tqdm(total=(len(dataset) - len(dataset) % opt.batch_size)) as _tqdm:_tqdm.set_description('epoch: {}/{}'.format(epoch + 1, opt.num_epochs))# 按照dataloader的格式取出datafor data in dataloader:inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)#print(inputs.size(), labels.size())outs = model(inputs)# 损失值计算,参数是预测值和实际值loss = criterion(outs, labels)#perception_loss = criterion(loss_network(outs), loss_network(labels))#loss = loss + perception_loss*0.06epoch_losses.update(loss.item(), len(inputs))# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 更新参数optimizer.step()_tqdm.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))_tqdm.update(len(inputs))torch.save(model.state_dict(), os.path.join(opt.outputs_dir, '{}_epoch_{}.pth'.format("EDAR_", epoch)))

2.4.测试模块

import argparse
import os
import io
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms
import PIL.Image as pil_image
import globfrom edar import EDARcudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")if __name__ == '__main__':# 参数设置parser = argparse.ArgumentParser()parser.add_argument('--weights_path', type=str, required=True)parser.add_argument('--image_path', type=str, required=True)parser.add_argument('--outputs_dir', type=str, required=True)parser.add_argument('--jpeg_quality', type=int, default=40)parser.add_argument('--input_dir', type=str, required=False)opt, unknown = parser.parse_known_args()model = EDAR()state_dict = model.state_dict()# 参数获取for n, p in torch.load(opt.weights_path, map_location=lambda storage, loc: storage).items():if n in state_dict.keys():state_dict[n].copy_(p)else:raise KeyError(n)model = model.to(device)print(device)model.eval()if opt.input_dir:filenames = [os.path.join(opt.input_dir, file) for file in os.listdir(opt.input_dir) if file.endswith(("ppm", "jpeg", "png", "jpg"))]print(filenames)else:filenames = opt.image_pathif not os.path.exists(opt.outputs_dir):os.makedirs(opt.outputs_dir)# 处理单个测试图片时使用:filename = filenamesprint("file is", filename)input = pil_image.open(filename).convert('RGB')print("Input size:", input.size)print("file is", filename)input = pil_image.open(filename).convert('RGB')print("Input size:", input.size)#buffer = io.BytesIO()#input.save(buffer, format='jpeg', quality=opt.jpeg_quality)#input = pil_image.open(buffer)#input.save(os.path.join(opt.outputs_dir, '{}_jpeg_q{}.png'.format(filename, opt.jpeg_quality)))input = transforms.ToTensor()(input).unsqueeze(0).to(device)output_path = os.path.join(opt.outputs_dir, '{}-{}.jpeg'.format(filename.split("/")[-1].split(".")[0], "edar"))if not os.path.exists(output_path):with torch.no_grad():pred = model(input)[-1]pred = pred.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()output = pil_image.fromarray(pred, mode='RGB')print("Output size", output.size)print("Output dir is", opt.outputs_dir)output.save(output_path)#print(os.path.join(opt.outputs_dir, '{}_{}.png'.format(filename, "EDAR")))#print("Output saved")'''处理多个测试图片时使用:for filename in filenames:print("file is", filename)input = pil_image.open(filename).convert('RGB')print("Input size:", input.size)# buffer = io.BytesIO()# input.save(buffer, format='jpeg', quality=opt.jpeg_quality)# input = pil_image.open(buffer)# input.save(os.path.join(opt.outputs_dir, '{}_jpeg_q{}.png'.format(filename, opt.jpeg_quality)))input = transforms.ToTensor()(input).unsqueeze(0).to(device)output_path = os.path.join(opt.outputs_dir, '{}-{}.jpeg'.format(filename.split("/")[-1].split(".")[0], "edar"))if not os.path.exists(output_path):with torch.no_grad():pred = model(input)[-1]pred = pred.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()output = pil_image.fromarray(pred, mode='RGB')print("Output size", output.size)print("Output dir is", opt.outputs_dir)output.save(output_path)# print(os.path.join(opt.outputs_dir, '{}_{}.png'.format(filename, "EDAR")))# print("Output saved")'''

三.测试网络

参数设置:

输入图片:

输出图片:

输入图片:

输出图片:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/536155.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决ipconfig不是内部或外部命令,也不是可运行的程序或批处理文件

问题所示:ipconfig不是内部或外部命令,也不是可运行的程序或批处理文件。 解决办法如下: 1.右击此电脑,点击属性设置: 2.点击高级系统设置 3.点击进入环境变量 4.在系统变量中进行设置,双击PATH进行配置 5.点击新建&am…

python-0007-django模版

介绍 模版是对js,html等资源的封装 新建 在项目路径下新建模版文件夹templates(可以为其他名称),要是想细分业务的话,还可以在templates路径下继续建文件夹。如下图: 注册模版 在项目的settings找到T…

基于springboot实现线上阅读系统项目【项目源码+论文说明】

基于springboot实现线上阅读系统演示 摘要 随着社会发展速度的愈来愈快,以及社会压力变化的越来越快速,致使很多人采取各种不同的方法进行解压。大多数人的稀释压力的方法,是捧一本书籍,心情地让自己沉浸在情节里面,以…

DARTS: DIFFERENTIABLE ARCHITECTURE SEARCH

DARTS:可微架构搜索 论文链接:https://arxiv.org/abs/1806.09055 项目链接:https://github.com/quark0/darts ABSTRACT 本文通过以可微分的方式表述任务,解决了架构搜索的可扩展性挑战。与在离散和不可微搜索空间上应用进化或强…

如何在WordPress网站上设置多语言展示

在今天的全球化世界中,拥有多语言网站对于吸引更广泛的受众至关重要。前不就我们遇到Hostease的客户咨询我们的在线客服,他想要对他的wordpress网站支持多语言。我们提供给客户可以尝试以下的插件来支持多语言。 在本教程中,我们将逐步介绍如…

HarmonyOS 发送系统通知 基础通知

harmonyos中 提供了 各种不同功能的通知 来满足我们不同的业务需求 本文呢 我们先来看 最基本的通知 它的场景和使用方式 最基本的 是 当应用被切入后台 而应用内部发生了一些变化 就可以用通知的方式 来提醒用户 就比如 我们微信 被切到了后台 而这时 别人发消息过来了 那么…

AI人工智能改变项目管理及控制的5种方式,最后一个你认同吗?

前言 斯蒂芬霍金说:“AI is likely to be either the best or the worst thing to happen to humanity. 人工智能可能是发生在人类身上的最好或最坏的事情”。根据Gartner的说法,它将是前者。它预测到2020年,人工智能将减少1万个就业机会&am…

指纹芯片系列——ACM32FP0 二合一(主控+TK)锁控芯片,ACM32FP4三合一(主控+算法+语音)等介绍

随着智能设备的持续发展,指纹识别技术成为了现在智能终端市场和移动支付市场中占有率最高的生物识别技术。凭借高识别率、短耗时等优势,被广泛地运用在智能门锁、智能手机、智能家居等设备上。 我们推荐的在2015年进入指纹识别应用领域,自研高…

使用ChatGPT高效完成简历制作[中篇]-有爱AI实战教程(五)

演示站点: https://ai.uaai.cn 对话模块 官方论坛: www.jingyuai.com 京娱AI 导读:在使用 ChatGPT 时,当你给的指令越精确,它的回答会越到位,举例来说,假如你要请它帮忙写文案,如果没…

基本数据类型介绍

整数类型:byte、short、int、long Java各整数类型有固定的表数范围和字段长度,不受具体操作系统的影响,以保证Java程序的可移植性。 定义long类型的变量,赋值时需要以"l"或"L"作为后缀。 Java程序中变量通常…

【教学类-34-10】20240313 春天拼图(Midjounery生成线描图,4*4格拼图块)(AI对话大师)

作品展示: 背景需求: 利用华文彩云空心字(粗胖字体。凑满9个拼图)制作了3*3的拼图块 【教学类-34-09】20240310华文彩云学号拼图(3*3格子浅灰底图 深灰拼图块)(AI对话大师)-CSDN博…

Apache Paimon系列之:主键表

Apache Paimon系列之:主键表 一、主键表1.Bucket2.LSM Trees3.Compaction 二、数据分布1.固定Bucket2.动态Bucket3.正常动态Bucket模式4.跨分区更新插入动态存储桶模式 三、Merge Engine1.Deduplicate2.部分更新3.序列组4.聚合部分更新5.聚合6.Retract7.First Row 四…