组装自己的稳定扩散模型

在本文中,我们将利用 Hugging Face Diffusers 库的组件实现自己的稳定扩散模型,可以像 diffuser.diffuse() 一样简单地生成图像。

在线工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 

1、概述

在我们开始使用代码之前,让我们回顾一下扩散器的推理工作原理。

  • 我们向扩散器输入提示。
  • 该提示通过文本编码器给出数学表示(嵌入)。
  • 产生了潜在的噪声。
  • U-Net 结合提示来预测潜在的噪声。
  • 与调度程序一起从潜在噪声中减去预测噪声。
  • 经过多次迭代后,去噪后的潜在图像被解压缩以生成最终生成的图像。

使用的主要组件有:

  • 文本编码器
  • U-Net模型
  • VAE 解码器

2、环境搭建

! pip install -Uqq fastcore transformers diffusers
import logging; logging.disable(logging.WARNING) # <1>
from fastcore.all import *
from fastai.imports import *
from fastai.vision.all import *

3、获取组件

要处理提示,我们需要下载CLIP分词器和文本编码器。 分词器会将提示分割成标记,而文本编码器会将标记转换为数字表示(嵌入)。

from transformers import CLIPTokenizer, CLIPTextModeltokz = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=torch.float16)
txt_enc = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=torch.float16).to('cuda')

float16 用于提高性能。

U-Net将预测图像中的噪声,而VAE将对生成的图像进行解压缩。

from diffusers import AutoencoderKL, UNet2DConditionModelvae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-ema', torch_dtype=torch.float16).to('cuda')
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).to("cuda")

调度器(scheduler)将控制最初添加到图像中的噪声量,还将控制从图像中减去 U-Net 预测的噪声量。

from diffusers import LMSDiscreteSchedulersched = LMSDiscreteScheduler(beta_start = 0.00085,beta_end = 0.012,beta_schedule = 'scaled_linear',num_train_timesteps = 1000
); sched
LMSDiscreteScheduler {"_class_name": "LMSDiscreteScheduler","_diffusers_version": "0.16.0","beta_end": 0.012,"beta_schedule": "scaled_linear","beta_start": 0.00085,"num_train_timesteps": 1000,"prediction_type": "epsilon","trained_betas": null
}

4、定义生成参数

生成所需的六个主要参数是:

  • prompt:提示
  • w, h:图像的宽度和高度
  • n_inf_steps:描述输出图像的噪声程度的数字(推理步数)
  • g_scale:描述扩散器应遵循提示的程度的数字(引导尺度)
  • bs:批大小
  • seed:种子
prompt = ['a photograph of an astronaut riding a horse']
w, h = 512, 512
n_inf_steps = 70
g_scale = 7.5
bs = 1
seed = 77

5、编码提示

现在我们需要解析提示。 为此,我们首先将其分词,然后对得到的标记进行编码以生成嵌入。

首先,让我们进行分词:

txt_inp = tokz(prompt,padding = 'max_length',max_length = tokz.model_max_length,truncation = True,return_tensors = 'pt'
); txt_inp

结果如下:

{'input_ids': tensor([[49406,   320,  8853,   539,   550, 18376,  6765,   320,  4558, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0]])}

标记 49407 是一个填充标记,表示 '<|endoftext|>'。 这些标记的注意力掩码为 0。

tokz.decode(49407)

输出如下:

'<|endoftext|>'

现在使用文本编码器,我们将创建这些标记的嵌入向量:

txt_emb = txt_enc(txt_inp['input_ids'].to('cuda'))[0].half(); txt_emb

输出如下:

tensor([[[-0.3884,  0.0229, -0.0523,  ..., -0.4902, -0.3066,  0.0674],[ 0.0292, -1.3242,  0.3076,  ..., -0.5254,  0.9766,  0.6655],[ 0.4609,  0.5610,  1.6689,  ..., -1.9502, -1.2266,  0.0093],...,[-3.0410, -0.0674, -0.1777,  ...,  0.3950, -0.0174,  0.7671],[-3.0566, -0.1058, -0.1936,  ...,  0.4258, -0.0184,  0.7588],[-2.9844, -0.0850, -0.1726,  ...,  0.4373,  0.0092,  0.7490]]],device='cuda:0', dtype=torch.float16, grad_fn=<NativeLayerNormBackward0>)

查看txt_emb的形状:

txt_emb.shape

输出如下:

torch.Size([1, 77, 768])

6、CFG 的嵌入

我们还需要为空提示(也称为无条件提示)创建嵌入。 这种嵌入用于控制引导。

txt_inp['input_ids'].shape
torch.Size([1, 77])
max_len = txt_inp['input_ids'].shape[-1] # <1>
uncond_inp = tokz([''] * bs, # <2>padding = 'max_length',max_length = max_len,return_tensors = 'pt',
); uncond_inp

我们使用提示的最大长度,因此无条件提示嵌入与文本提示嵌入的大小相匹配。
我们还将包含空提示的列表与批量大小相乘,以便每个文本提示都有一个空提示。

{'input_ids': tensor([[49406, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407]]), 'attention_mask': tensor([[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0]])}
uncond_inp['input_ids'].shape
torch.Size([1, 77])
uncond_emb = txt_enc(uncond_inp['input_ids'].to('cuda'))[0].half()
uncond_emb.shape
torch.Size([1, 77, 768])

然后我们可以将无条件嵌入和文本嵌入连接在一起。 这允许根据每个提示生成图像,而无需通过 U-Net 两次。

embs = torch.cat([uncond_emb, txt_emb])

7、创建噪声图像

现在是时候创建我们的噪声图像了,这将是生成的起点。

我们将创建一个64 x 64 像素的单个潜在图像,并且也有 4 个通道。 对潜在图像进行去噪后,我们将其解压缩为具有 3 个通道的 512 x 512 像素图像。

bs, unet.config.in_channels, h//8, w//8
(1, 4, 64, 64)
print(torch.randn((2, 3, 4)))
print(torch.randn((2, 3, 4)).shape)
tensor([[[ 0.2818,  1.9993, -0.2554, -1.8170],[-0.5899,  0.6199,  0.4697,  0.8363],[ 0.4416, -1.1702,  0.0392, -1.3377]],[[ 1.6029,  0.2883, -0.4365,  0.5624],[-1.4361, -0.6055,  0.9542, -0.2457],[-1.4045, -0.2218,  0.3492, -0.1245]]])
torch.Size([2, 3, 4])
torch.manual_seed(seed)
lats = torch.randn((bs, unet.config.in_channels, h//8, w//8)); lats.shape
torch.Size([1, 4, 64, 64])

潜在张量是 4 阶张量。 1 指的是批量大小,即生成的图像数量。 4 是通道数,64 是高度和宽度的像素数。

lats = lats.to('cuda').half(); lats
tensor([[[[-0.5044, -0.4163, -0.1365,  ..., -1.6104,  0.1381,  1.7676],[ 0.7017,  1.5947, -1.4434,  ..., -1.5859, -0.4089, -2.8164],[ 1.0664, -0.0923,  0.3462,  ..., -0.2390, -1.0947,  0.7554],...,[-1.0283,  0.2433,  0.3337,  ...,  0.6641,  0.4219,  0.7065],[ 0.4280, -1.5439,  0.1409,  ...,  0.8989, -1.0049,  0.0482],[-1.8682,  0.4988,  0.4668,  ..., -0.5874, -0.4019, -0.2856]],[[ 0.5688, -1.2715, -1.4980,  ...,  0.2230,  1.4785, -0.6821],[ 1.8418, -0.5117,  1.1934,  ..., -0.7222, -0.7417,  1.0479],[-0.6558,  0.1201,  1.4971,  ...,  0.1454,  0.4714,  0.2441],...,[ 0.9492,  0.1953, -2.4141,  ..., -0.5176,  1.1191,  0.5879],[ 0.2129,  1.8643, -1.8506,  ...,  0.8096, -1.5264,  0.3191],[-0.3640, -0.9189,  0.8931,  ..., -0.4944,  0.3916, -0.1406]],[[-0.5259,  1.5059, -0.3413,  ...,  1.2539,  0.3669, -0.1593],[-0.2957, -0.1169, -2.0078,  ...,  1.9268,  0.3833, -0.0992],[ 0.5020,  1.0068, -0.9907,  ..., -0.3008,  0.7324, -1.1963],...,[-0.7437, -1.1250,  0.1349,  ..., -0.6714, -0.6753, -0.7920],[ 0.5415, -0.5269, -1.0166,  ...,  1.1270, -1.7637, -1.5156],[-0.2319,  0.9165,  1.6318,  ...,  0.6602, -1.2871,  1.7568]],[[ 0.7100,  0.4133,  0.5513,  ...,  0.0326,  0.9175,  1.4922],[ 0.8862,  1.3760,  0.8599,  ..., -2.1172, -1.6533,  0.8955],[-0.7783, -0.0246,  1.4717,  ...,  0.0328,  0.4316, -0.6416],...,[ 0.0855, -0.1279, -0.0319,  ..., -0.2817,  1.2744, -0.5854],[ 0.2402,  1.3945, -2.4062,  ...,  0.3435, -0.5254,  1.2441],[ 1.6377,  1.2539,  0.6099,  ...,  1.5391, -0.6304,  0.9092]]]],device='cuda:0', dtype=torch.float16)

我们的潜在变量具有代表噪声的随机值。 这种噪声需要进行缩放,以便它可以与调度程序一起工作。

#| id: DgrthbcIEzVO
#| colab: {base_uri: 'https://localhost:8080/'}
#| id: DgrthbcIEzVO
#| outputId: 761f0f3c-010e-4dfa-b7a3-6d94d026d4cc
sched.set_timesteps(n_inf_steps); sched
LMSDiscreteScheduler {"_class_name": "LMSDiscreteScheduler","_diffusers_version": "0.16.0","beta_end": 0.012,"beta_schedule": "scaled_linear","beta_start": 0.00085,"num_train_timesteps": 1000,"prediction_type": "epsilon","trained_betas": null
}
lats *= sched.init_noise_sigma; sched.init_noise_sigma
tensor(14.6146)
sched.sigmas
tensor([14.6146, 13.3974, 12.3033, 11.3184, 10.4301,  9.6279,  8.9020,  8.2443,7.6472,  7.1044,  6.6102,  6.1594,  5.7477,  5.3709,  5.0258,  4.7090,4.4178,  4.1497,  3.9026,  3.6744,  3.4634,  3.2680,  3.0867,  2.9183,2.7616,  2.6157,  2.4794,  2.3521,  2.2330,  2.1213,  2.0165,  1.9180,1.8252,  1.7378,  1.6552,  1.5771,  1.5031,  1.4330,  1.3664,  1.3030,1.2427,  1.1852,  1.1302,  1.0776,  1.0272,  0.9788,  0.9324,  0.8876,0.8445,  0.8029,  0.7626,  0.7236,  0.6858,  0.6490,  0.6131,  0.5781,0.5438,  0.5102,  0.4770,  0.4443,  0.4118,  0.3795,  0.3470,  0.3141,0.2805,  0.2455,  0.2084,  0.1672,  0.1174,  0.0292,  0.0000])
sched.timesteps
tensor([999.0000, 984.5217, 970.0435, 955.5652, 941.0870, 926.6087, 912.1304,897.6522, 883.1739, 868.6957, 854.2174, 839.7391, 825.2609, 810.7826,796.3043, 781.8261, 767.3478, 752.8696, 738.3913, 723.9130, 709.4348,694.9565, 680.4783, 666.0000, 651.5217, 637.0435, 622.5652, 608.0870,593.6087, 579.1304, 564.6522, 550.1739, 535.6957, 521.2174, 506.7391,492.2609, 477.7826, 463.3043, 448.8261, 434.3478, 419.8696, 405.3913,390.9130, 376.4348, 361.9565, 347.4783, 333.0000, 318.5217, 304.0435,289.5652, 275.0870, 260.6087, 246.1304, 231.6522, 217.1739, 202.6957,188.2174, 173.7391, 159.2609, 144.7826, 130.3043, 115.8261, 101.3478,86.8696,  72.3913,  57.9130,  43.4348,  28.9565,  14.4783,   0.0000],dtype=torch.float64)
plt.plot(sched.timesteps, sched.sigmas[:-1])

8、去噪

降噪过程现在可以开始了!

from tqdm.auto import tqdmfor i, ts in enumerate(tqdm(sched.timesteps)):inp = torch.cat([lats] * 2) # <1>inp = sched.scale_model_input(inp, ts) # <2>with torch.no_grad(): preds = unet(inp, ts, encoder_hidden_states=embs).sample # <3>pred_uncond, pred_txt = preds.chunk(2) # <4>pred = pred_uncond + g_scale * (pred_txt - pred_uncond) # <4>lats = sched.step(pred, ts, lats).prev_sample #<5>
  • 我们首先创建两个潜在变量:一个用于文本提示,一个用于无条件提示。
  • 然后我们进一步缩放潜在的噪声。
  • 然后我们预测噪声。
  • 然后我们进行指导。
  • 然后,我们从图像中减去预测的引导噪声。

9、解码

我们现在可以解码潜在图像并显示它。

with torch.no_grad(): img = vae.decode(1/0.18215*lats).sample
img = (img / 2 + 0.5).clamp(0, 1)
img = img[0].detach().cpu().permute(1, 2, 0).numpy()
img = (img * 255).round().astype('uint8')
Image.fromarray(img)

现在你就拥有了我们使用文本编码器、VAE 和 U-Net 实现的稳定扩散!


原文链接:组装自己的稳定扩散 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/221535.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Portraiture全新4.1.2版本升级更新

关于PS修图插件&#xff0c;相信大家都有安装过使用过&#xff0c;而且还不止安装了一款&#xff0c;比如最为经典的DR5.0人像精修插件&#xff0c;Retouch4me11合1插件&#xff0c;Portraiture磨皮插件&#xff0c;这些都是人像精修插件中的领跑者。其中 Portraiture 刚刚升级…

内网隧道学习

默认密码&#xff1a;hongrisec2019 一.环境搭建 网卡学习 一个网卡一个分段&#xff0c;想象成一个管道 192.168.52一段 192.168.150一段 仅主机模式保证不予外界连通&#xff0c;保证恶意操作不会跑到真实机之上 52段是内部通信&#xff0c;150段属于服务器&#xff08;…

C语言学习笔记之函数篇

与数学意义上的函数不同&#xff0c;C语言中的函数又称为过程&#xff0c;接口&#xff0c;具有极其重要的作用。教科书上将其定义为&#xff1a;程序中的子程序。 在计算机科学中&#xff0c;子程序&#xff08;英语&#xff1a;Subroutine, procedure, function, routine, me…

FFmpeg零基础学习(二)——视频文件信息获取

目录 前言正文一、获取宽高信息1、核心代码2、AVFormatContext3、avformat_alloc_context4、avformat_open_input5、avformat_find_stream_info6、av_dump_format7、av_find_best_stream End、遇到的问题1、Qt Debug模式avformat_alloc_context 无法分配对象&#xff0c;而Rele…

【Spring】Spring是什么?

文章目录 前言什么是Spring什么是容器什么是 IoC传统程序开发控制反转式程序开发理解Spring IoCDI Spring帮助网站 前言 前面我们学习了 servlet 的相关知识&#xff0c;但是呢&#xff1f;使用 servlet 进行网站的开发步骤还是比较麻烦的&#xff0c;而我们本身程序员就属于是…

广度优先遍历与最短路径

广度优先遍历从某个顶点 v 出发&#xff0c;首先访问这个结点&#xff0c;并将其标记为已访问过&#xff0c;然后顺序访问结点v的所有未被访问的邻接点 {vi,..,vj} &#xff0c;并将其标记为已访问过&#xff0c;然后将 {vi,...,vj} 中的每一个节点重复节点v的访问方法&#xf…

Runloop解析

RunLoop 前言 ​ 本文介绍RunLoop的概念&#xff0c;并使用swift和Objective-C来描述RunLoop机制。 简介 ​ RunLoop——运行循环&#xff08;死循环&#xff09;&#xff0c;它提供了一个事件循环机制在程序运行过程中处理各种事件&#xff0c;例如用户交互、网络请求、定…

HT97226 免输出电容立体声耳机放大器的应用与曲线

HT97226应用&#xff1a; ・耳机 ・多媒体音频接口 ・机顶盒 ・ 蓝光/DVD播放器 ・LCD电视 ・音频消费电子产品 HT97226应用图于曲线&#xff1a; HT97226是一款差分输入/单端输入、可直接输出驱动的耳机放大器。5V供…

网络渗透测试(认识)

ARP协议 逻辑地址变成物理地址 32bit的IP地址变换成48bit的mac地址 ARP两个字节&#xff08;0x0806&#xff09; ARP解析协议 每一个主机都有ARP高速缓存&#xff0c;此缓存中记录了最近一段时间的内其他IP地址与其MAC地址的对应关系 如果本机想与某台主机通信&#xff0c;首先…

Azure Machine Learning - 创建Azure AI搜索服务

目录 准备工作查找 Azure AI 搜索产品/服务选择订阅设置资源组为服务命名选择区域选择层创建服务配置身份验证扩展服务何时添加第二个服务将多个服务添加到订阅 Azure AI 搜索是用于将全文搜索体验添加到自定义应用的 Azure 资源&#xff0c;本文介绍如何创建Azure AI搜索服务 …

5.前端--CSS-基本概念【2023.11.26】

1. CSS 语法规范 CSS 规则由两个主要的部分构成&#xff1a;选择器以及一条或多条声明。 属性和属性值之间用英文“:”分开 多个“键值对”之间用英文“;”进行区分 选择器 : 简单来说&#xff0c;就是选择标签用的。 声明 &#xff1a;就是改变样式 2.CSS引入方式 按照 CSS 样…

网络运维与网络安全 学习笔记2023.11.26

网络运维与网络安全 学习笔记 第二十七天 今日目标 NAT场景与原理、静态NAT、动态NAT PAT原理与配置、动态PAT之EasyIP、静态PAT之NAT Server NAT场景与原理 项目背景 为节省IP地址和费用&#xff0c;企业内网使用的都是“私有IP地址” Internet网络的组成设备&#xff0c…