Diffusion Model: DDPM

本文相关内容只记录看论文过程中一些难点问题,内容间逻辑性不强,甚至有点混乱,因此只作为本人“备忘”,不建议其他人阅读。

Denoising Diffusion Probabilistic Models: https://arxiv.org/abs/2006.11239

DDPM

一、基于 x_0 已知的情况下,x_t 分布的推导过程:推导过程中,直接递归迭代即可。同时,过程中使用了 —— 两个高斯分布的和也满足高斯分布,其中均值为两个高斯分布均值的和,方差为两个高斯分布方差的和。

二、逆向过程中,q(x_{t-1}|x_t, x_0) 分布求解

进一步根据 1 中的结果可得:

公式 9 中的 z_{\theta}(x_t,t) 就是 diffusion model 需要估计的噪声均值,而噪声的方式是由 \alpha_t 或者 \beta_t 直接得到的。

三、具体训练过程:训练过程比较直接,利用 一 中的公式即可。

https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py L274def q_sample(self, x_start, t, noise=None):noise = default(noise, lambda: torch.randn_like(x_start))return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)def get_loss(self, pred, target, mean=True):if self.loss_type == 'l1':loss = (target - pred).abs()if mean:loss = loss.mean()elif self.loss_type == 'l2':if mean:loss = torch.nn.functional.mse_loss(target, pred)else:loss = torch.nn.functional.mse_loss(target, pred, reduction='none')else:raise NotImplementedError("unknown loss type '{loss_type}'")return loss# 输入参数说明:
# x_start:原始图像 x0
# t:当前扩散步数
# noise:噪声,需要注意这里的 noise 与 x_start 维度相同;具体含义是每个位置上元素都服从 0-1 高斯分布
def p_losses(self, x_start, t, noise=None):# 生成第 t 步的高斯噪声noise = default(noise, lambda: torch.randn_like(x_start))# 根据本文 一 中推导的公式得到第 t 步加噪后的图像x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)# 模型预测结果,根据具体的设置,好像可以回归加的噪声,也可以直接回归原始图像model_out = self.model(x_noisy, t)loss_dict = {}if self.parameterization == "eps":# 模型估计噪声target = noiseelif self.parameterization == "x0":# 模型直接估计原始图像target = x_startelse:raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")# 使用 L1 或者 L2 Loss 计算误差loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])log_prefix = 'train' if self.training else 'val'loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})loss_simple = loss.mean() * self.l_simple_weightloss_vlb = (self.lvlb_weights[t] * loss).mean()loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})loss = loss_simple + self.original_elbo_weight * loss_vlbloss_dict.update({f'{log_prefix}/loss': loss})return loss, loss_dict

四、具体生成(采样)过程:根据 二 中推导的公式,依次计算前一步图像的分布。需要注意:

  1. 具体回归的均值的维度与图像维度完全相同,即图像每个位置(包括不同通道)都建模为高斯分布,均值就是无随机时图像应该有的“样子”
  2. 因此,在 T=0 步得到的均值就是最终生成的图像;不过在 T> 0 步依据均值和方差进行采样,可能的原因是增加生成图像的多样性。

https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py L222# 根据本文 二 中的公式计算 x_t-1 的均值和方差
def q_posterior(self, x_start, x_t, t):posterior_mean = (extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t)posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)return posterior_mean, posterior_variance, posterior_log_variance_clippeddef p_mean_variance(self, x, t, clip_denoised: bool):model_out = self.model(x, t)if self.parameterization == "eps":x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)elif self.parameterization == "x0":x_recon = model_outif clip_denoised:x_recon.clamp_(-1., 1.)model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)return model_mean, posterior_variance, posterior_log_variance# 基于估计的图像每个位置的均值 model_mean 和方差 model_log_variance 生成对应随机图像
@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):b, *_, device = *x.shape, x.devicemodel_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)noise = noise_like(x.shape, device, repeat_noise)# no noise when t == 0nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise# 从 T 步 ——> T-1 步 ——> ... ——> 0 步,依次进行反向估计
@torch.no_grad()
def p_sample_loop(self, shape, return_intermediates=False):device = self.betas.deviceb = shape[0]img = torch.randn(shape, device=device)intermediates = [img]for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),clip_denoised=self.clip_denoised)if i % self.log_every_t == 0 or i == self.num_timesteps - 1:intermediates.append(img)if return_intermediates:return img, intermediatesreturn img# 采样入口函数,batch_size 一次生成的图像数量
@torch.no_grad()
def sample(self, batch_size=16, return_intermediates=False):image_size = self.image_sizechannels = self.channelsreturn self.p_sample_loop((batch_size, channels, image_size, image_size),return_intermediates=return_intermediates)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/216499.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【开源】基于Vue和SpringBoot的农家乐订餐系统

项目编号: S 043 ,文末获取源码。 \color{red}{项目编号:S043,文末获取源码。} 项目编号:S043,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 用户2.2 管理员 三、系统展示四、核…

Windows下安装MySQL

几年前学习mycat中间件的时候在window机器上安装过MySql,但是由于电脑配置不高,同时打开Mysql服务,idea、SQlyog等软件非常卡,再加上SQLyog和MySQL版本不兼容导致登录不上,于是把它卸载了。最近做练习需要,…

java学习

【点我-这里送书】 本人详解 作者:王文峰,参加过 CSDN 2020年度博客之星,《Java王大师王天师》 公众号:JAVA开发王大师,专注于天道酬勤的 Java 开发问题中国国学、传统文化和代码爱好者的程序人生,期待你的关注和支持!本人外号:神秘小峯 山峯 转载说明:务必注明来源(…

Python中如何选择Web开发框架?

Python开发中Web框架可谓是百花齐放,各式各样的web框架层出不穷,那么对于需要进行Python开发的我们来说,如何选择web框架也就变成了一门学问了。本篇文章主要是介绍目前一些比较有特点受欢迎的Web框架,我们可以根据各个Web框架的特…

Leetcode103 二叉树的锯齿形层序遍历

二叉树的锯齿形层序遍历 题解1 层序遍历双向队列 给你二叉树的根节点 root ,返回其节点值的 锯齿形层序遍历 。(即先从左往右,再从右往左进行下一层遍历,以此类推,层与层之间交替进行)。 提示&#xff1a…

LLMLingua:集成LlamaIndex,对提示进行压缩,提供大语言模型的高效推理

大型语言模型(llm)的出现刺激了多个领域的创新。但是在思维链(CoT)提示和情境学习(ICL)等策略的驱动下,提示的复杂性不断增加,这给计算带来了挑战。这些冗长的提示需要大量的资源来进行推理,因此需要高效的解决方案,本文将介绍LLM…

深入理解MySQL索引及事务

✏️✏️✏️今天给各位带来的是关于数据库索引以及事务方面的基础知识 清风的CSDN博客 😛😛😛希望我的文章能对你有所帮助,有不足的地方还请各位看官多多指教,大家一起学习交流! 动动你们发财的小手&#…

Django必备知识点(图文详解)

目录 day02 django必备知识点 1.回顾 2.今日概要 3.路由系统 3.1 传统的路由 3.2 正则表达式路由 3.3 路由分发 小结 3.4 name 3.5 namespace 3.4 最后的 / 如何解决? 3.5 当前匹配对象 小结 4.视图 4.1 文件or文件夹 4.2 相对和绝对导入urls​编辑…

ubuntu22.04 arrch64版在线安装java环境

脚本 #安装java#!/bin/bashif type -p java; thenecho "Java has been installed."else#2.Installed Java , must install wgetwget -c https://repo.huaweicloud.com/java/jdk/8u151-b12/jdk-8u151-linux-arm64-vfp-hflt.tar.gz;tar -zxvf ./jdk-8u151-linux-arm6…

C#,《小白学程序》第九课:堆栈(Stack),先进后出的数据型式

1 文本格式 /// <summary> /// 《小白学程序》第九课&#xff1a;堆栈&#xff08;Stack&#xff09; /// 堆栈与队列是相似的数据形态&#xff1b;特点是&#xff1a;先进后出&#xff1b; /// 比如&#xff1a;狭窄的电梯&#xff0c;先进去的人只能最后出来&#xff1…

逸学java【初级菜鸟篇】10.I/O(输入/输出)

hi&#xff0c;我是逸尘&#xff0c;一起学java吧 目标&#xff08;任务驱动&#xff09; 1.请重点的掌握I/O的。 场景&#xff1a;最近你在企业也想搞一个短视频又想搞一个存储的云盘&#xff0c;你一听回想到自己对于这些存储的基础还不是很清楚&#xff0c;于是回家开始了…

ubuntu下配置qtcreator交叉编译环境

文章目录 安装交叉编译工具安装qt creator开发环境配置交叉编译示例demo参考 安装交叉编译工具 安装qt creator开发环境 1 官网 2 填写信息 3 下载 默认没有出现Qt5.15版本 WISONIC\80081001ub16-1001:~$ /opt/Qt/Tools/QtCreator/bin/qtcreator /opt/Qt/Tools/QtCreat…