联邦学习开山之作Communication-Efficient Learning of Deep Networks from Decentralized Data

1 介绍

1.1 背景

越来越多的手机和平板电脑成为许多人的主要计算设备。这些设备上强大的传感器(包括摄像头、麦克风和GPS),加上它们经常被携带的事实,意味着它们可以访问前所未有的大量数据,其中大部分本质上是私人的。根据这些数据学习的模型持有承诺通过支持更智能的应用程序来大大提高可用性,但数据的敏感性意味着将其存储在集中位置存在风险和责任。

1.2 本文贡献

本文的主要贡献是

  • 将来自移动设备的分散数据的训练问题(联邦学习)确定为一个重要的研究方向;

  • 选择可以应用于该设置的简单实用的算法FedAvg;

  • 对所提出的方法进行广泛的实证评估。

更具体地说,本文介绍了FedAvg算法,它将每个客户端上的局部随机梯度下降(SGD)与执行模型平均的服务器相结合。本文对该算法进行了广泛的实验,证明了它对不平衡非IID数据分布具有鲁棒性,并且可以将在分散数据上训练深度网络所需的通信轮次减少几个数量级。

1.3 联邦学习的理想问题

  • 对真实世界的移动设备上的数据进行训练 比 对数据中心可获得的代理数据进行训练,有明显的优势;
  • 数据是隐私的或数据量很大;
  • 对于监督任务,标签可以从用户交互中自然推断出来。

举例:

  • 图像分类任务。预测哪些照片最有可能在未来被多次查看或分享。用户拍摄的照片是隐私的,但对于本地,用户对照片的删除、共享等行为就是推断出来的标签。
  • 单词预测。用户在手机上输入时,输入法预测下一个单词。输入信息是隐私的,用户选择的下一个单词就是推断出来的标签。

1.4 联邦学习与分布式的对比

  • 非独立同分布:不同用户对移动设备的使用是不同的,因此数据非独立同分布。
  • 不平衡:一些用户会比其他人更频繁地使用服务或应用程序,从而导致本地训练数据的数量不同。
  • 大规模分布式:预计参与优化的客户端数量将远远大于每个客户端的平均实例数量。
  • 通信受限:移动设备有时候离线,或处于缓慢昂贵的连接中。

2 FedAvg

2.1 损失函数

对于机器学习问题,对于样本\((x_i,y_i)\)的损失为\(f_i(w)\),那么全局损失定义为:

\[f(w)\overset{\text{def}}{=}\frac{1}{n}\overset{n}{\sum}\limits_{i=1}f_i(w) \]

在联邦学习问题中,假设有\(K\)个客户端,第\(k\)个客户端的数据集为\(P_k\),数据集大小\(n_k=|P_k|\)。那么对于客户端\(k\),该客户端数据的损失函数为:

\[F_k(w)=\frac{1}{n_k}\sum\limits_{i\in P_k}f_i(w) \]

全局的损失函数定义为客户端损失的加权平均:

\[f(w)=\overset{K}{\sum}\limits_{k=1}\frac{n_k}{n}F_k(w) \]

2.2 通信成本与计算成本

对于数据集中到中心的情况,由于数据量较大,通信成本相对较小,计算成本较大。

通信成本指客户端与中央服务器之间传输数据所需的成本。联邦学习中,会受到移动设备带宽限制,同时客户端通常仅在有电源和有WiFi等情况下愿意参与优化,因此通信成本较大。而设备数据量小、手机有GPU等特性使得计算成本较小。

为了减小通信成本,方法:

  • 增加并行,每轮使用更多客户端(对应“客户端通常仅在有电源和有WiFi等情况下愿意参与优化”限制)。
  • 每个客户端在每个通信轮之间执行更复杂的计算,而不是执行像梯度计算这样的简单计算。

2.3 相关工作

以往工作没有考虑不平衡和非独立同分布数据,以及客户端数量少。

2.4 FedSGD

根据当前的模型\(w_t\)计算梯度\(g_k=\nabla F_k(w_t)\)。由于:

\[\nabla f(w_t)=\nabla[\overset{K}{\sum}\limits_{k=1}\frac{n_k}{n}F_k(w_t)]=\overset{K}{\sum}\limits_{k=1}\frac{n_k}{n}g_k \]

那么中心服务器聚合梯度并进行更新的结果为:

\[w_{t+1}\leftarrow w_t-\eta\nabla f(w_t)=w_t-\eta\overset{K}{\sum}\limits_{k=1}\frac{n_k}{n}g_k \]

上式也等价于客户端先在本地做一次梯度更新,中心服务器再对模型进行加权平均:

\[w^k_{t+1}\leftarrow w_t-\eta g_k \]

\[w_{t+1}\leftarrow \overset{K}{\sum}\limits_{k=1}\frac{n_k}{n}w^k_{t+1} \]

2.5 FedAvg

写成上述第二种形式后,可以在做平均之前,多次迭代本地更新:

\[w^k\leftarrow w^k-\eta\nabla F_k(w^k) \]

每个客户端可以多次计算上式得到本地在第\(t\)轮的最终模型,最后中心服务器将这些本地模型进行聚合得到\(w^{t+1}\)

这就是FedAvg的思想,该算法主要有三个超参数:

  • \(C\):每次选择的客户端的比例
  • \(B\):本地训练时batchsize,当\(B=\infty\),即全批量
  • \(E\):本地训练轮数

\(B=\infty,E=1\)时,FedAvg和FedSGD等价

这里还定义了每轮的本地更新次数:\(u_k=E\frac{n_k}{B}\),由该公式也可以算出,FedSGD每轮本地更新次数为1。

完整的伪代码:

至此我们可以简单比较FedSGD和FedAvg:

算法 local server
FedSGD 计算本轮梯度 收集local的梯度,加权平均后作为server要下降的梯度
FedAvg 多次梯度下降,得到本轮的本地模型 收集local的模型,加权平均后作为本轮得到的模型

3 实验

3.1 模型初始化

聚合参数\(\theta\):以\(\theta w+(1-\theta)w^{'}\)对两个模型进行聚合,得到最终模型。

左图是使用两个初始模型\(w,w^{'}\)训练不同数据得到的损失,右图是两模型使用同一个\(w\)初始化训练不同数据,可以看出右边损失较小,且当\(\theta=0.5\)效果最好。因此在联邦学习实验中,每个客户端需要共享相同的初始化模型。

3.2 数据集和训练任务

选取大小适中的数据集,以便研究超参数

第一个任务是MNIST数字识别,使用两个模型:

  • 多层感知机。2个隐藏层,每个隐藏层有200个单元,使用ReLU激活。

    199210个参数:图像为\(28\times 28\),转为一维后是784。第一层\(784*200+偏置200\),第二层\(200*200+偏置200\),第三层\(200*10+偏置10\)

  • \(32*5*5\)卷积+\(2*2\)最大池化+\(64*5*5\)卷积+\(2*2\)最大池化+512单元全连接+ReLU+Softmax

数据集划分:

  • iid:划分100个客户端,每个客户端接收600张图。
  • 非iid:先按数字对图片进行排序,并划分成200个大小为300的碎片,给100个客户端每个分2个碎片,即每个客户端分到的数据只包含两个数字。

分出来的数据集有iid和非iid,但都是平衡的。

第二个任务是字符预测,使用LSTM,读取一行字符预测下一个字符。

数据集是莎士比亚全集,每个说话角色为一个客户端,共1146个。每个客户端,前80%的行是训练集,后20%行是测试集。

数据集划分:

  • iid:将所有文字平均划分给每个客户端。
  • 非iid:每个客户端仅有该角色说的话。

学习率设置在\(10^{\frac{1}{3}}\)\(10^{\frac{1}{6}}\)区间。

3.3 增加并行性

\(C\)控制并行量,因此先改变\(C\)

实验记录了MLP达到97测试集准确率和CNN达到99测试准确率所需要的通信轮数。

使用小批量,当\(C=0.1\)时效果就已经较好。为平衡计算效率和收敛速度,之后实验固定\(C=0.1\)

3.4 增加每个客户端的计算量

在FedAvg算法部分,我们已经指出,每轮本地更新次数为\(u_k=E\frac{n_k}{B}\)。在实验中设置独立同分布的更新次数为期望更新次数,即\(u=E\frac{n}{kB}\)

首先,对于两种任务,增加\(B\)都减少了通信轮数。

对于MNIST任务,iid效果比非iid更显著。实际生活我们设备上的数字也不会是规律性的,因此这种情况是该方法鲁棒的论证。

对于莎士比亚数据集,非iid效果很好,而这代表了我们在现实生活的数据分布(不同的人说话数量相差很大)。推测是某些客户端有较大的数据集,使本地训练更具有价值。

3.5 FedSGD vs FedAvg

可以看出,FedAvg不仅减少通信轮数,还提高了测试精度(蓝色实线是FedSGD)。推测是模型平均会产生类似dropout正则化的收益。

3.6 是否能过度优化客户端

对于非常大的本地迭代次数,FedAvg可能会停滞或发散。这一结果表明,对于某些模型,尤其是在收敛的后期阶段,减少每轮的本地计算量(即减小E或增大B)可能是有益的,就像衰减学习率一样。

3.7 CIFAR实验

数据集包含50000个训练数据和10000个测试数据,将其平均划分给100个客户端,每个客户端包含500个训练数据和100个测试数据。

使用的模型为两个卷积层+两个全连接层+一个线性变换层。

图像会经过裁剪为\(24*24\)、左右反转、调整对比度、亮度等预处理。

单机的SGD对比10个客户端的FedSGD和FedAvg:

现有的模型,对CIFAR分类任务的测试精度已经很高,但这里只要达到80%左右即可,原因是本文的目标是评估FedAvg方法,而非提高CIFAR测试精度。

不同学习率的影响:

3.8 大规模LSTM实验

为了证明方法在现实世界问题上有效,还在大规模的预测下一个单词任务上进行了实验。

训练数据集由来自大型社交网络的1000万个公开帖子组成。按作者对帖子进行了分组,总共有超过500,000名客户。文中将每个客户端的数据集限制为最多5000个单词,并对10000个作者的数据进行了测试。

原文链接:https://arxiv.org/abs/1602.05629

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/835413.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

高级语言程序设计课程第7次个人作业

2024高级语言程序设计:https://edu.cnblogs.com/campus/fzu/2024C/ 高级语言程序设计课程第7次个人作业:https://edu.cnblogs.com/campus/fzu/2024C/homework/13304 学号:102400128 姓名:吴俊衡 1: 问题:无2: 问题:刚开始没想出来怎么弄,后面递归不会就用了多个for循环3:…

TYPE-C PD浅谈(四)

TYPE-C PD浅谈(四) 当对接识别完成后,Provider会先在VBUS上提供5V,接着会在CC脚位上送出Source Capability(SRC_CAP),格式如下:内容定义了供电的各种选项,如共有几组电源可选,相对应的电压电流等。 当Consumer接收到SRC_CAP封包后,会针对电源列表的内容,挑选一组电压…

STM32F103开发

本节我们将会对STM32的硬件资源进行介绍,包括如下内容:点亮LED; 检测按键按下和松开事件; 串口; 点亮128*128 TFT_LCD液晶屏;一、点亮LED 1.1 电路原理图 LED电路原理图如下图所示:其中:LED1连接到PA8`引脚,低电平点亮; LED2连接到PD2引脚,低电平点亮;1.2 GPIO引脚…

团队项目Scrum冲刺-day7

一、每天举行站立式会议 站立式会议照片一张昨天已完成的工作成员 任务陈国金 协助代码沙箱Docker实现凌枫 创建题目页面陈卓恒 协助开发创建题目页面谭立业 协助开发创建题目页面廖俊龙 接口测试曾平凡 前端页面测试曾俊涛 代码沙箱Docker实现薛秋昊 协助代码沙箱Docker实现今…

爱码单车队-冲刺日志第四天

会议记录:今天主要是投入一些后端的开发任务,然后开始实现基础的登录绑定用户的功能。

2024-2025-1 20241314 《计算机基础与程序设计》第八周学习总结

2024-2025-1 20241314 《计算机基础与程序设计》第八周学习总结 作业信息这个作业属于哪个课程 <班级的链接>(2024-2025-1-计算机基础与程序设计)这个作业要求在哪里 2024-2025-1计算机基础与程序设计第八周作业这个作业的目标 功能设计与面向对象设计 面向对象设计过…

bolt.new只要5分钟就能完成1个网站,神奇惊艳

下面是我用bolt.new5分钟完成的首页,前端页面配色和css都很好看,我觉得很惊艳啦!网站截图如下# 网站介绍生成随机yes或者no答案的网站,每次点击oracle按钮,可以生成yes或者no的随机答案,帮助选择困难症用户,轻松选择yes还是no 网站网址https://yesnooracle.dev 功能特点…

2024-2025-1 20241415《计算机基础与程序设计》第八周学习总结

如2024-2025-1 20241415 《计算机基础与程序设计》第八周学习总结 作业信息这个作业属于哪个课程 2024-2025-1-计算机基础与程序设计这个作业要求在哪里 2024-2025-1计算机基础与程序设计第八周作业这个作业的目标 功能设计与面向对象设计,面向对象设计过程,面向对象语言三要…

爱码单车队-冲刺日志第二天

会议记录:在第二天的冲刺中,团队成员在接口文档设计与前端框架方面取得了一定进展。经过小组的共同努力,确定了前段设计与后端对接的接口, 为后续的开发工作奠定了基础。虽然目前前端设计与预期原型存在较大差距,但基本功能已经得以实现。 接下来,团队将继续优化功能,并…

猿人学web端爬虫攻防大赛赛题第20题——2022新春快乐

题目网址:https://match.yuanrenxue.cn/match/20 解题步骤解题之前需要先了解wasm是什么:https://docs.pingcode.com/ask/294587.html看数据包。sign是一串加密的字符串,t一看就是时间戳。全局搜索api/match/20,只有一处。打断点,触发。看下sign的生成逻辑。 "sign&q…

[Tricks-00004]CF1954F(自己胡的 trick,被 Burnside 完爆)

介绍下自己的离奇思路: 先读清楚题意!要求是旋转等价,即两个以 \(c\) 个 \(1\) 开头,总 \(1\) 个数不超过 \(k+c\) 的字符串算一种。 那怎么刻画"只算一种"这个条件呢?一个想法可以是,对每个字符串赋一个权值,一种字符串的权值即旋转出来的每个合法的,把它们…

工控机维修数据恢复

工控机是一种加固的增强型个人计算机,由于经常在环境比较恶劣的情况下运行,对数据的安全性要求也更高。 一、数据丢失的原因 用户误操作:如错误删除文件、不小心切断电源等,这些操作可能导致数据丢失或损坏。 入侵与感染:恶意程序可能破坏硬盘数据,甚至具有格式硬盘的功能…