自然语言处理实验2 字符级RNN分类实验

实验2 字符级RNN分类实验

必做题:

(1)数据准备:academy_titles.txt为“考硕考博”板块的帖子标题,job_titles.txt为“招聘信息”板块的帖子标题,将上述两个txt进行划分,其中训练集为70%,测试集为30%。二分类标签:考硕考博为0,招聘信息为1。字符使用One-hot方法表示。

(2)设计模型:在训练集上训练字符级RNN模型。注意,字符级不用分词,是将文本的每个字依次送入模型。

(3)将训练好的模型在测试数据集上进行验证,计算准确率,并分析实验结果。要给出每一部分的代码。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.model_selection import train_test_split# 读取academy_titles文件内容
with open('C:\\Users\\hp\\Desktop\\academy_titles.txt', 'r', encoding='utf-8') as file:academy_titles = file.readlines()# 读取job_titles文件内容
with open('C:\\Users\\hp\\Desktop\\job_titles.txt', 'r', encoding='utf-8') as file:job_titles = file.readlines()# 将招聘信息与学术信息分开
academy_titles = [title.strip() for title in academy_titles]
job_titles = [title.strip() for title in job_titles]# 构建标签和数据
X = academy_titles + job_titles
y = [0] * len(academy_titles) + [1] * len(job_titles)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 构建字符到索引的映射
all_chars = set(''.join(academy_titles + job_titles))
char_to_index = {char: i for i, char in enumerate(all_chars)}# 将文本转换为模型可接受的输入形式
def text_to_input(text, max_len, char_to_index):X_indices = np.zeros((len(text), max_len, len(char_to_index)), dtype=np.float32)for i, title in enumerate(text):for t, char in enumerate(title):X_indices[i, t, char_to_index[char]] = 1return torch.tensor(X_indices)max_len = max([len(title) for title in X])
X_train_indices = text_to_input(X_train, max_len, char_to_index)
X_test_indices = text_to_input(X_test, max_len, char_to_index)# 构建字符级RNN模型
class CharRNN(nn.Module):def __init__(self, input_size, hidden_size):super(CharRNN, self).__init__()self.hidden_size = hidden_sizeself.i2h = nn.LSTM(input_size, hidden_size)self.fc = nn.Linear(hidden_size, 1)self.sigmoid = nn.Sigmoid()def forward(self, input):hidden, _ = self.i2h(input)output = self.fc(hidden[-1])output = self.sigmoid(output)return outputmodel = CharRNN(input_size=len(char_to_index), hidden_size=128)# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 转换数据为PyTorch张量
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)# 定义新的训练周期数和学习率
num_epochs = 30
learning_rate = 0.01# 定义新的优化器
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
best_accuracy = 0.0
best_model = None# 训练模型并输出每一轮的准确率
for epoch in range(num_epochs):optimizer.zero_grad()output = model(X_train_indices)output = output.view(-1, 1)loss = criterion(output, y_train_tensor[:output.size(0)])loss.backward()optimizer.step()# 计算训练集准确率predictions = (output > 0.5).float()correct = (predictions == y_train_tensor[:output.size(0)]).float()accuracy = correct.sum() / len(correct)print(f'Epoch {epoch+1}, 训练集准确率: {accuracy.item()}')# 保存准确率最高的模型if accuracy > best_accuracy:best_accuracy = accuracybest_model = model.state_dict().copy()# 加载最佳模型参数
model.load_state_dict(best_model)# 使用测试集上准确率最高的模型进行测试
test_output = model(X_test_indices)
test_output = test_output.view(-1, 1)
test_loss = criterion(test_output, y_test_tensor[:test_output.size(0)])
predictions = (test_output > 0.5).float()
correct = (predictions == y_test_tensor[:test_output.size(0)]).float()
accuracy = correct.sum() / len(correct)print(f'使用测试集上准确率最高的模型进行测试,准确率: {accuracy.item()}')

 这个实验准确率目前是偏低的,但是我没有很多时间去一直调整参数

希望后面有需要的同学,可以去调整参数!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/539912.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

概率论与数理统计(随机事件与概率)

1随机事件与概率 1.1随机事件及其运算规律 1.1.1运算 交换律结合律分配律德摩根律 1.2概率的定义及其确定方法 1.2.1概率的统计定义 频率 设在 n 次试验中,事件 A 发生了(A)次,则称为事件 A 发生的频率。 1.2.2概率的统计定义 在一组恒定不变的条…

GPT-SoVITS开源音色克隆框架的训练与调试

GPT-SoVITS开源框架的报错与调试 遇到的问题解决办法 GPT-SoVITS是一款创新的跨语言音色克隆工具,同时也是一个非常棒的少样本中文声音克隆项目。 它是是一个开源的TTS项目,只需要1分钟的音频文件就可以克隆声音,支持将汉语、英语、日语三种…

vscode 导入前端项目

vscode 导入前端项目 导入安装依赖 运行 参考vscode 下载 导入 安装依赖 运行 在前端项目的终端中输入npm run serve

KKVIEW: 远程控制软件哪个好用

远程控制软件哪个好用 随着科技的发展和工作方式的改变,远程控制软件越来越受到人们的关注和需求。无论是在家中远程办公,还是技术支持人员为远程用户提供帮助,选择一款高效稳定的远程控制软件至关重要。在众多选择中,有几款远程…

【数学建模】线性规划

针对未来可能的数学建模比赛内容,我对学习的内容做了一些调整,所以先跳过灰色关联分析和模糊综合评价的代码,今天先来了解一下运筹规划类——线性规划模型。 背景: 某数学建模游戏有三种题型,分别是A,B&am…

【AI论文阅读笔记】ResNet残差网络

论文地址:https://arxiv.org/abs/1512.03385 摘要 重新定义了网络的学习方式 让网络直接学习输入信息与输出信息的差异(即残差) 比赛第一名1 介绍 不同级别的特征可以通过网络堆叠的方式来进行丰富 梯度爆炸、梯度消失解决办法:1.网络参数的初始标准化…

微博热搜榜单采集,微博热搜榜单爬虫,微博热搜榜单解析,完整代码(话题榜+热搜榜+文娱榜和要闻榜)

文章目录 代码1. 话题榜2. 热搜榜3. 文娱榜和要闻榜 过程1. 话题榜2. 热搜榜3. 文娱榜和要闻榜 代码 1. 话题榜 import requests import pandas as pd import urllib from urllib import parse headers { authority: weibo.com, accept: application/json, text/pl…

jdk版本规则看这里

Java Development Kit (JDK) 的版本号是由几个不同的数字和有时的字母组合来定义的,这些数字和字母表达了版本的不同层面。下面是 JDK 版本号的一般结构和它们各自的含义: JDK 版本号的组成 主版本号 - 表示主要的发布版本。例如,在 JDK 8 或…

使用 WXT 开发浏览器插件(上手使用篇)

WXT (https://wxt.dev/), Next-gen Web Extension Framework. 号称下一代浏览器开发框架. 可一套代码 (code base) 开发支持多个浏览器的插件. 上路~ WXT 提供了脚手架可以方便我们快速进行开发,但是我们得先安装好环境依赖,这里我们使用 npm, 所以需要…

某赛通电子文档安全管理系统 DecryptApplication 任意文件读取漏洞(2024年3月发布)

漏洞简介 某赛通电子文档安全管理系统 DecryptApplication 接口处任意文件读取漏洞,未经身份验证的攻击者利用此漏洞获取系统内部敏感文件信息,导致系统处于极不安全的状态。 漏洞等级高危影响版本*漏洞类型任意文件读取影响范围>1W 产品简介 …

Selenium 学习(0.20)——软件测试之单元测试

我又(浪完)回来了…… 很久没有学习了,今天忙完终于想起来学习了。没有学习的这段时间,主要是请了两个事假(5工作日和10工作日)放了个年假(13天),然后就到现在了。 看了下…

pytorch之诗词生成3--utils

先上代码: import numpy as np import settingsdef generate_random_poetry(tokenizer, model, s):"""随机生成一首诗:param tokenizer: 分词器:param model: 用于生成古诗的模型:param s: 用于生成古诗的起始字符串,默认为空串:return: …