使用 T5 Transformer 进行多任务处理的指南

T5 (Text-to-Text Transfer Transformer) 模型是为探索迁移学习的局限性而进行的一项大规模研究(论文)的产物。它建立在 GPT、BERT 和 RoBERTa(仅举几例)模型等流行的架构之上,这些模型利用迁移学习取得了令人难以置信的成功。虽然类似 BERT 的模型可以进行微调以执行各种任务,但架构的约束意味着每个模型只能执行一项任务。

通常,这是通过在 Transformer 模型上添加特定于任务的层来完成的。例如,可以通过添加一个具有两个输出神经元(对应于每个类)的全连接层来适应二进制分类。T5 模型背离了这一传统,将所有 NLP 任务重新定义为文本转文本任务。这会导致任何 NLP 任务的共享框架作为模型的输入,并且模型的输出始终是一个字符串。在二元分类的例子中,T5 模型将简单地输出类的字符串表示形式(即 或 )。"0""1"

由于任何 NLP 任务的输入和输出格式都是相同的,因此可以教同一个 T5 模型执行多个任务!要指定应该执行哪个任务,我们可以简单地在模型的输入前面加上一个前缀(字符串)。Google AI 博客文章中的动画(如下所示)演示了这一概念。

 

摘自文章 Exploring Transfer Learning with T5: the Text-to-Text Transfer Transformer

在本文中,我们将使用这种技术来训练能够执行 3 个 NLP 任务、二元分类、多标签分类和回归的单个 T5 模型。

所有代码也可以在 Github 上找到

任务规范

二元分类

NLP 中二元分类的目标是将给定的文本序列分类为两类之一。在我们的任务中,我们将使用 Yelp Reviews 数据集将文本的情绪分类为正面 ( ) 或负面 ( )。"1""0"

多标签分类

在多标签分类中,应使用一组预定义标签的正确子集来标记给定的文本序列(请注意,该子集可以同时包括 null 集和完整标签集本身)。为此,我们将使用 Toxic Comments 数据集,其中每个文本都可以使用标签的任何子集进行标记。toxic, severe_toxic, obscene, threat, insult, identity_hate

回归

在回归任务中,目标变量是一个连续值。在我们的任务中,我们将使用 STS-B (Semantic Textual Similarity Benchmark) 数据集,目标是预测两个句子的相似性。相似性由 和 之间的连续值表示。05

数据准备

由于我们将使用 3 个数据集,因此我们将它们放在目录内的 3 个单独的子目录中。data

  • data/binary_classification
  • data/multilabel_classification
  • data/regression

下载

  1. 下载 Yelp 评论数据集。
  2. extract 和 to .train.csvtest.csvdata/binary_classification
  3. 下载 Toxic Comments 数据集。
  4. 将文件解压到 。csvdata/multilabel_classification
  5. 下载 STS-B 数据集。
  6. 将文件解压到 。csvdata/regression

合并数据集

如前所述,T5 模型的输入和输出始终是文本。通过使用前缀文本指定特定任务,该文本让模型知道它应该如何处理输入。

Simple Transformers 中 T5 模型的 input 数据格式反映了这一事实。输入是一个 Pandas 数据帧,其中包含 3 列 — 、 和 。这使得在多个任务上训练模型变得非常容易,因为您只需要将 .prefixinput_texttarget_textprefix

import pandas as pd
import json
from sklearn.model_selection import train_test_split
prefix = 'data/binary_classification/'binary_train_df = pd.read_csv(prefix + 'train.csv', header=None)
binary_train_df.head()binary_eval_df = pd.read_csv(prefix + 'test.csv', header=None)
binary_eval_df.head()binary_train_df[0] = (binary_train_df[0] == 2).astype(int)
binary_eval_df[0] = (binary_eval_df[0] == 2).astype(int)binary_train_df = pd.DataFrame({'prefix': ["binary classification" for i in range(len(binary_train_df))],'input_text': binary_train_df[1].str.replace('\n', ' '),'target_text': binary_train_df[0].astype(str),
})print(binary_train_df.head())binary_eval_df = pd.DataFrame({'prefix': ["binary classification" for i in range(len(binary_eval_df))],'input_text': binary_eval_df[1].str.replace('\n', ' '),'target_text': binary_eval_df[0].astype(str),
})print(binary_eval_df.head())
prefix                                         input_text  \
0  binary classification  Unfortunately, the frustration of being Dr. Go...   
1  binary classification  Been going to Dr. Goldberg for over 10 years. ...   
2  binary classification  I don't know what Dr. Goldberg was like before...   
3  binary classification  I'm writing this review to give you a heads up...   
4  binary classification  All the food is great here. But the best thing...   target_text  
0           0  
1           1  
2           0  
3           0  
4           1  prefix                                         input_text  \
0  binary classification  Contrary to other reviews, I have zero complai...   
1  binary classification  Last summer I had an appointment to get new ti...   
2  binary classification  Friendly staff, same starbucks fair you get an...   
3  binary classification  The food is good. Unfortunately the service is...   
4  binary classification  Even when we didn't have a car Filene's Baseme...   target_text  
0           1  
1           0  
2           1  
3           0  
4           1  
prefix = "data/multilabel_classification/"multi_train_df = pd.read_csv(prefix + 'train.csv')
multi_train_df["comment_text"].str.replace('\n', ' ').str.replace('\t', ' ')for col in multi_train_df.columns:if col not in ["id", "comment_text"]:multi_train_df[col] = multi_train_df[col].apply(lambda x: col if x else "")multi_train_df["target_text"] = multi_train_df['toxic'].str.cat(multi_train_df[[col for col in multi_train_df.columns if col not in ["id", "comment_text", "toxic"]]], sep=',')
multi_train_df["target_text"] = multi_train_df["target_text"].apply(lambda x: ",".join(word for word in x.split(",") if word)).apply(lambda x: x if x else "clean")
multi_train_df["input_text"] = multi_train_df["comment_text"].str.replace('\n', ' ')
multi_train_df["prefix"] = "multilabel classification"
multi_train_df = multi_train_df[["prefix", "input_text", "target_text"]]multi_train_df, multi_eval_df = train_test_split(multi_train_df, test_size=0.1)multi_train_df.head()
prefix    input_text    target_text
140162    multilabel classification    ban you got me banned on irc -    clean
135151    multilabel classification    This is a public computer Hi, I have a sligh...    clean
4901    multilabel classification    Why does nobody post anything on 'my talk' tha...    clean
58298    multilabel classification    Okay sorry I didn't read the article for a while.    clean
56472    multilabel classification    If you really feel that strongly about protect...    clean
prefix = 'data/regression/'sts_train_df = pd.read_csv(prefix + 'train.tsv', sep='\t', error_bad_lines=False).dropna()
sts_eval_df = pd.read_csv(prefix + 'dev.tsv', sep='\t', error_bad_lines=False).dropna()sts_train_df["sentence1"] = sts_train_df["sentence1"].str.replace('\n', ' ').str.replace('\t', ' ')
sts_train_df["sentence2"] = sts_train_df["sentence2"].str.replace('\n', ' ').str.replace('\t', ' ')
sts_eval_df["sentence1"] = sts_eval_df["sentence1"].str.replace('\n', ' ').str.replace('\t', ' ')
sts_eval_df["sentence2"] = sts_eval_df["sentence2"].str.replace('\n', ' ').str.replace('\t', ' ')
b'Skipping line 2509: expected 10 fields, saw 11\nSkipping line 2650: expected 10 fields, saw 11\nSkipping line 2727: expected 10 fields, saw 11\nSkipping line 3071: expected 10 fields, saw 11\nSkipping line 3393: expected 10 fields, saw 11\n'
b'Skipping line 1042: expected 10 fields, saw 11\nSkipping line 1066: expected 10 fields, saw 11\nSkipping line 1083: expected 10 fields, saw 11\nSkipping line 1137: expected 10 fields, saw 11\nSkipping line 1150: expected 10 fields, saw 11\n'
sts_train_df.drop(2001, inplace=True) # This line is badly formatted. Getting rid.
sts_train_df["input_text"] = sts_train_df.apply(lambda x: "sentence1: " + x["sentence1"] + " sentence2: " + x["sentence2"], axis=1)
sts_eval_df["input_text"] = sts_eval_df.apply(lambda x: "sentence1: " + x["sentence1"] + " sentence2: " + x["sentence2"], axis=1)sts_train_df["target_text"] = sts_train_df["score"].apply(lambda x: round(x * 5) / 5).astype(str)
sts_eval_df["target_text"] = sts_eval_df["score"].apply(lambda x: round(x * 5) / 5).astype(str)sts_train_df["prefix"] = "similarity"
sts_eval_df["prefix"] = "similarity"sts_train_df = sts_train_df[["prefix", "input_text", "target_text"]]
sts_eval_df = sts_eval_df[["prefix", "input_text", "target_text"]]
train_df = pd.concat([binary_train_df, multi_train_df, sts_train_df]).astype(str)
eval_df = pd.concat([binary_eval_df, multi_eval_df, sts_eval_df]).astype(str)
train_df.to_csv("data/train.tsv", "\t")
eval_df.to_csv("data/eval.tsv", "\t")

 

上面的笔记本加载每个数据集,针对 T5 对其进行预处理,最后将它们组合成一个统一的 DataFrame。

这为我们提供了一个具有 3 个唯一前缀的 DataFrame,即 、 和 。请注意,前缀本身是相当任意的,重要的是确保每个任务都有自己唯一的前缀。模型的输入将采用以下格式:binary classificationmultilabel classificationsimilarity

<prefix>: <input_text>

训练 时会自动添加。": "

其他一些需要注意的事项:

  • 多标签分类任务的输出是预测标签 () 的逗号分隔列表。如果未预测到标签,则输出应为 。toxic, severe_toxic, obscene, threat, insult, identity_hateclean
  • for the similarity 任务包括两个句子,如以下示例所示;input_text
    sentence1: A man plays the guitar. sentence2: The man sang and played his guitar.
  • 相似性任务的输出是一个介于 0.0 和 5.0 之间的数字(作为字符串),增量为 0.2。(例如 , , , )。这遵循 T5 论文作者使用的相同格式。0.00.43.05.0

从不同输入和输出的表示方式中可以看出,T5 模型的文本到文本方法在表示各种任务和我们可以执行的实际任务方面都为我们提供了极大的灵活性。

例如;

提出正确的问题:在新任务中训练 T5 Transformer 模型

T5 Transformer 将任何 NLP 任务框定为文本转文本任务,使其能够轻松学习新任务。让我们教...

towardsdatascience.com

 

唯一的限制是想象力!(嗯,想象力和计算资源,但那是另一回事了) 😅

回到数据,运行 Notebook 应该会给你一个 和一个文件,我们将在下一节中使用它来训练我们的模型!train.tsveval.tsv

设置

我们将使用 Simple Transformers 库(基于 Hugging Face Transformers)来训练 T5 模型。

下面给出的说明将安装所有要求。

  1. 从此处安装 Anaconda 或 Miniconda Package Manager。
  2. 创建新的虚拟环境并安装软件包。
    conda create -n simpletransformers python
    conda activate simpletransformers
    conda install pytorch cudatoolkit=10.1 -c pytorch
  3. 安装 simpletransformers。
    pip install simpletransformers

查看安装文档

训练 T5 模型

与往常一样,使用 Simple Transformers 训练模型非常简单。

import pandas as pd
from simpletransformers.t5 import T5Modeltrain_df = pd.read_csv("data/train.tsv", sep="\t").astype(str)
eval_df = pd.read_csv("data/eval.tsv", sep="\t").astype(str)model_args = {"max_seq_length": 196,"train_batch_size": 16,"eval_batch_size": 64,"num_train_epochs": 1,"evaluate_during_training": True,"evaluate_during_training_steps": 15000,"evaluate_during_training_verbose": True,"use_multiprocessing": False,"fp16": False,"save_steps": -1,"save_eval_checkpoints": False,"save_model_every_epoch": False,"reprocess_input_data": True,"overwrite_output_dir": True,"wandb_project": "T5 mixed tasks - Binary, Multi-Label, Regression",
}model = T5Model("t5", "t5-base", args=model_args)model.train_model(train_df, eval_data=eval_df)

 

这里使用的大多数参数都是相当标准的。

  • max_seq_length:选择此项后,大多数样本不会被截断。增加序列长度会显著影响模型的内存消耗,因此通常最好使其尽可能短(理想情况下不要截断输入序列)。
  • train_batch_size:越大越好(只要它适合您的 GPU)
  • eval_batch_size: 与train_batch_size
  • num_train_epochs:训练超过 1 个 epoch 可能会提高模型的性能,但显然也会增加训练时间(在 RTX Titan 上每个 epoch 大约 7 小时)。
  • evaluate_during_training:我们将定期根据测试数据测试模型,以了解它的学习情况。
  • evaluate_during_training_steps:上述测试模型的时间段。
  • evaluate_during_training_verbose:在测试完成后向我们显示结果。
  • use_multiprocessing:使用 multiprocessing 显著减少了分词化(在训练开始之前完成)所需的时间,但是,这目前会导致 T5 实施出现问题。所以,现在没有多处理。😢
  • fp16:FP16 或混合精度训练减少了训练模型的内存消耗(意味着可以进行更大的批处理)。不幸的是,目前 T5 的训练不稳定,因此它也被关闭了。fp16
  • save_steps:将此项设置为 this 表示不会保存 checkpoint。-1
  • save_eval_checkpoints:默认情况下,在训练期间执行评估时,将保存模型检查点。由于此 Experiment 仅用于演示,因此我们也不要浪费空间来保存这些 checkpoint。
  • save_model_every_epoch:我们只有 1 个 epoch,所以没有。也不需要这个。
  • reprocess_input_data:控制是从缓存加载特征(保存到磁盘),还是对输入序列再次进行分词化。只有在进行多次运行时,它才真正重要。
  • overwrite_output_dir:如果之前保存的模型位于同一输出目录中,这将覆盖它们。
  • wandb_project:用于可视化训练进度。

说到可视化,你可以在这里查看我的训练进度。感谢W&B的超棒图书馆!

测试 T5 模型

考虑到我们正在处理多个任务,最好使用合适的指标来评估每个任务。考虑到这一点,我们将使用以下指标;

  • 二元分类:F1 分数和准确性分数
  • 多标签分类:F1 分数(Hugging Face SQuAD 指标实施)和精确匹配(Hugging Face SQuAD 指标实施)
  • 相似性:Pearson 相关系数和 Spearman 相关
import json
from datetime import datetime
from pprint import pprint
from statistics import meanimport numpy as np
import pandas as pd
from scipy.stats import pearsonr, spearmanr
from simpletransformers.t5 import T5Model
from sklearn.metrics import accuracy_score, f1_score
from transformers.data.metrics.squad_metrics import compute_exact, compute_f1def f1(truths, preds):return mean([compute_f1(truth, pred) for truth, pred in zip(truths, preds)])def exact(truths, preds):return mean([compute_exact(truth, pred) for truth, pred in zip(truths, preds)])def pearson_corr(preds, labels):return pearsonr(preds, labels)[0]def spearman_corr(preds, labels):return spearmanr(preds, labels)[0]model_args = {"overwrite_output_dir": True,"max_seq_length": 196,"eval_batch_size": 32,"num_train_epochs": 1,"use_multiprocessing": False,"num_beams": None,"do_sample": True,"max_length": 50,"top_k": 50,"top_p": 0.95,"num_return_sequences": 3,
}# Load the trained model
model = T5Model("t5", "outputs", args=model_args)# Load the evaluation data
df = pd.read_csv("data/eval.tsv", sep="\t").astype(str)# Prepare the data for testing
to_predict = [prefix + ": " + str(input_text)for prefix, input_text in zip(df["prefix"].tolist(), df["input_text"].tolist())
]
truth = df["target_text"].tolist()
tasks = df["prefix"].tolist()# Get the model predictions
preds = model.predict(to_predict)# Saving the predictions if needed
with open(f"predictions/predictions_{datetime.now()}.txt", "w") as f:for i, text in enumerate(df["input_text"].tolist()):f.write(str(text) + "\n\n")f.write("Truth:\n")f.write(truth[i] + "\n\n")f.write("Prediction:\n")for pred in preds[i]:f.write(str(pred) + "\n")f.write("________________________________________________________________________________\n")# Taking only the first prediction
preds = [pred[0] for pred in preds]
df["predicted"] = preds# Evaluating the tasks separately
output_dict = {"binary classification": {"truth": [], "preds": [],},"multilabel classification": {"truth": [], "preds": [],},"similarity": {"truth": [], "preds": [],},
}results_dict = {}for task, truth_value, pred in zip(tasks, truth, preds):output_dict[task]["truth"].append(truth_value)output_dict[task]["preds"].append(pred)print("-----------------------------------")
print("Results: ")
for task, outputs in output_dict.items():if task == "multilabel classification":try:task_truth = output_dict[task]["truth"]task_preds = output_dict[task]["preds"]results_dict[task] = {"F1 Score": f1(task_truth, task_preds),"Exact matches": exact(task_truth, task_preds),}print(f"Scores for {task}:")print(f"F1 score: {f1(task_truth, task_preds)}")print(f"Exact matches: {exact(task_truth, task_preds)}")print()except:passelif task == "binary classification":try:task_truth = [int(t) for t in output_dict[task]["truth"]]task_preds = [int(p) for p in output_dict[task]["preds"]]results_dict[task] = {"F1 Score": f1_score(task_truth, task_preds),"Accuracy Score": accuracy_score(task_truth, task_preds),}print(f"Scores for {task}:")print(f"F1 score: {results_dict[task]['F1 Score']}")print(f"Accuracy Score: {results_dict[task]['Accuracy Score']}")print()except:passif task == "similarity":task_truth = [float(t) for t in output_dict[task]["truth"]]task_preds = [float(p) for p in output_dict[task]["preds"]]results_dict[task] = {"Pearson Correlation": pearson_corr(task_truth, task_preds),"Spearman Correlation": spearman_corr(task_truth, task_preds),}print(f"Scores for {task}:")print(f"Pearson Correlation: {results_dict[task]['Pearson Correlation']}")print(f"Spearman Correlation: {results_dict[task]['Spearman Correlation']}")print()with open(f"results/result_{datetime.now()}.json", "w") as f:json.dump(results_dict, f)

 

请注意,在准备数据时,会在 和 之间插入 a。这在训练时自动完成,但需要手动处理以进行预测。": “prefixinput_text

如果您想了解更多关于解码参数 (, ), 的信息,请参考 这篇文章num_beamsdo_samplemax_lengthtop_ktop_p

是时候看看我们的模型表现如何了!

-----------------------------------
Results:
Scores for binary classification:
F1 score: 0.96044512420231
Accuracy Score: 0.9605263157894737Scores for multilabel classification:
F1 score: 0.923048001002632
Exact matches: 0.923048001002632Scores for similarity:
Pearson Correlation: 0.8673017763553101
Spearman Correlation: 0.8644328787107548

该模型在每个任务上都表现得相当不错,尽管在 3 个单独的任务上进行了训练!在下一节中,我们将快速浏览一下如何尝试进一步提高模型的性能。

结束语

可能的改进

混合任务时出现的一个潜在问题是用于每个任务的数据集大小之间的差异。我们可以通过查看训练样本计数在数据集中看到这个问题。

binary classification        560000
multilabel classification 143613
similarity 5702

数据集基本上不平衡,任务的困境似乎特别可怕!这可以从任务落后于其他任务的评估分数中清楚地看出(尽管需要注意的是,我们查看的任务之间的指标并不相同)。similaritysimilarity

此问题的一种可能补救措施是对任务进行过度采样,以便 model.similarity

除此之外,增加训练 epoch 的数量(并优化其他超参数)也可能改进模型。

最后,调整解码参数也可以带来更好的结果。

结束语

T5 模型的文本转文本格式为将 Transformer 和 NLP 应用于各种任务铺平了道路,几乎不需要自定义。即使使用相同的模型执行多项任务,T5 模型的性能也很强大!

希望这将在不久的将来带来许多创新应用。

引用

  1. 使用统一的文本转文本转换器探索迁移学习的极限 — https://arxiv.org/abs/1910.10683
  2. Google AI 博客 — https://ai.googleblog.com/2020/02/exploring-transfer-learning-with-t5.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/787634.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQL备忘记(一)

一前言 环境:win10 mysql 5.7.32 记录一些sql中平时容易弄错的或不明白一些知识点 二 正文 1 select语句执行顺序 FROM→WHERE→GROUP BY→HAVING→SELECT→ORDER BY --一个大概的执行顺序,具体执行顺序根据数据库管理系统S的不同而不同 如下成绩表score如上,可以看出,avg(…

Linux C++ 开发5 - 一文了解CMake构建

1. 什么是CMake?1.1. CMake的定义 1.2. CMake有哪些优势? 1.3. CMake 的特点 1.4. Cmake 、CMakeLists.txt 、Make 、Makefile 之间的关系2. 应用案例2.1. 项目概述 2.2. CMakeLists.txt2.2.1. 基本用法 2.2.2. 完整内容 2.2.3. 构建执行上一篇《Linux C++ 开发4 - 入门makef…

BLE 广播报文格式

广播报文结构 一个完整的BLE广播报文由四部分组成,分别是前导码、接入地址、协议数据单元和CRC校验码。Preamble 前导 Access address(接入设备) PDU CRC校验1 Bytes 4 Bytes 2-37 Bytes 3 Bytes前导码:用来同步时序,可以是0x55或者0xAA,由接入地址的第一个比特决定。如果接…

[JLOI2015] 骗我呢——一类经典反射容斥

加载解析界面 数字变化跳跃反射容斥 一层反射:有一条线 \(y=x+b\) 不能碰到。 从第一次碰到直线开始,将后面的部分沿直线翻折,最终一定会到达 \((n-b,n+b)\),因为 \(b\ne 0\),所以构成双射。答案即为 \(\binom{2n}{n}-\binom{2n}{n-b}\)。 注意,如果最终到达的位置是 \((…

南沙区信息学奥林匹克竞赛(信奥赛)介绍

​信息学奥林匹克竞赛(International Olympiad in Informatics,IOI)是一项旨在选拔和培养信息技术和计算机科学人才的国际性竞赛。该竞赛始于1989年,每年举办一次,由不同的国家轮流承办。参加比赛的选手来自全球各国,都是信息技术和计算机科学领域的尖子生。信息学奥林匹…

英文单词字母大小写在线转换工具html代码

这是一个简单而实用的在线大小写转换工具。它允许用户输入任意文本,并提供三种转换选项:转换为全大写、全小写或首字母大写。 使用这个工具非常简单快捷。用户只需要在输入框中输入想要转换的文本,选择合适的转换类型,然后点击"转换"按钮即可。转换结果会立即显示在输…

TCP的调试助手开发笔记

动图:1 先利用VS自带的socket类来写好TCP_CORE: 类目录如下:点击查看代码 using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Net; using System.Net.Sockets; using System.Text; using System.Text.RegularExp…

Why Transformers Need Adam: A Hessian Perspective

目录概符号说明所有参数的 Hessian 矩阵Block-wise Hessian代码Zhang Y., Chen C., Ding T., Li Z., Sun R. and Luo Z. Why transformers need adam: a hessian perspective. arXiv preprint, 2024.概 本文从 Hessian 矩阵的角度回答为什么 Adam 相较于其它方法, 比如 SGD 在 …

VL24 边沿检测

这个就是需要对a 进行打一拍last_a<=a; 需要理解的点是打一拍的last_a是落后a一个时钟周期的,也就是对当前时刻使用a时候,此时的last_a是a的上一时刻的值。`timescale 1ns/1ns module edge_detect(input clk,input rst_n,input a,output reg rise,output reg down ); reg …

RE入门第三天---TEA算法

OK,老规矩,先复习一下昨天的内容 ..... 几分钟就复习了,直接开干今天的内容 先找大佬的wp 来源: TEA系列加密解密 | Gruges Blog (g2uge.github.io) 逆向算法之TEA算法 - Sk2rw - 博客园 (cnblogs.com) 一.TEA加密解密简介 在密码学中,微型加密算法(Tiny Encryption Algo…

vue3 控制el-dialog 双向绑定显示隐藏

父组件<Contact v-model:isView="isView" /> 子组件<template><div><el-dialogwidth="400"title="微信二维码":model-value="props.isView"@closed="handleClose"><div class="dialog-div…