利用huggingface尝试的第一个简单的文本分类任务

news/2025/1/7 7:39:57/文章来源:https://www.cnblogs.com/hhhhappy/p/18653339

`'''
这是一个简单的文本分类任务,基本的流程步骤还是挺清晰完整的,和之前那个简单的cnn差不多,
用到了transformers包,还需用到huggingface的模型rbt3,
但是好像连接不上``
'''
'''
遇到的问题:
1.导入的包不可用,从AutoModelForTokenClassification换成了AutoModelForSequenceClassification
2.模型加载不出来,然后选择将模型下载到本地进行加载,但是这样的话会非常耗内存,后续考虑设置一个代理
3.数据处理的问题,在处理完数据,进行训练的时候,没有处理好设备选择问题,导致部分数据在cpu,部分数据在gpu
'''

step1 导入相关包

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

step2 加载数据

import pandas as pd

data = pd.read_csv("./ChnSentiCorp_htl_all.csv")
data = data.dropna()

step3 创建Dataset

from torch.utils.data import Dataset

class MyDataset(Dataset):
def init(self):
super().init()
self.data = pd.read_csv("./ChnSentiCorp_htl_all.csv")
self.data = self.data.dropna()

def __getitem__(self, index):return self.data.iloc[index]["review"], self.data.iloc[index]["label"]def __len__(self):return len(self.data)

dataset = MyDataset()

step4 划分数据集

from torch.utils.data import random_split

trainset, validset = random_split(dataset, lengths=[0.9, 0.1])

step5 创建Dataloader

tokenizer = AutoTokenizer.from_pretrained("./hfl/rbt3")

def collate_func(batch): # 自定义的数据合并函数,负责将一个批次(batch)的文本和标签进行处理
texts, labels = [], []
for item in batch:
texts.append(item[0])
labels.append(item[1])
inputs = tokenizer(texts, max_length=128, padding="max_length", truncation=True, return_tensors="pt")

# 确保把所有输入移到正确的设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = {key: value.to(device) for key, value in inputs.items()}
inputs["labels"] = torch.tensor(labels).to(device)  # labels 也移到相同设备
return inputs

from torch.utils.data import DataLoader

trainloader = DataLoader(trainset, batch_size=32, shuffle=True, collate_fn=collate_func)
validloader = DataLoader(validset, batch_size=64, shuffle=False, collate_fn=collate_func)

step6 创建模型及优化器

from torch.optim import Adam

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForSequenceClassification.from_pretrained("./hfl/rbt3").to(device)
optimizer = Adam(model.parameters(), lr=2e-5)

step7 训练与验证

def evaluate():
model.eval() # 设置模型为评估模式
acc_num = 0
# 打印模型设备信息(只在验证开始时打印一次)
print(f"验证开始:\n模型设备: {next(model.parameters()).device}")
with torch.no_grad():
for batch in validloader:
batch = {k: v.to(device) for k, v in batch.items()} # 将数据移到指定设备
# 只需要在第一次 batch 时打印一次输入数据设备信息
if acc_num == 0:
print(f"输入数据设备: {batch['input_ids'].device}")
output = model(**batch)
pred = torch.argmax(output.logits, dim=-1)
acc_num += (pred.long() == batch["labels"].long()).float().sum()
return acc_num / len(validset)

def train(epoch=1, log_step=100):
global_step = 0
for ep in range(epoch):
model.train() # 设置模型为训练模式
# 打印模型和输入数据设备信息(只在每个 epoch 开始时打印一次)
print(f"Epoch {ep} 开始:")
print(f"模型设备: {next(model.parameters()).device}")
for batch in trainloader:
batch = {k: v.to(device) for k, v in batch.items()} # 移动到正确设备
# 只需要在第一次 batch 时打印一次输入数据设备信息
if global_step == 0:
print(f"输入数据设备: {batch['input_ids'].device}")
optimizer.zero_grad()
output = model(**batch)
output.loss.backward()
optimizer.step()
if global_step % log_step == 0:
print(f"ep:{ep}, global_step:{global_step}, loss:{output.loss.item()}")
global_step += 1
acc = evaluate() # 调用evaluate函数
print(f"ep:{ep}, acc:{acc}")

step8 模型训练

train()

step9 模型预测

sen = "我觉得这家酒店真不错,饭很好吃!"
id2_label = {0: "差评", 1: "好评"}
model.eval()

with torch.no_grad():
inputs = tokenizer(sen, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()} # 确保输入在正确的设备
# 只在预测开始时打印一次设备信息
print(f"预测开始:\n模型设备: {next(model.parameters()).device}")
print(f"输入数据设备: {inputs['input_ids'].device}")
logits = model(**inputs).logits
pred = torch.argmax(logits, dim=-1)
print(f"输入:{sen}\n模型预测结果:{id2_label.get(pred.item())}")

更简单的做法

from transformers import pipeline

model.config.id2label = id2_label
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
print(pipe(sen))
`
最终运行的结果是这样的:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/864405.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【嵌入式编程】内存分布

一、内存分布图在操作系统中,内存被组织和管理以支持进程的运行。以下是一些常见的内存分布概念: 【内核空间】:操作系统内核使用的内存区域,用于存储内核代码、数据结构和进程控制块(PCB)。【用户空间】:存储用户的代码。未初始化变量区(.bss):存放未初始化的全局变量…

北京健康证(立水桥地铁站附近)

体检:记得带身份证就可以,最好自带一支笔,用他的笔要排队,我买的是96的,带培训证的。 下面这个表只填:身份证号,手机号,姓名就可以,类别、是否培训以交钱时候前台登记的为准,照片也不采集,直接用身份证上的相片

北京健康证

体检:记得带身份证就可以,我买的是96的,带培训证的。 下面这个表只填:身份证号,手机号,姓名就可以,类别以交钱时候前台登记的为准,照片也不采集,直接用身份证上的相片

anaconda安装与环境配置

一、Anaconda简介 ​ Anaconda 是专门为了方便使用 Python 进行数据科学研究而建立的一组软件包,涵盖了数据科学领域常见的 Python 库,并且自带了专门用来解决软件环境依赖问题的 conda 包管理系统。主要是提供了包管理与环境管理的功能,可以很方便地解决多版本python并存、…

java学习报告

Java学习报告 目录 第一章 初识java与面向对象程序设计 1 第二章 java编程基础 3 第三章 面向对象程序设计(基础) 13 第四章 面向对象程序设计(进阶) 15 第五章 异常 17 第六章 java常用类 1720 初识java与面向对象程序设计Java概述计算机编程语言发展史“计算机之父”冯诺…

PyTorch Geometric框架下图神经网络的可解释性机制:原理、实现与评估

在机器学习领域存在一个普遍的认知误区,即可解释性与准确性存在对立关系。这种观点认为可解释模型在复杂度上存在固有限制,因此无法达到最优性能水平,神经网络之所以能够在各个领域占据主导地位,正是因为其超越了人类可理解的范畴。 其实这种观点存在根本性的谬误。研究表明…

25. K 个一组翻转链表(难)

目录题目法一、模拟--迭代法二、递归 题目给你链表的头节点 head ,每 k 个节点一组进行翻转,请你返回修改后的链表。 k 是一个正整数,它的值小于或等于链表的长度。如果节点总数不是 k 的整数倍,那么请将最后剩余的节点保持原有顺序。 你不能只是单纯的改变节点内部的值,而…

OpenWrt 系统UCI详解(Lua、C语言调用uci接口实例)

1 UCI简介 “uci"是"Unified Configuration Interface”(统一配置界面)的缩写,用于OpenWrt整个系统的配置集中化。 很多后台服务有自己的配置文件,并且配置文件格式都不相同,OpenWrt系统中需要开启各种服务,为了解决配置不兼容的问题,统一采用uci格式的配置文件。…

macbook 双系统win7忘密码 解决办法 通用

1开机出现以下图片按开机键强制重启2选择这一个3跟着选择4等待时间较长5继续678点击计算机9选择c盘10选择Windows进入system32文件夹11右键修改名字 将sethc 修改为sethc112然后键盘上输入cmd13将cmd名字修改为sethc 14关掉所有点击完成 15然后开机来到登录界面 按5次shift 次数…

GoLang 2024 安装激活详细使用教程(激活至2026,实测是永久,亲测!)

开发工具推荐:GoLang 安装激活详细使用教程(激活至2026,实际上永久,亲测!)申明:本教程 GoLang 补丁、激活码均收集于网络,请勿商用,仅供个人学习使用,如有侵权,请联系作者删除。若条件允许,希望大家购买正版 ! GoLang是JetBrains公司推出的一款功能强大的GO语言集成…

数值计算方法(3) 数值微分方法

+++ date = 2024-12-21T15:45:47+08:00 draft = true title = 数值计算方法(3) 数值微分方法 +++ 初次发布于我的个人文档 上一期讲了数值积分方法,这一次自然是要讲数值微分方法的,不然太不完善了。 更何况数值微分方法其实是基于数值积分方法得到的。 我们先从比较简单的估…

.Net NativeAOT另外一种选择-bflat

https://www.qiufengblog.com/articles/dotnet-native-bflat.html前言 说起bflat,还得先说NativeAOT,在.Net 7时,正式把NativeAOT合到Runtime中,地位是明显上升了,对NativeAOT的代码提交也越来越多了,之前还是corert时,几年也没有太大的进展. 这个时候的成果还是有ILC(ILCompil…