CogView3---CogView-3Plus-微调代码源码解析-三-

news/2024/10/23 9:25:16/文章来源:https://www.cnblogs.com/apachecn/p/18494398

CogView3 & CogView-3Plus 微调代码源码解析(三)

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\guiders.py

# 导入 logging 模块,用于记录日志信息
import logging
# 从 abc 模块导入 ABC 类和 abstractmethod 装饰器,用于定义抽象基类和抽象方法
from abc import ABC, abstractmethod
# 导入类型注解,方便在函数签名中定义复杂数据结构
from typing import Dict, List, Optional, Tuple, Union
# 从 functools 模块导入 partial 函数,用于部分应用函数
from functools import partial
# 导入数学模块,提供数学函数
import math# 导入 PyTorch 库,提供张量计算功能
import torch
# 从 einops 模块导入 rearrange 和 repeat 函数,用于张量重排和重复
from einops import rearrange, repeat# 从上层模块导入工具函数,提供一些默认值和实例化配置的功能
from ...util import append_dims, default, instantiate_from_config# 定义一个抽象基类 Guider,继承自 ABC
class Guider(ABC):# 定义一个抽象方法 __call__,接受一个张量和一个浮点数,返回一个张量@abstractmethoddef __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:pass# 定义准备输入的方法,接受多个参数并返回一个元组def prepare_inputs(self, x: torch.Tensor, s: float, c: Dict, uc: Dict) -> Tuple[torch.Tensor, float, Dict]:pass# 定义一个类 VanillaCFG,表示基本的条件生成模型
class VanillaCFG:"""implements parallelized CFG"""# 初始化方法,接受比例和动态阈值配置def __init__(self, scale, dyn_thresh_config=None):# 定义一个 lambda 函数,根据 sigma 返回 scale,保持独立于步数scale_schedule = lambda scale, sigma: scale  # independent of step# 使用 partial 固定 scale 参数,创建 scale_schedule 方法self.scale_schedule = partial(scale_schedule, scale)# 实例化动态阈值对象,如果没有提供配置则使用默认配置self.dyn_thresh = instantiate_from_config(default(dyn_thresh_config,{"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"},))# 定义 __call__ 方法,使该类可以被调用,接受多个参数def __call__(self, x, sigma, step = None, num_steps = None, **kwargs):# 将输入张量 x 拆分为两个部分 x_u 和 x_cx_u, x_c = x.chunk(2)# 根据 sigma 计算 scale_valuescale_value = self.scale_schedule(sigma)# 使用动态阈值处理函数进行预测,返回预测结果x_pred = self.dyn_thresh(x_u, x_c, scale_value, step=step, num_steps=num_steps)return x_pred# 定义准备输入的方法,接受多个参数并返回一个元组def prepare_inputs(self, x, s, c, uc):# 初始化输出字典c_out = dict()# 遍历条件字典 c 的键for k in c:# 如果键是特定值,则将 uc 和 c 中的对应张量拼接if k in ["vector", "crossattn", "concat"]:c_out[k] = torch.cat((uc[k], c[k]), 0)# 否则确保两个字典中对应的值相等,并直接赋值else:assert c[k] == uc[k]c_out[k] = c[k]# 返回拼接后的张量和条件字典return torch.cat([x] * 2), torch.cat([s] * 2), c_out# 定义一个类 IdentityGuider,实现一个恒等引导器
class IdentityGuider:# 定义 __call__ 方法,直接返回输入张量def __call__(self, x, sigma, **kwargs):return x# 定义准备输入的方法,返回输入和条件字典def prepare_inputs(self, x, s, c, uc):# 初始化输出字典c_out = dict()# 遍历条件字典 c 的键for k in c:# 直接将条件字典 c 的值赋给输出字典c_out[k] = c[k]# 返回输入张量和条件字典return x, s, c_out# 定义一个类 LinearPredictionGuider,继承自 Guider
class LinearPredictionGuider(Guider):# 初始化方法,接受多个参数def __init__(self,max_scale: float,num_frames: int,min_scale: float = 1.0,additional_cond_keys: Optional[Union[List[str], str]] = None,):# 初始化最小和最大比例self.min_scale = min_scaleself.max_scale = max_scale# 计算比例的线性变化,生成 num_frames 个值self.num_frames = num_framesself.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)# 确保 additional_cond_keys 是一个列表,如果是字符串则转换为列表additional_cond_keys = default(additional_cond_keys, [])if isinstance(additional_cond_keys, str):additional_cond_keys = [additional_cond_keys]# 保存附加条件键self.additional_cond_keys = additional_cond_keys# 定义可调用对象的方法,接收输入张量 x 和 sigma,以及其他参数 kwargs,返回一个张量def __call__(self, x: torch.Tensor, sigma: torch.Tensor, **kwargs) -> torch.Tensor:# 将输入张量 x 拆分为两部分:x_u 和 x_cx_u, x_c = x.chunk(2)# 重排 x_u 的维度,使其形状为 (批量大小 b, 帧数 t, ...),t 由 num_frames 指定x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)# 重排 x_c 的维度,使其形状为 (批量大小 b, 帧数 t, ...),t 由 num_frames 指定x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)# 复制 scale 张量的维度,使其形状为 (批量大小 b, 帧数 t)scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])# 将 scale 的维度扩展到与 x_u 的维度一致,并移动到 x_u 的设备上scale = append_dims(scale, x_u.ndim).to(x_u.device)# 将 scale 转换为与 x_u 相同的数据类型scale = scale.to(x_u.dtype)# 返回经过计算的结果,重排为 (批量大小 b * 帧数 t, ...)return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")# 定义准备输入的函数,接收输入张量 x 和 s,以及条件字典 c 和 uc,返回一个元组def prepare_inputs(self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict) -> Tuple[torch.Tensor, torch.Tensor, dict]:# 初始化一个空字典 c_out 用于存放处理后的条件c_out = dict()# 遍历条件字典 c 的每一个键 kfor k in c:# 如果 k 是指定的条件键之一,进行拼接if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:# 将 uc[k] 和 c[k] 沿第0维拼接,并存入 c_outc_out[k] = torch.cat((uc[k], c[k]), 0)else:# 确保 c[k] 与 uc[k] 相等assert c[k] == uc[k]# 将 c[k] 直接存入 c_outc_out[k] = c[k]# 返回拼接后的 x 和 s 以及处理后的条件字典 c_outreturn torch.cat([x] * 2), torch.cat([s] * 2), c_out

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\loss.py

# 导入所需的标准库和类型提示
import os
import copy
from typing import List, Optional, Union# 导入 NumPy 和 PyTorch 库
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# 导入 OmegaConf 中的 ListConfig
from omegaconf import ListConfig# 从自定义模块中导入所需的函数和类
from ...util import append_dims, instantiate_from_config
from ...modules.autoencoding.lpips.loss.lpips import LPIPS
from ...modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from ...util import get_obj_from_str, default
from ...modules.diffusionmodules.discretizer import generate_roughly_equally_spaced_steps, sub_generate_roughly_equally_spaced_steps# 定义标准扩散损失类,继承自 nn.Module
class StandardDiffusionLoss(nn.Module):# 初始化方法,设置损失类型和噪声级别等参数def __init__(self,sigma_sampler_config,type="l2",offset_noise_level=0.0,batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,):super().__init__()# 确保损失类型有效assert type in ["l2", "l1", "lpips"]# 根据配置实例化 sigma 采样器self.sigma_sampler = instantiate_from_config(sigma_sampler_config)# 保存损失类型和噪声级别self.type = typeself.offset_noise_level = offset_noise_level# 如果损失类型为 lpips,则初始化 lpips 模块if type == "lpips":self.lpips = LPIPS().eval()# 如果没有提供 batch2model_keys,则设置为空列表if not batch2model_keys:batch2model_keys = []# 如果 batch2model_keys 是字符串,则转为列表if isinstance(batch2model_keys, str):batch2model_keys = [batch2model_keys]# 将 batch2model_keys 转为集合以便于后续处理self.batch2model_keys = set(batch2model_keys)# 定义调用方法,计算损失def __call__(self, network, denoiser, conditioner, input, batch):# 使用条件器处理输入批次cond = conditioner(batch)# 从批次中提取附加模型输入additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)}# 生成 sigma 值sigmas = self.sigma_sampler(input.shape[0]).to(input.device)# 生成与输入相同形状的随机噪声noise = torch.randn_like(input)# 如果设置了噪声级别,调整噪声if self.offset_noise_level > 0.0:noise = noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level# 确保噪声数据类型与输入一致noise = noise.to(input.dtype)# 将输入与噪声和 sigma 结合,生成有噪声的输入noised_input = input.float() + noise * append_dims(sigmas, input.ndim)# 使用去噪网络处理有噪声的输入model_output = denoiser(network, noised_input, sigmas, cond, **additional_model_inputs)# 将去噪网络的权重调整为与输入相同的维度w = append_dims(denoiser.w(sigmas), input.ndim)# 返回损失值return self.get_loss(model_output, input, w)# 定义计算损失的方法def get_loss(self, model_output, target, w):# 根据损失类型计算 l2 损失if self.type == "l2":return torch.mean((w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1)# 根据损失类型计算 l1 损失elif self.type == "l1":return torch.mean((w * (model_output - target).abs()).reshape(target.shape[0], -1), 1)# 根据损失类型计算 lpips 损失elif self.type == "lpips":loss = self.lpips(model_output, target).reshape(-1)return loss# 定义线性中继扩散损失类,继承自 StandardDiffusionLoss
class LinearRelayDiffusionLoss(StandardDiffusionLoss):# 初始化方法,设置相关参数def __init__(self,sigma_sampler_config,type="l2",offset_noise_level=0.0,partial_num_steps=500,blurring_schedule='linear',batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,):# 调用父类构造函数,初始化基本参数super().__init__(sigma_sampler_config,  # sigma 采样器的配置type=type,  # 类型参数offset_noise_level=offset_noise_level,  # 偏移噪声水平batch2model_keys=batch2model_keys,  # 批次到模型的键映射)# 设置模糊调度参数self.blurring_schedule = blurring_schedule# 设置部分步骤数量self.partial_num_steps = partial_num_stepsdef __call__(self, network, denoiser, conditioner, input, batch):# 使用调节器处理批次数据,生成条件cond = conditioner(batch)# 生成额外的模型输入,筛选出与模型键对应的批次数据additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)}# 从批次中获取低分辨率输入lr_input = batch["lr_input"]# 生成随机整数,用于选择部分步骤rand = torch.randint(0, self.partial_num_steps, (input.shape[0],))# 从 sigma 采样器生成 sigma 值,并转换为输入数据类型和设备sigmas = self.sigma_sampler(input.shape[0], rand).to(input.dtype).to(input.device)# 生成与输入形状相同的随机噪声noise = torch.randn_like(input)# 如果偏移噪声水平大于0,则添加额外噪声if self.offset_noise_level > 0.0:# 生成额外随机噪声并调整其维度,乘以偏移噪声水平noise = noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level# 转换噪声为输入数据类型noise = noise.to(input.dtype)# 调整 rand 的维度并转换为输入数据类型和设备rand = append_dims(rand, input.ndim).to(input.dtype).to(input.device)# 根据模糊调度的不同方式计算模糊输入if self.blurring_schedule == 'linear':# 线性模糊处理blurred_input = input * (1 - rand / self.partial_num_steps) + lr_input * (rand / self.partial_num_steps)elif self.blurring_schedule == 'sigma':# 使用 sigma 最大值进行模糊处理max_sigmas = self.sigma_sampler(input.shape[0], torch.ones(input.shape[0])*self.partial_num_steps).to(input.dtype).to(input.device)blurred_input = input * (1 - sigmas / max_sigmas) + lr_input * (sigmas / max_sigmas)elif self.blurring_schedule == 'exp':# 指数模糊处理rand_blurring = (1 - torch.exp(-(torch.sin((rand+1) / self.partial_num_steps * torch.pi / 2)**4))) / (1 - torch.exp(-torch.ones_like(rand)))blurred_input = input * (1 - rand_blurring) + lr_input * rand_blurringelse:# 如果模糊调度不被支持,抛出未实现错误raise NotImplementedError# 将噪声添加到模糊输入中noised_input = blurred_input + noise * append_dims(sigmas, input.ndim)# 调用去噪声器处理模糊输入,获取模型输出model_output = denoiser(network, noised_input, sigmas, cond, **additional_model_inputs)# 调整去噪声器权重的维度w = append_dims(denoiser.w(sigmas), input.ndim)# 返回模型输出的损失值return self.get_loss(model_output, input, w)
# 定义一个名为 ZeroSNRDiffusionLoss 的类,继承自 StandardDiffusionLoss
class ZeroSNRDiffusionLoss(StandardDiffusionLoss):# 重载调用方法,接受网络、去噪器、条件、输入和批次作为参数def __call__(self, network, denoiser, conditioner, input, batch):# 使用条件生成器处理批次,得到条件变量cond = conditioner(batch)# 从批次中提取与模型键相交的额外输入additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)}# 生成累积的 alpha 值并获取索引alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], return_idx=True)# 将 alpha 值移动到输入的设备上alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device)# 将索引移动到输入的数据类型和设备上idx = idx.to(input.dtype).to(input.device)# 将索引添加到额外模型输入中additional_model_inputs['idx'] = idx# 生成与输入形状相同的随机噪声noise = torch.randn_like(input)# 如果偏移噪声水平大于零,则添加额外噪声if self.offset_noise_level > 0.0:noise = noise + append_dims(# 生成随机噪声并调整维度,乘以偏移噪声水平torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level# 计算加入噪声的输入noised_input = input.float() * append_dims(alphas_cumprod_sqrt, input.ndim) + noise * append_dims((1-alphas_cumprod_sqrt**2)**0.5, input.ndim)# 使用去噪器处理带噪声的输入model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs)# 计算 v-pred 权重w = append_dims(1/(1-alphas_cumprod_sqrt**2), input.ndim) # 返回损失值return self.get_loss(model_output, input, w)# 定义一个获取损失的函数def get_loss(self, model_output, target, w):# 如果损失类型为 L2,计算 L2 损失if self.type == "l2":return torch.mean(# 计算每个样本的 L2 损失并调整维度(w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1)# 如果损失类型为 L1,计算 L1 损失elif self.type == "l1":return torch.mean(# 计算每个样本的 L1 损失并调整维度(w * (model_output - target).abs()).reshape(target.shape[0], -1), 1)# 如果损失类型为 LPIPS,计算 LPIPS 损失elif self.type == "lpips":loss = self.lpips(model_output, target).reshape(-1)return loss

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\model.py

# pytorch_diffusion + derived encoder decoder
# 导入数学库
import math
# 导入类型注解相关
from typing import Any, Callable, Optional# 导入 numpy 库
import numpy as np
# 导入 pytorch 库
import torch
# 导入 pytorch 神经网络模块
import torch.nn as nn
# 导入 rearrange 函数以处理张量重排列
from einops import rearrange
# 导入版本管理库
from packaging import version# 尝试导入 xformers 模块
try:import xformersimport xformers.ops# 如果成功导入,设置标志为 TrueXFORMERS_IS_AVAILABLE = True
except:# 如果导入失败,设置标志为 False,并打印提示信息XFORMERS_IS_AVAILABLE = Falseprint("no module 'xformers'. Processing without...")# 从其他模块导入 LinearAttention 和 MemoryEfficientCrossAttention
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttentiondef get_timestep_embedding(timesteps, embedding_dim):"""此函数与 Denoising Diffusion Probabilistic Models 中的实现相匹配:来自 Fairseq。构建正弦嵌入。此实现与 tensor2tensor 中的实现相匹配,但与 "Attention Is All You Need" 第 3.5 节中的描述略有不同。"""# 确保时间步长是一维的assert len(timesteps.shape) == 1# 计算嵌入维度的一半half_dim = embedding_dim // 2# 计算嵌入因子的对数emb = math.log(10000) / (half_dim - 1)# 计算并生成指数衰减的嵌入emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)# 将嵌入移动到与时间步相同的设备上emb = emb.to(device=timesteps.device)# 扩展时间步并与嵌入相乘emb = timesteps.float()[:, None] * emb[None, :]# 将正弦和余弦嵌入拼接在一起emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)# 如果嵌入维度是奇数,则进行零填充if embedding_dim % 2 == 1:  # zero pademb = torch.nn.functional.pad(emb, (0, 1, 0, 0))# 返回最终的嵌入return embdef nonlinearity(x):# 使用 swish 激活函数return x * torch.sigmoid(x)def Normalize(in_channels, num_groups=32):# 返回一个 GroupNorm 归一化层return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)class Upsample(nn.Module):def __init__(self, in_channels, with_conv):# 初始化 Upsample 类super().__init__()# 记录是否使用卷积self.with_conv = with_conv# 如果使用卷积,则定义卷积层if self.with_conv:self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)def forward(self, x):# 使用最近邻插值将输入张量上采样x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")# 如果使用卷积,则应用卷积层if self.with_conv:x = self.conv(x)# 返回处理后的张量return xclass Downsample(nn.Module):def __init__(self, in_channels, with_conv):# 初始化 Downsample 类super().__init__()# 记录是否使用卷积self.with_conv = with_conv# 如果使用卷积,则定义卷积层if self.with_conv:# 因为 pytorch 卷积不支持不对称填充,需手动处理self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)def forward(self, x):# 如果使用卷积,先进行填充再应用卷积层if self.with_conv:pad = (0, 1, 0, 1)x = torch.nn.functional.pad(x, pad, mode="constant", value=0)x = self.conv(x)# 否则使用平均池化进行下采样else:x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)# 返回处理后的张量return xclass ResnetBlock(nn.Module):def __init__(self,*,in_channels,out_channels=None,conv_shortcut=False,dropout,temb_channels=512,):# 调用父类的初始化方法super().__init__()# 保存输入通道数self.in_channels = in_channels# 如果未指定输出通道数,则设置为输入通道数out_channels = in_channels if out_channels is None else out_channels# 保存输出通道数self.out_channels = out_channels# 保存是否使用卷积捷径的标志self.use_conv_shortcut = conv_shortcut# 初始化输入通道数的归一化层self.norm1 = Normalize(in_channels)# 定义第一层卷积,输入输出通道及卷积核参数self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)# 如果有时间嵌入通道,则定义时间嵌入投影层if temb_channels > 0:self.temb_proj = torch.nn.Linear(temb_channels, out_channels)# 初始化输出通道数的归一化层self.norm2 = Normalize(out_channels)# 定义 dropout 层self.dropout = torch.nn.Dropout(dropout)# 定义第二层卷积,输入输出通道及卷积核参数self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)# 如果输入和输出通道数不相同if self.in_channels != self.out_channels:# 如果使用卷积捷径,则定义卷积捷径层if self.use_conv_shortcut:self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)# 否则定义 1x1 卷积捷径层else:self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)# 前向传播函数def forward(self, x, temb):# 将输入赋值给 h 变量h = x# 对 h 进行归一化h = self.norm1(h)# 应用非线性激活函数h = nonlinearity(h)# 通过第一层卷积处理 hh = self.conv1(h)# 如果时间嵌入不为 Noneif temb is not None:# 将时间嵌入通过非线性激活函数处理后投影到输出通道,并与 h 相加h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]# 对 h 进行第二次归一化h = self.norm2(h)# 应用非线性激活函数h = nonlinearity(h)# 通过 dropout 层处理 hh = self.dropout(h)# 通过第二层卷积处理 hh = self.conv2(h)# 如果输入和输出通道数不相同if self.in_channels != self.out_channels:# 如果使用卷积捷径,则通过卷积捷径层处理 xif self.use_conv_shortcut:x = self.conv_shortcut(x)# 否则通过 1x1 卷积捷径层处理 xelse:x = self.nin_shortcut(x)# 返回 x 和 h 的相加结果return x + h
# 定义 LinAttnBlock 类,继承自 LinearAttention
class LinAttnBlock(LinearAttention):"""to match AttnBlock usage"""  # 文档字符串,说明该类用于匹配 AttnBlock 的使用方式# 初始化方法,接受输入通道数def __init__(self, in_channels):# 调用父类的初始化方法,设置维度和头数super().__init__(dim=in_channels, heads=1, dim_head=in_channels)# 定义 AttnBlock 类,继承自 nn.Module
class AttnBlock(nn.Module):# 初始化方法,接受输入通道数def __init__(self, in_channels):# 调用父类的初始化方法super().__init__()# 保存输入通道数self.in_channels = in_channels# 初始化归一化层self.norm = Normalize(in_channels)# 初始化查询卷积层self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)# 初始化键卷积层self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)# 初始化值卷积层self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)# 初始化输出投影卷积层self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)# 定义注意力计算方法def attention(self, h_: torch.Tensor) -> torch.Tensor:# 对输入进行归一化h_ = self.norm(h_)# 计算查询、键和值q = self.q(h_)k = self.k(h_)v = self.v(h_)# 获取查询的形状参数b, c, h, w = q.shape# 重新排列查询、键和值的形状q, k, v = map(lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))# 计算缩放的点积注意力h_ = torch.nn.functional.scaled_dot_product_attention(q, k, v)  # scale is dim ** -0.5 per default# 计算注意力# 返回重新排列后的注意力结果return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)# 定义前向传播方法def forward(self, x, **kwargs):# 将输入赋值给 h_h_ = x# 计算注意力h_ = self.attention(h_)# 应用输出投影h_ = self.proj_out(h_)# 返回输入与注意力结果的和return x + h_# 定义 MemoryEfficientAttnBlock 类,继承自 nn.Module
class MemoryEfficientAttnBlock(nn.Module):"""Uses xformers efficient implementation,see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223Note: this is a single-head self-attention operation"""  # 文档字符串,说明该类使用 xformers 高效实现的单头自注意力# 初始化方法,接受输入通道数def __init__(self, in_channels):# 调用父类的初始化方法super().__init__()# 保存输入通道数self.in_channels = in_channels# 初始化归一化层self.norm = Normalize(in_channels)# 初始化查询卷积层self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)# 初始化键卷积层self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)# 初始化值卷积层self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)# 初始化输出投影卷积层self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)# 初始化注意力操作,类型为可选的任意类型self.attention_op: Optional[Any] = None# 定义注意力机制的函数,输入为一个张量,输出也是一个张量def attention(self, h_: torch.Tensor) -> torch.Tensor:# 先对输入进行归一化处理h_ = self.norm(h_)# 通过线性变换生成查询张量q = self.q(h_)# 通过线性变换生成键张量k = self.k(h_)# 通过线性变换生成值张量v = self.v(h_)# 计算注意力# 获取查询张量的形状信息B, C, H, W = q.shape# 调整张量形状,将其从四维转为二维q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))# 对查询、键、值进行维度调整以便计算注意力q, k, v = map(lambda t: t.unsqueeze(3)  # 在最后增加一个维度.reshape(B, t.shape[1], 1, C)  # 调整形状.permute(0, 2, 1, 3)  # 变换维度顺序.reshape(B * 1, t.shape[1], C)  # 重新调整形状.contiguous(),  # 保证内存连续性(q, k, v),)# 使用内存高效的注意力操作out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)# 调整输出张量的形状out = (out.unsqueeze(0)  # 增加一个维度.reshape(B, 1, out.shape[1], C)  # 调整形状.permute(0, 2, 1, 3)  # 变换维度顺序.reshape(B, out.shape[1], C)  # 重新调整形状)# 将输出张量的形状恢复为原来的格式return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)# 定义前向传播函数def forward(self, x, **kwargs):# 输入数据赋值给 h_h_ = x# 通过注意力机制处理 h_h_ = self.attention(h_)# 通过输出投影处理 h_h_ = self.proj_out(h_)# 返回输入和处理后的 h_ 的和return x + h_
# 定义一个内存高效的交叉注意力包装类,继承自 MemoryEfficientCrossAttention
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):# 前向传播方法,接受输入张量和可选的上下文、掩码def forward(self, x, context=None, mask=None, **unused_kwargs):# 解包输入张量的维度:批量大小、通道数、高度和宽度b, c, h, w = x.shape# 重新排列输入张量的维度,将 (b, c, h, w) 转换为 (b, h*w, c)x = rearrange(x, "b c h w -> b (h w) c")# 调用父类的 forward 方法,处理重新排列后的输入out = super().forward(x, context=context, mask=mask)# 将输出张量的维度重新排列回 (b, c, h, w)out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)# 返回输入与输出的和,进行残差连接return x + out# 定义一个生成注意力模块的函数
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):# 检查传入的注意力类型是否在支持的类型列表中assert attn_type in ["vanilla","vanilla-xformers","memory-efficient-cross-attn","linear","none",], f"attn_type {attn_type} unknown"# 检查 PyTorch 版本,并且如果类型不是 "none",则验证是否可用 xformersif (version.parse(torch.__version__) < version.parse("2.0.0")and attn_type != "none"):assert XFORMERS_IS_AVAILABLE, (f"We do not support vanilla attention in {torch.__version__} anymore, "f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'")# 将注意力类型设置为 "vanilla-xformers"attn_type = "vanilla-xformers"# 根据注意力类型生成相应的注意力块if attn_type == "vanilla":# 验证注意力参数不为 Noneassert attn_kwargs is None# 返回标准的注意力块return AttnBlock(in_channels)elif attn_type == "vanilla-xformers":# 返回内存高效的注意力块return MemoryEfficientAttnBlock(in_channels)elif attn_type == "memory-efficient-cross-attn":# 设置查询维度为输入通道数attn_kwargs["query_dim"] = in_channels# 返回内存高效的交叉注意力包装类return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)elif attn_type == "none":# 返回一个身份映射层,不改变输入return nn.Identity(in_channels)else:# 返回线性注意力块return LinAttnBlock(in_channels)# 定义一个模型类,继承自 nn.Module
class Model(nn.Module):# 初始化方法,接受多个参数进行模型构建def __init__(self,*,ch,out_ch,ch_mult=(1, 2, 4, 8),num_res_blocks,attn_resolutions,dropout=0.0,resamp_with_conv=True,in_channels,resolution,use_timestep=True,use_linear_attn=False,attn_type="vanilla",# 定义前向传播方法,接受输入 x、时间步 t 和上下文 contextdef forward(self, x, t=None, context=None):# 确保输入 x 的高度和宽度与设定的分辨率相等(被注释掉)# assert x.shape[2] == x.shape[3] == self.resolution# 如果上下文不为 None,沿通道维度连接输入 x 和上下文if context is not None:# 假设上下文对齐,沿通道轴拼接x = torch.cat((x, context), dim=1)# 如果使用时间步,进行时间步嵌入if self.use_timestep:# 确保时间步 t 不为 Noneassert t is not None# 获取时间步嵌入temb = get_timestep_embedding(t, self.ch)# 通过第一层密集层处理时间步嵌入temb = self.temb.dense[0](temb)# 应用非线性变换temb = nonlinearity(temb)# 通过第二层密集层处理temb = self.temb.dense[1](temb)else:# 如果不使用时间步,设置时间步嵌入为 Nonetemb = None# 下采样hs = [self.conv_in(x)]  # 初始卷积层的输出for i_level in range(self.num_resolutions):for i_block in range(self.num_res_blocks):# 通过当前下采样层和时间步嵌入处理前一层输出h = self.down[i_level].block[i_block](hs[-1], temb)# 如果存在注意力层,则对输出进行注意力处理if len(self.down[i_level].attn) > 0:h = self.down[i_level].attn[i_block](h)# 将处理后的输出添加到列表hs.append(h)# 如果不是最后一层分辨率,进行下采样if i_level != self.num_resolutions - 1:hs.append(self.down[i_level].downsample(hs[-1]))# 中间处理h = hs[-1]  # 获取最后一层的输出h = self.mid.block_1(h, temb)  # 通过中间块处理h = self.mid.attn_1(h)  # 通过中间注意力层处理h = self.mid.block_2(h, temb)  # 再次通过中间块处理# 上采样for i_level in reversed(range(self.num_resolutions)):for i_block in range(self.num_res_blocks + 1):# 拼接上层输出和当前层的输出,然后通过上采样块处理h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)# 如果存在注意力层,则对输出进行注意力处理if len(self.up[i_level].attn) > 0:h = self.up[i_level].attn[i_block](h)# 如果不是第一层分辨率,进行上采样if i_level != 0:h = self.up[i_level].upsample(h)# 结束处理h = self.norm_out(h)  # 最后的归一化处理h = nonlinearity(h)  # 应用非线性变换h = self.conv_out(h)  # 通过输出卷积层处理return h  # 返回最终输出# 获取最后一层的卷积权重def get_last_layer(self):return self.conv_out.weight  # 返回输出卷积层的权重
# 定义一个编码器类,继承自 nn.Module
class Encoder(nn.Module):# 初始化方法,接收多个参数用于配置编码器def __init__(self,*,ch,out_ch,ch_mult=(1, 2, 4, 8),num_res_blocks,attn_resolutions,dropout=0.0,resamp_with_conv=True,in_channels,resolution,z_channels,double_z=True,use_linear_attn=False,attn_type="vanilla",mid_attn=True,**ignore_kwargs,):# 调用父类构造方法super().__init__()# 如果使用线性注意力,设置注意力类型为线性if use_linear_attn:attn_type = "linear"# 保存输入参数以供后续使用self.ch = chself.temb_ch = 0self.num_resolutions = len(ch_mult)self.num_res_blocks = num_res_blocksself.resolution = resolutionself.in_channels = in_channelsself.attn_resolutions = attn_resolutionsself.mid_attn = mid_attn# 下采样# 定义输入卷积层self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)# 当前分辨率初始化curr_res = resolution# 定义输入通道的倍率in_ch_mult = (1,) + tuple(ch_mult)self.in_ch_mult = in_ch_mult# 初始化下采样模块列表self.down = nn.ModuleList()# 遍历每个分辨率层级for i_level in range(self.num_resolutions):# 初始化块和注意力模块列表block = nn.ModuleList()attn = nn.ModuleList()# 输入和输出通道数计算block_in = ch * in_ch_mult[i_level]block_out = ch * ch_mult[i_level]# 遍历每个残差块for i_block in range(self.num_res_blocks):# 添加残差块到块列表中block.append(ResnetBlock(in_channels=block_in,out_channels=block_out,temb_channels=self.temb_ch,dropout=dropout,))# 更新输入通道数为当前块的输出通道数block_in = block_out# 如果当前分辨率在注意力分辨率列表中,添加注意力模块if curr_res in attn_resolutions:attn.append(make_attn(block_in, attn_type=attn_type))# 创建下采样模块down = nn.Module()down.block = blockdown.attn = attn# 如果不是最后一个分辨率,添加下采样层if i_level != self.num_resolutions - 1:down.downsample = Downsample(block_in, resamp_with_conv)# 更新当前分辨率为一半curr_res = curr_res // 2# 将下采样模块添加到列表中self.down.append(down)# 中间层self.mid = nn.Module()# 添加第一个残差块self.mid.block_1 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout,)# 如果使用中间注意力,添加注意力模块if mid_attn:self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)# 添加第二个残差块self.mid.block_2 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout,)# 结束层# 定义归一化层self.norm_out = Normalize(block_in)# 定义输出卷积层,根据是否双 z 通道设置输出通道数self.conv_out = torch.nn.Conv2d(block_in,2 * z_channels if double_z else z_channels,kernel_size=3,stride=1,padding=1,)# 定义前向传播方法,接受输入数据 xdef forward(self, x):# 时间步嵌入初始化为 Nonetemb = None# 下采样过程# 对输入 x 进行卷积操作,生成初始特征图 hshs = [self.conv_in(x)]# 遍历每个分辨率层for i_level in range(self.num_resolutions):# 遍历当前分辨率层中的每个残差块for i_block in range(self.num_res_blocks):# 使用当前层的残差块处理上一个层的输出和时间步嵌入h = self.down[i_level].block[i_block](hs[-1], temb)# 如果当前层有注意力机制,则应用注意力if len(self.down[i_level].attn) > 0:h = self.down[i_level].attn[i_block](h)# 将当前层的输出添加到特征图列表中hs.append(h)# 如果当前层不是最后一个分辨率层,则进行下采样if i_level != self.num_resolutions - 1:hs.append(self.down[i_level].downsample(hs[-1]))# 中间处理阶段h = hs[-1]  # 获取最后一层的输出# 通过中间块1处理输入h = self.mid.block_1(h, temb)# 如果中间层有注意力机制,则应用注意力if self.mid_attn:h = self.mid.attn_1(h)# 通过中间块2处理输出h = self.mid.block_2(h, temb)# 最终处理阶段h = self.norm_out(h)  # 应用输出归一化h = nonlinearity(h)   # 应用非线性激活函数h = self.conv_out(h)  # 通过输出卷积生成最终结果return h  # 返回最终输出
# 定义一个解码器类,继承自 PyTorch 的 nn.Module
class Decoder(nn.Module):# 初始化方法,定义解码器的参数def __init__(self,*,ch,  # 输入通道数out_ch,  # 输出通道数ch_mult=(1, 2, 4, 8),  # 通道数的倍增因子num_res_blocks,  # 残差块的数量attn_resolutions,  # 注意力机制应用的分辨率dropout=0.0,  # dropout 比例,默认值为 0resamp_with_conv=True,  # 是否使用卷积进行上采样in_channels,  # 输入的通道数resolution,  # 输入的分辨率z_channels,  # 潜在变量的通道数give_pre_end=False,  # 是否在前面给予额外的结束标志tanh_out=False,  # 输出是否经过 tanh 激活use_linear_attn=False,  # 是否使用线性注意力机制attn_type="vanilla",  # 注意力类型,默认为“vanilla”mid_attn=True,  # 是否在中间层使用注意力**ignorekwargs,  # 其他忽略的参数,采用关键字参数形式):# 初始化父类super().__init__()# 如果使用线性注意力机制,设置注意力类型为线性if use_linear_attn:attn_type = "linear"# 设置通道数self.ch = ch# 初始化时间嵌入通道数为0self.temb_ch = 0# 计算分辨率数量self.num_resolutions = len(ch_mult)# 设置残差块数量self.num_res_blocks = num_res_blocks# 设置输入分辨率self.resolution = resolution# 设置输入通道数self.in_channels = in_channels# 设置是否给出前置结束标志self.give_pre_end = give_pre_end# 设置激活函数输出self.tanh_out = tanh_out# 设置注意力分辨率self.attn_resolutions = attn_resolutions# 设置中间注意力self.mid_attn = mid_attn# 计算输入通道倍数、块输入通道和当前最低分辨率in_ch_mult = (1,) + tuple(ch_mult)# 计算当前块的输入通道数block_in = ch * ch_mult[self.num_resolutions - 1]# 计算当前分辨率curr_res = resolution // 2 ** (self.num_resolutions - 1)# 设置潜在变量的形状self.z_shape = (1, z_channels, curr_res, curr_res)# print(#     "Working with z of shape {} = {} dimensions.".format(#         self.z_shape, np.prod(self.z_shape)#     )# )# 创建注意力和残差块类make_attn_cls = self._make_attn()make_resblock_cls = self._make_resblock()make_conv_cls = self._make_conv()# 将潜在变量映射到块输入通道self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)# 中间层self.mid = nn.Module()# 创建第一个残差块self.mid.block_1 = make_resblock_cls(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout,)# 如果启用中间注意力,创建注意力层if mid_attn:self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)# 创建第二个残差块self.mid.block_2 = make_resblock_cls(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout,)# 上采样层self.up = nn.ModuleList()# 从高到低遍历每个分辨率级别for i_level in reversed(range(self.num_resolutions)):block = nn.ModuleList()  # 残差块列表attn = nn.ModuleList()   # 注意力层列表# 计算当前块的输出通道数block_out = ch * ch_mult[i_level]# 创建每个残差块for i_block in range(self.num_res_blocks + 1):block.append(make_resblock_cls(in_channels=block_in,out_channels=block_out,temb_channels=self.temb_ch,dropout=dropout,))# 更新块输入通道block_in = block_out# 如果当前分辨率在注意力分辨率中,添加注意力层if curr_res in attn_resolutions:attn.append(make_attn_cls(block_in, attn_type=attn_type))up = nn.Module()  # 上采样模块up.block = block  # 添加残差块up.attn = attn   # 添加注意力层# 如果不是最低分辨率,添加上采样层if i_level != 0:up.upsample = Upsample(block_in, resamp_with_conv)# 更新当前分辨率curr_res = curr_res * 2# 将上采样模块插入列表的开头self.up.insert(0, up)  # prepend to get consistent order# 结束层# 创建归一化层self.norm_out = Normalize(block_in)# 创建输出卷积层self.conv_out = make_conv_cls(block_in, out_ch, kernel_size=3, stride=1, padding=1)# 定义一个私有方法,用于返回注意力机制的构造函数def _make_attn(self) -> Callable:return make_attn# 定义一个私有方法,用于返回残差块的构造函数def _make_resblock(self) -> Callable:return ResnetBlock# 定义一个私有方法,用于返回二维卷积层的构造函数def _make_conv(self) -> Callable:return torch.nn.Conv2d# 获取最后一层的权重def get_last_layer(self, **kwargs):return self.conv_out.weight# 前向传播方法,接收输入 z 和可选参数def forward(self, z, **kwargs):# 确保输入 z 的形状与预期相同(被注释掉的检查)# assert z.shape[1:] == self.z_shape[1:]# 记录输入 z 的形状self.last_z_shape = z.shape# 初始化时间步嵌入temb = None# 将输入 z 传入卷积层h = self.conv_in(z)# 中间处理h = self.mid.block_1(h, temb, **kwargs)  # 通过第一块中间块处理if self.mid_attn:  # 如果启用了中间注意力h = self.mid.attn_1(h, **kwargs)  # 应用中间注意力层h = self.mid.block_2(h, temb, **kwargs)  # 通过第二块中间块处理# 上采样过程for i_level in reversed(range(self.num_resolutions)):  # 从最高分辨率到最低分辨率for i_block in range(self.num_res_blocks + 1):  # 遍历每个残差块h = self.up[i_level].block[i_block](h, temb, **kwargs)  # 通过上采样块处理if len(self.up[i_level].attn) > 0:  # 如果存在注意力层h = self.up[i_level].attn[i_block](h, **kwargs)  # 应用注意力层if i_level != 0:  # 如果不是最低分辨率h = self.up[i_level].upsample(h)  # 执行上采样# 结束处理if self.give_pre_end:  # 如果启用了预处理结束返回return hh = self.norm_out(h)  # 对输出进行归一化h = nonlinearity(h)  # 应用非线性激活函数h = self.conv_out(h, **kwargs)  # 通过最终卷积层处理if self.tanh_out:  # 如果启用了 Tanh 输出h = torch.tanh(h)  # 应用 Tanh 激活函数return h  # 返回最终输出

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\openaimodel.py

# 导入操作系统模块,用于处理文件和目录操作
import os
# 导入数学模块,提供数学函数和常量
import math
# 从 abc 模块导入抽象方法装饰器,用于定义抽象基类
from abc import abstractmethod
# 从 functools 模块导入 partial 函数,用于偏函数应用
from functools import partial
# 从 typing 模块导入类型注解,用于类型提示
from typing import Iterable, List, Optional, Tuple, Union# 导入 numpy 库,通常用于数值计算
import numpy as np
# 导入 torch 库,通常用于深度学习
import torch as th
# 导入 PyTorch 的神经网络模块
import torch.nn as nn
# 导入 PyTorch 的功能模块,提供激活函数等
import torch.nn.functional as F
# 从 einops 导入 rearrange 函数,用于重排张量
from einops import rearrange# 导入自定义模块中的 SpatialTransformer 类
from ...modules.attention import SpatialTransformer
# 导入自定义模块中的实用函数
from ...modules.diffusionmodules.util import (avg_pool_nd,  # 平均池化函数checkpoint,   # 检查点函数conv_nd,      # 卷积函数linear,       # 线性变换函数normalization, # 归一化函数timestep_embedding, # 时间步嵌入函数zero_module,  # 零模块函数
)# 导入自定义模块中的实用函数
from ...util import default, exists# 定义一个空的占位函数,用于将模块转换为半精度浮点数
# dummy replace
def convert_module_to_f16(x):pass# 定义一个空的占位函数,用于将模块转换为单精度浮点数
def convert_module_to_f32(x):pass# 定义一个用于注意力池化的类,继承自 nn.Module
## go
class AttentionPool2d(nn.Module):"""从 CLIP 中改编: https://github.com/openai/CLIP/blob/main/clip/model.py"""# 初始化方法,设置各类参数def __init__(self,spacial_dim: int,  # 空间维度embed_dim: int,    # 嵌入维度num_heads_channels: int,  # 头通道数量output_dim: int = None,  # 输出维度(可选)):# 调用父类初始化方法super().__init__()# 定义位置嵌入参数,初始化为正态分布self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)# 定义查询、键、值的卷积投影self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)# 定义输出的卷积投影self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)# 计算头的数量self.num_heads = embed_dim // num_heads_channels# 初始化注意力机制self.attention = QKVAttention(self.num_heads)# 前向传播方法def forward(self, x):# 获取输入的批次大小和通道数b, c, *_spatial = x.shape# 将输入重塑为 (批次, 通道, 高*宽) 的形状x = x.reshape(b, c, -1)  # NC(HW)# 在最后一维上添加均值作为额外的特征x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)  # NC(HW+1)# 将位置嵌入加到输入上x = x + self.positional_embedding[None, :, :].to(x.dtype)  # NC(HW+1)# 对输入进行查询、键、值投影x = self.qkv_proj(x)# 应用注意力机制x = self.attention(x)# 对结果进行输出投影x = self.c_proj(x)# 返回第一个通道的结果return x[:, :, 0]# 定义一个时间步模块的基类,继承自 nn.Module
class TimestepBlock(nn.Module):"""任何模块的 forward() 方法接受时间步嵌入作为第二个参数。"""# 定义抽象的前向传播方法@abstractmethoddef forward(self, x, emb):"""将模块应用于 `x`,并给定 `emb` 时间步嵌入。"""# 定义一个时间步嵌入的顺序模块,继承自 nn.Sequential 和 TimestepBlock
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):"""一个顺序模块,将时间步嵌入作为额外输入传递给支持的子模块。"""# 重写前向传播方法def forward(self,x: th.Tensor,  # 输入张量emb: th.Tensor,  # 时间步嵌入张量context: Optional[th.Tensor] = None,  # 上下文张量(可选)):# 遍历所有子模块for layer in self:module = layer# 如果子模块是 TimestepBlock,则使用时间步嵌入进行计算if isinstance(module, TimestepBlock):x = layer(x, emb)# 如果子模块是 SpatialTransformer,则使用上下文进行计算elif isinstance(module, SpatialTransformer):x = layer(x, context)# 否则,仅使用输入进行计算else:x = layer(x)# 返回最终的输出return x# 定义一个上采样模块,继承自 nn.Module
class Upsample(nn.Module):"""一个可选卷积的上采样层。:param channels: 输入和输出的通道数。:param use_conv: 布尔值,确定是否应用卷积。:param dims: 确定信号是 1D、2D 还是 3D。如果是 3D,则在内两个维度上进行上采样。"""# 初始化方法,设置类的基本属性def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, third_up=False):# 调用父类初始化方法super().__init__()# 保存输入的通道数self.channels = channels# 如果没有指定输出通道数,则默认与输入通道数相同self.out_channels = out_channels or channels# 保存是否使用卷积的标志self.use_conv = use_conv# 保存维度信息self.dims = dims# 保存是否进行第三层上采样的标志self.third_up = third_up# 如果使用卷积,初始化卷积层if use_conv:self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)# 前向传播方法,定义输入如何通过网络进行处理def forward(self, x):# 确保输入的通道数与初始化时指定的通道数一致assert x.shape[1] == self.channels# 如果输入为三维数据if self.dims == 3:# 根据是否需要第三层上采样确定时间因子t_factor = 1 if not self.third_up else 2# 对输入进行上采样x = F.interpolate(x,(t_factor * x.shape[2], x.shape[3] * 2, x.shape[4] * 2),mode="nearest",)else:# 对输入进行上采样,比例因子为2x = F.interpolate(x, scale_factor=2, mode="nearest")# 如果使用卷积,则将输入通过卷积层处理if self.use_conv:x = self.conv(x)# 返回处理后的输出return x
# 定义一个转置上采样的类,继承自 nn.Module
class TransposedUpsample(nn.Module):"Learned 2x upsampling without padding"  # 文档字符串,描述该类的功能# 初始化方法,设置输入通道、输出通道和卷积核大小def __init__(self, channels, out_channels=None, ks=5):super().__init__()  # 调用父类的初始化方法self.channels = channels  # 保存输入通道数量self.out_channels = out_channels or channels  # 如果没有指定输出通道,则与输入通道相同# 定义一个转置卷积层,用于上采样self.up = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=ks, stride=2)# 前向传播方法,执行上采样操作def forward(self, x):return self.up(x)  # 返回上采样后的结果# 定义一个下采样层的类,继承自 nn.Module
class Downsample(nn.Module):"""A downsampling layer with an optional convolution.:param channels: channels in the inputs and outputs.:param use_conv: a bool determining if a convolution is applied.:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, thendownsampling occurs in the inner-two dimensions."""# 初始化方法,设置输入通道、是否使用卷积、维度等参数def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, third_down=False):super().__init__()  # 调用父类的初始化方法self.channels = channels  # 保存输入通道数量self.out_channels = out_channels or channels  # 如果没有指定输出通道,则与输入通道相同self.use_conv = use_conv  # 保存是否使用卷积的标志self.dims = dims  # 保存信号的维度stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))  # 确定步幅if use_conv:  # 如果使用卷积# print(f"Building a Downsample layer with {dims} dims.")  # 打印信息,表示正在构建下采样层# print(#     f"  --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "#     f"kernel-size: 3, stride: {stride}, padding: {padding}"# )  # 打印卷积层的设置参数# if dims == 3:#     print(f"  --> Downsampling third axis (time): {third_down}")  # 打印是否在第三维进行下采样# 定义卷积操作self.op = conv_nd(dims,self.channels,self.out_channels,3,stride=stride,padding=padding,)else:  # 如果不使用卷积assert self.channels == self.out_channels  # 确保输入通道与输出通道相同# 定义平均池化操作self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)# 前向传播方法,执行下采样操作def forward(self, x):assert x.shape[1] == self.channels  # 确保输入的通道数匹配return self.op(x)  # 返回下采样后的结果# 定义一个残差块的类,继承自 TimestepBlock
class ResBlock(TimestepBlock):"""A residual block that can optionally change the number of channels.:param channels: the number of input channels.:param emb_channels: the number of timestep embedding channels.:param dropout: the rate of dropout.:param out_channels: if specified, the number of out channels.:param use_conv: if True and out_channels is specified, use a spatialconvolution instead of a smaller 1x1 convolution to change thechannels in the skip connection.:param dims: determines if the signal is 1D, 2D, or 3D.:param use_checkpoint: if True, use gradient checkpointing on this module.:param up: if True, use this block for upsampling.:param down: if True, use this block for downsampling."""# 初始化方法,用于创建类的实例def __init__(self,channels,  # 输入通道数emb_channels,  # 嵌入通道数dropout,  # 丢弃率out_channels=None,  # 输出通道数,默认为 Noneuse_conv=False,  # 是否使用卷积use_scale_shift_norm=False,  # 是否使用缩放位移归一化dims=2,  # 数据维度,默认为 2use_checkpoint=False,  # 是否使用检查点up=False,  # 是否进行上采样down=False,  # 是否进行下采样kernel_size=3,  # 卷积核大小,默认为 3exchange_temb_dims=False,  # 是否交换时间嵌入维度skip_t_emb=False,  # 是否跳过时间嵌入):# 调用父类初始化方法super().__init__()# 设置输入通道数self.channels = channels# 设置嵌入通道数self.emb_channels = emb_channels# 设置丢弃率self.dropout = dropout# 设置输出通道数,如果未提供则默认与输入通道数相同self.out_channels = out_channels or channels# 设置是否使用卷积self.use_conv = use_conv# 设置是否使用检查点self.use_checkpoint = use_checkpoint# 设置是否使用缩放位移归一化self.use_scale_shift_norm = use_scale_shift_norm# 设置是否交换时间嵌入维度self.exchange_temb_dims = exchange_temb_dims# 如果卷积核大小是可迭代的,计算每个维度的填充大小if isinstance(kernel_size, Iterable):padding = [k // 2 for k in kernel_size]else:# 否则直接计算单个卷积核的填充大小padding = kernel_size // 2# 创建输入层的序列,包括归一化、激活函数和卷积操作self.in_layers = nn.Sequential(normalization(channels),  # 归一化nn.SiLU(),  # SiLU 激活函数conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),  # 卷积层)# 判断是否进行上采样或下采样self.updown = up or down# 如果进行上采样,初始化上采样层if up:self.h_upd = Upsample(channels, False, dims)  # 上采样层self.x_upd = Upsample(channels, False, dims)  # 上采样层# 如果进行下采样,初始化下采样层elif down:self.h_upd = Downsample(channels, False, dims)  # 下采样层self.x_upd = Downsample(channels, False, dims)  # 下采样层# 否则使用身份映射else:self.h_upd = self.x_upd = nn.Identity()  # 身份映射层# 设置是否跳过时间嵌入self.skip_t_emb = skip_t_emb# 根据是否使用缩放位移归一化计算嵌入输出通道数self.emb_out_channels = (2 * self.out_channels if use_scale_shift_norm else self.out_channels)# 如果跳过时间嵌入,输出警告并设置嵌入层为 Noneif self.skip_t_emb:print(f"Skipping timestep embedding in {self.__class__.__name__}")  # 警告信息assert not self.use_scale_shift_norm  # 确保不使用缩放位移归一化self.emb_layers = None  # 嵌入层设置为 Noneself.exchange_temb_dims = False  # 不交换时间嵌入维度# 否则创建嵌入层的序列else:self.emb_layers = nn.Sequential(nn.SiLU(),  # SiLU 激活函数linear(emb_channels,  # 嵌入通道数self.emb_out_channels,  # 嵌入输出通道数),)# 创建输出层的序列,包括归一化、激活函数、丢弃层和卷积层self.out_layers = nn.Sequential(normalization(self.out_channels),  # 归一化nn.SiLU(),  # SiLU 激活函数nn.Dropout(p=dropout),  # 丢弃层zero_module(conv_nd(dims,  # 数据维度self.out_channels,  # 输出通道数self.out_channels,  # 输出通道数kernel_size,  # 卷积核大小padding=padding,  # 填充)),  # 卷积层)# 根据输入和输出通道数设置跳过连接if self.out_channels == channels:self.skip_connection = nn.Identity()  # 身份映射层elif use_conv:self.skip_connection = conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding  # 卷积层)else:self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)  # 卷积层,卷积核大小为 1# 定义前向传播函数,接受输入张量和时间步嵌入def forward(self, x, emb):"""Apply the block to a Tensor, conditioned on a timestep embedding.:param x: an [N x C x ...] Tensor of features.:param emb: an [N x emb_channels] Tensor of timestep embeddings.:return: an [N x C x ...] Tensor of outputs."""# 调用检查点函数以保存中间计算结果,减少内存使用return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint)# 定义实际的前向传播逻辑def _forward(self, x, emb):# 如果设置了 updown,则进行上采样和下采样if self.updown:# 分离输入层的最后一层和其他层in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]# 通过其他输入层处理输入 xh = in_rest(x)# 更新隐藏状态h = self.h_upd(h)# 更新输入 xx = self.x_upd(x)# 通过卷积层处理隐藏状态h = in_conv(h)else:# 直接通过输入层处理输入 xh = self.in_layers(x)# 如果跳过时间嵌入,则初始化嵌入输出为零张量if self.skip_t_emb:emb_out = th.zeros_like(h)else:# 通过嵌入层处理时间嵌入,确保数据类型与 h 一致emb_out = self.emb_layers(emb).type(h.dtype)# 扩展 emb_out 的形状以匹配 h 的形状while len(emb_out.shape) < len(h.shape):emb_out = emb_out[..., None]# 如果使用缩放和偏移规范化if self.use_scale_shift_norm:# 分离输出层中的规范化层和其他层out_norm, out_rest = self.out_layers[0], self.out_layers[1:]# 将嵌入输出分割为缩放和偏移scale, shift = th.chunk(emb_out, 2, dim=1)# 对隐藏状态进行规范化并应用缩放和偏移h = out_norm(h) * (1 + scale) + shift# 通过剩余的输出层处理隐藏状态h = out_rest(h)else:# 如果交换时间嵌入的维度if self.exchange_temb_dims:# 重新排列嵌入输出的维度emb_out = rearrange(emb_out, "b t c ... -> b c t ...")# 将嵌入输出与隐藏状态相加h = h + emb_out# 通过输出层处理隐藏状态h = self.out_layers(h)# 返回输入 x 与处理后的隐藏状态的跳跃连接return self.skip_connection(x) + h
# 定义一个注意力模块,允许空间位置相互关注
class AttentionBlock(nn.Module):"""An attention block that allows spatial positions to attend to each other.Originally ported from here, but adapted to the N-d case.https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66."""# 初始化方法,定义模块的基本参数def __init__(self,channels,  # 输入通道数num_heads=1,  # 注意力头的数量,默认为1num_head_channels=-1,  # 每个头的通道数,默认为-1use_checkpoint=False,  # 是否使用检查点use_new_attention_order=False,  # 是否使用新的注意力顺序):# 调用父类初始化方法super().__init__()self.channels = channels  # 保存输入通道数# 判断 num_head_channels 是否为 -1if num_head_channels == -1:self.num_heads = num_heads  # 如果为 -1,直接使用 num_headselse:# 断言通道数可以被 num_head_channels 整除assert (channels % num_head_channels == 0), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"self.num_heads = channels // num_head_channels  # 计算头的数量self.use_checkpoint = use_checkpoint  # 保存检查点标志self.norm = normalization(channels)  # 初始化归一化层self.qkv = conv_nd(1, channels, channels * 3, 1)  # 创建卷积层用于计算 q, k, v# 根据是否使用新注意力顺序选择相应的注意力类if use_new_attention_order:# 在分割头之前分割 qkvself.attention = QKVAttention(self.num_heads)else:# 在分割 qkv 之前分割头self.attention = QKVAttentionLegacy(self.num_heads)# 初始化输出投影层self.proj_out = zero_module(conv_nd(1, channels, channels, 1))# 前向传播方法def forward(self, x, **kwargs):# TODO 添加跨帧注意力并使用混合检查点# 使用检查点机制来调用内部前向传播函数return checkpoint(self._forward, (x,), self.parameters(), True)  # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!# return pt_checkpoint(self._forward, x)  # pytorch# 内部前向传播方法def _forward(self, x):b, c, *spatial = x.shape  # 解包输入张量的形状x = x.reshape(b, c, -1)  # 将输入张量重塑为 (batch_size, channels, spatial_dim)qkv = self.qkv(self.norm(x))  # 计算 q, k, vh = self.attention(qkv)  # 应用注意力机制h = self.proj_out(h)  # 对注意力结果进行投影return (x + h).reshape(b, c, *spatial)  # 返回重塑后的结果# 计算注意力操作的 FLOPS
def count_flops_attn(model, _x, y):"""A counter for the `thop` package to count the operations in anattention operation.Meant to be used like:macs, params = thop.profile(model,inputs=(inputs, timestamps),custom_ops={QKVAttention: QKVAttention.count_flops},)"""b, c, *spatial = y[0].shape  # 解包输入张量的形状num_spatial = int(np.prod(spatial))  # 计算空间维度的总数# 进行两个矩阵乘法,具有相同数量的操作。# 第一个计算权重矩阵,第二个计算值向量的组合。matmul_ops = 2 * b * (num_spatial**2) * c  # 计算矩阵乘法的操作数model.total_ops += th.DoubleTensor([matmul_ops])  # 将操作数累加到模型的总操作数中# 旧版 QKV 注意力模块
class QKVAttentionLegacy(nn.Module):"""A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping"""# 初始化方法,设置注意力头的数量def __init__(self, n_heads):super().__init__()  # 调用父类初始化方法self.n_heads = n_heads  # 保存注意力头的数量# 定义前向传播方法,接收 QKV 张量def forward(self, qkv):"""应用 QKV 注意力机制。:param qkv: 一个形状为 [N x (H * 3 * C) x T] 的张量,包含 Q、K 和 V。:return: 一个形状为 [N x (H * C) x T] 的张量,经过注意力处理后输出。"""# 获取输入张量的批量大小、宽度和长度bs, width, length = qkv.shape# 确保宽度可以被 (3 * n_heads) 整除,以分割 Q、K 和 Vassert width % (3 * self.n_heads) == 0# 计算每个头的通道数ch = width // (3 * self.n_heads)# 将 qkv 张量重塑并分割成 Q、K 和 V 三个部分q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)# 计算缩放因子,用于稳定性scale = 1 / math.sqrt(math.sqrt(ch))# 使用爱因斯坦求和约定计算注意力权重,乘以缩放因子weight = th.einsum("bct,bcs->bts", q * scale, k * scale)  # 使用 f16 比后续除法更稳定# 对权重进行 softmax 归一化,并保持原始数据类型weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)# 根据权重和 V 计算输出张量a = th.einsum("bts,bcs->bct", weight, v)# 将输出张量重塑为原始批量大小和通道数return a.reshape(bs, -1, length)# 定义静态方法以计算模型的浮点运算数@staticmethoddef count_flops(model, _x, y):# 调用辅助函数计算注意力层的浮点运算数return count_flops_attn(model, _x, y)
# 定义一个名为 QKVAttention 的类,继承自 nn.Module
class QKVAttention(nn.Module):"""A module which performs QKV attention and splits in a different order."""# 初始化方法,接收注意力头的数量def __init__(self, n_heads):super().__init__()  # 调用父类的初始化方法self.n_heads = n_heads  # 保存注意力头的数量# 前向传播方法,接收 qkv 张量并执行注意力计算def forward(self, qkv):"""Apply QKV attention.:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.:return: an [N x (H * C) x T] tensor after attention."""bs, width, length = qkv.shape  # 解包 qkv 张量的维度assert width % (3 * self.n_heads) == 0  # 确保宽度能够被注意力头数量整除ch = width // (3 * self.n_heads)  # 计算每个头的通道数q, k, v = qkv.chunk(3, dim=1)  # 将 qkv 张量分成 Q, K, V 三部分scale = 1 / math.sqrt(math.sqrt(ch))  # 计算缩放因子weight = th.einsum("bct,bcs->bts",  # 定义爱因斯坦求和约定,计算权重(q * scale).view(bs * self.n_heads, ch, length),  # 缩放后的 Q 重塑形状(k * scale).view(bs * self.n_heads, ch, length),  # 缩放后的 K 重塑形状)  # More stable with f16 than dividing afterwardsweight = th.softmax(weight.float(), dim=-1).type(weight.dtype)  # 计算权重的 softmax,确保其和为 1a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))  # 计算最终的注意力输出return a.reshape(bs, -1, length)  # 将输出重塑回原始批量形状@staticmethod# 计算 FLOPs 的静态方法def count_flops(model, _x, y):return count_flops_attn(model, _x, y)  # 调用函数计算注意力层的 FLOPs# 定义一个名为 Timestep 的类,继承自 nn.Module
class Timestep(nn.Module):def __init__(self, dim):super().__init__()  # 调用父类的初始化方法self.dim = dim  # 保存时间步的维度# 前向传播方法,接收时间步张量def forward(self, t):return timestep_embedding(t, self.dim)  # 调用时间步嵌入函数# 定义一个字典,将字符串类型映射到对应的 PyTorch 数据类型
str_to_dtype = {"fp32": th.float32,  # fp32 对应 float32"fp16": th.float16,  # fp16 对应 float16"bf16": th.bfloat16   # bf16 对应 bfloat16
}# 定义一个名为 UNetModel 的类,继承自 nn.Module
class UNetModel(nn.Module):"""The full UNet model with attention and timestep embedding.:param in_channels: channels in the input Tensor.:param model_channels: base channel count for the model.:param out_channels: channels in the output Tensor.:param num_res_blocks: number of residual blocks per downsample.:param attention_resolutions: a collection of downsample rates at whichattention will take place. May be a set, list, or tuple.For example, if this contains 4, then at 4x downsampling, attentionwill be used.:param dropout: the dropout probability.:param channel_mult: channel multiplier for each level of the UNet.:param conv_resample: if True, use learned convolutions for upsampling anddownsampling.:param dims: determines if the signal is 1D, 2D, or 3D.:param num_classes: if specified (as an int), then this model will beclass-conditional with `num_classes` classes.:param use_checkpoint: use gradient checkpointing to reduce memory usage.:param num_heads: the number of attention heads in each attention layer.:param num_heads_channels: if specified, ignore num_heads and instead usea fixed channel width per attention head.:param num_heads_upsample: works with num_heads to set a different numberof heads for upsampling. Deprecated.:param use_scale_shift_norm: use a FiLM-like conditioning mechanism."""# 参数 resblock_updown:是否在上采样/下采样过程中使用残差块# 参数 use_new_attention_order:是否使用不同的注意力模式以提高效率"""# 初始化方法def __init__(# 输入通道数self,in_channels,# 模型通道数model_channels,# 输出通道数out_channels,# 残差块的数量num_res_blocks,# 注意力分辨率attention_resolutions,# dropout 比例,默认为 0dropout=0,# 通道的倍增因子,默认值为 (1, 2, 4, 8)channel_mult=(1, 2, 4, 8),# 是否使用卷积重采样,默认为 Trueconv_resample=True,# 数据维度,默认为 2dims=2,# 类别数,默认为 Nonenum_classes=None,# 是否使用检查点,默认为 Falseuse_checkpoint=False,# 是否使用 fp16 精度,默认为 Falseuse_fp16=False,# 注意力头数,默认为 -1num_heads=-1,# 每个头的通道数,默认为 -1num_head_channels=-1,# 上采样时的头数,默认为 -1num_heads_upsample=-1,# 是否使用尺度偏移归一化,默认为 Falseuse_scale_shift_norm=False,# 是否使用残差块进行上采样/下采样,默认为 Falseresblock_updown=False,# 是否使用新的注意力顺序,默认为 Falseuse_new_attention_order=False,# 是否使用空间变换器,支持自定义变换器use_spatial_transformer=False,  # custom transformer support# 变换器的深度,默认为 1transformer_depth=1,  # custom transformer support# 上下文维度,默认为 Nonecontext_dim=None,  # custom transformer support# 嵌入数,默认为 Nonen_embed=None,  # custom support for prediction of discrete ids into codebook of first stage vq model# 是否使用传统模式,默认为 Truelegacy=True,# 是否禁用自注意力,默认为 Nonedisable_self_attentions=None,# 注意力块的数量,默认为 Nonenum_attention_blocks=None,# 是否禁用中间自注意力,默认为 Falsedisable_middle_self_attn=False,# 是否在变换器中使用线性输入,默认为 Falseuse_linear_in_transformer=False,# 空间变换器的注意力类型,默认为 "softmax"spatial_transformer_attn_type="softmax",# 输入通道数,默认为 Noneadm_in_channels=None,# 是否使用 Fairscale 检查点,默认为 Falseuse_fairscale_checkpoint=False,# 是否将计算卸载到 CPU,默认为 Falseoffload_to_cpu=False,# 中间变换器的深度,默认为 Nonetransformer_depth_middle=None,# 配置条件嵌入维度,默认为 Nonecfg_cond_embed_dim=None,# 数据类型,默认为 "fp32"dtype="fp32",# 将模型的主体转换为 float16def convert_to_fp16(self):"""将模型的主体转换为 float16。"""# 对输入块应用转换模块,将其转换为 float16self.input_blocks.apply(convert_module_to_f16)# 对中间块应用转换模块,将其转换为 float16self.middle_block.apply(convert_module_to_f16)# 对输出块应用转换模块,将其转换为 float16self.output_blocks.apply(convert_module_to_f16)# 将模型的主体转换为 float32def convert_to_fp32(self):"""将模型的主体转换为 float32。"""# 对输入块应用转换模块,将其转换为 float32self.input_blocks.apply(convert_module_to_f32)# 对中间块应用转换模块,将其转换为 float32self.middle_block.apply(convert_module_to_f32)# 对输出块应用转换模块,将其转换为 float32self.output_blocks.apply(convert_module_to_f32)# 定义前向传播函数,接收输入数据和其他参数def forward(self, x, timesteps=None, context=None, y=None, scale_emb=None, **kwargs):"""应用模型于输入批次。:param x: 输入张量,形状为 [N x C x ...]。:param timesteps: 一维时间步批次。:param context: 通过 crossattn 插入的条件信息。:param y: 标签张量,形状为 [N],如果是类条件。:return: 输出张量,形状为 [N x C x ...]。"""# 如果输入数据类型不匹配,则转换为模型所需的数据类型if x.dtype != self.dtype:x = x.to(self.dtype)# 确保 y 的存在性与类数设置一致assert (y is not None) == (self.num_classes is not None), "must specify y if and only if the model is class-conditional"# 初始化存储中间结果的列表hs = []# 生成时间步嵌入t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)# 如果提供了缩放嵌入,则进行相应处理if scale_emb is not None:assert hasattr(self, "w_proj"), "w_proj not found in the model"t_emb = t_emb + self.w_proj(scale_emb.to(self.dtype))# 通过时间嵌入生成最终嵌入emb = self.time_embed(t_emb)# 如果模型是类条件,则将标签嵌入加入到最终嵌入中if self.num_classes is not None:assert y.shape[0] == x.shape[0]emb = emb + self.label_emb(y)# 将输入数据赋值给 h# h = x.type(self.dtype)h = x# 通过输入模块处理 h,并保存中间结果for module in self.input_blocks:h = module(h, emb, context)hs.append(h)# 通过中间模块进一步处理 hh = self.middle_block(h, emb, context)# 通过输出模块处理 h,并逐层合并中间结果for module in self.output_blocks:h = th.cat([h, hs.pop()], dim=1)h = module(h, emb, context)# 将 h 转换回原输入数据类型h = h.type(x.dtype)# 检查是否支持预测码本 IDif self.predict_codebook_ids:assert False, "not supported anymore. what the f*** are you doing?"else:# 返回最终输出结果return self.out(h)

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\sampling.py

# 部分代码移植自 https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
"""Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
"""# 从 typing 模块导入字典和联合类型
from typing import Dict, Union# 导入 PyTorch 库
import torch
# 从 omegaconf 模块导入配置相关的类
from omegaconf import ListConfig, OmegaConf
# 导入 tqdm 库用于显示进度条
from tqdm import tqdm# 从相对路径模块导入采样相关的工具函数
from ...modules.diffusionmodules.sampling_utils import (get_ancestral_step,  # 获取祖先步骤linear_multistep_coeff,  # 线性多步骤系数to_d,  # 转换为 dto_neg_log_sigma,  # 转换为负对数sigmato_sigma,  # 转换为 sigma
)
# 从相对路径模块导入离散化工具
from ...modules.diffusionmodules.discretizer import generate_roughly_equally_spaced_steps
# 从相对路径模块导入工具函数
from ...util import append_dims, default, instantiate_from_config# 定义默认引导器配置
DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}# 定义用于生成引导嵌入的函数
def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):"""参考文献: https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298Args:timesteps (`torch.Tensor`):在这些时间步生成嵌入向量embedding_dim (`int`, *可选*, 默认为 512):生成的嵌入的维度dtype:生成嵌入的数据类型Returns:`torch.FloatTensor`: 形状为 `(len(timesteps), embedding_dim)` 的嵌入向量"""# 确保输入张量是一个一维张量assert len(w.shape) == 1# 将输入乘以 1000.0w = w * 1000.0# 计算嵌入维度的一半half_dim = embedding_dim // 2# 计算基础嵌入的系数emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)# 生成嵌入基础,转换为指数形式并调整为目标设备和数据类型emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb).to(w.device).to(w.dtype)# 生成最终的嵌入向量emb = w.to(dtype)[:, None] * emb[None, :]# 将正弦和余弦值连接在一起emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)# 如果嵌入维度为奇数,进行零填充if embedding_dim % 2 == 1:  # zero pademb = torch.nn.functional.pad(emb, (0, 1))# 确保生成的嵌入形状与预期一致assert emb.shape == (w.shape[0], embedding_dim)# 返回生成的嵌入向量return emb# 定义基础扩散采样器类
class BaseDiffusionSampler:# 初始化采样器def __init__(self,discretization_config: Union[Dict, ListConfig, OmegaConf],  # 离散化配置num_steps: Union[int, None] = None,  # 采样步数,默认为 Noneguider_config: Union[Dict, ListConfig, OmegaConf, None] = None,  # 引导器配置,默认为 Nonecfg_cond_scale: Union[int, None] = None,  # 条件缩放参数,默认为 Nonecfg_cond_embed_dim: Union[int, None] = 256,  # 条件嵌入维度,默认为 256verbose: bool = False,  # 是否显示详细信息device: str = "cuda",  # 设备类型,默认为 CUDA):# 设置采样步数self.num_steps = num_steps# 实例化离散化配置self.discretization = instantiate_from_config(discretization_config)# 实例化引导器配置self.guider = instantiate_from_config(default(guider_config,DEFAULT_GUIDER,))# 设置条件参数self.cfg_cond_scale = cfg_cond_scaleself.cfg_cond_embed_dim = cfg_cond_embed_dim# 设置详细模式和设备self.verbose = verboseself.device = device# 准备采样循环的函数def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):# 生成 sigma 值sigmas = self.discretization(self.num_steps if num_steps is None else num_steps, device=self.device)# 默认使用条件uc = default(uc, cond)# 根据 sigma 计算 x 的调整x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)# 获取 sigma 的数量num_sigmas = len(sigmas)# 创建新的一维张量 s_in,初始值为 1s_in = x.new_ones([x.shape[0]]).float()# 返回调整后的 x 和其他参数return x, s_in, sigmas, num_sigmas, cond, uc# 定义去噪函数,接受输入x、去噪器denoiser、噪声水平sigma、条件cond和无条件ucdef denoise(self, x, denoiser, sigma, cond, uc):# 检查条件缩放系数是否不为Noneif self.cfg_cond_scale is not None:# 获取输入批次的大小batch_size = x.shape[0]# 创建与批次大小相同的全1张量,并乘以条件缩放系数,生成缩放嵌入scale_emb = guidance_scale_embedding(torch.ones(batch_size, device=x.device) * self.cfg_cond_scale, embedding_dim=self.cfg_cond_embed_dim, dtype=x.dtype)# 使用去噪器处理输入,传入缩放嵌入denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), scale_emb=scale_emb)else:# 若无条件缩放系数,直接使用去噪器处理输入denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))# 对去噪后的结果进行进一步引导处理denoised = self.guider(denoised, sigma)# 返回最终去噪结果return denoised# 定义生成sigma的函数,接受sigma数量num_sigmasdef get_sigma_gen(self, num_sigmas):# 创建一个范围生成器,从0到num_sigmas-1sigma_generator = range(num_sigmas - 1)# 如果启用了详细输出if self.verbose:# 打印分隔线和采样设置信息print("#" * 30, " Sampling setting ", "#" * 30)print(f"Sampler: {self.__class__.__name__}")print(f"Discretization: {self.discretization.__class__.__name__}")print(f"Guider: {self.guider.__class__.__name__}")# 使用tqdm包装生成器以显示进度条sigma_generator = tqdm(sigma_generator,total=num_sigmas,desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",)# 返回sigma生成器return sigma_generator
# 定义一个单步扩散采样器类,继承自基本扩散采样器
class SingleStepDiffusionSampler(BaseDiffusionSampler):# 定义采样步骤方法,未实现def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):# 抛出未实现错误,表明该方法需在子类中实现raise NotImplementedError# 定义欧拉步骤方法,用于计算下一个状态def euler_step(self, x, d, dt):# 返回更新后的状态,基于当前状态、导数和时间增量return x + dt * d# 定义 EDM 采样器类,继承自单步扩散采样器
class EDMSampler(SingleStepDiffusionSampler):# 初始化 EDM 采样器的参数def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs):# 调用父类的初始化方法super().__init__(*args, **kwargs)# 设置采样器的参数self.s_churn = s_churn  # 变化率self.s_tmin = s_tmin    # 最小时间self.s_tmax = s_tmax    # 最大时间self.s_noise = s_noise  # 噪声强度# 定义采样步骤方法def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):# 计算调整后的 sigma 值sigma_hat = sigma * (gamma + 1.0)# 如果 gamma 大于 0,加入噪声if gamma > 0:# 生成与 x 形状相同的随机噪声eps = torch.randn_like(x) * self.s_noise# 更新 x 的值,加入噪声x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5# 去噪,得到去噪后的结果denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)# 计算导数d = to_d(x, sigma_hat, denoised)# 计算时间增量dt = append_dims(next_sigma - sigma_hat, x.ndim)# 执行欧拉步骤,更新 xeuler_step = self.euler_step(x, d, dt)# 进行可能的修正步骤,得到最终的 xx = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)# 返回更新后的 xreturn x# 定义调用方法def __call__(self, denoiser, x, cond, uc=None, num_steps=None):# 准备采样循环所需的参数x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)# 遍历 sigma 值for i in self.get_sigma_gen(num_sigmas):# 计算 gamma 值gamma = (min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)if self.s_tmin <= sigmas[i] <= self.s_tmaxelse 0.0)# 执行采样步骤,更新 xx = self.sampler_step(s_in * sigmas[i],s_in * sigmas[i + 1],denoiser,x,cond,uc,gamma,)# 返回最终的 xreturn x# 定义 DDIM 采样器类,继承自单步扩散采样器
class DDIMSampler(SingleStepDiffusionSampler):# 初始化 DDIM 采样器的参数def __init__(self, s_noise=0.1, *args, **kwargs):# 调用父类的初始化方法super().__init__(*args, **kwargs)# 设置噪声强度self.s_noise = s_noise# 定义采样步骤方法def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0):# 去噪,得到去噪后的结果denoised = self.denoise(x, denoiser, sigma, cond, uc)# 计算导数d = to_d(x, sigma, denoised)# 计算时间增量dt = append_dims(next_sigma * (1 - s_noise**2)**0.5 - sigma, x.ndim)# 计算欧拉步骤,加入噪声euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x)# 进行可能的修正步骤,得到最终的 xx = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)# 返回更新后的 xreturn x# 定义一个可调用的类方法,接收去噪器、输入数据、条件及其他参数def __call__(self, denoiser, x, cond, uc=None, num_steps=None):# 准备采样循环,返回处理后的数据和相关参数x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)# 遍历生成的 sigma 值for i in self.get_sigma_gen(num_sigmas):# 执行采样步骤,更新输入数据 xx = self.sampler_step(s_in * sigmas[i],    # 当前 sigma 乘以输入信号s_in * sigmas[i + 1],# 下一个 sigma 乘以输入信号denoiser,            # 传递去噪器x,                   # 当前数据cond,                # 条件信息uc,                  # 可选的额外条件self.s_noise,        # 传递噪声信息)# 返回最终处理后的数据return x
# 定义一个继承自 SingleStepDiffusionSampler 的类 AncestralSampler
class AncestralSampler(SingleStepDiffusionSampler):# 初始化方法,设定默认参数 eta 和 s_noisedef __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):# 调用父类的初始化方法super().__init__(*args, **kwargs)# 设置 eta 属性self.eta = eta# 设置 s_noise 属性self.s_noise = s_noise# 定义噪声采样器,生成与输入形状相同的随机噪声self.noise_sampler = lambda x: torch.randn_like(x)# 定义 ancestral_euler_step 方法,用于执行欧拉步长def ancestral_euler_step(self, x, denoised, sigma, sigma_down):# 计算偏导数 dd = to_d(x, sigma, denoised)# 将 sigma_down 和 sigma 的差值扩展到 x 的维度dt = append_dims(sigma_down - sigma, x.ndim)# 返回欧拉步长的结果return self.euler_step(x, d, dt)# 定义 ancestral_step 方法,执行采样步骤def ancestral_step(self, x, sigma, next_sigma, sigma_up):# 根据条件选择更新 x 的值x = torch.where(append_dims(next_sigma, x.ndim) > 0.0,  # 检查 next_sigma 是否大于 0x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),  # 更新 x 的值x,  # 保持原值)# 返回更新后的 xreturn x# 定义调用方法,使得类可以被调用def __call__(self, denoiser, x, cond, uc=None, num_steps=None):# 准备采样循环,获取必要的输入x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)# 遍历 sigma 生成器,进行采样步骤for i in self.get_sigma_gen(num_sigmas):x = self.sampler_step(s_in * sigmas[i],  # 当前 sigma 值s_in * sigmas[i + 1],  # 下一个 sigma 值denoiser,  # 去噪器x,  # 当前 x 值cond,  # 条件uc,  # 额外条件)# 返回最终的 x 值return x# 定义一个继承自 BaseDiffusionSampler 的类 LinearMultistepSampler
class LinearMultistepSampler(BaseDiffusionSampler):# 初始化方法,设定默认的 order 参数def __init__(self,order=4,*args,**kwargs,):# 调用父类的初始化方法super().__init__(*args, **kwargs)# 设置 order 属性self.order = order# 定义调用方法,使得类可以被调用def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):# 准备采样循环,获取必要的输入x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)# 初始化一个列表 ds 用于存储导数ds = []# 将 sigmas 从 GPU 移到 CPU,并转换为 numpy 数组sigmas_cpu = sigmas.detach().cpu().numpy()# 遍历 sigma 生成器for i in self.get_sigma_gen(num_sigmas):# 计算当前的 sigmasigma = s_in * sigmas[i]# 使用去噪器处理当前输入denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs)# 使用引导函数对去噪结果进行处理denoised = self.guider(denoised, sigma)# 计算导数 dd = to_d(x, sigma, denoised)# 将导数添加到列表 dsds.append(d)# 如果 ds 的长度超过 order,移除最早的元素if len(ds) > self.order:ds.pop(0)# 计算当前的阶数cur_order = min(i + 1, self.order)# 计算当前阶数的线性多步系数coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j)for j in range(cur_order)]# 更新 x 值x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))# 返回最终的 x 值return x# 定义一个继承自 EDMSampler 的类 EulerEDMSampler
class EulerEDMSampler(EDMSampler):# 定义可能的校正步骤方法def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):# 返回 euler_step,表示不进行额外的校正return euler_step# 定义一个继承自 EDMSampler 的类 HeunEDMSampler
class HeunEDMSampler(EDMSampler):# 定义可能的校正步骤方法def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):):# 如果下一个噪声水平的总和小于一个非常小的阈值if torch.sum(next_sigma) < 1e-14:# 如果所有噪声水平为0,保存网络评估的结果return euler_stepelse:# 使用去噪器对当前步进行去噪处理denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)# 将去噪后的结果转换为新数据d_new = to_d(euler_step, next_sigma, denoised)# 计算当前数据与新数据的平均值d_prime = (d + d_new) / 2.0# 如果噪声水平不为0,则应用修正x = torch.where(# 检查噪声水平是否大于0,决定是否修正append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step)# 返回修正后的结果return x
# 定义一个 Euler 祖先采样器类,继承自 AncestralSampler
class EulerAncestralSampler(AncestralSampler):# 定义采样步骤的方法,接受多个参数def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):# 获取下一个采样步的 sigma 值sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)# 使用去噪器对当前输入进行去噪denoised = self.denoise(x, denoiser, sigma, cond, uc)# 使用 Euler 方法更新 x 的值x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)# 应用祖先步骤更新 x 的值x = self.ancestral_step(x, sigma, next_sigma, sigma_up)# 返回更新后的 xreturn x# 定义一个 DPMPP2S 祖先采样器类,继承自 AncestralSampler
class DPMPP2SAncestralSampler(AncestralSampler):# 获取变量的方法,计算相关参数def get_variables(self, sigma, sigma_down):# 将 sigma 和 sigma_down 转换为负对数形式t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]# 计算时间间隔 hh = t_next - t# 计算 s 值s = t + 0.5 * h# 返回计算的参数return h, s, t, t_next# 获取乘法因子的方法def get_mult(self, h, s, t, t_next):# 计算各个乘法因子mult1 = to_sigma(s) / to_sigma(t)mult2 = (-0.5 * h).expm1()mult3 = to_sigma(t_next) / to_sigma(t)mult4 = (-h).expm1()# 返回所有乘法因子return mult1, mult2, mult3, mult4# 采样步骤的方法,执行多个计算步骤def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):# 获取下一个采样步的 sigma 值sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)# 对输入进行去噪denoised = self.denoise(x, denoiser, sigma, cond, uc)# 使用 Euler 方法更新 x 的值x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)# 检查 sigma_down 是否接近于零if torch.sum(sigma_down) < 1e-14:# 如果噪声级别为 0,则保存网络评估x = x_eulerelse:# 获取变量 h, s, t, t_nexth, s, t, t_next = self.get_variables(sigma, sigma_down)# 获取乘法因子,并调整维度mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)]# 更新 x 的值x2 = mult[0] * x - mult[1] * denoised# 对 x2 进行去噪denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)# 计算最终的 x 值x_dpmpp2s = mult[2] * x - mult[3] * denoised2# 如果噪声级别不为 0,则应用校正x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)# 最终应用祖先步骤更新 xx = self.ancestral_step(x, sigma, next_sigma, sigma_up)# 返回更新后的 xreturn x# 定义一个 DPMPP2M 采样器类,继承自 BaseDiffusionSampler
class DPMPP2MSampler(BaseDiffusionSampler):# 获取变量的方法,计算相关参数def get_variables(self, sigma, next_sigma, previous_sigma=None):# 将 sigma 和 next_sigma 转换为负对数形式t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]# 计算时间间隔 hh = t_next - t# 如果提供了 previous_sigma,则进行额外计算if previous_sigma is not None:h_last = t - to_neg_log_sigma(previous_sigma)r = h_last / hreturn h, r, t, t_nextelse:# 如果没有提供,则返回 h 和 t 值return h, None, t, t_next# 获取乘法因子的方法def get_mult(self, h, r, t, t_next, previous_sigma):# 计算基础乘法因子mult1 = to_sigma(t_next) / to_sigma(t)mult2 = (-h).expm1()# 如果提供了 previous_sigma,则计算额外的乘法因子if previous_sigma is not None:mult3 = 1 + 1 / (2 * r)mult4 = 1 / (2 * r)return mult1, mult2, mult3, mult4else:# 返回基本的乘法因子return mult1, mult2# 采样步骤的方法,执行多个计算步骤def sampler_step(self,old_denoised,previous_sigma,sigma,next_sigma,denoiser,x,cond,uc=None,):# 使用去噪器对输入数据进行去噪,返回去噪后的结果denoised = self.denoise(x, denoiser, sigma, cond, uc)# 获取当前和下一个噪声级别相关的变量h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)# 计算多重系数,扩展维度以匹配输入数据的维度mult = [append_dims(mult, x.ndim)for mult in self.get_mult(h, r, t, t_next, previous_sigma)]# 计算标准化后的输出x_standard = mult[0] * x - mult[1] * denoised# 检查之前的去噪结果是否存在或下一噪声级别是否接近零if old_denoised is None or torch.sum(next_sigma) < 1e-14:# 如果噪声级别为零或处于第一步,返回标准化结果和去噪结果return x_standard, denoisedelse:# 计算去噪后的数据修正值denoised_d = mult[2] * denoised - mult[3] * old_denoised# 计算高级输出x_advanced = mult[0] * x - mult[1] * denoised_d# 如果噪声级别不为零且不是第一步,应用修正x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)# 返回最终输出和去噪结果return x, denoiseddef __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):# 准备采样循环,包括输入数据和条件信息的处理x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)old_denoised = None# 遍历噪声级别生成器for i in self.get_sigma_gen(num_sigmas):# 在每个步骤中执行采样,更新去噪结果x, old_denoised = self.sampler_step(old_denoised,None if i == 0 else s_in * sigmas[i - 1],s_in * sigmas[i],s_in * sigmas[i + 1],denoiser,x,cond,uc=uc,)# 返回最终的去噪结果return x
# 定义一个将输入信号传递到去噪器的函数
def relay_to_d(x, sigma, denoised, image, step, total_step):# 计算模糊度的变化量blurring_d = (denoised - image) / total_step# 根据模糊度和当前步长更新去噪图像blurring_denoised = image + blurring_d * step# 计算当前信号与去噪信号的差异,标准化为 sigma 的维度d = (x - blurring_denoised) / append_dims(sigma, x.ndim)# 返回计算得到的差异和模糊度变化return d, blurring_d# 定义一个线性中继EDM采样器,继承自EulerEDMSampler
class LinearRelayEDMSampler(EulerEDMSampler):# 初始化函数,设定部分步数def __init__(self, partial_num_steps=20, *args, **kwargs):# 调用父类初始化方法super().__init__(*args, **kwargs)# 设置部分步数self.partial_num_steps = partial_num_steps# 定义采样调用方法def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None):# 克隆随机数以保持不变randn_unit = randn.clone()# 准备采样循环,获取相关参数randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(randn, cond, uc, num_steps)# 初始化 x 为 Nonex = None# 遍历生成的 sigma 值for i in self.get_sigma_gen(num_sigmas):# 如果当前步数小于总步数减去部分步数,继续下一次循环if i < self.num_steps - self.partial_num_steps:continue# 如果 x 还未初始化,则根据图像和随机数计算初始值if x is None:x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape))# 计算 gamma 值,控制采样过程中的噪声gamma = (min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)if self.s_tmin <= sigmas[i] <= self.s_tmaxelse 0.0)# 进行一次采样步骤x = self.sampler_step(s_in * sigmas[i],s_in * sigmas[i + 1],denoiser,x,cond,uc,gamma,step=i - self.num_steps + self.partial_num_steps,image=image,index=self.num_steps - i,)# 返回最终的图像return x# 定义欧拉步骤的计算方法def euler_step(self, x, d, dt, blurring_d):# 更新 x 的值return x + dt * d + blurring_d# 定义采样步骤的计算方法def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0, step=None, image=None, index=None):# 计算 sigma_hat,考虑 gamma 的影响sigma_hat = sigma * (gamma + 1.0)# 如果 gamma 大于 0,添加噪声if gamma > 0:eps = torch.randn_like(x) * self.s_noisex = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5# 使用去噪器去噪当前图像denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)# 计算 beta_t,控制去噪过程beta_t = next_sigma / sigma_hat * index / self.partial_num_steps - (index - 1) / self.partial_num_steps# 更新 x 的值,结合去噪结果x = x * append_dims(next_sigma / sigma_hat, x.ndim) + denoised * append_dims(1 - next_sigma / sigma_hat + beta_t, x.ndim) - image * append_dims(beta_t, x.ndim)# 返回更新后的图像return x# 定义零信噪比DDIM采样器,继承自SingleStepDiffusionSampler
class ZeroSNRDDIMSampler(SingleStepDiffusionSampler):# 初始化函数,设定是否使用条件生成def __init__(self,do_cfg=True,*args,**kwargs,):# 调用父类初始化方法super().__init__(*args, **kwargs)# 设置条件生成标志self.do_cfg = do_cfg# 准备采样循环的参数def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):# 计算累积的 alpha 值,并获取对应的索引alpha_cumprod_sqrt, indices = self.discretization(self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True)# 如果 uc 为 None,则使用 conduc = default(uc, cond)# 获取 sigma 的数量num_sigmas = len(alpha_cumprod_sqrt)# 初始化 s_in 为全 1 向量s_in = x.new_ones([x.shape[0]])# 返回准备好的参数return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, indices# 定义去噪函数,接受输入数据和其他参数def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, i=None, idx=None):# 初始化额外的模型输入字典additional_model_inputs = {}# 如果启用 CFG,准备包含索引的输入if self.do_cfg:additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * idx] * 2)# 否则只准备单个索引输入else:additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * idx])# 使用去噪器处理准备好的输入和额外参数,得到去噪后的结果denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs)# 使用引导器进一步处理去噪后的结果denoised = self.guider(denoised, alpha_cumprod_sqrt, step=i, num_steps=self.num_steps)# 返回去噪后的结果return denoised# 定义采样步骤函数,执行去噪和更新过程def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, i=None, idx=None, return_denoised=False):# 调用去噪函数,并转换结果为浮点型denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, i, idx).to(torch.float32)# 如果达到最后一步,返回去噪结果if i == self.num_steps - 1:if return_denoised:return denoised, denoisedreturn denoised# 计算当前步骤的 a_t 值a_t = ((1-next_alpha_cumprod_sqrt**2)/(1-alpha_cumprod_sqrt**2))**0.5# 计算当前步骤的 b_t 值b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t# 更新 x 的值,结合去噪后的结果x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised# 根据需要返回去噪结果if return_denoised:return x, denoisedreturn x# 定义可调用函数,用于处理采样和去噪流程def __call__(self, denoiser, x, cond, uc=None, num_steps=None):# 准备采样循环所需的输入数据x, s_in, alpha_cumprod_sqrts, num_sigmas, cond, uc, indices = self.prepare_sampling_loop(x, cond, uc, num_steps)# 根据 sigma 生成器逐步执行采样for i in self.get_sigma_gen(num_sigmas):x = self.sampler_step(s_in * alpha_cumprod_sqrts[i],s_in * alpha_cumprod_sqrts[i + 1],denoiser,x,cond,uc,i=i,idx=indices[self.num_steps-i-1],)# 返回最终的结果return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/819511.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CogView3---CogView-3Plus-微调代码源码解析-二-

CogView3 & CogView-3Plus 微调代码源码解析(二) .\cogview3-finetune\sat\sgm\models\__init__.py # 从同一模块导入 AutoencodingEngine 类,用于后续的自动编码器操作 from .autoencoder import AutoencodingEngine# 注释文本(可能是无关信息或标识符) #XuDwndGaCFo…

券后价复杂根源和解法

券后价领域划分不清楚 券后价在电商系统中是个很奇怪的存在 无论是按商品领域还是营销领域划分,它都不合适归类到这两者中间。结果就是券后价是个很不理想的拆分逻辑。 券后价可以理解是商品的价格属性,这个属性是由营销来计算控制。领域划分可以理解为商品领域,营销做计算!…

营销领域分析

用户与商品的连接用户购买商品是整个商业的基本盘。用户与商品是多对多关系,在这个基础之上就可衍生出许多行为。可以跟据商品的属性又可以设计各种运营方式。 用几个条件来归类交易产生的条件条件 人 物when 人什么时候需要商品 商品什么时候被需要why 人为什么需要商品 商品…

计数系统设计

在营销的场景里有三要素用户 商品 优惠在这三个要素里,再加一些如时间,数量,频次等变量,会演化出各种组合,使得业务变得非常灵活。各业务线为了满足业务,一般都会各自实现,且多数情况下都会重复实现,而且实现起来各地方都会产生交叉配置,交叉互斥的问题。在观察到这些…

PbootCMS授权码可以更换域名吗? 授权码丢失怎么办?

授权码可以更换域名吗?不可以:授权码是绑定特定域名的,如果需要更换域名,建议重新获取新的授权码。授权码丢失怎么办?重新获取:如果授权码丢失,可以重新访问授权页面,输入相同的域名再次获取授权码。扫码添加技术【解决问题】专注中小企业网站建设、网站安全12年。熟悉…

PbootCMS授权码是否可以用于不同域名的子域名?是否可以用于不同域名的子目录?

授权码是否可以用于不同域名的子域名?不可以:授权码是绑定特定域名的,不支持不同域名的子域名。例如,sub1.example.com 和 sub2.anotherdomain.com 需要分别获取授权码。18. 授权码是否可以用于不同域名的子目录?不可以:授权码是绑定特定域名的,不支持不同域名的子目录。…

PbootCMS验证码不显示或显示不清楚怎么办

验证码不显示或显示不清楚问题描述:后台登录时验证码不显示或显示不清楚。 解决方案:避免使用中文路径:确保所有文件和目录名称均为英文或数字。 切换PHP版本:推荐使用PHP 7.3、7.2、5.6版本。 检查文件权限:确保验证码相关文件和目录具有适当的读写权限(通常为755或644)…

值得信赖的FTP替代方案有哪些,一文带你详细了解!

FTP(文件传输协议)因其传输速度慢、安全隐患、管理复杂性、稳定性不足以及审计难题等缺陷,使得企业在寻找更高效的替代方案时显得尤为迫切。 FTP替代方案有哪些,简单了解看下吧: 1、SFTP:SFTP是建立在SSH(Secure Shell)协议之上的文件传输协议,提供了数据传输的加密和…

HTTP 2.0 新特性

HTTP 2.0 新特性HTTP 2.0 为什么使用二进制分帧?二进制协议比文本协议更加紧凑,减少占用空间 分帧层相当于将 HTTP 切分,更加灵活,比如可以对 header 帧做单独的特殊处理 分帧层有着属于自己的报文头,其中的 Stream Identity 使得操作系统具备将多个响应以及请求一一匹配的…

Python脚本检测笑脸漏洞

Python脚本检测笑脸漏洞 一、漏洞介绍 ​ vsftpd2.3.4中在6200端口存在一个shell,使得任何人都可以进行连接,并且VSFTPD v2.3.4 服务,是以 root 权限运行的,最终我们提到的权限也是root;当连接带有vsftpd 2.3.4版本的服务器的21端口时,输入用户中带有“😃 ”,密码…

Veritas Backup Exec 24.0 发布,新增功能概览

Veritas Backup Exec 24.0 发布,新增功能概览Veritas Backup Exec 24.0 发布,新增功能概览 Veritas Backup Exec 24.0 (Windows) - 面向中小型企业的数据备份和恢复 请访问原文链接:https://sysin.org/blog/veritas-backup-exec-24/ 查看最新版。原创作品,转载请保留出处。…

slope trick

slope trickP4597 序列 sequence 首先考虑 \(dp\) 。 由于只需将序列改为非严格递增,那么就有一个贪心,即最终答案的数集不会变大。 为什么呢? 这是因为只有序列某一位置严格递减时,才会进行修改。 修改可以将前面的数降到和后面的数一样大,或者将后面的数提到和前面的数一…