Data-Free,多目标域适应合并方案,简单又有效 | ECCV24

news/2024/10/30 9:21:57/文章来源:https://www.cnblogs.com/VincentLee/p/18515002

来源:晓飞的算法工程笔记 公众号,转载请注明出处

论文: Training-Free Model Merging for Multi-target Domain Adaptation

  • 论文地址:https://arxiv.org/abs/2407.13771
  • 论文代码:https://air-discover.github.io/ModelMerging

创新点


  • 对域适应的场景解析模型中的模式连通性进行了系统的探索,揭示了模型合并有效的潜在条件。
  • 引入了一种模型合并技术,包括参数合并和缓冲区合并,适用于多目标域适应任务,可应用于任何单目标域适应模型。
  • 在数据可用性受限的情况下,也能达到与使用多个合并数据集进行训练相当的性能。

内容概述


论文研究的是场景理解模型的多目标域适应(MTDA)。虽然之前的方法通过领域间一致性损失取得了可观的结果,但它们通常假设可以不切实际地同时访问所有目标领域的图像,忽略了数据传输带宽限制和数据隐私等问题。鉴于这些挑战,论文提出了一个问题:如何在不直接访问训练数据的情况下合并在不同领域独立适应的模型?

对此问题的解决方案包含两个部分,即合并模型参数和合并模型缓冲区(即归一化层统计数据)。在合并模型参数方面,模式连通性的实证分析意外地表明,对于使用相同的预训练主干权重训练的单独模型,线性合并就足够了。在合并模型缓冲区方面,使用高斯先验来建模现实世界分布,并从单独训练模型的缓冲区中估计新的统计数据。

论文的方法简单而有效,取得了与数据组合训练基线相当的性能,同时消除了访问训练数据的必要性。

方法


以往的方法假设,在适应阶段能够同时访问所有目标领域图像的非实际假设。相反,论文方法的流程包括两个不同的阶段:

  1. 单目标域适应阶段,分别训练适应于各个目标领域的模型。简单地采用最先进的无监督域适应方法HRDA,利用各种主干架构,如ResNet和视觉Transformer
  2. 模型合并阶段(主要关注点),专注于将这些适应后的模型合并在一起以创建一个稳健的模型,而不需要访问任何训练数据。该方法包含模型的两个关键组成部分:参数(即可学习层的权重和偏置)和缓冲区(即归一化层的运行统计信息)。

参数合并

论文通过对比实验发现,当从相同的预训练权重开始时,域适应模型能够有效地过渡到多样的目标领域,同时在参数空间中保持线性模式连接。因此,这些训练模型之间的简单中点合并可以生成在两个领域中都具有鲁棒性的模型。

缓冲区合并

缓冲区,即用于批归一化(BN)层的运行均值和方差,与领域有密切关系,因为它们封装了特定领域的特征。现有方法主要处理在同一领域内对两个训练于不同子集的模型的合并,而论文研究在完全不同目标领域中训练的两个模型的合并,因此缓冲区合并的问题变得不再简单。

BN层的引入是为了缓解内部协变量偏移的问题,即输入的均值和方差在经过内部可学习层时发生变化。在这种背景下,基本考虑是后续的可学习层预期合并的BN层的输出遵循正态分布。由于输出的BN层保留了输入符合高斯先验的归纳偏见,因此可以从 \(\mathbf{\Gamma}_A\)\(\mathbf{\Gamma}_B\) 中获取的值来估计 \(\boldsymbol{\mu}^{(i)}\)\([\boldsymbol{\sigma}^{(i)}]^2\) 。首先获得来自该高斯先验的数据点的均值和方差的两个集合,以及这些集合的大小,共同利用这些值来估计该分布的参数。

当将合并方法扩展到 \(m (m \geq 2)\) 个高斯分布时,可以按如下方式计算已跟踪批次的数量 \(n^{(i)}\) 、均值的加权平均 \(\boldsymbol{\mu}^{(i)}\) 和方差的加权平均。

\[\begin{equation} \label{m-buffer-merging-n-and-mean} \begin{split} n^{(i)} =& n^{(i)}_1 + n^{(i)}_2 + \cdots +n^{(i)}_M, \\ \boldsymbol{\mu}^{(i)} =& \frac{1}{n^{(i)}} (n^{(i)}_1 \boldsymbol{\mu}^{(i)}_1 + n^{(i)}_2 \boldsymbol{\mu}^{(i)}_2 + \cdots + n^{(i)}_M \boldsymbol{\mu}^{(i)}_M),\\ \boldsymbol{\sigma}^2 =& \frac{\sum_{j=1}^{M} n^{(i)} (\boldsymbol{\sigma}^i_j)^2 + \sum_{j=1}^{M} n_j^i (\boldsymbol{\mu}_j^i - \boldsymbol{\mu}^i)^2}{\sum_{j=1}^{M} n_j^i}. \end{split} \end{equation} \]

主要实验




如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】

work-life balance.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/824011.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

项目管理知识体系梳理

经常在做项目,但项目管理体系在大脑里面是混乱的,今天特意画一个图加深一下印象。关注公众号了解更多知识:

面试官:Spring Boot 控制层中,@Service 可以完全替代 @Controller 吗?90% 都会答错!

作者:毅航 来源:juejin.cn/post/7393533304505204787 在SpringBoot开发中,@Controller和@Service基本上是日常开发中使用的最频繁的两个注解。但你有没考虑过@Service代替@Controller注解来标注到控制层的场景?换言之,经过@Service标注的控制层能否实现将用户请求分发到服…

两台linux的文件传输

起因 本地拉取docker镜像timeout,然后就准备把阿里云上已经在运行的镜像打包下载下来。 指令1:rsync rsync 是一个非常强大的工具,用于文件同步和高效的数据传输。它可以用于备份、文件传输以及数据同步等多种场景。 rsync 的主要优点在于其高效性和灵活性,特别是在处理大量…

10 早期计算机如何编程

程序需要加载进入内存, 最早是纺织机利用穿孔纸卡进行编程,穿孔纸卡用在过人口普查,用于记录一条条数字,但机器只有汇总功能,汇总穿孔数目 后来机器功能增多,人需要一个控制面板执行不同操作, 最早是重新布线更换指令,后来有了插线板,控制面板成了可拔插,可以给机器插…

PbootCMS模板首页循环调用所有栏目和对应内容

{pboot:nav} 栏目链接:[nav:link] 栏目名称:[nav:name] {pboot:list scode=[nav:scode] num=4 order=date} 内容链接:[list:link] 内容名称:[list:title] 内容图片:[list:ico] 内容时间:[list:date style=Y-m-d] 内容描述:[list:description] {/pboot:list} {/pboot:nav…

PbootCMS自带的sitemap.xml增加tag标签链接

修改 SitemapModel.php 文件:打开 /apps/home/model/SitemapModel.php 文件 在 78 行后面增加以下代码:public function getSortTags($scode) {$join = array(array(ay_content_sort b, a.scode=b.scode, LEFT),array(ay_model c, b.mcode=c.mcode, LEFT));$scode_arr = arra…

PbootCMS 面包屑导航样式修改和自定义的设置方法

问题:PbootCMS面包屑导航样式修改和自定义的设置方法。 答案:面包屑调用:{pboot:position}自定义参数:separator=*:分隔符,默认为 >>。 separatoricon=*:分割图标,例如 separatoricon=fa fa-angle-double-right。 indextext=*:首页文本,默认为“首页”。 index…

Maximum execution time of 30 secon

这种问题出现在Web开发环境中,特别是PHP等脚本语言中,当某个脚本运行时间超过预设的最大执行时间(例如30秒)时,服务器会终止该脚本的执行以防止资源被长时间占用。 解决方案增加脚本的最大执行时间在PHP中,可以通过修改php.ini文件中的max_execution_time值来增加脚本的最…

工地货梯AI人数识别系统

工地货梯AI人数识别系统采用人体神经网络深度学习算法,工地货梯AI人数识别系统对升降机轿厢内的人数进行智能分析和识别,能够精确识别出升降机内的人数。系统可以实时监测升降机内的人数变化,并根据设定的门限值,当人数超过限制时自动触发图像抓取和报警功能。报警方式可以…

物品堆放限高监测系统

物品堆放限高监测系统采用神经网络深度学习算法,物品堆放限高监测系统能够实时监测物品堆放区域的状态。通过在现场安装监控摄像头,系统对摄像头拍摄的实时视频进行处理分析,识别并判断物品堆放的高度情况。当堆放超过限定的高度范围时,系统将立即触发语音告警功能。物品堆…

煤块堵塞监测识别系统

煤块堵塞监测识别系统利用现场摄像头实时监测煤矿生产线上的皮带煤块堵塞情况。煤块堵塞监测识别系统可以准确地识别出堆积在生产线上的煤块,并计算出其堆积的程度。当煤块堆积的程度超过预设的警戒范围时,系统会立刻通知相关工作人员前往现场进行物料疏通。与传统人工巡检相…

1、K8S环境渗透学习

一、概述Kubernetes,简称k8s,是当前主流的容器调度平台,被称为云原生时代的操作系统。在实际项目也经常发现厂商部署了使用k8s进行管理的云原生架构环境,在目前全面上云的趋势,有必要学习在k8s环境的下的一些攻击手法。 二、k8s用户 Kubernetes 集群中包含两类用户:一类是…