27、ResNet50处理STEW数据集,用于情感三分类+全备的代码

1、数据介绍

IEEE-Datasets-STEW:SIMULTANEOUS TASK EEG WORKLOAD DATASET :

该数据集由48名受试者的原始EEG数据组成,他们参加了利用SIMKAP多任务测试进行的多任务工作负荷实验。受试者在休息时的大脑活动也在测试前被记录下来,也包括在其中。Emotiv EPOC设备,采样频率为128Hz,有14个通道,用于获取数据,每个案例都有2.5分钟的EEG记录。受试者还被要求在每个阶段后以1到9的评分标准对其感知的心理工作量进行评分,评分结果在单独的文件中提供。

说明:每个受试者的数据遵循命名惯例:subno_task.txt。例如,sub01_lo.txt将是受试者1在休息时的原始脑电数据,而sub23_hi.txt将是受试者23在多任务测试中的原始脑电数据。每个数据文件的行对应于记录中的样本,列对应于EEG设备的14个通道: AF3, F7, F3, FC5, T7, P7, O1, O2, P8, T8, FC6, F4, F8, AF4。

数据说明、下载地址:

STEW: Simultaneous Task EEG Workload Data Set | IEEE Journals & Magazine | IEEE Xplore

2、代码

本次使用ResNet50,去做此情感数据的分类工作,数据导入+模型训练+测试代码如下:

import torch
import torchvision.datasets
from torch.utils.data import Dataset        # 继承Dataset类
import os
from PIL import Image
import numpy as np
from torchvision import transforms# 预处理
data_transform = transforms.Compose([transforms.Resize((224,224)),           # 缩放图像transforms.ToTensor(),                  # 转为Tensotransforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))       # 标准化
])path =  r'C:\STEW\test'for root,dirs,files in os.walk(path):print('root',root) #遍历到该目录地址print('dirs',dirs) #遍历到该目录下的子目录名 []print('files',files)  #遍历到该目录下的文件  []def read_txt_files(path):# 创建文件名列表file_names = []# 遍历给定目录及其子目录下的所有文件for root, dirs, files in os.walk(path):# 遍历所有文件for file in files:# 如果是 .txt 文件,则加入文件名列表if file.endswith('.txt'): # endswith () 方法用于判断字符串是否以指定后缀结尾,如果以指定后缀结尾返回True,否则返回False。file_names.append(os.path.join(root, file))# 返回文件名列表return file_namesclass DogCat(Dataset):      # 数据处理def __init__(self,root,transforms = None):                  # 初始化,指定路径,是否预处理等等#['cat.15454.jpg', 'cat.445.jpg', 'cat.46456.jpg', 'cat.656165.jpg', 'dog.123.jpg', 'dog.15564.jpg', 'dog.4545.jpg', 'dog.456465.jpg']imgs = os.listdir(root)self.imgs = [os.path.join(root,img) for img in imgs]    # 取出root下所有的文件self.transforms = data_transform                        # 图像预处理def __getitem__(self, index):       # 读取图片img_path = self.imgs[index]label = 1 if 'dog' in img_path.split('/')[-1] else 0 #然后,就可以根据每个路径的id去做label了。将img_path 路径按照 '/ '分割,-1代表取最后一个字符串,如果里面有dog就为1,cat就为0.data = Image.open(img_path)if self.transforms:     # 图像预处理data = self.transforms(data)return data,labeldef __len__(self):return len(self.imgs)dataset = DogCat('./data/',transforms=True)for img,label in dataset:print('img:',img.size(),'label:',label)
'''
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
'''import os# 获取file_path路径下的所有TXT文本内容和文件名
def get_text_list(file_path):files = os.listdir(file_path)text_list = []for file in files:with open(os.path.join(file_path, file), "r", encoding="UTF-8") as f:text_list.append(f.read())return text_list, filesclass ImageFolderCustom(Dataset):# 2. Initialize with a targ_dir and transform (optional) parameterdef __init__(self, targ_dir: str, transform=None) -> None:# 3. Create class attributes# Get all image pathsself.paths = list(pathlib.Path(targ_dir).glob("*/*.jpg")) # note: you'd have to update this if you've got .png's or .jpeg's# Setup transformsself.transform = transform# Create classes and class_to_idx attributesself.classes, self.class_to_idx = find_classes(targ_dir)# 4. Make function to load imagesdef load_image(self, index: int) -> Image.Image:"Opens an image via a path and returns it."image_path = self.paths[index]return Image.open(image_path) # 5. Overwrite the __len__() method (optional but recommended for subclasses of torch.utils.data.Dataset)def __len__(self) -> int:"Returns the total number of samples."return len(self.paths)# 6. Overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset)def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:"Returns one sample of data, data and label (X, y)."img = self.load_image(index)class_name  = self.paths[index].parent.name # expects path in data_folder/class_name/image.jpegclass_idx = self.class_to_idx[class_name]# Transform if necessaryif self.transform:return self.transform(img), class_idx # return data, label (X, y)else:return img, class_idx # return data, label (X, y)import torchvision as tv
import numpy as np
import torch
import time
import os
from torch import nn, optim
from torchvision.models import resnet50
from torchvision.transforms import transformsos.environ["CUDA_VISIBLE_DEVICE"] = "0,1,2"# cifar-10进行测验class Cutout(object):"""Randomly mask out one or more patches from an image.Args:n_holes (int): Number of patches to cut out of each image.length (int): The length (in pixels) of each square patch."""def __init__(self, n_holes, length):self.n_holes = n_holesself.length = lengthdef __call__(self, img):"""Args:img (Tensor): Tensor image of size (C, H, W).Returns:Tensor: Image with n_holes of dimension length x length cut out of it."""h = img.size(1)w = img.size(2)mask = np.ones((h, w), np.float32)for n in range(self.n_holes):y = np.random.randint(h)x = np.random.randint(w)y1 = np.clip(y - self.length // 2, 0, h)y2 = np.clip(y + self.length // 2, 0, h)x1 = np.clip(x - self.length // 2, 0, w)x2 = np.clip(x + self.length // 2, 0, w)mask[y1: y2, x1: x2] = 0.mask = torch.from_numpy(mask)mask = mask.expand_as(img)img = img * maskreturn imgdef load_data_cifar10(batch_size=128,num_workers=2):# 操作合集# Data augmentationtrain_transform_1 = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.RandomRotation(degrees=(-80,80)),  # 随机角度翻转transforms.ToTensor(),transforms.Normalize((0.491339968,0.48215827,0.44653124), (0.24703233,0.24348505,0.26158768)  # 两者分别为(mean,std)),Cutout(1, 16),  # 务必放在ToTensor的后面])train_transform_2 = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.491339968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)  # 两者分别为(mean,std))])test_transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize((0.491339968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)  # 两者分别为(mean,std))])# 训练集1trainset1 = tv.datasets.CIFAR10(root='data',train=True,download=False,transform=train_transform_1,)# 训练集2trainset2 = tv.datasets.CIFAR10(root='data',train=True,download=False,transform=train_transform_2,)# 测试集testset = tv.datasets.CIFAR10(root='data',train=False,download=False,transform=test_transform,)# 训练数据加载器1trainloader1 = torch.utils.data.DataLoader(trainset1,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=(torch.cuda.is_available()))# 训练数据加载器2trainloader2 = torch.utils.data.DataLoader(trainset2,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=(torch.cuda.is_available()))# 测试数据加载器testloader = torch.utils.data.DataLoader(testset,batch_size=batch_size,shuffle=False,num_workers=num_workers,pin_memory=(torch.cuda.is_available()))return trainloader1,trainloader2,testloaderdef main():start = time.time()batch_size = 128cifar_train1,cifar_train2,cifar_test = load_data_cifar10(batch_size=batch_size)model = resnet50().cuda()# model.load_state_dict(torch.load('_ResNet50.pth'))# 存在已保存的参数文件# model = nn.DataParallel(model,device_ids=[0,])  # 又套一层model = nn.DataParallel(model,device_ids=[0,1,2])loss = nn.CrossEntropyLoss().cuda()optimizer = optim.Adam(model.parameters(),lr=0.001)for epoch in range(50):model.train()  # 训练时务必写loss_=0.0num=0.0# train on trainloader1(data augmentation) and trainloader2for i,data in enumerate(cifar_train1,0):x, label = datax, label = x.cuda(),label.cuda()# xp = model(x) #outputl = loss(p,label) #lossoptimizer.zero_grad()l.backward()optimizer.step()loss_ += float(l.mean().item())num+=1for i, data in enumerate(cifar_train2, 0):x, label = datax, label = x.cuda(), label.cuda()# xp = model(x)l = loss(p, label)optimizer.zero_grad()l.backward()optimizer.step()loss_ += float(l.mean().item())num += 1model.eval()  # 评估时务必写print("loss:",float(loss_)/num)# test on trainloader2,testloaderwith torch.no_grad():total_correct = 0total_num = 0for x, label in cifar_train2:# [b, 3, 32, 32]# [b]x, label = x.cuda(), label.cuda()# [b, 10]logits = model(x)# [b]pred = logits.argmax(dim=1)# [b] vs [b] => scalar tensorcorrect = torch.eq(pred, label).float().sum().item()total_correct += correcttotal_num += x.size(0)# print(correct)acc_1 = total_correct / total_num# Testwith torch.no_grad():total_correct = 0total_num = 0for x, label in cifar_test:# [b, 3, 32, 32]# [b]x, label = x.cuda(), label.cuda()# [b, 10]logits = model(x) #output# [b]pred = logits.argmax(dim=1)# [b] vs [b] => scalar tensorcorrect = torch.eq(pred, label).float().sum().item()total_correct += correcttotal_num += x.size(0)# print(correct)acc_2 = total_correct / total_numprint(epoch+1,'train acc',acc_1,'|','test acc:', acc_2)# 保存时只保存model.moduletorch.save(model.module.state_dict(),'resnet50.pth')print("The interval is :",time.time() - start)if __name__ == '__main__':main()

3、对你有帮助的话,给个关注吧~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/292683.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

15-deoxy-Δ12,14-PGJ2 ELISA kit

首款上市的15-d-PGJ2 ELISA试剂盒,可用于类花生酸研究 15-deoxy-Δ12,14-PGJ2(15-d-PGJ2)是PGD2的最终脱水产物之一,通过中间体Δ12-PGJ2形成。生理条件下,15-d-PGJ2存在于体液中,浓度介于10^(-12)至10^(-9…

Enge问题解决教程

目录 解决问题的一般步骤: 针对"Enge问题"的具体建议: 以下是一些普遍适用的解决问题的方法: 以下是一些更深入的Enge浏览器问题和解决办法: 浏览器性能问题: 浏览器插件与网站冲突: 浏览…

为什么C语言函数声明与定义的参数名称可以不一样呢??

为什么C语言函数声明与定义的参数名称可以不一样呢?? 在开始前我有一些资料,是我根据自己从业十年经验,熬夜搞了几个通宵,精心整理了一份「C语言的资料从专业入门到高级教程工具包」,点个关注,全部无偿共享…

不受父容器大小约束的TextView

序言 为了实现以下效果,特意开发了一个自定义控件。主要是红色的点赞数和评论数。 问题分析 自定义控件 该控件主要是在于忽略的父容器的大小限制,这样可以展示出全部内容。注意父容器的属性中需要下列配置。 package com.trs.myrb.view.count;impor…

计算机体系结构实验——Branch-Target Buffers

实验五 Branch-Target Buffers 本次实验的主要目的是加深对Branch-Target Buffers的理解。掌握使用Branch-Target Buffers减少或增加分支带来的延迟的情况。 实验内容: 将以下程序段修改为可利用WinMIPS64模拟器运行的程序。假设R3的初始值为R240 在使用forward…

【C++入门到精通】互斥锁 (Mutex) C++11 [ C++入门 ]

阅读导航 引言一、Mutex的简介二、Mutex的种类1. std::mutex (基本互斥锁)2. std::recursive_mutex (递归互斥锁)3. std::timed_mutex (限时等待互斥锁)4. std::recursive_timed_mutex (限时等待…

rhel7/centos7升级openssh到openssh9.5-p1

openssh9.3-p2以下版本有如下漏洞 在rhel7.4/7.5/7.6均做过测试。 本文需要用到的rpm包如下: https://download.csdn.net/download/kadwf123/88652359 升级步骤 1、升级前启动telnet ##升级前启动telnet服务 yum -y install telnet-server yum -y install xinetd…

Redis(非关系型数据库)

Redis(非关系型数据库) 文章目录 Redis(非关系型数据库)认识Redis(Remote Dictionary Server)1.Redis的基本介绍2.Redis的应用场景2.1 取最新N个数据的操作2.2 排行榜应用,取TOP N操作2.3 需要精准设定过期时间的应用2.4 计数器应用2.5 Uniq 操作,获取某段时间所有数…

MATLAB遗传算法工具箱的三种使用方法

MATLAB中有三种调用遗传算法的方式: 一、遗传算法的开源文件 下载“gatbx”压缩包文件,解压后,里面有多个.m文件,可以看到这些文件的编辑日期都是1998年,很古老了。 这些文件包含了遗传算法的基础操作,包含…

【深度学习实践】换脸应用dofaker本地部署

本文介绍了dofaker换脸应用的本地部署教程,dofaker支持windows、linux、cpu/gpu推理,不依赖于任何深度学习框架,是一个非常好用的换脸工具。 本教程的部署系统为windows 11,使用CPU推理。 注意: 1、请确保您的所有路…

【大模型】快速体验百度智能云千帆AppBuilder搭建知识库与小助手

文章目录 前言千帆AppBuilder什么是千帆AppBuilderAppBuilder能做什么 体验千帆AppBuilderJava知识库高考作文小助手 总结 前言 前天,在【百度智能云智算大会】上,百度智能云千帆AppBuilder正式开放服务。这是一个AI原生应用开发工作台,可以…

线程活跃性问题(死锁、活锁、饥饿)

1.什么是“死锁”? 在多线程并发中,两个及以上线程互相持有对方所需要的资源又不主动释放,导致程序进入无尽的阻塞这就是“死锁”; 2.写一段“死锁”代码 import java.util.concurrent.TimeUnit; /*** 写一段必然发生死锁代码*/ public class MustDead…