[PyTorch][chapter 57][WGAN-GP 代码实现]

前言:

 下图为WGAN 的效果图:

  绿色为真实数据的分布: 8个高斯分布

  红色: 为随机产生的数据分布,跟真实分布基本一致

WGAN-GP:

1 判别器D: 最后一层去掉sigmoid
2 生成器G 和判别器D: loss不取log
3 损失函数 增加了penalty,使用Adam

 Wasserstein GAN
1 判别器D: 最后一层去掉sigmoid
2 生成器G 和判别器D: loss不取log
3 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
4 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行
 


一  简介

    1.1 模型结构

 1.2 伪代码

      

从Wasserstein距离、对偶理论到WGAN - 科学空间|Scientific Spaces


二  wgan.py

 主要变化:

    Generator 中 去掉了之前的logit 函数

# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 11:10:19 2023@author: chengxf2
"""import torch
from   torch import nn#生成器模型
h_dim = 400
class Generator(nn.Module):def __init__(self):super(Generator,self).__init__()# z: [batch,input_features]self.net = nn.Sequential(nn.Linear(2, h_dim),nn.ReLU(True),nn.Linear( h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, 2))def forward(self, z):output = self.net(z)return output#鉴别器模型
class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()hDim=400# x: [batch,input_features]self.net = nn.Sequential(nn.Linear(2, hDim),nn.ReLU(True),nn.Linear(hDim, hDim),nn.ReLU(True),nn.Linear(hDim, hDim),nn.ReLU(True),nn.Linear(hDim, 1),)def forward(self, x):#x:[batch,1]output = self.net(x)out = output.view(-1)return out

三 main.py

  主要变化:

    损失函数中增加了gradient_penalty

# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 11:28:32 2023@author: chengxf2
"""import visdom
from gan  import  Discriminator
from gan  import Generator
import numpy as np
import random
import torch
from   torch import nn, optim
from    matplotlib import pyplot as plt
from torch import autogradh_dim =400
batchsz = 256
viz = visdom.Visdom()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def weights_init(net):if isinstance(net, nn.Linear):# net.weight.data.normal_(0.0, 0.02)nn.init.kaiming_normal_(net.weight)net.bias.data.fill_(0)def data_generator():"""8- gaussian destributionReturns-------None."""scale = 2a = np.sqrt(2.0)centers =[(1,0),(-1,0),(0,1),(0,-1),(1/a,1/a),(1/a,-1/a),(-1/a, 1/a),(-1/a,-1/a)]centers = [(scale*x, scale*y) for x,y in centers]while True:dataset =[]for i in range(batchsz):point = np.random.randn(2)*0.02center = random.choice(centers)point[0] += center[0]point[1] += center[1]dataset.append(point)dataset = np.array(dataset).astype(np.float32)dataset /=a#生成器函数是一个特殊的函数,可以返回一个迭代器yield datasetdef generate_image(D, G, xr, epoch):      #xr表示真实的sample"""Generates and saves a plot of the true distribution, the generator, and thecritic."""N_POINTS = 128RANGE = 3plt.clf()points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]points = points.reshape((-1, 2))             # (16384, 2)x = y = np.linspace(-RANGE, RANGE, N_POINTS)N = len(x)# draw contourwith torch.no_grad():points = torch.Tensor(points)      # [16384, 2]disc_map = D(points).cpu().numpy() # [16384]plt.contour(x, y, disc_map.reshape((N, N)).transpose())#plt.clabel(cs, inline=1, fontsize=10)plt.colorbar()# draw sampleswith torch.no_grad():z = torch.randn(batchsz, 2)                 # [b, 2]samples = G(z).cpu().numpy()                # [b, 2]plt.scatter(xr[:, 0], xr[:, 1], c='green', marker='.')plt.scatter(samples[:, 0], samples[:, 1], c='red', marker='+')viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch))def gradient_penalty(D, xr,xf):#[b,1]t =  torch.rand(batchsz, 1).to(device)       #[b,1]=>[b,2]  保证每个sample t 相同t =  t.expand_as(xr)#sample penalty interpoation [b,2]mid = t*xr +(1-t)*xfmid.requires_grad_()pred = D(mid) #[256]'''grad_outputs:   如果outputs 是向量,则此参数必须写retain_graph:  True 则保留计算图, False则释放计算图create_graph: 若要计算高阶导数,则必须选为Trueallow_unused: 允许输入变量不进入计算'''grads = autograd.grad(outputs= pred, inputs = mid,grad_outputs= torch.ones_like(pred),create_graph=True,retain_graph=True,only_inputs=True)[0]gp = torch.pow(grads.norm(2, dim=1)-1,2).mean()return gpdef main():lambd = 0.2 #超参数maxIter = 1000torch.manual_seed(10)np.random.seed(10)data_iter  = data_generator()G = Generator().to(device)D = Discriminator().to(device)G.apply(weights_init)D.apply(weights_init)optim_G = optim.Adam(G.parameters(),lr =5e-4, betas=(0.5,0.9))optim_D = optim.Adam(D.parameters(),lr =5e-4, betas=(0.5,0.9))K = 5viz.line([[0,0]], [0], win='loss', opts=dict(title='loss', legend=['D', 'G']))for epoch in range(maxIter):#1: train Discrimator fistlyfor k in range(K):#1.1: train on real dataxr = next(data_iter)xr = torch.from_numpy(xr).to(device)predr = D(xr)#max(predr) == min(-predr)lossr = -predr.mean()#1.2: train on fake dataz = torch.randn(batchsz,2).to(device) #[b,2] 随机产生的噪声xf = G(z).detach() #固定G,不更新G参数 tf.stop_gradient()predf =D(xf)lossf = predf.mean()#1.3 gradient_penaltygp = gradient_penalty(D, xr,xf.detach())#aggregate allloss_D = lossr + lossf +lambd*gpoptim_D.zero_grad()loss_D.backward()optim_D.step()#print("\n Discriminator 训练结束 ",loss_D.item())# 2 train  Generator#2.1 train on fake dataz = torch.randn(batchsz, 2).to(device)xf = G(z)predf =D(xf) #期望最大loss_G= -predf.mean()#optimizeoptim_G.zero_grad()loss_G.backward()optim_G.step()if epoch %100 ==0:viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')generate_image(D, G, xr, epoch)print("\n epoch: %d"%epoch,"\t lossD: %7.4f"%loss_D.item(),"\t lossG: %7.4f"%loss_G.item())if __name__ == "__main__":main()

参考:

课时130 WGAN-GP实战_哔哩哔哩_bilibili

WGAN基本原理及Pytorch实现WGAN-CSDN博客

CSDN

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/129623.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenCV C++ Look Up Table(查找表)

OpenCV C Look Up Table(查找表) 引言 在图像处理和计算机视觉中,查找表(Look Up Table, LUT)是一种非常高效和实用的方法,用于快速地映射或更改图像的颜色和像素值。LUT 能够极大地提高图像处理算法的执…

【C++设计模式之解释器模式:行为型】分析及示例

简介 解释器模式(Interpreter Pattern)是一种行为型设计模式,它提供了一种解决问题的方法,通过定义语言的文法规则,解释并执行特定的语言表达式。 解释器模式通过使用表达式和解释器,将文法规则中的句子逐…

【java基础学习】之DOS命令

#java基础学习 1.常用的DOS命令: dir:列出当前目录下的文件以及文件夹 md: 创建目录 rd:删除目录cd:进入指定目录 cd.. :退回到上级目录 cd\ : 退回到根目录 del:删除文件 exit:退出dos命令行 1.dir:列出当前目录下的文件以及文件夹 2.md: 创建目录 …

【kubernetes】带你了解k8s中PV和PVC的由来

文章目录 1 为什么需要卷(Volume)2 卷的挂载2.1 k8s集群中可以直接使用2.2 需要额外的存储组件2.3 公有云 2 PV(Persistent Volume)3 SC(Storage Class) 和 PVC(Persistent Volume Claim)4 总结 1 为什么需要卷(Volume) Pod是由一个或者多个容器组成的,在启动Pod中…

Logback日志框架使用详解以及如何Springboot快速集成

Logback简介 日志系统是用于记录程序的运行过程中产生的运行信息、异常信息等&#xff0c;一般有8个级别&#xff0c;从低到高为All < Trace < Debug < Info < Warn < Error < Fatal < OFF off 最高等级&#xff0c;用于关闭所有日志记录fatal 指出每个…

【Java】微服务——RabbitMQ消息队列(SpringAMQP实现五种消息模型)

目录 1.初识MQ1.1.同步和异步通讯1.1.1.同步通讯1.1.2.异步通讯 1.2.技术对比&#xff1a; 2.快速入门2.1.RabbitMQ消息模型2.4.1.publisher实现2.4.2.consumer实现 2.5.总结 3.SpringAMQP3.1.Basic Queue 简单队列模型3.1.1.消息发送3.1.2.消息接收3.1.3.测试 3.2.WorkQueue3.…

磁盘满了对日志打印(Logback)的影响

背景 我们生产环境有一个服务半夜报警&#xff1a;磁盘剩余空间不足10%&#xff0c;请及时处理。排查后发现是新上线的一个功能&#xff0c;日志打太多导致的&#xff0c;解决方法有很多&#xff0c;就不赘述了。领导担心报警不及时、或者报警遗漏&#xff0c;担心磁盘满了对线…

sqli-lab靶场通关

文章目录 less-1less-2less-3less-4less-5less-6less-7less-8less-9less-10 less-1 1、提示输入参数id&#xff0c;且值为数字&#xff1b; 2、判断是否存在注入点 id1报错&#xff0c;说明存在 SQL注入漏洞。 3、判断字符型还是数字型 id1 and 11 --id1 and 12 --id1&quo…

spark-03

RDD是抽象概念&#xff0c;分区是物理概念

实用指南:如何解决企业组网中网络卡顿问题?

随着互联网的发展&#xff0c;企业逐步将办公应用系统部署在内网服务器或者上云了&#xff0c;导致很多日常工作都需要网络才能访问。员工在工作的时候网络不给力&#xff0c;卡顿半天也打不开&#xff0c;非常影响工作效率和心情。 在企业组网过程中&#xff0c;网络卡顿现象的…

给 Linux0.11 添加网络通信功能 (Day3: 完成 MIT6.S081 最终实验 网卡驱动(1. 安装工具链和依赖))

url: https://pdos.csail.mit.edu/6.S081/2020/labs/net.html 首先看 tools章节&#xff1a;https://pdos.csail.mit.edu/6.S081/2020/tools.html 浏览了一下&#xff0c;就是要我们安装依赖 执行以下命令 sudo apt-get install git build-essential gdb-multiarch qemu-syst…

100M跨境电商服务器能同时容纳多少人访问?

​  随着“出国”“出海”需求的业务量增多&#xff0c;网络的不断发展&#xff0c;服务商开始在带宽资源配备上作出各种改进。无论是纯国际带宽还是优化回国带宽租用&#xff0c;我们都可以独享&#xff0c;并且享受到大带宽。一般&#xff0c;做跨境电商业务的群体&#xf…