VAE生成人脸代码

news/2024/7/4 16:38:01/文章来源:https://www.cnblogs.com/xjlearningAI/p/18276484

基于VAE介绍的理论,简单实现VAE生成人脸,代码如下:

utils.py

import os
from torch.utils.data import Dataset
from torchvision.transforms import transforms
import glob
import cv2
import numpy as np
import torchclass MyDataset(Dataset):def __init__(self, img_path, device):super(MyDataset, self).__init__()self.device = deviceself.fnames = glob.glob(os.path.join(img_path+"*.jpg"))self.transforms = transforms.Compose([transforms.ToTensor(),])def __getitem__(self, idx):fname = self.fnames[idx]img = cv2.imread(fname, cv2.IMREAD_COLOR)img = self.transforms(img)img = img.to(self.device)return imgdef __len__(self):return len(self.fnames)

 

VAE.py

import torch
import torch.nn as nnclass VAE(nn.Module):def __init__(self, image_size: int, in_channels: int, latent_dim: int, hid_dims: int = None):super(VAE, self).__init__()self.latent_dim = latent_dimif not hid_dims:hid_dims = [32, 64, 128, 256]feature_size = image_size // (2**4)modules = []for h_d in hid_dims:modules.append(nn.Sequential(nn.Conv2d(in_channels, h_d, 3, 2, 1),nn.BatchNorm2d(h_d),nn.LeakyReLU()))in_channels = h_dself.encoder = nn.Sequential(*modules)self.fc_mu = nn.Linear(hid_dims[-1]*feature_size**2, latent_dim)self.fc_var = nn.Linear(hid_dims[-1]*feature_size**2, latent_dim)# decoderself.decoder_input = nn.Linear(latent_dim, hid_dims[-1]*feature_size**2)hid_dims.reverse()modules = []for i in range(len(hid_dims)-1):modules.append(nn.Sequential(nn.ConvTranspose2d(hid_dims[i], hid_dims[i+1], 3, 2, 1, 1),nn.BatchNorm2d(hid_dims[i+1]),nn.LeakyReLU()))self.decoder = nn.Sequential(*modules)self.decoder_out = nn.Sequential(nn.ConvTranspose2d(hid_dims[-1], hid_dims[-1], 3, 2, 1, 1),nn.BatchNorm2d(hid_dims[-1]),nn.LeakyReLU(),nn.Conv2d(hid_dims[-1], 3, 3, 1, 1, 1),nn.Sigmoid())def encode(self, x):x = self.encoder(x)x = torch.flatten(x, start_dim=1)mu = self.fc_mu(x)var = self.fc_var(x)return mu, vardef decode(self, x):x = self.decoder_input(x)x = x.view(-1, 256, 6, 6)x = self.decoder(x)x = self.decoder_out(x)return xdef re_parameterize(self, mu, log_var):std = torch.exp_(0.5*log_var)eps = torch.randn_like(std)return mu + std*epsdef forward(self, x):mu, log_var = self.encode(x)z = self.re_parameterize(mu, log_var)out = self.decode(z)return out, mu, log_vardef sample(self, n_samples, device):z = torch.randn((n_samples, self.latent_dim)).to(device)samples = self.decode(z)return samplesif __name__ == '__main__':DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")fake_input = torch.ones((1, 3, 96, 96))model = VAE(96, 3, 1024)out, *_ = model(fake_input)print(out.shape)print(model.sample(10, DEVICE).shape)

 

Loss.py

import torch
import torch.nn as nnclass Loss(nn.Module):def __init__(self, kld_weight=0.03):super(Loss, self).__init__()self.kld_weight = kld_weightself.criterion = nn.MSELoss(reduction='mean')def forward(self, input, output, mu, log_var):recon_loss = self.criterion(output, input)kld_loss = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())return recon_loss + self.kld_weight*kld_loss

 

train_vae.py

import os
import numpy as np
import torch
from VAE import VAE
import argparse
from torch.utils.data import DataLoader
from PIL import Image
from torch.optim import Adam
from utils import MyDataset
from torchvision.utils import save_image
from Loss import Loss
from tqdm import tqdmdef args_parser():parser = argparse.ArgumentParser(description="Parameters of training vae model")parser.add_argument("-b", "--batch_size", type=int, default=128)parser.add_argument("-i", "--in_channels", type=int, default=3)parser.add_argument("-d", "--latent_dim", type=int, default=256)parser.add_argument("-l", "--lr", type=float, default=1e-3)parser.add_argument("-w", "--weight_decay", type=float, default=1e-5)parser.add_argument("-e", "--epoch", type=int, default=500)parser.add_argument("-v", "--snap_epoch", type=int, default=1)parser.add_argument("-n", "--num_samples", type=int, default=64)parser.add_argument("-p", "--path", type=str, default="./results_linear")return parser.parse_args()def train(model, input_data, loss_fn, optimizer):optimizer.zero_grad()out, mu, log_var = model(input_data)total_loss = loss_fn(input_data, out, mu, log_var)total_loss.backward()optimizer.step()print("loss:", total_loss.item())if __name__ == '__main__':DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")opt = args_parser()loss_fn = Loss(kld_weight=0.03)dataset = MyDataset(img_path="../faces/", device=DEVICE)train_loader = DataLoader(dataset=dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0)model = VAE(image_size=96, in_channels=opt.in_channels, latent_dim=opt.latent_dim)model.to(DEVICE)optimizer = Adam(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)for epoch in range(opt.epoch):model.train()data_bar = tqdm(train_loader)for step, data in enumerate(data_bar):train(model, data.to(DEVICE), loss_fn, optimizer)if epoch % opt.snap_epoch == 0 or epoch == opt.epoch - 1:model.eval()images = model.sample(opt.num_samples, DEVICE)imgs = images.detach().cpu().numpy()saved_image_path = os.path.join(opt.path, "images")os.makedirs(saved_image_path, exist_ok=True)fname = './my_generated-images-{0:0=4d}.png'.format(epoch)save_image(images, fname, nrow=8)saved_model_path = os.path.join(opt.path, "models")os.makedirs(saved_model_path, exist_ok=True)torch.save(model.state_dict(), os.path.join(saved_model_path, f"epoch_{epoch}.pth"))

 

没有调参,训练333个epoch,模型生成的结果如下:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/734203.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ret2shellcode

参考链接 :基本 ROP - CTF Wiki (ctf-wiki.org) 参考链接: https://www.freebuf.com/vuls/266711.html 参考链接:PWN入门(2-2-1)-栈迁移(x86) (yuque.com)介绍栈溢出漏洞的一种利用方式,通过向可写入可执行内存写入shellcode,并利用栈溢出漏洞将返回地址覆盖为shellco…

判断存在与否

问题:A、B两个表,在B表中判断其中数据在A表是否存在。 函数公式解决: =COUNTIF(A!A:A,A2)

有点难以描述的问题(转列+排序+去重)

问题: 以下数据起始值为0,与0同行的2在数据源中有重复,则取与之重复的2的同行数据即6,6在数据源中仍有重复,再取与之重复的6的同行数据8,8在数据源中没有重复,则为第一行第二列的结果。 以此类推。函数公式解决:=WRAPROWS(UNIQUE(SORT(TOCOL(A3:B8)),,1),2) 先用ToCal…

outside_的第三次博客作业

outside_的第三次博客作业 --Wenxiaowenyy 前言: 这次博客是本学期的最后一次博客,也是大一的最后篇博客,回顾这学期学习java的热情以及完成大作业的积极性比起学c语言的时候减少了许多,原因有很多,其一是自己变懒惰了,其二是这学期的java比起上学期的c语言的确难度增加了…

最大值减不为0的最小值

问题:一列中的最大值减去不为0的最小数(所有数据均为正数) 函数公式解决:=MAX(B4:B7)-MINIFS(B4:B7,B4:B7,">0")如果数据有正有负,则需要用MaxIfs减去MinIfs,MaxIfs与MinIfs用法相同。

搭建rust开发环境-记录

通过官网教程(https://www.rust-lang.org/zh-CN/learn/get-started)下载rustup安装 在项目开始的时候提示需要解析工具,按照提示需下载vsstudio,安装的时候选择c++开发桌面程序,不然后面build时候报错 参考文档:https://blog.csdn.net/weixin_44475303/article/details/1…

文本时间转成小数

问题:文本时间(**小时**分钟**秒)转成小数,小时部分为整数。 函数公式解决:传统套路 =SUBSTITUTE(SUBSTITUTE(A2,"小",),"钟",)*24 新套路 =SUBSTITUTES(A2,{"小","钟"},)*24 更新的套路(正则) =REGEXP(A2,"[小钟]",…

Face Adapter - 一键面部表情迁移、换脸工具 本地一键整合包下载

Face Adapter是一款高效的人脸编辑适配器,由浙江大学和腾讯联合开发,适用于预先训练的扩散模型,专门针对人脸再现和交换任务。只需要上传一张源脸和一张参考人脸,就能按照参考人脸的风格生成相同的面部的表情,一键生成两张换脸照片。类似的ID保持的人像生成软件,还有我们…

高级筛选超过15位的数字

问题:高级筛选直接设置条件,当条件的数字超过15位时会出现错误,要如何解决。 解决:在高级筛选条件中设置公式,公式所引用单元格为数据源表标题行下第一行,公式所在单元格上一个单元格必须空

OOP第三轮大作业总结

关于学习OOP的一点总结 本学期的pta也是走到尾声了,一路过来最深的体会是想做好面向对象真不是件容易的事情,但它确实在日常生活中发挥了很大的作用。个人很喜欢这种和实际结合起来的课程,但几个月下来我学得并不是很好,只能日后自己钻研了。 个人体会 关于语法: 1.因为一…

第三次大作业Blog

目录前言设计与分析踩坑心得改进建议总结 前言 知识点:类与对象的应用: 在三次大作业中,类与对象的应用无疑是核心和基础。这充分体现了Java作为一种面向对象编程语言的特性。通过定义类,我们可以创建具有特定属性和行为的对象,从而构建出复杂的程序逻辑。在每次大作业中,…

前端调用后端产生跨域问题解决

[参考文章](https://www.cnblogs.com/zhaodalei/p/17090119.html) ## 问题复现 * 前端的地址是* 后端的的请求资源地址 http://127.0.0.1:3000/api/category/list。 * 当前端请求获取后端数据时,会报如下错误,导致资源加载不出来。但是直接访问是可以获得数据的。说明不是数据…

java第三次大作业blog

pta第三次博客 目录 • pta第三次博客 o 1.前言 o 2.设计与分析 o 3.踩坑心得: o 4.改进建议 o 5.总结1.前言 这两次题目集的主要考察的知识点是继承和多态,包括对super、extend关键字的使用,方法的重写,类的继承,接口,排序,正则表达式等。 在数据处理方面,作业同样要求…

Python基础之多进程

目录1 多进程1.1 简介1.2 Linux下多进程1.3 multiprocessing1.4 Pool1.5 进程间通信1.6 分布式进程 1 多进程 1.1 简介 要让Python程序实现多进程(multiprocessing),我们先了解操作系统的相关知识。 Unix/Linux操作系统提供了一个fork()系统调用,它非常特殊。普通的函数调用…

[LeetCode] 169. Majority Element

排序,返回中值。class Solution:def majorityElement(self, nums: List[int]) -> int:#always existsnums.sort()return nums[len(nums)//2]

BUUCTF---childRSA(费马引理)

题目点击查看代码 from random import choice from Crypto.Util.number import isPrime, sieve_base as primes from flag import flagdef getPrime(bits):while True:n = 2while n.bit_length() < bits:n *= choice(primes)if isPrime(n + 1):return n + 1e = 0x10001 m = …

Python 使用__slots__来限制实例动态添加属性

在Python中,是可以随便在对象实例中动态添加属性的。那么,怎么样可以防止其他人在调用类实例的时候胡乱添加属性和方法?使用 __slots__ 属性,来限制 class 实例能添加的属性也就是说,只有在 __slots__ 变量中的属性才能被动态添加,否则会添加失败。例如,创建一个 Person …

[python] Python日志记录库loguru使用指北

Loguru是一个功能强大且易于使用的开源Python日志记录库。它建立在Python标准库中的logging模块之上,并提供了更加简洁直观、功能丰富的接口。Logging模块的使用见:Python日志记录库logging总结。Loguru官方仓库见:loguru,loguru官方文档见: loguru-doc。 Loguru的主要特点…

Codeforces Round 955 (Div. 2, with prizes from NEAR!) codeforces div2 955

A. Soccer ------------------------题解--------------- 给你开始比分和结束比分问你中间两队比分有没有相等过有可能就是YES不可能就是NO 结束时两队比分肯定>=各自队伍开始时比分,我们只需要让开始时大的先到达结束比分,再让开始时落后的比分到达结束时比分,只需要在心…

PTA题目集7~8的总结

PTA题目集7~8的总结 一、前言 第七次题目集为家居强电电路模拟程序3。本题模拟的控制设备包括:开关、互斥开关、分档调速器、连续调速器。模拟的受控设备包括:灯、风扇、受控窗帘。两种设备都有两根引脚,通过两根引脚电压的电压差驱动设备工作。输入信息有设备信息、连接信息…