联邦学习中的模型聚合

目录

联邦学习中的模型聚合

1.client-server 算法

2. fully decentralized(完全去中心化)算法


联邦学习中的模型聚合

在联邦学习的情景下引入了多任务学习,其采用的手段是使每个client/task节点的训练数据分布不同,从而使各任务节点学习到不同的模型,且每个任务节点以及全局(global)的模型都由多个分量模型集成。该论文最关键与核心的地方在于将各任务节点学习到的模型进行聚合/通信,依据模型聚合方式的不同,可以将模型采用的算法分为client-server方法,和fully decentralized(完全去中心化)的方法

因为有多种任务聚合器(Aggregator)要实现,采取的措施是先实现Aggregator抽象基类,实现好一些通用方法,并规定好抽象方法的接口,然后具体的任务聚合类继承抽象基类,然后做具体的实现。

我们先来看任务聚合器(Aggregator)这一抽象基类

class Aggregator(ABC):r"""Aggregator的基类. `Aggregator`规定了client之间的通信"""def __init__(self,clients,global_learners_ensemble,log_freq,global_train_logger,global_test_logger,sampling_rate=1.,sample_with_replacement=False,test_clients=None,verbose=0,seed=None,*args,**kwargs):rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))self.rng = random.Random(rng_seed) # 随机数生成器self.np_rng = np.random.default_rng(rng_seed) # numpy随机数生成器if test_clients is None:test_clients = []self.clients = clients #  List[Client]self.test_clients = test_clients #  List[Client]self.global_learners_ensemble = global_learners_ensemble # List[Learner]self.device = self.global_learners_ensemble.deviceself.log_freq = log_freqself.verbose = verbose# verbose: 调整输出打印的冗余度(verbosity), # `0` 表示quiet(无任何打印输出), `1` 显示日志, `2` 显示所有局部日志; 默认是 `0`self.global_train_logger = global_train_loggerself.global_test_logger = global_test_loggerself.model_dim = self.global_learners_ensemble.model_dim # #模型特征维度self.n_clients = len(clients)self.n_test_clients = len(test_clients)self.n_learners = len(self.global_learners_ensemble)# 存储为每个client分配的权重(权重为0-1之间的小数)self.clients_weights =\torch.tensor([client.n_train_samples for client in self.clients],dtype=torch.float32)self.clients_weights = self.clients_weights / self.clients_weights.sum()self.sampling_rate = sampling_rate  #  clients在每一轮使用的比例,默认为`1.`self.sample_with_replacement = sample_with_replacement #对client进行采用是可重复还是无重复的,with_replacement=True表示可重复的,否则是不可重复的# 每轮迭代需要使用到的client个数self.n_clients_per_round = max(1, int(self.sampling_rate * self.n_clients))# 采样得到的client列表self.sampled_clients = list()# 记载当前的迭代通信轮数self.c_round = 0 self.write_logs()@abstractmethoddef mix(self): """该方法用于完成各client之间的权重参数与通信操作"""pass@abstractmethoddef update_clients(self): """该方法用于将所有全局分量模型拷贝到各个client,相当于boardcast操作"""passdef update_test_clients(self):"""将全局(gobal)的所有分量模型都拷贝到各个client上"""def write_logs(self):"""对全局(global)的train和test数据集的loss和acc做记录需要对所有client的所有样本做累加,然后除以所有client的样本总数做平均。"""def save_state(self, dir_path):"""保存aggregator的模型state,。例如, `global_learners_ensemble`中每个分量模型'learner'的state字典(以`.pt`文件格式),以及`self.clients` 中每个client的 `learners_weights` (注意,这个权重不是模型内部的参数,而是进行继承的时候对各个分量模型赋予的权重,包含train和test两部分,以一个大小为n_clients(n_test_clients)× n_learners的numpy数组的格式,即`.npy` 文件)。"""def load_state(self, dir_path):"""加载aggregator的模型state,即save_state方法里保存的那些"""def sample_clients(self):"""对clients进行采样,如果self.sample_with_replacement为True,则为可重复采样,否则,则为不可重复采用。最终得到一个clients子集列表并赋予self.sampled_clients"""

1.client-server 算法

这种方式的通信/聚合方法也称中心化(centralized)方法,因为该方法在每一轮迭代最后将所有client的权重数据汇集到server节点。这种方法的优化迭代部分的伪代码示意如下:

 

落实到具体代码实现上,这种方法的Aggregator设计如下:

class CentralizedAggregator(Aggregator):r""" 标准的中心化Aggreagator所有clients在每一轮迭代末和average client完全同步."""def mix(self):self.sample_clients()# 对self.sampled_clients中每个client的参数进行优化for client in self.sampled_clients:# 相当于伪代码第11行调用的LocalSolver函数client.step()# 遍历global模型(self.global_learners_ensemble) 中每一个分量模型(learner)# 相当于伪代码第13行for learner_id, learner in enumerate(self.global_learners_ensemble):# 获取所有client中对应learner_id的分量模型learners = [client.learners_ensemble[learner_id] for client in self.clients]# global模型的分量模型为所有client对应分量模型取平均,相当于伪代码第14行average_learners(learners, learner, weights=self.clients_weights)# 将更新后的模型赋予所有clients,相当于伪代码第5行的boardcast操作self.update_clients()# 通信轮数+1self.c_round += 1if self.c_round % self.log_freq == 0:self.write_logs()def update_clients(self):"""此函数负责将所有全局分量模型拷贝到各个client,相当于伪代码中第5行的boardcast操作"""for client in self.clients:for learner_id, learner in enumerate(client.learners_ensemble):copy_model(learner.model, self.global_learners_ensemble[learner_id].model)if callable(getattr(learner.optimizer, "set_initial_params", None)):learner.optimizer.set_initial_params(self.global_learners_ensemble[learner_id].model.parameters())

2. fully decentralized(完全去中心化)算法

这种方法之所以被称为去中心化的,因为该方法在每一轮迭代不需要所有client的权重数据汇集到一个特定的server节点,而只需要完成每个节点和其邻居进行通信(参数共享)即可。这种方法的优化迭代部分的伪代码示意如下:

落实到具体代码实现上,这种方法的Aggregator设计如下:

 

class DecentralizedAggregator(Aggregator):def __init__(self,clients,global_learners_ensemble,mixing_matrix,log_freq,global_train_logger,global_test_logger,sampling_rate=1.,sample_with_replacement=True,test_clients=None,verbose=0,seed=None):super(DecentralizedAggregator, self).__init__(clients=clients,global_learners_ensemble=global_learners_ensemble,log_freq=log_freq,global_train_logger=global_train_logger,global_test_logger=global_test_logger,sampling_rate=sampling_rate,sample_with_replacement=sample_with_replacement,test_clients=test_clients,verbose=verbose,seed=seed)self.mixing_matrix = mixing_matrixassert self.sampling_rate >= 1, "partial sampling is not supported with DecentralizedAggregator"def update_clients(self):passdef mix(self):# 对各clients的模型参数进行优化for client in self.clients:client.step()# 存储每个模型各参数混合的权重# 行对应不同的client,列对应单个模型中不同的参数# (注意:每个分量有独立的mixing_matrix)mixing_matrix = torch.tensor(self.mixing_matrix.copy(),dtype=torch.float32,device=self.device)# 遍历global模型(self.global_learners_ensemble) 中每一个分量模型(learner)# 相当于伪代码第14行for learner_id, global_learner in enumerate(self.global_learners_ensemble):# 用于将指定learner_id的各client的模型state读出暂存state_dicts = [client.learners_ensemble[learner_id].model.state_dict() for client in self.clients]# 遍历global模型中的各参数, key对应模型中参数的名称for key, param in global_learner.model.state_dict().items():shape_ = param.shapemodels_params = torch.zeros(self.n_clients, int(np.prod(shape_)), device=self.device)for ii, sd in enumerate(state_dicts):# models_params的第ii个下标存储的是第ii个client的(名为key的)参数models_params[ii] = sd[key].view(1, -1) # models_params的每一行是一个client的参数# @符号表示矩阵乘/矩阵向量乘# 故这里表示每个client参数是其他所有client参数的混合models_params = mixing_matrix @ models_paramsfor ii, sd in enumerate(state_dicts):# 将第ii个client的(名为key的)参数存入state_dicts中对应位置sd[key] = models_params[ii].view(shape_)# 将更新好的参数从state_dicts存入各client节点的模型中for client_id, client in enumerate(self.clients):client.learners_ensemble[learner_id].model.load_state_dict(state_dicts[client_id])# 通信轮数+1self.c_round += 1if self.c_round % self.log_freq == 0:self.write_logs()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/7988.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ELK增量同步数据【MySql->ES】

一、前置条件 1. linux,已经搭建好的logstasheskibana【系列版本7.0X】,es 的plugs中安装ik分词器 ES版本: Logstash版本: (以上部署,都是运维同事搞的,我不会部署,同事给力&#…

Neighborhood Contrastive Learning for Novel Class Discovery (CVPR 2021)

Neighborhood Contrastive Learning for Novel Class Discovery (CVPR 2021) 摘要 在本文中,我们解决了新类发现(NCD)的问题,即给定一个具有已知类的有标签数据集,在一组未标记的样本中揭示新的类。我们利用ncd的特性构建了一个新的框架&am…

JMeter之事务控制器实践

目录 前言 事务控制器 JMeter控制器添加路径: Generate parent sample 1、不勾选任何选项: 2、勾选【Generate parent sample】 3、Include duration of timer and pre-post processors in generated sample 小结 前言 在JMeter中,事…

Linux下GO IDE安装和配置(附快捷键)

目前,GoLand、VSCode 这些 IDE 都很优秀,但它们都是 Windows 系统下的 IDE。在 Linux 系统下我们可以选择将 Vim 配置成 Go IDE。熟练 Vim IDE 操作之后,开发效率不输 GoLand 和 VSCode。有多种方法可以配置一个 Vim IDE,这里我选…

基于免疫优化算法的物流配送中心选址规划研究(Matlab实现)

目录 1 概述 2 物流配送中心选址规划研究 3 Matlab代码 4 结果 1 概述 影响物流配送中心选址的因素有很多,精确选址优化问题亟待解决。通过充分考虑货物的配送时间,将免疫算法加入其中,介绍了物流配送选址模型的构建以及免疫算法实现的相关步骤,最后利用matlab软件进行分析,提出…

UE5.2 LyraDemo源码阅读笔记(二)

UE5.2 LyraDemo源码阅读笔记(二) 创建了关卡中的体验玩家Actor和7个体验玩法入口之后。 接下来操作关卡中的玩家与玩法入口交互,进入玩法入口,选择进入B_LyraFrontEnd_Experience玩法入口,也就是第3个入口。触发以下请…

web学习1--maven--项目管理工具

写在前面: 这学期搞主攻算法去了,web的知识都快忘了。开始复习学习了。 文章目录 maven介绍功能介绍maven安装jar包搜索仓库 pom文件项目介绍父工程依赖管理属性控制可选依赖构建 依赖管理依赖的传递排除依赖可选依赖 maven生命周期分模块开发模块聚合…

算法的时间复杂度

算法的时间复杂度 什么是时间复杂度 时间复杂度是衡量算法执行时间随输入规模增长而增长的度量标准。它描述了算法运行时间与问题规模之间的关系,用于评估算法的效率和性能。 通常情况下,时间复杂度表示为大O符号(O)&#xff0…

限时等待的互斥量

本文结束一种新的锁&#xff0c;称为 timed_mutex 代码如下&#xff1a; #include<iostream> #include<mutex> #include<thread> #include<string> #include<chrono>using namespace std;timed_mutex tmx;void fun1(int id, const string&a…

C/C++的发展历程和未来趋势

文章目录 C/C的起源C/C的应用C/C开发的工具C/C未来趋势 C/C的起源 C语言 C语言是一种通用的高级编程语言&#xff0c;由美国计算机科学家Dennis Ritchie在20世纪70年代初期开发出来。起初&#xff0c;C语言是作为操作系统UNIX的开发语言而创建的。C语言的设计目标是提供一种功…

基于Springboot+Vue的手机商城(源代码+数据库)081

基于SpringbootVue的手机商城(源代码数据库)081 一、系统介绍 本项目前后端分离&#xff08;该项目还有ssmvue版本&#xff09; 本系统分为管理员、用户两种角色 用户角色包含以下功能&#xff1a; 登录、注册、商品搜索、收藏、购物车、订单提交、评论、退款、收货地址管…

【C++】vector模拟实现

&#x1f680; 作者简介&#xff1a;一名在后端领域学习&#xff0c;并渴望能够学有所成的追梦人。 &#x1f681; 个人主页&#xff1a;不 良 &#x1f525; 系列专栏&#xff1a;&#x1f6f8;C &#x1f6f9;Linux &#x1f4d5; 学习格言&#xff1a;博观而约取&#xff0…