扩散模型(三):Score-based Generative Models

news/2025/3/31 2:27:12/文章来源:https://www.cnblogs.com/kawhisyz/p/18797184

扩散模型(三):Score-based Generative Models

Score function

为建模数据分布,我呢吧将数据分布写作能量模型EBM形式:

\[p_{\theta}(x)=\frac{\exp(-E_{\theta}(x))}{Z_{\theta}} \]

其中, \(Z_\theta\) 为使其满足概率密度分布的归一化项。我们使用最大似然估计对参数 \(\theta\) 进行训练:

\[max_\theta\sum_i^N\log{p_\theta(x_i)}=max_\theta\sum_i^N-E_\theta(x)-\log{Z_\theta} \]

此时会出现归一化项 \(Z_\theta\) 无法计算的问题,为解决该问题,我们引入score function \(s_\theta(x)=\nabla_x\log{p(x)}\) ,score function的计算与\(Z_\theta\)无关:

\[\nabla_x\log{p(x)}=-\nabla_xf_\theta(x)-s_\theta(x) \]

那么,假设我们已经获得了数据分布 \(p(x)\) 的score function \(\nabla_x\log{p(x)}\) ,该如何获得该数据分布的样本呢?

此时,我们可以采用朗之万动力学采样法,从某个先验分布中随机采样的初始样本开始,利用如下迭代公式,逐渐得到服从数据分布 \(p(x)\) 的样本:

\[x_{i+1}=x_i+\epsilon\nabla_x\log{p(x)}+\sqrt{2\epsilon}z_i,\ z_i\sim\mathcal{N}(0,\mathbf{I}),\ i=0,1,...K \]

当迭代次数 \(K\) 足够多时,样本 \(x_K\) 收敛为从 \(p(x)\) 采样的样本。

Score matching

下面我们讨论如何训练网络估计出score function,我们设计如下损失函数:

\[\mathcal{L}=\mathbb{E}_{x\sim p(x)}\left[||\nabla_x\log{p(x)}-s_{\theta}(x)||^2\right] \]

然而,由于我们不知道真实的数据分布 \(p(x)\) ,该损失函数无法计算。此时,我们可以使用分数匹配score matching的方法在不知道真实数据分布的基础上计算该损失。

分数匹配的相关方法有很多,在此我们讨论使用条件得分匹配替代得分匹配的方法。

\(x_0\sim p_0(x_0)\) 表示原始数据分布,我们定义一个已知解析式的条件分布(例如给定方差系数的正态分布) \(x_0,x\sim p_0(x_0)p(x|x_0)\) ,那么该条件分布的score function \(\nabla_x\log{p(x|x_0)}\) 便可以通过采样一对 \((x_0,x)\) 计算。但我们想要网络学会估计的是分布 \(x\sim p(x)\) 的score function。

我们先给出结论,得分匹配 \(\mathbb{E}_{x\sim p(x)}\left[||\nabla_x\log{p(x)}-s_\theta(x)||^2\right]\) 、条件得分匹配 \(\mathbb{E}_{x_0,x\sim p_0(x_0)p(x|x_0)}\left[||\nabla_x\log{p(x|x_0)}-s_\theta(x)||^2\right]\) 二者作为优化目标是等价的。因此,我们可以将后者作为优化目标。下面给出该结论的证明:

首先,对得分匹配,我们有:

\[\begin{align*} &\mathbb{E}_{x\sim p(x)}\left[||\nabla_x\log{p(x)}-s_\theta(x)||^2\right] \\ = &\mathbb{E}_{x\sim p(x)}\left[||\nabla_x\log{p(x)}||^2+||s_\theta(x)||^2-2s_{\theta}(x)\nabla_x\log{p(x)}\right] \end{align*} \]

对条件得分匹配,我们有:

\[\begin{align*} &\mathbb{E}_{x_0,x\sim p_0(x_0)p(x|x_0)}\left[||\nabla_x\log{p(x|x_0)}-s_\theta(x)||^2\right] \\ = &\mathbb{E}_{x_0,x\sim p(x)p(x_0|x)}\left[||\nabla_x\log{p(x|x_0)}||^2+||s_\theta(x)||^2-2s_{\theta}(x)\nabla_x\log{p(x|x_0)}\right] \\ = &\mathbb{E}_{x\sim p(x),x_0\sim p(x_0|x)}\left[||\nabla_x\log{p(x|x_0)}||^2+||s_\theta(x)||^2-2s_{\theta}(x)\nabla_x\log{p(x|x_0)}\right] \\ = &\mathbb{E}_{x\sim p(x)}\left[\mathbb{E}_{x_0\sim p(x_0|x)}\left[||\nabla_x\log{p(x|x_0)}||^2\right]+||s_\theta(x)||^2-2s_\theta(x)\mathbb{E}_{x_0\sim p(x_0|x)}\left[\nabla_x\log{p(x|x_0)}\right]\right] \end{align*} \]

针对最后一项 \(\mathbb{E}_{x_0\sim p(x_0|x)}\left[\nabla_x\log{p(x|x_0)}\right]\) ,我们进行如下化简:

\[\begin{align*} \mathbb{E}_{x_0\sim p(x_0|x)}\left[\nabla_x\log{p(x|x_0)}\right] &= \int{p(x_0|x)\nabla_x\log{p(x|x_0)}dx_0} \\ &= \int{\frac{p(x|x_0)p_x(x_0)\nabla_x\log{p(x|x_0)}dx_0}{p(x)}} \\ &= \int{\frac{p(x|x_0)\nabla_xp(x|x_0)dx_0}{p(x)}} \\ &= \frac{\nabla_x\int{p(x|x_0)}p_0(x_0)dx}{p(x)} \\ &= \frac{\nabla_x p(x)}{p(x)} \\ &= \nabla_x\log{p(x)} \end{align*} \]

因此可以得到如下化简结果:

\[\begin{align*} &\mathbb{E}_{x_0,x\sim p_0(x_0)p(x|x_0)}\left[||\nabla_x\log{p(x|x_0)}-s_\theta(x)||^2\right] \\ = &\mathbb{E}_{x\sim p(x)}\left[\mathbb{E}_{x_0\sim p(x_0|x)}\left[||\nabla_x\log{p(x|x_0)}||^2\right]+||s_\theta(x)||^2-2s_\theta(x)\nabla_x\log{p(x)}\right] \end{align*} \]

我们将得分匹配和条件得分匹配二者的化简结果相减,发现得到一个与网络参数 \(\theta\) 无关的结果:

\[\mathbb{E}_{x\sim p(x)}\left[\mathbb{E}_{x_0\sim p(x_0|x)}\left[||\nabla_x\log{p(x|x_0)}||^2\right]-||\nabla_x\log{p(x)}||^2\right] \]

自此,我们证明了得分匹配和条件得分匹配的优化目标等价。

存在问题

上述给定的理论基础可以确定设计score-based generative model的基本方案,但实际存在中还存在一些问题。

首先,我们给出流形假设(manifold hypothesis)的定义,现实世界的数据分布倾向于聚集在内嵌在一个高维空间(ambinet space)的低维流形(manifold)上。

基于该假设,直接训练score-based model会遇到如下问题。对于数据密度较低的区域,采样到的训练样本过少,导致该区域的score无法被准确估计,对于分布之外的区域,score的估计更加不准确,而我们使用的朗之万采样法的起点为从某一分布中采样的随机样本,大概率会落在低概率区域或是分布之外,从而无法准确地生成样本。

解决方案:Noise Conditional Score Networks (NCSN)

为了解决上述问题,NSCN采样了如下方法,利用不同尺度的高斯噪声扰动训练数据,使其覆盖整个概率空间,并训练网络估计不同尺度噪声扰动下数据分布的score。

为了权衡噪声过小无法有效填充低密度区域和噪声过大使其严重偏离原始数据分布,NCSN定义了一系列从大到小的噪声级别 \(\{\sigma_i\}^L_{i=1}\) 的高斯分布作为不同采样步数的条件分布:

\[q_{\sigma}(\tilde{x}|x)=\mathcal{N}(\tilde{x}|x,\sigma^2\mathbf{I}) \]

NCSN训练一个网络 \(s_\theta(\tilde{x},\sigma)\) 用于估计不同尺度噪声扰动下的conditional score \(\nabla_{\tilde{x}}\log{q_\sigma(\tilde{x})}\)

由上文证明的条件得分匹配等价于得分匹配,我们可以把训练最小化目标转化为:

\[\begin{align*} &\mathbb{E}_{x,\tilde{x}\sim p(x)q_\sigma(\tilde{x}|x)}\left[||s_\theta(\tilde{x},\sigma)-\nabla_{\tilde{x}}\log{q_\sigma(\tilde{x}|x)}||^2\right] \\ = &\mathbb{E}_{x,\tilde{x}\sim p(x)q_\sigma(\tilde{x}|x)}\left[||s_\theta(\tilde{x},\sigma)+\frac{\tilde{x}-x}{\sigma^2}||^2\right] \end{align*} \]

为不同噪声级别的目标函数增加权重项 \(\lambda(\sigma_i)\) (一般是 \(\sigma_i^2\) ),我们得到最终的目标函数:

\[\mathcal{L}=\sum_{i=1}^L{\lambda(\sigma_i)}\mathbb{E}_{x,\tilde{x}\sim p(x)q_\sigma(\tilde{x}|x)}\left[||s_\theta(\tilde{x},\sigma)+\frac{\tilde{x}-x}{\sigma^2}||^2\right] \]

NCSN的采样方法为退火朗之万采样法。首先,从某个固定的先验分布(原作者采用均匀分布)采样一个初始样本 \(\tilde{x}_0\) ,之后,从 \(\sigma_L\) 开始到 \(\sigma_1\) ,在每种噪声强度 \(\sigma_i\) 下,设定步长为:

\[\alpha_i = \epsilon \cdot \sigma_i^2 / \sigma^2_L \]

之后按如下公式进行 \(T\) 次采样:

\[\tilde{x}_t = \tilde{x}_{t-1}+\frac{\alpha_i}{2}s_\theta(\tilde{x}_{t-1},\sigma_i)+\sqrt{\alpha_i}z_t,z_t\sim\mathcal{N}(0,\mathbf{I}) \]

其中 \(\{\sigma_i\}^L_{i=1},\epsilon,T\) 为预先设定的超参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/906932.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

APP性能测试工具-GT

GT(随身调)是腾讯研发的一款可以用来做App性能测试的工具,可以对APP进行快速的性能测试,检测App的CPU、内存、流量、电量、帧率/流畅度等等、还能开启日志的查看、Crash日志查看、网络数据包的抓取、APP内部参数的调试、真机代码耗时统计等。 虽然现在该项目已经停止维护了…

一文速通Python并行计算:04 Python多线程编程-多线程同步(上)—基于条件变量、事件和屏障

本文介绍了Python多线程同步的三种机制:条件变量(Condition)、事件(Event)和屏障(Barrier),条件变量指的是线程等待特定条件满足后执行,适用于生产者-消费者模型;Event指的是线程通过事件标志进行同步,适用于线程间简单通信;Barrier指的是多个线程需同步到同一阶段…

docker desktop windows安装

我的机器windows 11 家庭版 下载docker desktop for windows 就直接安装了。安装后打开,遇到了界面转圈圈加载不出来问题,docker engine也是stopped. 病急乱投医,先是说要启用hyper-v,控制面板=》程序和功能里没有发现有hyper-v,一看是家庭版,网上倒是有一个脚本可以在家…

C语言打卡学习第6天(2025.3.25)(补发)

只做了一些有关循环分支函数求值的题,感觉循环函数其实差不多,只有一些细微差别,可能是做的题还不够多或者看运用场景吧

C语言打卡学习第5天(2025.3.24)(补发)

1、把char,getchar,putchar简单看了一下,求ascii值之类的 之类的简单看了一下 2、交换值那一题很奇怪,结果我输出的跟答案要求是一样的,交过去之后显示答案错误,白天的时候问一下

Vulnstack红日靶场通关(持续更新)

带你速通内网渗透相关知识点!!!Vulnstack通关 来源于《内网渗透实战攻略》实战部分 个人是写下自己的笔记 攻击链:探索发现阶段->入侵和感染阶段->攻击和利用阶段->探索感知阶段->传播阶段->持久化和恢复阶段 Windows权限级别前置知识:权限层级 账户类型 权…

Ubuntu 24.04安装MySQL,并且配置外网访问

安装启动更新软件包列表sudo apt update安装MySQL软件包sudo apt install mysql-server启动MySQL服务sudo systemctl start mysql重启命令:systemctl restart mysql配置外网访问 需要修改一个配置 vim /etc/mysql/mysql.conf.d/mysqld.cnf注释掉 这行 配置 bind-address …

2022CCPC Online Contest G - Name the Puppy

对正串和反串分别建立 Trie 树,定义 \(dp[i][j]\) 表示正串 Trie 树上编号为 \(i\) 的点匹配反串 Trie 树上编号为 \(j\) 的点所能拼出最长 anti-border 的长度。 如此,从根节点开始搜索,直到无法匹配为止都可以搜,搜到底后回到根节点继续匹配,可以证明,拼出来的 anti-bo…

互联网不景气了那就玩玩嵌入式吧,用纯.NET开发并制作一个智能桌面机器人(四):结合BotSharp智能体框架开发语音交互

前言 前段时间太忙了博客一直都没来得及更新,但是不代表我已经停止开发了,刚好最近把语音部分给调整了一下,所以就来分享一下具体的内容了。我想说一下,更新晚还是有好处的,社区已经有很多的小伙伴自己实现了一些语音对话功能的案例,比如小智也有.NET客户端了,还有就是一…

【AI News | 20250327】每日AI进展

AI Repos 1、playwright-mcp 使用Playwright提供浏览器自动化功能的MCP服务,核心是让LLM通过结构化的可访问性快照与网页交互,不需要依赖截图或视觉模型。可以用来自动填写网页表单、自动收集网页信息、自动进行网页测试等。支持两种模式:快照模式(默认):使用可访问性快照…

markdown常用命令行格式

Markdown 主要命令(语法)如下:标题 使用 # 号表示标题,# 的个数决定标题的级别:一级标题 二级标题 三级标题 四级标题 五级标题 六级标题段落 & 换行 直接输入文字形成段落,使用两个以上空格或 进行换行:这是一个段落。 这是同一段的下一行。 使用 <br> 也可…

微调可以获得什么

1.改变模型的行为: 使模型的响应更稳定; 使模型聚焦于某一领域; 发展期潜力,在某一方面更加出色,比如对话 2.获取新的知识: 学习预训练阶段没学过的知识; 纠正过时的错误和信息;