GAN 网络的损失函数介绍代码

文章目录

  • GAN的损失函数介绍
    • 1.L1 losses
    • 2.mse loss
    • 3.smooth L1
    • 4.charbonnier_loss
    • 5.perceptual loss (content and style losses)
    • 6.Gan损失
    • 7.WeightedTVLoss
    • 8.完整代码方便使用,含训练epoch代码。

GAN的损失函数介绍

1.L1 losses

pixel_opt:
type: L1Loss
loss_weight: 1.0
reduction: mean

相比于一般的l1 loss多了 loss weight, reduction, weight三个功能。

首先loss_util.py文件定义weight_loss

import functools
from torch.nn import functional as Fdef reduce_loss(loss, reduction):"""Reduce loss as specified.Args:loss (Tensor): Elementwise loss tensor.reduction (str): Options are 'none', 'mean' and 'sum'.Returns:Tensor: Reduced loss tensor."""reduction_enum = F._Reduction.get_enum(reduction)# none: 0, elementwise_mean:1, sum: 2if reduction_enum == 0:return losselif reduction_enum == 1:return loss.mean()else:return loss.sum()def weight_reduce_loss(loss, weight=None, reduction='mean'):"""Apply element-wise weight and reduce loss.Args:loss (Tensor): Element-wise loss.weight (Tensor): Element-wise weights. Default: None.reduction (str): Same as built-in losses of PyTorch. Options are'none', 'mean' and 'sum'. Default: 'mean'.Returns:Tensor: Loss values."""# if weight is specified, apply element-wise weightif weight is not None:assert weight.dim() == loss.dim()assert weight.size(1) == 1 or weight.size(1) == loss.size(1)loss = loss * weight# if weight is not specified or reduction is sum, just reduce the lossif weight is None or reduction == 'sum':loss = reduce_loss(loss, reduction)# if reduction is mean, then compute mean over weight regionelif reduction == 'mean':if weight.size(1) > 1:weight = weight.sum()else:weight = weight.sum() * loss.size(1)loss = loss.sum() / weightreturn lossdef weighted_loss(loss_func):"""Create a weighted version of a given loss function.To use this decorator, the loss function must have the signature like`loss_func(pred, target, **kwargs)`. The function only needs to computeelement-wise loss without any reduction. This decorator will add weightand reduction arguments to the function. The decorated function will havethe signature like `loss_func(pred, target, weight=None, reduction='mean',**kwargs)`.:Example:>>> import torch>>> @weighted_loss>>> def l1_loss(pred, target):>>>     return (pred - target).abs()>>> pred = torch.Tensor([0, 2, 3])>>> target = torch.Tensor([1, 1, 1])>>> weight = torch.Tensor([1, 0, 1])>>> l1_loss(pred, target)tensor(1.3333)>>> l1_loss(pred, target, weight)tensor(1.5000)>>> l1_loss(pred, target, reduction='none')tensor([1., 1., 2.])>>> l1_loss(pred, target, weight, reduction='sum')tensor(3.)"""@functools.wraps(loss_func)def wrapper(pred, target, weight=None, reduction='mean', **kwargs):# get element-wise lossloss = loss_func(pred, target, **kwargs) # 这里 reduction='none'loss = weight_reduce_loss(loss, weight, reduction)return lossreturn wrapper

接下来定义带weight的L1 loss
有什么用呢?主要是weight, weight和loss的shape是一致的,
比如 L1 loss : 图像a-图像b 的绝对值: N,c,h,w
那么weight的形状也是 N,c,h,w或者可以 广播到N,c,h,w
比如 N,1,h,w 和 N,c,1,1

import math
import torch
from torch import autograd as autograd
from torch import nn as nn
from torch.nn import functional as F
import numpy as np@weighted_loss
def l1_loss(pred, target):return F.l1_loss(pred, target, reduction='none')class L1Loss(nn.Module):"""L1 (mean absolute error, MAE) loss.Args:loss_weight (float): Loss weight for L1 loss. Default: 1.0.reduction (str): Specifies the reduction to apply to the output.Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'."""def __init__(self, loss_weight=1.0, reduction='mean'):super(L1Loss, self).__init__()if reduction not in ['none', 'mean', 'sum']:raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')self.loss_weight = loss_weightself.reduction = reductiondef forward(self, pred, target, weight=None, **kwargs):"""Args:pred (Tensor): of shape (N, C, H, W). Predicted tensor.target (Tensor): of shape (N, C, H, W). Ground truth tensor.weight (Tensor, optional): of shape (N, C, H, W). Element-wiseweights. Default: None."""return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)

2.mse loss

@weighted_loss
def mse_loss(pred, target):return F.mse_loss(pred, target, reduction='none')class MSELoss(nn.Module):"""MSE (L2) loss.Args:loss_weight (float): Loss weight for MSE loss. Default: 1.0.reduction (str): Specifies the reduction to apply to the output.Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'."""def __init__(self, loss_weight=1.0, reduction='mean'):super(MSELoss, self).__init__()if reduction not in ['none', 'mean', 'sum']:raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')self.loss_weight = loss_weightself.reduction = reductiondef forward(self, pred, target, weight=None, **kwargs):"""Args:pred (Tensor): of shape (N, C, H, W). Predicted tensor.target (Tensor): of shape (N, C, H, W). Ground truth tensor.weight (Tensor, optional): of shape (N, C, H, W). Element-wiseweights. Default: None."""return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)

3.smooth L1

L1是 差异绝对值
L2(MSE)是差异的平方
使用L1结果会更容易稀疏(包含0),不太照顾离群点
使用L2结果会更平滑,对离群点压制比较厉害。
https://blog.csdn.net/Roaddd/article/details/114798798 这篇博客介绍的很好

在这里插入图片描述

torch.nn.functional.smooth_l1_loss(input, target, size_average=None, educe=None, reduction=‘mean’)
torch.nn.SmoothL1Loss(size_average=None, reduce=None, reduction=‘mean’)

@weighted_loss
def smooth_l1_loss(pred, target):return F.smooth_l1_loss(pred, target, reduction='none')class SmoothL1Loss(nn.Module):"""MSE (L2) loss.Args:loss_weight (float): Loss weight for MSE loss. Default: 1.0.reduction (str): Specifies the reduction to apply to the output.Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'."""def __init__(self, loss_weight=1.0, reduction='mean'):super(MSELoss, self).__init__()if reduction not in ['none', 'mean', 'sum']:raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')self.loss_weight = loss_weightself.reduction = reductiondef forward(self, pred, target, weight=None, **kwargs):"""Args:pred (Tensor): of shape (N, C, H, W). Predicted tensor.target (Tensor): of shape (N, C, H, W). Ground truth tensor.weight (Tensor, optional): of shape (N, C, H, W). Element-wiseweights. Default: None."""return self.loss_weight * smooth_l1_loss(pred, target, weight, reduction=self.reduction)

4.charbonnier_loss

先平方,再开方是对L1的改进。
ϵ是一个很小的常数,用于保证在x=0时函数的可微性

@weighted_loss
def charbonnier_loss(pred, target, eps=1e-12):return torch.sqrt((pred - target)**2 + eps)
class CharbonnierLoss(nn.Module):"""Charbonnier loss (one variant of Robust L1Loss, a differentiablevariant of L1Loss).Described in "Deep Laplacian Pyramid Networks for Fast and AccurateSuper-Resolution".Args:loss_weight (float): Loss weight for L1 loss. Default: 1.0.reduction (str): Specifies the reduction to apply to the output.Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.eps (float): A value used to control the curvature near zero.Default: 1e-12."""def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):super(CharbonnierLoss, self).__init__()if reduction not in ['none', 'mean', 'sum']:raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')self.loss_weight = loss_weightself.reduction = reductionself.eps = epsdef forward(self, pred, target, weight=None, **kwargs):"""Args:pred (Tensor): of shape (N, C, H, W). Predicted tensor.target (Tensor): of shape (N, C, H, W). Ground truth tensor.weight (Tensor, optional): of shape (N, C, H, W). Element-wiseweights. Default: None."""return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)

5.perceptual loss (content and style losses)

感知损失主要包括内容损失和风格损失。

vgg19下载地址:vgg19-dcbb9e9d.pth
perceptual loss需要用到训练好的vgg模型,这里以vgg19为例

首先修改 vgg model, 我们只提取特定层的feature。

import os
import torch
from collections import OrderedDict
from torch import nn as nn
from torchvision.models import vgg as vggfrom basicsr.utils.registry import ARCH_REGISTRYVGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
NAMES = {'vgg11': ['conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2','pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2','pool5'],'vgg13': ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2','conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4','conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'],'vgg16': ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2','conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2','relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3','pool5'],'vgg19': ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2','conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1','relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1','conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5']
}def insert_bn(names):"""Insert bn layer after each conv.Args:names (list): The list of layer names.Returns:list: The list of layer names with bn layers."""names_bn = []for name in names:names_bn.append(name)if 'conv' in name:position = name.replace('conv', '')names_bn.append('bn' + position)return names_bn@ARCH_REGISTRY.register()
class VGGFeatureExtractor(nn.Module):"""VGG network for feature extraction.In this implementation, we allow users to choose whether use normalizationin the input feature and the type of vgg network. Note that the pretrainedpath must fit the vgg type.Args:layer_name_list (list[str]): Forward function returns the correspondingfeatures according to the layer_name_list.Example: {'relu1_1', 'relu2_1', 'relu3_1'}.vgg_type (str): Set the type of vgg network. Default: 'vgg19'.use_input_norm (bool): If True, normalize the input image. Importantly,the input feature must in the range [0, 1]. Default: True.range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].Default: False.requires_grad (bool): If true, the parameters of VGG network will beoptimized. Default: False.remove_pooling (bool): If true, the max pooling operations in VGG netwill be removed. Default: False.pooling_stride (int): The stride of max pooling operation. Default: 2."""def __init__(self,layer_name_list,vgg_type='vgg19',use_input_norm=True,range_norm=False,requires_grad=False,remove_pooling=False,pooling_stride=2):super(VGGFeatureExtractor, self).__init__()self.layer_name_list = layer_name_listself.use_input_norm = use_input_normself.range_norm = range_normself.names = NAMES[vgg_type.replace('_bn', '')]if 'bn' in vgg_type:self.names = insert_bn(self.names)# only borrow layers that will be used to avoid unused paramsmax_idx = 0for v in layer_name_list:idx = self.names.index(v)if idx > max_idx:max_idx = idxif os.path.exists(VGG_PRETRAIN_PATH):vgg_net = getattr(vgg, vgg_type)(pretrained=False)state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)vgg_net.load_state_dict(state_dict)else:vgg_net = getattr(vgg, vgg_type)(pretrained=True)features = vgg_net.features[:max_idx + 1]modified_net = OrderedDict()for k, v in zip(self.names, features):if 'pool' in k:# if remove_pooling is true, pooling operation will be removedif remove_pooling:continueelse:# in some cases, we may want to change the default stridemodified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)else:modified_net[k] = vself.vgg_net = nn.Sequential(modified_net)if not requires_grad:self.vgg_net.eval()for param in self.parameters():param.requires_grad = Falseelse:self.vgg_net.train()for param in self.parameters():param.requires_grad = Trueif self.use_input_norm:# the mean is for image with range [0, 1]self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))# the std is for image with range [0, 1]self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))def forward(self, x):"""Forward function.Args:x (Tensor): Input tensor with shape (n, c, h, w).Returns:Tensor: Forward results."""if self.range_norm:x = (x + 1) / 2if self.use_input_norm:x = (x - self.mean) / self.stdoutput = {}for key, layer in self.vgg_net._modules.items():x = layer(x)if key in self.layer_name_list:output[key] = x.clone()return output

perceptual loss:利用VGGFeatureExtractor 提取特定层的feature map

提取之后计算loss:
比如分别提取 gt 和 model output 的 vgg 特征图,然后计算差异:内容差异常用L1 loss, 风格差异常用 相似度
默认参数定义:

  perceptual_opt:type: PerceptualLosslayer_weights:# before relu'conv1_2': 0.1'conv2_2': 0.1'conv3_4': 1'conv4_4': 1'conv5_4': 1vgg_type: vgg19use_input_norm: trueperceptual_weight: !!float 1style_weight: 0range_norm: falsecriterion: l1
class PerceptualLoss(nn.Module):"""Perceptual loss with commonly used style loss.Args:layer_weights (dict): The weight for each layer of vgg feature.Here is an example: {'conv5_4': 1.}, which means the conv5_4feature layer (before relu5_4) will be extracted with weight1.0 in calculting losses.vgg_type (str): The type of vgg network used as feature extractor.Default: 'vgg19'.use_input_norm (bool):  If True, normalize the input image in vgg.Default: True.range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].Default: False.perceptual_weight (float): If `perceptual_weight > 0`, the perceptualloss will be calculated and the loss will multiplied by theweight. Default: 1.0.style_weight (float): If `style_weight > 0`, the style loss will becalculated and the loss will multiplied by the weight.Default: 0.criterion (str): Criterion used for perceptual loss. Default: 'l1'."""def __init__(self,layer_weights,vgg_type='vgg19',use_input_norm=True,range_norm=False,perceptual_weight=1.0,style_weight=0.,criterion='l1'):super(PerceptualLoss, self).__init__()self.perceptual_weight = perceptual_weightself.style_weight = style_weightself.layer_weights = layer_weightsself.vgg = VGGFeatureExtractor(layer_name_list=list(layer_weights.keys()),vgg_type=vgg_type,use_input_norm=use_input_norm,range_norm=range_norm)self.criterion_type = criterionif self.criterion_type == 'l1':self.criterion = torch.nn.L1Loss()elif self.criterion_type == 'l2':self.criterion = torch.nn.L2loss()elif self.criterion_type == 'fro':self.criterion = Noneelse:raise NotImplementedError(f'{criterion} criterion has not been supported.')def forward(self, x, gt):"""Forward function.Args:x (Tensor): Input tensor with shape (n, c, h, w).gt (Tensor): Ground-truth tensor with shape (n, c, h, w).Returns:Tensor: Forward results."""# extract vgg featuresx_features = self.vgg(x)gt_features = self.vgg(gt.detach())# calculate perceptual lossif self.perceptual_weight > 0:percep_loss = 0for k in x_features.keys():if self.criterion_type == 'fro':percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]else:percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]percep_loss *= self.perceptual_weightelse:percep_loss = None# calculate style lossif self.style_weight > 0:style_loss = 0for k in x_features.keys():if self.criterion_type == 'fro':style_loss += torch.norm(self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]else:style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * self.layer_weights[k]style_loss *= self.style_weightelse:style_loss = Nonereturn percep_loss, style_lossdef _gram_mat(self, x): #其实计算的结果是c x c的矩阵,每个元素是各个通道的相关性,协方差矩阵"""Calculate Gram matrix.Args:x (torch.Tensor): Tensor with shape of (n, c, h, w).Returns:torch.Tensor: Gram matrix."""n, c, h, w = x.size()features = x.view(n, c, w * h)features_t = features.transpose(1, 2)gram = features.bmm(features_t) / (c * h * w) # bmm只能应用与维度为3的tensorreturn gram

6.Gan损失

gan损失其实就是判别器的分类损失。

gan损失默认的一个参数设置如下:

# gan loss
gan_opt:type: GANLossgan_type: vanillareal_label_val: 1.0fake_label_val: 0.0loss_weight: !!float 1e-1

GANLoss 代码 和 MultiScaleGanLoss 代码如下

class GANLoss(nn.Module):"""Define GAN loss.Args:gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.real_label_val (float): The value for real label. Default: 1.0.fake_label_val (float): The value for fake label. Default: 0.0.loss_weight (float): Loss weight. Default: 1.0.Note that loss_weight is only for generators; and it is always 1.0for discriminators."""def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):super(GANLoss, self).__init__()self.gan_type = gan_typeself.loss_weight = loss_weightself.real_label_val = real_label_valself.fake_label_val = fake_label_valif self.gan_type == 'vanilla':self.loss = nn.BCEWithLogitsLoss()elif self.gan_type == 'lsgan':self.loss = nn.MSELoss()elif self.gan_type == 'wgan':self.loss = self._wgan_losselif self.gan_type == 'wgan_softplus':self.loss = self._wgan_softplus_losselif self.gan_type == 'hinge':self.loss = nn.ReLU()else:raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')def _wgan_loss(self, input, target):"""wgan loss.Args:input (Tensor): Input tensor.target (bool): Target label.Returns:Tensor: wgan loss."""return -input.mean() if target else input.mean()def _wgan_softplus_loss(self, input, target):"""wgan loss with soft plus. softplus is a smooth approximation to theReLU function.In StyleGAN2, it is called:Logistic loss for discriminator;Non-saturating loss for generator.Args:input (Tensor): Input tensor.target (bool): Target label.Returns:Tensor: wgan loss."""return F.softplus(-input).mean() if target else F.softplus(input).mean()def get_target_label(self, input, target_is_real):"""Get target label.Args:input (Tensor): Input tensor.target_is_real (bool): Whether the target is real or fake.Returns:(bool | Tensor): Target tensor. Return bool for wgan, otherwise,return Tensor."""if self.gan_type in ['wgan', 'wgan_softplus']:return target_is_realtarget_val = (self.real_label_val if target_is_real else self.fake_label_val) #这里根据目标是real图还是生成图来 分别赋值 1和0, 挺多余的,转换来转换去,意思都一样。real就是1,fake就是0return input.new_ones(input.size()) * target_valdef forward(self, input, target_is_real, is_disc=False):"""Args:input (Tensor): The input for the loss module, i.e., the networkprediction.target_is_real (bool): Whether the targe is real or fake.is_disc (bool): Whether the loss for discriminators or not.Default: False.Returns:Tensor: GAN loss value."""target_label = self.get_target_label(input, target_is_real)if self.gan_type == 'hinge':if is_disc:  # for discriminators in hinge-ganinput = -input if target_is_real else inputloss = self.loss(1 + input).mean()else:  # for generators in hinge-ganloss = -input.mean()else:  # other gan typesloss = self.loss(input, target_label)# loss_weight is always 1.0 for discriminatorsreturn loss if is_disc else loss * self.loss_weight
"""
MultiScaleGANLoss 用于传进来的input是一个list,包含多个tensor的情况,这样对每个tensor分别计算ganloss,再求平均。
"""
class MultiScaleGANLoss(GANLoss):"""MultiScaleGANLoss accepts a list of predictions"""def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight)def forward(self, input, target_is_real, is_disc=False):"""The input is a list of tensors, or a list of (a list of tensors)"""if isinstance(input, list):loss = 0for pred_i in input:if isinstance(pred_i, list):# Only compute GAN loss for the last layer# in case of multiscale feature matchingpred_i = pred_i[-1]# Safe operaton: 0-dim tensor calling self.mean() does nothingloss_tensor = super().forward(pred_i, target_is_real, is_disc).mean()loss += loss_tensorreturn loss / len(input)else:return super().forward(input, target_is_real, is_disc)

7.WeightedTVLoss

这个损失很有意思,梯度损失?相邻像素的变化损失?。假如这个损失会更平滑,更连续吧

class WeightedTVLoss(L1Loss):"""Weighted TV loss.Args:loss_weight (float): Loss weight. Default: 1.0."""def __init__(self, loss_weight=1.0):super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)def forward(self, pred, weight=None):y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])loss = x_diff + y_diffreturn loss

在 3dlut生成中有类似的损失。
TV损失和mn是单调性损失

class TV_3D(nn.Module):def __init__(self, dim=33):super(TV_3D,self).__init__()self.weight_r = torch.ones(3,dim,dim,dim-1, dtype=torch.float)self.weight_r[:,:,:,(0,dim-2)] *= 2.0self.weight_g = torch.ones(3,dim,dim-1,dim, dtype=torch.float)self.weight_g[:,:,(0,dim-2),:] *= 2.0self.weight_b = torch.ones(3,dim-1,dim,dim, dtype=torch.float)self.weight_b[:,(0,dim-2),:,:] *= 2.0self.relu = torch.nn.ReLU()def forward(self, LUT):dif_r = LUT.LUT[:,:,:,:-1] - LUT.LUT[:,:,:,1:]dif_g = LUT.LUT[:,:,:-1,:] - LUT.LUT[:,:,1:,:]dif_b = LUT.LUT[:,:-1,:,:] - LUT.LUT[:,1:,:,:]tv = torch.mean(torch.mul((dif_r ** 2),self.weight_r)) + torch.mean(torch.mul((dif_g ** 2),self.weight_g)) + torch.mean(torch.mul((dif_b ** 2),self.weight_b))# 3dlut 是不断递增的,因此希望后面的数比前面的数大, diff_r,diff_g,diff_b都是前面的减去后面的,因此希望为负 等价于 (加个relu使负的为0)mn = torch.mean(self.relu(dif_r)) + torch.mean(self.relu(dif_g)) + torch.mean(self.relu(dif_b))return tv, mn

8.完整代码方便使用,含训练epoch代码。

import functools
from torch.nn import functional as Fimport math
import torch
from torch import autograd as autograd
from torch import nn as nn
from torch.nn import functional as F
import numpy as np
def reduce_loss(loss, reduction):"""Reduce loss as specified.Args:loss (Tensor): Elementwise loss tensor.reduction (str): Options are 'none', 'mean' and 'sum'.Returns:Tensor: Reduced loss tensor."""reduction_enum = F._Reduction.get_enum(reduction)# none: 0, elementwise_mean:1, sum: 2if reduction_enum == 0:return losselif reduction_enum == 1:return loss.mean()else:return loss.sum()def weight_reduce_loss(loss, weight=None, reduction='mean'):"""Apply element-wise weight and reduce loss.Args:loss (Tensor): Element-wise loss.weight (Tensor): Element-wise weights. Default: None.reduction (str): Same as built-in losses of PyTorch. Options are'none', 'mean' and 'sum'. Default: 'mean'.Returns:Tensor: Loss values."""# if weight is specified, apply element-wise weightif weight is not None:assert weight.dim() == loss.dim()assert weight.size(1) == 1 or weight.size(1) == loss.size(1)loss = loss * weight# if weight is not specified or reduction is sum, just reduce the lossif weight is None or reduction == 'sum':loss = reduce_loss(loss, reduction)# if reduction is mean, then compute mean over weight regionelif reduction == 'mean':if weight.size(1) > 1:weight = weight.sum()else:weight = weight.sum() * loss.size(1)loss = loss.sum() / weightreturn lossdef weighted_loss(loss_func):"""Create a weighted version of a given loss function.To use this decorator, the loss function must have the signature like`loss_func(pred, target, **kwargs)`. The function only needs to computeelement-wise loss without any reduction. This decorator will add weightand reduction arguments to the function. The decorated function will havethe signature like `loss_func(pred, target, weight=None, reduction='mean',**kwargs)`.:Example:>>> import torch>>> @weighted_loss>>> def l1_loss(pred, target):>>>     return (pred - target).abs()>>> pred = torch.Tensor([0, 2, 3])>>> target = torch.Tensor([1, 1, 1])>>> weight = torch.Tensor([1, 0, 1])>>> l1_loss(pred, target)tensor(1.3333)>>> l1_loss(pred, target, weight)tensor(1.5000)>>> l1_loss(pred, target, reduction='none')tensor([1., 1., 2.])>>> l1_loss(pred, target, weight, reduction='sum')tensor(3.)"""@functools.wraps(loss_func)def wrapper(pred, target, weight=None, reduction='mean', **kwargs):# get element-wise lossloss = loss_func(pred, target, **kwargs) # 这里 reduction='none'loss = weight_reduce_loss(loss, weight, reduction)return lossreturn wrapper###############################################################################################
"""
pixel_opt:type: L1Lossloss_weight: 1.0reduction: mean
"""
@weighted_loss
def l1_loss(pred, target):return F.l1_loss(pred, target, reduction='none')class L1Loss(nn.Module):"""L1 (mean absolute error, MAE) loss.Args:loss_weight (float): Loss weight for L1 loss. Default: 1.0.reduction (str): Specifies the reduction to apply to the output.Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'."""def __init__(self, loss_weight=1.0, reduction='mean'):super(L1Loss, self).__init__()if reduction not in ['none', 'mean', 'sum']:raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')self.loss_weight = loss_weightself.reduction = reductiondef forward(self, pred, target, weight=None, **kwargs):"""Args:pred (Tensor): of shape (N, C, H, W). Predicted tensor.target (Tensor): of shape (N, C, H, W). Ground truth tensor.weight (Tensor, optional): of shape (N, C, H, W). Element-wiseweights. Default: None."""return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)@weighted_loss
def mse_loss(pred, target):return F.mse_loss(pred, target, reduction='none')class MSELoss(nn.Module):"""MSE (L2) loss.Args:loss_weight (float): Loss weight for MSE loss. Default: 1.0.reduction (str): Specifies the reduction to apply to the output.Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'."""def __init__(self, loss_weight=1.0, reduction='mean'):super(MSELoss, self).__init__()if reduction not in ['none', 'mean', 'sum']:raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')self.loss_weight = loss_weightself.reduction = reductiondef forward(self, pred, target, weight=None, **kwargs):"""Args:pred (Tensor): of shape (N, C, H, W). Predicted tensor.target (Tensor): of shape (N, C, H, W). Ground truth tensor.weight (Tensor, optional): of shape (N, C, H, W). Element-wiseweights. Default: None."""return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)@weighted_loss
def smooth_l1_loss(pred, target):return F.smooth_l1_loss(pred, target, reduction='none')class SmoothL1Loss(nn.Module):"""MSE (L2) loss.Args:loss_weight (float): Loss weight for MSE loss. Default: 1.0.reduction (str): Specifies the reduction to apply to the output.Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'."""def __init__(self, loss_weight=1.0, reduction='mean'):super(MSELoss, self).__init__()if reduction not in ['none', 'mean', 'sum']:raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')self.loss_weight = loss_weightself.reduction = reductiondef forward(self, pred, target, weight=None, **kwargs):"""Args:pred (Tensor): of shape (N, C, H, W). Predicted tensor.target (Tensor): of shape (N, C, H, W). Ground truth tensor.weight (Tensor, optional): of shape (N, C, H, W). Element-wiseweights. Default: None."""return self.loss_weight * smooth_l1_loss(pred, target, weight, reduction=self.reduction)@weighted_loss
def charbonnier_loss(pred, target, eps=1e-12):return torch.sqrt((pred - target)**2 + eps)
class CharbonnierLoss(nn.Module):"""Charbonnier loss (one variant of Robust L1Loss, a differentiablevariant of L1Loss).Described in "Deep Laplacian Pyramid Networks for Fast and AccurateSuper-Resolution".Args:loss_weight (float): Loss weight for L1 loss. Default: 1.0.reduction (str): Specifies the reduction to apply to the output.Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.eps (float): A value used to control the curvature near zero.Default: 1e-12."""def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):super(CharbonnierLoss, self).__init__()if reduction not in ['none', 'mean', 'sum']:raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')self.loss_weight = loss_weightself.reduction = reductionself.eps = epsdef forward(self, pred, target, weight=None, **kwargs):"""Args:pred (Tensor): of shape (N, C, H, W). Predicted tensor.target (Tensor): of shape (N, C, H, W). Ground truth tensor.weight (Tensor, optional): of shape (N, C, H, W). Element-wiseweights. Default: None."""return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)##############################################################################################################################
# perceptual loss
"""perceptual_opt:type: PerceptualLosslayer_weights:# before relu'conv1_2': 0.1'conv2_2': 0.1'conv3_4': 1'conv4_4': 1'conv5_4': 1vgg_type: vgg19use_input_norm: trueperceptual_weight: !!float 1style_weight: 0range_norm: falsecriterion: l1
"""
import os
import torch
from collections import OrderedDict
from torch import nn as nn
from torchvision.models import vgg as vgg#from basicsr.utils.registry import ARCH_REGISTRYVGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
NAMES = {'vgg11': ['conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2','pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2','pool5'],'vgg13': ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2','conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4','conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'],'vgg16': ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2','conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2','relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3','pool5'],'vgg19': ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2','conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1','relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1','conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5']
}def insert_bn(names):"""Insert bn layer after each conv.Args:names (list): The list of layer names.Returns:list: The list of layer names with bn layers."""names_bn = []for name in names:names_bn.append(name)if 'conv' in name:position = name.replace('conv', '')names_bn.append('bn' + position)return names_bnclass VGGFeatureExtractor(nn.Module):"""VGG network for feature extraction.In this implementation, we allow users to choose whether use normalizationin the input feature and the type of vgg network. Note that the pretrainedpath must fit the vgg type.Args:layer_name_list (list[str]): Forward function returns the correspondingfeatures according to the layer_name_list.Example: {'relu1_1', 'relu2_1', 'relu3_1'}.vgg_type (str): Set the type of vgg network. Default: 'vgg19'.use_input_norm (bool): If True, normalize the input image. Importantly,the input feature must in the range [0, 1]. Default: True.range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].Default: False.requires_grad (bool): If true, the parameters of VGG network will beoptimized. Default: False.remove_pooling (bool): If true, the max pooling operations in VGG netwill be removed. Default: False.pooling_stride (int): The stride of max pooling operation. Default: 2."""def __init__(self,layer_name_list,vgg_type='vgg19',use_input_norm=True,range_norm=False,requires_grad=False,remove_pooling=False,pooling_stride=2):super(VGGFeatureExtractor, self).__init__()self.layer_name_list = layer_name_listself.use_input_norm = use_input_normself.range_norm = range_normself.names = NAMES[vgg_type.replace('_bn', '')]if 'bn' in vgg_type:self.names = insert_bn(self.names)# only borrow layers that will be used to avoid unused paramsmax_idx = 0for v in layer_name_list:idx = self.names.index(v)if idx > max_idx:max_idx = idxif os.path.exists(VGG_PRETRAIN_PATH):vgg_net = getattr(vgg, vgg_type)(pretrained=False)state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)vgg_net.load_state_dict(state_dict)else:vgg_net = getattr(vgg, vgg_type)(pretrained=True)features = vgg_net.features[:max_idx + 1]modified_net = OrderedDict()for k, v in zip(self.names, features):if 'pool' in k:# if remove_pooling is true, pooling operation will be removedif remove_pooling:continueelse:# in some cases, we may want to change the default stridemodified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)else:modified_net[k] = vself.vgg_net = nn.Sequential(modified_net)if not requires_grad:self.vgg_net.eval()for param in self.parameters():param.requires_grad = Falseelse:self.vgg_net.train()for param in self.parameters():param.requires_grad = Trueif self.use_input_norm:# the mean is for image with range [0, 1]self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))# the std is for image with range [0, 1]self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))def forward(self, x):"""Forward function.Args:x (Tensor): Input tensor with shape (n, c, h, w).Returns:Tensor: Forward results."""if self.range_norm:x = (x + 1) / 2if self.use_input_norm:x = (x - self.mean) / self.stdoutput = {}for key, layer in self.vgg_net._modules.items():x = layer(x)if key in self.layer_name_list:output[key] = x.clone()return outputclass PerceptualLoss(nn.Module):"""Perceptual loss with commonly used style loss.Args:layer_weights (dict): The weight for each layer of vgg feature.Here is an example: {'conv5_4': 1.}, which means the conv5_4feature layer (before relu5_4) will be extracted with weight1.0 in calculting losses.vgg_type (str): The type of vgg network used as feature extractor.Default: 'vgg19'.use_input_norm (bool):  If True, normalize the input image in vgg.Default: True.range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].Default: False.perceptual_weight (float): If `perceptual_weight > 0`, the perceptualloss will be calculated and the loss will multiplied by theweight. Default: 1.0.style_weight (float): If `style_weight > 0`, the style loss will becalculated and the loss will multiplied by the weight.Default: 0.criterion (str): Criterion used for perceptual loss. Default: 'l1'."""def __init__(self,layer_weights,vgg_type='vgg19',use_input_norm=True,range_norm=False,perceptual_weight=1.0,style_weight=0.,criterion='l1'):super(PerceptualLoss, self).__init__()self.perceptual_weight = perceptual_weightself.style_weight = style_weightself.layer_weights = layer_weightsself.vgg = VGGFeatureExtractor(layer_name_list=list(layer_weights.keys()),vgg_type=vgg_type,use_input_norm=use_input_norm,range_norm=range_norm)self.criterion_type = criterionif self.criterion_type == 'l1':self.criterion = torch.nn.L1Loss()elif self.criterion_type == 'l2':self.criterion = torch.nn.L2loss()elif self.criterion_type == 'fro':self.criterion = Noneelse:raise NotImplementedError(f'{criterion} criterion has not been supported.')def forward(self, x, gt):"""Forward function.Args:x (Tensor): Input tensor with shape (n, c, h, w).gt (Tensor): Ground-truth tensor with shape (n, c, h, w).Returns:Tensor: Forward results."""# extract vgg featuresx_features = self.vgg(x)gt_features = self.vgg(gt.detach())# calculate perceptual lossif self.perceptual_weight > 0:percep_loss = 0for k in x_features.keys():if self.criterion_type == 'fro':percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]else:percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]percep_loss *= self.perceptual_weightelse:percep_loss = None# calculate style lossif self.style_weight > 0:style_loss = 0for k in x_features.keys():if self.criterion_type == 'fro':style_loss += torch.norm(self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]else:style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * self.layer_weights[k]style_loss *= self.style_weightelse:style_loss = Nonereturn percep_loss, style_lossdef _gram_mat(self, x): #其实计算的结果是c x c的矩阵,每个元素是各个通道的相关性,协方差矩阵"""Calculate Gram matrix.Args:x (torch.Tensor): Tensor with shape of (n, c, h, w).Returns:torch.Tensor: Gram matrix."""n, c, h, w = x.size()features = x.view(n, c, w * h)features_t = features.transpose(1, 2)gram = features.bmm(features_t) / (c * h * w) # bmm只能应用与维度为3的tensorreturn gram#################################################################################################
# gan loss
# 默认设置参数
"""
gan_opt:type: GANLossgan_type: vanillareal_label_val: 1.0fake_label_val: 0.0loss_weight: !!float 1e-1
"""class GANLoss(nn.Module):"""Define GAN loss.Args:gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.real_label_val (float): The value for real label. Default: 1.0.fake_label_val (float): The value for fake label. Default: 0.0.loss_weight (float): Loss weight. Default: 1.0.Note that loss_weight is only for generators; and it is always 1.0for discriminators."""def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):super(GANLoss, self).__init__()self.gan_type = gan_typeself.loss_weight = loss_weightself.real_label_val = real_label_valself.fake_label_val = fake_label_valif self.gan_type == 'vanilla':self.loss = nn.BCEWithLogitsLoss()elif self.gan_type == 'lsgan':self.loss = nn.MSELoss()elif self.gan_type == 'wgan':self.loss = self._wgan_losselif self.gan_type == 'wgan_softplus':self.loss = self._wgan_softplus_losselif self.gan_type == 'hinge':self.loss = nn.ReLU()else:raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')def _wgan_loss(self, input, target):"""wgan loss.Args:input (Tensor): Input tensor.target (bool): Target label.Returns:Tensor: wgan loss."""return -input.mean() if target else input.mean()def _wgan_softplus_loss(self, input, target):"""wgan loss with soft plus. softplus is a smooth approximation to theReLU function.In StyleGAN2, it is called:Logistic loss for discriminator;Non-saturating loss for generator.Args:input (Tensor): Input tensor.target (bool): Target label.Returns:Tensor: wgan loss."""return F.softplus(-input).mean() if target else F.softplus(input).mean()def get_target_label(self, input, target_is_real):"""Get target label.Args:input (Tensor): Input tensor.target_is_real (bool): Whether the target is real or fake.Returns:(bool | Tensor): Target tensor. Return bool for wgan, otherwise,return Tensor."""if self.gan_type in ['wgan', 'wgan_softplus']:return target_is_realtarget_val = (self.real_label_val if target_is_real else self.fake_label_val) #这里根据目标是real图还是生成图来 分别赋值 1和0, 挺多余的,转换来转换去,意思都一样。real就是1,fake就是0return input.new_ones(input.size()) * target_valdef forward(self, input, target_is_real, is_disc=False):"""Args:input (Tensor): The input for the loss module, i.e., the networkprediction.target_is_real (bool): Whether the targe is real or fake.is_disc (bool): Whether the loss for discriminators or not.Default: False.Returns:Tensor: GAN loss value."""target_label = self.get_target_label(input, target_is_real)if self.gan_type == 'hinge':if is_disc:  # for discriminators in hinge-ganinput = -input if target_is_real else inputloss = self.loss(1 + input).mean()else:  # for generators in hinge-ganloss = -input.mean()else:  # other gan typesloss = self.loss(input, target_label)# loss_weight is always 1.0 for discriminatorsreturn loss if is_disc else loss * self.loss_weight
"""
MultiScaleGANLoss 用于传进来的input是一个list,包含多个tensor的情况,这样对每个tensor分别计算ganloss,再求平均。
"""
class MultiScaleGANLoss(GANLoss):"""MultiScaleGANLoss accepts a list of predictions"""def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight)def forward(self, input, target_is_real, is_disc=False):"""The input is a list of tensors, or a list of (a list of tensors)"""if isinstance(input, list):loss = 0for pred_i in input:if isinstance(pred_i, list):# Only compute GAN loss for the last layer# in case of multiscale feature matchingpred_i = pred_i[-1]# Safe operaton: 0-dim tensor calling self.mean() does nothingloss_tensor = super().forward(pred_i, target_is_real, is_disc).mean()loss += loss_tensorreturn loss / len(input)else:return super().forward(input, target_is_real, is_disc)#################################################################################################
class WeightedTVLoss(L1Loss):"""Weighted TV loss.Args:loss_weight (float): Loss weight. Default: 1.0."""def __init__(self, loss_weight=1.0):super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)def forward(self, pred, weight=None):y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])loss = x_diff + y_diffreturn lossclass TV_3D(nn.Module):def __init__(self, dim=33):super(TV_3D,self).__init__()self.weight_r = torch.ones(3,dim,dim,dim-1, dtype=torch.float)self.weight_r[:,:,:,(0,dim-2)] *= 2.0self.weight_g = torch.ones(3,dim,dim-1,dim, dtype=torch.float)self.weight_g[:,:,(0,dim-2),:] *= 2.0self.weight_b = torch.ones(3,dim-1,dim,dim, dtype=torch.float)self.weight_b[:,(0,dim-2),:,:] *= 2.0self.relu = torch.nn.ReLU()def forward(self, LUT):dif_r = LUT.LUT[:,:,:,:-1] - LUT.LUT[:,:,:,1:]dif_g = LUT.LUT[:,:,:-1,:] - LUT.LUT[:,:,1:,:]dif_b = LUT.LUT[:,:-1,:,:] - LUT.LUT[:,1:,:,:]tv = torch.mean(torch.mul((dif_r ** 2),self.weight_r)) + torch.mean(torch.mul((dif_g ** 2),self.weight_g)) + torch.mean(torch.mul((dif_b ** 2),self.weight_b))# 3dlut 是不断递增的,因此希望后面的数比前面的数大, diff_r,diff_g,diff_b都是前面的减去后面的,因此希望为负 等价于 (加个relu使负的为0)mn = torch.mean(self.relu(dif_r)) + torch.mean(self.relu(dif_g)) + torch.mean(self.relu(dif_b))return tv, mn# type: L1Loss#     loss_weight: 1.0#     reduction: mean    # type: PerceptualLoss#     layer_weights:#     # before relu#     'conv1_2': 0.1#     'conv2_2': 0.1#     'conv3_4': 1#     'conv4_4': 1#     'conv5_4': 1#     vgg_type: vgg19#     use_input_norm: true#     perceptual_weight: !!float 1#     style_weight: 0#     range_norm: false#     criterion: l1# type: GANLoss#     gan_type: vanilla#     real_label_val: 1.0#     fake_label_val: 0.0#     loss_weight: !!float 1e-1def gan_loss_opti(net_g, net_d, input, gt, optimizer_g, optimizer_d, epoch, net_d_init_iters=100, net_d_iters=1):# 0. loss definecri_pix = L1Loss(loss_weight = 1.0, reduction='mean')cri_perceptual = PerceptualLoss(layer_weights={ 'conv1_2': 0.1,'conv2_2': 0.1,'conv3_4': 1,'conv4_4': 1,'conv5_4': 1},vgg_type='vgg19',use_input_norm=True,range_norm=False,perceptual_weight=1.0,style_weight=0.,criterion='l1')cri_gan = GANLoss(gan_type='vanilla', real_label_val=1.0, fake_label_val=0.0, loss_weight=0.1)# 一次迭代步骤的优化。优化一次生成器,接着优化一次判别器。# optimize net_g# 1. 首先优化 生成网络net_g, net_d判别网络不更新weightfor p in net_d.parameters():p.requires_grad = False# 2. 梯度归0optimizer_g.zero_grad()# 3. 前向生成网络,输入的是一个低质图像output = net_g(input)# 4. 计算训练生成网络的损失# 主要包括 pixel loss 重建损失 self.cri_pix(self.output, self.gt)# 图像内容和风格感知损失    self.cri_perceptual(self.output, self.gt)# gan损失,使预测迷惑判别器 self.cri_gan(fake_g_pred, True, is_disc=False)l_g_total = 0loss_dict = OrderedDict()# 首先在epoch小于net_d_init_iters的情况下只训练 net_d, 不训练net_gif (epoch % net_d_iters == 0 and epoch > net_d_init_iters):# pixel lossif cri_pix:l_pix = cri_pix(output, gt)l_g_total += l_pixloss_dict['l_pix'] = l_pix# perceptual lossif cri_perceptual:l_percep, l_style = cri_perceptual(output, gt)if l_percep is not None:l_g_total += l_perceploss_dict['l_percep'] = l_percepif l_style is not None:l_g_total += l_styleloss_dict['l_style'] = l_style# gan lossfake_g_pred = net_d(output)l_g_gan = cri_gan(fake_g_pred, True, is_disc=False)l_g_total += l_g_ganloss_dict['l_g_gan'] = l_g_gan# 5. 计算梯度和优化l_g_total.backward()optimizer_g.step()# optimize net_d# 6. 优化判别器网络,首先requires_grad设为ture,可训练for p in net_d.parameters():p.requires_grad = True# 7. 梯度归0optimizer_d.zero_grad()# real# 8. 计算gt进入判别器的损失,使gt 尽量为 1real_d_pred = net_d(gt)l_d_real = cri_gan(real_d_pred, True, is_disc=True)loss_dict['l_d_real'] = l_d_realloss_dict['out_d_real'] = torch.mean(real_d_pred.detach())l_d_real.backward()# fake# 9. 计算gt进入判别器的损失,使predict output 尽量为 0fake_d_pred = net_d(output.detach())l_d_fake = cri_gan(fake_d_pred, False, is_disc=True)loss_dict['l_d_fake'] = l_d_fakeloss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())# 10. 梯度计算和优化l_d_fake.backward()optimizer_d.step()# 11. for loglog_dict = OrderedDict()for name, value in loss_dict.items():log_dict[name] = value.mean().item()#print(log_dict)return output, log_dict

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/520333.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

新书速览|PyTorch语音识别实战(人工智能技术丛书)

实战语音唤醒、音频特征抽取、语音情绪分类、Whisper语音转换、鸟叫多标签分类、多模态语音文字转换 01 本书内容 《PyTorch语音识别实战》使用PyTorch 2.0作为语音识别的基本框架,循序渐进地引导读者从搭建环境开始,逐步深入到语音识别基本理论、算法以…

linux下部署OpenCV环境(Java/SpringBoot/IDEA)

环境 本文基于Linux(CentOS 7)、SpringBoot部署运行OpenCV 4.5.5,并顺带记录Windows/IDEA下如何调试SpringBoot调用OpenCV项目。 Windows下调试 首先我们编写代码,并在Windows/IDEA下调试通过。 下载Windows版安装包&#xff0…

星辰天合参与编制 国内首个可兼顾 AI 大模型训练的高性能计算存储标准正式发布

近日,在中国电子工业标准化技术协会高标委的支持和指导下,XSKY星辰天合作为核心成员参与编制的《高性能计算分布式存储系统技术要求》团体标准,在中国电子工业标准化技术协会网站正式发布。 该团体标准强调了分布式存储系统对包括传统高性能计…

教育中的人工智能:ChatGPT只是开始

近日,智能聊天工具ChatGPT在全球掀起热潮。谷歌、微软、苹果等科技巨头纷纷宣布跟进布局,有消息称中国互联网企业将于3月在推出类似ChatGPT的人工智能聊天机器人。 对于许多学生来说,数字助手已经成为一种非正式的家庭作业助手。自1966年第一…

2016年认证杯SPSSPRO杯数学建模C题(第一阶段)如何有效的抑制校园霸凌事件的发生解题全过程文档及程序

2016年认证杯SPSSPRO杯数学建模 C题 如何有效的抑制校园霸凌事件的发生 原题再现: 近年来,我国发生的多起校园霸凌事件在媒体的报道下引发了许多国人的关注。霸凌事件对学生身体和精神上的影响是极为严重而长远的,因此对于这些情况我们应该…

express接受请求参数

传参问题 1. get方式接受请求参数 get方式请求的参数会拼接在地址栏的后面,参数的格式是?namevalue&namevalue...express针对前端get方式发送的数据可以通过req.query来获取后端代码 // cart.js router.get(/getList, (req,res)>{const param {username…

微前端之使用无界创建一个微前端项目

wujie 使用手册 使用简介 主应用配置 安装 wujie依赖main.js配置 是否开启预加载 生命周期函数 – lifecycle.js配置 子应用配置 跨域设置运行模式 生命周期改造 在主应用中,使用wujie,将子应用引入到主应用中去 wujie 使用手册 wujie 是一个基于 Web…

Python实现简单的读文字发音

使用pyttsx3包,先安装。 核心代码:engine pyttsx3.init() # 初始化 uname "周吴郑王" engine.say("奥利给给给" str(uname) "的" str(uname) ",感谢!!!") e…

Day22:安全开发-PHP应用留言板功能超全局变量数据库操作第三方插件引用

目录 开发环境 数据导入-mysql架构&库表列 数据库操作-mysqli函数&增删改查 数据接收输出-html混编&超全局变量 第三方插件引用-js传参&函数对象调用 完整源码 思维导图 PHP知识点: 功能:新闻列表,会员中心&#xff0…

基于Spring Boot的图书个性化推荐系统 ,计算机毕业设计(带源码+论文)

源码获取地址: 码呢-一个专注于技术分享的博客平台一个专注于技术分享的博客平台,大家以共同学习,乐于分享,拥抱开源的价值观进行学习交流http://www.xmbiao.cn/resource-details/1765769136268455938

C#,动态规划的集合划分问题(DP Partition problem)算法与源代码

1 动态规划问题中的划分问题 动态规划问题中的划分问题是确定一个给定的集是否可以划分为两个子集,使得两个子集中的元素之和相同。 动态规划(Dynamic Programming,DP)是运筹学的一个分支,是求解决策过程最优化的过程…

macbook pro 2018 安装 arch linux 双系统

文章目录 友情提醒关于我的 mac在 mac 上需要提前做的事情复制 wifi 驱动 在 linux 上的操作还原 wifi 驱动连接 wifi 网络磁盘分区制作文件系统挂载分区 使用 archinstall 来安装 arch linux遗留问题 友情提醒 安装 archl linux 的时候,mac 的键盘是没法用的&#…