【深度学习】SDXL-Lightning 体验,gradio教程,SDXL-Lightning 论文

文章目录

  • 资源
  • SDXL-Lightning 论文

资源

SDXL-Lightning论文:https://arxiv.org/abs/2402.13929

gradio教程:https://blog.csdn.net/qq_21201267/article/details/131989242

SDXL-Lightning :https://huggingface.co/ByteDance/SDXL-Lightning

SDXL-Lightning实时出图:https://huggingface.co/spaces/radames/Real-Time-Text-to-Image-SDXL-Lightning

SDXL-Lightning demo自己体验代码:

import timeimport gradio as gr
import torch
import base64
import io
from PIL import Image
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_filebase = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_4step_unet.safetensors"  # Use the correct ckpt for your step setting!# Load model.
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")# Ensure sampler uses "trailing" timesteps.
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")def get_image_from_text(prompt):time1 = time.time()# Ensure using the same inference steps as the loaded model and CFG set to 0.image = pipe(prompt, num_inference_steps=4, guidance_scale=0).images[0]print("time:", time.time() - time1)return imagedef generate(prompt):result_image = get_image_from_text(prompt)return result_imagegr.close_all()
demo = gr.Interface(fn=generate,inputs=[gr.Textbox(label="提示词")],outputs=[gr.Image(label="输出图片")],title="文本生成图片",description="输入提示词,使用SD模型生成图片",allow_flagging="never",examples=["A girl smiling", "A beautiful sunset"])demo.launch(share=True, server_name="0.0.0.0", server_port=7869)

SDXL-Lightning 论文

摘要
我们提出了一种扩散蒸馏方法,在基于SDXL的一步/少步1024px文本到图像生成中实现了新的最先进水平。我们的方法结合了渐进和对抗性蒸馏,以在质量和模式覆盖之间取得平衡。在本文中,我们讨论了理论分析、鉴别器设计、模型构建和训练技术。我们将我们的蒸馏SDXL-Lightning模型开源,包括LoRA和完整的UNet权重。

模型链接:https://huggingface.co/ByteDance/SDXL-Lightning

  1. 引言
    扩散模型是一类新兴的生成模型,已在各种应用中取得了最先进的结果,如文本到图像、文本到视频和图像到视频等。然而,扩散模型的迭代生成过程缓慢且计算量大。如何更快地生成高质量样本是一个积极研究的领域,也是我们工作的主要焦点。

从概念上讲,生成涉及逐渐将样本在数据和噪声概率分布之间传输的概率流。扩散模型学习预测该流的任何位置的梯度。生成只是通过遵循流中预测的梯度,将样本从噪声分布传输到数据分布。由于流是复杂且弯曲的,生成必须一次小步骤地进行。形式上,流可以表示为常微分方程(ODE)。实践中,生成高质量数据样本需要超过50个推理步骤。

已经研究了不同的方法来减少推理步骤的数量。先前的研究提出了更好的ODE求解器来考虑流的弯曲性质。其他人提出了使流更直的公式。尽管如此,这些方法通常仍需要超过20个推理步骤。

另一方面,模型蒸馏可以在不到10个推理步骤下生成高质量的样本。它不是预测当前流位置的梯度,而是将模型更改为直接预测未来更远处的下一个流位置。现有方法可以在4或8个推理步骤下获得良好的结果,但是使用1或2个推理步骤仍然不符合生产要求。我们的方法属于模型蒸馏范畴,并且与现有方法相比获得了更优越的质量。

我们的方法结合了渐进蒸馏和对抗性蒸馏的优点。渐进蒸馏确保蒸馏模型遵循与原始模型相同的概率流,并具有相同的模式覆盖。然而,使用均方误差(MSE)损失的渐进蒸馏在8个推理步骤以下会产生模糊的结果,我们在论文中提供了理论分析。为了减轻这个问题,我们在蒸馏的每个阶段使用对抗损失,以在质量和模式覆盖之间取得平衡。渐进蒸馏还带来了另一个好处,即对于多步采样,我们的模型预测ODE轨迹上的下一个位置,而不是每次跳到ODE轨迹的端点,这更好地保留了原始模型行为,并促进了与LoRA模块和控制插件的更好兼容性。

此外,我们的论文提出了创新的鉴别器设计、损失目标和稳定的训练技术。具体来说,我们使用预训练的扩散UNet编码器作为鉴别器骨干,并完全在潜在空间中操作。我们提出了两个对抗损失目标来权衡样本质量和模式覆盖。我们研究了扩散计划和输出形式的影响。我们讨论了稳定对抗训练的技术。我们的蒸馏方法产生了支持1024px分辨率的一步/少步生成的新的最先进的SDXL模型。我们将我们的蒸馏模型开源为SDXL-Lightning。

2.4. 对抗性蒸馏
对抗性训练涉及一个最小最大化优化,其中包括一个旨在识别生成样本和真实样本的鉴别器网络,以及一个旨在欺骗鉴别器的生成器网络。最初提出为生成对抗网络(GANs),但它存在模式坍塌和不稳定性等问题。最近的研究发现,对抗目标可以纳入扩散训练和蒸馏中。SDXL-Turbo是使用对抗性扩散蒸馏的最新和最流行的开源模型。它遵循先前的工作,使用预训练的图像编码器DINOv2作为鉴别器骨干来加速训练。然而,这带来了几个限制。首先,使用现成的视觉编码器意味着它必须在像素空间而不是潜在空间中操作,这会显著增加计算、内存消耗和训练时间,使高分辨率的蒸馏变得不切实际。这很可能是SDXL-Turbo只支持最高512px分辨率的原因。其次,现成的视觉编码器只在t = 0时起作用。蒸馏模型必须被训练以跳到ODE轨迹端点x0,但由于一步推理的质量还不够好,再次为多步推理添加随机噪声。这种多步推理的方式显著改变了模型行为,使其与现有的LoRA模块和控制插件的兼容性降低。第三,现成的编码器可能很难找到适用于其他数据集(动漫、线条图等)和模态(视频、音频等)的编码器,这降低了蒸馏方法的泛化能力。最后,仅凭对抗目标本身不能强制模型遵循相同的概率流,因此不能强制模式覆盖。

我们的方法使用扩散模型的U-Net编码器作为鉴别器骨干。这使我们能够在潜在空间中有效地进行高分辨率模型的蒸馏,支持在所有时间步骤进行鉴别,并可泛化到所有数据集和模态。我们的方法还允许控制质量和模式覆盖之间的权衡,如后面3.2和3.4节所讨论的那样。

2.5. 其他蒸馏方法
我们简要讨论了我们的方法与其他蒸馏方法相比的优点。

一致性模型(CM)也需要在每个推理步骤中跳转到ODE轨迹的端点。这导致多步采样时模型行为的巨大变化,降低了与LoRA模块和插件的兼容性。该方法已应用于SDXL,但在8个步骤以下的生成质量较差。一致性轨迹模型(CTM)增加了对抗性损失,并支持跳转到任意流位置,但对抗性训练是在蒸馏后应用的,而不是在蒸馏过程中应用的,而且该方法尚未应用于大规模的文本到图像模型。

矫正流(RF)通过重复使用确定性数据和噪声对训练,使流变得直。然而,其少步生成质量仍然很差。此外,由于在蒸馏过程中模型只见过特定的数据和噪声对,它不再支持将数据与任意噪声配对,这影响了像SDEdit这样的图像编辑的能力。

得分蒸馏采样(SDS)已用于SDXL-Turbo来稳定对抗性训练,但其效果很小,并且不能单独作为蒸馏方法使用。变分得分蒸馏(VSD)最近在扩散蒸馏中使用。然而,在蒸馏过程中需要训练一个额外的负分布得分模型,而且像对抗训练中的鉴别器一样,它还涉及动态训练目标,这可能会对训练稳定性产生负面影响。没有开源模型供比较,我们的初步实验发现我们的方法达到了更好的质量。

2.6. LoRA
低秩适应(LoRA)是一种高效的微调技术。它只训练少量额外的参数,并已成为对现有文本到图像模型进行风格化模块训练的特别流行方法。

LCM-LoRA是首个表明模型蒸馏也可以作为LoRA模块进行训练的模型。这确保了最小的参数更改,并可以方便地插入到现有的生态系统中。

我们的工作受到这种方法的启发,我们提供了我们的蒸馏模型作为LoRA,以便进行方便的插拔,并且作为完整模型以获得更好的质量。

  1. 方法
    3.1. 为什么使用MSE蒸馏失败
    在这里插入图片描述

图1. 不同容量模型学习的多个可能流的示意图。针对少步生成的蒸馏学生模型无法具备与教师模型匹配的相同容量,导致使用MSE损失产生模糊结果。
学习到的概率流由数据集、前向函数、损失函数和模型容量确定。鉴于有限的训练样本,底层数据分布是模糊的。最大似然估计(MLE)是一种将均匀概率分配给观察到的样本,其他地方概率为零的分布。如果模型容量无限,它将学习到这种最大似然估计的流,并过度拟合以始终生成观察到的样本并生成没有新数据。实际上,扩散模型可以生成新数据,因为神经网络不是精确学习器。
当模型用于多步生成时,它被堆叠并具有更高的利普希茨常数和更多的非线性,以逼近更复杂的分布。但是当模型用于少步生成时,它不再具有足够的容量来很好地逼近相同的分布。这可以通过扩散模型在初始噪声上进行轻微更改而产生的结果发生非常明显的变化来证明,但是蒸馏模型的潜在遍历更加平滑。这解释了为什么使用MSE损失进行蒸馏会产生模糊的结果。学生模型简单地没有能力与教师相匹配。
此外,神经网络参数优化涉及复杂的景观。即使具有相同容量的模型也很难完全匹配输出,因为参数可能会卡在不同的局部最小值处。
我们发现其他距离度量,例如L1和感知损失,也会产生不理想的结果。另一方面,我们发现对抗目标对缓解这个问题是有效的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/498166.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

蓝桥杯_定时器的基本原理与应用

一 什么是定时器 定时器/计数器是一种能够对内部时钟信号或外部输入信号进行计数,当计数值达到设定要求时,向cpu提出中断处理请求,从而实现,定时或者计数功能的外设。 二 51单片机的定时/计数器 单片机外部晶振12MHZ,…

Facebook的元宇宙实践:数字化社交的新前景

近年来,元宇宙(Metaverse)这一概念备受瞩目,被认为是数字化社交的未来趋势之一。而在众多科技巨头中,Facebook(现更名为Meta)一直处于元宇宙发展的前沿。在本文中,我们将深入探讨Fac…

火灾安全护航:火灾监测报警摄像机助力建筑安全

火灾是建筑安全中最常见也最具破坏力的灾难之一,为了及时发现火灾、减少火灾造成的损失,火灾监测报警摄像机应运而生,成为建筑防火安全的重要技术装备。 火灾监测报警摄像机采用高清晰度摄像头和智能识别系统,能够全天候监测建筑内…

下载huggingface数据集到本地并读取.arrow文件遇到的问题

文章目录 1. 524MB中文维基百科语料(需要下载的数据集)2. 下载 hugging face 网站上的数据集3. 读取 .arrow 文件报错代码4. 纠正后代码 1. 524MB中文维基百科语料(需要下载的数据集) 2. 下载 hugging face 网站上的数据集 要将H…

视频二维码生成的应用领域:探寻多彩世界的大门

在数字化时代,生成二维码已成为连接线上线下的桥梁,而视频二维码则为信息传递注入了全新的生机和活力。二维码不再只是静态的黑白方块,通过二维彩虹技术的运用,视频二维码打开了一扇通往多彩世界的大门,让我们一同探寻…

【Java程序员面试专栏 算法思维】一 高频面试算法题:排序算法

一轮的算法训练完成后,对相关的题目有了一个初步理解了,接下来进行专题训练,以下这些题目就是汇总的高频题目,本篇主要聊聊排序算法,包括手撕排序算法,经典的TOPK问题以及区间合并,所以放到一篇Blog中集中练习 题目关键字解题思路时间空间快速排序双指针+递归+基准值分…

半导体行业案例:Jira与龙智插件助力某半导体企业实现精益项目管理

近日,龙智Atlassian技术团队收到了国内一家大型半导体企业的感谢信。龙智团队提供的半导体行业项目管理解决方案和服务受到了客户的好评: 在龙智团队的支持下,我们的业务取得了喜人的成果和进步。龙智公司的专业服务和产品,是我们…

预约出行真方便!苏州金龙海格客车服务京津冀上班族总运量破百万

元宵已过,全国各地上班族们纷纷回到各自的岗位,通勤路也再次变得繁忙。近日,“京津冀协同合作打造环京通勤定制快巴网络”入选交通运输部公布的道路客运转型发展典型案例。截至2023年11月底,京津冀定制快巴已开通北京至河北燕郊、…

c# ABB 机械手上位机连接

c# 程式开发和调试步骤如下: ABB 机械手要开启PC Interface功能。ABB 机械手设定ip地址。设定测试笔记本和机械手同一网段,用网线直连机械手,也可以通过交换机连接机械手。确保笔记本能够ping通和telnet 机械手80端口都是OK的。以上都OK的话…

阿里云A10推理qwen

硬件配置 vCPU:32核 内存:188 GiB 宽带:5 Mbps GPU:NVIDIA A10 24Gcuda 安装 wget https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda-repo-rhel7-12-1-local-12.1.0_530.30.02-1.x86_64.rpm s…

js 手写深拷贝方法

文章目录 一、深拷贝实现代码二、代码讲解2.1 obj.constructor(obj)2.2 防止循环引用 手写一个深拷贝是我们常见的面试题,在实现过程中我们需要考虑的类型很多,包括对象、数组、函数、日期等。以下就是深拷贝实现逻辑 一、深拷贝实现代码 const origin…

Python中re(正则)模块的使用

re 是 Python 标准库中的一个模块,用于支持正则表达式操作。通过 re 模块,可以使用各种正则表达式来搜索、匹配和操作字符串数据。 使用 re 模块可以帮助在处理字符串时进行高效的搜索和替换操作,特别适用于需要处理文本数据的情况。 # 导入…