文章目录
- 链接
- 导入所需包
- class ModelArgs
- class Mamba
- def __ init __
- def forward
- class ResidualBlock
- class RNSNorm
- 文本生成demo
manba的简单最小限度实现,和原始论文实现 state-spaces/mamba (github.com)相比,为了可读性对参数没有很好的初始化,原论文用CUDA写了并行扫描,所以速度会快。
这里是剩余部分介绍,主要包括利用MambaBlock和其他组件如残差连接,归一化等定义一个序列模型。
MambaBlock的介绍Mamba-minimal Mamba的最小限度实现 (一)-CSDN博客
链接
来自johnma2006/mamba-minimal: Simple, minimal implementation of the Mamba SSM in one file of PyTorch. (github.com)
导入所需包
from __future__ import annotations
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsum
class ModelArgs
模型参数设置
参数 | 介绍 |
---|---|
d_model | 模型维度,和输入数据通道对应 |
n_layer | 残差块的数目 |
d_state | 潜在状态维度 |
expand | 扩展因子,d_in = d_state * state |
dt_rank | delta的秩 |
d_conv | 1D卷积的卷积核大小 |
vocab_size | 词汇表的大小 |
pad_vocab_size_multiple | 确保vocab_size是设定值的倍数 |
conv_bias | 1D卷积的bias选项 |
bias | lm_head映射的bias选项 |
@dataclass
class ModelArgs:d_model: intn_layer: intvocab_size: intd_state: int = 16expand: int = 2dt_rank: Union[int, str] = 'auto'd_conv: int = 4 pad_vocab_size_multiple: int = 8conv_bias: bool = Truebias: bool = Falsedef __post_init__(self):self.d_inner = int(self.expand * self.d_model)if self.dt_rank == 'auto':self.dt_rank = math.ceil(self.d_model / 16)if self.vocab_size % self.pad_vocab_size_multiple != 0:self.vocab_size += (self.pad_vocab_size_multiple- self.vocab_size % self.pad_vocab_size_multiple)
class Mamba
一个完整的序列处理Mamba模型,包含多个被包裹的MambaBlock。
nn.Embedding参照深度学习:pytorch nn.Embedding详解-CSDN博客
lm_head层则是预测下一个token的输出层,它将模型的输出映射到一个概率分布上,以便于模型预测下一个token,权重和Embedding公用。
输入一个序列 x ( b a t c h _ s i z e , l e n g t h ) x(batch\_size, length) x(batch_size,length) 简写为 ( b , l ) (b, l) (b,l),输出取词的概率 ( b , l , v o c a b _ s i z e ) (b, l, vocab\_size) (b,l,vocab_size)
组件 | 尺寸变换 |
---|---|
embedding | (b, l) -> (b, l, d_model) |
layers | (b, l, d_model) -> (b, l, d_model) |
norm_f | \ |
lm_head | (b, l, d_model) -> (b, l, vocab_size) |
def __ init __
class Mamba(nn.Module):def __init__(self, args: ModelArgs):"""Full Mamba model."""super().__init__()self.args = argsself.embedding = nn.Embedding(args.vocab_size, args.d_model)self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])self.norm_f = RMSNorm(args.d_model)self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)self.lm_head.weight = self.embedding.weight # Tie output projection to embedding weights.# See "Weight Tying" paper
def forward
def forward(self, input_ids):x = self.embedding(input_ids)for layer in self.layers:x = layer(x)x = self.norm_f(x)logits = self.lm_head(x)return logits
class ResidualBlock
一个包裹MambaBlock的一个残差块
MambaBlock的介绍Mamba-minimal Mamba的最小限度实现 (一)-CSDN博客
class ResidualBlock(nn.Module):def __init__(self, args: ModelArgs):super().__init__()self.args = argsself.mixer = MambaBlock(args)self.norm = RMSNorm(args.d_model)def forward(self, x):output = self.mixer(self.norm(x)) + xreturn output
class RNSNorm
所用到的归一化
可以参考RMSNorm论文阅读-CSDN博客
LLM中的RMSNorm - 知乎 (zhihu.com)
class RMSNorm(nn.Module):def __init__(self,d_model: int,eps: float = 1e-5):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(d_model))def forward(self, x):output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weightreturn output
文本生成demo
来自demo.ipynb
这里是一个colab_demo
加载模型
from model import Mamba, ModelArgs
from transformers import AutoTokenizer# One of:
# 'state-spaces/mamba-2.8b-slimpj'
# 'state-spaces/mamba-2.8b'
# 'state-spaces/mamba-1.4b'
# 'state-spaces/mamba-790m'
# 'state-spaces/mamba-370m'
# 'state-spaces/mamba-130m'
pretrained_model_name = 'state-spaces/mamba-370m'model = Mamba.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
生成文本
在概率为top-k的输出中采样
import torch
import torch.nn.functional as Fdef generate(model,tokenizer,prompt: str,n_tokens_to_gen: int = 50,sample: bool = True,top_k: int = 40):model.eval()input_ids = tokenizer(prompt, return_tensors='pt').input_idsfor token_n in range(n_tokens_to_gen):with torch.no_grad():indices_to_input = input_idsnext_token_logits = model(indices_to_input)[:, -1]probs = F.softmax(next_token_logits, dim=-1)(batch, vocab_size) = probs.shapeif top_k is not None:(values, indices) = torch.topk(probs, k=top_k)probs[probs < values[:, -1, None]] = 0probs = probs / probs.sum(axis=1, keepdims=True)if sample:next_indices = torch.multinomial(probs, num_samples=1)else:next_indices = torch.argmax(probs, dim=-1)[:, None]input_ids = torch.cat([input_ids, next_indices], dim=1)output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]return output_completions
print(generate(model, tokenizer, 'Mamba is the'))