ChatGLM2-6B Lora 微调训练医疗问答任务

一、ChatGLM2-6B Lora 微调

LoRA 微调技术的思想很简单,在原始 PLM (Pre-trained Language Model) 增加一个旁路,一般是在 transformer 层,做一个降维再升维的操作,模型的输入输出维度不变,来模拟 intrinsic rank,如下图的 AB。训练时冻结 PLM 的参数,只训练 AB ,,输出时将旁路输出与 PLM 的参数叠加,进而影响原始模型的效果。该方式,可以大大降低训练的参数量,而性能可以优于其它参数高效微调方法,甚至和全参数微调(Fine-Tuning)持平甚至超过。

对于 AB 参数的初始化,A 使用随机高斯分布,B 使用 0 矩阵,这样在最初时可以保证旁路为一个 0 矩阵,最开始时使用原始模型的能力。

在这里插入图片描述
对于 lora 微调的实现可以使用 HuggingFace 开源的 PEFT 库,地址如下:

https://github.com/huggingface/peft

下载依赖:

pip install peft -i https://pypi.tuna.tsinghua.edu.cn/simple

使用方式也很简单,例如先查看 ChatGLM2-6B 的模型结构:

from transformers import AutoModelmodel_name = "chatglm-6b"
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
print(model)

输出结果:

ChatGLMForConditionalGeneration((transformer): ChatGLMModel((embedding): Embedding((word_embeddings): Embedding(65024, 4096))(rotary_pos_emb): RotaryEmbedding()(encoder): GLMTransformer((layers): ModuleList((0-27): 28 x GLMBlock((input_layernorm): RMSNorm()(self_attention): SelfAttention((query_key_value): Linear(in_features=4096, out_features=4608, bias=True)(core_attention): CoreAttention((attention_dropout): Dropout(p=0.0, inplace=False))(dense): Linear(in_features=4096, out_features=4096, bias=False))(post_attention_layernorm): RMSNorm()(mlp): MLP((dense_h_to_4h): Linear(in_features=4096, out_features=27392, bias=False)(dense_4h_to_h): Linear(in_features=13696, out_features=4096, bias=False))))(final_layernorm): RMSNorm())(output_layer): Linear(in_features=4096, out_features=65024, bias=False))
)

可以看出 ChatGLM 主要由 28 层的 GLMBlock 进行提取和理解语义特征,下面借助 PEFT 库将 Lora 旁路层注入到模型中,主要关注下 query_key_value 层的变化:

from transformers import AutoTokenizer, AutoModel, AutoConfig
from peft import LoraConfig, get_peft_model, TaskTypemodel_name = "chatglm-6b"
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)config = LoraConfig(peft_type="LORA",task_type=TaskType.CAUSAL_LM,inference_mode=False,r=8,lora_alpha=16,lora_dropout=0.1,fan_in_fan_out=False,bias='lora_only',target_modules=["query_key_value"]
)model = get_peft_model(model, config)
print(model)

其中 r 就是 lora 中秩的大小。

输出结果:

PeftModelForCausalLM((base_model): LoraModel((model): ChatGLMForConditionalGeneration((transformer): ChatGLMModel((embedding): Embedding((word_embeddings): Embedding(65024, 4096))(rotary_pos_emb): RotaryEmbedding()(encoder): GLMTransformer((layers): ModuleList((0-27): 28 x GLMBlock((input_layernorm): RMSNorm()(self_attention): SelfAttention((query_key_value): Linear(in_features=4096, out_features=4608, bias=True(lora_dropout): ModuleDict((default): Dropout(p=0.1, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=4096, out_features=8, bias=False))(lora_B): ModuleDict((default): Linear(in_features=8, out_features=4608, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict())(core_attention): CoreAttention((attention_dropout): Dropout(p=0.0, inplace=False))(dense): Linear(in_features=4096, out_features=4096, bias=False))(post_attention_layernorm): RMSNorm()(mlp): MLP((dense_h_to_4h): Linear(in_features=4096, out_features=27392, bias=False)(dense_4h_to_h): Linear(in_features=13696, out_features=4096, bias=False))))(final_layernorm): RMSNorm())(output_layer): Linear(in_features=4096, out_features=65024, bias=False))))
)

可以对比下原始的 ChatGLM 模型结构, query_key_value 层中已经被加入下 loraAB 层,下面可以通过 model.print_trainable_parameters() 打印可训练的参数量:

trainable params: 2,078,720 || all params: 6,245,533,696 || trainable%: 0.03328330453698988

可以看到可训练的参数量只有 0.03328330453698988

下面依然借助前面文章使用的医疗问答数据集,在 ChatGLM2 lora 微调下的效果。

对该数据集不了解的小伙伴可以参考下面这篇文章:

ChatGLM2-6B P-Tuning v2 微调训练医疗问答任务

二、ChatGLM2-6B Lora 微调

解析数据,构建 Dataset 数据集 qa_dataset.py

# -*- coding: utf-8 -*-
from torch.utils.data import Dataset
import torch
import json
import numpy as npclass QADataset(Dataset):def __init__(self, data_path, tokenizer, max_source_length, max_target_length) -> None:super().__init__()self.tokenizer = tokenizerself.max_source_length = max_source_lengthself.max_target_length = max_target_lengthself.max_seq_length = self.max_source_length + self.max_target_lengthself.data = []with open(data_path, "r", encoding='utf-8') as f:for line in f:if not line or line == "":continuejson_line = json.loads(line)content = json_line["content"]summary = json_line["summary"]self.data.append({"question": content,"answer": summary})print("data load , size:", len(self.data))def preprocess(self, question, answer):prompt = self.tokenizer.build_prompt(question, None)a_ids = self.tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,max_length=self.max_source_length)b_ids = self.tokenizer.encode(text=answer, add_special_tokens=False, truncation=True,max_length=self.max_target_length)context_length = len(a_ids)input_ids = a_ids + b_ids + [self.tokenizer.eos_token_id]labels = [self.tokenizer.pad_token_id] * context_length + b_ids + [self.tokenizer.eos_token_id]pad_len = self.max_seq_length - len(input_ids)input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_lenlabels = labels + [self.tokenizer.pad_token_id] * pad_lenlabels = [(l if l != self.tokenizer.pad_token_id else -100) for l in labels]return input_ids, labelsdef __getitem__(self, index):item_data = self.data[index]input_ids, labels = self.preprocess(**item_data)return {"input_ids": torch.LongTensor(np.array(input_ids)),"labels": torch.LongTensor(np.array(labels))}def __len__(self):return len(self.data)

构造 Lora 结构,微调训练 train_lora.py

# -*- coding: utf-8 -*-
import pandas as pd
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from qa_dataset import QADataset
from peft import LoraConfig, get_peft_model, TaskType
from tqdm import tqdm
import torch
import os, time, sysdef train(epoch, model, device, loader, optimizer, gradient_accumulation_steps):model.train()time1 = time.time()for index, data in enumerate(tqdm(loader, file=sys.stdout, desc="Train Epoch: " + str(epoch))):input_ids = data['input_ids'].to(device, dtype=torch.long)labels = data['labels'].to(device, dtype=torch.long)outputs = model(input_ids=input_ids,labels=labels,)loss = outputs.loss# 反向传播,计算当前梯度loss.backward()# 梯度累积步数if (index % gradient_accumulation_steps == 0 and index != 0) or index == len(loader) - 1:# 更新网络参数optimizer.step()# 清空过往梯度optimizer.zero_grad()# 100轮打印一次 lossif index % 100 == 0 or index == len(loader) - 1:time2 = time.time()tqdm.write(f"{index}, epoch: {epoch} -loss: {str(loss)} ; each step's time spent: {(str(float(time2 - time1) / float(index + 0.0001)))}")def validate(tokenizer, model, device, loader, max_length):model.eval()predictions = []actuals = []with torch.no_grad():for _, data in enumerate(tqdm(loader, file=sys.stdout, desc="Validation Data")):input_ids = data['input_ids'].to(device, dtype=torch.long)labels = data['labels'].to(device, dtype=torch.long)generated_ids = model.generate(input_ids=input_ids,max_length=max_length,do_sample=False,temperature=0)preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g ingenerated_ids]target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in labels]predictions.extend(preds)actuals.extend(target)return predictions, actualsdef main():model_name = "chatglm-6b"train_json_path = "./data/train.json"val_json_path = "./data/val.json"max_source_length = 128max_target_length = 512epochs = 5batch_size = 1lr = 1e-4lora_rank = 8lora_alpha = 32gradient_accumulation_steps = 16model_output_dir = "output"# 设备device = torch.device("cuda:0")# 加载分词器和模型tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)model = AutoModel.from_pretrained(model_name, trust_remote_code=True)# setup peftpeft_config = LoraConfig(task_type=TaskType.CAUSAL_LM,inference_mode=False,r=lora_rank,lora_alpha=lora_alpha,lora_dropout=0.1)model = get_peft_model(model, peft_config)model.is_parallelizable = Truemodel.model_parallel = Truemodel.print_trainable_parameters()# 转为半精度model = model.half()model.float()print("Start Load Train Data...")train_params = {"batch_size": batch_size,"shuffle": True,"num_workers": 0,}training_set = QADataset(train_json_path, tokenizer, max_source_length, max_target_length)training_loader = DataLoader(training_set, **train_params)print("Start Load Validation Data...")val_params = {"batch_size": batch_size,"shuffle": False,"num_workers": 0,}val_set = QADataset(val_json_path, tokenizer, max_source_length, max_target_length)val_loader = DataLoader(val_set, **val_params)optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)model = model.to(device)print("Start Training...")for epoch in range(epochs):train(epoch, model, device, training_loader, optimizer, gradient_accumulation_steps)print("Save Model To ", model_output_dir)model.save_pretrained(model_output_dir)# 验证print("Start Validation...")with torch.no_grad():predictions, actuals = validate(tokenizer, model, device, val_loader, max_target_length)# 验证结果存储final_df = pd.DataFrame({"Generated Text": predictions, "Actual Text": actuals})val_data_path = os.path.join(model_output_dir, "predictions.csv")final_df.to_csv(val_data_path)print("Validation Data To ", val_data_path)if __name__ == '__main__':main()

开始训练:

在这里插入图片描述

等待训练结束后,可以在输出目录看到保存的模型,仅只有 lora 层的参数,所以模型比较小:

在这里插入图片描述

此时可以查看下 predictions.csv 中验证集的效果。

三、模型测试

from transformers import AutoTokenizer, AutoModel, AutoConfig
from peft import PeftConfig, PeftModel, LoraConfig, get_peft_model, TaskType
import torchdef load_lora_config(model):config = LoraConfig(task_type=TaskType.CAUSAL_LM,inference_mode=False,r=8,lora_alpha=32,lora_dropout=0.1,target_modules=["query_key_value"])return get_peft_model(model, config)device = torch.device("cuda:0")model_name = "chatglm-6b"
lora_dir = "output"model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)config = PeftConfig.from_pretrained(lora_dir)
model = PeftModel.from_pretrained(model, lora_dir)model = model.to(device)
model.eval()response, history = model.chat(tokenizer, "5月至今上腹靠右隐痛,右背隐痛带酸,便秘,喜睡,时有腹痛,头痛,腰酸症状?", history=[])
print("回答:", response)

输出:

在这里插入图片描述

回答: 你好,根据你的叙述,考虑是胃炎引来的。建议你平时留意饮食规律,不要吃辛辣刺激性食物,多喝热水,可以口服奥美拉唑肠溶胶囊和阿莫西林胶囊实施救治,如果效果不好,建议去医院做胃镜仔细检查。除了及时救治胃痛外,患者朋友理应始终保持愉快的心态去直面疾病,只有这样才能令得患者及时对症救治,同时要多看重自身饮食护理,多观注自身的症状变动,认为这样一定能将胃痛撵走。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/110300.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

pytorch代码实现之动态卷积模块ODConv

ODConv动态卷积模块 ODConv可以视作CondConv的延续,将CondConv中一个维度上的动态特性进行了扩展,同时了考虑了空域、输入通道、输出通道等维度上的动态性,故称之为全维度动态卷积。ODConv通过并行策略采用多维注意力机制沿核空间的四个维度…

【JavaSE笔记】抽象类与接口

一、抽象类 1、概念 在面向对象的概念中,所有的对象都是通过类来描绘的,但是反过来,并不是所有的类都是用来描绘对象的,如果一个类中没有包含足够的信息来描绘一个具体的对象,这样的类就是抽象类。 package demo2…

气传导耳机哪个好?值得推荐的气传导耳机分享

​随着生活节奏的加快,人们越来越关注听力健康。气传导耳机以其独特的传导方式和舒适的佩戴感受,逐渐成为耳机市场的新宠。气传导耳机不入耳设计听音,让你在享受音乐的同时,也能保护你的听力安全。今天我们就一起来看看几款值得大…

饲料添加剂 微生物 屎肠球菌

声明 本文是学习GB 7300.503-2023 饲料添加剂 第5部分:微生物 屎肠球菌. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本文件规定了饲料添加剂屎肠球菌的技术要求、采样、检验规则、标签、包装、运输、贮存和保质 期&#xff0…

【vue】vue 中插槽的三种类型:

文章目录 一、匿名插槽&#xff1a;二、具名插槽&#xff1a;三、作用域插槽 一、匿名插槽&#xff1a;<slot></slot> 1.没有为插槽指定名称 2.通过slot标签可以添加匿名插槽 3.在使用组件的时候&#xff0c;组件中的内容会填充到所有匿名插槽的位置&#xff0c;所…

数据结构——散列函数、散列表

文章目录 前言一、散列表的基本概念二、散列函数的构造方法三、处理冲突的方法1. 开放定址法&#xff1a;2. 拉链法 四、散列查找及性能分析总结 前言 散列表的基本概念散列函数的构造方法处理冲突的方法散列查找及性能分析 提示&#xff1a;以下是本篇文章正文内容&#xff0…

七天学会C语言-第一天(C语言基本语句)

一、固定格式 这个是C程序的基本框架&#xff0c;需要记住&#xff01;&#xff01;&#xff01; #include<stdio.h>int main(){return 0; }二、printf 语句 简单输出一句C程序&#xff1a; #include<stdio.h> int main(){printf("大家好&#xff0c;&quo…

S7-1200PLC和LED电子看板通信(TCP/IP)

S7-200SMART PLC和LED电子看板通信应用,请查看下面文章链接: SMART 200 PLC UDP通讯应用LED看板_RXXW_Dor的博客-CSDN博客开放式用户通信 (OUC) 库:数据解析:https://rxxw-control.blog.csdn.net/article/details/121424897这篇博客我们主要介绍S7-1200PLC和LED电子看板通…

PowerDesigner 逆向工程以及IDEA中UML插件

1、MySQL数据库连接&#xff08;JDBC方式&#xff09; 1.1 新建一个pdm&#xff0c;dbms选择mysql 1.2 Database - Connect 选择数据库连接 1.3 配置连接信息 数据库连接这里是通过一个配置文件来获取连接信息的&#xff0c;首次的话因为没有&#xff0c;所以我们需要选择…

Michael.W基于Foundry精读Openzeppelin第34期——MerkleProof.sol

Michael.W基于Foundry精读Openzeppelin第34期——MerkleProof.sol 0. 版本0.1 MerkleProof.sol 1. 目标合约2. 代码精读2.1 processProof(bytes32[] memory proof, bytes32 leaf) && processProofCalldata(bytes32[] calldata proof, bytes32 leaf)2.2 verify(bytes32[…

人工智能现在可以从文本中生成具有CD音质的音乐,而且只会越来越好

想象一下&#xff0c;键入“戏剧性的介绍音乐”并听到一首飙升的交响乐&#xff0c;或者编写“令人毛骨悚然的脚步声”并获得高质量的音效。这是稳定音频的承诺&#xff0c;一个文本到音频的人工智能模型周三宣布由能合成立体声的稳定人工智能44.1千赫来自文字描述的音乐或声音…

进化算法、遗传编程和学习

一、说明 进化算法是一系列搜索算法&#xff0c;其灵感来自自然界&#xff08;达尔文主义&#xff09;进化过程。所有不同家庭成员的共同点是&#xff0c;通过应用受自然遗传学和自然选择启发的 算子&#xff0c;通过进化出最初 随机的候选解决方案群体来解决问题&#…