Denoising Diffusion Probabilistic Model(DDPM)原理
1. 生成模型对比
记 真实图片为 \(x_0\),噪声图片为 \(x_t\),噪声变量 \(z\sim \mathcal{N}(\mu,\sigma^2)\),噪声变量 \(\varepsilon \sim \mathcal{N}(0,I)\),编码过程 \(q\),解码过程 \(p\)。
GAN网络
VAE网络
Diffusion网络
加噪编码:
去噪解码:
DDPM加噪过程分成了 t 步,每次只加入少量噪声,同样的去噪过程也分为t步,使得模型更容易学习,生成效果更稳定。
2. 原理解析
2.1 加噪过程
此过程中,时刻t时图片 \(x_t\) 只依赖于上一时刻图片 \(x_{t-1}\),故 将整个加噪过程建模成马尔可夫过程,则有
对于单个加噪过程,希望每次加入的噪声比较小,时间相邻的图片差异不要太大。
设加入噪声\(\sqrt{\beta_t}\varepsilon_t \sim \mathcal{N}(0,\beta_tI),\quad \varepsilon_t \sim \mathcal{N}(0,I), \beta_t \to 0^+\) :
由 \(x_t=\mu_t + \sigma_t \varepsilon\) (重参数化)可知,\(x_t\) 的均值为 \(k_t x_{t-1}\),方差为 \(\beta_t\) ,
其PDF定义为:
把 \(x_t\) 展开:
根据 \(X+Y \sim \mathcal{N}(\mu_1+\mu_2, \sigma_1^2 + \sigma_2^2)\) 有 \(x_t\)的方差:
希望最后一次加噪后的\(x_t\)为\(\mathcal{N}(0,I)\)分布纯噪声,即 \(\sigma_t^2 =1\)。解得 \(\beta_t=1-k_t^2\),即 \(k_t=\sqrt{1-\beta_t}, 0<k_t<1\),此时的均值 \(\mu_t= k_t k_{t-1}\cdots k_1 x_0\) 恰好也收敛于 0 。
令 \(k_t=\sqrt{\alpha_t}=\sqrt{1-\beta_t}\),则有 \(\alpha_t + \beta_t =1\)。
令 \(\bar{\alpha}_t=\alpha_t \alpha_{t-1}\cdots \alpha_1=\prod_{i=1}^t \alpha_i\),有
其中 \(\bar{\varepsilon} \sim \mathcal{N}(0,I)\) 为纯噪声,\(\bar{\alpha}_t\) 控制每步掺入噪声的强度,可以设置为超参数。故 整个加噪过程是确定的,不需要模型学习,相当于给每步时间 t 制作的了一份可供模型学习的样本标签数据。公式 (1) 描述局部过程分布,公式 (2) 描述整体过程分布。
2.2 去噪解码
设,神经网络的完整去噪过程为 \(p_\theta(x_{0}|x_t)\),由马尔科夫链有
对于单步去噪的分布则定义为
其中 均值 \(\mu_\theta\) 和 协方差 \(\Sigma_\theta\)(暂定)需要模型去学习得到。
如何从每步加噪 \(q(x_t|x_{t-1})\) 中,学习去噪的知识呢?我们先观察一下其逆过程 \(q(x_{t-1}|x_t)\) 的分布形式。
高斯分布:
逆过程:
由此可知 \(q(x_{t-1}|x_t)\) 服从高斯分布 \(\mathcal{N}(x_{t-1};\mu_q(x_t,x_0),\Sigma_q(t))\) , 其中 均值\(\mu_q(x_t,x_0)\)可以看作是只与 \(x_0\) 有关的函数,\(\Sigma_q(t)=\sigma_q^2(t)I\) 只与时间步 t 有关,可以看作常数,令 \(\Sigma_\theta(x_t,t)=\Sigma_q(t)\) 即可。
此时,确定优化目标只需要 \(q(x_{t-1}|x_t)\) 和 \(p_\theta(x_{t-1}|x_t)\) 两个分布尽可能相似即可,即最小化两个分布的KL散度来度量。
两个高斯分布的KL散度公式:
代入公式
此时,KL散度最小,只要两个分布的均值 \(\mu_q(x_t,x_0)\) \(\mu_\theta(x_t,t)\) 相近即可。又:
所以,相似的可以把 \(\mu_\theta(x_t,t)\) 设计成如下形式:
代入 \(\mu_q(x_t,x_0)\) \(\mu_\theta(x_t,t)\) 可得:
至此, \(\underset{\theta}{argmin} D_{KL}(q(x_{t-1}|x_t)||p_\theta(x_{t-1}|x_t))\) 的问题,被转化成了通过给定 \((x_t,t)\) 让模型预测图片 \(\hat{x}_\theta(x_t,t)\) 对比 真实图片 \(x_0\) 的问题 (Dalle2的训练采用此方式)。
然而,后续研究发现通过噪声直接预测图片的训练效果不太理想,DDPM通过重参数化把对图片的预测转化成对噪声的预测,获得了更好的实验效果。
\(\mu_q(x_t,x_0)\) 代入 \(x_0\)
同理,\(\mu_\theta(x_t,t)\) 也设计成相近的形式:
计算此时的KL散度:
由此可见,这两种方法在本质上是等价的,\(\Vert \bar{\varepsilon}_0 - \hat{\epsilon}_\theta(x_t,t) \Vert_2^2\) 更关注对加入的噪声的预测。
通俗的解释:给模型 一张加过噪的图片 和 时间步,让模型预测出最初的纯噪声长什么样子。
对于整个过程损失只需每步的损失都是最小即可:\(min\ L_{simple} = \mathbb{E}_{t\sim U\{2,T\}}\underset{\theta}{argmin}[\Vert \bar{\varepsilon}_0 - \hat{\epsilon}_\theta(x_t,t) \Vert_2^2]\)。
2.2.1 模型推理的采样过程
由模型的单步去噪的分布定义:
重参数化,代入 \(\mu_\theta\) \(\Sigma_\theta\):
综上,
-
训练过程:抽取图片 \(x_0\) 和 噪声 \(\bar{\varepsilon}_0\),循环时间步 \(t\sim U\{1,\cdots,T\}\):
- 预测噪声损失 \(\Vert \bar{\varepsilon}_0 - \hat{\epsilon}_\theta(\sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \bar{\varepsilon}_0,t) \Vert_2^2\)
- 然后抽取下一张图片,重复过程,直至损失收敛。
-
推理过程:随机初始噪声图片 \(x_T\),倒序循环时间步 \(t\sim U\{T,\cdots,1\}\) :
- 用模型采样噪声 \(\epsilon_\theta(x_t,t)\)
- 【噪声图片 \(x_t\)】-【采样噪声\(\epsilon_\theta\)】+【随机扰动方差噪声\(\sigma_q(t) \mathrm{z}\)】=【噪声图片 \(x_{t-1}\)】
- 【噪声图片 \(x_{t-1}\)】进行处理,至 \(t=1\) 时,不加扰动
- 【噪声图片 \(x_1\)】-【采样噪声\(\epsilon_\theta\)】= 【真实图片 \(x_{0}\)】
代码实现
超参数选择
时间步 \(T=1000\)
令 \(\bar{\beta_t}={1-\bar{\alpha}_t}= \prod \beta_t\)
加噪过程,每次只加入少许噪声,加噪至最后图片已经接近纯噪声需要稍微加大噪声。
确定取值范围 \(\beta_t \in [0.0001, 0.002]\)
!export CUDA_LAUNCH_BLOCKING=1
import os
from pathlib import Path
import math
import torch
import torchvision
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch.nn import Module, ModuleList
from torch import nn
import torch.nn.functional as F
from torch.optim import Adamfrom einops import rearrange
from einops.layers.torch import Rearrangeimport matplotlib.pyplot as plt
import PIL
class MyDataset(Dataset):def __init__(self,folder,image_size = 96,exts = ['jpg', 'jpeg', 'png', 'tiff']):super().__init__()self.folder = folderself.image_size = image_sizeself.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]self.transform = transforms.Compose([transforms.Resize((self.image_size, self.image_size)),transforms.RandomHorizontalFlip(),transforms.ToTensor(), # Scales data into [0,1]transforms.Lambda(lambda t: (t * 2) - 1), # Scale between [-1, 1]])def __len__(self):return len(self.paths)def __getitem__(self, index):path = self.paths[index]img = PIL.Image.open(path)return self.transform(img)def show_img_batch(batch, num_samples=16, cols=4):reverse_transforms = transforms.Compose([transforms.Lambda(lambda t: (t + 1) / 2), # [-1,1] -> [0,1]transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWCtransforms.Lambda(lambda t: t * 255.0),transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),transforms.ToPILImage(),])"""Plots some samples from the dataset"""plt.figure(figsize=(10, 10))for i in range(batch.shape[0]):if i == num_samples:breakplt.subplot(int(num_samples / cols) + 1, cols, i + 1)plt.imshow(reverse_transforms(batch[i, :, :, :]))plt.show()# dataset = MyDataset("/data/dataset/img_align_celeba")
# # drop_last 保证 batch_size 一致
# loader = DataLoader(dataset, batch_size=32, shuffle=True, drop_last=True)# for batch in loader:
# show_img_batch(batch)
# break
#时间转向量
class SinusoidalPosEmb(Module):def __init__(self, dim, theta = 10000):super().__init__()self.dim = dimself.theta = thetadef forward(self, x):device = x.devicehalf_dim = self.dim // 2emb = math.log(self.theta) / (half_dim - 1)emb = torch.exp(torch.arange(half_dim, device=device) * -emb)emb = x[:, None] * emb[None, :]emb = torch.cat((emb.sin(), emb.cos()), dim=-1)return embclass Block(Module):def __init__(self, dim, dim_out, up = False, dropout = 0.001):super().__init__()self.norm = nn.GroupNorm(num_groups=32, num_channels=dim) # 总通道dim,分为32组 self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)self.act = nn.SiLU()self.dropout = nn.Dropout(dropout)def forward(self, x, scale_shift = None):x = self.norm(x)x = self.proj(x)if (scale_shift is not None):scale, shift = scale_shiftx = x * (scale + 1) + shiftx = self.act(x)return self.dropout(x)class AttnBlock(nn.Module):def __init__(self, in_ch, out_ch, time_mlp_dim, up=False, isLast = False):super().__init__()# 使用 AdaLN 生成 scale, shift 这里输出channel乘2再拆分self.time_condition = nn.Sequential(nn.Linear(time_mlp_dim, time_mlp_dim),nn.SiLU(),nn.Linear(time_mlp_dim, out_ch * 2))self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_ch)self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_ch)self.act1 = nn.SiLU()self.short_cut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()# self.attention = Attention(out_ch, heads = 4)if up:self.sample = nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'nearest'), # 直接插值,原地卷积nn.Conv2d(in_ch, in_ch, 3, padding = 1))else: # downself.sample = nn.Conv2d(in_ch, in_ch, 3, 2, 1)if isLast: # 最底层不缩放self.sample = nn.Conv2d(out_ch, out_ch, 3, padding = 1)self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding = 1, groups=32)self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding = 1)self.dropout = nn.Dropout(0.01)def forward(self, x, time_emb):x = self.norm1(x)h = self.sample(x)x = self.sample(x)h = self.act1(h)h = self.conv1(h)# (b, time_mlp_dim) -> (b, c+c)time_emb = self.time_condition(time_emb)time_emb = rearrange(time_emb, 'b c -> b c 1 1') # b c+c 1 1scale, shift = time_emb.chunk(2, dim = 1)h = h * (scale + 1) + shifth = self.norm2(h)h = self.act1(h)h = self.conv2(h)h = self.dropout(h)h = h + self.short_cut(x)return hclass Unet(Module):def __init__(self,image_channel = 3,init_dim = 32,down_channels = [(32, 64), (64, 128), (128, 256), (256, 512)],up_channels = [(512+512, 256), (256+256, 128), (128+128, 64), (64+64, 32)], #这里需要跳跃拼接,所以输入维度有两个拼一起):super().__init__()self.init_conv = nn.Conv2d(image_channel, init_dim, 7, padding = 3)time_pos_dim = 32time_mlp_dim = time_pos_dim * 4self.time_mlp = nn.Sequential(SinusoidalPosEmb(time_pos_dim),nn.Linear(time_pos_dim, time_mlp_dim),nn.GELU())self.downs = ModuleList([])self.ups = ModuleList([])for i in range(len(down_channels)):down_in, down_out = down_channels[i]isLast = len(down_channels)-1 == i # 底部self.downs.append(AttnBlock(down_in, down_out, time_mlp_dim, up=False, isLast=False))for i in range(len(up_channels)):up_in, up_out = up_channels[i]isLast = 0 == i # 底部self.ups.append(AttnBlock(up_in, up_out, time_mlp_dim, up=True, isLast=False))self.output = nn.Conv2d(init_dim, image_channel, 1)for p in self.parameters():if p.dim() > 1:nn.init.xavier_uniform_(p) #初始化权重def forward(self, img, time):# 时间转向量 time (b, 1) -> (b, time_mlp_dim)time_emb = self.time_mlp(time)x = self.init_conv(img)skip_connect = []for down in self.downs:x = down(x, time_emb)skip_connect.append(x)for up in self.ups:x = torch.cat((x, skip_connect.pop()), dim=1) #先在通道上拼接 再输入x = up(x, time_emb)return self.output(x)# device = "cuda"
# model = Unet().to(device)
# img_size = 64
# img = torch.randn((1, 3, img_size, img_size)).to(device)
# time = torch.tensor([4]).to(device)
# print(model(img, time).shape)
预测噪声损失 \(\Vert \bar{\varepsilon}_0 - \hat{\epsilon}_\theta(\sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \bar{\varepsilon}_0,t) \Vert_2^2\)
最大时间步 \(T=1000\)
加噪过程,每次只加入少许噪声,加噪至最后图片已经接近纯噪声需要稍微加大噪声。
确定取值范围 \(\beta_t \in [0.0001, 0.02]\) 令 \(\bar{\beta_t}={1-\bar{\alpha}_t}\)
模型输入参数:
img = \((\sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \bar{\varepsilon}_0)\)
time = \(t\)
image_size=96
epochs = 500
batch_size = 128
device = 'cuda'
T=1000
betas = torch.linspace(0.0001, 0.02, T).to('cuda') # torch.Size([1000])# train
alphas = 1 - betas # 0.9999 -> 0.98
alphas_cumprod = torch.cumprod(alphas, axis=0) # 0.9999 -> 0.0000
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1-alphas_cumprod)def get_val_by_index(val, t, x_shape):batch_t = t.shape[0]out = val.gather(-1, t)return out.reshape(batch_t, *((1,) * (len(x_shape) - 1))) # torch.Size([batch_t, 1, 1, 1])def q_sample(x_0, t, noise):device = x_0.devicesqrt_alphas_cumprod_t = get_val_by_index(sqrt_alphas_cumprod, t, x_0.shape)sqrt_one_minus_alphas_cumprod_t = get_val_by_index(sqrt_one_minus_alphas_cumprod, t, x_0.shape)noised_img = sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noisereturn noised_img# 校验输入的噪声图片
def print_noised_img():# print(alphas)# print(sqrt_alphas_cumprod)# print(sqrt_one_minus_alphas_cumprod)test_dataset = MyDataset("/data/dataset/img_align_celeba", image_size=96)# drop_last 保证 batch_size 一致test_loader = DataLoader(dataset, batch_size=1, shuffle=True, drop_last=True)for batch in test_loader:test_img = batch.to(device)# img = torch.randn((batch_size, 3, 64, 64)).to(device)time = torch.randint(190, 191, (1,), device=device).long()out = get_val_by_index(sqrt_alphas_cumprod, time, img.shape)# print('out', out)noise = torch.randn_like(test_img, device=img.device)# print('noise', noise, torch.mean(noise))noised_img = q_sample(test_img, time, noise)# print(noised_img.shape)show_img_batch(noised_img.detach().cpu())break
def train(checkpoint_prefix=None, start_epoch = 0):dataset = MyDataset("/data/dataset/img_align_celeba", image_size=image_size)# drop_last 保证 batch_size 一致loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)model = Unet().to(device)if start_epoch > 0:model.load_state_dict(torch.load(f'{checkpoint_prefix}_{start_epoch}.pth'))# optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)for epoch in range(epochs):epoch += 1batch_idx = 0for batch in loader:batch_idx += 1img = batch.to(device)time = torch.randint(0, T, (batch_size,), device=device).long()noise = torch.randn_like(img, device=img.device)noised_img = q_sample(img, time, noise)noise_pred = model(noised_img, time)# loss = F.mse_loss(noise, noise_pred, reduction='sum')loss = (noise - noise_pred).square().sum(dim=(1, 2, 3)).mean(dim=0)loss.backward()optimizer.step()optimizer.zero_grad()if batch_idx % 100 == 0:print(f"Epoch {epoch+start_epoch} | Batch index {batch_idx:03d} Loss: {loss.item()}")if epoch % 100 == 0:torch.save(model.state_dict(), f'{checkpoint_prefix}_{start_epoch+epoch}.pth')train(checkpoint_prefix='ddpm_T1000_l2_epochs', start_epoch = 0)
Epoch 1 | Batch index 100 Loss: 20164.40625
...
Epoch 300 | Batch index 100 Loss: 343.15917
训练完成后,从模型中采样图片
\(x_{t-1}=\dfrac{1}{\sqrt{{\alpha}_t}}\left( x_t - \dfrac{1-{\alpha}_{t}}{\sqrt{1-\bar{\alpha}_t}}\hat{\epsilon}_\theta(x_t,t) \right) + \sigma_q(t) \mathrm{z}\)
\(\sigma^2_q(t)={\dfrac{(1-{\alpha}_{t})(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_{t}}}\) ,则 \(\sigma_q(t)=\sqrt{\dfrac{(1-{\alpha}_{t})(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_{t}}}\) 且 t=1 时,\(\sigma_q(t)=0\)
torch.cuda.empty_cache()
model = Unet().to(device)
model.load_state_dict(torch.load('ddpm_T1000_l2_epochs_300.pth'))
model.eval()
print("")
# 计算去噪系数
one_divided_sqrt_alphas = 1 / torch.sqrt(alphas)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod# 计算方差
alphas_cumprod = alphas_cumprod
# 去掉最后一位,(1, 0)表示左侧 填充 1.0, 表示 alphas t-1
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)std_variance_q = torch.sqrt((1-alphas) * (1-alphas_cumprod_prev) / (1-alphas_cumprod))def p_sample_ddpm(model):def step_denoise(model, x_t, t):one_divided_sqrt_alphas_t = get_val_by_index(one_divided_sqrt_alphas, t, x_t.shape)one_minus_alphas_t = get_val_by_index(1-alphas, t, x_t.shape)one_minus_alphas_cumprod_t = get_val_by_index(sqrt_one_minus_alphas_cumprod, t, x_t.shape)mean = one_divided_sqrt_alphas_t * (x_t - (one_minus_alphas_t / one_minus_alphas_cumprod_t * model(x_t, t)))std_variance_q_t = get_val_by_index(std_variance_q, t, x_t.shape)if t == 0:return meanelse:noise_z = torch.randn_like(x_t, device=x_t.device)return mean + std_variance_q_t * noise_zimg_pred = torch.randn((4, 3, image_size, image_size), device=device)for i in reversed(range(0, T)):t = torch.tensor([i], device=device, dtype=torch.long)img_pred = step_denoise(model, img_pred, t)# print(img_pred.mean())# 每步截断效果比较差# if i % 200 == 0:# img_pred = torch.clamp(img_pred, -1.0, 1.0)torch.cuda.empty_cache()return img_predwith torch.no_grad():img = p_sample_ddpm(model)img = torch.clamp(img, -1.0, 1.0)show_img_batch(img.detach().cpu())
参考文章:
Denoising Diffusion Probabilistic Model
Improved Denoising Diffusion Probabilistic Models
Understanding Diffusion Models: A Unified Perspective
Improved Denoising Diffusion Probabilistic Models