【LM、LLM】浅尝二叉树在前馈神经网络上的应用

前言

随着大模型的发展,模型参数量暴涨,以Transformer的为组成成分的隐藏神经元数量增长的越来越多。因此,降低前馈层的推理成本逐渐进入视野。前段时间看到本文介绍的相关工作还是MNIST数据集上的实验,现在这个工作推进到BERT上面来了,再次引起兴趣记录一下。该工作将前馈神经基于二叉树结构进行改装,加速前向传播的速度,称为:快速前馈网络(FFF),然后应用FFF,取代BERT中的前馈网络(FF),实现12个神经元加速推理。

快速前馈网络算法概述

快速前馈网络(Fast Feedforward Network,FFF)是由两部分组成的:节点网络集合 N \mathcal{N} N 和叶子网络集合 L \mathcal{L} L

  • 节点网络集合 N \mathcal{N} N 包含了一组节点网络,每个节点网络都是一个 < dim ⁡ I , n , 1 > \left<\dim_I,n,1\right> dimI,n,1-前馈网络,并在输出上增加了一个 sigmoid 激活函数。这些节点网络按照平衡的可微分二叉树的形式排列,其中 N m + 1 , 2 n N_{m+1,2n} Nm+1,2n N m + 1 , 2 n + 1 N_{m+1,2n+1} Nm+1,2n+1 N m , n N_{m,n} Nm,n 的子节点。

  • 叶子网络集合 L \mathcal{L} L 包含了一组叶子网络,每个叶子网络都是一个 < dim ⁡ I , ℓ , dim ⁡ O > \left<\dim_I,\ell,\dim_O\right> dimI,,dimO-前馈网络。叶子网络没有子节点,它们的输出直接作为 FFF 的输出。

前向传播过程由下面算法控制。

算法的输入包括一个输入样本 ι \iota ι 和根节点 N 0 , 0 N_{0,0} N0,0,输出为该样本在 FFF 中的输出。

算法定义了两个函数: F o r w a r d T Forward_T ForwardT F o r w a r d I {Forward}_I ForwardI。其中, F o r w a r d T {Forward}_T ForwardT 函数用于计算节点的输出,而 F o r w a r d I {Forward}_I ForwardI 函数用于计算节点的指示值(indicator value)。

  • F o r w a r d T {Forward}_T ForwardT 函数中,如果当前节点是叶子节点,则直接调用该节点的前馈传播函数 N m , n ( ι ) N_{m,n}(\iota) Nm,n(ι) 来计算输出。否则,首先计算当前节点的输出 c m , n = N m , n ( ι ) c_{m,n}=N_{m,n}(\iota) cm,n=Nm,n(ι),然后递归地调用 F o r w a r d T {Forward}_T ForwardT 函数来计算当前节点的两个子节点的输出,并将它们加权相加作为当前节点的输出。
  • F o r w a r d I {Forward}_I ForwardI 函数中,如果当前节点是叶子节点,则直接调用该节点的前馈传播函数 N m , n ( ι ) N_{m,n}(\iota) Nm,n(ι) 来计算输出。否则,首先计算当前节点的输出 c m , n = N m , n ( ι ) c_{m,n}=N_{m,n}(\iota) cm,n=Nm,n(ι),然后根据输出值的大小决定选择哪个子节点进行递归计算。


传统前馈神经网络

快速前馈神经网络

与传统的前馈神经网络算法相比,该算法的主要区别在于引入了一个计算节点的指示值。指示值表示了当前节点的输出是否大于等于阈值(这里的阈值为0.5),根据指示值的大小来确定选择哪个子节点进行计算。这种方式可以大大减少计算量,提高前向传播的效率。同时,FFF 是一种具有平衡二叉树结构的前馈神经网络,其中节点网络和叶子网络分别用于处理中间层和输出层的计算。通过利用二叉树结构和递归计算,FFF 可以实现快速的前向传播。

UltraFastBERT

UltraFastBERT,一种BERT变体,在推理过程中使用0.3%的神经元,同时表现 与类似的BERT模型相当。UltraFastBERT选择性地使用4095个神经元中的12个(有选择的执行矩阵乘法(CMM))进行每层推理。这是通过用快速前馈网络(FFFs)取代前馈网络来实现的。

FFF_BMM代码

import torch
from torch import nn
import mathclass FFF(nn.Module):def __init__(self, input_width: int, depth: int, output_width: int, *args, **kwargs):super().__init__(*args, **kwargs)self.input_width = input_widthself.depth = depthself.output_width = output_widthself.n_nodes = 2 ** (depth + 1) - 1self.initialise_weights()def initialise_weights(self):init_factor_l1 = 1.0 / math.sqrt(self.input_width)init_factor_l2 = 1.0 / math.sqrt(self.depth + 1)self.w1s = nn.Parameter(torch.empty(self.n_nodes, self.input_width).uniform_(-init_factor_l1, +init_factor_l1), requires_grad=True)self.w2s = nn.Parameter(torch.empty(self.n_nodes, self.output_width).uniform_(-init_factor_l2, +init_factor_l2), requires_grad=True)def forward(self, x):# the shape of x is (batch_size, input_width)# retrieve the indices of the relevant elementsbatch_size = x.shape[0]current_nodes = torch.zeros((batch_size,), dtype=torch.long, device=x.device)all_nodes = torch.zeros(batch_size, self.depth+1, dtype=torch.long, device=x.device)all_logits = torch.empty((batch_size, self.depth+1), dtype=torch.float, device=x.device)for i in range(self.depth+1):all_nodes[:, i] = current_nodesplane_coeffs = self.w1s.index_select(dim=0, index=current_nodes)			# (batch_size, input_width)plane_coeff_score = torch.bmm(x.unsqueeze(1), plane_coeffs.unsqueeze(-1))	# (batch_size, 1, 1)plane_score = plane_coeff_score.squeeze(-1).squeeze(-1) 					# (batch_size,)all_logits[:, i] = plane_scoreplane_choices = (plane_score >= 0).long()									# (batch_size,)current_nodes = current_nodes * 2 + plane_choices + 1						# (batch_size,)# get the weightsselected_w2s = self.w2s.index_select(0, index=all_nodes.flatten()) \.view(batch_size, self.depth+1, self.output_width)	# (batch_size, depth+1, output_width)# forward passmlp1 = torch.nn.functional.gelu(all_logits)				# (batch_size, depth+1)mlp2 = torch.bmm(mlp1.unsqueeze(1), selected_w2s) 		# (batch_size, output_width)# donereturn mlp2

从代码可以看出,与传统的批矩阵乘法(BMM)不同的是,在forward中,基于二叉树的结构,通过迭代计算节点的索引和权重,使用激活函数(GeLU)对结果进行处理,并最终得到输出。

结果

在推理过程中仅使用0.3%的神经元,同时表现与类似的BERT模型相当(下游任务没有降很多点);实现78倍CPU加速,实现40倍PyTorch加速。

总结

该工作很有趣,将传统前馈神经网络定义成一棵二叉树,其叶子是小型神经网络,在每个非叶子节点处都有一个微小的神经网络(单个神经元也可以工作)来决定走哪条路径取决于在输入上。在训练期间,它们对所选路径进行加权平均值,从而得出树的所有叶子(在输入上评估为神经网络)的总加权平均值,但在推理过程中,它们可以只遵循投票最高的分支,从而得出建议的结果指数加速。并且,基于FFF的思想,将工作推到BERT这种语言模型上,初步证明了大模型的前馈层的神经元并不是都需要参与推理。

文章及公开的代码还介绍了条件矩阵乘法的详细细节,因此感兴趣可以深入研究一下。

参考文献

【1】paper:Exponentially Faster Language Modelling,https://arxiv.org/abs/2311.10770
【2】code:https://github.com/pbelcak/fastfeedforward
【3】paper:Fast Feedforward Networks,https://arxiv.org/abs/2308.14711

【4】code:https://github.com/pbelcak/UltraFastBERT
【5】model:https://huggingface.co/pbelcak/UltraFastBERT-1x11-long

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/216575.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用Linux JumpServer堡垒机本地部署与远程访问

&#x1f308;个人主页&#xff1a;聆风吟 &#x1f525;系列专栏&#xff1a;网络奇遇记、Cpolar杂谈 &#x1f516;少年有梦不应止于心动&#xff0c;更要付诸行动。 文章目录 &#x1f4cb;前言一. 安装Jump server二. 本地访问jump server三. 安装 cpolar内网穿透软件四. 配…

发现有一个会Python的男友魅力值杠杠的!!!

Python能做什么&#xff1f; 可以做日常任务&#xff0c;比如自动备份你的MP3&#xff0c;可以做网站&#xff0c;很多著名的网站像知乎、YouTube就是Python写的&#xff0c; 可以做网络游戏的后台&#xff0c;很多在线游戏的后台都是Python开发的。 上面说的这些本人并没有实…

免费获取GPT-4的五种工具

不可否认&#xff0c;由OpenAI带来的GPT-4已是全球最受欢迎的、功能最强大的大语言模型&#xff08;LLM&#xff09;之一。大多数人都需要使用ChatGPT Plus的订阅服务去访问GPT-4。为此&#xff0c;他们通常需要每月支付20美元。那么问题来了&#xff0c;如果您不想每月有这笔支…

【SpringBoot篇】阿里云OSS—存储文件的利器

文章目录 &#x1f339;什么是阿里云OSS⭐阿里云OSS的优点 &#x1f3f3;️‍&#x1f308;为什么要使用云服务OSS&#x1f384;使用步骤⭐OSS开通⭐参考官方SDK &#x1f354;编写代码⭐上传文件 &#x1f339;综合案例 &#x1f339;什么是阿里云OSS 阿里云对象存储&#xf…

转录组学习第5弹-比对参考基因组

比对参考基因组 在构建文库的过程中需要将DNA片段化&#xff0c;因此测序得到的序列只是基因组的部分序列。为了确定测序reads在基因组上的位置&#xff0c;需要将reads比对回参考基因组上&#xff0c;这个步骤叫做比对&#xff0c;即文献中所提到的alignment或mapping。包括基…

面试题:说说什么是本地缓存、分布式缓存以及多级缓存,它们各自的优缺点?

文章目录 前言一、缓存的概念&#xff08;什么是缓存&#xff09;二、为什么要用缓存&#xff08;为什么要用redis作为缓存&#xff09;三、缓存的分类有哪些1、本地缓存2、分布式缓存3、多级缓存 总结 前言 像MySql等传统的关系型数据库已经不能适用于所有的业务场景&#xf…

JSP EL表达式之 empty

好 本文我们还是继续说EL表达式 我们来讲一个非空判断的好手 empty 我们直接编写代码如下 <% page contentType"text/html; charsetUTF-8" pageEncoding"UTF-8" %> <%request.setCharacterEncoding("UTF-8");%> <!DOCTYPE html&…

CSS-长度单位篇

px&#xff1a;像素em&#xff1a;相对元素font-size的倍数rem&#xff1a;相对根字体大小&#xff0c;html标签就是根%&#xff1a;相对父元素计算 注意&#xff1a;CSS中设置长度&#xff0c;必须加单位&#xff0c;否则样式无效&#xff01;

【LeetCode 热题 HOT 100】题解笔记 —— Day01

❤ 作者主页&#xff1a;欢迎来到我的技术博客&#x1f60e; ❀ 个人介绍&#xff1a;大家好&#xff0c;本人热衷于Java后端开发&#xff0c;欢迎来交流学习哦&#xff01;(&#xffe3;▽&#xffe3;)~* &#x1f34a; 如果文章对您有帮助&#xff0c;记得关注、点赞、收藏、…

远程网络监控(RMON)

远程网络监控是一个使 IT 团队能够获得远程网络可见性的过程&#xff0c;它涉及主动监控网络以帮助网络无缝运行&#xff0c;这些监控远程网络的系统提供对性能的实时洞察&#xff0c;及时检测问题并在影响最终用户之前解决问题。这样&#xff0c;远程网络虽然相距遥远&#xf…

洛谷P1049装箱问题 ————递归+剪枝+回溯

没没没没没没没没没错&#xff0c;又是一道简单的递归&#xff0c;只不过加了剪枝&#xff0c;我已经不想再多说&#xff0c;这道题写了一开始写了普通深搜&#xff0c;然后tle了一个点&#xff0c;后面改成剪枝&#xff0c;就ac了&#xff0c;虽然数据很水&#xff0c;但是不妨…

类和对象(3)日期类的实现

日期类的实现 一&#xff0c;声明二&#xff0c;函数成员定义2.1构造函数2.2获取月份天数2.3比较运算符2.3.1等于和大于2.3.2其他 2.4计算运算符2.4.1 &&2.4.2-&&- 2.5日期-日期 一&#xff0c;声明 class Date { public:Date(int year 1, int month 1, int…