【信号处理】基于CNN自编码器的心电信号异常检测识别(tensorflow)

关于

本项目主要实现卷积自编码器对于异常心电ECG信号的检测和识别,属于无监督学习中的生理信号检测的典型方法之一。

工具

 

方法实现

读取心电信号
normal_df = pd.read_csv("/heartbeat/ptbdb_normal.csv").iloc[:, :-1]
anomaly_df = pd.read_csv("/heartbeat/ptbdb_abnormal.csv").iloc[:, :-1]
normal_df.head()

信号可视化

def plot_sample(normal, anomaly):index = np.random.randint(0, len(normal_df), 2)fig, ax = plt.subplots(1, 2, sharey=True, figsize=(10, 4))ax[0].plot(normal.iloc[index[0], :].values, label=f"Case {index[0]}")ax[0].plot(normal.iloc[index[1], :].values, label=f"Case {index[1]}")ax[0].legend(shadow=True, frameon=True, facecolor="inherit", loc=1, fontsize=9)ax[0].set_title("Normal")ax[1].plot(anomaly.iloc[index[0], :].values, label=f"Case {index[0]}")ax[1].plot(anomaly.iloc[index[1], :].values, label=f"Case {index[1]}")ax[1].legend(shadow=True, frameon=True, facecolor="inherit", loc=1, fontsize=9)ax[1].set_title("Anomaly")plt.tight_layout()plt.show()plot_sample(normal_df, anomaly_df)

 

 信号均值计算及可视化
def plot_smoothed_mean(data, class_name = "normal", step_size=5, ax=None):df = pd.DataFrame(data)roll_df = df.rolling(step_size)smoothed_mean = roll_df.mean().dropna().reset_index(drop=True)smoothed_std = roll_df.std().dropna().reset_index(drop=True)margin = 3*smoothed_stdlower_bound = (smoothed_mean - margin).values.flatten()upper_bound = (smoothed_mean + margin).values.flatten()ax.plot(smoothed_mean.index, smoothed_mean)ax.fill_between(smoothed_mean.index, lower_bound, y2=upper_bound, alpha=0.3, color="red")ax.set_title(class_name, fontsize=9)fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharey=True)
axes = axes.flatten()
for i, label in enumerate(CLASS_NAMES, start=1):data_group = df.groupby("target")data = data_group.get_group(label).mean(axis=0, numeric_only=True).to_numpy()plot_smoothed_mean(data, class_name=label, step_size=20, ax=axes[i-1])
fig.suptitle("Plot of smoothed mean for each class", y=0.95, weight="bold")
plt.tight_layout()

 训练/测试数据划分
normal_df.drop("target", axis=1, errors="ignore", inplace=True)
normal = normal_df.to_numpy()
anomaly_df.drop("target", axis=1, errors="ignore", inplace=True)
anomaly = anomaly_df.to_numpy()X_train, X_test = train_test_split(normal, test_size=0.15, random_state=45, shuffle=True)
print(f"Train shape: {X_train.shape}, Test shape: {X_test.shape}, anomaly shape: {anomaly.shape}")
搭建自编码器
class AutoEncoder(Model):def __init__(self, input_dim, latent_dim):super(AutoEncoder, self).__init__()self.input_dim = input_dimself.latent_dim = latent_dimself.encoder = tf.keras.Sequential([layers.Input(shape=(input_dim,)),layers.Reshape((input_dim, 1)),  # Reshape to 3D for Conv1Dlayers.Conv1D(128, 3, strides=1, activation='relu', padding="same"),layers.BatchNormalization(),layers.MaxPooling1D(2, padding="same"),layers.Conv1D(128, 3, strides=1, activation='relu', padding="same"),layers.BatchNormalization(),layers.MaxPooling1D(2, padding="same"),layers.Conv1D(latent_dim, 3, strides=1, activation='relu', padding="same"),layers.BatchNormalization(),layers.MaxPooling1D(2, padding="same"),])# Previously, I was using UpSampling. I am trying Transposed Convolution this time around.self.decoder = tf.keras.Sequential([layers.Conv1DTranspose(latent_dim, 3, strides=1, activation='relu', padding="same"),
#             layers.UpSampling1D(2),layers.BatchNormalization(),layers.Conv1DTranspose(128, 3, strides=1, activation='relu', padding="same"),
#             layers.UpSampling1D(2),layers.BatchNormalization(),layers.Conv1DTranspose(128, 3, strides=1, activation='relu', padding="same"),
#             layers.UpSampling1D(2),layers.BatchNormalization(),layers.Flatten(),layers.Dense(input_dim)])def call(self, X):encoded = self.encoder(X)decoded = self.decoder(encoded)return decodedinput_dim = X_train.shape[-1]
latent_dim = 32model = AutoEncoder(input_dim, latent_dim)
model.build((None, input_dim))
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), loss="mae")
model.summary()
模型训练
epochs = 100
batch_size = 128
early_stopping = EarlyStopping(patience=10, min_delta=1e-3, monitor="val_loss", restore_best_weights=True)history = model.fit(X_train, X_train, epochs=epochs, batch_size=batch_size,validation_split=0.1, callbacks=[early_stopping])
训练可视化
plt.plot(history.history['loss'], label="Training loss")
plt.plot(history.history['val_loss'], label="Validation loss", ls="--")
plt.legend(shadow=True, frameon=True, facecolor="inherit", loc="best", fontsize=9)
plt.title("Training loss")
plt.ylabel("Loss")
plt.xlabel("Epoch")
plt.show()

 

信号重建可视化
fig, axes = plt.subplots(2, 5, sharey=True, sharex=True, figsize=(12, 6))
random_indexes = np.random.randint(0, len(X_train), size=5)for i, idx in enumerate(random_indexes):data = X_train[[idx]]plot_examples(model, data, ax=axes[0, i], title="Normal")for i, idx in enumerate(random_indexes):data = anomaly[[idx]]plot_examples(model, data, ax=axes[1, i], title="anomaly")plt.tight_layout()
fig.suptitle("Sample plots (Actual vs Reconstructed by the CNN autoencoder)", y=1.04, weight="bold")
fig.savefig("autoencoder.png")
plt.show()

计算重建MAE误差
train_mae = model.evaluate(X_train, X_train, verbose=0)
test_mae = model.evaluate(X_test, X_test, verbose=0)
anomaly_mae = model.evaluate(anomaly_df, anomaly_df, verbose=0)print("Training dataset error: ", train_mae)
print("Testing dataset error: ", test_mae)
print("Anormaly dataset error: ", anomaly_mae)

 异常检测阈值选取

MAE误差阈值=正常数据重建MAE均值+正常数据重建MAE标准差,此阈值可以用来直接检测某信号为正常信号还是异常心电信号。

def predict(model, X):pred = model.predict(X, verbose=False)loss = mae(pred, X)return pred, loss_, train_loss = predict(model, X_train)
_, test_loss = predict(model, X_test)
_, anomaly_loss = predict(model, anomaly)
threshold = np.mean(train_loss) + np.std(train_loss) # Setting threshold for distinguish normal data from anomalous databins = 40
plt.figure(figsize=(9, 5), dpi=100)
sns.histplot(np.clip(train_loss, 0, 0.5), bins=bins, kde=True, label="Train Normal")
sns.histplot(np.clip(test_loss, 0, 0.5), bins=bins, kde=True, label="Test Normal")
sns.histplot(np.clip(anomaly_loss, 0, 0.5), bins=bins, kde=True, label="anomaly")ax = plt.gca()  # Get the current Axes
ylim = ax.get_ylim()
plt.vlines(threshold, 0, ylim[-1], color="k", ls="--")
plt.annotate(f"Threshold: {threshold:.3f}", xy=(threshold, ylim[-1]), xytext=(threshold+0.009, ylim[-1]),arrowprops=dict(facecolor='black', shrink=0.05), fontsize=9)
plt.legend(shadow=True, frameon=True, facecolor="inherit", loc="best", fontsize=9)
plt.show()

模型评估
plot_confusion_matrix(model, X_train, X_test, anomaly, threshold=threshold)ytrue, ypred = prepare_labels(model, X_train, X_test, anomaly, threshold=threshold)
print(classification_report(ytrue, ypred, target_names=CLASS_NAMES))

 

代码获取

相关项目开发和问题,欢迎后台沟通交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/637619.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

7.Eureka注册中心

将user-service服务注册到eureka 将order-service服务注册到eureka eureka:client:service-url:defaultZone: http://localhost:10086/eureka/ <dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-starter-netflix…

[转载] 在IIS上启用https的免费ssl证书使用教程

一、申请证书 数字证书管理服务&#xff08;原SSL证书&#xff09;_SSL数字证书_HTTPS加密_服务器证书_CA认证-阿里云 二、添加证书 1、在控制台上做如下操作&#xff1a;文件》添加/删除管理单元》可用的管理单元》证书》添加》确定。 2、在证书管理单元中选择&#xff1a;…

基于spark进行数据分析的心力衰竭可视化大屏项目

基于spark进行数据分析的心力衰竭可视化大屏项目 项目背景 在当今的医疗领域&#xff0c;数据驱动的决策变得日益重要。心力衰竭作为常见的心血管疾病&#xff0c;其临床数据的分析对于改善患者治疗结果至关重要。本文将介绍如何利用Apache Spark进行大规模心力衰竭临床数据的…

IOS 32位调试环境搭建

一、背景 调试IOS程序经常使用gdb&#xff0c;目前gdb只支持32位程序调试&#xff0c;暂不支持IOS 64位程序调试。IOS 32位程序使用GDB调试之前&#xff0c;必须确保手机已越狱&#xff0c;否则无法安装和使用GDB调试软件。下面详细介绍GDB调试IOS 32位程序的环境搭建。 二、I…

SpringCloud 与 Dubbo 的区别详解

一、Spring Cloud 和 Dubbo 的概述 1.1 SpringCloud 简介 SpringCloud 是一个用于构建云原生应用的框架集合&#xff0c;它为开发者提供了一套完整的工具链&#xff0c;用于快速搭建分布式系统。SpringCloud 基于 SpringBoot 开发&#xff0c;具有如下特点&#xff1a; 提供…

智慧园区引领未来产业趋势:科技创新驱动园区发展,构建智慧化产业新体系

目录 一、引言 二、智慧园区引领未来产业趋势 1、产业集聚与协同发展 2、智能化生产与服务 3、绿色可持续发展 三、科技创新驱动园区发展 1、创新资源的集聚与整合 2、创新成果的转化与应用 3、创新文化的培育与弘扬 四、构建智慧化产业新体系 1、优化产业布局与结构…

TBWeb开发版V3.2.6免授权无后门Chatgpt系统源码下载及详细安装教程

TBWeb系统是基于 NineAI 二开的可商业化 TB Web 应用&#xff08;免授权&#xff0c;无后门&#xff0c;非盗版&#xff0c;已整合前后端&#xff0c;支持快速部署&#xff09;。相比稳定版&#xff0c;开发版进度更快一些。前端改进&#xff1a;对话页UI重构&#xff0c;参考C…

年如何在不丢失数据的情况下解锁锁定的 Android 手机?

当您忘记密码、PIN 码或图案并且想要解锁 Android 手机时&#xff0c;您可能会丢失 Android 手机上的数据。但您无需再担心&#xff0c;因为在这里&#xff0c;我们想出了几种解锁锁定的 Android 手机而不丢失数据的方法。 方法 1. 使用 Android Unlock 解锁锁定的 Android 且不…

网络原理-UDP和TCP

在传输层中有两个非常重要的协议&#xff0c;UDP和TCP&#xff0c;现在就来研究一下这两个协议。 UDP 报文格式 我们观察可以发现&#xff0c;里面UDP报文长度为2个字节&#xff0c;那么是多少呢&#xff1f;我们需要快速反应如下固定字节数据类型的取值范围&#xff1a; 字…

【大语言模型LLM】-使用大语言模型搭建点餐机器人

关于作者 行业&#xff1a;人工智能训练师/LLM 学者/LLM微调乙方PM发展&#xff1a;大模型微调/增强检索RAG分享国内大模型前沿动态&#xff0c;共同成长&#xff0c;欢迎关注交流… 大语言模型LLM基础-系列文章 【大语言模型LLM】-大语言模型如何编写Prompt?【大语言模型LL…

微信小程序vue.js+uniapp服装商城销售管理系统nodejs-java

本技术是java平台的开源应用框架&#xff0c;其目的是简化Sping的初始搭建和开发过程。默认配置了很多框架的使用方式&#xff0c;自动加载Jar包&#xff0c;为了让用户尽可能快的跑起来spring应用程序。 SpinrgBoot的主要优点有&#xff1a; 1、为所有spring开发提供了一个更快…

第七周C语言编程题

第七周C语言编程题 第一题 题目&#xff1a;循环结构练习05 用for语句输出倒三角图案 这是一个编程题模板。 要求用for语句&#xff0c;输出指定的由“*”符号组成的倒三角图案。 输入格式: 本题目没有输入。 输出格式: 按照下列格式输出由“*”符号组成的倒三角图案。…