diffusion model(四)文生图diffusion model(classifier-free guided)

文章目录

系列阅读

  • diffusion model(一)DDPM技术小结 (denoising diffusion probabilistic)
  • diffusion model(二)—— DDIM技术小结
  • diffusion model(三)—— classifier guided diffusion model
URL
paperClassifier-Free Diffusion Guidance
GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models
githubhttps://github.com/openai/glide-text2im

文生图diffusion model(classifier-free guided)

背景

在classifier-guided这篇博客我们提到对于一般的DM(如DDPM, DDIM)的采样过程是直接从一个噪声分布,通过不断采样来生成图片。但这个方法生成的图片类别是随机的,classifier-guided通过额外训练一个分类器来不断矫正每一个时间步的生成图片,最终实现特定类别图片的生成。

Classifier-free的核心思路是:我们无需训练额外的分类器,直接训练带类别信息的噪声预测模型来实现特定类别图片的生成,即 ϵ θ ( x t , t ) → ϵ ^ θ ( x t , y , t ) \epsilon_{\theta}(x_t, t) \rightarrow \hat{\epsilon}_{\theta}(x_t, y, t) ϵθ(xt,t)ϵ^θ(xt,y,t)。从而简化整体的pipeline。

此外,classifier-free方法不局限于类别信息的融入,它还能实现将语义信息融入到diffusion model中,实现更为灵活的文生图。这用classifier-guide是很难做到的。目前的很多工作如DALLE,Stable Diffusion, Imagen等都是Classifier-free形式。如:

在这里插入图片描述

下面我们来看他是怎么做的吧!

方法大意

classifier-free diffusion的实现非常简单。下面对比普通的diffusion model,classifier-guided与classifier-free三种方式的差异。

模型训练目标实现功能训练数据
DM (DDPM, DDIM) ϵ θ ( x t , t ) \epsilon_{\theta}(x_t, t) ϵθ(xt,t)从服从高斯分布的噪声中生成图片图片
classifier-guided DM ϵ θ ( x t , t ) \epsilon_{\theta}(x_t, t) ϵθ(xt,t)和分类器 p ( y ∣ x t ) p(y|x_t) p(yxt)从服从高斯分布的噪声中生成特定类别的图片DM:图片 分类器:图片-标签对
classifier-free DM ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t), ϵ θ ( x t , t ) \epsilon_{\theta}(x_t, t) ϵθ(xt,t)从服从高斯分布的噪声中生成符合文本描述的图片图片-文本对
  • 对于训练 ϵ θ ( x t , t ) \epsilon_{\theta}(x_t, t) ϵθ(xt,t)来估计 x t x_t xt在时间 t t t上添加的噪声,再根据采样公式推出 x t − 1 x_{t-1} xt1,从而实现图片生成。训练数据只需要准备图片即可。
  • 对于classifier-guided DM是在普通DM的基础上,额外再训练一个Classifier来获得当前时间步生成的图片类别概率分布,从而实现特定类别的图片生成。
  • 对于classifier-free DM将类别信息(或语义信息)集成到diffusion model的训练过程中,训练 ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t) ϵ θ ( x t , y = ∅ , t ) ( 即 ϵ θ ( x t , t ) ) \epsilon_{\theta}(x_t, y=\empty,t)(\text{即}\epsilon_{\theta}(x_t,t)) ϵθ(xt,y=,t)(ϵθ(xt,t))。训练的时候也会加入无类别信息(或语义信息)的图片进行训练。

回答3个问题深入理解classifier-free DM

  1. 模型如何融入类别信息(或语义信息)
  2. 如何训练 ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t) ϵ θ ( x t , y = ∅ , t ) \epsilon_{\theta}(x_t, y=\empty,t) ϵθ(xt,y=,t)
  3. 如何进行采样生成

模型如何融入类别信息(或语义信息)

采用交叉注意力机制融入

我们知道,深度学习模型推理的本质可以理解为一系列的数值计算,因此将类别信息(或语义信息)融入到模型中需要预先将其转化为数值。转化的方法有很多,如可以用一个embedding layer。也可以用NLP模型,如Bert、T5、CLIP的text encoder等将类别信息(或语义信息)转化为数值向量,一般称为text embedding。随后需要将text embedding和原本模型中的image representation进行融合。最为常见且有效的方法是用交叉注意力机制CrossAttention。具体来说就是将text embedding作为注意力机制中的keyvalue,原始的图片表征作为query。大家熟知的Stable Diffusion用的就是这个融入方法。交叉注意力机制融入语义信息的本质是spatial-wise attention。

在这里插入图片描述

class SpatialCrossAttention(nn.Module):def __init__(self, dim, context_dim, heads=4, dim_head=32) -> None:super(SpatialCrossAttention, self).__init__()self.scale = dim_head ** -0.5self.heads = headshidden_dim = dim_head * headsself.proj_in = nn.Conv2d(dim, context_dim, kernel_size=1, stride=1, padding=0)self.to_q = nn.Linear(context_dim, hidden_dim, bias=False)self.to_k = nn.Linear(context_dim, hidden_dim, bias=False)self.to_v = nn.Linear(context_dim, hidden_dim, bias=False)self.to_out = nn.Conv2d(hidden_dim, dim, 1)def forward(self, x, context=None):x_q = self.proj_in(x) b, c, h, w = x_q.shapex_q = rearrange(x_q, "b c h w -> b (h w) c")if context is None:context = x_qif context.ndim == 2:context = rearrange(context, "b c -> b () c")q = self.to_q(x_q)k = self.to_k(context)v = self.to_v(context)q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=self.heads), (q, k, v))sim = einsum('b i d, b j d -> b i j', q, k) * self.scale# attention, what we cannot get enough ofattn = sim.softmax(dim=-1)out = einsum('b i j, b j d -> b i d', attn, v)out = rearrange(out, '(b h) n d -> b n (h d)', h=self.heads)out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w)out = self.to_out(out)return out 

基于channel-wise attention融入

该融入方法与time-embedding的融入方法相同,在时间中往往会预先和time-embedding进行融合,再融入到图片特征中,伪代码如下:

# mixture time-embedding and label embedding
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:assert y.shape == (x.shape[0],)emb = emb + self.label_emb(y)
while len(emb_out.shape) < len(h.shape):emb_out = emb_out[..., None]
emb_out = self.emb_layers(emb).type(h.dtype)  # h is image feature
scale, shift = th.chunk(emb_out, 2, dim=1)  # One half of the embedding is used for scaling and the other half for offset
h = h * (1 + scale) + shift  

基于channel-wise的融入粒度没有CrossAttention细。一般适用类别数量有限的特征融入,如时间embedding,类别embedding。而语义信息的融入更推荐上面CrossAttention的方法。

在这里插入图片描述

如何训练 ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t) ϵ θ ( x t , y = ∅ , t ) \epsilon_{\theta}(x_t, y=\emptyset,t) ϵθ(xt,y=,t)

ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t)的训练需要图文对,但互联网上具备文本描述的图片只是浩如烟海的图片海洋中的一小部分。仅用具备图文对数据训练 ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t)将会大大束缚DM的生成多样性。另外,为了使得模型更好的捕获图文的联系 ϵ θ ( x t , y = ∅ , t ) \epsilon_{\theta}(x_t, y=\empty,t) ϵθ(xt,y=,t)的数据不宜过多,否则模型生成结果的保真度会降低。反之,若 ϵ θ ( x t , y = ∅ , t ) \epsilon_{\theta}(x_t, y=\empty,t) ϵθ(xt,y=,t)数据过少,将会影响生成结果的多样性。需要根据实际的场景进行调整。

有两个实践中的trick需要注意

  • 在实践中,为了统一 y = ∅ y=\empty y= y ≠ ∅ y \neq \empty y=两种情形,通常会给定一个 y = ∅ y=\empty y=的embedding(可以随机初始化,也可以人为给定),来统一两种情形的建模。
  • 即使所有的数据都有图片对也没有关系,只需在每一个batch中随机将某些数据的标签编码替换为 y = ∅ y=\empty y=的embedding即可。另外

如何进行采样生成

classifier-free diffusion的采样生成过程与前面介绍的DDPM,DDIM类似。唯一有所区别的是将原本的 ϵ ( x t , t ) \epsilon(x_t, t) ϵ(xt,t)用下式代替。
ϵ ^ θ ( x t , y , t ) = ϵ θ ( x t , y = ∅ , t ) + s [ ϵ θ ( x t , y , t ) − ϵ θ ( x t , y = ∅ , t ) ] \begin{align} \hat{\epsilon}_{\theta}(x_t, y, t)=\epsilon_{\theta}(x_t, y=\empty,t) + s[\epsilon_{\theta}(x_t, y, t) - \epsilon_{\theta}(x_t, y=\empty, t) ]\tag{1} \end{align} ϵ^θ(xt,y,t)=ϵθ(xt,y=,t)+s[ϵθ(xt,y,t)ϵθ(xt,y=,t)](1)

下面给出详细的推导过程:

首先根据贝叶斯公式有
p ( y ∣ x t ) = p ( x t ∣ y ) p ( y ) ⏞ 先验分布 p ( x t ) ⇒ p ( y ∣ x t ) ∝ p ( x t ∣ y ) / p ( x t ) ⇒ 取对数 log ⁡ p ( y ∣ x t ) = log ⁡ p ( x t ∣ y ) − log ⁡ p ( x t ) ⇒ 对 x t 求导 ∇ x t log ⁡ p ( y ∣ x t ) = ∇ x t log ⁡ p ( x t ∣ y ) − ∇ x t log ⁡ p ( x t ) ⇒ 根据score function ∇ x t log ⁡ p θ ( x t ) = − 1 1 − α ‾ t ϵ θ ( x t ) ∇ x t log ⁡ p ( y ∣ x t ) = − 1 1 − α ‾ t ( ϵ θ ( x t , y , t ) − ϵ θ ( x t , y = ∅ , t ) ) (2) \begin{aligned} p (y| x_t) & = \frac{p (x_t|y) \overbrace{p(y)}^{\text{先验分布}} } {p(x_t) } \\ \Rightarrow p (y| x_t) & \propto p (x_t|y) / {p (x_t) } \\ \stackrel{取对数} \Rightarrow \log{p (y| x_t)} & = \log{p (x_t|y)} - \log{{p (x_t) }} \\ \stackrel{对x_t求导} \Rightarrow \nabla_{x_t}\log{p (y| x_t)} & = \nabla_{x_t}\log{p (x_t|y)} - \nabla_{x_t}\log{{p (x_t) }} \\ \stackrel{\text{根据score function} \nabla_{x_t} \log p_\theta (x_t) = - \frac{1}{\sqrt{1 - \overline{\alpha}_t}} \epsilon_\theta(x_t)} \Rightarrow \nabla_{x_t}\log{p (y| x_t)} & = - \frac{1}{\sqrt{1 - \overline{\alpha}_t}}(\epsilon_{\theta}(x_t, y, t) - \epsilon_{\theta}(x_t, y=\empty, t) ) \end{aligned} \tag{2} p(yxt)p(yxt)取对数logp(yxt)xt求导xtlogp(yxt)根据score functionxtlogpθ(xt)=1αt 1ϵθ(xt)xtlogp(yxt)=p(xt)p(xty)p(y) 先验分布p(xty)/p(xt)=logp(xty)logp(xt)=xtlogp(xty)xtlogp(xt)=1αt 1(ϵθ(xt,y,t)ϵθ(xt,y=,t))(2)
当我们得到 ∇ x t log ⁡ p ( y ∣ x t ) \nabla_{x_t}\log{p (y| x_t)} xtlogp(yxt),参考classifier-guided的式(17)
ϵ ^ ( x t ∣ y ) ⏟ 本文中的 ϵ ^ θ ( x t , y , t ) : = ϵ θ ( x t ) ⏟ 本文中的 ϵ θ ( x t , y = ∅ , t ) − s 1 − α ‾ t ∇ x t log ⁡ p ϕ ( y ∣ x t ) (3) \underbrace{\hat{\epsilon}(x_t|y)}_{\text{本文中的}\hat{\epsilon}_{\theta}(x_t, y, t)} := \underbrace{\epsilon_\theta(x_t)}_{\text{本文中的}\epsilon_{\theta}(x_t, y=\empty, t)} - s\sqrt{1 - \overline{\alpha}_t}\nabla_{x_t} \log{p_\phi(y|x_t)} \tag{3} 本文中的ϵ^θ(xt,y,t) ϵ^(xty):=本文中的ϵθ(xt,y=,t) ϵθ(xt)s1αt xtlogpϕ(yxt)(3)
可得
ϵ ^ θ ( x t , y , t ) = ϵ θ ( x t , y = ∅ , t ) + s [ ϵ θ ( x t , y , t ) − ϵ θ ( x t , y = ∅ , t ) ] \begin{align} \hat{\epsilon}_{\theta}(x_t, y, t)=\epsilon_{\theta}(x_t, y=\empty,t) + s[\epsilon_{\theta}(x_t, y, t) - \epsilon_{\theta}(x_t, y=\empty, t) ]\tag{4} \end{align} ϵ^θ(xt,y,t)=ϵθ(xt,y=,t)+s[ϵθ(xt,y,t)ϵθ(xt,y=,t)](4)
后面的采样过程与之前的方式一致。

结语

本文详细介绍了classifier-free的提出背景与具体实现方案。它是后续一系列如stable diffusion,DALLE等文生图工作的基石。

参考文献

[1]: Classifier-Free Diffusion Guidance
[2]: GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/17574.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

React 之 CSS编写方式

一、概述 整个前端已经是组件化的天下&#xff0c;而CSS的设计就不是为组件化而生的&#xff0c;所以在目前组件化的框架中都在需要一种合适的CSS解决方案 在组件化中选择合适的CSS解决方案应该符合以下条件&#xff1a; 可以编写局部css&#xff1a;css具备自己的具备作用域&a…

静态路由介绍

目录 静态路由配置方法&#xff08;基本配置&#xff09;&#xff1a; 静态路由的拓展配置 负载均衡 1.环回接口——测试 2.手工汇总——子网汇总 3.路由黑洞&#xff08;黑洞路由) 4.缺省路由 5.空接口——NULL 0 6.浮动静态路由 静态路由配置方法&#xff08;基本配置&#x…

Spark---第 1 章 Spark 内核概述

Spark 内核泛指 Spark 的核心运行机制&#xff0c;包括 Spark 核心组件的运行机制、Spark 任务调度机制、Spark 内存管理机制、Spark 核心功能的运行原理等&#xff0c;熟练掌握 Spark 内核原理&#xff0c;能够帮助我们更好地完成 Spark 代码设计&#xff0c;并能够帮助我们准…

python车牌识别

识别结果 蓝牌 绿牌 黄牌 环境 python:3.9\opencv:4.5.1 环境安装 pip3 install opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple pip3 install hyperlpr -i https://pypi.tuna.tsinghua.edu.cn/simple 修改 cd /Library/Frameworks/Python.framework/Versi…

14、双亲委托模型

双亲委托模型 先直接来看一幅图 双亲委派模型的好处&#xff1a; 主要是为了安全性&#xff0c;避免用户自己编写的类动态替换Java的一些核心类&#xff0c;比如 String。 同时也避免了类的重复加载&#xff0c;因为JVM中区分不同类&#xff0c;不仅仅是根据类名&#xff0c…

【Linux】生产者消费者模型

目录 一、生产者消费者模型的优点 二、基于BlockingQueue的生产者消费者模型 1、BlockingQueue 2、C queue模拟阻塞队列的生产消费模型 3、POSIX信号量 3.1、信号量接口 三、基于环形队列的生产消费模型 1、模型说明 2、代码实现 3、互斥锁与信号量 一、生产者消费者模…

Windows+IDEA+Nginx反向代理本机实现简单集群

先简单创建一个项目&#xff0c;可以是Maven也可以是Spring Initializr&#xff0c;如果是 Maven则需要自己配置启动类 按照目录路径创建controller类 package com.cloud.SR.controller;import org.springframework.beans.factory.annotation.Autowired; import org.springfram…

动作捕捉技术在发布会中的应用,虚拟数字人如何实现实时动作捕捉

动作捕捉又称为动态捕捉&#xff0c;是一种将真人动作转化为数字数据的技术&#xff0c;通过传感器等设备记录真人的运动轨迹&#xff0c;再将这些动作捕捉数据转化为计算机可以识别的数字数据&#xff0c;近年来在发布会中得到了广泛应用。 素材源于网络 在2023海丝之路文化和…

Linux:项目自动化构建工具——make/Makefile

文章目录 一.make与Makefile的关系1.Makefile2.make 二.项目清理1.clean2. .PHONY 前言&#xff1a; 本章主要内容有认识与学习Linux环境下如何使用项目自动化构建工具——make/makefile。 一.make与Makefile的关系 当我们编写一个较大的软件项目时&#xff0c;通常需要将多个…

集成学习-BaggingVoting和多个模型的混淆矩阵

当涉及到集成学习时&#xff0c;投票法和袋装法是两种常见的技术&#xff0c;用于将多个基学习器&#xff08;base learner&#xff09;组合成一个强大的集成模型。 投票法&#xff08;Voting&#xff09;&#xff1a;投票法是一种简单且常用的集成学习方法。在投票法中&#…

Graalvm编译spring boot 3 + jpa 的原生镜像

编译spring boot 3 native jpa的原生镜像 其中涉及版本&#xff1a; maven: 3.5.4 jdk: 17 graalvm: 22.3 springboot jpa: 3.0.8 一、Windows 1、graalvm安装 GraalVM22.3.0安装地址 解压到任意目录后添加JAVA_HOME环境变量 新增path&#xff1a;%JAVA_HOME%与%JAVA_H…

MySQL数据库(三)

前言 聚合查询、分组查询、联合查询是数据库知识中最重要的一部分&#xff0c;是将表的行与行之间进行运算。 目录 前言 一、聚合查询 &#xff08;一&#xff09;聚合函数 1、count 2、sum 3、avg 4、max 5、min 二、分组查询 &#xff08;一&#xff09;group by …