目录
联邦学习中的模型聚合
1.client-server 算法
2. fully decentralized(完全去中心化)算法
联邦学习中的模型聚合
在联邦学习的情景下引入了多任务学习,其采用的手段是使每个client/task节点的训练数据分布不同,从而使各任务节点学习到不同的模型,且每个任务节点以及全局(global)的模型都由多个分量模型集成。该论文最关键与核心的地方在于将各任务节点学习到的模型进行聚合/通信,依据模型聚合方式的不同,可以将模型采用的算法分为client-server方法,和fully decentralized(完全去中心化)的方法
因为有多种任务聚合器(Aggregator)要实现,采取的措施是先实现Aggregator抽象基类,实现好一些通用方法,并规定好抽象方法的接口,然后具体的任务聚合类继承抽象基类,然后做具体的实现。
我们先来看任务聚合器(Aggregator)这一抽象基类
class Aggregator(ABC):r"""Aggregator的基类. `Aggregator`规定了client之间的通信"""def __init__(self,clients,global_learners_ensemble,log_freq,global_train_logger,global_test_logger,sampling_rate=1.,sample_with_replacement=False,test_clients=None,verbose=0,seed=None,*args,**kwargs):rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))self.rng = random.Random(rng_seed) # 随机数生成器self.np_rng = np.random.default_rng(rng_seed) # numpy随机数生成器if test_clients is None:test_clients = []self.clients = clients # List[Client]self.test_clients = test_clients # List[Client]self.global_learners_ensemble = global_learners_ensemble # List[Learner]self.device = self.global_learners_ensemble.deviceself.log_freq = log_freqself.verbose = verbose# verbose: 调整输出打印的冗余度(verbosity), # `0` 表示quiet(无任何打印输出), `1` 显示日志, `2` 显示所有局部日志; 默认是 `0`self.global_train_logger = global_train_loggerself.global_test_logger = global_test_loggerself.model_dim = self.global_learners_ensemble.model_dim # #模型特征维度self.n_clients = len(clients)self.n_test_clients = len(test_clients)self.n_learners = len(self.global_learners_ensemble)# 存储为每个client分配的权重(权重为0-1之间的小数)self.clients_weights =\torch.tensor([client.n_train_samples for client in self.clients],dtype=torch.float32)self.clients_weights = self.clients_weights / self.clients_weights.sum()self.sampling_rate = sampling_rate # clients在每一轮使用的比例,默认为`1.`self.sample_with_replacement = sample_with_replacement #对client进行采用是可重复还是无重复的,with_replacement=True表示可重复的,否则是不可重复的# 每轮迭代需要使用到的client个数self.n_clients_per_round = max(1, int(self.sampling_rate * self.n_clients))# 采样得到的client列表self.sampled_clients = list()# 记载当前的迭代通信轮数self.c_round = 0 self.write_logs()@abstractmethoddef mix(self): """该方法用于完成各client之间的权重参数与通信操作"""pass@abstractmethoddef update_clients(self): """该方法用于将所有全局分量模型拷贝到各个client,相当于boardcast操作"""passdef update_test_clients(self):"""将全局(gobal)的所有分量模型都拷贝到各个client上"""def write_logs(self):"""对全局(global)的train和test数据集的loss和acc做记录需要对所有client的所有样本做累加,然后除以所有client的样本总数做平均。"""def save_state(self, dir_path):"""保存aggregator的模型state,。例如, `global_learners_ensemble`中每个分量模型'learner'的state字典(以`.pt`文件格式),以及`self.clients` 中每个client的 `learners_weights` (注意,这个权重不是模型内部的参数,而是进行继承的时候对各个分量模型赋予的权重,包含train和test两部分,以一个大小为n_clients(n_test_clients)× n_learners的numpy数组的格式,即`.npy` 文件)。"""def load_state(self, dir_path):"""加载aggregator的模型state,即save_state方法里保存的那些"""def sample_clients(self):"""对clients进行采样,如果self.sample_with_replacement为True,则为可重复采样,否则,则为不可重复采用。最终得到一个clients子集列表并赋予self.sampled_clients"""
1.client-server 算法
这种方式的通信/聚合方法也称中心化(centralized)方法,因为该方法在每一轮迭代最后将所有client的权重数据汇集到server节点。这种方法的优化迭代部分的伪代码示意如下:
落实到具体代码实现上,这种方法的Aggregator设计如下:
class CentralizedAggregator(Aggregator):r""" 标准的中心化Aggreagator所有clients在每一轮迭代末和average client完全同步."""def mix(self):self.sample_clients()# 对self.sampled_clients中每个client的参数进行优化for client in self.sampled_clients:# 相当于伪代码第11行调用的LocalSolver函数client.step()# 遍历global模型(self.global_learners_ensemble) 中每一个分量模型(learner)# 相当于伪代码第13行for learner_id, learner in enumerate(self.global_learners_ensemble):# 获取所有client中对应learner_id的分量模型learners = [client.learners_ensemble[learner_id] for client in self.clients]# global模型的分量模型为所有client对应分量模型取平均,相当于伪代码第14行average_learners(learners, learner, weights=self.clients_weights)# 将更新后的模型赋予所有clients,相当于伪代码第5行的boardcast操作self.update_clients()# 通信轮数+1self.c_round += 1if self.c_round % self.log_freq == 0:self.write_logs()def update_clients(self):"""此函数负责将所有全局分量模型拷贝到各个client,相当于伪代码中第5行的boardcast操作"""for client in self.clients:for learner_id, learner in enumerate(client.learners_ensemble):copy_model(learner.model, self.global_learners_ensemble[learner_id].model)if callable(getattr(learner.optimizer, "set_initial_params", None)):learner.optimizer.set_initial_params(self.global_learners_ensemble[learner_id].model.parameters())
2. fully decentralized(完全去中心化)算法
这种方法之所以被称为去中心化的,因为该方法在每一轮迭代不需要所有client的权重数据汇集到一个特定的server节点,而只需要完成每个节点和其邻居进行通信(参数共享)即可。这种方法的优化迭代部分的伪代码示意如下:
落实到具体代码实现上,这种方法的Aggregator设计如下:
class DecentralizedAggregator(Aggregator):def __init__(self,clients,global_learners_ensemble,mixing_matrix,log_freq,global_train_logger,global_test_logger,sampling_rate=1.,sample_with_replacement=True,test_clients=None,verbose=0,seed=None):super(DecentralizedAggregator, self).__init__(clients=clients,global_learners_ensemble=global_learners_ensemble,log_freq=log_freq,global_train_logger=global_train_logger,global_test_logger=global_test_logger,sampling_rate=sampling_rate,sample_with_replacement=sample_with_replacement,test_clients=test_clients,verbose=verbose,seed=seed)self.mixing_matrix = mixing_matrixassert self.sampling_rate >= 1, "partial sampling is not supported with DecentralizedAggregator"def update_clients(self):passdef mix(self):# 对各clients的模型参数进行优化for client in self.clients:client.step()# 存储每个模型各参数混合的权重# 行对应不同的client,列对应单个模型中不同的参数# (注意:每个分量有独立的mixing_matrix)mixing_matrix = torch.tensor(self.mixing_matrix.copy(),dtype=torch.float32,device=self.device)# 遍历global模型(self.global_learners_ensemble) 中每一个分量模型(learner)# 相当于伪代码第14行for learner_id, global_learner in enumerate(self.global_learners_ensemble):# 用于将指定learner_id的各client的模型state读出暂存state_dicts = [client.learners_ensemble[learner_id].model.state_dict() for client in self.clients]# 遍历global模型中的各参数, key对应模型中参数的名称for key, param in global_learner.model.state_dict().items():shape_ = param.shapemodels_params = torch.zeros(self.n_clients, int(np.prod(shape_)), device=self.device)for ii, sd in enumerate(state_dicts):# models_params的第ii个下标存储的是第ii个client的(名为key的)参数models_params[ii] = sd[key].view(1, -1) # models_params的每一行是一个client的参数# @符号表示矩阵乘/矩阵向量乘# 故这里表示每个client参数是其他所有client参数的混合models_params = mixing_matrix @ models_paramsfor ii, sd in enumerate(state_dicts):# 将第ii个client的(名为key的)参数存入state_dicts中对应位置sd[key] = models_params[ii].view(shape_)# 将更新好的参数从state_dicts存入各client节点的模型中for client_id, client in enumerate(self.clients):client.learners_ensemble[learner_id].model.load_state_dict(state_dicts[client_id])# 通信轮数+1self.c_round += 1if self.c_round % self.log_freq == 0:self.write_logs()