FSQ: FINITE SCALAR QUANTIZATION: VQ-VAE MADE SIMPLE

Paper name

FINITE SCALAR QUANTIZATION: VQ-VAE MADE SIMPLE

Paper Reading Note

Paper URL: https://arxiv.org/abs/2309.15505

Code URL:

  • (官方 jax 实现) https://github.com/google-research/google-research/tree/master/fsq
  • (pytorch 实现) https://github.com/lucidrains/vector-quantize-pytorch?tab=readme-ov-file#finite-scalar-quantization

TL;DR

  • 2023 年 google 发表的文章,可以用于文本、视频生成领域中。提出一种称为有限标量量化(FSQ)的简单方案来替换 VQ-VAEs 中的向量量化(VQ)。解决传统 VQ 中的两个主要问题:
    • 需要避免 codebook collapse 的辅助损失
    • 大 codebook size 情况下码本利用率低

Introduction

背景

  • vqgan、phenaki 等工作表明,对使用 GAN 损失训练的 VQ-VAE 表示进行自回归 transformer 模型训练可以实现强大的图像和视频生成模型
  • 然而 VQVAE 中的向量量化 (VQ)有一定缺点:
    • 训练优化困难:在训练VQ-VAE时,目标是学习一个码本 C,希望其包含输入数据(通常是图像)的压缩语义表示。在前向传播中,图像 x 被编码成一个表示 z(通常是特征向量序列),并且 z 中的每个向量都被量化为 C 中最接近的向量。量化操作是不可微的。在使用 VQ 在潜在表示中训练 VAE 时,需要使用直通估计器(STE),将梯度从解码器输入复制到编码器输出,从而使得编码器的梯度能正常传导。由于这仍然不能产生用于码书向量的梯度,vqvae 工作进一步引入了两个辅助损失,将码字向量拉向(未量化的)表示向量,反之亦然。这导致随着 C 的大小增加,可能会有码本坍塌(codebook collapse)问题,也即许多码字将不会被使用。后续一些工作需要单独处理这个问题,比如重新初始化整个码书等

本文方案

  • 提出一种称为有限标量量化(FSQ)的简单方案来替换 VQ-VAEs 中的向量量化(VQ)。新方案希望解决传统 VQ 中的两个主要问题:
    • 消除辅助损失
    • 提高码本利用率
    • 作为 VQ 的可替换组件
  • 具体实现方式
    • 将 VAE 表示投影到少量维度(通常少于10)。每个维度被量化为一组固定的值,由这些数值集合的乘积给出(隐式的)码本 (codebook)
      • 比如对于一个具有 d 个 channel 的向量 z,如果将每个条目 zi 映射到 L 个值(比如通过 z i = R o u n d ( ⌊ L / 2 ⌋ t a n h ( z i ) ) z_i = Round(\lfloor L / 2 \rfloor tanh(z_{i})) zi=Round(⌊L/2tanh(zi)) ,其中 Round 是四舍五入算子),则可以获得一个量化后的向量 z ^ \hat{z} z^,其中 z ^ \hat{z} z^ L d L^d Ld 个唯一可能的向量之一。下图展示了 d=3, L=3 的 FSQ,码本 C = ( − 1 , − 1 , − 1 ) , ( − 1 , − 1 , 0 ) , ( − 1 , − 1 , 1 ) , . . . , ( 1 , 1 , 1 ) C = {(−1, −1, −1), (−1, −1, 0), (−1, −1, 1), . . . , (1, 1, 1)} C=(1,1,1),(1,1,0),(1,1,1),...,(1,1,1),其中码本大小 ∣ C ∣ = L d = 27 |C|=L^{d}=27 C=Ld=27
        FSQ 对比 VQ
    • 为了让 round 也获得梯度,这里也使用了 STE。因此,使用 FSQ 在使用重构损失训练的自动编码器中,我们获得了对编码器的梯度,这迫使模型将信息分散到多个 quantization bins 中,因为这减少了重构损失。结果是,我们获得了一个使用所有码字的量化器,而不需要任何辅助损失
    • 尽管 FSQ 的设计要简单得多,但本文在图像生成、多模态生成、深度估计等任务中获得了有竞争力的结果。FSQ 的优点是不会遭受码本坍塌(codebook collapse),并且不需要 VQ 中为了避免码本坍塌而使用的复杂机制(承诺损失、码本重新播种、码分割、熵惩罚等)
FSQ 超参数
  • FSQ具有以下超参数:通道数 d 和每个通道的级别数 L = [L1, . . . , Ld]。在我们的大部分实验中,为了进行公平比较,我们将根据我们打算用 FSQ 替换的 VQ 码书选择目标码书大小 |C|。然而,各种不同的 d 和 Li 的配置都可以近似给定的 |C|。我们在研究中探索了各种配置,并发现并非所有选择都会产生最佳结果。然而,我们发现了一个简单的启发式方法,在所有考虑的任务中表现良好:对于所有i,使用 Li ≥ 5。在表1中,我们列出了常见目标 |C| 的 L 值。
    FSQ levels 推荐设置
FSQ 参数量
  • 我们注意到,与 VQ 相比,FSQ 的参数更少,因为在 VQ 中,会学习一个大小为 |C| · d 的码书。例如,对于典型的 |C|=2^12=4096 和 d=512,这将导致 2M 个参数,而 FSQ 则不具备这些参数。此外,由于对于 FSQ 来说,d 往往比 VQ 小得多(例如,对于这个 |C|,FSQ 的 d 为5,见表1),所以当训练 FSQ 时,最终的编码器层也具有更少的参数。为了弥补这一点,我们探索了在 VAE 编码器末尾或解码器开始处添加更多的全连接层,但发现这样做并没有进一步的收益。因此,在本文中的所有模型中,具有相同码书大小的 FSQ 具有更少的参数。

代码实现

  • 主要看 quantization 实现,可以看出来实现方面很简单,不考虑 L 要兼容奇数偶数的情况下,其实量化操作主要就是对输入的 z 非线性缩放后进行四舍五入操作
import jax.numpy as jnp
import numpy as npdef round_ste(z):"""使用直通梯度进行四舍五入。"""zhat = jnp.round(z)  # 对输入数组z进行四舍五入return z + jax.lax.stop_gradient(zhat - z)  # 返回z加上z和四舍五入后的z的差值,这个差值的梯度会被停止传播class FSQ:"""量化器类。"""def __init__(self, levels: list[int], eps: float = 1e-3):"""初始化量化器。Args:levels (list[int]): 分量化等级列表。eps (float): 用于边界处理的微小值。"""self._levels = levelsself._eps = epsself._levels_np = np.asarray(levels)  # 将等级列表转换为NumPy数组self._basis = np.concatenate(([1], np.cumprod(self._levels_np[:-1]))).astype(np.uint32)  # 创建基数组self._implicit_codebook = self.indexes_to_codes(np.arange(self.codebook_size))  # 初始化隐式码本@propertydef num_dimensions(self) -> int:"""输入期望的维度数量。"""return len(self._levels)@propertydef codebook_size(self) -> int:"""码本的大小。"""return np.prod(self._levels)@propertydef codebook(self):"""返回隐式码本。形状为 (prod(levels), num_dimensions)。"""return self._implicit_codebookdef bound(self, z: jax.Array) -> jax.Array:"""对形状为 (..., d) 的数组z进行边界处理。"""half_l = (self._levels_np - 1) * (1 - self._eps) / 2  # 计算每个维度的边界offset = jnp.where(self._levels_np % 2 == 1, 0.0, 0.5)  # 计算偏移量shift = jnp.tan(offset / half_l)  # 计算偏移return jnp.tanh(z + shift) * half_l - offset  # 应用边界和偏移def quantize(self, z: jax.Array) -> Codeword:"""量化z,返回量化后的zhat,形状与z相同。"""quantized = round_ste(self.bound(z))  # 对边界处理后的z进行四舍五入# 重新归一化到[-1, 1]区间half_width = self._levels_np // 2return quantized / half_widthdef _scale_and_shift(self, zhat_normalized):# 将归一化后的zhat缩放和偏移,使其范围在[0, ..., L-1]half_width = self._levels_np // 2return (zhat_normalized * half_width) + half_widthdef _scale_and_shift_inverse(self, zhat):# 逆操作,将范围在[0, ..., L-1]的zhat转换回归一化形式half_width = self._levels_np // 2return (zhat - half_width) / half_widthdef codes_to_indexes(self, zhat: Codeword) -> Indices:"""将码转换为码本中的索引。Args:zhat (Codeword): 归一化后的码。Returns:Indices: 码本中的索引。"""assert zhat.shape[-1] == self.num_dimensions  # 确保码的维度与期望的维度相同zhat = self._scale_and_shift(zhat)  # 缩放和偏移return (zhat * self._basis).sum(axis=-1).astype(jnp.uint32)  # 计算索引def indexes_to_codes(self, indices: Indices) -> Codeword:"""`codes_to_indexes`的逆操作。Args:indices (Indices): 码本中的索引。Returns:Codeword: 转换后的码。"""indices = indices[..., jnp.newaxis]  # 扩展索引维度codes_non_centered = np.mod(np.floor_divide(indices, self._basis), self._levels_np)  # 计算非中心化的码return self._scale_and_shift_inverse(codes_non_centered)  # 逆缩放和偏移

Experiments

MaskGIT 与 UViM 回顾

  • 首先简要回顾一下 MaskGIT 和 UViM。
    • 在MaskGIT中,作者首先训练了一个(卷积)VQ-GAN自编码器用于重建(第一阶段)。然后,他们冻结了自编码器,并训练了一个掩码变换器 BERT style来预测量化表示(第二阶段):给定一个表示 z,随机地“掩盖”一部分 token,即用一个特殊的 MASK token 替换。生成的序列 zM 被馈送到一个 transformer 中,除了一个 class token 外,transformer 为每个被掩盖的 token 预测一个分布。在推断过程中,最初只有 MASK token 和 class token 与 transformer 一起被馈送。然后,根据预测的置信度选择一些 token 位置,并抽样相应的 token。这些 token 用于替换输入中的掩码 token,并且再次运行模型,直到所有输入 token 都被揭示。
    • UViM 是一种通用架构,用于解决计算机视觉中的各种(密集)预测任务。在第一阶段,基于 transformer 的 VQ-VAE 被训练来模拟目标任务的标签空间。可选地,VQ-VAE编码器和解码器都可以依赖于任务输入(RGB图像用于深度估计和分割,灰度图像用于着色)作为附加信息或“上下文”,这对于某些任务是有益的。在第二阶段,一个编码器-解码器 transformer 被训练来预测由 VQ-VAE 编码器产生的量化 token 的密集标签,给定任务输入。对于推断,使用条件于输入的 transformer 自回归地对代码进行采样,然后将其馈送给 VQ-VAE 解码器。这个架构对于三个任务是共享的,但是为每个任务学习了不同的权重。

VQ、FSQ 对比的指标

  • 基于 MaskGIT 进行对比,对较低分辨率的 128 × 128 ImageNet 图像进行了训练,训练时间较原始论文更短(第一阶段 100 个 epoch,第二阶段 200 个 epoch。这使我们能够调整码本大小和其他超参数。对于 VQ,我们使用了 MaskGIT 的辅助熵损失,旨在增加码本的熵(以增加利用率)。我们只调整码本大小。对于 FSQ,我们探索各种 d 和 Li 来匹配这些码本大小。
  • 跟踪以下指标:
    • 重建 FID:即当将 50k 验证图像馈送到量化自动编码器时,由 GAN 训练的自动编码器获得的 FID。如果第二阶段 transformer 能够完美地对数据建模,这就是它将实现的理想 FID。
    • 码本使用情况:在对验证集进行编码时至少使用了一次的码字的比例。
    • sample FID:即第二阶段使用 transformer 对表示 z ^ \hat{z} z^ 进行抽样(按类条件)时获得的 FID
    • 压缩成本: 作为模拟表示底层离散分布的困难程度的代理(即,建模复杂性)。请注意,任何预测离散码的分布的 transformer 都可以用于对相应表示进行无损压缩。对于掩码 transformer,唯一的要求是一个确定性的掩码 schedule,逐渐揭示输入。使用这样的 schedule,我们可以通过将 transofmre 输出与熵编码配对,将任何 z ^ \hat{z} z^ 压缩到比特。我们使用了 M2T 中采用的确定性掩码 schedule

训练细节

MaskGIT
  • 早期实验表明,与 VQ 相比,FSQ 在 Precision & Recall 点上处于不同的位置(FSQ 具有更高的召回率和更低的精确度)。受扩散文献的启发,因此我们将分类器无关的引导(CFG)添加到 MaskGIT 中:在训练过程中,我们用 MASK 标记替换 10% 的类别标签,以让模型学习无条件分布。在推断过程中,我们对 logits 进行插值:设 l c l_{c} lc 是在给定类别标签 c 的情况下获得的 logits, l ∅ l_{∅} l 是无条件 logits。在推断过程中,我们计算新的 logits l ′ = l c + α ( l c − l ∅ ) l^{′} = l_{c} + α(l_{c} − l_{∅}) l=lc+α(lcl) ,其中 α 是 CFG 推断权重。直观地说,这将预测的分布拉向无条件分布。这一点之前在 MUSE 里面已经被探讨过。
UVIM
  • 三个任务都做实验:panoptic segmentation, depth estimation, colorization
  • stage2 会训练 3 次取平均精度
  • 评测指标:
    • 全景分割使用全景质量(PQ)
    • 深度估计使用 RMSE
    • 着色使用 FID-5k
  • 实验了 UViM 使用的码本分割策略,以避免在 VQ-VAE 中出现未使用的码字。具体来说,他们采用了 Linde 等人(1980)的算法,在整个训练过程中,检测未使用的向量,然后将这些向量替换为将最常用的嵌入拆分为两个新的嵌入,并为每个嵌入添加噪声。由于我们观察到在全景分割任务中关闭码本分割时出现了训练不稳定性,我们使用深度估计任务进行此项消融研究。

VQ 对比 FSQ

  • 128px Imagenet 实验:
    • FSQ 当 codebook size 越大重建 FID 越低,比 VQ 更适合大 codebook(VQ 在 2^11 下精度最高,再增加 codebook 利用率就逐渐降低了)
    • 低 codebook size 下 VQ 略优于 FSQ,原因主要是 VQ 表达能力天生更强
    • 生成精度随着 codebook size 增加 FSQ 也是变好的
    • codebook 利用率 FSQ 更高(2^14 下基本全部使用,之后略有下降)
      在这里插入图片描述
  • 如果再继续增加 codebook size,sample FID 能否继续增加?
    • 作者认为不能。因为如上图 d 所示,表示的压缩成本不断增加。这表明对于 transformer 来说,量化表示变得越来越复杂了。2^12 codebook 之后再增加精度趋向饱和。

MaskGIT 实验

  • 定量对比:VQ 和 FSQ 精度接近
    在这里插入图片描述
    在这里插入图片描述
  • 可视化对比:效果接近
    在这里插入图片描述

UViM 实验

  • 定量对比:FSQ 比 VQ 精度差不多,FSQ 稍差一点
    在这里插入图片描述
  • 可视化对比
    在这里插入图片描述

Thoughts

  • 消除 vq 的各种辅助损失看起来很有吸引力
  • FSQ 精度上比 vq 并没有明显优势,这里可能还有优化空间

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/484934.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

教你用Appium搭建Android自动化测试框架(详细教程)

如题,本文附有仓库地址以及代码,目录如下: 1.下载安装appium 2.配置并使用appium 3.配置工程 一、下载安装appium 下载appium-desktop的mac版本,我选择用1.17.1版本,最新的版本1.22.0,拆分一个桌面应用…

2.22作业

作业要求&#xff1a; 程序代码&#xff1a; #include<myhead.h>int num 520; pthread_mutex_t mutex;void* task1(void* arg) {pthread_mutex_lock(&mutex);num 1314;sleep(3);printf("task1:num %d\n", num);pthread_mutex_unlock(&mutex); }voi…

医药之链:基于Django的智能药品管理系统

框架 Python 3.7 django 3.2.13 Bootstrap&#xff08;前端&#xff09; sqlite&#xff08;数据库&#xff09;导包 pip install django3.2.13 pip install pandas pip install xlwt环境搭建 登录 zfx 123456

iOS 不用VOIP也可以实现微信来电邀请推送持续响铃震动及及时来电取消

iOS 不用VOIP也可以实现微信来电邀请推送持续响铃震动及及时来电取消 前言一 、关于NotificationServiceExtension1. 创建2. 推送机制3. API 实例方法二、 APNS Push Payload2. 来电邀请2. 来电取消三、数据共享3. 配置共享域四 知识点就到这里,下面是主要代码4. 1 Notificati…

C++ 八数码问题理解 `IDA*` 算法原则:及时止损,缘尽即散

1.前言 八数码是典型的状态搜索案例。如字符串转换问题、密码锁问题都是状态搜索问题。 状态搜索问题指由一种状态转换到到最终状态&#xff0c;求解中间需要经过多少步转换&#xff0c;或者说最小需要转换多少步&#xff0c;或者说有多少种转换方案。本文和大家聊聊八数码问…

如何在Shopify 创建Mega Menu

在Shopify上创建Mega Menu可以通过以下步骤完成&#xff1a; 1. 选择适合的主题 首先&#xff0c;选择一个适合的主题&#xff0c;因为不是所有的Shopify主题都支持Mega Menu。确保选择一个具有自定义菜单功能的主题&#xff0c;或者您可以使用Shopify App Store中的应用程序来…

(十三)【Jmeter】线程(Threads(Users))之tearDown 线程组

简述 操作路径如下: 作用:在正式测试结束后执行清理操作,如关闭连接、释放资源等。配置:设置清理操作的采样器、执行顺序等参数。使用场景:确保在测试结束后应用程序恢复到正常状态,避免资源泄漏或对其他测试的影响。优点:提供清理操作,确保测试环境的整洁和可重复性…

mysql索引问题

今天在工作写项目的时候&#xff0c;突然发现很多地方没有加索引&#xff0c;然后我就去加了&#xff0c;查了不少资料&#xff0c;捡起来了不少东西&#xff0c;来简单聊一聊&#xff0c;工作中最重要的一个细节&#xff1a;索引 mysq的存储结构 首先要聊一聊mysql的存储模式…

Docker镜像和容器

1.Docker的架构和底层技术 Docker提供了一个开发、打包、运行APP&#xff08;应用application&#xff09;的平台把APP和底层infrastructure&#xff08;基础设备&#xff09;隔离开来 ApplicationDocker EngineInfrastructure(physical/virtual) 1.1Docker Engine组成 后台…

【前端素材】推荐优质后台管理系统GramOs平台模板(附源码)

一、需求分析 后台管理系统是一种用于管理网站、应用程序或系统的工具&#xff0c;它通常作为一个独立的后台界面存在&#xff0c;供管理员或特定用户使用。下面详细分析后台管理系统的定义和功能&#xff1a; 1. 定义 后台管理系统是一个用于管理和控制网站、应用程序或系统…

Imagewheel私人图床搭建结合内网穿透实现无公网IP远程访问教程

文章目录 1.前言2. Imagewheel网站搭建2.1. Imagewheel下载和安装2.2. Imagewheel网页测试2.3.cpolar的安装和注册 3.本地网页发布3.1.Cpolar临时数据隧道3.2.Cpolar稳定隧道&#xff08;云端设置&#xff09;3.3.Cpolar稳定隧道&#xff08;本地设置&#xff09; 4.公网访问测…

工具篇-- 定时任务quartz

文章目录 前言一、quartz 介绍:二、quartz 的简单使用:2.1 引入jar&#xff1a;2.2 定义任务&#xff1a; 三、quartz 核心组件:3.1 JobDetail&#xff1a;3.1.1 JobDetail介绍&#xff1a;3.1.2 JobDetail 和job 的关系&#xff1a; 3.2 trigger&#xff1a;3.2.1 trigger 介绍…