(论文复现)DeepAnt模型复现及应用

DeepAnt论文如下,其主要是用于时间序列的无监督粗差探测

 

其提出的模型架构如下:

        该文提出了一个无监督的时间序列粗差探测模型,其主要有预测模块和探测模块组成,其中预测模块的网络结构如下。
       预测结构是将时间序列数据组织成数据集之后经过两次的卷积和最大池化,最后将卷积结果通过一个全连接层转换为一个输出数据(若是单步预测则输出单元个数为1)
       探测模块是将模型的时序预测结果与该时刻的观测数据相比来计算欧氏距离,以此来作为当前时间点距离的异常分数。以此来作为数据粗差探测的标准。

        (本博客主要是分享复现代码,论文中的细节原理可自行下载学习)


 复现代码(数据不便分享):

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader,TensorDataset
from sklearn.preprocessing import MinMaxScaler,StandardScalerdef MSE(arr1,arr2):arr1,arr2 = np.array(arr1).flatten(),np.array(arr2).flatten()assert arr1.shape[0] == arr2.shape[0]return np.sum(np.power(arr1-arr2,2)) / arr1.shape[0]def MAE(arr1,arr2):arr1,arr2 = np.array(arr1).flatten(),np.array(arr2).flatten()assert arr1.shape[0] == arr2.shape[0]return np.sum(np.abs(arr1-arr2)) / arr1.shape[0]class MyData(Dataset):def __init__(self,arr,history_window,predict_len) -> None:self.length = arr.flatten().shape[0]self.history_window = history_windowself.dataset_x,self.dataset_y = self.get_dataset(arr,history_window,predict_len)def get_dataset(self,arr,history_window,predict_len):arr = np.array(arr).flatten()N = history_windowM = predict_lendataset_x = np.zeros((arr.shape[0] - N,N))dataset_y = np.zeros((arr.shape[0] - N,M))for i in range(arr.shape[0] - N):dataset_x[i] = arr[i:i+N]dataset_y[i] = arr[i+N:i+N+M]dataset_x = torch.from_numpy(dataset_x).to(torch.float)dataset_y = torch.from_numpy(dataset_y).to(torch.float)return (dataset_x,dataset_y)def __getitem__(self, index):		# 定义方法 data[i] 的返回值return (self.dataset_x[index,:],self.dataset_y[index,:])def __len__(self):					# 获取数据集样本个数return self.length - self.history_windowclass DeepAnt(nn.Module):def __init__(self,lag,p_w):super().__init__()self.convblock1 = nn.Sequential(nn.Conv1d(in_channels=1, out_channels=32, kernel_size=3, padding='valid'),nn.ReLU(inplace=True),nn.MaxPool1d(kernel_size=2))self.convblock2 = nn.Sequential(nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, padding='valid'),nn.ReLU(inplace=True),nn.MaxPool1d(kernel_size=2))self.flatten = nn.Flatten()self.denseblock = nn.Sequential(nn.Linear(32, 40), # for lag = 10#nn.Linear(96, 40), # for lag = 20#nn.Linear(192, 40), # for lag = 30nn.ReLU(inplace=True),nn.Dropout(p=0.25),)self.out = nn.Linear(40, p_w)def forward(self, x):x = x.view(-1,1,lag)x = self.convblock1(x)x = self.convblock2(x)x = self.flatten(x)x = self.denseblock(x)x = self.out(x)return xdef Train(model,data_set,EPOCH,task_id):if torch.cuda.is_available():device = torch.device('cuda')print('cuda is used...')else:torch.device('cpu')print('cpu is used...')scale = StandardScaler()loss_fn = nn.MSELoss()model.to(device)loss_fn.to(device)train_x,train_y = data_set.dataset_x,data_set.dataset_ytrain_x = scale.fit_transform(train_x)train_x = torch.from_numpy(train_x).to(torch.float).to(device)train_y = train_y.to(device).to(torch.float)torch_dataset = TensorDataset(train_x,train_y)optimizer = torch.optim.Adam(model.parameters())BATCH_SIZE = 100model = model.train()train_loss = []print('======Start training...=======')print(f'Epoch is {EPOCH}\ntrain_x shape is {train_x.shape}\nBATCH_SIZE is {BATCH_SIZE}')for i in range(EPOCH):loader = DataLoader(dataset=torch_dataset,batch_size=BATCH_SIZE,shuffle=True)temp_1 = []for step,(batch_x,batch_y) in enumerate(loader):out = model(batch_x)optimizer.zero_grad()loss = loss_fn(out,batch_y)temp_1.append(loss.item())loss.backward()optimizer.step()torch.cuda.empty_cache()train_loss.append(np.mean(np.array(temp_1)))if i % 10 == 0:print(f"The {i}/{EPOCH} is end, loss is {np.round(np.mean(np.array(temp_1)),6)}.")print('========Training end...=======')model = model.eval()plt.plot(train_loss)pred = model(train_x).cpu().data.numpy()print(f'pred shape {pred.shape}')plt.figure()y = train_y.cpu().data.numpy().flatten()print(f'y shape {y.shape}')plt.plot(y,c='b',label='True')plt.plot(pred,'r',label='pred')plt.legend()plt.title('Train_result')plt.show()return predif __name__ == "__main__":data_f = pd.read_csv('HF05_processed.csv')data = np.array(pd.DataFrame(data_f)['OT'])lag = 10dataset = MyData(data,lag,1)model = DeepAnt(lag,1)res = Train(model,dataset,200,'1')data = data[lag:].flatten() plt.plot(data)plt.plot(res,c='r')err = data - res.flatten()anomaly_score = np.sqrt(np.power(err,2))plt.figure()plt.plot(anomaly_score)error_list = []threshold = 0.04for i in range(anomaly_score.shape[0]):if anomaly_score[i] > threshold:error_list.append(i)print(len(error_list))plt.figure()plt.plot(data)plt.plot(error_list,[data[i] for i in error_list],ls='',marker='x',c='r',markersize=4)plt.show()

运行结果:

 

才疏学浅,敬请指正!

欢迎交流:

邮箱:rton.xu@qq.com

QQ:2264787072

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/56613.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQL Server数据库 -- 索引与视图

文章目录 一、索引 聚集索引非聚集索引二、视图三、自定义函数 标量函数表值函数四、游标五、总结 前言 在学习完创建库表、查询等知识点后,为了更加方便优化数据库的存储和内容,我们需要学习一系列的方法例如索引与视图等等,从而使我们更加…

IDEA中Git面板操作介绍 变基、合并、提取、拉取、签出

IDEA中Git面板操作介绍 变基、合并、提取、拉取、签出 面板介绍 变基、合并 提取、拉取 签出、Checkout 面板介绍 如图,在IDEA的Git面板中,仓库会分为本地仓库和远程仓库,代码仓库里面放的是各个分支。 分支前面的书签🔖标志…

Linux常用命令学习总结

Linux命令分类 1. Linux目录操作命令2. Linux文件名称3. Linux磁盘命令4. Linux进程与防火墙5. Linux用户与组的关系6. Linux权限操作(chmod命令)7. Linux中的文件类型文件的寻找 最近系统地学习下Linux命令的使用,因此作如下记录,以便随时复习和翻阅。 …

微信小程序(van-tabs) 去除横向滚动条样式(附加源码解决方案+报错图)

问题描述 今天第一次接触vant组件库。 ant官网地址适用于Vue3 支持Vue2、Vue3、微信小程序等 我在使用van-tabs组件时遇到了一个问题,如下图所示: 从图片上可以看到有个灰色的横向滚动条,一开始领导给我说这个问题,我反反复复都…

宝塔面板点击SSL闪退打不开怎么解决?

宝塔Linux面板点击SSL证书闪退如何解决?旧版本的宝塔Linux面板确实存在这种情况,如何解决?升级你的宝塔Linux面板即可。新手站长分享宝塔面板SSL闪退的解决方法: 宝塔面板点击SSL证书闪退解决方法 问题:宝塔Linux面板…

关于策略模式的注入问题

上面抄别人的 当在实现策略方法时,报null,排查后发现是接口实现有多个,需要添加别名 注入时添加Qeualifier,指定名称,如下图;如图上修改, 测试类中不用new具体行为策略了,注入别名即…

文心一言 VS 讯飞星火 VS chatgpt (72)-- 算法导论7.1 3题

三、请简要地证明:在规模为 n 的子数组上,PARTITION 的时间复杂度为 Θ(n)。 文心一言: 首先,我们需要明确PARTITION函数的具体定义。PARTITION函数通常用于快速排序算法中,它将一个数组分为两个子数组,使得一个子数…

java-IDEA MAVEN查看依赖树,解决jar包重复和冲突

如果这里面的依赖关系有红线,就说明有包冲突,一般都是版本不一致,可以在idea里下一个插件Maven Helper,点击install并重启IDEA 打开pom.xml文件,在下方会出现Dependency Analyzer,选择它会出现重复依赖列表,选择对应的依赖,右键红…

K8s中的PV和PVC和监控

1.PV和PVC PV:持久化存储,对存储资源进行抽象,对外提供可以调用的地方(类似:生产者) PVC:用于调用,不需要关心内部实现细节(类似:消费者) 2.实…

leetcode原题 路径总和 I II III(递归实现)

路径总和 I : 给你二叉树的根节点 root 和一个表示目标和的整数 targetSum 。判断该树中是否存在 根节点到叶子节点 的路径,这条路径上所有节点值相加等于目标和 targetSum 。如果存在,返回 true ;否则,返回 false 。…

[C++项目] Boost文档 站内搜索引擎(4): 搜索的相关接口的实现、线程安全的单例index接口、cppjieba分词库的使用、综合调试...

有关Boost文档搜索引擎的项目的前三篇文章, 已经分别介绍分析了: 项目背景: 🫦[C项目] Boost文档 站内搜索引擎(1): 项目背景介绍、相关技术栈、相关概念介绍…文档解析、处理模块parser的实现: 🫦[C项目] Boost文档 站内搜索引擎(2): 文档文本解析模块…

第4章 变量、作用域与内存

引言 由于js是一门只有在声明变量后才能明确类型的语言,并且在任意时刻都可以改变数据类型。这也引起了一些问题 原始值与引用值 原始值就是基本数据类型,引言值就是复杂数据类型 变量在赋值的时候。js会判断如果是原始值,访问时就是按值访问…