15、Scalable Diffusion Models with Transformers

简介

官网
在这里插入图片描述
DiT(Diffusuion Transformer)将扩散模型的 UNet backbone 换成 Transformer,并且发现通过增加 Transformer 的深度/宽度或增加输入令牌数量,具有较高 Gflops 的 DiT 始终具有较低的 FID(~2.27),这样说明 DiT 是可扩展的(Scalable),网络复杂度(以 Gflops 度量)与样本质量(以 FID 度量)之间存在强相关性。

在这里插入图片描述

实现流程

在这里插入图片描述

Patchify

对于 256 × 256 × 3 256 \times256\times3 256×256×3 的图片, 中间特征 z 的维度是 32 × 32 × 4 32\times32\times4 32×32×4。DiT 的第1步和 ViT 一样,都是把图片 Patchify,并经过 Linear Embedding,最终变为 T 个 d 维的 tokens。在 Patchify 之后,将标准的基于 ViT 频率的位置编码 (sine-cosine 版本) 应用到所有的输入 tokens 上面。token T 的数量由 Patch 的大小 p 决定。Patch 的大小 p 和 token 的数量 T 之间满足 T = ( I / p ) 2 T=(I/p)^2 T=(I/p)2 的关系。当 Patch 的大小 p 越小时,token 的数量 T 越大。减半 p 将会使 T 变为4倍,使得计算量也变为4倍。尽管 p 对 GFLOPs 有显著的影响,但参数量没有很实质的影响。
在这里插入图片描述
在 DiT design space 里面使用 p = 2 , 4 , 8 p=2,4,8 p=2,4,8

DiT block design

在patchify之后,输入令牌由一系列转换器块处理。除了带噪声的图像输入,扩散模型有时还处理附加的条件信息,如噪声时间步长 t、类标签 c、自然语言等。

作者探索了4种不同类型的 Transformer Block,以不同的方式处理条件输入。这些设计都对标准 ViT Block 进行了微小的修改。

In-context conditioning

在这里插入图片描述
像这样的带条件输入的情况,In-context conditioning 的做法只需要将时间步长 t , 类标签 c 作为2个额外的 token 附加到输入的序列中。作者将它们视为与图像的 token 没有任何区别。这就有点类似于 ViT 的 [CLS] token。这就允许 DiT 使用标准的 ViT Block 而不用任何修改。经过了最后一个 Block 之后,删除条件 token 就行。这种方式带来的额外 GFLOPs 微不足道。

Cross-attention block

在这里插入图片描述
Cross-attention block 的做法是将 t和 c 的 Embedding 连接成一个长度为2的 Sequence,且与 image token 序列分开。这种方法给 Transformer Block 添加一个 Cross-Attention 块。这个操作带来的额外 GFLOPs 开销是大约 15%。

Adaptive layer norm (adaLN) block.

Adaptive Layer Norm (adaLN) Block 遵循 GAN[15]中的自适应归一化层,希望探索这个东西在扩散模型里面好不好用。没有直接学习缩放和移位参数 γ \gamma γ β \beta β ,而是改用噪声时间步长 t 和类标签 c 得到。adaLN 带来的额外的 GFLOPs 是最少的,因此计算效率最高。它也是唯一一种限制于将相同函数应用于所有令牌的条件调节机制

这里补充一些知识,归一化的一般公式是: z ˉ = z − μ σ 2 + ϵ ⊙ γ + β \bar{z} = \frac{z-\mu}{\sqrt{\sigma^2+\epsilon} }\odot \gamma + \beta zˉ=σ2+ϵ zμγ+β ,其中 z 是激活值, μ , σ 2 \mu,\sigma^2 μ,σ2 是根据激活值按照某个规则(如 BatchNorm、LayerNorm、GroupNorm等)计算出来的均值和方差。 γ , β \gamma,\beta γ,β 是两个可学习的参数。

主要是由于自适应归一化层(adaptive normalization layers)在 GAN 和 UNet backbone 中的广泛使用(归一化的参数(如均值和方差)是根据当前批次的数据或特定层的激活值计算得出的。而在自适应归一化层中,这些参数可能会根据另一组数据(如风格图像)进行调整,以实现特定的效果),考虑将 Transformer 块中的 standard layer norm layers 替换成 adaptive layer norm (adaLN)。不同的是,对于其中的 scale and shift parameters 即 γ , β \gamma,\beta γ,β ,并不是直接学习它们,而是从
t 和 c 的 embedding vector 的和去回归这两个参数。

adaLN-Zero block

在这里插入图片描述
在有监督学习中,对每个 Block 的第一个 Batch Norm 操作的缩放因子进行 Zero-Initialization 可以加速其大规模训练。基于 U-Net 的扩散模型使用类似的初始化策略,对每个 Block 的第一个卷积进行 Zero-Initialization。本文作者作了一些改进:除了回归计算缩放和移位参数 γ \gamma γ β \beta β 之外,还回归缩放系数 α \alpha α

作者初始化 MLP 使其输出的缩放系数 α \alpha α 全部为0,这样一来,DiT Block 就初始化为了 Identity Function。adaLN-Zero Block 带来的额外的 GFLOPs 可以忽略不计。

Transformer Decoder

在最后一个 DiT Block 之后,需要将 image tokens 的序列解码为输出噪声以及对角的协方差矩阵的预测结果。

而且,这两个输出的形状都与原始的空间输入一致。作者在这个环节使用标准的线性解码器,将每个 token 线性解码为 P × P × 2 C P\times P \times 2C P×P×2C 的张量,其中 C 是空间输入中到 DiT 的通道数。最后将解码的 tokens 重新排列到其原始空间布局中,得到预测的噪声和协方差。

模型尺寸

遵循 ViT 的做法,作者在缩放 DiT 时也是从下面几个维度进行考虑:深度 N ,hidden dimension d,head 数量。作者设计出4种不同尺寸的 DiT 模型 DiT-S, DiT-B, DiT-L 和 DiT-XL,从 0.3 到 118.6 GFLOPs,详细信息如下图所示。
在这里插入图片描述

实验

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/589849.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

go库x/text缺陷报告CVE-2022-32149的处理方案

#问题描述 go库 golang.org/x/text ,注意这里不是go的源码, 在0.3.8版本之前存在一个缺陷(Vulnerability) 缺陷ID CVE-2022-32149 具体描述 攻击者可以通过制作一个Accept-Language报头来导致拒绝服务。 具体的原因是,在解析这个Accept-L…

算法学习 | day34/60 不同路径/不同路径II

一、题目打卡 1.1 不同路径 题目链接:. - 力扣(LeetCode) 拿到手,首先见到答案需要求的是种类的个数,并且看题目,每次移动的时候只有两个方向,这也就说明,对于某一个位置来说&#x…

深入剖析主机安全中的零信任机制及其实施原理

引言 在数字化转型加速与云端服务普及的大背景下,传统依赖边界的网络安全模式逐渐显露出其局限性。面对愈发复杂多变的威胁环境,零信任安全架构作为新一代的安全范式应运而生,尤其是在主机层面的安全实践中,零信任机制正扮演着至…

基于YOLOv8车牌识别算法支持12种中文车牌类型(源码+图片+说明文档)

yolov8车牌识别算法,支持12种中文车牌类型 支持如下: 1.单行蓝牌 2.单行黄牌 3.新能源车牌 4.白色警用车牌 5.教练车牌 6.武警车牌 7.双层黄牌 8.双层白牌 9.使馆车牌 10.港澳粤Z牌 11.双层绿牌 12.民航车牌 图片测试demo: 直接运行detect_plate.py 或者…

DDL ---- 数据库的操作

1.查询所有数据库 show databases; 上图除了自创的,其他的四个都是mysql自带的数据库 。(不区分大小写) 2.查询当前数据库 select database(); 最开始没有使用数据库,那么查找结果为NULL 所以我们就需要先使用数据库&#xff…

shopee虾皮业绩一直没办法提升?不同时期要有不同的运营思路

店铺运营“开荒期”需要根据自身店铺数据调整运营策略,“运营期”就需要更多分析竞品的运营数据,分析接近上架时间段的出单同款/相似款,有效找到影响起量的因素;在出单缓慢,接近瓶颈期时找同行的策略方案,抓…

PyLMKit(8):ChatDB与你的数据库聊天,数据库问答

功能介绍 与你的结构化数据聊天:支持主流数据库、表格型excel等数据! ChatDB:支持数据库问答ChatTable:支持txt,excel,csv等pandas dataframe表格的问答 1.下载安装 pip install pylmkit -U pip install pymysql sqlalchemy s…

【学习心得】Numpy学习指南或复习手册

本文是自己在学习Numpy过后总是遗忘的很快,反思后发现主要是两个原因: numpy的知识点很多,很杂乱。练习不足,学习过后一段时间不敲代码就会忘记。 针对这两个问题,我写了这篇文章。希望将numpy的知识点织成一张网&…

【智能算法】金枪鱼群优化算法(TSO)原理及实现

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.代码展示4.参考文献 1.背景 2021年,Xie等人受到自然界中金枪鱼狩猎行为启发,提出了金枪鱼优化算法(Tuna swarm optimization,TSO)。 2.算法原理 2.1算法思想 TSO模…

选数异或(DP)

题目描述 给定一个长度为 n 的数列 A1, A2, , An 和一个非负整数 x,给定 m 次查询, 每次询问能否从某个区间 [l,r] 中选择两个数使得他们的异或等于 x 。 输入格式 输入的第一行包含三个整数 n, m, x 。 第二行包含 n 个整数 A1, A2, , An 。 接下来 m 行…

【Java基础】Java基础知识整合

文章目录 1. 转义字符2. 变量2.1 字符串与整型相加2.2 byte和short的区别2.3 float和double的区别2.4 char类型2.5 boolean类型2.6 自动类型转换及运算2.7 强制类型转换2.8 String的转换2.9 除法运算2.10 取模规则 3. 自增4. 逻辑运算符5. 赋值运算 6. 三元运算符:7…

FreeRTOS中断管理以及实验

FreeRTOS中断管理以及实验 继续记录学习FreeRTOS的博客,参照正点原子FreeRTOS的视频。 ARM Cortex-M 使用了 8 位宽的寄存器来配置中断的优先等级,这个寄存器就是中断优先级配置寄存器 , STM32寄存器中并且这个寄存器只使用[7:4]&#xff0c…