Mamba-minimal Mamba的最小限度实现 (二)

文章目录

    • 链接
    • 导入所需包
    • class ModelArgs
    • class Mamba
      • def __ init __
      • def forward
    • class ResidualBlock
    • class RNSNorm
    • 文本生成demo

manba的简单最小限度实现,和原始论文实现 state-spaces/mamba (github.com)相比,为了可读性对参数没有很好的初始化,原论文用CUDA写了并行扫描,所以速度会快。
这里是剩余部分介绍,主要包括利用MambaBlock和其他组件如残差连接,归一化等定义一个序列模型。

MambaBlock的介绍Mamba-minimal Mamba的最小限度实现 (一)-CSDN博客

链接

来自johnma2006/mamba-minimal: Simple, minimal implementation of the Mamba SSM in one file of PyTorch. (github.com)

导入所需包

from __future__ import annotations
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsum

class ModelArgs

模型参数设置

参数介绍
d_model模型维度,和输入数据通道对应
n_layer残差块的数目
d_state潜在状态维度
expand扩展因子,d_in = d_state * state
dt_rankdelta的秩
d_conv1D卷积的卷积核大小
vocab_size词汇表的大小
pad_vocab_size_multiple确保vocab_size是设定值的倍数
conv_bias1D卷积的bias选项
biaslm_head映射的bias选项
@dataclass
class ModelArgs:d_model: intn_layer: intvocab_size: intd_state: int = 16expand: int = 2dt_rank: Union[int, str] = 'auto'd_conv: int = 4 pad_vocab_size_multiple: int = 8conv_bias: bool = Truebias: bool = Falsedef __post_init__(self):self.d_inner = int(self.expand * self.d_model)if self.dt_rank == 'auto':self.dt_rank = math.ceil(self.d_model / 16)if self.vocab_size % self.pad_vocab_size_multiple != 0:self.vocab_size += (self.pad_vocab_size_multiple- self.vocab_size % self.pad_vocab_size_multiple)

class Mamba

一个完整的序列处理Mamba模型,包含多个被包裹的MambaBlock。

nn.Embedding参照深度学习:pytorch nn.Embedding详解-CSDN博客

lm_head层则是预测下一个token的输出层,它将模型的输出映射到一个概率分布上,以便于模型预测下一个token,权重和Embedding公用。

输入一个序列 x ( b a t c h _ s i z e , l e n g t h ) x(batch\_size, length) x(batch_size,length) 简写为 ( b , l ) (b, l) (b,l),输出取词的概率 ( b , l , v o c a b _ s i z e ) (b, l, vocab\_size) (b,l,vocab_size)

组件尺寸变换
embedding(b, l) -> (b, l, d_model)
layers(b, l, d_model) -> (b, l, d_model)
norm_f\
lm_head(b, l, d_model) -> (b, l, vocab_size)

def __ init __

class Mamba(nn.Module):def __init__(self, args: ModelArgs):"""Full Mamba model."""super().__init__()self.args = argsself.embedding = nn.Embedding(args.vocab_size, args.d_model)self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])self.norm_f = RMSNorm(args.d_model)self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)self.lm_head.weight = self.embedding.weight  # Tie output projection to embedding weights.# See "Weight Tying" paper

def forward

 def forward(self, input_ids):x = self.embedding(input_ids)for layer in self.layers:x = layer(x)x = self.norm_f(x)logits = self.lm_head(x)return logits

class ResidualBlock

一个包裹MambaBlock的一个残差块

MambaBlock的介绍Mamba-minimal Mamba的最小限度实现 (一)-CSDN博客

class ResidualBlock(nn.Module):def __init__(self, args: ModelArgs):super().__init__()self.args = argsself.mixer = MambaBlock(args)self.norm = RMSNorm(args.d_model)def forward(self, x):output = self.mixer(self.norm(x)) + xreturn output

class RNSNorm

所用到的归一化

可以参考RMSNorm论文阅读-CSDN博客

LLM中的RMSNorm - 知乎 (zhihu.com)

class RMSNorm(nn.Module):def __init__(self,d_model: int,eps: float = 1e-5):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(d_model))def forward(self, x):output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weightreturn output

文本生成demo

来自demo.ipynb

这里是一个colab_demo

加载模型

from model import Mamba, ModelArgs
from transformers import AutoTokenizer# One of:
#     'state-spaces/mamba-2.8b-slimpj'
#     'state-spaces/mamba-2.8b'
#     'state-spaces/mamba-1.4b'
#     'state-spaces/mamba-790m'
#     'state-spaces/mamba-370m'
#     'state-spaces/mamba-130m'
pretrained_model_name = 'state-spaces/mamba-370m'model = Mamba.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

生成文本
在概率为top-k的输出中采样

import torch
import torch.nn.functional as Fdef generate(model,tokenizer,prompt: str,n_tokens_to_gen: int = 50,sample: bool = True,top_k: int = 40):model.eval()input_ids = tokenizer(prompt, return_tensors='pt').input_idsfor token_n in range(n_tokens_to_gen):with torch.no_grad():indices_to_input = input_idsnext_token_logits = model(indices_to_input)[:, -1]probs = F.softmax(next_token_logits, dim=-1)(batch, vocab_size) = probs.shapeif top_k is not None:(values, indices) = torch.topk(probs, k=top_k)probs[probs < values[:, -1, None]] = 0probs = probs / probs.sum(axis=1, keepdims=True)if sample:next_indices = torch.multinomial(probs, num_samples=1)else:next_indices = torch.argmax(probs, dim=-1)[:, None]input_ids = torch.cat([input_ids, next_indices], dim=1)output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]return output_completions

在这里插入图片描述

print(generate(model, tokenizer, 'Mamba is the'))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/525909.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

菜鸟笔记-14Python绘图颜色使用

Python中绘图主要依赖于各种库&#xff0c;其中matplotlib是最常用且功能强大的一个。在matplotlib中&#xff0c;你可以使用各种颜色来表示不同的数据点、线条或填充区域。下面我将详细介绍如何在Python中使用matplotlib来设置绘图颜色&#xff0c;并给出具体的例子。 14.1颜…

面向对象高级编程下

面向对象高级编程下 面向对象高级编程下一. 转换函数二. non-explict-one-argument ctor三. explicit-one-argument ctor四. pointer-like classes1. 智能指针2. 迭代器 五. function-like classes六. namespace七. 模板1.类模板2.函数模板3.成员模板 八.模板特化和偏特化1. 模…

Edu18 -- Divide by Three --- 题解

目录 Divide by Three&#xff1a; 题目大意&#xff1a; ​编辑​编辑思路解析&#xff1a; 代码实现&#xff1a; Divide by Three&#xff1a; 题目大意&#xff1a; 思路解析&#xff1a; 一个数字是3的倍数&#xff0c;那么他的数位之和也是3的倍数&#xff0c;所以我…

0基础学习VR全景平台篇第143篇:限定访问功能

大家好&#xff0c;欢迎观看蛙色VR官方——后台使用系列课程&#xff01;这期&#xff0c;我们将为大家介绍如何使用限定访问功能。 一.什么是限定访问功能&#xff1f; 限定访问&#xff0c;就是可以在编辑后台设置可以访问作品的用户的类型&#xff0c;还有可以访问作品的IP…

【每日刷题】栈与队列-LC394、LC347、LC215

题外话&#xff1a;感觉脑子没长到栈这块…最近刷栈的题都好难啊…哭哭…坚持坚持&#xff01;多刷几遍就好了&#xff01;&#xff01; 1. LC394.字符串解码 题目链接 先说数据结构。 维护两个栈&#xff1a;一个栈存之前的字符串&#xff0c;另一个栈存之后的字符串的重复…

OpenAI GPT LLMs 高级提示词工程方法汇总

原文地址&#xff1a;An Introduction to Prompt Engineering for OpenAI GPT LLMs Github&#xff1a;Prompt-Engineering-Intro 2023 年 3 月 2 日 Naive 提示词&#xff1a;带有提示的情感分类器 prompt Decide whether a Tweets sentiment is positive, neutral, or …

基于ThinkPHP框架的校园一卡通系统设计与实现

目 录 摘 要 I Abstract II 引 言 1 1 相关技术 3 1.1 框架技术 3 1.1.1 Bootstrap 3 1.1.2 ThinkPHP框架 3 1.2 前端技术 4 1.2.1 JavaScript 4 1.2.2 ECharts 4 1.3 B/S架构 4 1.4 数据库技术 5 1.4.1 MySQL 5 1.5 本章小结 6 2 系统分析 7 2.1 功能需求分析 7 2.2 非功能需…

每日OJ题_路径dp①_力扣62. 不同路径

目录 力扣62. 不同路径 解析代码 力扣62. 不同路径 62. 不同路径 难度 中等 一个机器人位于一个 m x n 网格的左上角 &#xff08;起始点在下图中标记为 “Start” &#xff09;。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角&#xff08;在下图中标…

博士推荐 | 薄膜、涂层技术和液晶材料/器件领域的博士

编辑 / 木子 审核 / 朝阳 伟骅英才 伟骅英才致力于以大数据、区块链、AI人工智能等前沿技术打造开放的人力资本生态&#xff0c;用科技解决职业领域问题&#xff0c;提升行业数字化服务水平&#xff0c;提供创新型的产业与人才一体化服务的人力资源解决方案和示范平台&#x…

SLAM|初识SLAM

在空间中&#xff0c;人可以通过固定不动的事物来作为参考系中的参照物。 而这些固定不动的东西可以称之为特征&#xff0c;空间可以理解成特征存在的空间。 而参照物的意义&#xff0c;可以变成是看到某某参照物&#xff0c;就按这个某某参照物进行位置移动。 比如说碰到这个…

基于Springboot的高校宣讲会管理系统。Javaee项目,springboot项目。

演示视频&#xff1a; 基于Springboot的高校宣讲会管理系统。Javaee项目&#xff0c;springboot项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结构&#xff0c;通过Spring Spri…

【Python】专栏文章索引

为了方便 快速定位 和 便于文章间的相互引用等 作为一个快速准确的导航工具 Python 目录&#xff1a; &#xff08;一&#xff09;装饰器函数 &#xff08;二&#xff09;牛客网—软件开发-Python专项练习 &#xff08;三&#xff09;time模块