32、[ShallowFBCSPNet、EEG-ITNet、EEGResNet、EEGInception]4种模型处理脑机接口-MOABB数据库+代码+结果

脑机接口基准之母—MOABB数据库介绍:

18、MOABB:BCI创新模型基准测试的群虫之心-CSDN博客

Dataset:

BNCI 2014-001 Motor Imagery dataset. (BCI IV2a):

https://paperswithcode.com/dataset/bci-competition-4-version-iia

BNCI 2014-002 Motor Imagery dataset. (BCI IV2a,与001参考电极不同):

https://paperswithcode.com/dataset/bnci-2014-002-motor-imagery-dataset-1

BNCI 2014-004 Motor Imagery dataset.(BCI IV2b):

https://paperswithcode.com/dataset/bnci-2014-004-motor-imagery-dataset

BNCI 2015-001 Motor Imagery dataset(5s的右手、双脚持续的运动想象图像):

https://paperswithcode.com/dataset/bnci-2015-001-motor-imagery-dataset-1

BNCI 2015-004 Motor Imagery dataset(7s的5项不同的心理任务MT):

https://paperswithcode.com/dataset/bnci-2015-004-motor-imagery-dataset-1

另外,欢迎大家加入此群聊,今天新建的,用于脑机接口技术交流和知识分享(公益性质),欢迎各位粉丝和有志从事BCI领域的同僚加入!本人专注于研发BCI领域的深度学习模型,致力于研发一种可以媲美EEGNet的新型CNN模型,工作也是脑机接口技术方向,目前就职于国内脑机接口一所龙头企业(研究院)

代码:

导入数据:

from braindecode.datasets import MOABBDataset#1、导入数据
subject_id = 1
# BNCI2014001 表示 BCIC IV 2a 数据集   subject_ids表示试验者编号
dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])

想导入哪个数据直接更换dataset_name编号即可,换被试直接更改subject_id即可

其他处理和剩余代码(时间仓促,3点开会,直接贴上后续代码,代码备注详尽!)


from braindecode.datasets import MOABBDataset
#--------------------------------------------------------------------------------------------------------------
#1、导入数据
subject_id = 1
# BNCI2014001 表示 BCIC IV 2a 数据集   subject_ids表示试验者编号
dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])
#--------------------------------------------------------------------------------------------------------------
#2、滤波处理
from braindecode.preprocessing import (exponential_moving_standardize, preprocess, Preprocessor)
from numpy import multiplylow_cut_hz = 4.  # low cut frequency for filtering
high_cut_hz = 38.  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000
# Factor to convert from V to uV
factor = 1e6preprocessors = [Preprocessor('pick_types', eeg=True, meg=False, stim=False),  # Keep EEG sensorsPreprocessor(lambda data: multiply(data, factor)),  # Convert from V to uVPreprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz),  # Bandpass filterPreprocessor(exponential_moving_standardize,  # Exponential moving standardizationfactor_new=factor_new, init_block_size=init_block_size)
]# Transform the data
preprocess(dataset, preprocessors)
#--------------------------------------------------------------------------------------------------------------
#3、剪切计算窗口
from braindecode.preprocessing import create_windows_from_eventstrial_start_offset_seconds = -0.5 #截取试验之前0.5s数据,4.5s=1125数据点
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info['sfreq']
assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(dataset,trial_start_offset_samples=trial_start_offset_samples,trial_stop_offset_samples=0,preload=True)
#--------------------------------------------------------------------------------------------------------------
#4、数据切分
splitted = windows_dataset.split('session')
train_set = splitted['0train']
valid_set = splitted['1test']
#--------------------------------------------------------------------------------------------------------------
#5、创建模型
import torch
from braindecode.util import set_random_seeds
from braindecode.models import ShallowFBCSPNet
from braindecode.models import EEGITNet
from braindecode.models import EEGResNet
from braindecode.models import EEGInceptioncuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it
device = 'cuda' if cuda else 'cpu'
if cuda:torch.backends.cudnn.benchmark = True
# Set random seed to be able to roughly reproduce results
# Note that with cudnn benchmark set to True, GPU indeterminism
# may still make results substantially different between runs.
# To obtain more consistent results at the cost of increased computation time,
# you can set `cudnn_benchmark=False` in `set_random_seeds`
# or remove `torch.backends.cudnn.benchmark = True`
seed = 20200220
#seed = 0
set_random_seeds(seed=seed, cuda=cuda)n_classes = 4
# Extract number of chans and time steps from dataset
n_chans = train_set[0][0].shape[0] #22Channels
input_window_samples = train_set[0][0].shape[1] #4.5s=1125model = ShallowFBCSPNet(n_chans,n_classes,input_window_samples=input_window_samples,final_conv_length='auto')
#model = EEGITNet(in_channels=n_chans,n_classes=n_classes,input_window_samples=input_window_samples)#model = EEGResNet(in_chans=n_chans,n_classes=n_classes,n_first_filters=8,input_window_samples=input_window_samples,final_pool_length = 'auto')
#model = EEGInception(in_channels=n_chans,n_classes=n_classes,input_window_samples=input_window_samples)
# Send model to GPU
if cuda:model.cuda()
#--------------------------------------------------------------------------------------------------------------
#6、模型训练
from skorch.callbacks import LRScheduler
from skorch.helper import predefined_split
from braindecode import EEGClassifier
# These values we found good for shallow network:
lr = 0.0625 * 0.01
#lr = 0.001
weight_decay = 0
# For deep4 they should be:
# lr = 1 * 0.01
# weight_decay = 0.5 * 0.001
batch_size = 64
n_epochs = 500clf = EEGClassifier(model,criterion=torch.nn.NLLLoss,optimizer=torch.optim.AdamW,train_split=predefined_split(valid_set),  # using valid_set for validationoptimizer__lr=lr,optimizer__weight_decay=weight_decay,batch_size=batch_size,callbacks=["accuracy", ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),],device=device,
)
# Model training for a specified number of epochs. `y` is None as it is already supplied
# in the dataset.
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score# X_train, X_test, y_train, y_test = train_test_split(train_set, y=None, test_size=0.4, random_state=0)
# scores = cross_val_score(clf, train_set,y=None,cv=5)
# scores
# print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() * 2))
clf.fit(train_set, y=None, epochs=n_epochs)
#--------------------------------------------------------------------------------------------------------------
#7、输出结果图像
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import pandas as pd# Extract loss and accuracy values for plotting from history object
results_columns = ['train_loss', 'valid_loss', 'train_accuracy', 'valid_accuracy']
df = pd.DataFrame(clf.history[:, results_columns], columns=results_columns,index=clf.history[:, 'epoch'])# get percent of misclass for better visual comparison to loss
df = df.assign(train_misclass=100 - 100 * df.train_accuracy,valid_misclass=100 - 100 * df.valid_accuracy)plt.style.use('seaborn')
fig, ax1 = plt.subplots(figsize=(8, 3))
df.loc[:, ['train_loss', 'valid_loss']].plot(ax=ax1, style=['-', ':'], marker='o', color='tab:blue', legend=False, fontsize=14)ax1.tick_params(axis='y', labelcolor='tab:blue', labelsize=14)
ax1.set_ylabel("Loss", color='tab:blue', fontsize=14)ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axisdf.loc[:, ['train_misclass', 'valid_misclass']].plot(ax=ax2, style=['-', ':'], marker='o', color='tab:red', legend=False)
ax2.tick_params(axis='y', labelcolor='tab:red', labelsize=14)
ax2.set_ylabel("Misclassification Rate [%]", color='tab:red', fontsize=14)
ax2.set_ylim(ax2.get_ylim()[0], 85)  # make some room for legend
ax1.set_xlabel("Epoch", fontsize=14)# where some data has already been plotted to ax
handles = []
handles.append(Line2D([0], [0], color='black', linewidth=1, linestyle='-', label='Train'))
handles.append(Line2D([0], [0], color='black', linewidth=1, linestyle=':', label='Valid'))
plt.legend(handles, [h.get_label() for h in handles], fontsize=14)
plt.tight_layout()
plt.show()
plt.savefig('acc_loss.png')
#--------------------------------------------------------------------------------------------------------------
#8、混淆矩阵
# from sklearn.metrics import confusion_matrix
# from braindecode.visualization import plot_confusion_matrix# # generate confusion matrices
# # get the targets
# y_true = valid_set.get_metadata().target
# y_pred = clf.predict(valid_set)# # generating confusion matrix
# confusion_mat = confusion_matrix(y_true, y_pred)# # add class labels
# # label_dict is class_name : str -> i_class : int
# label_dict = valid_set.datasets[0].windows.event_id.items()
# # sort the labels by values (values are integer class labels)
# labels = list(dict(sorted(list(label_dict), key=lambda kv: kv[1])).keys())# # plot the basic conf. matrix
# plot_confusion_matrix(confusion_mat, class_names=labels)
#--------------------------------------------------------------------------------------------------------------
from sklearn.metrics import confusion_matrix
from braindecode.visualization import plot_confusion_matrix
# generate confusion matrices
# get the targets
y_true = valid_set.get_metadata().target
y_pred = clf.predict(valid_set)# generating confusion matrix
confusion_mat = confusion_matrix(y_true, y_pred)# add class labels
# label_dict is class_name : str -> i_class : int
# 命令改变的地方 调用方式改变
label_dict = valid_set.datasets[0].window_kwargs[0][1]['mapping']
# sort the labels by values (values are integer class labels)
# 有所改变  但是意思没变
labels = [k for k, v in sorted(label_dict.items(), key=lambda kv: kv[1])]# plot the basic conf. matrix
plot_confusion_matrix(confusion_mat, class_names=labels)
plt.savefig('混淆矩阵.png')
import torchvision.models as models
from torchsummary import summary
summary(model,(1,22,1125),batch_size=64,device="cuda")
print(model)

Result:

1、MOABB-BNCI 2014-001

ShallowFBCSPNet-Sub1

2、MOABB-BNCI 2014-004

ShallowFBCSPNet-Sub4

3、MOABB-BNCI 2015-001

ShallowFBCSPNet-Sub1

4、ShallowFBCSPNet-Sub1

ShallowFBCSPNet-Sub4

上述4个数据给出了其中一个模型的混淆矩阵图,后续我会加上其余的混淆矩阵,希望看到这篇博客的人士:学生也好,工作人员也好,加入我们,大家一起学习,一起进步,共同在脑机接口-算法研发这条路上共同奋进!

                                                                                         ——2024年1月5日,15:00 于北京-馒头

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/323525.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI小蜜批量写作助手:多级指令,插件,GPTs满足不同写作需求

为什么会开发这个脚本? 爆文项目的核心是矩阵怼量 具体怎么做这里介绍很清楚了: AI爆文撸流量主保姆级教程3.0脚本写作教程(解放双手) 我在刚做爆文项目时候,都是手动操作,复制指令,组合指令…

2023 | 美团技术团队热门技术文章汇总

新年好!时光飞逝,我们告别了难忘的2023,迎来了充满希望的2024。再次感谢大家的一路相伴~~ 今天,我们整理了2023年公众号阅读量靠前的10篇技术文章,欢迎大家品阅。祝愿大家在新的一年里,幸福平安&#xff0…

[足式机器人]Part2 Dr. CAN学习笔记-动态系统建模与分析 Ch02-6频率响应与滤波器

本文仅供学习使用 本文参考: B站:DR_CAN Dr. CAN学习笔记-动态系统建模与分析 Ch02-6频率响应与滤波器 1st order system 一阶系统 低通滤波器——Loss Pass Filter

C#上位机与欧姆龙PLC的通信10----开发专用的通讯工具软件(WPF版)

1、介绍 上节开发了一个winform版的通讯测试工具,这节再搞个wpf版的,wpf是什么?请自行百度,也可以看前面的博客,WPF真入门教程,wpf的界面效果是比winform漂亮,因为wpf使用了web项目中的css样式…

国图公考:2024年上半年中小学教师资格考试(笔试)报考须知

(一)信息填报时间:2024年1月12日9:00至1月15日16:00 (二)信息确认时间:2024年1月13日9:00至1月16日16:00 (三)网上缴费时间:2024年1月13日9:00至1月17日24:00

新年福利|这款价值数万的报表工具永久免费了

随着数据资产的价值逐渐凸显,越来越多的企业会希望采用报表工具来处理数据分析,了解业务经营状况,从而辅助经营决策。不过,企业在选型报表工具的时候经常会遇到以下几个问题: 各个报表工具有很多功能和特性&#xff0c…

硬盘基本知识(磁头、磁道、扇区、柱面)

概述 盘片(platter) 磁头(head) 磁道(track) 扇区(sector) 柱面(cylinder) 盘片 片面 和 磁头 硬盘中一般会有多个盘片组成,每个盘片包含两个面…

大数据毕设分享 flink大数据淘宝用户行为数据实时分析与可视化

文章目录 0 前言1、环境准备1.1 flink 下载相关 jar 包1.2 生成 kafka 数据1.3 开发前的三个小 tip 2、flink-sql 客户端编写运行 sql2.1 创建 kafka 数据源表2.2 指标统计:每小时成交量2.2.1 创建 es 结果表, 存放每小时的成交量2.2.2 执行 sql &#x…

商品小程序(6.商品详情)

目录 一、获取商品详情数据二、渲染商品详情页的UI结构1、渲染轮播图区域2、实现轮播图预览效果3、渲染商品信息区域4、渲染商品详情信息5、解决商品价格闪烁的问题 三、渲染详情页底部的商品导航区域1、渲染商品导航区域的UI结构2、点击跳转到购物车页面 本章主要完成商品详情…

大数据框架ElasticSearch学习网站,让你的技能瞬间升级!

介绍:Elasticsearch是一个分布式、免费和开放的搜索和分析引擎,它适用于所有类型的数据,包括文本Elasticsearch是一个分布式、免费和开放的搜索和分析引擎,它适用于所有类型的数据,包括文本、数字、地理空间、结构化和…

DVenom:一款功能强大的Shellcode加密封装和加载工具

关于DVenom DVenom是一款功能强大的Shellcode加密封装和加载工具,该工具专为红队研究人员设计,可以帮助红队成员通过对Shellcode执行加密封装和加载实现反病毒产品的安全检测绕过。 功能介绍 1、支持绕过某些热门反病毒产品; 2、提供了多种…

解决uniapp打包成apk后uni.getStorageSync获取不到值

uniapp写的项目,在hbuilderx中云打包成apk后我在登录存储的token死都获取不到,导致后续接口请求头没有token连接不到接口,只有运行到手机或者模拟器还有打包成apk后是获取不到,其他的小程序还有网页都可以获取到 试过了很多种方法…