信号处理--基于EEG脑电信号的眼睛状态的分析

本实验为生物信息学专题设计小项目。项目目的是通过提供的14导联EEG 脑电信号,实现对于人体睁眼和闭眼两个状态的数据分类分析。每个脑电信号的时长大约为117秒。

目录

加载相关的库函数

读取脑电信号数据并查看数据的属性

绘制脑电多通道连接矩阵

绘制两类数据的相对占比

 数据集划分和预处理

模型定义及可视化

模型训练及训练可视化

模型评价


加载相关的库函数

import tensorflow.compat.v1 as tf
from sklearn.metrics import confusion_matrix
import numpy as np
from scipy.io import loadmat
import os
from pywt import wavedec
from functools import reduce
from scipy import signal
from scipy.stats import entropy
from scipy.fft import fft, ifft
import pandas as pd
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from tensorflow import keras as K
import matplotlib.pyplot as plt
import scipy
from sklearn import metrics
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold,cross_validate
from tensorflow.keras.layers import Dense, Activation, Flatten, concatenate, Input, Dropout, LSTM, Bidirectional,BatchNormalization,PReLU,ReLU,Reshape
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.metrics import classification_report
from tensorflow.keras.models import Sequential, Model, load_model
import matplotlib.pyplot as plt;
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.decomposition import PCA
from tensorflow import keras
from sklearn.model_selection import cross_val_score
from tensorflow.keras.layers import Conv1D,Conv2D,Add
from tensorflow.keras.layers import MaxPool1D, MaxPooling2D
import seaborn as snsimport warnings
warnings.filterwarnings('ignore')

读取脑电信号数据并查看数据的属性

df = pd.read_csv("../input/eye-state-classification-eeg-dataset/EEG_Eye_State_Classification.csv")df.info()

 

 

绘制脑电多通道连接矩阵

 

plt.figure(figsize = (15,15))
cor_matrix = df.corr()
sns.heatmap(cor_matrix,annot=True)

 

绘制两类数据的相对占比

# Plotting target distribution 
plt.figure(figsize=(6,6))
df['eyeDetection'].value_counts().plot.pie(explode=[0.1,0.1], autopct='%1.1f%%', shadow=True, textprops={'fontsize':16}).set_title("Target distribution")

 数据集划分和预处理

data = df.copy()
y= data.pop('eyeDetection')
x= datax_new = StandardScaler().fit_transform(x)x_new = pd.DataFrame(x_new) 
x_new.columns = x.columnsx_train,x_test,y_train,y_test = train_test_split(x_new,y,test_size=0.15)x_train = np.array(x_train).reshape(-1,14,1)
x_test = np.array(x_test).reshape(-1,14,1)

模型定义及可视化

inputs = tf.keras.Input(shape=(14,1))Dense1 = Dense(64, activation = 'relu',kernel_regularizer=keras.regularizers.l2())(inputs)#Dense2 = Dense(128, activation = 'relu',kernel_regularizer=keras.regularizers.l2())(Dense1)
#Dense3 = Dense(256, activation = 'relu',kernel_regularizer=keras.regularizers.l2())(Dense2)lstm_1=  Bidirectional(LSTM(256, return_sequences = True))(Dense1)
drop = Dropout(0.3)(lstm_1)
lstm_3=  Bidirectional(LSTM(128, return_sequences = True))(drop)
drop2 = Dropout(0.3)(lstm_3)flat = Flatten()(drop2)#Dense_1 = Dense(256, activation = 'relu')(flat)Dense_2 = Dense(128, activation = 'relu')(flat)
outputs = Dense(1, activation='sigmoid')(Dense_2)model = tf.keras.Model(inputs, outputs)model.summary()tf.keras.utils.plot_model(model)def train_model(model,x_train, y_train,x_test,y_test, save_to, epoch = 2):opt_adam = keras.optimizers.Adam(learning_rate=0.001)es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10)mc = ModelCheckpoint(save_to + '_best_model.h5', monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)lr_schedule = tf.keras.callbacks.LearningRateScheduler(lambda epoch: 0.001 * np.exp(-epoch / 10.))model.compile(optimizer=opt_adam,loss=['binary_crossentropy'],metrics=['accuracy'])history = model.fit(x_train,y_train,batch_size=20,epochs=epoch,validation_data=(x_test,y_test),callbacks=[es,mc,lr_schedule])saved_model = load_model(save_to + '_best_model.h5')return model,history

 

 

模型训练及训练可视化

model,history = train_model(model, x_train, y_train,x_test, y_test, save_to= './', epoch = 100)plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

 

 

 

模型评价

y_pred =model.predict(x_test)
y_pred = np.array(y_pred >= 0.5, dtype = np.int)
confusion_matrix(y_test, y_pred)print(classification_report(y_test, y_pred))

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/76217.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【大虾送书第六期】搞懂大模型的智能基因,RLHF系统设计关键问答

目录 ✨1、RLHF是什么? ✨2、RLHF适用于哪些任务? ✨3、RLHF和其他构建奖励模型的方法相比有何优劣? ✨4、什么样的人类反馈才是好的反馈 ✨5、RLHF算法有哪些类别,各有什么优缺点? ✨6、RLHF采用人类反馈会带来哪些局…

71 # 协商缓存的配置:通过内容

对比(协商)缓存 比较一下再去决定是用缓存还是重新获取数据,这样会减少网络请求,提高性能。 对比缓存的工作原理 客户端第一次请求服务器的时候,服务器会把数据进行缓存,同时会生成一个缓存标识符&#…

算法leetcode|72. 编辑距离(rust重拳出击)

文章目录 72. 编辑距离:样例 1:样例 2:提示: 分析:题解:rust:二维数组(易懂)滚动数组(更加优化的内存空间) go:c:python&a…

(排序) 剑指 Offer 21. 调整数组顺序使奇数位于偶数前面 ——【Leetcode每日一题】

❓剑指 Offer 21. 调整数组顺序使奇数位于偶数前面 难度:简单 输入一个整数数组,实现一个函数来调整该数组中数字的顺序,使得所有奇数在数组的前半部分,所有偶数在数组的后半部分。 示例: 输入:nums [1…

辅助笔记-安装Ubantu20.04.1虚拟机

安装Ubantu20.04.1虚拟机 文章目录 安装Ubantu20.04.1虚拟机步骤一:检查BIOS虚拟化支持步骤二:VMware17安装虚拟机步骤1:新建虚拟机步骤2:验证虚拟机能否上网 步骤3:设置Ubantu语言为中文 本文主要参考B站视频“P108_ …

公司内部测试团队可以替代专业的软件检测机构吗,性能测试怎么收费?

尽管软件测试是伴随着软件开发的发展而产生的,但是在信息技术日新月异的今天,软件测试逐渐走出开发附庸的定位。 一方面,很多大型企业都在内部设置了专门的测试团队以承接软件系统的测试工作,为产品质量把关。另一方面&#xff0c…

nginx反向代理后实现nginx和apache两种web服务器能够记录客户端的真实IP地址

一.构建环境 二.配置反向代理 1.基于源码安装的nginx环境下修改nginx.conf(设备1) 2.通过windows powershell进行修改hosts文件并测试 3.设备2和设备3上查看日志,可以看到访问来源都是代理服务器(2.190)而不是真实…

SpringBoot的日志信息及Lombok的常用注解

文章目录 一. 日志的介绍1. 什么是日志2. 日志的作用 二. 日志的使用1. 日志格式说明2. 自定义日志的输出3. 日志级别4. 日志级别的配置5. 日志持久化6. 更简单的输出日志-Lomok7. Lombok框架实现原理以及其他常见注解 一. 日志的介绍 1. 什么是日志 日志是我们程序重要组成部…

基于Mysql+Vue+Django的协同过滤和内容推荐算法的智能音乐推荐系统——深度学习算法应用(含全部工程源码)+数据集

目录 前言总体设计系统整体结构图系统流程图 运行环境Python 环境MySQL环境VUE环境 模块实现1. 数据请求和储存2. 数据处理计算歌曲、歌手、用户相似度计算用户推荐集 3. 数据存储与后台4. 数据展示 系统测试工程源代码下载其它资料下载 前言 本项目以丰富的网易云音乐数据为基…

WSL2和本地windows端口互通

众所周知 WSL 默认安装后,只允许windows访问 Windows Subsystem for Linux,而WSL是不能反之访问本地windows。我之前用vmware的思路认为是nat的网络模式,于是改成了桥接,结果wsl的桥接模式被我改的能访问本地,但是却不…

Vue--BM记事本

效果如下&#xff1a; 用到了如下的技术&#xff1a; 1.列表渲染&#xff1a;v-for key的设置 2.删除功能&#xff1a;v-on调用参数 fliter过滤 覆盖修改原数组 3.添加功能&#xff1a;v-model绑定&#xff0c;unshift修改原数组添加 html文件如下&#xff1a; <!DOCTYPE …

使用 Node.js 生成优化的图像格式

使用 Node.js 生成优化的图像格式 图像是任何 Web 应用程序的重要组成部分&#xff0c;但如果优化不当&#xff0c;它们也可能成为性能问题的主要根源。在本文中&#xff0c;我们将介绍如何使用 Node.js 自动生成优化的图像格式&#xff0c;并以最适合用户浏览器的格式显示它们…