【Diffusion实战】训练一个diffusion模型生成蝴蝶图像(Pytorch代码详解)

  上一篇Diffusion实战是确确实实一步一步走的公式,这回采用一个更方便的库:diffusers,来实现Diffusion模型训练。


Diffusion实战篇:
  【Diffusion实战】训练一个diffusion模型生成S曲线(Pytorch代码详解)
Diffusion综述篇:
  【Diffusion综述】医学图像分析中的扩散模型(一)
  【Diffusion综述】医学图像分析中的扩散模型(二)


0、所需安装

pip install diffusers  # diffusers库
pip install datasets  

1、数据集下载

  下载地址:蝴蝶数据集
  下载好后的文件夹中包括以下文件,放在当前目录下就可以了。

在这里插入图片描述
加载数据集,并对一批数据进行可视化:

import torch
import torchvision
from datasets import load_dataset
from torchvision import transforms
import numpy as np
import torch.nn.functional as F
from matplotlib import pyplot as plt
from PIL import Imagedef show_images(x):"""Given a batch of images x, make a grid and convert to PIL"""x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)grid = torchvision.utils.make_grid(x)grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))return grid_imdef transform(examples):images = [preprocess(image.convert("RGB")) for image in examples["image"]]return {"images": images}device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)# 数据加载
dataset = load_dataset("./smithsonian_butterflies_subset", split='train')image_size = 32
batch_size = 64# 数据增强
preprocess = transforms.Compose([transforms.Resize((image_size, image_size)),  # Resizetransforms.RandomHorizontalFlip(),  # Randomly flip (data augmentation)transforms.ToTensor(),  # Convert to tensor (0, 1)transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)]
)dataset.set_transform(transform)# 数据装载
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)# 抽取一批数据可视化
xb = next(iter(train_dataloader))["images"].to(device)[:8]
print("X shape:", xb.shape)
show_images(xb).resize((8 * 64, 64), resample=Image.NEAREST)

输出可视化结果:

在这里插入图片描述


2、加噪调度器

  即DDPM论文中需要预定义的 β t {\beta_t } βt ,可使用DDPMScheduler类来定义,其中num_train_timesteps参数为时间步 t {t} t

from diffusers import DDPMScheduler# βt值
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)plt.figure(dpi=300)
plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
plt.legend(fontsize="x-large");

根据定义的 β t {\beta_t } βt ,可视化 α ˉ t {\sqrt {{{\bar \alpha }_t}}} αˉt 1 − α ˉ t {\sqrt {1 - {{\bar \alpha }_t}}} 1αˉt

在这里插入图片描述

  通过设置beta_start、beta_end和beta_schedule三个参数来控制噪声调度器的超参数 β t {\beta_t } βt

noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_start=0.001, beta_end=0.004)

在这里插入图片描述

  beta_schedule可以通过一个函数映射来为模型推理的每一步生成一个 β t {\beta_t } βt值。

noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')

在这里插入图片描述

x t = α ˉ t x 0 + 1 − α ˉ t ε {{x_t} = \sqrt {{{\bar \alpha }_t}} {x_0} + \sqrt {1 - {{\bar \alpha }_t}} \varepsilon } xt=αˉt x0+1αˉt ε 加噪前向过程可视化:

timesteps = torch.linspace(0, 999, 8).long().to(device)  # 随机采样时间步
noise = torch.randn_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)  # 加噪
print("Noisy X shape", noisy_xb.shape)
show_images(noisy_xb).resize((8 * 64, 64), resample=Image.NEAREST)

输出为:

在这里插入图片描述


3、扩散模型定义

  diffusers库中模型的定义也非常简洁:

# 创建模型
from diffusers import UNet2DModelmodel = UNet2DModel(sample_size=image_size,  # the target image resolutionin_channels=3,  # the number of input channels, 3 for RGB imagesout_channels=3,  # the number of output channelslayers_per_block=2,  # how many ResNet layers to use per UNet blockblock_out_channels=(64, 128, 128, 256),  # More channels -> more parametersdown_block_types=("DownBlock2D",  # a regular ResNet downsampling block"DownBlock2D","AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention"AttnDownBlock2D",),up_block_types=("AttnUpBlock2D","AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention"UpBlock2D","UpBlock2D",  # a regular ResNet upsampling block),
)model.to(device)
with torch.no_grad():model_prediction = model(noisy_xb, timesteps).sample
model_prediction.shape  # 验证输出与输出尺寸相同

4、扩散模型训练

  定义优化器,和传统模型一样的训练写法:

# 定义噪声调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2"
)# 优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)losses = []for epoch in range(30):for step, batch in enumerate(train_dataloader):clean_images = batch["images"].to(device)# 为图像添加随机噪声noise = torch.randn(clean_images.shape).to(clean_images.device)  # epsbs = clean_images.shape[0]# 为每一张图像随机选择一个时间步timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()  # 根据时间步,向清晰的图像中加噪声, 前向过程:根号下αt^ * x0 + 根号下(1-αt^) * epsnoisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)# 获得模型预测结果noise_pred = model(noisy_images, timesteps, return_dict=False)[0]# 计算损失, 损失回传loss = F.mse_loss(noise_pred, noise)  loss.backward(loss)losses.append(loss.item())# 更新模型参数optimizer.step()optimizer.zero_grad()if (epoch + 1) % 5 == 0:loss_last_epoch = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}")

30个epoch训练过程如下所示:

在这里插入图片描述

可用以下代码查看损失曲线:

# 损失曲线可视化
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(losses)
axs[1].plot(np.log(losses))  # 对数坐标
plt.show()

损失曲线可视化:

在这里插入图片描述


5、图像生成

  (1)通过建立pipeline生成图像:

# 图像生成
# 方法一:建立一个pipeline, 打包模型和噪声调度器
from diffusers import DDPMPipeline
image_pipe = DDPMPipeline(unet=model, scheduler=noise_scheduler)pipeline_output = image_pipe()
plt.figure()
plt.imshow(pipeline_output.images[0])
plt.axis('off')
plt.show()# 保存pipeline
image_pipe.save_pretrained("my_pipeline")  # 在当前目录下保存了一个 my_pipeline 的文件夹

生成的蝴蝶图像如下:

在这里插入图片描述

生成的my_pipeline文件夹如下:

在这里插入图片描述

  (2)通过随机采样循环生成图像:

# 方法二:模型调用, 写采样循环 
# 随机初始化8张图像:
sample = torch.randn(8, 3, 32, 32).to(device)for i, t in enumerate(noise_scheduler.timesteps):# 获得模型预测结果with torch.no_grad():residual = model(sample, t).sample# 根据预测结果更新图像sample = noise_scheduler.step(residual, t, sample).prev_sampleshow_images(sample)

8张生成图像如下:
在这里插入图片描述


6、代码汇总

import torch
import torchvision
from datasets import load_dataset
from torchvision import transforms
import numpy as np
import torch.nn.functional as F
from matplotlib import pyplot as plt
from PIL import Imagedef show_images(x):"""Given a batch of images x, make a grid and convert to PIL"""x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)grid = torchvision.utils.make_grid(x)grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))return grid_imdef transform(examples):images = [preprocess(image.convert("RGB")) for image in examples["image"]]return {"images": images}# --------------------------------------------------------------------------------
# 1、数据集加载与可视化
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)# 数据加载
dataset = load_dataset("./smithsonian_butterflies_subset", split='train')image_size = 32
batch_size = 64# 数据增强
preprocess = transforms.Compose([transforms.Resize((image_size, image_size)),  # Resizetransforms.RandomHorizontalFlip(),  # Randomly flip (data augmentation)transforms.ToTensor(),  # Convert to tensor (0, 1)transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)]
)dataset.set_transform(transform)# 数据装载
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# --------------------------------------------------------------------------------# --------------------------------------------------------------------------------
# 抽取一批数据可视化
xb = next(iter(train_dataloader))["images"].to(device)[:8]
print("X shape:", xb.shape)
show_images(xb).resize((8 * 64, 64), resample=Image.NEAREST)
# --------------------------------------------------------------------------------# --------------------------------------------------------------------------------
# 2、噪声调度器
from diffusers import DDPMScheduler# 加噪声的系数βt
# noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
# noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_start=0.001, beta_end=0.004)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')plt.figure(dpi=300)
plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
plt.legend(fontsize="x-large");
# --------------------------------------------------------------------------------# --------------------------------------------------------------------------------
# 加噪声可视化
timesteps = torch.linspace(0, 999, 8).long().to(device)  # 随机采样时间步
noise = torch.randn_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)  # 加噪
print("Noisy X shape", noisy_xb.shape)
show_images(noisy_xb).resize((8 * 64, 64), resample=Image.NEAREST)
# --------------------------------------------------------------------------------# --------------------------------------------------------------------------------
# 3、创建模型
from diffusers import UNet2DModelmodel = UNet2DModel(sample_size=image_size,  # the target image resolutionin_channels=3,  # the number of input channels, 3 for RGB imagesout_channels=3,  # the number of output channelslayers_per_block=2,  # how many ResNet layers to use per UNet blockblock_out_channels=(64, 128, 128, 256),  # More channels -> more parametersdown_block_types=("DownBlock2D",  # a regular ResNet downsampling block"DownBlock2D","AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention"AttnDownBlock2D",),up_block_types=("AttnUpBlock2D","AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention"UpBlock2D","UpBlock2D",  # a regular ResNet upsampling block),
)model.to(device)
with torch.no_grad():model_prediction = model(noisy_xb, timesteps).sample
model_prediction.shape  # 验证输出与输出尺寸相同
# --------------------------------------------------------------------------------# --------------------------------------------------------------------------------
# 4、扩散模型训练
# 定义噪声调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2"
)# 优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)losses = []for epoch in range(30):for step, batch in enumerate(train_dataloader):clean_images = batch["images"].to(device)# 为图像添加随机噪声noise = torch.randn(clean_images.shape).to(clean_images.device)  # epsbs = clean_images.shape[0]# 为每一张图像随机选择一个时间步timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()  # 根据时间步,向清晰的图像中加噪声, 前向过程:根号下αt^ * x0 + 根号下(1-αt^) * epsnoisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)# 获得模型预测结果noise_pred = model(noisy_images, timesteps, return_dict=False)[0]# 计算损失, 损失回传loss = F.mse_loss(noise_pred, noise)  loss.backward(loss)losses.append(loss.item())# 更新模型参数optimizer.step()optimizer.zero_grad()if (epoch + 1) % 5 == 0:loss_last_epoch = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}")
# --------------------------------------------------------------------------------# --------------------------------------------------------------------------------
# 损失曲线可视化
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(losses)
axs[1].plot(np.log(losses))  # 对数坐标
plt.show()
# --------------------------------------------------------------------------------# --------------------------------------------------------------------------------
# 5、图像生成
# 方法一:建立一个pipeline, 打包模型和噪声调度器
from diffusers import DDPMPipeline
image_pipe = DDPMPipeline(unet=model, scheduler=noise_scheduler)pipeline_output = image_pipe()plt.figure()
plt.imshow(pipeline_output.images[0])
plt.axis('off')
plt.show()image_pipe.save_pretrained("my_pipeline")  # 在当前目录下保存了一个 my_pipeline 的文件夹# 方法二:模型调用, 写采样循环 
# 随机初始化8张图像:
sample = torch.randn(8, 3, 32, 32).to(device)for i, t in enumerate(noise_scheduler.timesteps):# 获得模型预测结果with torch.no_grad():residual = model(sample, t).sample# 根据预测结果更新图像sample = noise_scheduler.step(residual, t, sample).prev_sampleshow_images(sample)grid_im = show_images(sample).resize((8 * 64, 64), resample=Image.NEAREST)
plt.figure(dpi=300)
plt.imshow(grid_im)
plt.axis('off')
plt.show()
# --------------------------------------------------------------------------------

  参考资料:扩散模型从原理到实践. 人民邮电出版社. 李忻玮, 苏步升等.

  diffusers确实很方便使用,有点子PyCaret的感觉了~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/641956.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

六西格玛管理培训:我的转变与成长之旅

4月初,我参与了天行健咨询的六西格玛管理培训,这次经历不仅极大地提升了我的工作效率,还帮助我在工作中实现了卓越。现在,我想分享一些我在这次培训中的学习心得和实践经验,希望能对正在寻求提升绩效和卓越之路的大家有…

用于车载T-BOX汽车级的RA8900CE

用于车载T-BOX等高精度计时的汽车级时钟模块RTC:RA8900CE.车载实时时钟芯片RA8900CE内置32.768Khz的晶体,实现年、月、日、星期、小时、分钟和秒精准计时。RA8900CE满足AEC-Q200认证,内置温补功能,保证实时时钟的稳定可靠,功耗低至…

【目标检测】YOLO系列-YOLOv1 理论基础 通俗易懂

为方便大家理解YOLO的原理,这里将YOLOv1的部分内容基础内容进行用比较直白的话和例子进行阐述,为后续大家学习YOLO作为铺垫。 1、模型所干的活 工作中,大家经常将 Word 文档 上传到某转换器,然后转换输出为PDF文档。目标检测中我…

嵌入式Linux八股(二)——Linux

二、Linux 01.Linux系统编程 01.Linux系统文件类型: 7/8 种 普通文件:- 目录文件:d 字符设备文件:c 块设备文件:b 软连接:l 管道文件:p 套接字:s 未知文件 文件权限说明 …

抖音运营全攻略 沈阳新媒体运营培训

抖音发展趋势 数据显示,2023年,抖音日活量突破10亿。是目前最火的短视频软件。 抖音的总用户数量已超过12亿,日活10亿,人均单日使用时长超过2小时,这只是平均数据,其实大部分人刷抖音时间会超过3个小时&am…

Hbuilder快捷键个人习惯修改

自定义修改 [// {"key":"ctrld","command":"editor.action.deleteLines"},// {"key":"ctrle","command":"editor.action.addSelectionToNextFindMatch"}//目录内查找字符串{"key"…

动物解剖流程3d仿真展示动画支持反复观看和使用

在兽医专业的广袤领域中,动物解剖学作为基石学科,为组织胚胎学、生理学、病理解剖学、外科手术学、临床诊断学等科研教学提供了坚实的基础。而如今,随着科技的飞速发展,我们迎来了一个全新的学习时代——3D数字动物解刨虚拟仿真实…

亚马逊自养号测评三大要点,你了解多少?

对于初涉亚马逊测评的新手而言,近期最常听闻的莫过于平台风控日益严苛,尤其是在留评管理方面。稍有不慎,便可能面临评论被删除的风险。那么,如何有效规避这一困境呢?对于亚马逊新手小白来说,进行测评时又需…

helpdesk桌面运维常见问题解决

helpdesk是一套帮助IT团队管理IT工单生命周期、自动化日常工作、优化工作流程的软件或软件集合,它可以帮助IT团队提高生产力、降低成本、改善服务水平和客户体验。 在现代企业中,helpdesk桌面运维是一项至关重要的工作,helpdesk团队负责处理员…

【机器学习】分类与预测算法的评价与优化

以实际案例解析F1值与P-R曲线的应用 一、分类算法与性能评价的重要性二、F1值与P-R曲线的概念与意义三、实例解析:以垃圾邮件检测为例四、代码实现与结果分析五、结论与展望 在数据驱动的时代,机器学习算法以其强大的数据处理和分析能力,成为…

鸿蒙开发使用云数据库的坑

在使用华为官网提供的CloudDBZoneWrapper.js有个坑, CloudDBZoneWrapper.js需要使用两个包 const clouddb require(hw-agconnect/database-server/dist/index.js); const agconnect require(agconnect/common-server); const path require(path);/*配置区域 */…

C++ | Leetcode C++题解之第44题通配符匹配

题目&#xff1a; 题解&#xff1a; class Solution { public:bool isMatch(string s, string p) {auto allStars [](const string& str, int left, int right) {for (int i left; i < right; i) {if (str[i] ! *) {return false;}}return true;};auto charMatch []…