《Communication-Efficient Learning of Deep Networks from Decentralized Data》

Communication-Efficient Learning of Deep Networks from Decentralized Data

这篇文章算是联邦学习的开山之作吧,提出了FedAvg的算法,文中对比了不同客户端本地训练次数,客户端训练数据集划分的影响。

0. Abstract

现代移动设备可以获取大量适合学习模型的数据,然而,这些丰富的数据通常是隐私敏感的、数量很大的,这可能导致无法记录到数据中心并使用传统方法进行培训。本文提倡一种替代方案,将训练数据分布在移动设备上,并通过聚合本地计算的更新来学习共享模型,称为联合学习。
本文在五种模型四个数据集下测试了联邦学习的效果。

1. Introduction

每个客户端都有一个从未上传到服务器的本地训练数据集。相反,每个客户端计算一个对服务器维护的当前全局模型的更新,并且只有这个更新才会被通信。

主要贡献
1)确定了移动设备上分散数据的训练问题作为一个重要的研究方向;
2)选择一个简单实用的算法,可以应用于这一设置;
3)对所提出的方法进行了广泛的实证评估。更具体地说,我们引入了FederatedAveraging算法,它将每个客户端的局部随机梯度下降(SGD)与执行模型平均的服务器相结合。对该算法进行了大量的实验,证明了该算法对不平衡和非iid数据分布具有鲁棒性,并且可以将在分散数据上训练深度网络所需的通信次数减少几个数量级。

联邦学习的理想问题有以下属性:
1)在移动设备上进行的训练比在数据中心中通常可用的代理数据上进行的训练具有明显的优势。
2)这些数据是隐私敏感的或大的(与模型的大小相比),所以最好不要纯粹为了模型训练的目的而将其记录到数据中心(服务于集中收集原则)。
3)对于监督任务,数据上的标签可以从用户交互中自然地推断出来。

隐私安全:与持久数据上的数据中心培训相比,联邦学习具有明显的隐私优势。
联邦优化:我们将联邦学习中隐含的优化问题称为联邦优化,它与分布式优化建立了联系(并形成对比)。联邦优化与典型的分布式优化问题有几个关键的区别:

  1. 非独立同分布:标签分布不均
  2. 不平衡:样本数量不同
  3. 大规模
  4. 通信约束:移动设备经常离线,或者连接速度慢或费用高。

我们的目标是使用额外的计算,以减少训练模型所需的通信轮数。有两种主要的方法可以增加计算量:1)增加并行性,在每个通信轮之间我们使用更多的客户端独立工作;2)增加了每个客户端的计算量,每个客户端在每个通信轮之间执行一个更复杂的计算,而不是执行一个简单的计算,比如梯度计算。我们研究了这两种方法,但我们实现的加速主要是由于在每个客户机上增加了更多的计算量,一旦使用了客户机上的最小并行度。

2. FedAvg

我们在每一轮中选择一个客户端的Cfraction,并计算这些客户端所持有的所有数据的损失梯度。因此,C控制全局批量大小,C = 1对应的是全批量(非随机)梯度下降我们将此基线算法称为FederatedSGD(或federdsgd)。

FedAvg计算量由三个关键参数控制:C,每轮执行计算的客户端的比例; E,每个客户端在每一轮中对其本地数据集进行训练的次数; 和B,用于客户端更新的本地minibatch大小。 我们写b=∞来表示整个局部数据集被视为单个小批处理。 因此,在该算法族的一个端点上,我们可以取b= ∞ \infty 和e=1,这正好对应于FEDSGD。
在这里插入图片描述
θ \theta θ是两个模型的聚合参数,聚合权重一个是 θ \theta θ一个是 1 − θ 1-\theta 1θ,根据实验结果图我们可以发现,如果两个模型的初始模型参数不同,使用平均聚合0.5的结果不是很好,反而单独使用一个模型的权重, θ \theta θ接近0或者1反而更好。但是如果两个模型初始化使用相同的模型参数,那么使用0.5进行聚合的结果就会比较好。

在这里插入图片描述
FedAvg算法分为客户端的服务器两块,客户端就是简单地本地训练,这里训练首先划分了一个batch的大小,batch集合为 B \mathcal{B} B,然后客户端需要迭代E轮,那样客户端就总共需要循环 E ∗ n k B E*\frac{n_k}{B} EBnk次。服务器负责选中随机客户端,并且下发全局模型并且等待客户端返回更新结果,这里的更新权重是以客户端本地的数据量来决定的,客户端本地的数据量越多,FedAvg聚合时所占的权重也就越大。

3. Experiment

实验设置:初步研究包括两个数据集上的三个模型族。前两个是MNIST数字识别任务[26]:
1)一个简单的多层感知器,有2个隐藏层,每个层有200个单元,使用ReLu激活(总共199,210个参数),我们称之为MNIST 2NN。
2)有两个5x5卷积层的CNN(第一个是32个通道,第二个是64个,每个都有2x2 max pooling),全连接层有512个单元,ReLu激活,最后有softmax输出层(总共1663370个参数)。

研究了在客户端上划分MNIST数据的两种方法:IID,数据被洗牌,然后划分为100个客户端,每个客户端接收600个示例,以及Non-IID,按照数字标签对数据排序,将其划分为200个大小为300的分片,并为100个客户端分配2个分片。这是一种病态的非iid数据分区,因为大多数客户机将只有两个数字的示例,这让我们可以探索我们的算法在高度非iid数据上的破坏程度。

对于语言建模,我们从莎士比亚全集[32]构建了一个数据集。生成包含1146个客户机的数据集。对于每个客户,我们将数据分解为一组训练线(角色的前80%的线)和测试线(最后的20%,四舍五入到至少一行)。同样分为IID和non-IID的两组数据,使用的模型为LSTM。

在这里插入图片描述

对于b= ∞ \infty (对于MNIST每轮处理所有600个客户机示例为一个批次),在增加客户机部分方面只有很小的优势。 使用较小的批处理大小b=10显示了使用c≥0.1的显著改进,尤其是在非IID情况下。 基于这些结果,在我们剩下的大部分实验中,我们确定C=0.1,这在计算效率和收敛速度之间取得了很好的平衡。 比较表1中b= ∞ \infty 和b=10列的轮数,可以看到显著的加速,我们接下来将对此进行研究。

在这里插入图片描述

在本节中,我们将C=0.1,并在每一轮中为每个客户机添加更多的计算量,或者减少B,或者增加E,或者两者兼而有之。我们根据这个统计信息对表中每一部分的行进行排序。 我们看到,通过改变e和b来增加u是有效的。
通过实验结果可以看出,增大客户端本地训练迭代次数,减少batch的大小可以有效地加大客户端本地的计算时长,同时也可以带来更大的训练加速,在更少的大轮次内达到收敛的效果。

准确率曲线图
在这里插入图片描述
在CIFAR10数据集上作者还与SGD进行了对比,这里的SGD但就是对集中式机器学习,并以此作为基线。(其实我觉得这有点不公平,这就是分布式机器学习和集中式机器学习对比了,效率肯定吊打单机学习)
在这里插入图片描述
在LSTM上结果相同
在这里插入图片描述

4. Conclusion

我们的实验表明联邦学习是可行的,因为FedAvg使用相对较少的通信轮来训练高质量的模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/29538.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

opencv -13 掩模

什么是掩膜? 在OpenCV中,掩模(mask)是一个与图像具有相同大小的二进制图像,用于指定哪些像素需要进行操作或被考虑。掩模通常用于选择特定区域或进行像素级别的过滤操作。 OpenCV 中的很多函数都会指定一个掩模&…

matlab入门

命名规则: clc:清除命令行的所有命令 clear all:清除所有工作区的内容 注释:两个% 空格 %% matlab的数据类型 1、数字 3 3 * 5 3 / 5 3 5 3 - 52、字符与字符串 s a %% 求s的ascill码 abs(s) char(97) num2str(65) str I…

Simulink仿真模块 - Data Store Read

Data Store Read:从数据存储中读取数据 在仿真库中的位置为:Simulink / Signal Routing 模型为: 说明 Data Store Read 模块将指定数据存储中的数据复制到其输出中。多个 Data Store Read 模块可从同一个数据存储读取数据。 用来读取数据的源数据存储由 Data Store Memory 模…

上门服务小程序|上门家政小程序开发

随着现代生活节奏的加快和人们对便利性的追求,上门家政服务逐渐成为了许多家庭的首选。然而,传统的家政服务存在着信息不透明、服务质量不稳定等问题,给用户带来了困扰。为了解决这些问题,上门家政小程序应运而生。上门家政小程序…

自动收小麦机(牛客2023萌新)

题目链接 示例1 输入 复制 4 1 2 1 1 4 5 2 2 2 3 4 输出 复制 10 说明 在第4格放出水流后,水流会流向第3格,由于第3格高度比第4格低,所以水流继续向左流向第2格,因为平地水流只能流2格,所以到达第2格后水流停…

sqli-labs 堆叠注入 解析

打开网页首先判断闭合类型 说明为双引号闭合 我们可以使用单引号将其报错 先尝试判断回显位 可以看见输出回显位为2,3 尝试暴库爆表 这时候进行尝试堆叠注入,创造一张新表 ?id-1 union select 1,database(),group_concat(table_name) from informatio…

店铺记账用什么软件好?应该如何选购?

店铺记账过程中,会遇到各种问题:手写记账容易出错、效率低下、数据容易丢失;手动整理数据导致实际库存和账面库存不匹配,影响补货和订单管理。 而借助专业的店铺记账软件,可以有效解决上面这些问题,通过自动…

TCP的三次握手过程

TCP 是面向连接的协议,所以使用 TCP 前必须先建立连接,而建立连接是通过三次握手来进行的。三次握手的过程如下图: 刚开始客户端处于 closed 的状态,服务端处于 listen 状态。 第一次握手:客户端给服务端发一个 SYN 报…

AI数字人:图像超分辨率模型 Real-ESRGAN

1 Real-ESRGAN介绍 1.1 Real-ESRGAN是什么? Real-ESRGAN全名为Enhanced Super-Resolution GAN:增强的超分辨率的对抗生成网络,是由腾讯ARC实验室发布的一个盲图像超分辨率模型,它的目标是开发出实用的图像/视频修复算法&#xf…

Ceph

Ceph简介 Ceph使用C语言开发,是一个开放、自我修复和自我管理的开源分布式存储系统。具有高扩展性、高性能、高可靠性的优点。Ceph目前已得到众多云计算厂商的支持并被广泛应用。RedHat及OpenStack,Kubernetes都可与Ceph整合以支持虚拟机镜像的后端存储…

消息队列——RabbitMQ基本概念+容器化部署和简单工作模式程序

目录 基本概念 MQ 的优势 1.应用解耦 2.异步提速 3.削峰填谷 MQ 的劣势 使用mq的条件 常见MQ产品 RabbitMQ简介 RabbitMQ的六种工作模式 JMS RabbitMQ安装和配置。 RabbitMQ控制台使用。 RabbitMQ快速入门——生产者 需求: RabbitMQ快速入门——消费者 小结 基本概…

快7月底了,让我康康有多少准备跳槽的

前两天跟朋友感慨,今年的铜三铁四、裁员、疫情影响导致好多人都没拿到offer!现在已经快7月底了,具体金九银十只剩下2个月。 对于想跳槽的职场人来说,绝对要从现在开始做准备了。这时候,很多高薪技术岗、管理岗的缺口和市场需求也…