运动想象 (MI) 迁移学习系列 (5) : SSMT

运动想象迁移学习系列:SSMT

  • 0. 引言
  • 1. 主要贡献
  • 2. 网络结构
  • 3. 算法
  • 4. 补充
    • 4.1 为什么设置一种新的适配器?
    • 4.2 动态加权融合机制究竟是干啥的?
  • 5. 实验结果
  • 6. 总结
  • 欢迎来稿

论文地址:https://link.springer.com/article/10.1007/s11517-024-03032-z
论文题目:Semi-supervised multi-source transfer learning for cross-subject EEG motor imagery classification
论文代码:无

0. 引言

脑电图(EEG)运动意象(MI)分类是指利用脑电信号对受试者的运动意象活动进行识别和分类;随着脑机接口(BCI)的发展,这项任务越来越受到关注。然而,脑电图数据的收集通常是耗时且劳动密集型的,这使得很难从新受试者那里获得足够的标记数据来训练新模型。此外,不同个体的脑电信号表现出显着差异,导致在直接对从新受试者获得的脑电信号进行分类时,在现有受试者上训练的模型的性能显着下降。因此,充分利用现有受试者的脑电数据和新目标受试者的未标记脑电数据,提高目标受试者达到的心肌梗死分类性能至关重要。本研究提出了一种半监督多源迁移(SSMT)学习模型来解决上述问题;该模型学习信息和域不变表示,以解决跨主题的 MI-EEG 分类任务。具体而言,该文提出了一种动态转移加权模式,通过整合从多源域派生的加权特征来获得最终预测。

文中主要解决方法是针对无监督的脑电数据迁移学习方案,是一个不错的角度,也提出了很有新意的算法设计!!!

1. 主要贡献

  1. 一种基于 MMDCMMD域适应方法,用于解决单个 MI-EEG信号差异的问题,对齐每个源域和靶域之间的条件和边际分布差异。此外,伪标签被应用于目标域的未标记数据,并在整个训练过程中迭代更新。通过这种方式,条件分布信息将更新为近似真实的条件分布。
  2. 基于域间差异度量设计了一种动态权重转移模型,使每个源域能够根据其与目标域的相似性为训练过程做出贡献。因此,通过减轻与目标域显著差异的源域的不利影响,可以进一步提高分类器对目标域的预测性能。
  3. 通过一系列实验,在两个公开可用的 BCI数据集上评估了所提出的方法。结果表明,所提方法的每一项创新都有助于提高解码性能,与基线相比,解码性能更好。

2. 网络结构

在这里插入图片描述
SSMT两个主要阶段组成。预训练阶段预训练所有可用于在特征提取任务和原始监督分类任务中训练的标记数据,以获得仅包含特征提取器和分类器的全局模型。然后,利用预训练模型对目标域的未标记数据进行伪标记;再训练阶段包括三个主要步骤。首先,域适配器旨在减少每个源域和目标域之间的差异。然后,使用伪标签信息并不断更新以优化模型。最后,最终决策由MLP分类器的转移权重融合产生。

3. 算法

符号说明
{ X s k , y s k } k = 1 n \{X_s^k, y_s^k\}_{k=1}^n {Xsk,ysk}k=1n 表示存在n个源域 X t X_t Xt 表示目标域,包含两个部分,分别是 X l X_l Xl X u X_u Xu; X l X_l Xl y l y_l yl 表示目标域中已知(标记)的样本 X u X_u Xu 表示目标域中未标记的样本,即也不知道其对应的类别。

SSMT算法步骤

输入: { X s k , y s k } k = 1 n , X l , y l , X u \{X_s^k, y_s^k\}_{k=1}^n, X_l, y_l, X_u {Xsk,ysk}k=1n,Xl,yl,Xu

  1. 初始化权重参数 θ f , θ c \theta_f, \theta_c θf,θc

  2. 通过输入 { X s k , y s k } k = 1 n , X l , y l \{X_s^k, y_s^k\}_{k=1}^n, X_l, y_l {Xsk,ysk}k=1n,Xl,yl 直接训练预训练模型中的特征提取器 G f G_f Gf 和MLP分类器 G c G_c Gc , 并根据下面等式更新参数 θ f , θ c \theta_f, \theta_c θf,θc L c = − ∑ k = 1 n y s k ⋅ log ⁡ ( G c ( G f ( X s k ; θ f ) ; θ c ) ) − y l ⋅ log ⁡ ( G c ( G f ( X l ; θ f ) ; θ c ) ) , \begin{aligned} L_c= & {} -\sum _{k=1}^n \textbf{y}^k_s\cdot \log (G_c(G_f(\textbf{X}^k_s;\theta _f);\theta _c))\nonumber \\{} & {} -\textbf{y}_l\cdot \log (G_c(G_f(\textbf{X}_l;\theta _f);\theta _c)), \end{aligned} Lc=k=1nysklog(Gc(Gf(Xsk;θf);θc))yllog(Gc(Gf(Xl;θf);θc)),

  3. 生成测试集的伪标签: y ^ u = G c ( G f ( X u ; θ f ) ; θ c ) , \begin{aligned} \hat{\textbf{y}}_u=G_c(G_f(\textbf{X}_u;\theta _f);\theta _c), \end{aligned} y^u=Gc(Gf(Xu;θf);θc), 预训练阶段结束

  4. X l X_l Xl X u X_u Xu 的数据合并为目标域 X t X_t Xt,并连接所有域的数据(将 X s k X_s^k Xsk X t X_t Xt 的数据进行连接)

  5. 重复

  6. 将连接的数据输入 G f G_f Gf 来得到所有域的特征:
    F = [ G f ( X s 1 ; θ f ) , . . . , G f ( X s n ; θ f ) , G f ( X t ; θ f ) ] T F=[G_f(X_s^1;\theta_f),...,G_f(X_s^n;\theta_f),G_f(X_t;\theta_f)]^T F=[Gf(Xs1;θf),...,Gf(Xsn;θf),Gf(Xt;θf)]T

  7. 根据以下公式获取每个源域的差异损失转移权重: L d k = M M D ( D s k , D t ) + C M M D ( D s k , D t ) . \begin{aligned} L_d^k=MMD(\mathcal {D}^k_s, \mathcal {D}_t)+CMMD(\mathcal {D}^k_s, \mathcal {D}_t). \end{aligned} Ldk=MMD(Dsk,Dt)+CMMD(Dsk,Dt). C M M D ( D s k , D t ) = ∑ c = 1 C ∥ 1 m c ∑ x s k , i ∣ y s k , i = c ϕ ( G f ( x s k , i ; θ f ) ) − 1 n ^ c + n c ( ∑ x l i ∣ y l i = c ϕ ( G f ( x l i ; θ f ) ) + ∑ x u i ∣ y ^ u i = c ϕ ( G f ( x u i ; θ f ) ) ∥ , \begin{aligned} CMMD(\mathcal {D}^k_s, \mathcal {D}_t)= & {} \sum _{c=1}^C\Vert \frac{1}{m_c} \sum _{\textbf{x}_s^{k,i} |y^{k,i}_s=c} \phi (G_f(\textbf{x}_s^{k,i};\theta _f))\nonumber \\{} & {} -\frac{1}{\hat{n}_c+n_c}(\sum _{\textbf{x}_l^i |{y}_l^i=c} \phi (G_f(\textbf{x}_l^i;\theta _f))\nonumber \\{} & {} +\sum _{\textbf{x}_u^i |\hat{y}_u^i=c} \phi (G_f(\textbf{x}_u^i;\theta _f))\Vert , \end{aligned} CMMD(Dsk,Dt)=c=1Cmc1xsk,iysk,i=cϕ(Gf(xsk,i;θf))n^c+nc1(xliyli=cϕ(Gf(xli;θf))+xuiy^ui=cϕ(Gf(xui;θf)), M M D ( D s k , D t ) = ∥ 1 n s k ∑ i = 1 n s k ϕ ( G f ( x s k , i ; θ f ) ) − 1 n t ∑ i = 1 n t ϕ ( G f ( x t i ; θ f ) ) ∥ , \begin{aligned} MMD\left( \mathcal {D}^k_s, \mathcal {D}_t\right)= & {} \Bigg \Vert \frac{1}{n^k_s} \sum _{i=1}^{n^k_s} \phi (G_f(\textbf{x}_s^{k,i};\theta _f))\nonumber \\{} & {} - \frac{1}{n_t} \sum _{i=1}^{n_t} \phi (G_f(\textbf{x}_t^i;\theta _f))\Bigg \Vert , \end{aligned} MMD(Dsk,Dt)= nsk1i=1nskϕ(Gf(xsk,i;θf))nt1i=1ntϕ(Gf(xti;θf)) ,

  8. 基于下面式子对每个域的特征进行动态加权,然后将 F ∗ F^* F 作为 G c G_c Gc 的输入:

    w = [ W d 1 , … , W d n ] ⊤ = [ K − L d 1 2 ∑ k = 1 n K − L d k 2 , … , K − L d n 2 ∑ k = 1 n K − L d k 2 ] ⊤ , \begin{aligned} \textbf{w}= & {} [W^1_d, \ldots , W^n_d]^{\top }\nonumber \\= & {} \left[ \frac{K^{- {L_d^1}^2}}{\sum _{k=1}^n K^{- {L_d^k}^2}}, \ldots , \frac{K^{- {L_d^n}^2}}{\sum _{k=1}^n K^{- {L_d^k}^2}}\right] ^{\top }, \end{aligned} w==[Wd1,,Wdn][k=1nKLdk2KLd12,,k=1nKLdk2KLdn2], F ∗ = [ F s 1 ∗ , … , F s n ∗ , F t ] ⊤ = [ W d 1 F s 1 , … , W d n F s n , F t ] ⊤ , \begin{aligned} \textbf{F}^*=[{\textbf{F}^1_s}^*,\ldots ,{\textbf{F}^n_s}^*,\textbf{F}_t]^\top =[W^1_d\textbf{F}^1_s,\ldots ,W^n_d\textbf{F}^n_s,\textbf{F}_t]^\top , \end{aligned} F=[Fs1,,Fsn,Ft]=[Wd1Fs1,,WdnFsn,Ft],

  9. 根据下面等式,通过最小化 L L L 更新参数 θ f , θ c \theta_f, \theta_c θf,θc

L = L c + λ L d , \begin{aligned} L=L_c+\lambda L_d, \end{aligned} L=Lc+λLd,

  1. 通过预测 X u X_u Xu 更新 y ^ u \hat{y}_u y^u

  2. 直到收敛

  3. 返回 y ^ u \hat{y}_u y^u

4. 补充

4.1 为什么设置一种新的适配器?

最近的研究表明,随着域间差异的增加,分类器对特征的可转移性显着降低,这表明直接转移提取的特征是一种不安全的策略。因此,在不考虑个体信号差异的情况下,使用所有可用数据进行预训练的模型可能会导致目标受试者分类的性能下降。为了防止传统两级流水线引起的分布过拟合问题,设计了一种域适配器来减轻单个信号差异的负面影响。

尽管经典MMD已被广泛用作分布差异度量,但现有研究表明,在处理类权重偏差(即类不平衡数据)时,MMD并不总是可靠的。调查发现类条件分布之间的差异 P s ( x s k , i ∣ y s k , i = c ) P_s\left( \textbf{x}_s^{k,i} \mid y^{k,i}_s=c\right) Ps(xsk,iysk,i=c) P t ( x l i ∣ y l i = c ) P_t\left( \textbf{x}_l^i \mid y_l^i=c\right) Pt(xliyli=c)可以提供更合适的域差异量表,并导致卓越的域适应性能。什么时候 P s ( x s k , i ∣ y s k , i = c ) = P t ( x l i ∣ y l i = c ) P_s\left( \textbf{x}_s^{k,i} \mid y^{k,i}_s=c\right) =P_t\left( \textbf{x}_l^i \mid y_l^i=c\right) Ps(xsk,iysk,i=c)=Pt(xliyli=c),在源域中学习的分类器可以更安全地应用于目标域。基于这一概念,引入了条件最大均值差异(CMMD)度量,以对齐所有源域和目标域特征的类条件分布.

4.2 动态加权融合机制究竟是干啥的?

从所有数据中获得的特征 G f G_f Gf 可直接用于输入 G c G_c Gc 用于训练,但分类器的这种无歧视训练输入可能会导致不良结果。这一结果可归因于负转移当通过蛮力利用与目标关系不相关的来源时,就会发生负转移,从而导致对目标域的分类器预测有偏差。

为了减轻负迁移的影响,分类器被赋予了动态加权特征,用于最终决策融合。

5. 实验结果

对比实验结果:
在这里插入图片描述
消融实验结果:

在这里插入图片描述

  • PT:PT是仅包含特征提取器和MLP分类器的基本模型,可以完成简单的特征提取和分类任务。
  • DA:域适配器 (DA) 基于 MMD 和 CMMD。特别是,DA 仅使用通过预训练生成的伪标签来计算域间差异。
  • SS:SS 是一个迭代标签更新器。它的作用是在重新训练过程中周期性地生成和更新伪标签。
  • WF:WF是指动态加权模型,它对来自多源域的加权特征进行动态加权和整合。

6. 总结

到此,使用 SSMT 已经介绍完毕了!!! 如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。

如果觉得这篇文章对你有用,记得点赞、收藏并分享给你的小伙伴们哦😄。

欢迎来稿

欢迎投稿合作,投稿请遵循科学严谨、内容清晰明了的原则!!!! 有意者可以后台私信!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/529059.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux之线程控制

目录 一、POSIX线程库 二、线程的创建 三、线程等待 四、线程终止 五、分离线程 六、线程ID:pthread_t 1、获取线程ID 2、pthread_t 七、线程局部存储:__thread 一、POSIX线程库 由于Linux下的线程并没有独立特有的结构,所以Linux并…

白炽灯护眼还是LED护眼?热门护眼台灯全方位实测推荐

白炽灯是制热发光体,电流通过发热时产生热量,灯丝温度达到2000摄氏度以上,灯丝处于白炽状态了,它的光电转换效率低,费电,寿命也比较短,最主要的一点是光线会比较集中刺眼,而且调节性…

Cassandra 安装部署

文章目录 一、概述1.官方文档2. 克隆服务器3.安装准备3.1.安装 JDK 113.2.安装 Python3.3.下载文件 二、安装部署1.配置 Cassandra2.启动 Cassandra3.关闭Cassandra4.查看状态5.客户端连接服务器6.服务运行脚本 开源中间件 # Cassandrahttps://iothub.org.cn/docs/middleware/…

深入理解Java多线程与线程池:提升程序性能的利器

✨✨谢谢大家捧场,祝屏幕前的小伙伴们每天都有好运相伴左右,一定要天天开心哦!✨✨ 🎈🎈作者主页: 喔的嘛呀🎈🎈 目录 引言 一、实现多线程 1.1. 继承Thread类 1.2. 实现Runnab…

趣学前端 | Taro迁移完成之后,总结了一些踩坑经验

背景 四月份的时候,尝试将老的移动端项目改造成多端。因为老项目使用的React框架,综合考量,保障当前业务开发的进度同时,进行项目迁移,所以最后选择了Taro框架。迁移成本会低一些,上手快一些。 上个月&am…

js之原型链

在JavaScript中,原型链是一种用于实现继承和属性查找的机制。每个对象都有一个内部属性[[Prototype]],这个属性指向创建该对象时使用的构造函数的“prototype"属性。对象的方法和属性定义在它的原型对象上。 1.原型(Prototypes&#xf…

AIGC实战——GPT(Generative Pre-trained Transformer)

AIGC实战——GPT 0. 前言1. GPT 简介2. 葡萄酒评论数据集3. 注意力机制3.1 查询、键和值3.2 多头注意力3.3 因果掩码 4. Transformer4.1 Transformer 块4.2 位置编码 5. 训练GPT6. GPT 分析6.1 生成文本6.2 注意力分数 小结系列链接 0. 前言 注意力机制能够用于构建先进的文本…

mysql中 多表查询介绍

在 MySQL 中,多表查询是 SQL 语句的重要组成部分,用于从两个或多个表中检索数据。多表查询可以帮助我们更灵活地处理复杂的数据关系,并从中获取所需的信息。以下是 MySQL 中常见的多表查询及其特点、区别和应用场景。 常见多表查询 1. **内连…

java中几种对象存储(文件存储)中间件的介绍

一、前言 在博主得到系统中使用的对象存储主要有OSS(阿里云的对象存储) COS(腾讯云的对象存储)OBS(华为云的对象存储)还有就是MinIO 这些玩意。其实这种东西大差不差,几乎实现方式都是一样&…

【JAVA】CSS2:样式、选择器、伪类、颜色、字体、边框、列表、背景、盒子、布局、浮动

本文介绍了CSS样式、选择器、伪类、像素、颜色、字体、边框、列表、表格属性、背景、盒子、布局与浮动 1.样式 1.1 行内样式 <h1 style"color: aqua;font-size: large;">123</h1> 1.2 内部样式 <style>h1{color: red;font: 100;}</style>…

从16-bit 到 1.58-bit :大模型内存效率和准确性之间的最佳权衡

通过量化可以减少大型语言模型的大小&#xff0c;但是量化是不准确的&#xff0c;因为它在过程中丢失了信息。通常较大的llm可以在精度损失很小的情况下量化到较低的精度&#xff0c;而较小的llm则很难精确量化。 什么时候使用一个小的LLM比量化一个大的LLM更好? 在本文中&a…

关于比特币的AI对话

【ChatGPT】 比特币源码开源吗&#xff1f; 是的&#xff0c;比特币的源码是开源的。比特币项目是在MIT许可证下发布的&#xff0c;这意味着任何人都可以查看、修改、贡献和分发代码。比特币的源码托管在GitHub上&#xff0c;可以通过下面的链接进行访问&#xff1a; https://g…