Transformers实战——Datasets板块

文章目录

  • 一、基本使用
    • 1.加载在线数据集
    • 2.加载数据集合集中的某一项任务
    • 3.按照数据集划分进行加载
    • 4.查看数据集
      • 查看一条数据集
      • 查看多条数据集
      • 查看数据集里面的某个字段
      • 查看所有的列
      • 查看所有特征
    • 5.数据集划分
    • 6.数据选取与过滤
    • 7.数据映射
    • 8.保存与加载
  • 二、加载本地数据集
    • 1.直接加载文件作为数据集
    • 2.加载文件夹内全部文件作为数据集
    • 3.通过预先加载的其他格式转换加载数据集
    • 4.Dataset with DataCollator

!pip install datasets
from datasets import load_dataset

一、基本使用

1.加载在线数据集

datasets = load_dataset("madao33/new-title-chinese")
datasets
'''
DatasetDict({train: Dataset({features: ['title', 'content'],num_rows: 5850})validation: Dataset({features: ['title', 'content'],num_rows: 1679})
})
'''

2.加载数据集合集中的某一项任务

boolq_dataset = load_dataset("super_glue", "boolq")
boolq_dataset
'''
DatasetDict({train: Dataset({features: ['question', 'passage', 'idx', 'label'],num_rows: 9427})validation: Dataset({features: ['question', 'passage', 'idx', 'label'],num_rows: 3270})test: Dataset({features: ['question', 'passage', 'idx', 'label'],num_rows: 3245})
})
'''

3.按照数据集划分进行加载

dataset = load_dataset("madao33/new-title-chinese", split="train")
dataset
'''
Dataset({features: ['title', 'content'],num_rows: 5850
})
'''
  • 可以取切片
dataset = load_dataset("madao33/new-title-chinese", split="train[10:100]")
dataset
'''
Dataset({features: ['title', 'content'],num_rows: 90
})
'''
  • 可以取百分比
dataset = load_dataset("madao33/new-title-chinese", split="train[:50%]")
dataset
'''
Dataset({features: ['title', 'content'],num_rows: 2925
})
'''

  • 可以取多个
dataset = load_dataset("madao33/new-title-chinese", split=["train[:50%]", "train[50%:]"])
dataset
'''
[Dataset({features: ['title', 'content'],num_rows: 2925}),Dataset({features: ['title', 'content'],num_rows: 2925})]
'''
boolq_dataset = load_dataset("super_glue", "boolq", split=["train[:50%]", "train[50%:]"])
boolq_dataset
'''
[Dataset({features: ['question', 'passage', 'idx', 'label'],num_rows: 4714}),Dataset({features: ['question', 'passage', 'idx', 'label'],num_rows: 4713})]
'''

4.查看数据集

datasets = load_dataset("madao33/new-title-chinese")
datasets
'''
DatasetDict({train: Dataset({features: ['title', 'content'],num_rows: 5850})validation: Dataset({features: ['title', 'content'],num_rows: 1679})
})
'''

查看一条数据集

datasets["train"][0]

查看多条数据集

datasets["train"][:2]
'''

查看数据集里面的某个字段

datasets["train"]["title"][:5]# 或者这样也可以
datasets["train"][:5]['title']

查看所有的列

datasets["train"].column_names
'''
['title', 'content']
'''

查看所有特征

datasets["train"].features
'''
{'title': Value(dtype='string', id=None),'content': Value(dtype='string', id=None)}
'''

5.数据集划分

dataset = datasets["train"]
dataset.train_test_split(test_size=0.1, seed=3407)
'''
DatasetDict({train: Dataset({features: ['title', 'content'],num_rows: 5265})test: Dataset({features: ['title', 'content'],num_rows: 585})
})
'''
  • 分类数据集可以按照比例划分(分布均衡),即单看某一个类别所占的比例 train和test中应该是一样的 比如0类在train中占0.3,那test中0类占比也是 0.3
dataset = boolq_dataset["train"]
dataset.train_test_split(test_size=0.1, stratify_by_column="label") 
'''
DatasetDict({train: Dataset({features: ['question', 'passage', 'idx', 'label'],num_rows: 8484})test: Dataset({features: ['question', 'passage', 'idx', 'label'],num_rows: 943})
})
'''# 解释分布均衡
dataset['train']['label'].count(1) / len(dataset['train'])
'''
0.6230551626591231
'''
dataset['test']['label'].count(1) / len(dataset['test'])
'''
0.6230551626591231
'''

6.数据选取与过滤

# 选取
datasets["train"].select([0, 1])
'''
Dataset({features: ['title', 'content'],num_rows: 2
})
'''
# 过滤
filter_dataset = datasets["train"].filter(lambda example: "中国" in example["title"])
filter_dataset["title"][:5]

7.数据映射

def add_prefix(example):example["title"] = 'Prefix: ' + example["title"]return example
prefix_dataset = datasets.map(add_prefix)
prefix_dataset["train"][:10]["title"]

from transformers import AutoTokenizertokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")def preprocess_function(example):model_inputs = tokenizer(example["content"], max_length=512, truncation=True)labels = tokenizer(example["title"], max_length=32, truncation=True)# label就是title编码的结果model_inputs["labels"] = labels["input_ids"]return model_inputs
processed_datasets = datasets.map(preprocess_function)
processed_datasets
'''
DatasetDict({train: Dataset({features: ['title', 'content', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],num_rows: 5850})validation: Dataset({features: ['title', 'content', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],num_rows: 1679})
})
'''
  • 使用多进程
    • 注意需要多一个参数 tokenizer=tokenizer
def preprocess_function(example, tokenizer=tokenizer):model_inputs = tokenizer(example["content"], max_length=512, truncation=True)labels = tokenizer(example["title"], max_length=32, truncation=True)# label就是title编码的结果model_inputs["labels"] = labels["input_ids"]return model_inputsprocessed_datasets = datasets.map(preprocess_function, num_proc=4)
processed_datasets
'''
DatasetDict({train: Dataset({features: ['title', 'content', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],num_rows: 5850})validation: Dataset({features: ['title', 'content', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],num_rows: 1679})
})
'''
  • TokenizerFast 可以使用 batched=True加速映射过程
processed_datasets = datasets.map(preprocess_function, batched=True)
processed_datasets
'''
DatasetDict({train: Dataset({features: ['title', 'content', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],num_rows: 5850})validation: Dataset({features: ['title', 'content', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],num_rows: 1679})
})
'''

  • 删除多余列
processed_datasets = datasets.map(preprocess_function, batched=True, remove_columns=datasets["train"].column_names)
processed_datasets
'''
DatasetDict({train: Dataset({features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],num_rows: 5850})validation: Dataset({features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],num_rows: 1679})
})
'''

8.保存与加载

processed_datasets.save_to_disk("./processed_data")

image.png


from datasets import load_from_diskprocessed_datasets = load_from_disk("./processed_data")
processed_datasets
'''
DatasetDict({train: Dataset({features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],num_rows: 5850})validation: Dataset({features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],num_rows: 1679})
})
'''

二、加载本地数据集

1.直接加载文件作为数据集

  • 这里加 split="train"是因为加载本地数据集会默认将数据集视为 train
dataset = load_dataset("csv", data_files="./ChnSentiCorp_htl_all.csv", split="train")
dataset
'''
Dataset({features: ['label', 'review'],num_rows: 7766
})
'''
  • 也可以用 Dataset
from datasets import Datasetdataset = Dataset.from_csv("./ChnSentiCorp_htl_all.csv")
dataset
'''
Dataset({features: ['label', 'review'],num_rows: 7766
})
'''

2.加载文件夹内全部文件作为数据集

image.png

dataset = load_dataset("csv", data_dir="/content/all_data", split='train')
dataset
'''
Dataset({features: ['label', 'review'],num_rows: 15532
})
'''
  • 指定加载文件夹中哪些文件
dataset = load_dataset("csv",data_files=['/content/all_data/ChnSentiCorp_htl_all.csv','/content/all_data/ChnSentiCorp_htl_all2.csv'], split='train')
dataset
'''
Dataset({features: ['label', 'review'],num_rows: 15532
})
'''

  • cache_dir:构建的数据集缓存目录,方便下次快速加载
dataset = load_dataset("csv", data_files=['/content/all_data/ChnSentiCorp_htl_all.csv','/content/all_data/ChnSentiCorp_htl_all2.csv'], split='train',cache_dir='dir')
dataset
'''
Dataset({features: ['label', 'review'],num_rows: 15532
})
'''

image.png


3.通过预先加载的其他格式转换加载数据集

import pandas as pddata = pd.read_csv("./ChnSentiCorp_htl_all.csv")
data.head()
dataset = Dataset.from_pandas(data)
dataset
'''
Dataset({features: ['label', 'review'],num_rows: 7766
})
'''
  • List格式的数据需要内嵌{},明确数据字段
# List格式的数据需要内嵌{},明确数据字段
data = [{"text": "abc"}, {"text": "def"}]
# data = ["abc", "def"] # 报错
Dataset.from_list(data)
'''
Dataset({features: ['text'],num_rows: 2
})
'''

4.Dataset with DataCollator

from transformers import  DataCollatorWithPadding
dataset = load_dataset("csv", data_files="./ChnSentiCorp_htl_all.csv", split='train')
dataset = dataset.filter(lambda x: x["review"] is not None)
dataset
'''
Dataset({features: ['label', 'review'],num_rows: 7765
})
'''
def process_function(examples):tokenized_examples = tokenizer(examples["review"], max_length=128, truncation=True)tokenized_examples["labels"] = examples["label"]return tokenized_examples
tokenized_dataset = dataset.map(process_function, batched=True, remove_columns=dataset.column_names)
tokenized_dataset
'''
Dataset({features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],num_rows: 7765
})
'''
print(tokenized_dataset[:3])
'''
{'input_ids': [[101, 6655, 4895, 2335, 3763, 1062, 6662, 6772, 6818, 117, 852, 3221, 1062, 769, 2900, 4850, 679, 2190, 117, 1963, 3362, 3221, 107, 5918, 7355, 5296, 107, 4638, 6413, 117, 833, 7478, 2382, 7937, 4172, 119, 2456, 6379, 4500, 1166, 4638, 6662, 5296, 119, 2791, 7313, 6772, 711, 5042, 1296, 119, 102], [101, 1555, 1218, 1920, 2414, 2791, 8024, 2791, 7313, 2523, 1920, 8024, 2414, 3300, 100, 2160, 8024, 3146, 860, 2697, 6230, 5307, 3845, 2141, 2669, 679, 7231, 106, 102], [101, 3193, 7623, 1922, 2345, 8024, 3187, 6389, 1343, 1914, 2208, 782, 8024, 6929, 6804, 738, 679, 1217, 7608, 1501, 4638, 511, 6983, 2421, 2418, 6421, 7028, 6228, 671, 678, 6821, 702, 7309, 7579, 749, 511, 2791, 7313, 3315, 6716, 2523, 1962, 511, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 
'labels': [1, 1, 1]}
'''
collator = DataCollatorWithPadding(tokenizer=tokenizer)
  • 动态填充,每个 batch_size长度不一样
from torch.utils.data import DataLoaderdl = DataLoader(tokenized_dataset, batch_size=4, collate_fn=collator, shuffle=True)
num = 0
for batch in dl:print(batch["input_ids"].size())num += 1if num > 10:break
'''
torch.Size([4, 128])
torch.Size([4, 128])
torch.Size([4, 128])
torch.Size([4, 115])
torch.Size([4, 128])
torch.Size([4, 117])
torch.Size([4, 128])
torch.Size([4, 128])
torch.Size([4, 128])
torch.Size([4, 128])
torch.Size([4, 127])
'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/188586.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

鞋业生产制造用什么ERP软件?能为企业带来哪些好处

鞋服这类商品的种类众多,同时也是我们生活当中较为常见的产品,各个制鞋企业有差异化的营销渠道和经营模式,日常生产过程存在的问题呈现多样化。 有些制鞋企业依然采用传统的管理方式,在这种模式之下,企业并不能随时掌…

一文了解ChatGPT Plus如何完成论文写作和AI绘图

2023年我们进入了AI2.0时代。微软创始人比尔盖茨称ChatGPT的出现有着重大历史意义,不亚于互联网和个人电脑的问世。360创始人周鸿祎认为未来各行各业如果不能搭上这班车,就有可能被淘汰在这个数字化时代,如何能高效地处理文本、文献查阅、PPT…

渗透测试--实战若依ruoyi框架

免责声明: 文章中涉及的漏洞均已修复,敏感信息均已做打码处理,文章仅做经验分享用途,切勿当真,未授权的攻击属于非法行为!文章中敏感信息均已做多层打马处理。传播、利用本文章所提供的信息而造成的任何直…

【超好用的工具库】hutool-all工具库的基本使用

简介(可不看): hutool-all是一个Java工具库,提供了许多实用的工具类和方法,用于简化Java开发过程中的常见任务。它包含了各种模块,涵盖了字符串操作、日期时间处理、加密解密、文件操作、网络通信、图片处…

河北大学选择ZStack Cube超融合一体机打造实训云平台

河北大学通过云轴科技ZStack Cube超融合一体机构建校园实训云平台,部署测试仅耗时1天,该平台能够更快地为学生提供高性能、高可用的云主机、云存储和云网络服务;同时也能满足日常运维管理要求,为学生提供更好的实训环境。 河北省…

【Mysql学习笔记】1 - Mysql入门

一、Mysql5.7安装配置 下载后会得到zip 安装文件解压的路径最好不要有中文和空格这里我解压到 D:\hspmysql\mysql-5.7.19-winx64 目录下 【根据自己的情况来指定目录,尽量选择空间大的盘】 添加环境变量 : 电脑-属性-高级系统设置-环境变量,在Path 环境变量增加mysq…

CMakeLists.txt基础指令与cmake-gui生成VS项目的步骤

简介 本博客主要介绍cmake的基本指令,同时,很多使用Visual Studio小白从Gitbub下载项目源码后,看到CMakeLists.txt,不知道如何使用Visual Studio编译源码;针对以上问题,做一下简单操作与解释,方…

centos7安装keepalived 保证Nginx的高可用

keepalived工作在虚拟路由器冗余协议 VRRP (Virtual Router Redundancy Protocol) 上,它允许一个静态 IP 在两个 Linux 系统之间进行故障转移。 环境准备: 两台虚拟机centos7,IP:192.168.213.4(backup) 192.168.213.6(master) 安…

解决更换NodeJs版本后npm -v返回空白

一、问题描述 win11电脑上输入cmd进入控制台,输入 node --version 有正常返回安装的nodejs的版本号 再输入 npm -v 返回空白。正常情况应该是要返回版本号。 二、问题背景 最近准备学习vue,在不久前已经安装了NodeJs和python。运行了好几个开源项…

SpringSecurity6 | 自动配置(上)

✅作者简介:大家好,我是Leo,热爱Java后端开发者,一个想要与大家共同进步的男人😉😉 🍎个人主页:Leo的博客 💞当前专栏: Java从入门到精通 ✨特色专栏&#xf…

小程序授权获取昵称

wxml: <form bindsubmit"formsubmit"><view style"width: 90%;display: flex;margin-left: 5%;"><view class"text1">昵称&#xff1a;</view><input style"width: 150px;margin-left: 30px;margin-top: 30px;…

CI/CD相关概念学习

文章目录 CI/CD相关概念学习前言CI/CD相关概念介绍集成地狱持续集成持续交付持续部署Devops CI/CD相关应用介绍JenkinsTekton PipelinesSpinnakerTravis CIGoCD CI/CD相关概念学习 前言 本文主要是介绍一些 CI/CD 相关的概念&#xff0c;通过阅读本文你将快速了解 CI/CD 是什么…