GPT建模与预测实战

代码链接见文末

效果图:

1.数据样本生成方法

训练配置参数:

--epochs 40 --batch_size 8 --device 0 --train_path data/train.pkl

其中train.pkl是处理后的文件

因此,我们首先需要执行preprocess.py进行预处理操作,配置参数:

--data_path data/novel --save_path data/train.pkl --win_size 200 --step 200

其中--vocab_file是语料表,一般不用修改,--log_path是日志路径

预处理流程如下:

  • 首先,初始化tokenizer
  • 读取作文数据集目录下的所有文件,预处理后,对于每条数据,使用滑动窗口对其进行截断
  • 最后,序列化训练数据 

代码如下:

# 初始化tokenizertokenizer = CpmTokenizer(vocab_file="vocab/chinese_vocab.model")#pip install jiebaeod_id = tokenizer.convert_tokens_to_ids("<eod>")   # 文档结束符sep_id = tokenizer.sep_token_id# 读取作文数据集目录下的所有文件train_list = []logger.info("start tokenizing data")for file in tqdm(os.listdir(args.data_path)):file = os.path.join(args.data_path, file)with open(file, "r", encoding="utf8")as reader:lines = reader.readlines()title = lines[1][3:].strip()    # 取出标题lines = lines[7:]   # 取出正文内容article = ""for line in lines:if line.strip() != "":  # 去除换行article += linetitle_ids = tokenizer.encode(title, add_special_tokens=False)article_ids = tokenizer.encode(article, add_special_tokens=False)token_ids = title_ids + [sep_id] + article_ids + [eod_id]# train_list.append(token_ids)# 对于每条数据,使用滑动窗口对其进行截断win_size = args.win_sizestep = args.stepstart_index = 0end_index = win_sizedata = token_ids[start_index:end_index]train_list.append(data)start_index += stepend_index += stepwhile end_index+50 < len(token_ids):  # 剩下的数据长度,大于或等于50,才加入训练数据集data = token_ids[start_index:end_index]train_list.append(data)start_index += stepend_index += step# 序列化训练数据with open(args.save_path, "wb") as f:pickle.dump(train_list, f)

2.模型训练过程

 (1) 数据与标签

        在训练过程中,我们需要根据前面的内容预测后面的内容,因此,对于每一个词的标签需要向后错一位。最终预测的是每一个位置的下一个词的token_id的概率。

(2)训练过程

        对于每一轮epoch,我们需要统计该batch的预测token的正确数与总数,并计算损失,更新梯度。

训练配置参数:

--epochs 40 --batch_size 8 --device 0 --train_path data/train.pkl
def train_epoch(model, train_dataloader, optimizer, scheduler, logger,epoch, args):model.train()device = args.deviceignore_index = args.ignore_indexepoch_start_time = datetime.now()total_loss = 0  # 记录下整个epoch的loss的总和epoch_correct_num = 0   # 每个epoch中,预测正确的word的数量epoch_total_num = 0  # 每个epoch中,预测的word的总数量for batch_idx, (input_ids, labels) in enumerate(train_dataloader):# 捕获cuda out of memory exceptiontry:input_ids = input_ids.to(device)labels = labels.to(device)outputs = model.forward(input_ids, labels=labels)logits = outputs.logitsloss = outputs.lossloss = loss.mean()# 统计该batch的预测token的正确数与总数batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index)# 统计该epoch的预测token的正确数与总数epoch_correct_num += batch_correct_numepoch_total_num += batch_total_num# 计算该batch的accuracybatch_acc = batch_correct_num / batch_total_numtotal_loss += loss.item()if args.gradient_accumulation_steps > 1:loss = loss / args.gradient_accumulation_stepsloss.backward()# 梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)# 进行一定step的梯度累计之后,更新参数if (batch_idx + 1) % args.gradient_accumulation_steps == 0:# 更新参数optimizer.step()# 更新学习率scheduler.step()# 清空梯度信息optimizer.zero_grad()if (batch_idx + 1) % args.log_step == 0:logger.info("batch {} of epoch {}, loss {}, batch_acc {}, lr {}".format(batch_idx + 1, epoch + 1, loss.item() * args.gradient_accumulation_steps, batch_acc, scheduler.get_lr()))del input_ids, outputsexcept RuntimeError as exception:if "out of memory" in str(exception):logger.info("WARNING: ran out of memory")if hasattr(torch.cuda, 'empty_cache'):torch.cuda.empty_cache()else:logger.info(str(exception))raise exception# 记录当前epoch的平均loss与accuracyepoch_mean_loss = total_loss / len(train_dataloader)epoch_mean_acc = epoch_correct_num / epoch_total_numlogger.info("epoch {}: loss {}, predict_acc {}".format(epoch + 1, epoch_mean_loss, epoch_mean_acc))# save modellogger.info('saving model for epoch {}'.format(epoch + 1))model_path = join(args.save_model_path, 'epoch{}'.format(epoch + 1))if not os.path.exists(model_path):os.mkdir(model_path)model_to_save = model.module if hasattr(model, 'module') else modelmodel_to_save.save_pretrained(model_path)logger.info('epoch {} finished'.format(epoch + 1))epoch_finish_time = datetime.now()logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))return epoch_mean_loss

(3)部署与网页预测展示

        app.py既是模型预测文件,又能够在网页中展示,这需要我们下载一个依赖库:

pip install streamlit

        

生成下一个词流程,每次只根据当前位置的前context_len个token进行生成:

  • 第一步,先将输入文本截断成训练的token大小,训练时我们采用的200,截断为后200个词
  • 第二步,预测的下一个token的概率,采用温度采样和topk/topp采样

最终,我们不断的以自回归的方式不断生成预测结果

这里指定模型目录 

进入项目路径

执行streamlit run app.py 

 生成效果:

 数据与代码链接:https://pan.baidu.com/s/1XmurJn3k_VI5OR3JsFJgTQ?pwd=x3ci 
提取码:x3ci 

 

         

      

 

         

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/614694.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数学建模-最优包衣厚度终点判别法-三(Bayes判别分析法和梯度下降算法)

&#x1f49e;&#x1f49e; 前言 hello hello~ &#xff0c;这里是viperrrrrrr~&#x1f496;&#x1f496; &#xff0c;欢迎大家点赞&#x1f973;&#x1f973;关注&#x1f4a5;&#x1f4a5;收藏&#x1f339;&#x1f339;&#x1f339; &#x1f4a5;个人主页&#xff…

OceanMind海睿思入选《2024 中国MarTech行业生态图》

「Morketing研究院」正式发布《2024 中国MarTech行业生态图》&#xff0c;中新赛克海睿思作为国内数据治理优秀厂商&#xff0c;成功入选「数据与分析」板块「数据管理平台」子类&#xff0c;占据Martech领域关键节点。 ◎《2024中国MarTech行业生态图》 关于MarTech生态图 《…

毅速ESU丨增材制造有助于传统制造企业打造新增长极

在科技浪潮的推动下&#xff0c;传统制造企业正面临着前所未有的挑战与机遇。产品的复杂程度不断提升&#xff0c;个性化需求层出不穷&#xff0c;越来越短的生产周期&#xff0c;不断升级的品质要求等&#xff0c;传统的生产模式在应对这些变化并不容易。而增材制造&#xff0…

图像入门处理4(How to get the scaling ratio between different kinds of images)

just prepare for images fusion and registration ! attachments for some people who need link1 图像处理入门 3

bilibili PC客户端架构设计——基于Electron

众所周知&#xff0c;bilibili是个学习的网站&#xff0c;网页端和粉版移动端都非常的好用&#xff0c;不过&#xff0c;相对其它平台来说bilibili的PC客户端也算是大器晚成了。在有些场景PC客户端的优势也是显而易见的&#xff0c;比如&#xff0c;跓留电脑桌面的快捷、独立的…

鱼眼摄像头畸变校正方法概述

1. 摘要 鱼眼摄像头以其独特的广阔视场和其他特点&#xff0c;在各个领域得到了广泛应用。然而&#xff0c;与针孔相机相比&#xff0c;鱼眼摄像头存在显著的畸变&#xff0c;导致拍摄的图像失畸变严重。鱼眼摄像头畸变是数字图像处理中常见的问题&#xff0c;需要有效的校正技…

King’s LIMS类器官版

King’s LIMS服务检测行业17年&#xff0c;为行业用户提供专业的系统及服务。King’s LIMS类器官版是患者、医生、实验室三位一体的数字化、流程化、标准化的医学实验室检测管理平台。旨在助力临床药物筛选、验证。系统通过全业务流程电子化管理实现患者/医生报告出具的全过程信…

《师兄啊师兄》:玄机科技打造国漫新高峰,IP运营再显神力

在这个国漫蓬勃发展的时代&#xff0c;玄机科技再次以其超凡的制作水准和出色的IP运营能力&#xff0c;为我们带来了一部国漫新经典——《师兄啊师兄》。这部作品不仅在画面、剧情上达到了行业新高度&#xff0c;更在IP运营上展现出了其强大的实力与前瞻性。 《师兄啊师兄》的画…

从数据中台到上层应用全景架构示例

一、前言 对于大型企业而言&#xff0c;数据已经成为基本的生产资料&#xff0c;但是有很多公司还是值关心上层应用&#xff0c;而忽略了数据的治理&#xff0c;从而并不能很好的发挥公司的数据资产效益。比如博主自己是做后端的&#xff0c;主要是做应用层&#xff0c;也就是…

Ubuntu无网络标识的解决方法

1.出现的情况的特点 2.解决办法 2.1 进入root并输入密码 sudo su 2.2 更新NetworkManager的配置 得先有gedit或者vim&#xff0c;两个随意一个&#xff0c;这里用的gedit&#xff0c;没有就先弄gedit&#xff0c;有的话直接下一步 apt-get install gedit 或者vim apt-get ins…

dfs板子

递归实现排列 留着明早省赛之前看 #include<iostream> using namespace std; int arr[10010]; int brr[10010]; int n,k; void dfs(int num){if(num > n){for(int i 1;i < n;i){cout << arr[i] << " ";}cout << endl;return;}for(in…

【c语言】atoi函数---使用和模拟实现(详解)

atoi函数---使用和模拟实现 atoi函数在Cplusplus中的定义 atoi函数的使用 #include <stdio.h> #include <stdlib.h>int main() {char arr[20] "4831213";int ret 0;ret atoi(arr);printf("arr:%s\n", arr);printf("ret:%d\n", re…