【信号处理】基于DGGAN的单通道脑电信号增强和情绪检测(tensorflow)

关于

情绪检测,是脑科学研究中的一个常见和热门的方向。在进行情绪检测的分类中,真实数据不足,经常导致情绪检测模型的性能不佳。因此,对数据进行增强,成为了一个提升下游任务的重要的手段。本项目通过DCGAN模型实现脑电信号的扩充。

 图片来源:https://www.medicalnewstoday.com/articles/seizure-eeg

工具

数据

方法实现

DCGAN速递:https://arxiv.org/abs/1511.06434

数据加载和预处理
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import LSTM
from tensorflow.keras.optimizers import SGD
from sklearn.metrics import accuracy_score
from model_DCGAN import DCGAN
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from sklearn.utils import shuffle
from sklearn.ensemble import GradientBoostingClassifieruse_feature_reduction = Truetf.keras.backend.clear_session()df=pd.read_csv('dataset/emotions.csv')encode = ({'NEUTRAL': 0, 'POSITIVE': 1, 'NEGATIVE': 2} )
#new dataset with replaced values
df_encoded = df.replace(encode)print(df_encoded.head())
print(df_encoded['label'].value_counts()),x=df_encoded.drop(["label"]  ,axis=1)
y = df_encoded.loc[:,'label'].valuesscaler = StandardScaler()
scaler.fit(x)
x = scaler.transform(x)
y = to_categorical(y)x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2, random_state = 4)if use_feature_reduction:# Feature reduction partest = GradientBoostingClassifier(n_estimators=10, learning_rate=0.1, random_state=0).fit(x_train,y_train.argmax(-1))# Obtain feature importance results from Gradient Boosting Regressorfeature_importance = est.feature_importances_epsilon_feature = 1e-2x_train = x_train[:, feature_importance > epsilon_feature]x_test = x_test[:, feature_importance > epsilon_feature]
设置DCGAN优化器

# setup optimzers
gen_optim = Adam(1e-4, beta_1=0.5)
disc_optim = RMSprop(5e-4)
 训练GAN生成类别0脑电数据
# generate samples for class 0
generator_class = 0
dcgan = DCGAN(gen_optim, disc_optim, noise_dim=100, dropout=0.3, input_dim=x_train.shape[2])
x_train_class_0 = x_train[y_train[:,generator_class]==1,:]
loss_history_class_0, acc_history_class_0, grads_history_class_0 = dcgan.train(x_train_class_0, epochs=100)
print("Class 0 fake samples are generating")
generator_class_0 = dcgan.generator
generated_samples_class_0, _ = dcgan.generate_fake_data(N=len(x_train_class_0))
  训练GAN生成类别1脑电数据
# generate samples for class 1
generator_class = 1
dcgan = DCGAN(gen_optim, disc_optim, noise_dim=100, dropout=0.3, input_dim=x_train.shape[2])
x_train_class_1 = x_train[y_train[:,generator_class]==1,:]
loss_history_class_1, acc_history_class_1, grads_history_class_1 = dcgan.train(x_train_class_1, epochs=100)
print("Class 1 fake samples are generating")
generator_class_1 = dcgan.generator
generated_samples_class_1, _ = dcgan.generate_fake_data(N=len(x_train_class_1))
 训练GAN生成类别2脑电数据
# generate samples for class 2
generator_class = 2
dcgan = DCGAN(gen_optim, disc_optim, noise_dim=100, dropout=0.3, input_dim=x_train.shape[2])
x_train_class_2 = x_train[y_train[:,generator_class]==1,:]
loss_history_class_2, acc_history_class_2, grads_history_class_2 = dcgan.train(x_train_class_2,epochs=100)
print("Class 2 fake samples are generating")
generator_class_2 = dcgan.generator
generated_samples_class_2, _ = dcgan.generate_fake_data(N=len(x_train_class_2))
合成数据融入真实训练数据集
generated_samples = np.concatenate((generated_samples_class_0,generated_samples_class_1,generated_samples_class_2),axis=0)
generated_y =np.concatenate((np.zeros((len(x_train_class_0),),dtype=np.int32),np.ones((len(x_train_class_1),),dtype=np.int32),2 * np.ones((len(x_train_class_2),),dtype=np.int32)),axis=0)generated_y = to_categorical(generated_y)x_train_all = np.concatenate((x_train,generated_samples),axis=0)
y_train_all = np.concatenate((y_train,generated_y), axis=0)#shuffle training data
x_train_all, y_train_all = shuffle(x_train_all,y_train_all)
 基于数据增强的LSTM模型情绪检测
model = Sequential()
model.add(LSTM(64, input_shape=(1,x_train_all.shape[2]),activation="relu",return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(32,activation="sigmoid"))
model.add(Dropout(0.2))
model.add(Dense(3, activation='sigmoid'))
model.compile(loss = 'categorical_crossentropy', optimizer = "adam", metrics = ['accuracy'])
model.summary()history = model.fit(x_train_all, y_train_all, epochs = 250, validation_data= (x_test, y_test))
score, acc = model.evaluate(x_test, y_test)pred = model.predict(x_test)
predict_classes = np.argmax(pred,axis=1)
expected_classes = np.argmax(y_test,axis=1)
print(expected_classes.shape)
print(predict_classes.shape)
correct = accuracy_score(expected_classes,predict_classes)
print(f"Test Accuracy: {correct}")

已附DCGAN模型

相关项目和代码问题,欢迎沟通交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/573584.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

预期为文件结尾。json [行2,列1]

报错背景 在huggingface上传数据集后,Dataset Viewer无法显示,报错: The dataset viewer is not available for this split. Cannot extract the features (columns) for the split train of the config default of the dataset. Error cod…

【Java初阶(六)上】封装 继承 多态

❣博主主页: 33的博客❣ ▶文章专栏分类: Java从入门到精通◀ 🚚我的代码仓库: 33的代码仓库🚚 目录 1.前言2.封装2.1如何实现封装2.2 访问限定符2.3包的概念2.3.1导入包中的类2.3.2自定义包 3.继承3.1为什么要继承3.2继承的概念3.3继承的语法3.4父类成员…

『Apisix进阶篇』动态负载均衡:APISIX的实战演练与策略应用

🚀『Apisix系列文章』探索新一代微服务体系下的API管理新范式与最佳实践 【点击此跳转】 📣读完这篇文章里你能收获到 🎯 掌握APISIX中多种负载均衡策略的原理及其适用场景。📈 学习如何通过APISIX的Admin API和Dashboard进行负…

QtCreator调试时无法显示std::string的内容

在银河麒麟V10或Ubuntu下使用QtCreator调试代码时&#xff0c;std::string类型变量在大多数情况下不显示实际内容&#xff0c;而是显示"<无法访问>"字样&#xff0c;鼠标点击进去也是看不见任何有用信息&#xff0c;这样非常影响调试效率&#xff0c;为此&…

android 11 SystemUI 状态栏打开之后的界面层级关系说明之一

比如WiFi 图标的父layout为&#xff1a; Class Name: ButtonRelativeLayout Class Name: QSTileView Class Name: TilePage Class Name: PagedTileLayout Class Name: QSPanel Class Name: NonInterceptingScrollView Class Name: QSContainerImpl Class Name: FrameLayout Cl…

软件开发服务合同套用模板(WORD原件)

一、合作方式 二、合同标的 三、开发进度及软件成果交付 四、开发费用 五、付款结算方式 六、知识产权条款 七、双方的权利和义务 八、验收 九、售后服务支持 十、培训 十一、保密责任 十二、不可抗力 十三、争议的解决 十四、其它事项 软件全套资料包领取&#xff1a;软件开发…

手机termux免root安装kali:一步到位+图形界面_termux安装kali-

1.工具 安卓包括鸿蒙手机、WiFi、充足的电量、脑子 2.浏览器搜索termuxvnc viewer下载安装。 3.对抗华为纯净模式需要一些操作先断网弹窗提示先不开等到继续安装的时候连上网智能检测过后就可以了 termux正常版本可以通过智能监测失败了就说明安装包是盗版 4.以后出现类似…

区块链技术与大数据结合的商业模式探索

hello宝子们...我们是艾斯视觉擅长ui设计和前端开发10年经验&#xff01;希望我的分享能帮助到您&#xff01;如需帮助可以评论关注私信我们一起探讨&#xff01;致敬感谢感恩&#xff01; 随着区块链技术和大数据技术的不断发展&#xff0c;两者的结合为企业带来了新的商业模式…

聊聊多模态大模型处理的思考

多模态&#xff1a;文本、音频、视频、图像等多形态的展现形式。目前部门内业务要求领域大模型需要是多模态——支持音频/文本。从个人思考的角度来审视下&#xff0c;审视下多模态大模型的实现方式。首先就要区分输入与输出&#xff0c;即输入的模态与输出的模态。从目前来看&…

HBase的Python API(happybase)操作

一、Windows下安装Python库&#xff1a;happybase pip install happybase -i https://pypi.tuna.tsinghua.edu.cn/simple 二、 开启HBase的Thrift服务 想要使用Python API连接HBase&#xff0c;需要开启HBase的Thrift服务。所以&#xff0c;在Linux服务器上&#xff0c;执行如…

2024最新版克魔助手抓包教程(9) - 克魔助手 IOS 数据抓包

引言 在移动应用程序的开发中&#xff0c;了解应用程序的网络通信是至关重要的。数据抓包是一种很好的方法&#xff0c;可以让我们分析应用程序的网络请求和响应&#xff0c;了解应用程序的网络操作情况。克魔助手是一款非常强大的抓包工具&#xff0c;可以帮助我们在 Android …

前端Webpack5高级进阶课程

课程介绍 本套视频教程主要内容包含React/Vue最新版本脚手架分析、基于Webpack5编写自己的loader和plugin等&#xff0c;让你开发时选择更多样&#xff0c;最后&#xff0c;用不到一百行的代码实现Webpack打包。通过本套视频教程的学习&#xff0c;可以帮你彻底打通Webpack的任…