【Python】科研代码学习:七 TrainingArguments,Trainer

【Python】科研代码学习:七 TrainingArguments,Trainer

  • TrainingArguments
    • 重要的方法
  • Trainer
    • 重要的方法
    • 使用 Trainer 的简单例子

TrainingArguments

  • HF官网API:Training
    众所周知,推理是一个大头,训练是另一个大头
    之前的很多内容,都是为训练这里做了一个小铺垫
    如何快速有效地调用代码,训练大模型,才是重中之重(不然学那么多HF库感觉怪吃苦的)
  • 首先看训练参数,再看训练器吧。
    首先,它的头文件是 transformers.TrainingArguments
    再看它源码的参数,我勒个去,太多了吧。
    ※ 我这里挑重要的讲解,全部请看API去。
  • output_dir (str)设置模型输出预测,或者中继点 (checkpoints) 的输出目录。模型训练到一半,肯定需要有中继点文件的嘛,就相当于游戏存档有很多一样,防止跑一半直接程序炸了,还要从头训练
  • overwrite_output_dir (bool, optional, defaults to False):把这个参数设置成 True,就会覆盖其中 output_dir 中的文档。一般在从中继点继续训练时需要这么用
  • do_train (bool, optional, defaults to False):指明我在做训练集的训练任务
  • do_eval (bool, optional):指明我在做验证集的评估任务
  • do_predict (bool, optional, defaults to False) :指明我在做测试集的预测任务
  • evaluation_strategy :评估策略:训练时不评估 / 每 eval_steps 步评估,或者每 epoch 评估
"no": No evaluation is done during training.
"steps": Evaluation is done (and logged) every eval_steps.
"epoch": Evaluation is done at the end of each epoch.
  • per_device_train_batch_size :训练时每张卡的batch大小,默认为8
    per_device_eval_batch_size :评估时每张卡的batch大小,默认为8
  • learning_rate (float, optional, defaults to 5e-5):学习率,里面使用的是 AdamW optimizer
    其他相应的 AdamW Optimizer 的参数还有:
    weight_decay adam_beta1adam_beta2adam_epsilon
  • num_train_epochs:训练的 epoch 个数,默认为3,可以设置小数。
  • lr_scheduler_type:具体作用要查看 transformers 里的 Scheduler 是干什么用的
  • warmup_ratiowarmup_steps :让一开始的学习率从0逐渐升到 learning_rate 用的
  • logging_dir :设置 logging 输出的文档
    除此之外还有一些和 logging相关的参数:
    logging_strategy ,logging_first_step ,logging_steps ,logging_nan_inf_filter 设置日志的策略
  • 与保存模型中继文件相关的参数:
    save_strategy :不保存中继文件 / 每 epoch 保存 / 每 save_steps 步保存
"no": No save is done during training.
"epoch": Save is done at the end of each epoch.
"steps": Save is done every save_steps.

save_steps :如果是整数,表示多少步保存一次;小数,则是按照总训练步,多少比例之后保存一次
save_total_limit :最多中继文件的保存上限,如果超过上限,会先把最旧的那个中继文件删了再保存新的
save_safetensors :使用 savetensor来存储和加载 tensors,默认为 True
push_to_hub :是否保存到 HF hub

  • use_cpu (bool, optional, defaults to False):是否用 cpu 训练
  • seed (int, optional, defaults to 42) :训练的种子,方便复现和可重复实验
  • data_seed :数据采样的种子
  • 数据精读相关的一些参数:
    FP32、TF32、FP16、BF16、FP8、FP4、NF4、INT8
    bf16 (bool, optional, defaults to False)fp16 (bool, optional, defaults to False)
    tf32 (bool, optional)
  • run_name :展示在 wandb and mlflow logging 中的描述
  • load_best_model_at_end (bool, optional, defaults to False):是否保存效果最好的中继点作为最终模型,与 save_total_limit 有些交互操作
    如果上述设置成 True 的话,考虑 metric_for_best_model ,即如何评估效果最好。默认为 loss 即损失最小
    如果你修改了 metric_for_best_model 的话,考虑 greater_is_better ,即指标越大越好还是越小越好
  • 一些加速相关的参数,貌似都比较麻烦
    fsdp
    fsdp_config
    deepspeed
    accelerator_config
  • optim :设置 optimizer,默认为 adamw_torch
    也可以设置成 adamw_hf, adamw_torch, adamw_torch_fused, adamw_apex_fused, adamw_anyprecision or adafactor.
  • resume_from_checkpoint :传入中继点文件的目录,从中继点继续训练

重要的方法

  • ※ 那我怎么访问或者修改上述参数呢?
    由于这个需要实例化,所以我们需要使用OO的方法修改
    下面讲一下其中重要的方法
  • set_dataloader:设置 dataloader
    在这里插入图片描述
from transformers import TrainingArgumentsargs = TrainingArguments("working_dir")
args = args.set_dataloader(train_batch_size=16, eval_batch_size=64)
args.per_device_train_batch_size
  • 设置 logging 相关的参数
    在这里插入图片描述
  • 设置 optimizer
    在这里插入图片描述
  • 设置保存策略
    在这里插入图片描述
  • 设置训练策略
    在这里插入图片描述
  • 设置评估策略
    在这里插入图片描述
  • 设置测试策略
    在这里插入图片描述

Trainer

  • 终于到大头了。Trainer 是主要用 pt 训练的,主要支持 GPUs (NVIDIA GPUs, AMD GPUs)/ TPUs
  • 看下源码,它要的东西不少,讲下重要参数:
  • model:要么是 transformers.PretrainedModel 类型的,要么是简单的 torch.nn.Module 类型的
  • argsTrainingArguments 类型的训练参数。如果不提供的话,默认使用 output_dir/tmp_trainer 里面的那个训练参数
  • data_collator DataCollator 类型参数,给训练集或验证集做数据分批和预处理用的,如果没有tokenizer默认使用 default_data_collator,否则默认使用 DataCollatorWithPadding (Will default to default_data_collator() if no tokenizer is provided, an instance of DataCollatorWithPadding otherwise.)
  • train_dataset (torch.utils.data.Dataset or torch.utils.data.IterableDataset, optional) :提供训练的数据集,当然也可以是 Datasets 类型的数据
  • eval_dataset :类似的验证集的数据集
  • tokenizer :提供 tokenizer 分词器
  • compute_metrics :验证集使用时候的计算指标,具体得参考 EvalPrediction 类型
  • optimizers :可以提供 Tuple(optimizer, scheduler)。默认使用 AdamW 以及 get_linear_schedule_with_warmup() controlled by args
    在这里插入图片描述

重要的方法

  • compute_loss:设置如何计算损失
    在这里插入图片描述
  • train:设置训练集训练任务,第一个参数可以设置是否从中继点开始训练
    在这里插入图片描述
  • evaluate:设置验证集评估任务,需要提供验证集
    在这里插入图片描述
  • predict:设置测试集任务
    在这里插入图片描述
  • save_model:保存模型参数到 output_dir在这里插入图片描述
  • training_step:设置每一个训练的 step,把一个batch的输入经过了何种操作,得到一个 torch.Tensor
    在这里插入图片描述

使用 Trainer 的简单例子

  • 主要就是加载一些参数,传进去即可
    模型、训练参数、训练集、验证集、计算指标
    调用训练方法 .train()
    最后保存模型即可 .save_model()
from transformers import (Trainer,)
trainer = Trainer(model=model,args=training_args,train_dataset=small_train_dataset,eval_dataset=small_eval_dataset,compute_metrics=compute_metrics,
)trainer.train()
trainer.save_model(outputdir="./xxx")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/528962.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微信小程序(五十三)修改用户头像与昵称

注释很详细,直接上代码 上一篇 新增内容: 1.外界面个人资料基本模块 2.资料修改界面同步问题实现(细节挺多,考虑了后期转服务器端的方便之处) 源码: app.json {"window": {},"usingCompone…

盘点CSV文件在Excel中打开后乱码问题的两种处理方法

目录 一、CSV文件乱码问题概述 二、修改文件编码格式 1.识别CSV文件编码 2.修改编码格式 3.在Excel中打开修改后的CSV文件 案例 三、利用文本编辑器进行预处理 1.打开CSV文件并检查乱码 2.替换或删除乱码字符 3.保存并导入Excel 案例 四、注意事项 1、识别原始编码…

vulhub中Weblogic WLS Core Components 反序列化命令执行漏洞复现(CVE-2018-2628)

Oracle 2018年4月补丁中,修复了Weblogic Server WLS Core Components中出现的一个反序列化漏洞(CVE-2018-2628),该漏洞通过t3协议触发,可导致未授权的用户在远程服务器执行任意命令。 访问http://your-ip:7001/consol…

CodeReview 规范及实施

优质博文:IT-BLOG-CN 一、为什么需要CodeReview 随着业务压力增大,引发代码质量下降,代码质量的下降导致了开发效率的降低,维护成功高等问题,开发效率下降后又加重了业务压力,最终陷入了死亡三角的内耗之…

腾讯云服务器99元一年厉害了,老用户可以买,续费也是99元

良心腾讯云推出99元一年服务器,新用户和老用户均可以购买,续费不涨价,续费也是99元,配置为轻量2核2G4M、50GB SSD盘、300GB月流量、4M带宽:优惠价格99元一年,续费99元,官方活动页面 txybk.com/g…

Linux/Windows下部署OpenCV环境(Java/SpringBoot/IDEA)

环境 本文基于Linux(CentOS 7)、SpringBoot部署运行OpenCV 4.5.5,并顺带记录Windows/IDEA下如何调试SpringBoot调用OpenCV项目。 Windows下调试 首先我们编写代码,并在Windows/IDEA下调试通过。 下载Windows版安装包&#xff0…

第108讲:Mycat实践指南:枚举分片下的水平分表详解

文章目录 1.枚举分片的概念2.水平分表枚举分片案例2.1.准备测试的表结构2.2.配置Mycat实现枚举分片的水平分表2.2.1.配置Schema配置文件2.2.2.配置Rule分片规则配置文件2.2.3.配置Server配置文件2.2.4.重启Mycat 2.3.写入数据观察水平分表效果 1.枚举分片的概念 枚举分片是根据…

Python 负载测试工具库之locust使用详解

概要 在当今的互联网时代,负载测试是确保应用程序和网站性能稳定的关键步骤之一。Python Locust是一个强大的开源负载测试工具,它可以模拟大量用户并测量应用程序的性能。本文将提供有关Python Locust的全面指南,包括安装和配置、基本概念、性能测试、任务编写、报告生成以…

龙年新征程!凌恩生物2月客户文章累计IF>300

2024年2月,凌恩生物助力客户发表文章47篇(6篇预发表),累计影响因子317.3分,其中包括Advanced Materials、Microbiome、Journal of Hazardous Materials、Small、Molecular Psychiatry、Science of the Total Environme…

MySQL 学习笔记(基础篇 Day3)

「写在前面」 本文为黑马程序员 MySQL 教程的学习笔记。本着自己学习、分享他人的态度,分享学习笔记,希望能对大家有所帮助。推荐先按顺序阅读往期内容: 1. MySQL 学习笔记(基础篇 Day1) 2. MySQL 学习笔记&#xff08…

Elasticsearch 的开源分析可视化工具--Kibana 安装 使用

1、简介 Kibana 是一款 Elasticsearch 的开源分析可视化工具,能够与存储在 Elasticsearch 中的数据进行交互,在开发中能够监控Elasticsearch中的数据,并且能够执行命令操作Elasticsearch中的数据。 2、Kibana 的下载与安装 2.1、Kibana 下载…

NIN网络中的网络

是什么 intro LeNet→AlexNet→VGG→NiN→GoogLeNet→ResNetLeNet→AlexNet→VGG 卷积层模块充分抽取空间特征全连接层输出分类结果AlexNet & VGG 改进在于把两个模块加宽 、加深(加宽指增加通道数,那加深呢?(层数增加叭 Ni…