diffusion model(十三):DiT技术小结

info
paperhttps://arxiv.org/abs/2212.09748
githubhttps://github.com/facebookresearch/DiT/tree/main
个人博客主页http://myhz0606.com/article/dit
create date2024-03-08

阅读前需要具备以下前置知识:

DDPM(扩散模型基本原理):知乎地址 个人博客地址 paper

LDM (隐空间扩散模型基本原理,stable diffusion 底层架构) 知乎地址 个人博客地址 paper

classifier-free guided(文生图基本原理) 知乎地址 个人博客地址 paper

Motivate

虽然Transformer架构已经在诸多自然语言处理和计算机视觉任务中展现出卓越的scalable能力,但目前主导扩散模型架构的仍是UNet。本文旨在探讨以Transformer取代UNet在扩散模型中的可行性和潜在方案,并对所提出的Diffusion Transformer (DIT)架构的scalable能力进行了验证和评估。

Method

采用DiT架构替换UNet主要需要探索以下几个关键问题:

  1. Token化处理。Transformer的输入为一维序列,形式为 R T × d \mathbb{R}^{T \times d} RT×d(忽略batch维度),而LDM的latent表征 z ∈ R H f × W f × C z \in \mathbb{R}^{\frac{H}{f} \times \frac{W}{f} \times C} zRfH×fW×C为spatial张量。因此,需要设计合适的Token化方法将二维latent映射为一维序列。
  2. 条件信息嵌入。sable diffusion火出圈的一个关键在于它能够根据用户的文本指令生成高质量的图像。这里面的核心在于需要将文本特征嵌入到扩散模型中协同生成。并且扩散模型的每一个生成还需要融入time-embedding来引入时间步的信息。因此,若要用Transformer架构取代Unet需要系统研究Transformer架构的条件嵌入

DiT这篇paper的核心在于对上述两个问题的系统研究。

在这里插入图片描述

Patchify(token化)

假定原始图片 x ∈ R 256 × 256 × 3 x \in \mathbb{R} ^ {256\times256\times3} xR256×256×3,经过auto-encoder后得到latent表征 z ∈ R 32 × 32 × 4 z \in \mathbb{R} ^ {32\times32\times4} zR32×32×4。首先DiT 用ViT中patch化的方式将隐表征 z z z 转化为token序列,随后给序列添加位置编码。图中展示了patch化的过程。patch_size p是一个超参数。文中分别尝试了p=2,4,8。(DiT的输出会将每一个token线性解码成pxpx2C,再reshape为nose和协方差)

在这里插入图片描述

DiT block设计

这个部分系统探究了4中在DiT中引入控制信号的方案。

(一)In-context conditioning

直接将时间步信号、文本控制信号作为addition token和输入sequence进行拼接。其角色类似于类似于ViT里面的[CLS]token。这样做有一个好处,原本的ViT架构都可以不动,并且增加的的计算量可以忽略不计。

(二)Cross-Attention block

这个方法首先将时间步信号 ( R 1 × d ) (\mathbb{R} ^{1 \times d}) (R1×d)和文本信号 ( R 1 × d ) (\mathbb{R} ^{1 \times d}) (R1×d)进行拼接,得到拼接后的控制信号 ( R 2 × d ) (\mathbb{R} ^{2 \times d}) (R2×d)。随后类似文献[1]的做法,在ViT中添加cross attention层,将控制信号作为cross-attention的key,value进行融入。

(三)Adaptive Layer Norm (adaLN) block

作者参考文献[2]提出的adaptive normalization layer(adaLN),将transformer block的layer norm替换为adaLN。简单来说就是,原本的将原本layer norm用于仿射变换的scale parameter γ \gamma γ和shift parameter β \beta β 用condition embedding来替代。下面给出了最简的示例代码便于理解。

论文原话:Rather than directly learn dimensionwise scale and shift parameters γ and β, we regress them from the sum of the embedding vectors of t and c.

import numpy as npclass LayerNorm:def __init__(self, feature_dim, epsilon=1e-6):self.epsilon = epsilonself.gamma = np.random.rand(feature_dim)  # scale parametersself.beta = np.random.rand(feature_dim)  # shift parametrsdef __call__(self, x: np.ndarray) -> np.ndarray:"""Args:x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)return:x_layer_norm (np.ndarray): shape: (batch_size, sequence_length, feature_dim)"""_mean = np.mean(x, axis=-1, keepdims=True)_std = np.var(x, axis=-1, keepdims=True)x_layer_norm = self.gamma * (x - _mean / (_std + self.epsilon)) + self.betareturn x_layer_normclass DiTAdaLayerNorm:def __init__(self,feature_dim, epsilon=1e-6):self.epsilon = epsilonself.weight = np.random.rand(feature_dim, feature_dim * 2)def __call__(self, x, condition):"""Args:x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)condition (np.ndarray): shape: (batch_size, 1, feature_dim)Ps: condition = time_cond_embedding + class_cond_embeddingreturn:x_layer_norm (np.ndarray): shape: (batch_size, sequence_length, feature_dim)"""affine = condition @ self.weight  # shape: (batch_size, 1, feature_dim * 2)gamma, beta = np.split(affine, 2, axis=-1)_mean = np.mean(x, axis=-1, keepdims=True)_std = np.var(x, axis=-1, keepdims=True)x_layer_norm = gamma * (x - _mean / (_std + self.epsilon)) + betareturn x_layer_norm

(四)adaLN-Zero block

这个方法是(三)的延伸。简单来说就是condition embedding除了融入到layer norm中,还作为residual的强度融入到residual连接中。下面给出了最简的示例代码

import numpy as npclass LayerNorm:def __init__(self, epsilon=1e-6):self.epsilon = epsilondef __call__(self, x: np.ndarray, gamma: np.ndarray, beta: np.ndarray) -> np.ndarray:"""Args:x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)gamma (np.ndarray): shape: (batch_size, 1, feature_dim), generated by condition embeddingbeta (np.ndarray): shape: (batch_size, 1, feature_dim), generated by condition embeddingreturn:x_layer_norm (np.ndarray): shape: (batch_size, sequence_length, feature_dim)"""_mean = np.mean(x, axis=-1, keepdims=True)_std = np.var(x, axis=-1, keepdims=True)x_layer_norm = self.gamma * (x - _mean / (_std + self.epsilon)) + self.betareturn x_layer_normclass DiTBlock:def __init__(self, feature_dim):self.MultiHeadSelfAttention = lambda x: x # mock multi-head self-attentionself.layer_norm = LayerNorm()self.MLP = lambda x: x # mock multi-layer perceptronself.weight = np.random.rand(feature_dim, feature_dim * 6)def __call__(self, x: np.ndarray, time_embedding: np.ndarray, class_emnedding: np.ndarray) -> np.ndarray:"""Args:x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)time_embedding (np.ndarray): shape: (batch_size, 1, feature_dim)class_emnedding (np.ndarray): shape: (batch_size, 1, feature_dim)return:x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)"""condition_embedding = time_embedding + class_emneddingaffine_params = condition_embedding @ self.weight  # shape: (batch_size, 1, feature_dim * 6)gamma_1, beta_1, alpha_1, gamma_2, beta_2, alpha_2 = np.split(affine_params, 6, axis=-1)x = x + alpha_1 * self.MultiHeadSelfAttention(self.layer_norm(x, gamma_1, beta_1))x = x + alpha_2 * self.MLP(self.layer_norm(x, gamma_2, beta_2))return x

Result

作者在imagenet数据上,以classifier-free的方式训练DiT(仅做class-control,即text condition embedding为类别embedding)。作者设置了4种不同model size的DiT,并开展实验。

在这里插入图片描述

DiT的scalable能力验证

作者分别尝试了的patch size,不同model size的DiT,从图中不难发现

  • patch size越小生成的效果越好(意味着初始时sequence的token数越多)。这里不太明白为什么作者不实验p=1的情形。因为latent表征本身就可以视作是CNN抽取的隐式token,只要flatten即可,很多hybrid的架构(CNN+ViT)都是这么玩的,或许是为了控制计算量?
  • model size越大生成效果越好。从实验结果中DiT-XLDiT-L的差距很小,可能是因为训练数据量还不够大体现不出大模型的优势

在这里插入图片描述

在这里插入图片描述

DiT Block有效性验证

作者在imagenet数据集上验证上面提出的四种DiT block的的生成效果。ada LN-Zero方案的生成效果最好。

在这里插入图片描述

小结

DiT 系统研究了diffusion transformer的token化和条件嵌入两个关键问题,验证了基于transformer架构的扩散模型的scalable能力。

参考文献

[1] Attention is all you need.

[2] Film: Visual reasoning with a general conditioning layer.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/526727.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java 抽象类和接口

登神长阶 第三阶 抽象类和接口 🍀🍀🍀🍀🍀🍀🍀🍀🍀🍀🍀🍀🍀🍀🍀🍀🍀&…

分割模型TransNetR的pytorch代码学习笔记

这个模型在U-net的基础上融合了Transformer模块和残差网络的原理。 论文地址:https://arxiv.org/pdf/2303.07428.pdf 具体的网络结构如下: 网络的原理还是比较简单的, 编码分支用的是预训练的resnet模块,解码分支则重新设计了。…

Masked Generative Distillation(MGD)2022年ECCV

Masked Generative Distillation(MGD)2022年ECCV 摘要 **目前的蒸馏算法通常通过模仿老师的输出来提高学生的表现。本文表明,教师还可以通过引导学生特征恢复来提高学生的代表性。从这个角度来看,我们提出的掩模生成蒸馏&#x…

先进电机技术 —— 高速电机与低速电机

一、背景 高速电机是指转速远高于一般电机的电动机,通常其转速在每分钟几千转至上万转甚至几十万转以上。这类电机具有功率密度高、响应速度快、输出扭矩大等特点,在航空航天、精密仪器、机器人、电动汽车、高端装备制造等领域有着广泛的应用。 高速电…

无人机生态环境监测、图像处理与GIS数据分析

构建“天空地”一体化监测体系是新形势下生态、环境、水文、农业、林业、气象等资源环境领域的重大需求,无人机生态环境监测在一体化监测体系中扮演着极其重要的角色。通过无人机航空遥感技术可以实现对地表空间要素的立体观测,获取丰富多样的地理空间数…

跨平台大小端判断与主机节序转网络字节序使用

1.macOS : 默认使用小端 ,高位使用高地址,转换为网络字节序成大端 #include <iostream> #include <arpa/inet.h> int main() {//大小端判断union{short s;char c[sizeof(short)];}un;un.s = 0x0102;printf("低地址:%d,高地址:%d\n",un.c[0],un.c[1]);if …

灯塔批量添加指纹信息

攻击地址https://github.com/loecho-sec/ARL-Finger-ADD 指纹文件https://github.com/lemonlove7/EHole_magic/blob/main/finger.json 成功导入可以看到

基于Springboot的在线租房和招聘平台(有报告)。Javaee项目,springboot项目。

演示视频&#xff1a; 基于Springboot的在线租房和招聘平台&#xff08;有报告&#xff09;。Javaee项目&#xff0c;springboot项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结…

采用 Amazon DocumentDB 和 Amazon Bedrock 上的 Claude 3 构建游戏行业产品推荐

前言 大语言模型&#xff08;LLM&#xff09;自面世以来即展示了其创新能力&#xff0c;但 LLM 面临着幻觉等挑战。如何通过整合外部数据库的知识&#xff0c;检索增强生成&#xff08;RAG&#xff09;已成为通用和可行的解决方案。这提高了模型的准确性和可信度&#xff0c;特…

【个人开发】llama2部署实践(三)——python部署llama服务(基于GPU加速)

1.python环境准备 注&#xff1a;llama-cpp-python安装一定要带上前面的参数安装&#xff0c;如果仅用pip install装&#xff0c;启动服务时并没将模型加载到GPU里面。 # CMAKE_ARGS"-DLLAMA_METALon" FORCE_CMAKE1 pip install llama-cpp-python CMAKE_ARGS"…

PyTorch搭建LeNet训练集详细实现

一、下载训练集 导包 import torch import torchvision import torch.nn as nn from model import LeNet import torch.optim as optim import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as npToTensor()函数&#xff1a; 把图像…

【脚本玩漆黑的魅影】全自动刷努力值

文章目录 原理全部代码 原理 全自动练级&#xff0c;只不过把回城治疗改成吃红苹果。 吃一个可以打十下&#xff0c;背包留10个基本就练满了。 吃完会自动停止。 if img.getpixel(data_attack[0]) data_attack[1] or img.getpixel(data_attack_2[0]) data_attack_2[1]: # …