stargan项目实战及源码解读

数据及代码链接见文末

​​​​​​​论文解析:Star GAN论文解析-CSDN博客

1.测试模块效果与实验分析

        测试数据需要准备两个文件夹src(源)和ref(目标),这两个文件夹下的文件夹名称代表各个domain。

运行测试模块:

python main.py --mode eval --num_domains 2 --w_hpf 1 \--resume_iter 100000 \--train_img_dir data/celeba_hq/train \--val_img_dir data/celeba_hq/val \--checkpoint_dir expr/checkpoints/celeba_hq \--eval_dir expr/eval/celeba_hq

或者指定参数:

 2.项目配置与数据源下载

        以人脸数据集为例,数据集下包含训练集和验证集,训练集和测试集下的文件夹代表一个一个domain 

        

        需要注意的是,数据集是做过特殊处理的,里面的人脸是对齐的,如果要训练自己的数据集,也需要做类似的处理 

环境配置:

  • 安装pytorch,默认为1.4版本,比1.4版本高也行
  • pip install ffmpeg
  • pip install opencv-python
  • pip install scikit-image
  • pip install pillow
  • pip install scipy
  • pip install tqdm
  • pip install munch

 常用参数

模型与损失函数相关

  

batch size

训练和测试输入与测试输出文件夹路径 

3.整体流程

         整个网络有四个网络组成,生成器、map映射网络、ecoder、判别器。

  • 生成网络,即对输入图像生成一张给定风格的图像
  • 映射网络,随机初始化一个向量,通过全连接层得到对应风格的转化向量。
  • ecoder:直接将图像编码为对应风格的向量
  • 判别器:对于输入图像,为每一种风格判断真假  

(1)生成器

        生成器生成特定风格的图像,生成器有U-net结构的网络堆叠而成,即先下采样,在上采样。此处的归一化策略采取Instance norm,即在实例维度进行归一化。并使用残差模块

代码

class Generator(nn.Module):def __init__(self, img_size=256, style_dim=64, max_conv_dim=512, w_hpf=1):super().__init__()dim_in = 2**14 // img_sizeself.img_size = img_sizeself.from_rgb = nn.Conv2d(3, dim_in, 3, 1, 1) #(in_channels,out_channels,kernel_size,stride,padding)self.encode = nn.ModuleList()self.decode = nn.ModuleList()self.to_rgb = nn.Sequential(nn.InstanceNorm2d(dim_in, affine=True), # 在每个实例维度进行归一化nn.LeakyReLU(0.2),nn.Conv2d(dim_in, 3, 1, 1, 0))# down/up-sampling blocksrepeat_num = int(np.log2(img_size)) - 4if w_hpf > 0:repeat_num += 1for _ in range(repeat_num):dim_out = min(dim_in*2, max_conv_dim)self.encode.append(ResBlk(dim_in, dim_out, normalize=True, downsample=True))self.decode.insert(0, AdainResBlk(dim_out, dim_in, style_dim,w_hpf=w_hpf, upsample=True))  # stack-likedim_in = dim_out# bottleneck blocksfor _ in range(2):self.encode.append(ResBlk(dim_out, dim_out, normalize=True)) # 残差模块self.decode.insert(0, AdainResBlk(dim_out, dim_out, style_dim, w_hpf=w_hpf))if w_hpf > 0:device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')self.hpf = HighPass(w_hpf, device)def forward(self, x, s, masks=None):x = self.from_rgb(x)cache = {}for block in self.encode:if (masks is not None) and (x.size(2) in [32, 64, 128]):cache[x.size(2)] = xx = block(x)for block in self.decode:x = block(x, s)if (masks is not None) and (x.size(2) in [32, 64, 128]):mask = masks[0] if x.size(2) in [32] else masks[1]mask = F.interpolate(mask, size=x.size(2), mode='bilinear')x = x + self.hpf(mask * cache[x.size(2)])return self.to_rgb(x)

 (2)Map映射网络

        map网络将随机初始化的隐向量转变为风格向量。 map映射网络主要由全连接层构成 

代码实现:

class MappingNetwork(nn.Module):def __init__(self, latent_dim=16, style_dim=64, num_domains=2):super().__init__()layers = []layers += [nn.Linear(latent_dim, 512)]layers += [nn.ReLU()]for _ in range(3):layers += [nn.Linear(512, 512)]layers += [nn.ReLU()]self.shared = nn.Sequential(*layers)self.unshared = nn.ModuleList()for _ in range(num_domains):self.unshared += [nn.Sequential(nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, style_dim))]def forward(self, z, y):h = self.shared(z)out = []for layer in self.unshared:out += [layer(h)]out = torch.stack(out, dim=1)  # (batch, num_domains, style_dim)idx = torch.LongTensor(range(y.size(0))).to(y.device)s = out[idx, y]  # (batch, style_dim)return s

 (3)判别器

        判别器用于判断生成图片和原始图片的真假。其也是由残差模块堆叠而成。具体来说,生成图片向量预测接近于1,原始图片预测接近于0。但是,与传统的生成器不同,这里的生成器对于每一个domain都要预测。

 

(4)style ecoder

        style ecoder为生成图片预测对应的风格向量。其输入为生成的图片,输出为风格向量。风格向量应该与生成这张图片时生成器输入的风格向量非常相近。其网络结构也与判别器相同。

4. 损失函数

1.Style reconstruction

         首先,在使用生成网络生成图片时,我们会输入一张图片和对应风格的向量s,然后生成得到对应风格的图片。在得到生成图片后,我们再使用ecoder将生成图片编码为对应风格的向量s'。很显然,我们希望s和s'足够接近。

 2.Style diversification(多样性损失)

首先,初始化2组向量z1和z2,然后经过map网络得到对应风格的编码s1和s2,很显然,s1和s2是不同的,我们现在希望根据s1和s2生成的结果差异越大越好,差异越大,多样性越高。即损失函数越大越好

 

3.Preserving source characteristics 

        可以理解为一种重构损失,我们希望生成的结果还是同一个人,因此,对于生成图片还原回去要与原来的输入图片足够接近。

4.Adversarial objective

即判别器损失,原始图片预测接近于1,而生成图像预测接近于0

总损失为上述损失的加权和

数据及代码链接:链接:https://pan.baidu.com/s/1aNlghgo6mtD4iWqNgMOWOQ?pwd=s206 
提取码:s206 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/594021.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[lesson07]函数参数的扩展

函数参数的扩展 函数参数的默认值 C中可以在函数声明时为参数提供一个默认值 当函数调用时没有提供参数的值,则使用默认值 参数的默认值必须在函数声明中指定 函数默认参数的规则 参数的默认值必须从右向左提供函数调用时使用了默认值,则后续参数必…

数据结构:详解【树和二叉树】

1. 树的概念及结构(了解) 1.1 树的概念 树是一种非线性的数据结构,它是由n(n>0)个有限结点组成一个具有层次关系的集合。把它叫做树是因为它看起来像一棵倒挂的树,也就是说它是根朝上,而叶朝…

小林coding图解计算机网络|基础篇02|键入网址到网页显示,期间发生了什么?

小林coding网站通道:入口 本篇文章摘抄应付面试的重点内容,详细内容还请移步:小林coding网站通道 文章目录 孤单小弟——HTTP真实地址查询——DNS指南好帮手——协议栈可靠传输——TCP远程定位——IP两点传输——MAC出口——网卡送别者——交…

一文搞懂 ThreadLocal

简介 ThreadLocal存取的数据,总是与当前线程相关,也就是说,JVM 为每个运行的线程,绑定了私有的本地实例存取空间,从而为多线程环境常出现的并发访问问题提供了一种隔离机制。 ThreadLocal的作用是提供线程内的局部变…

突破编程_前端_ACE编辑器(选中区域、跳转行以及点击事件)

1 选中区域 要在 ACE 编辑器中选中一个区域,通常需要使用编辑器的 selection 对象。 以下是一个简单的示例,展示了如何使用 ACE 编辑器的 API 来选中一个特定的区域: 初始化 ACE 编辑器:首先,需要在页面上初始化 AC…

arm开发板移植工具mkfs.ext4

文章目录 一、前言二、手动安装e2fsprogs1、下载源码包2、解压源码3、配置4、编译5、安装 三、移植四、验证五、总结 一、前言 在buildroot菜单中,可以通过勾选e2fsprogs工具来安装mkfs.ext4工具: Target packages -> Filesystem and flash utilit…

为移动云数据实现基于可撤销属性组的加密:多代理辅助方法

参考文献为2023年发表的Achieving Revocable Attribute Group-Based Encryption for Mobile Cloud Data: A Multi-Proxy Assisted Approach 动机 对于目前的代理辅助的可撤销基于属性加密来说,外包解密存一些缺点。当多个具有相同属性的用户请求外包转换时&#x…

整合Mybatis(Spring学习笔记十二)

一、导入相关的包 junit 包 Mybatis包 mysql数据库包 Spring相关的包 Aop相关的包 Mybatis-Spring包(现在就来学这个) 提示jdk版本不一致的朋友记得 jdk8只支持spring到5.x 所以如果导入的spring(spring-we…

家具木材选择,橡胶木和松木哪个好?福州中宅装饰,福州装修

装修中,选择橡胶木和松木作为家具材料是一个常见的选择。然而,对于哪种木材更适合做家具这个问题,需要从多个方面进行分析和比较。 首先,让我们来看看 橡 胶 木。橡胶木通常被认为是一种坚硬和耐用的木材,这使得它非常…

【快速解决】python缺少了PyQt5模块的QtMultimedia子模块

目录 问题描述 问题原因 解决方法 成功示范 问题描述 Traceback (most recent call last): File "d:\桌面\python项目\DesktopWords-master\main.py", line 4, in <module> from PyQt5.QtMultimedia import QMediaPlayer, QMediaContent ModuleNotFoundEr…

Jupyter IPython帮助文档及其魔法命令

1.IPython 的帮助文档 使用 help() 使用 ? 使用 &#xff1f;&#xff1f; tab 自动补全 shift tab 查看参数和函数说明 2.运行外部 Python 文件 使用下面命令运行外部 Python 文件&#xff08;默认是当前目录&#xff0c;也可以使用绝对路径&#xff09; %run *.py …

Spring-IoC 基于注解

基于xml方法见&#xff1a;http://t.csdnimg.cn/dir8j 注解是代码中的一种特殊标记&#xff0c;可以在编译、类加载和运行时被读取&#xff0c;执行相应的处理&#xff0c;简化 Spring的 XML配置。 格式&#xff1a;注解(属性1"属性值1",...) 可以加在类上…