图片速览 BitNet: 1-bit LLM

输入数据

  • 模型使用absmax 量化方法进行b比特量化,将输入量化到 [ − Q b , Q b ] ( Q b = 2 b − 1 ) \left[-Q_{b},Q_{b}\right](Q_{b}=2^{b-1}) [Qb,Qb](Qb=2b1)
    x ~ = Q u a n t ( x ) = C l i p ( x × Q b γ , − Q b + ϵ , Q b − ϵ ) , Clip ⁡ ( x , a , b ) = max ⁡ ( a , min ⁡ ( b , x ) ) , γ = ∣ ∣ x ∣ ∣ ∞ , \widetilde{x}=\mathrm{Quant}(x)=\mathrm{Clip}(x\times\frac{Q_b}{\gamma},-Q_b+\epsilon,Q_b-\epsilon),\\ \operatorname{Clip}(x,a,b)=\max(a,\min(b,x)),\quad\gamma=||x||_\infty, x =Quant(x)=Clip(x×γQb,Qb+ϵ,Qbϵ),Clip(x,a,b)=max(a,min(b,x)),γ=∣∣x,

  • 其中 ε 是一个小的浮点数,可防止在执行截断时溢出。

// https://github.com/kyegomez/BitNet/blob/main/bitnet/bitbnet_b158.py
def absmean_quantize_weights(weights):"""Quantizes the weights to -1, 0, or +1 using an absmean quantization function.Parameters:- weights (Tensor): The weights of a neural network layer.Returns:- Tensor: The quantized weights."""# Calculate the average absolute value (γ) of the weightsgamma = torch.mean(torch.abs(weights))# Scale weights by γ and round to the nearest integer among {-1, 0, +1}quantized_weights = torch.clamp(torch.round(weights / gamma), min=-1, max=1)return quantized_weights

权重

  • 权重 W 的二值化可以公式化为:

α = 1 n m ∑ i j W i j W ~ = S i g n ( W − α ) , Sign ⁡ ( W i j ) = { + 1 , if W i j > 0 , − 1 , if W i j ≤ 0 , \\ \alpha=\frac1{nm}\sum_{ij}W_{ij} \\ \widetilde{W}=\mathrm{Sign}(W-\alpha),\\ \left.\operatorname{Sign}(W_{ij})=\left\{\begin{array}{ll}+1,&\quad\text{if}W_{ij}>0,\\-1,&\quad\text{if}W_{ij}\leq0,\end{array}\right.\right. α=nm1ijWijW =Sign(Wα),Sign(Wij)={+1,1,ifWij>0,ifWij0,

在这里插入图片描述

矩阵乘法

  • 使用上述量化方程,矩阵乘法可以写成:

y = W ~ x ~ y=\widetilde W\widetilde{x} y=W x

  • 为了保持量化后的方差,我们在激活量化之前引入了一个 LayerNorm函数。这样,输出 y 的方差就估计为 1

y = W ~ x ~ = W ~ Quant ( LN ( x ) ) × β γ Q b y=\widetilde{W}\widetilde{x}=\widetilde{W}\text{Quant}(\text{LN}(x))\times\frac{\beta\gamma}{Q_b} y=W x =W Quant(LN(x))×Qbβγ
L N ( x ) = x − E ( x ) V a r ( x ) + ϵ , β = 1 n m ∥ W ∥ 1 \mathrm{LN}(x)=\frac{x-E(x)}{\sqrt{\mathrm{Var}(x)+\epsilon}},\quad\beta=\frac1{nm}\|W\|_1 LN(x)=Var(x)+ϵ xE(x),β=nm1W1

在这里插入图片描述

// https://github.com/kyegomez/BitNet/blob/main/bitnet/bitlinear.py
import torch
from torch import Tensor, nnclass BitLinear(nn.Linear):"""BitLinear is a custom linear layer that performs binarization of weights and quantization of activationsin a group-wise manner.Args:in_features (int): Number of input features.out_features (int): Number of output features.bias (bool, optional): If set to False, the layer will not learn an additive bias. Default is True.num_groups (int, optional): Number of groups to divide the weights and activations into. Default is 1."""def __init__(self,in_features: int,out_features: int,bias: bool = True,num_groups: int = 1,b: int = 8,):super().__init__(in_features, out_features, bias)self.in_features = in_featuresself.out_features = out_featuresself.b = bself.num_groups = num_groupsself.eps = 1e-5self.norm = nn.LayerNorm(in_features)def ste(self, x):"""Applies the sign function for binarization and uses Straight-Through Estimator (STE) during backward pass.Args:x (Tensor): Input tensor.Returns:Tensor: Binarized tensor."""binarized_x = torch.sign(x)binarized_x = (binarized_x - x).detach() + xreturn binarized_xdef binarize_weights_groupwise(self):"""Binarizes the weights of the layer in a group-wise manner using STE.Returns:Tensor: Binarized weights tensor."""group_size = self.weight.shape[0] // self.num_groupsbinarized_weights = torch.zeros_like(self.weight)for g in range(self.num_groups):start_idx = g * group_sizeend_idx = (g + 1) * group_sizeweight_group = self.weight[start_idx:end_idx]alpha_g = weight_group.mean()binarized_weights[start_idx:end_idx] = self.ste(weight_group - alpha_g)return binarized_weightsdef quantize_activations_groupwise(self, x):"""Quantizes the activations of the layer in a group-wise manner.Args:x (Tensor): Input tensor.b (int, optional): Number of bits for quantization. Default is 8.Returns:Tensor: Quantized activations tensor."""Q_b = 2 ** (self.b - 1)group_size = x.shape[0] // self.num_groupsquantized_x = torch.zeros_like(x)for g in range(self.num_groups):start_idx = g * group_sizeend_idx = (g + 1) * group_sizeactivation_group = x[start_idx:end_idx]gamma_g = activation_group.abs().max()quantized_x[start_idx:end_idx] = torch.clamp(activation_group * Q_b / (gamma_g + self.eps),-Q_b + self.eps,Q_b - self.eps,)return quantized_xdef dequantize_activations_groupwise(self, x):"""Dequantizes the activations of the layer in a group-wise manner.Args:x (Tensor): Quantized input tensor.b (int, optional): Number of bits used during the quantization. Default is 8.Returns:Tensor: Dequantized activations tensor."""Q_b = 2 ** (self.b - 1)dequantized_x = torch.zeros_like(x)for g in range(self.num_groups):start_idx = g * x.shape[0] // self.num_groupsend_idx = (g + 1) * x.shape[0] // self.num_groupsquantized_group = x[start_idx:end_idx]gamma_g = quantized_group.abs().max()dequantized_x[start_idx:end_idx] = quantized_group * gamma_g / Q_breturn dequantized_xdef forward(self, x: Tensor) -> Tensor:"""Forward pass of the BitLinear layer.Args:x (Tensor): Input tensor.Returns:Tensor: Output tensor."""# Normalize inputx = self.norm(x)# Binarize weights and quantize activationsbinarized_weights = self.binarize_weights_groupwise()# Perform linear transformationoutput = torch.nn.functional.linear(x, binarized_weights, self.bias)# Quantize activationsoutput = self.quantize_activations_groupwise(output)# Dequantize activationsoutput = self.dequantize_activations_groupwise(output)# Return outputreturn output# Example usage
bitlinear = BitLinear(10, 5, num_groups=2, b=8)
input_tensor = torch.randn(5, 10)  # Example input tensor
output = bitlinear(input_tensor)
print(output)  # Example output tensor

CG

  • 【自然语言处理】【大模型】BitNet:用1-bit Transformer训练LLM

  • BitNet: Scaling 1-bit Transformers for Large Language Models

  • The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits

  • Implementation of “BitNet: Scaling 1-bit Transformers for Large Language Models” in pytorch

  • DB-LLM: Accurate Dual-Binarization for Efficient LLMs

  • 如何看待微软提出的BitNet b1.58?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/522213.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

浅谈字典攻击

一、前言 字典攻击是一种常见的密码破解方法,它使用预先编制的字典文件作为攻击字典,通过尝试猜测密码的方式来破解密码。下面是一个关于字典攻击的博客,希望能够为您了解字典攻击提供帮助。 二、字典攻击概述 字典攻击是一种密码破解方法&…

Unity笔记:C#基础(1)

杂项 虚函数 CSDN - C虚函数详解 cnblog - C#中的虚函数virtual 常量池与new 在C#中,string是不可变的,这意味着对string对象的操作通常会返回一个新的string对象,而不会修改原始的string对象。因此,几乎所有涉及更改string内…

(未解决)macOS matplotlib 中文是方框

reference: Mac OS系统下实现python matplotlib包绘图显示中文(亲测有效)_mac plt 中文值-CSDN博客 module ‘matplotlib.font_manager‘ has no attribute ‘_rebuild‘解决方法_font_manager未解析-CSDN博客 # 问题描述(笑死 显而易见 # solve 找到…

开源文生图大模型Playground v2.5发布:超越SD、DALL·E 3和 Midjourney

前言 在AI技术迅速发展的今天,文生图模型成为了艺术创作、设计创新等领域的重要工具。Playground v2.5的发布,不仅在技术上取得了突破,更在开源文化的推广与实践上迈出了重要一步。 Huggingface模型下载:https://huggingface.co/…

虚拟机环境搭建

搭建vm环境,配置虚拟机,期间遇到不支持,重启电脑后还是没用 此主机支持 AMD-V,但 AMD-V 处于禁用状态。 如果已在 BIOS/固件设置中禁用 AMD-V,或主机自更改此设置后从未重新启动,则 AMD-V 可能被禁用。 确…

PDF文件中有多个文件如何一次性的全部分割出来? 这个办法绝对能够帮到你

PDF作为一种常用的文件格式,广泛应用于各种文档、报表、合同等文件的制作和传输。但有时候,我们可能会遇到一个问题:PDF文件中包含了多个文件,我们需要单独提取其中的一个或几个文件。那么,该如何操作呢?下…

常见BUG如何在测试过程中分析定位

前言 在测试的日常工作中,相信经常有测试的小伙伴遇到类似的情况:在项目上线时,只要出现问题(bug),就很容易成为“背锅侠”。 软件测试人员在工作中是无法避免的要和开发人员和产品经理打交道的&#xff…

黑马点评-发布探店笔记

探店笔记 探店笔记类似点评网站的评价,往往是图文结合。 对应的表有两个: tb_blog:探店笔记表,包含笔记中的标题、文字、图片等 tb_blog_comments:其他用户对探店笔记的评价 流程如下: 上传接口&#…

005-事件捕获、冒泡事件委托

事件捕获、冒泡&事件委托 1、事件捕获与冒泡2、事件冒泡示例3、阻止事件冒泡4、阻止事件默认行为5、事件委托6、事件委托优点 1、事件捕获与冒泡 2、事件冒泡示例 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /…

【嵌入式高级C语言】9:万能型链表懒人手册

文章目录 序言单向不循环链表拼图框架搭建 - Necessary功能拼图块1 创建链表头信息结构体 - Necessary2 链表头部插入 - Optional3 链表的遍历 - Optional4 链表的销毁 - Necessary5 链表头信息结构体销毁 - Necessary6 获取链表中节点的个数 - Optional7 链表尾部插入 - Optio…

如何做代币分析:以 ARB 币为例

作者&#xff1a;lesleyfootprint.network 编译&#xff1a;mingfootprint.network 数据源&#xff1a;ARB 代币仪表板 &#xff08;仅包括以太坊数据&#xff09; 在加密货币和数字资产领域&#xff0c;代币分析起着至关重要的作用。代币分析指的是深入研究与代币相关的数据…

【人工智能课程】计算机科学博士作业三

【人工智能课程】计算机科学博士作业三 来源&#xff1a;李宏毅2022课程第10课的作业 1 图片攻击概念 图片攻击是指故意对数字图像进行修改&#xff0c;以使机器学习模型产生错误的输出或者产生预期之外的结果。这种攻击是通过将微小的、通常对人类难以察觉的扰动应用于输入…