引言
SRGAN是第一个将GAN用在图像超分辨率上的模型。在这之前,超分辨率常用的损失是L1、L2这种像素损失,这使得模型倾向于学习到平均的结果,也就是给低分辨率图像增加“模糊的细节”。SRGAN引入GAN来解决这个问题。GAN可以生成“真实”的图像, 那么当“真实的图像”是清晰的图像时,也意味着GAN可以生成清晰的图像。但是,如果只用GAN损失,没有其他约束,并不能生成与低分辨率图像对应的高分辨率图像。所以,将像素损失和对抗损失相结合。此外,SRGAN还使用了感知损失,计算图像在特征空间的损失。
准备
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
import paddle
import paddle as P
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout, AdaptiveAvgPool2D, MaxPool2D, AvgPool2Dnn.initializer.set_global_initializer(nn.initializer.Normal(mean=0.0,std=0.01), nn.initializer.Constant())
加载数据
使用CelebA数据集,实现人脸图像超分辨率。
为了不OOM,切块大小为44×44(而且CelebA也只能切这么大了),与原文96×96不同。
SCALE = 4
PATH = '/path/to/data/celeba/img_align_celeba/'
DIRS = os.listdir(PATH)
PATCH_SIZE = [44, 44, 3]def reader_patch(batchsize,scale=SCALE,patchsize=PATCH_SIZE):np.random.shuffle(DIRS)for filename in DIRS:LRs = np.zeros((batchsize,patchsize[2],patchsize[0],patchsize[1])).astype("float32")HRs = np.zeros((batchsize,patchsize[2],patchsize[0]*scale,patchsize[1]*scale)).astype("float32")image = Image.open(PATH+filename)sz = image.sizesz_row = sz[1]//(patchsize[0]*scale)*patchsize[0]*scalediff_row = sz[1] - sz_rowsz_col = sz[0]//(patchsize[1]*scale)*patchsize[1]*scalediff_col = sz[0] - sz_colrow_min = np.random.randint(diff_row+1)col_min = np.random.randint(diff_col+1)HR = image.crop((col_min,row_min,col_min+sz_col,row_min+sz_row))LR = HR.resize((sz[0]//(patchsize[1]*scale)*patchsize[1],sz[1]//(patchsize[0]*scale)*patchsize[0]), Image.BICUBIC)LR = np.array(LR).astype("float32") / 255 * 2 - 1HR = np.array(HR).astype("float32") / 255 * 2 - 1for batch in range(batchsize):rowMin, colMin = np.random.randint(0,LR.shape[0]-patchsize[0]+1), np.random.randint(0,LR.shape[1]-patchsize[1]+1)LRs[batch,:,:,:] = LR[rowMin:rowMin+patchsize[0], colMin:colMin+patchsize[1],:].transpose([2,0,1])HRs[batch,:,:,:] = HR[scale*rowMin:scale*(rowMin+patchsize[0]), scale*colMin:scale*(colMin+patchsize[1])].transpose([2,0,1])yield LRs, HRsdef data_augmentation(LR, HR): #数据增强:随机翻转、旋转if np.random.randint(2) == 1:LR = LR[:,:,:,::-1]HR = HR[:,:,:,::-1]n = np.random.randint(4)if n == 1:LR = LR[:,:,::-1,:].transpose([0,1,3,2])HR = HR[:,:,::-1,:].transpose([0,1,3,2])if n == 2:LR = LR[:,:,::-1,::-1]HR = HR[:,:,::-1,::-1]if n == 3:LR = LR[:,:,:,::-1].transpose([0,1,3,2])HR = HR[:,:,:,::-1].transpose([0,1,3,2])return LR, HRdata = reader_patch(1)
for i in range(2):LR, HR = next(data)LR = LR.transpose([2,3,1,0]).reshape(PATCH_SIZE[0],PATCH_SIZE[1],PATCH_SIZE[2])LR = Image.fromarray(np.uint8((LR+1)/2*255))HR = HR.transpose([2,3,1,0]).reshape(PATCH_SIZE[0]*SCALE,PATCH_SIZE[1]*SCALE,PATCH_SIZE[2])HR = Image.fromarray(np.uint8((HR+1)/2*255))plt.subplot(1,2,1), plt.imshow(LR),plt.title('LRx'+str(SCALE))plt.subplot(1,2,2), plt.imshow(HR),plt.title('HR')plt.show()
网络结构
生成器整体结构:
这是一个残差网络,名为SRResNet。首先用一个卷积提取浅层特征,然后经过一个残差层提取深层特征,最后是一个上采样层重建出高分辨率图像。
其中残差层包括16个残差块、一个卷积和跳级连接。
上采样层有两个上采样块和一个卷积。
除了第一个卷积和上采样层中的卷积,每个卷积后面都有BN(其实,BN在SR中没有效果甚至略差,SR输入和输出有相似的空间分布,而BN白化中间的特征的方式完全破坏了原始空间的表征,因此需要部分参数来恢复这种表征,所以同样多的参数,有BN的还要拿出一部分参数做恢复,效果就差了点)。
激活函数都为PReLU,由于我不知道怎么实现PReLU,所以用ReLU代替。。。
class G(nn.Layer): # 生成器SRResNetdef __init__(self, channel=64, num_rb=16):super(G, self).__init__()self.conv1 = nn.Conv2D(3, channel, 9, 1, 4)# self.prelu = nn.PReLU('all')self.prelu = nn.ReLU()self.rb_list = []for i in range(num_rb):self.rb_list += [self.add_sublayer('rb_%d' % i, RB(channel))]self.conv2 = nn.Conv2D(channel, channel, 3, 1, 1)self.bn = nn.BatchNorm2D(channel)self.us1 = US(channel, channel*4)self.us2 = US(channel, channel*4)self.conv3 = nn.Conv2D(channel, 3, 9, 1, 4)def forward(self, x):x = self.conv1(x)x = self.prelu(x)y = xfor rb in self.rb_list:y = rb(y)y = self.conv2(y)y = self.bn(y)y = x + yy = self.us1(y)y = self.us2(y)y = self.conv3(y)return y
残差块:
这是一个经典的残差块:conv、bn、relu(prelu)、conv、bn加跳过连接。
class RB(nn.Layer): # 残差块def __init__(self, channel=64):super(RB, self).__init__()self.conv1 = nn.Conv2D(channel, channel, 3, 1, 1)self.bn1 = nn.BatchNorm2D(channel)# self.prelu = nn.PReLU('all')self.prelu = nn.ReLU()self.conv2 = nn.Conv2D(channel, channel, 3, 1, 1)self.bn2 = nn.BatchNorm2D(channel)def forward(self, x):y = self.conv1(x)y = self.bn1(y)y = self.prelu(y)y = self.conv2(y)y = self.bn2(y)return x + y
上采样块:
包括conv、upscale_factor为2的pixelshuffle和prelu。
网络里用了两个上采样块,所以总的upscale_factor为4。
class US(nn.Layer): # 上采样块def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):super(US, self).__init__()self.conv = nn.Conv2D(in_channels, out_channels, kernel_size, stride, padding)self.ps = nn.PixelShuffle (2)# self.prelu = nn.PReLU('all')self.prelu = nn.ReLU()def forward(self, x):x = self.conv(x)x = self.ps(x)x = self.prelu(x)return x
判别器整体结构:
这是一个经典的结构,包括一系列的conv-bn-leakyrelu和两个全连接。
第一个conv后没有bn;除了最后的激活函数为sigmoid,其他都为leakyrelu。
由于有全连接的存在,不同的输入尺寸会有不同的全连接参数数量,这里的参数数量与论文中不同。
class D(nn.Layer): # 判别器def __init__(self, channel=64):super(D, self).__init__()self.layer_list = []self.layer_list += [self.add_sublayer('conv', nn.Conv2D(3, channel, 3, 1, 1))]self.layer_list += [self.add_sublayer('lrelu1', nn.LeakyReLU())]self.layer_list += [self.add_sublayer('cna1', CNA(channel, channel, 3, 2, [1,0,1,0]))]self.layer_list += [self.add_sublayer('cna2', CNA(channel, channel*2))]self.layer_list += [self.add_sublayer('cna3', CNA(channel*2, channel*2, 3, 2, [1,0,1,0]))]self.layer_list += [self.add_sublayer('cna4', CNA(channel*2, channel*4))]self.layer_list += [self.add_sublayer('cna5', CNA(channel*4, channel*4, 3, 2, [1,0,1,0]))]self.layer_list += [self.add_sublayer('cna6', CNA(channel*4, channel*8))]self.layer_list += [self.add_sublayer('cna7', CNA(channel*8, channel*8, 3, 2, [1,0,1,0]))]self.layer_list += [self.add_sublayer('flatten', nn.Flatten(start_axis=1, stop_axis=3))]self.layer_list += [self.add_sublayer('fc1', nn.Linear(PATCH_SIZE[0]*4//16*PATCH_SIZE[1]*4//16*channel*8, channel*16))]self.layer_list += [self.add_sublayer('lrelu2', nn.LeakyReLU())]self.layer_list += [self.add_sublayer('fc1', nn.Linear(channel*16, 1))]self.layer_list += [self.add_sublayer('sigmoid', nn.Sigmoid())]def forward(self, x):for layer in self.layer_list:x = layer(x)return x
conv + norm + act:
class CNA(nn.Layer): # conv-norm-actdef __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):super(CNA, self).__init__()self.conv = nn.Conv2D(in_channels, out_channels, kernel_size, stride, padding)self.bn = nn.BatchNorm(out_channels)self.lrelu = nn.LeakyReLU()def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.lrelu(x)return x
预训练网络VGG19。
代码链接:
https://github.com/PaddlePaddle/PaddleClas/blob/dygraph/ppcls/modeling/architectures/vgg.py
参数下载链接:
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/VGG19_pretrained.pdparams
这里使用conv5_4后激活层的输出。
class ConvBlock(nn.Layer):def __init__(self, input_channels, output_channels, groups, name=None):super(ConvBlock, self).__init__()self.groups = groupsself._conv_1 = Conv2D(in_channels=input_channels,out_channels=output_channels,kernel_size=3,stride=1,padding=1,weight_attr=ParamAttr(name=name + "1_weights"),bias_attr=False)if groups == 2 or groups == 3 or groups == 4:self._conv_2 = Conv2D(in_channels=output_channels,out_channels=output_channels,kernel_size=3,stride=1,padding=1,weight_attr=ParamAttr(name=name + "2_weights"),bias_attr=False)if groups == 3 or groups == 4:self._conv_3 = Conv2D(in_channels=output_channels,out_channels=output_channels,kernel_size=3,stride=1,padding=1,weight_attr=ParamAttr(name=name + "3_weights"),bias_attr=False)if groups == 4:self._conv_4 = Conv2D(in_channels=output_channels,out_channels=output_channels,kernel_size=3,stride=1,padding=1,weight_attr=ParamAttr(name=name + "4_weights"),bias_attr=False)self._pool = MaxPool2D(kernel_size=2, stride=2, padding=0)def forward(self, inputs):x = self._conv_1(inputs)x = F.relu(x)if self.groups == 2 or self.groups == 3 or self.groups == 4:x = self._conv_2(x)x = F.relu(x)if self.groups == 3 or self.groups == 4:x = self._conv_3(x)x = F.relu(x)if self.groups == 4:x = self._conv_4(x)x = F.relu(x)y = xx = self._pool(x)return x, yclass VGGNet(nn.Layer):def __init__(self):super(VGGNet, self).__init__()self.groups = [2, 2, 4, 4, 4]self._conv_block_1 = ConvBlock(3, 64, self.groups[0], name="conv1_")self._conv_block_2 = ConvBlock(64, 128, self.groups[1], name="conv2_")self._conv_block_3 = ConvBlock(128, 256, self.groups[2], name="conv3_")self._conv_block_4 = ConvBlock(256, 512, self.groups[3], name="conv4_")self._conv_block_5 = ConvBlock(512, 512, self.groups[4], name="conv5_")def forward(self, inputs):x, y = self._conv_block_1(inputs)x, y = self._conv_block_2(x)x, y = self._conv_block_3(x)x, y = self._conv_block_4(x)_, y = self._conv_block_5(x)return y
vgg19 = VGGNet()
vgg19.set_state_dict(P.load('/home/aistudio/work/vgg19_ww.pdparams'))
vgg19.eval()
辅助函数
在训练迭代中显示图像,以观察效果。
def show_image(srresnet=None, srgan=None, path=None):if srresnet == None:srresnet = G()srresnet.eval()if srgan == None:srgan = G() srgan.eval()fig = plt.figure(figsize=(25, 25))gs = plt.GridSpec(1, 4)gs.update(wspace=0.1, hspace=0.1)if path == None:image = Image.open(PATH+DIRS[np.random.randint(len(DIRS))])else:image = Image.open(path)image = image.crop([0,0,image.size[0]//SCALE*SCALE,image.size[1]//SCALE*SCALE])# image = image.crop([0,0,40,40])LR0 = image.resize((image.size[0]//SCALE,image.size[1]//SCALE),Image.BICUBIC)LR = np.array(LR0).astype('float32').reshape([LR0.size[1],LR0.size[0],3,1]).transpose([3,2,0,1]) / 255 * 2 - 1LSR_srresnet = srresnet(P.to_tensor(LR)).numpy()LSR_srresnet = LSR_srresnet.reshape([3,LR0.size[1]*SCALE,LR0.size[0]*SCALE]).transpose([1,2,0])# LSR_srresnet = Image.fromarray(np.uint8((LSR_srresnet+1)/2*255)) ### 亮斑的罪魁祸首LSR_srresnet = (LSR_srresnet+1)/2LSR_srgan = srgan(P.to_tensor(LR)).numpy()print(np.max(LSR_srgan), np.min(LSR_srgan))LSR_srgan = LSR_srgan.reshape([3,LR0.size[1]*SCALE,LR0.size[0]*SCALE]).transpose([1,2,0])# LSR_srgan = Image.fromarray(np.uint8((LSR_srgan+1)/2*255)) ### 亮斑的罪魁祸首LSR_srgan = (LSR_srgan+1)/2ax = plt.subplot(gs[0])plt.imshow(LR0)plt.title('LR')ax = plt.subplot(gs[1])plt.imshow(LSR_srresnet)plt.title('SRResNet')ax = plt.subplot(gs[2])plt.imshow(LSR_srgan)plt.title('SRGAN')ax = plt.subplot(gs[3])plt.imshow(image)plt.title('HR')plt.show()show_image()
训练
为了与SRGAN作比较,同时训练一个SRResNet,也就是只使用了生成器,并只用L2损失来训练的网络。
SRGAN生成器的损失 = 图像L2损失 + λ1×感知损失 + λ2×对抗损失, 其中λ1=1e-2, λ2=1e-2。
SRResNet和SRGAN的生成器相同初始化。
由于Celeba比DIV2K图像数量多很多,epoch可以相对少一些。
def srresnet_trainer(lr, hr, srresnet, optimizer_srresnet):sr = srresnet(lr)loss = P.mean((sr-hr)**2)srresnet.clear_gradients()loss.backward()optimizer_srresnet.minimize(loss)def srgan_trainer(lr, hr, srgan_g, srgan_d, vgg, optimizer_srgan_g, optimizer_srgan_d, λ1=1e-2, λ2=1e-2):sr = srgan_g(lr)f = vgg(P.concat([sr,hr],axis=0))loss_content = P.mean((sr-hr)**2) + λ1*P.mean((f[:f.shape[0]//2,:,:,:]-f[f.shape[0]//2:,:,:,:])**2)d = srgan_d(P.concat([sr,hr],axis=0))loss_adversarial_g = P.mean(-P.log(d[:d.shape[0]//2,:]+1e-8))loss_adversarial_d = (P.mean(-P.log(d[d.shape[0]//2:,:]+1e-8)) + P.mean(-P.log(1-d[:d.shape[0]//2,:]+1e-8))) / 2loss_g = loss_content + λ2*loss_adversarial_gvgg.clear_gradients()srgan_g.clear_gradients()srgan_d.clear_gradients()loss_g.backward(retain_graph=True)loss_adversarial_d.backward()optimizer_srgan_g.minimize(loss_g)optimizer_srgan_d.minimize(loss_adversarial_d)def train(epoch_num=200, load_model=False, batchsize=1, model_path = './output/'):srresnet = G()srgan_g = G()srgan_g.set_state_dict(srresnet.state_dict())srgan_d = D()srgan_d.train()optimizer_srresnet = P.optimizer.Adam(learning_rate=1e-4, beta1=0.9, parameters=srresnet.parameters())optimizer_srgan_g = P.optimizer.Adam(learning_rate=1e-4, beta1=0.9, parameters=srgan_g.parameters())optimizer_srgan_d = P.optimizer.Adam(learning_rate=1e-4, beta1=0.9, parameters=srgan_d.parameters())if load_model == True:srresnet.set_state_dict(P.load(model_path+'srresnet.pdparams'))srgan_g.set_state_dict(P.load(model_path+'srgan_g.pdparams'))srgan_d.set_state_dict(P.load(model_path+'srgan_d.pdparams'))srresnet.set_state_dict(P.load(model_path+'备用srresnet.pdparams'))srgan_g.set_state_dict(P.load(model_path+'备用srgan_g.pdparams'))srgan_d.set_state_dict(P.load(model_path+'备用srgan_d.pdparams'))iteration_num = 0for epoch in range(epoch_num):reader = reader_patch(batchsize)for iteration in range(len(DIRS)):srresnet.train()srgan_g.train()iteration_num += 1 LR, HR = next(reader)LR, HR = data_augmentation(LR, HR)LR = P.to_tensor(LR)HR = P.to_tensor(HR)srresnet_trainer(LR, HR, srresnet, optimizer_srresnet)srgan_trainer(LR, HR, srgan_g, srgan_d, vgg19, optimizer_srgan_g, optimizer_srgan_d)if(iteration_num % 100 == 0):print('Epoch: ', epoch, ', Iteration: ', iteration_num) P.save(srresnet.state_dict(), model_path+'srresnet.pdparams')P.save(srgan_g.state_dict(), model_path+'srgan_g.pdparams')P.save(srgan_d.state_dict(), model_path+'srgan_d.pdparams')P.save(srresnet.state_dict(), model_path+'备用srresnet.pdparams')P.save(srgan_g.state_dict(), model_path+'备用srgan_g.pdparams')P.save(srgan_d.state_dict(), model_path+'备用srgan_d.pdparams')show_image(srresnet, srgan_g) # train(epoch_num=1, load_model=False, batchsize=16)
# train(epoch_num=998, load_model=True, batchsize=16)
测试
可以看到图中有一些斑点,根据我的猜测,这是训练不充分导致的,总体上SRGAN的斑点更多,说明它比SRResNet需要更多训练,也就是它的上限更高。 老天爷,我之前竟然装模作样瞎分析一番,尴了个大尬。。。不删了,作为我成长的见证。。。出现斑点的原因其实是用了Image.fromarray(np.uint8())!不过说训练不充分也有道理,训练充分的话就不会超出范围,也就没这个幺蛾子啦。。
相对SRResNet来说,SRGAN不那么平滑,但是有些细节并不准确,更像是噪声,而且有时会出现奇怪的东西,例如额头上的亮光。
srresnet = G()
srgan_g = G()
model_path = './output/'
srresnet.set_state_dict(P.load(model_path+'srresnet.pdparams'))
srgan_g.set_state_dict(P.load(model_path+'srgan_g.pdparams'))
show_image(srresnet, srgan_g)