基于SRGAN的人脸图像超分辨率

引言

SRGAN是第一个将GAN用在图像超分辨率上的模型。在这之前,超分辨率常用的损失是L1、L2这种像素损失,这使得模型倾向于学习到平均的结果,也就是给低分辨率图像增加“模糊的细节”。SRGAN引入GAN来解决这个问题。GAN可以生成“真实”的图像, 那么当“真实的图像”是清晰的图像时,也意味着GAN可以生成清晰的图像。但是,如果只用GAN损失,没有其他约束,并不能生成与低分辨率图像对应的高分辨率图像。所以,将像素损失和对抗损失相结合。此外,SRGAN还使用了感知损失,计算图像在特征空间的损失。

准备

import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
import paddle
import paddle as P
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout, AdaptiveAvgPool2D, MaxPool2D, AvgPool2Dnn.initializer.set_global_initializer(nn.initializer.Normal(mean=0.0,std=0.01), nn.initializer.Constant())

加载数据

使用CelebA数据集,实现人脸图像超分辨率。
为了不OOM,切块大小为44×44(而且CelebA也只能切这么大了),与原文96×96不同。

SCALE = 4
PATH = '/path/to/data/celeba/img_align_celeba/'
DIRS = os.listdir(PATH)
PATCH_SIZE = [44, 44, 3]def reader_patch(batchsize,scale=SCALE,patchsize=PATCH_SIZE):np.random.shuffle(DIRS)for filename in DIRS:LRs = np.zeros((batchsize,patchsize[2],patchsize[0],patchsize[1])).astype("float32")HRs = np.zeros((batchsize,patchsize[2],patchsize[0]*scale,patchsize[1]*scale)).astype("float32")image = Image.open(PATH+filename)sz = image.sizesz_row = sz[1]//(patchsize[0]*scale)*patchsize[0]*scalediff_row = sz[1] - sz_rowsz_col = sz[0]//(patchsize[1]*scale)*patchsize[1]*scalediff_col = sz[0] - sz_colrow_min = np.random.randint(diff_row+1)col_min = np.random.randint(diff_col+1)HR = image.crop((col_min,row_min,col_min+sz_col,row_min+sz_row))LR = HR.resize((sz[0]//(patchsize[1]*scale)*patchsize[1],sz[1]//(patchsize[0]*scale)*patchsize[0]), Image.BICUBIC)LR = np.array(LR).astype("float32") / 255 * 2 - 1HR = np.array(HR).astype("float32") / 255 * 2 - 1for batch in range(batchsize):rowMin, colMin = np.random.randint(0,LR.shape[0]-patchsize[0]+1), np.random.randint(0,LR.shape[1]-patchsize[1]+1)LRs[batch,:,:,:] = LR[rowMin:rowMin+patchsize[0], colMin:colMin+patchsize[1],:].transpose([2,0,1])HRs[batch,:,:,:] = HR[scale*rowMin:scale*(rowMin+patchsize[0]), scale*colMin:scale*(colMin+patchsize[1])].transpose([2,0,1])yield LRs, HRsdef data_augmentation(LR, HR): #数据增强:随机翻转、旋转if np.random.randint(2) == 1:LR = LR[:,:,:,::-1]HR = HR[:,:,:,::-1]n = np.random.randint(4)if n == 1:LR = LR[:,:,::-1,:].transpose([0,1,3,2])HR = HR[:,:,::-1,:].transpose([0,1,3,2])if n == 2:LR = LR[:,:,::-1,::-1]HR = HR[:,:,::-1,::-1]if n == 3:LR = LR[:,:,:,::-1].transpose([0,1,3,2])HR = HR[:,:,:,::-1].transpose([0,1,3,2])return LR, HRdata = reader_patch(1)
for i in range(2):LR, HR = next(data)LR = LR.transpose([2,3,1,0]).reshape(PATCH_SIZE[0],PATCH_SIZE[1],PATCH_SIZE[2])LR = Image.fromarray(np.uint8((LR+1)/2*255))HR = HR.transpose([2,3,1,0]).reshape(PATCH_SIZE[0]*SCALE,PATCH_SIZE[1]*SCALE,PATCH_SIZE[2])HR = Image.fromarray(np.uint8((HR+1)/2*255))plt.subplot(1,2,1), plt.imshow(LR),plt.title('LRx'+str(SCALE))plt.subplot(1,2,2), plt.imshow(HR),plt.title('HR')plt.show()

网络结构

生成器整体结构:

这是一个残差网络,名为SRResNet。首先用一个卷积提取浅层特征,然后经过一个残差层提取深层特征,最后是一个上采样层重建出高分辨率图像。
其中残差层包括16个残差块、一个卷积和跳级连接。
上采样层有两个上采样块和一个卷积。
除了第一个卷积和上采样层中的卷积,每个卷积后面都有BN(其实,BN在SR中没有效果甚至略差,SR输入和输出有相似的空间分布,而BN白化中间的特征的方式完全破坏了原始空间的表征,因此需要部分参数来恢复这种表征,所以同样多的参数,有BN的还要拿出一部分参数做恢复,效果就差了点)。
激活函数都为PReLU,由于我不知道怎么实现PReLU,所以用ReLU代替。。。

class G(nn.Layer): # 生成器SRResNetdef __init__(self, channel=64, num_rb=16):super(G, self).__init__()self.conv1 = nn.Conv2D(3, channel, 9, 1, 4)# self.prelu = nn.PReLU('all')self.prelu = nn.ReLU()self.rb_list = []for i in range(num_rb):self.rb_list += [self.add_sublayer('rb_%d' % i, RB(channel))]self.conv2 = nn.Conv2D(channel, channel, 3, 1, 1)self.bn = nn.BatchNorm2D(channel)self.us1 = US(channel, channel*4)self.us2 = US(channel, channel*4)self.conv3 = nn.Conv2D(channel, 3, 9, 1, 4)def forward(self, x):x = self.conv1(x)x = self.prelu(x)y = xfor rb in self.rb_list:y = rb(y)y = self.conv2(y)y = self.bn(y)y = x + yy = self.us1(y)y = self.us2(y)y = self.conv3(y)return y

残差块:

这是一个经典的残差块:conv、bn、relu(prelu)、conv、bn加跳过连接。

class RB(nn.Layer): # 残差块def __init__(self, channel=64):super(RB, self).__init__()self.conv1 = nn.Conv2D(channel, channel, 3, 1, 1)self.bn1 = nn.BatchNorm2D(channel)# self.prelu = nn.PReLU('all')self.prelu = nn.ReLU()self.conv2 = nn.Conv2D(channel, channel, 3, 1, 1)self.bn2 = nn.BatchNorm2D(channel)def forward(self, x):y = self.conv1(x)y = self.bn1(y)y = self.prelu(y)y = self.conv2(y)y = self.bn2(y)return x + y

上采样块:

包括conv、upscale_factor为2的pixelshuffle和prelu。
网络里用了两个上采样块,所以总的upscale_factor为4。

class US(nn.Layer): # 上采样块def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):super(US, self).__init__()self.conv = nn.Conv2D(in_channels, out_channels, kernel_size, stride, padding)self.ps = nn.PixelShuffle (2)# self.prelu = nn.PReLU('all')self.prelu = nn.ReLU()def forward(self, x):x = self.conv(x)x = self.ps(x)x = self.prelu(x)return x

判别器整体结构:

这是一个经典的结构,包括一系列的conv-bn-leakyrelu和两个全连接。
第一个conv后没有bn;除了最后的激活函数为sigmoid,其他都为leakyrelu。
由于有全连接的存在,不同的输入尺寸会有不同的全连接参数数量,这里的参数数量与论文中不同。

class D(nn.Layer): # 判别器def __init__(self, channel=64):super(D, self).__init__()self.layer_list = []self.layer_list += [self.add_sublayer('conv', nn.Conv2D(3, channel, 3, 1, 1))]self.layer_list += [self.add_sublayer('lrelu1', nn.LeakyReLU())]self.layer_list += [self.add_sublayer('cna1', CNA(channel, channel, 3, 2, [1,0,1,0]))]self.layer_list += [self.add_sublayer('cna2', CNA(channel, channel*2))]self.layer_list += [self.add_sublayer('cna3', CNA(channel*2, channel*2, 3, 2, [1,0,1,0]))]self.layer_list += [self.add_sublayer('cna4', CNA(channel*2, channel*4))]self.layer_list += [self.add_sublayer('cna5', CNA(channel*4, channel*4, 3, 2, [1,0,1,0]))]self.layer_list += [self.add_sublayer('cna6', CNA(channel*4, channel*8))]self.layer_list += [self.add_sublayer('cna7', CNA(channel*8, channel*8, 3, 2, [1,0,1,0]))]self.layer_list += [self.add_sublayer('flatten', nn.Flatten(start_axis=1, stop_axis=3))]self.layer_list += [self.add_sublayer('fc1', nn.Linear(PATCH_SIZE[0]*4//16*PATCH_SIZE[1]*4//16*channel*8, channel*16))]self.layer_list += [self.add_sublayer('lrelu2', nn.LeakyReLU())]self.layer_list += [self.add_sublayer('fc1', nn.Linear(channel*16, 1))]self.layer_list += [self.add_sublayer('sigmoid', nn.Sigmoid())]def forward(self, x):for layer in self.layer_list:x = layer(x)return x

conv + norm + act:

class CNA(nn.Layer): # conv-norm-actdef __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):super(CNA, self).__init__()self.conv = nn.Conv2D(in_channels, out_channels, kernel_size, stride, padding)self.bn = nn.BatchNorm(out_channels)self.lrelu = nn.LeakyReLU()def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.lrelu(x)return x

预训练网络VGG19。
代码链接:
https://github.com/PaddlePaddle/PaddleClas/blob/dygraph/ppcls/modeling/architectures/vgg.py
参数下载链接:
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/VGG19_pretrained.pdparams
这里使用conv5_4后激活层的输出。

class ConvBlock(nn.Layer):def __init__(self, input_channels, output_channels, groups, name=None):super(ConvBlock, self).__init__()self.groups = groupsself._conv_1 = Conv2D(in_channels=input_channels,out_channels=output_channels,kernel_size=3,stride=1,padding=1,weight_attr=ParamAttr(name=name + "1_weights"),bias_attr=False)if groups == 2 or groups == 3 or groups == 4:self._conv_2 = Conv2D(in_channels=output_channels,out_channels=output_channels,kernel_size=3,stride=1,padding=1,weight_attr=ParamAttr(name=name + "2_weights"),bias_attr=False)if groups == 3 or groups == 4:self._conv_3 = Conv2D(in_channels=output_channels,out_channels=output_channels,kernel_size=3,stride=1,padding=1,weight_attr=ParamAttr(name=name + "3_weights"),bias_attr=False)if groups == 4:self._conv_4 = Conv2D(in_channels=output_channels,out_channels=output_channels,kernel_size=3,stride=1,padding=1,weight_attr=ParamAttr(name=name + "4_weights"),bias_attr=False)self._pool = MaxPool2D(kernel_size=2, stride=2, padding=0)def forward(self, inputs):x = self._conv_1(inputs)x = F.relu(x)if self.groups == 2 or self.groups == 3 or self.groups == 4:x = self._conv_2(x)x = F.relu(x)if self.groups == 3 or self.groups == 4:x = self._conv_3(x)x = F.relu(x)if self.groups == 4:x = self._conv_4(x)x = F.relu(x)y = xx = self._pool(x)return x, yclass VGGNet(nn.Layer):def __init__(self):super(VGGNet, self).__init__()self.groups = [2, 2, 4, 4, 4]self._conv_block_1 = ConvBlock(3, 64, self.groups[0], name="conv1_")self._conv_block_2 = ConvBlock(64, 128, self.groups[1], name="conv2_")self._conv_block_3 = ConvBlock(128, 256, self.groups[2], name="conv3_")self._conv_block_4 = ConvBlock(256, 512, self.groups[3], name="conv4_")self._conv_block_5 = ConvBlock(512, 512, self.groups[4], name="conv5_")def forward(self, inputs):x, y = self._conv_block_1(inputs)x, y = self._conv_block_2(x)x, y = self._conv_block_3(x)x, y = self._conv_block_4(x)_, y = self._conv_block_5(x)return y
vgg19 = VGGNet()
vgg19.set_state_dict(P.load('/home/aistudio/work/vgg19_ww.pdparams'))
vgg19.eval()

辅助函数

在训练迭代中显示图像,以观察效果。

def show_image(srresnet=None, srgan=None, path=None):if srresnet == None:srresnet = G()srresnet.eval()if srgan == None:srgan = G() srgan.eval()fig = plt.figure(figsize=(25, 25))gs = plt.GridSpec(1, 4)gs.update(wspace=0.1, hspace=0.1)if path == None:image = Image.open(PATH+DIRS[np.random.randint(len(DIRS))])else:image = Image.open(path)image = image.crop([0,0,image.size[0]//SCALE*SCALE,image.size[1]//SCALE*SCALE])# image = image.crop([0,0,40,40])LR0 = image.resize((image.size[0]//SCALE,image.size[1]//SCALE),Image.BICUBIC)LR = np.array(LR0).astype('float32').reshape([LR0.size[1],LR0.size[0],3,1]).transpose([3,2,0,1]) / 255 * 2 - 1LSR_srresnet = srresnet(P.to_tensor(LR)).numpy()LSR_srresnet = LSR_srresnet.reshape([3,LR0.size[1]*SCALE,LR0.size[0]*SCALE]).transpose([1,2,0])# LSR_srresnet = Image.fromarray(np.uint8((LSR_srresnet+1)/2*255)) ### 亮斑的罪魁祸首LSR_srresnet = (LSR_srresnet+1)/2LSR_srgan = srgan(P.to_tensor(LR)).numpy()print(np.max(LSR_srgan), np.min(LSR_srgan))LSR_srgan = LSR_srgan.reshape([3,LR0.size[1]*SCALE,LR0.size[0]*SCALE]).transpose([1,2,0])# LSR_srgan = Image.fromarray(np.uint8((LSR_srgan+1)/2*255)) ### 亮斑的罪魁祸首LSR_srgan = (LSR_srgan+1)/2ax = plt.subplot(gs[0])plt.imshow(LR0)plt.title('LR')ax = plt.subplot(gs[1])plt.imshow(LSR_srresnet)plt.title('SRResNet')ax = plt.subplot(gs[2])plt.imshow(LSR_srgan)plt.title('SRGAN')ax = plt.subplot(gs[3])plt.imshow(image)plt.title('HR')plt.show()show_image()

 训练

为了与SRGAN作比较,同时训练一个SRResNet,也就是只使用了生成器,并只用L2损失来训练的网络。
SRGAN生成器的损失 = 图像L2损失 + λ1×感知损失 + λ2×对抗损失, 其中λ1=1e-2, λ2=1e-2。
SRResNet和SRGAN的生成器相同初始化。
由于Celeba比DIV2K图像数量多很多,epoch可以相对少一些。

def srresnet_trainer(lr, hr, srresnet, optimizer_srresnet):sr = srresnet(lr)loss = P.mean((sr-hr)**2)srresnet.clear_gradients()loss.backward()optimizer_srresnet.minimize(loss)def srgan_trainer(lr, hr, srgan_g, srgan_d, vgg, optimizer_srgan_g, optimizer_srgan_d, λ1=1e-2, λ2=1e-2):sr = srgan_g(lr)f = vgg(P.concat([sr,hr],axis=0))loss_content = P.mean((sr-hr)**2) + λ1*P.mean((f[:f.shape[0]//2,:,:,:]-f[f.shape[0]//2:,:,:,:])**2)d = srgan_d(P.concat([sr,hr],axis=0))loss_adversarial_g = P.mean(-P.log(d[:d.shape[0]//2,:]+1e-8))loss_adversarial_d = (P.mean(-P.log(d[d.shape[0]//2:,:]+1e-8)) + P.mean(-P.log(1-d[:d.shape[0]//2,:]+1e-8))) / 2loss_g = loss_content + λ2*loss_adversarial_gvgg.clear_gradients()srgan_g.clear_gradients()srgan_d.clear_gradients()loss_g.backward(retain_graph=True)loss_adversarial_d.backward()optimizer_srgan_g.minimize(loss_g)optimizer_srgan_d.minimize(loss_adversarial_d)def train(epoch_num=200,  load_model=False, batchsize=1, model_path = './output/'):srresnet = G()srgan_g = G()srgan_g.set_state_dict(srresnet.state_dict())srgan_d = D()srgan_d.train()optimizer_srresnet = P.optimizer.Adam(learning_rate=1e-4, beta1=0.9, parameters=srresnet.parameters())optimizer_srgan_g = P.optimizer.Adam(learning_rate=1e-4, beta1=0.9, parameters=srgan_g.parameters())optimizer_srgan_d = P.optimizer.Adam(learning_rate=1e-4, beta1=0.9, parameters=srgan_d.parameters())if load_model == True:srresnet.set_state_dict(P.load(model_path+'srresnet.pdparams'))srgan_g.set_state_dict(P.load(model_path+'srgan_g.pdparams'))srgan_d.set_state_dict(P.load(model_path+'srgan_d.pdparams'))srresnet.set_state_dict(P.load(model_path+'备用srresnet.pdparams'))srgan_g.set_state_dict(P.load(model_path+'备用srgan_g.pdparams'))srgan_d.set_state_dict(P.load(model_path+'备用srgan_d.pdparams'))iteration_num = 0for epoch in range(epoch_num):reader = reader_patch(batchsize)for iteration in range(len(DIRS)):srresnet.train()srgan_g.train()iteration_num += 1             LR, HR = next(reader)LR, HR = data_augmentation(LR, HR)LR = P.to_tensor(LR)HR = P.to_tensor(HR)srresnet_trainer(LR, HR, srresnet, optimizer_srresnet)srgan_trainer(LR, HR, srgan_g, srgan_d, vgg19, optimizer_srgan_g, optimizer_srgan_d)if(iteration_num % 100 == 0):print('Epoch: ', epoch, ', Iteration: ', iteration_num)            P.save(srresnet.state_dict(), model_path+'srresnet.pdparams')P.save(srgan_g.state_dict(), model_path+'srgan_g.pdparams')P.save(srgan_d.state_dict(), model_path+'srgan_d.pdparams')P.save(srresnet.state_dict(), model_path+'备用srresnet.pdparams')P.save(srgan_g.state_dict(), model_path+'备用srgan_g.pdparams')P.save(srgan_d.state_dict(), model_path+'备用srgan_d.pdparams')show_image(srresnet, srgan_g)  # train(epoch_num=1,  load_model=False, batchsize=16)
# train(epoch_num=998,  load_model=True, batchsize=16)

测试

可以看到图中有一些斑点,根据我的猜测,这是训练不充分导致的,总体上SRGAN的斑点更多,说明它比SRResNet需要更多训练,也就是它的上限更高。 老天爷,我之前竟然装模作样瞎分析一番,尴了个大尬。。。不删了,作为我成长的见证。。。出现斑点的原因其实是用了Image.fromarray(np.uint8())!不过说训练不充分也有道理,训练充分的话就不会超出范围,也就没这个幺蛾子啦。。
相对SRResNet来说,SRGAN不那么平滑,但是有些细节并不准确,更像是噪声,而且有时会出现奇怪的东西,例如额头上的亮光。

srresnet = G()
srgan_g = G()
model_path = './output/'
srresnet.set_state_dict(P.load(model_path+'srresnet.pdparams'))
srgan_g.set_state_dict(P.load(model_path+'srgan_g.pdparams'))
show_image(srresnet, srgan_g)   

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/215789.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue 2.0源码分析-Virtual DOM

Virtual DOM 这个概念相信大部分人都不会陌生,它产生的前提是浏览器中的 DOM 是很“昂贵"的,为了更直观的感受,我们可以简单的把一个简单的 div 元素的属性都打印出来,如图所示: 可以看到,真正的 DOM …

npm ERR! node-sass@4.13.0 postinstall: `node scripts/build.js`

npm ERR! node-sass4.13.0 postinstall: node scripts/build.js npm config set sass_binary_sitehttps://npm.taobao.org/mirrors/node-sass npm install npm run dev Microsoft Windows [版本 10.0.19045.2965] (c) Microsoft Corporation。保留所有权利。C:\Users\Administr…

R语言如何实现多元线性回归

输入数据 先把数据用excel保存为csv格式放在”我的文档”文件夹 打开R软件,不用新建,直接写 回归计算 求三个平方和 置信区间(95%)

通过内网穿透本地MariaDB数据库,实现在公网环境下使用navicat图形化工具

公网远程连接MariaDB数据库【cpolar内网穿透】 文章目录 公网远程连接MariaDB数据库【cpolar内网穿透】1. 配置MariaDB数据库1.1 安装MariaDB数据库1.2 测试局域网内远程连接 2. 内网穿透2.1 创建隧道映射2.2 测试随机地址公网远程访问3. 配置固定TCP端口地址3.1 保留一个固定的…

【点云surface】 凹包重构

1 处理过程可视化 原始数据 直通滤波过滤后 pcl::ProjectInliers结果 pcl::ExtractIndices结果 凹包结果 凸包结果 2 处理过程分析: 原始点云 ---> 直通滤波 --> pcl::SACSegmentation分割出平面 -->pcl::ProjectInliers投影 --> pcl::ConcaveHull凹包…

手把手教你安装 Visual Studio 2022 及其简单使用

软件下载 打开 Visual Studio 官网,个人选择免费的Community社区版就够用了。 软件安装 双击运行安装程序: 点击继续 即可: 等待加载完成: 可以看到 Visual Studio 2022 对应不同的开发需求提供了若干工作负载,这里以…

Go 本地搭建playground

搭建go playground 的步骤 1、安装docker 如果你使用的Ubuntu,docker的安装步骤可以参见这里,这是我之前写的在Ubuntu18.04下安装fabric,其中有docker的安装步骤,这里就不再赘述了。 CentOS下安装docker的,可以参见…

Ubuntu下使用protoBuf

一、protobuf简介: 1.1 protobuf的定义: protobuf是用来干嘛的? protobuf是一种用于 对结构数据进行序列化的工具,从而实现 数据存储和交换。 (主要用于网络通信中 收发两端进行消息交互。所谓的“结构数据”是指类…

DGL在异构图上的GraphConv模块

回顾同构图GraphConv模块 首先回顾一下同构图中实现GraphConv的主要思路(以GraphSAGE为例): 在初始化模块首先是获取源节点和目标节点的输入维度,同时获取输出的特征维度。根据SAGE论文提出的三种聚合操作,需要获取所…

HCIA-RS基础:动态路由协议基础

摘要:本文介绍动态路由协议的基本概念,为后续动态路由协议原理课程提供基础和引入。主要讲解常见的动态路由协议、动态路由协议的分类,以及路由协议的功能和自治系统的概念。文章旨在优化标题吸引力,并通过详细的内容夯实读者对动…

人工智能教程(二):人工智能的历史以及再探矩阵

目录 前言 更多矩阵的知识 Pandas 矩阵的秩 前言 在上一章中,我们讨论了人工智能、机器学习、深度学习、数据科学等领域的关联和区别。我们还就整个系列将使用的编程语言、工具等做出了一些艰难的选择。最后,我们还介绍了一点矩阵的知识。在本文中&am…

蓝桥杯物联网竞赛_STM32L071_4_按键控制

原理图: 当按键S1按下PC14接GND,为低电平 CubMX配置: Keil配置: main函数: while (1){/* USER CODE END WHILE */OLED_ShowString(32, 0, "hello", 16);if(Function_KEY_S1Check() 1){ OLED_ShowString(16, 2, &quo…