【论文精读】DALL·E

摘要

       本文利用从互联网上收集的2.5亿个图像/文本对数据,训练了一个120亿参数的自回归transformer,进而得到一个可以通过自然语言/图像控制生成的高保真图像生成模型。在大多数数据集上的表现超越以往的方法。

框架

       本文的目标为通过训练一个自回归transformer,通过将文本和图像tokens自回归建模为单个数据流,进而结合图像解码器进行图像生成,整体分为两个阶段:
image

  • 第一阶段:训练一个离散变分自编码器(dVAE),其编码器会将输入图像从 256 × 256 256 × 256 256×256压缩为 32 × 32 32 × 32 32×32的图像tokens,其中每个token都会映射到 K = 8192 K = 8192 K=8192的codebook向量中。相比于直接使用像素作为图像token,这可以使后续步骤的自回归transfromer的上下文大小减少192倍,同时不会大幅降低视觉质量(如上图)。
  • 第二阶段:将256个由BPE编码的文本tokens 与 32 × 32 = 1024 32 × 32 = 1024 32×32=1024个图像tokens拼接起来,基于此训练一个自回归transformer,实现对文本和图像tokens的联合分布的建模。

       整个过程可以被看作最大化模型在图像 x x x、文本 y y y和tokens z z z上的联合似然的ELBO,通过因子分解可以将该分布建模为 p θ , ψ ( x , y , z ) = p θ ( x ∣ y , z ) p ψ ( y , z ) p_{θ,ψ}(x, y, z) = p_θ(x | y, z)p_ψ(y, z) pθ,ψ(x,y,z)=pθ(xy,z)pψ(y,z),对应的ELBO为:
ln ⁡ p θ , ψ ( x , y ) ≥ E z ∼ q ϕ ( z ∣ x ) ( ln ⁡ p θ ( x ∣ y , z ) − β D K L ( q ϕ ( y , z ∣ x ) , p ψ ( y , z ) ) ) \ln p_{θ,ψ}(x, y) \ge \mathbb{E}_{z \sim q_ϕ(z|x)}(\ln p_θ(x|y,z)-\beta D_{KL}(q_ϕ(y,z|x),p_ψ(y,z))) lnpθ,ψ(x,y)Ezqϕ(zx)(lnpθ(xy,z)βDKL(qϕ(y,zx),pψ(y,z)))

       其中, q ϕ q_ϕ qϕ表示给定图像 x x x经过dVAE编码器生成的 32 × 32 32 × 32 32×32的tokens的分布; p θ p_θ pθ表示由tokens经过dVAE解码器生成的图像的分布; p ψ p_ψ pψ表示由自回归transformer建模的文本和图像tokens的联合分布。该ELBO只适用与 β = 1 \beta = 1 β=1的情况。

Learning the Visual Codebook

       第一阶段的训练目标为最大化 ϕ ϕ ϕ θ θ θ的ELBO,即通过给定图像训练dVAE。先验 p ψ p_ψ pψ初始化为基于codebook( K = 8192 K = 8192 K=8192)向量的均匀分类分布(uniform categorical distribution); q ϕ q_ϕ qϕ初始化为在编码器输出的 32 × 32 × 8192 32 × 32×8192 32×32×8192的logits参数化的分类分布(uniform categorical)。

       由于 p ψ p_ψ pψ是一个离散分布,无法使用梯度进行优化,故此处采用gumbel-softmax松弛,用 q ϕ τ q ^τ_ ϕ qϕτ取代 q ϕ q_ϕ qϕ,当 τ → 0 τ → 0 τ0时,松弛程度会逐渐缩小,逼近原始分布。 p θ p_θ pθ的似然使用log-laplace分布评估,以避免离群值导致的生成模糊问题。

       松弛后的ELBO使用Adam和EMA优化,以下配置对训练稳定性很重要:

  • 松弛temperature和步长的具体退火方法。实验发现 τ τ τ退火到1/16时,松弛ELBO的 q ϕ τ q ^τ_ ϕ qϕτ和真实ELBO的 q ϕ q_ϕ qϕ之间的gap就会消失。
  • 在编码器的末尾和解码器的开头使用1 × 1卷积。 实验发现,通过减少松弛方法周围的卷积层的感受野大小,可以使其泛化到真实ELBO的情况。
  • 将编码器和解码器的输出激活值乘一个小的常数,可以使初始化时的训练更加训练。

       另外,KL权重增加到 β = 6.6 β = 6.6 β=6.6时,可以得到更好codebook,故而使训练结束时的重构误差更小。

Learning the Prior

       第二阶段在固定 ϕ ϕ ϕ θ θ θ的情况下,最大化关于 ψ ψ ψ的ELBO,学习文本和图像token的联合先验分布。其中, p ψ p_ψ pψ是一个120亿参数的稀疏transformer。

       具体,给定一个文本/图像对,首先通过对小写文本进行BPE编码(词汇表大小为16384)得到最多256个文本token ,并对dVAE编码器输出的logits进行argmax采样codebook得到1024个图像token,此处没有添加gumbel噪声。最后,拼接这些文本和图像token作为单个数据流进行自回归建模。

       本文限制文本标题的最大长度为256,每个文本位置都会学习一个特殊的“padding” token,当对应位置没有文本token时使用此token。得到文本和图像token的交叉熵损失后,将文本交叉熵损失乘以1/8,图像交叉熵损失乘以7/8,以对loss归一化,本阶段也使用EMA和Adam进行优化。

Data Collection

       本文从互联网上收集了2.5亿个文本/图像对,创建了一个与JFT-300M相似规模的数据集。该数据集包括一部分Conceptual Captions和YFCC100M的经过滤的子集。

Mixed-Precision Training

       为了节省GPU内存并增加吞吐量,模型的大多数参数、Adam矩阵和模型激活值都以FP16存储,并使用了activation checkpointing技术。

       在训练过程中发现,随着模型变得更深更广,resblocks的激活梯度会单调减少,较深层的resblocks的激活梯度可能小于FP16的最小值,其会被四舍五入为0,这种现象称为下溢(underflow)。实验发现消除下溢可以使训练更加稳定。
image
       故对于模型中每个的resblock,通过执行“gradient scale”可以解决下溢问题,如上图。

Distributed Optimization

image
       DELL-E有120亿参数,以FP16精度存储时会消耗约24GB内存,这超过单张NVIDIA V100 GPU的16GB内存,故使用参数分片。如上图。
image
       本文的实现中,每台机器上的每个GPU都独立地计算其参数分片梯度的低秩因子,而不依赖于其相邻的GPU。一旦计算出低秩因子,每台机器都会将其error buffer设置为其八个GPU上未压缩的参数梯度的平均值(通过reduce-scatter获得)与通过解压缩低秩因子得到的梯度之间的残差(两者偏差)。

       对于一个模型训练集群,其机器之间的带宽远低于同一机器上不同GPU之间的带宽,故机器之间梯度平均操作(all-reduce)成为训练期间的主要速度瓶颈,通过引入PowerSGD压缩梯度,可以大大降低这种成本。 PowerSGD会将未压缩的参数梯度的通信操替换为基于其低秩因子的两个更小的通信操作。给定压缩rank r r r和transformer激活尺寸 d m o d e l d_{model} dmodel,其压缩率为 1 − 5 r / ( 8 d m o d e l ) 1 − 5r/(8d_{model}) 15r/(8dmodel)。如上表显示,无论模型大或小,该方法可以实现约85%的压缩率。

Sample Generation

image
       对于从transformer中生成的一系列图像,本文采用预训练CLIP对生成图像与文本标题的匹配程度来分配分数并排序。如上图显示了给定生成的N张图像,并从中选择的top-k图像。除非另有说明,用于定性和定量结果的所有样本都是在不降低temperature的情况下获得的,并使用N = 512重新排序。

实验

Quantitative Results

image
       上图定性比较了DALL-E和AttnGAN、DM-GAN和DF-GAN的生成。
image
       上图为人类验证实验。给定一个文本标题,相比DF-GAN,DALL-E的生成在93%的情况下与文本标题更好地匹配,而获得更多的人类投票。在90%的情况下,也因为更真实而获得了大多数人类投票。
image
       上图(a上)为在MS-COCO数据集上验证的定量结果,DALL-E与之前最佳方法只差2个点的FID分数。由于DALL-E的训练数据中包含一个YFCC100M的过滤子集,其中包含MS-COCO验证集中大约21%的图像,故为了隔离这种影响,另外分别计算了验证集有这些图像(实线)和没有这些图像(虚线)的FID信息,结果没有明显变化。

       用dVAE编码器的token训练transformer,可以使模型学习更多的图像低频信息,使图像在视觉上更真实。,但这也不利于模型学习产生高频细节。为了验证模型的高频建模能力,本实验对验证图像和模型生成的样本应用了不同半径的高斯滤波器,并计算对应IS值。结果如上图(a下),随着模糊半径的增加,DALL-E和其他方法之间的差距越拉越大,当模糊半径大于等于2时,DALL-E取得了最佳结果。

       DALL-E在CUB数据集上的表现比较差,如上图(b),和之前的主要方法有近40点FID的差距。经过检测发现,训练数据集中包含12%的CUB数据,但去除这些数据后模型表现仍旧不佳。故推测zero-shot DALL-E不太可能在CUB等专业分布的数据集上获得优势。

       上图(c)显示了当用于CLIP重排序的样本增加时,DALL-E的FID有了明显改进。
image
       上图显示了DALL-E在CUB数据集中不同文本标题下的生成示例。

Qualitative Findings

image
       通过验证发现DALL-E有不以最初预期的方式进行泛化的能力。当给出文本标题“a tapir made of accordion… ”,该模型似乎画了一个以手风琴为身 体的貘(上图a)。这表明,其发展出一种基本的能力,可以在较高的抽象层次上组合概念。

       DALL-E似乎也能进行组合泛化,例如在渲染如“an illustration of a baby hedgehog in a christmas sweater walking a dog”这样的句子(上图b、c)。

       在有限的可靠性程度上,还发现该模型能够由自然语言控制图像到图像的翻译。当模型被赋予标题“the exact same cat on the top as a sketch at the bottom”时,其能够在底部画一个类似的猫的草图(上图d)。 这也适用于其他几种类型的转换,包括图像操作(例如改变图像的颜色、将其转换为灰度或翻转图像)和样式转换(例如在贺卡、邮票或手机壳上画猫)。一些只涉及改变动物颜色的转换,表明DALL-E能够执行基本的对象分割。

Appendix

Details for Discrete VAE

Architecture

       dVAE编码器和解码器都为具有bottleneck-style resblocks的ResNets。编码器的第一层卷积核尺寸为 7 × 7 7 × 7 7×7,编码器的最后一层卷积核尺寸为 1 × 1 1×1 1×1(输出尺寸为 32 × 32 × 8192 32 × 32 × 8192 32×32×8192,用作图像token的分类分布的logits)。解码器的第一层卷积和最后一层卷积核尺寸都为 1 × 1 1×1 1×1。编码器使用最大池化下采样,解码器使用最近邻上采样。

Training

image
       dVAE在与transformer相同的数据集上进行训练,使用上图中给出的数据增强代码。以下量在训练过程中使用余弦退火进行衰减:

  • KL权重 β β β在前5000次迭代中从0增加到6.6
  • 松弛 τ τ τ在前150000次迭代中从1退火到1/16
  • 在1200000迭代中,step size从 1 ⋅ 1 0 − 4 1\cdot10^{−4} 1104退火到 1.25 ⋅ 1 0 − 6 1.25 \cdot 10^{−6} 1.25106

       使用 β 1 = 0.9 , β 2 = 0.999 , ϵ = 1 0 − 8 β_1 = 0.9, β_2 = 0.999, ϵ = 10^{−8} β1=0.9,β2=0.999,ϵ=108的AdamW和 1 0 − 4 10^{−4} 104的weight decay和 0.999 0.999 0.999的EMA优化模型。该模型在64个16 GB NVIDIA V100 gpu上使用混合精度训练,每个gpu的batch size为8,总batch size为512,总共3000000次。

Details for Transformer

Architecture

image
       本文第二阶段模型是一个仅解码器的稀疏transformer,其输入tokens embedding格式如上图。其包括64个注意力层,每个层使用62个注意力头,每个头的维度大小为64。
image
       该模型使用三种稀疏注意力mask,如上图。给定自注意力层的索引 i i i i ∈ [ 1 , 63 ] i ∈ [1, 63] i[1,63]),如果 i − 2 m o d 4 = 0 i − 2\ mod \ 4 = 0 i2 mod 4=0,则使用列注意力mask(c),否则使用行注意力mask,例如,前四个自注意力层分别使用row、column、row、row。卷积注意力mask(d)仅用于最后的自注意力层。

Training

image
       对于训练transformer的训练,在使用dVAE编码器编码图像之前,首先对图像进行如上图代码所示的数据增强。在用BPE编码文本标题时,还应用了10%的BPE dropout。该模型使用逐resblock缩放和梯度压缩进行训练,总压缩rank为896(每个GPU的参数分片使用112的压缩rank)。

       使用 β 1 = 0.9 , β 2 = 0.96 , ϵ = 1 0 − 8 β_1 = 0.9, β_2 = 0.96, ϵ = 10^{−8} β1=0.9,β2=0.96,ϵ=108的AdamW与和 4.5 ⋅ 1 0 − 2 4.5 \cdot 10^{−2} 4.5102的weight decay和0.99的EMA优化参数。在应用Adam更新前,会使用阈值为4的norm对解压后的梯度进行裁剪,梯度裁剪仅在训练开始的预热阶段运行。为了节省内存,大部分Adam矩阵以FP16格式存储,其中运行平均值为1-6-9格式(即1位用于符号,6位用于指数,9位用于尾数),运行方差为0-6-10格式,在更新参数或动量之前,会将运行方差裁剪为5。其次,还会异步地将模型参数从GPU复制到CPU(每25次更新复制一次),以获得更稳定的更新。

       该模型在1024个16 GB NVIDIA V100 gpu和总batch size为1024的设置下训练模型,总共进行了430000更新。step size在前5000次迭代中,线性退火到 4.5 ⋅ 1 0 − 4 4.5 · 10^{−4} 4.5104,并在每次训练损失趋于稳定时将step size减半,训练周期内,总共减半了5次,比初始步长小32倍的最终步长结束训练。

Details for Human Evaluation Experiments

image
       对于人类验证实验,本文使每个模型对每个文本标题生成一个示例图像,并给定文本和示例图像让人类给出比较结果,实验提交给了亚马逊的Mechanical Turk,每组生成都由五名不同的人类回答。工作人员被要求比较两张图像并选择答案:(1)哪张图像最真实,(2)哪张图像最匹配文本标题。提供给人类的实验设置如上图。

Zero-Shot Image-to-Image Translation

image
       上图显示了DALL-E的zero-shot图像到图像转换的示例。

reference

Ramesh, A. , Pavlov, M. , Goh, G. , Gray, S. , Voss, C. , & Radford, A. , et al. (2021). Zero-shot text-to-image generation.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/471844.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

牛客小白月赛87

说明 年后第一次写题,已经麻了,这次的题很简单但居然只写了两道题。有种本该发挥80分的水平,但是只做出了20分的水平的感觉。不过剩下几个题(除了G题),比完赛一小时内就AC了。欢迎大家交流学习。&#xff0…

java8-用optional取代nu11

本章内容口nu11引用引发的问题,以及为什么要避免nu11引用从nu11到optiona1:以nu11安全的方式重写你的域模型让optiona1发光发热:去除代码中对nu11的检查 读取optiona1中可能值的几种方法口对可能缺失值的再思考 如果你作为Java程序员曾经遭遇过Nu11PointerException…

【动态规划初识】不同的二叉搜索树

每日一道算法题之不同二叉搜索树个数 一、题目描述二、思路三、C++代码一、题目描述 题目来源:LeetCode 给你一个整数 n ,求恰由 n 个节点组成且节点值从 1 到 n 互不相同的 二叉搜索树 有多少种?返回满足题意的二叉搜索树的种数。 C++程序要求输入输出格式如下: 示例1:…

BUUCTF misc 专题(47)[SWPU2019]神奇的二维码

下载附件,得到一张二维码图片,并用工具扫描(因为图片违规了,所以就不放了哈。工具的话,一般的二维码扫描都可以) swpuctf{flag_is_not_here},(刚开始出了点小差错对不住各位师傅&am…

【机器学习笔记】5 机器学习实践

数据集划分 子集划分 训练集(Training Set):帮助我们训练模型,简单的说就是通过训练集的数据让我们确定拟合曲线的参数。 验证集(Validation Set):也叫做开发集( Dev Set &#xf…

基于Springboot的社区物资交易互助平台(有报告)。Javaee项目,springboot项目。

演示视频: 基于Springboot的社区物资交易互助平台(有报告)。Javaee项目,springboot项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系…

在已有代码基础上创建Git仓库

在已有代码基础上创建Git仓库 背景方法处理问题 背景 先进行了代码编写,后续想放入仓库方便大家一起合作开发,此时需要在已有代码的基础上建立仓库。 方法 首先在Gitee或者GitHub上创建仓库,这里以Gitee为例。创建完后,我们可以…

阿里云BGP多线精品EIP香港CN2线路低时延,价格贵

阿里云香港等地域服务器的网络线路类型可以选择BGP(多线)和 BGP(多线)精品,普通的BGP多线和精品有什么区别?BGP(多线)适用于香港本地、香港和海外之间的互联网访问。使用BGP&#xf…

kali无线渗透之WEP加密模式与破解13

WEP加密是最早在无线加密中使用的技术,新的升级程序在设置上和以前的有点不同,功能当然也比之前丰富一些。但是随着时间的推移,人们发现了WEP标准的许多漏洞。随着计算能力的提高,利用难度也越来越低。尽管WEP加密方式存在许多漏洞…

【二叉树层序遍历】【队列】Leetcode 102 107 199 637 429 515 116 117 104 111

【二叉树层序遍历】【队列】Leetcode 102 107 199 637 429 515 116 117 102. 二叉树的层序遍历解法 用队列实现107. 二叉树的层序遍历 II解法199. 二叉树的右视图 解法637. 二叉树的层平均值 解法429. N叉树的层序遍历515. 在每个树行中找最大值116. 填充每个节点的下一个右侧节…

102.网游逆向分析与插件开发-网络通信封包解析-解读喊话道具数据包并且利用Net发送

内容参考于:易道云信息技术研究院VIP课 上一个内容:解读聊天数据包并且利用Net发送 码云地址(游戏窗口化助手 分支):https://gitee.com/dye_your_fingers/sro_-ex.git 码云版本号:cc6370dc5ca6b0176aafc…

【C语言】指针的进阶篇,深入理解指针和数组,函数之间的关系

欢迎来CILMY23的博客喔,本期系列为【C语言】指针的进阶篇,深入理解指针和数组,函数之间的关系,图文讲解其他指针类型以及指针和数组,函数之间的关系,带大家更深刻理解指针,以及数组指针&#xf…