基于 Quanto 和 Diffusers 的内存高效 transformer 扩散模型

news/2024/11/18 6:40:22/文章来源:https://www.cnblogs.com/huggingface/p/18388714

过去的几个月,我们目睹了使用基于 transformer 模型作为扩散模型的主干网络来进行高分辨率文生图 (text-to-image,T2I) 的趋势。和一开始的许多扩散模型普遍使用 UNet 架构不同,这些模型使用 transformer 架构作为扩散过程的主模型。由于 transformer 的性质,这些主干网络表现出了良好的可扩展性,模型参数量可从 0.6B 扩展至 8B。

随着模型越变越大,内存需求也随之增加。对扩散模型而言,这个问题愈加严重,因为扩散流水线通常由多个模型串成: 文本编码器、扩散主干模型和图像解码器。此外,最新的扩散流水线通常使用多个文本编码器 - 如: Stable Diffusion 3 有 3 个文本编码器。使用 FP16 精度对 SD3 进行推理需要 18.765GB 的 GPU 显存。

这么高的内存要求使得很难将这些模型运行在消费级 GPU 上,因而减缓了技术采纳速度并使针对这些模型的实验变得更加困难。本文,我们展示了如何使用 Diffusers 库中的 Quanto 量化工具脚本来提高基于 transformer 的扩散流水线的内存效率。

基础知识

你可参考 这篇文章 以获取 Quanto 的详细介绍。简单来说,Quanto 是一个基于 PyTorch 的量化工具包。它是 Hugging Face Optimum 的一部分,Optimum 提供了一套硬件感知的优化工具。

模型量化是 LLM 从业者必备的工具,但在扩散模型中并不算常用。Quanto 可以帮助弥补这一差距,其可以在几乎不伤害生成质量的情况下节省内存。

我们基于 H100 GPU 配置进行基准测试,软件环境如下:

  • CUDA 12.2
  • PyTorch 2.4.0
  • Diffusers (从源代码安装,参考 此提交)
  • Quanto (从源代码安装,参考 此提交)

除非另有说明,我们默认使用 FP16 进行计算。我们不对 VAE 进行量化以防止数值不稳定问题。你可于 此处 找到我们的基准测试代码。

截至本文撰写时,以下基于 transformer 的扩散模型流水线可用于 Diffusers 中的文生图任务:

  • PixArt-Alpha 及 PixArt-Sigma
  • Stable Diffusion 3
  • Hunyuan DiT
  • Lumina
  • Aura Flow

另外还有一个基于 transformer 的文生视频流水线: Latte。

为简化起见,我们的研究仅限于以下三个流水线: PixArt-Sigma、Stable Diffusion 3 以及 Aura Flow。下表显示了它们各自的扩散主干网络的参数量:

模型 Checkpoint 参数量(Billion)
PixArt https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 0.611
Stable Diffusion 3 https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers 2.028
Aura Flow https://huggingface.co/fal/AuraFlow/ 6.843
请记住,本文主要关注内存效率,因为量化对推理延迟的影响很小或几乎可以忽略不计。

用 Quanto 量化 DiffusionPipeline

使用 Quanto 量化模型非常简单。

from optimum.quanto import freeze, qfloat8, quantize
from diffusers import PixArtSigmaPipeline
import torchpipeline = PixArtSigmaPipeline.from_pretrained("PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16
).to("cuda")quantize(pipeline.transformer, weights=qfloat8)
freeze(pipeline.transformer)

我们对需量化的模块调用 quantize() ,以指定我们要量化的部分。上例中,我们仅量化参数,保持激活不变,量化数据类型为 FP8。最后,调用 freeze() 以用量化参数替换原始参数。

然后,我们就可以如常调用这个 pipeline 了:

image = pipeline("ghibli style, a fantasy landscape with castles").images[0]
FP16 将 transformer 扩散主干网络量化为 FP8
FP16 image.FP8 quantized image.

我们注意到使用 FP8 可以节省显存,且几乎不影响生成质量; 我们也看到量化模型的延迟稍有变长:

Batch Size 量化 内存 (GB) 延迟 (秒)
1 12.086 1.200
1 FP8 11.547 1.540
4 12.087 4.482
4 FP8 11.548 5.109

我们可以用相同的方式量化文本编码器:

quantize(pipeline.text_encoder, weights=qfloat8)
freeze(pipeline.text_encoder)

文本编码器也是一个 transformer 模型,我们也可以对其进行量化。同时量化文本编码器和扩散主干网络可以带来更大的显存节省:

Batch Size 量化 是否量化文本编码器 显存 (GB) 延迟 (秒)
1 FP8 11.547 1.540
1 FP8 5.363 1.601
4 FP8 11.548 5.109
4 FP8 5.364 5.141

量化文本编码器后生成质量与之前的情况非常相似:

ckpt@pixart-bs@1-dtype@fp16-qtype@fp8-qte@1.png

上述攻略通用吗?

将文本编码器与扩散主干网络一起量化普遍适用于我们尝试的很多模型。但 Stable Diffusion 3 是个特例,因为它使用了三个不同的文本编码器。我们发现 _ 第二个 _ 文本编码器量化效果不佳,因此我们推荐以下替代方案:

  • 仅量化第一个文本编码器 (CLIPTextModelWithProjection) 或
  • 仅量化第三个文本编码器 (T5EncoderModel) 或
  • 同时量化第一个和第三个文本编码器

下表给出了各文本编码器量化方案的预期内存节省情况 (扩散 transformer 在所有情况下均被量化):

Batch Size 量化 量化文本编码器 1 量化文本编码器 2 量化文本编码器 3 显存 (GB) 延迟 (秒)
1 FP8 1 1 1 8.200 2.858
1 ✅ FP8 0 0 1 8.294 2.781
1 FP8 1 1 0 14.384 2.833
1 FP8 0 1 0 14.475 2.818
1 ✅ FP8 1 0 0 14.384 2.730
1 FP8 0 1 1 8.325 2.875
1 ✅ FP8 1 0 1 8.204 2.789
1 - - - 16.403 2.118
量化文本编码器: 1 量化文本编码器: 3 量化文本编码器: 1 和 3
Image with quantized text encoder 1.Image with quantized text encoder 3.Image with quantized text encoders 1 and 3.

其他发现

在 H100 上 bfloat16 通常表现更好

对于支持 bfloat16 的 GPU 架构 (如 H100 或 4090),使用 bfloat16 速度更快。下表列出了在我们的 H100 参考硬件上测得的 PixArt 的一些数字: Batch Size 精度 量化 显存 (GB) 延迟 (秒) 是否量化文本编码器

Batch Size 精度 量化 显存(GB) 延迟(秒) 是否量化文本编码器
1 FP16 INT8 5.363 1.538
1 BF16 INT8 5.364 1.454
1 FP16 FP8 5.363 1.601
1 BF16 FP8 5.363 1.495

qint8 的前途

我们发现使用 qint8 (而非 qfloat8 ) 进行量化,推理延迟通常更好。当我们对注意力 QKV 投影进行水平融合 (在 Diffusers 中调用 fuse_qkv_projections() ) 时,效果会更加明显,因为水平融合会增大 int8 算子的计算维度从而实现更大的加速。我们基于 PixArt 测得了以下数据以证明我们的发现:

Batch Size 量化 显存 (GB) 延迟 (秒) 是否量化文本编码器 QKV 融合
1 INT8 5.363 1.538
1 INT8 5.536 1.504
4 INT8 5.365 5.129
4 INT8 5.538 4.989

INT4 咋样?

在使用 bfloat16 时,我们还尝试了 qint4 。目前我们仅支持 H100 上的 bfloat16qint4 量化,其他情况尚未支持。通过 qint4 ,我们期望看到内存消耗进一步降低,但代价是推理延迟变长。延迟增加的原因是硬件尚不支持 int4 计算 - 因此权重使用 4 位,但计算仍然以 bfloat16 完成。下表展示了 PixArt-Sigma 的结果:

Batch Size 是否量化文本编码器 显存 (GB) 延迟 (秒)
1 9.380 7.431
1 3.058 7.604

但请注意,由于 INT4 量化比较激进,最终结果可能会受到影响。所以,一般对于基于 transformer 的模型,我们通常不量化最后一个投影层。在 Quanto 中,我们做法如下:

quantize(pipeline.transformer, weights=qint4, exclude="proj_out")
freeze(pipeline.transformer)

"proj_out" 对应于 pipeline.transformer 的最后一层。下表列出了各种设置的结果:

量化文本编码器: 否 , 不量化的层: 无 量化文本编码器: 否 , 不量化的层: "proj_out" 量化文本编码器: 是 , 不量化的层: 无 量化文本编码器: 是 , 不量化的层: "proj_out"
Image 1 without text encoder quantization.Image 2 without text encoder quantization but with proj_out excluded in diffusion transformer quantization.Image 3 with text encoder quantization.Image 3 with text encoder quantization but with proj_out excluded in diffusion transformer quantization..

为了恢复损失的图像质量,常见的做法是进行量化感知训练,Quanto 也支持这种训练。这项技术超出了本文的范围,如果你有兴趣,请随时与我们联系!

本文的所有实验结果都可以在 这里 找到。

加个鸡腿 - 在 Quanto 中保存和加载 Diffusers 模型

以下代码可用于对 Diffusers 模型进行量化并保存量化后的模型:

from diffusers import PixArtTransformer2DModel
from optimum.quanto import QuantizedPixArtTransformer2DModel, qfloat8model = PixArtTransformer2DModel.from_pretrained("PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", subfolder="transformer")
qmodel = QuantizedPixArtTransformer2DModel.quantize(model, weights=qfloat8)
qmodel.save_pretrained("pixart-sigma-fp8")

此代码生成的 checkpoint 大小为 587MB ,而不是原本的 2.44GB。然后我们可以加载它:

from optimum.quanto import QuantizedPixArtTransformer2DModel
import torchtransformer = QuantizedPixArtTransformer2DModel.from_pretrained("pixart-sigma-fp8")
transformer.to(device="cuda", dtype=torch.float16)

最后,在 DiffusionPipeline 中使用它:

from diffusers import DiffusionPipeline
import torchpipe = DiffusionPipeline.from_pretrained("PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",transformer=None,torch_dtype=torch.float16,
).to("cuda")
pipe.transformer = transformerprompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt).images[0]

将来,我们计划支持在初始化流水线时直接传入 transformer 就可以工作:

pipe = PixArtSigmaPipeline.from_pretrained("PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
- transformer=None,
+ transformer=transformer,torch_dtype=torch.float16,
).to("cuda")

QuantizedPixArtTransformer2DModel 实现可参考 此处。如果你希望 Quanto 支持对更多的 Diffusers 模型进行保存和加载,请在 此处 提出需求并 @sayakpaul

小诀窍

  • 根据应用场景的不同,你可能希望对流水线中不同的模块使用不同类型的量化。例如,你可以对文本编码器进行 FP8 量化,而对 transformer 扩散模型进行 INT8 量化。由于 Diffusers 和 Quanto 的灵活性,你可以轻松实现这类方案。
  • 为了优化你的用例,你甚至可以将量化与 Diffuser 中的其他 内存优化技术 结合起来,如 enable_model_cpu_offload()

总结

本文,我们展示了如何量化 Diffusers 中的 transformer 模型并优化其内存消耗。当我们同时对文本编码器进行量化时,效果变得更加明显。我们希望大家能将这些工作流应用到你的项目中并从中受益🤗。

感谢 Pedro Cuenca 对本文的细致审阅。


英文原文: https://hf.co/blog/quanto-diffusers

原文作者: Sayak Paul,David Corvoysier

译者: Matrix Yao (姚伟峰),英特尔深度学习工程师,工作方向为 transformer-family 模型在各模态数据上的应用及大规模模型的训练推理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/789607.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

opc da 服务器数据 转IEC61850项目案例

目录 1 案例说明 1 2 VFBOX网关工作原理 1 3 应用条件 2 4 查看OPC DA服务器的相关参数 2 5 配置网关采集opc da数据 4 6 用IEC61850协议转发数据 6 7 网关使用多个逻辑设备和逻辑节点的方法 9 8 在服务器上运行仰科OPC DA采集软件 10 9 案例总结 12 1 案例说明在OPC DA服务器上…

电科校园邮箱系统逻辑漏洞

校园邮件系统逻辑漏洞导致邮件轰炸 邮件轰炸 首先通过自己的账号登录进入邮件系统之后,进入到信息修改的界面发现存在邮箱绑定功能,在尝试绑定自己的邮箱之后,可以看到存在提示“找回密码时可以使用备用邮箱找回”。输入邮箱密码之后进入到下一个页面在此页面完成邮箱绑定,…

基于surging 如何利用peerjs进行语音视频通话

一 、 概述 PeerJS 是一个基于浏览器WebRTC功能实现的js功能包,简化了WebrRTC的开发过程,对底层的细节做了封装,直接调用API即可,再配合surging 协议组件化从而做到稳定,高效可扩展的微服务,再利用RtmpToWebrtc 引擎组件可以做到不仅可以利用httpflv 观看rtmp推流直播,还可…

gitee误删项目,重新上传

删除项目更目录.git 解除绑定

pinpoint-php-aop 内部原理

pinpoint-php-aop 是一个支持pinpoint-php agent 的库自动注入PHP内置函数,比如redis,pdo,mysqli 自动注入用户类,比如 guzzlehttp, predis怎样处理内置函数内置函数解释:PHP comes standard with many functions and constructs. There are also functions that require…

从代码到产品,我的IT职业成长之路

每个人的职业生涯都是一段充满转折和挑战的旅程,当然每一次职业转型都是一次重新定义自己的机会,从2015年开始,当时我刚踏入IT行业,成为一名Java开发者,后来随着时间的推移,我的职业方向逐渐转向了前端开发者,埋头于代码的世界。最终在2018年找到了属于自己的职业定位—…

1-0.AI工具

1-0.AI工具 一. 我知道或使用过的AI大模型平台 1. OpenAI 平台: OpenAI GPT 特点: 提供先进的自然语言处理能力,支持对话生成、文本总结、翻译等。包括GPT-3、GPT-4等版本。 2. Google AI 平台: Google Cloud AI 特点: 提供全面的AI和机器学习服务,包括AutoML、自然语言处理、…

折腾 Quickwit,Rust 编写的分布式搜索引擎 - 可观测性之分布式追踪

概述 分布式追踪是一种跟踪应用程序请求流经不同服务(如前端、后端、数据库等)的过程。它是一个强大的工具,可以帮助您了解应用程序的工作原理并调试性能问题。 Quickwit 是一个用于索引和搜索非结构化数据的云原生引擎,这使其非常适合用作追踪数据的后端。 此外,Quickwit…

POA:已开源,蚂蚁集团提出同时预训练多种尺寸网络的自监督范式 | ECCV 2024

论文提出一种新颖的POA自监督学习范式,通过弹性分支设计允许同时对多种尺寸的模型进行预训练。POA可以直接从预训练teacher生成不同尺寸的模型,并且这些模型可以直接用于下游任务而无需额外的预训练。这个优势显著提高了部署灵活性,并有助于预训练的模型在各种视觉任务中取得…

【信息收集】旁站和C段

一、 站长之家二、 google hacking2.1 网络空间搜索引擎2.2 在线c段 webscan.cc2.3 Nmap,Msscan扫描等2.4 常见端口表旁站往往存在业务功能站点,建议先收集已有IP的旁站,再探测C段,确认C段目标后,再在C段的基础上再收集一次旁站。 旁站是和已知目标站点在同一服务器但不同端…

茂名工厂智能视频监控系统

茂名工厂智能视频监控系统除开监控出入工作人员外,还必须监控车子,以追踪出入时长。除开组装超清精彩短视频监控监控摄像头外,茂名工厂智能视频监控系统还必须组装车辆识别系统和智能安全通道。办公室一般是一个主要的信息内容,在安装视频监控时,也需要考虑到防盗系统系统…

【信息收集】查找真实ip

一、 多地ping确认是否使用CDN二、查询历史DNS解析记录2.1 DNSDB2.2 微步在线2.3 Ipip.net2.4 viewdns三、phpinfo四、绕过CDN如果目标网站使用了CDN,使用了cdn真实的ip会被隐藏,如果要查找真实的服务器就必须获取真实的ip,根据这个ip继续查询旁站。 注意:很多时候,主站虽…