理解: Federated Learning

Google 近日 (大概是3月7日) 开源了TensorFlow Federated (TFF)框架。个人觉得这个框架在工程应用领域很有价值,所以对其核心原理做一个整理和解读,以供参考。内容主要参考Google的两篇论文:
[1] McMahan, et. Communication-Efficient Learning of Deep Networks from Decentralized Data. 28 Feb 2017
[2] Keith Bonawitz, et. Towards Federated Learning at Scale: System Design. 4 Feb 2019

0x01 模型算法

Federated Learning 的特点主要有三:1) 训练数据更为真实;2) 训练过程中,无需将敏感数据集中到数据中心; 3) 对于监督学习,可以很自然的通过用户交互获取数据的labels。然而个人认为,在实际工程中应用Federeated Learning模型,可以将一些特殊功能从硬编码中解放出来,我们可以再应用上线初期构建一个基本的机器学习模型,并在用户的使用过程中不断的训练该模型,以优化用户体验。
广义上来讲,Federated Learning可以看作是一种分布式机器学习的模型,然而其优化特性又与经典分布式机器学习模型截然不同,根据文献[1] 可以总结出几个特性:

  • Non-IID( Non-independent and identically distributed ):分散各个设备端的数据集因为用户的特异性而存在较大的差异;
  • Unbalanced:因为用户对于服务端的访问频率差异,使得各个设备端训练数据量存在不同;
  • Massively distributed:设备端数量远远大于设备端训练模型的数量;
  • Limited communication:因为网络等原因,使得设备端访问连接受到限制。

当然,FL模型包含用户数据隐私问题,但本文并不作过多关注。现在假设又K个用户,选取其中C用户比例用于一轮机器学习训练,那么FL所关注的非凸机器学习模型可以表示为
\underset{\omega \in \mathbb{R}^{d}}{min} f(\omega) \qquad 这里\qquad f(\omega) \overset{\text{def}}{=} \frac{1}{n} \sum_{i=1}^n f_i(\omega)其中f_i(\omega)=\ell(x_i,y_i;\omega),即样本数据(x_i,y_i)以及参数\omega。如果假设又K个用户数据分区,那么用\mathcal{P_k}表示k用户的数据点索引的集合,令n_k=|\mathcal{P}_k|,因此将上述模型改写为
f(\omega)=\sum_{k=1}^K \frac{n_k}{n} F_k(\omega) \qquad 这里 \qquad F_k(\omega)=\frac{1}{n_k}\sum_{i \in \mathcal{P}_k} f_i(\omega)这里对于独立同分布(IID)数据,那么对于F_k(\omega)显然存在关系
\mathbb{E}_{\mathcal{P}_k}[F_k(\omega)]=f(\omega)然而对于FL模型的数据集则无法得到该关系。当然我们更为关心Federated Learning 模型的SGD算法,以一个典型实现为例,令C=1(实际上,C决定了每轮迭代的batch size)且合适的学习率\eta,对于每个用户k计算g_k=\nabla F_k(\omega_t)\omega_t表示当前用户设备的模型,而对于服务端则收集此些模型,并利用g_k的平均值进行模型更新\omega_{t+1} \leftarrow \omega_t+\eta \sum\nolimits_{k=1}^K \frac{n_k}{n}g_k,又有\sum\nolimits_{k=1}^K \frac{n_k}{n}g_k=\nabla f(\omega),因此等效的有
\forall k,\omega_{t+1}^k \leftarrow \omega_t + \eta g_k于是
\omega_{t+1} \leftarrow \sum\nolimits_{k=1}^K \frac{n_k}{n} \omega_{t+1}^k如此便是,每个用户使用本地数据集对模型进行训练,然后server端搜集训练结果并根据权重计算梯度平均值。当然也可以计算梯度平均值之前在用户设备端进行多次迭代
\omega^k \leftarrow \omega^k + \eta \nabla F_k(\omega_k)此即FederatedAveraging (or FedAvg)算法。用E表示每轮每个用户设备端训练次数,而用B表示设备端本地的 minibatch size,当B=\inftyE=1时即为FedSGD,而对于一个客户端n_k个样本,每轮更新数量可由u_k=E\frac{n_k}{B}得出。
由此,文献中给出一段算法伪代码

\text{Server excutes}:
   initialize \omega_0
   for each round t=1,2,\cdots do
        m \leftarrow max(C\cdot K, 1)
        S_t \leftarrow (\text{random set of } m \text{ clients})
        for each client k \in S_t in parallel do
            \omega_{t+1}^k \leftarrow \text{ClientUpdate}(k, \omega_t)
        \omega_{t+1} \leftarrow \sum\nolimits_{k=1}^K \frac{n_k}{n} \omega_{t+1}^k

\text{ClientUpdate}(k, \omega_t): // Run on client k
  \mathcal{B} \leftarrow (\text{split } \mathcal{P}_k \text{ into batches of size }B)
  for each local epoch i from 1 to E do
        for batch b \in \mathcal{B} do
            \omega \leftarrow \omega - \eta \nabla \ell (\omega;b)
  return \omega to server

0x02 架构设计

根据上述Federated Learning模型以及FedSGD算法,Google团队基于TensorFlow设计并开源了FL框架。如下图是其架构的网络协议
Federated Learning Protocol [2]

,在FL任务开始时,服务器筛首先选出所有设备端的一个有效子集作为本轮FL任务的执行者,继而服务器向所有子集中的设备发送数据,数据主要包括TF的计算图以及执行该计算图的方法。而在每轮训练任务中,服务端于本轮开始时需要向设备端发送当前模型的超参数以及从checkpoint得到的必要状态数据。之后每个接收到任务的设备根据全局参数和状态数据以及本地数据集执行计算任务,并将更新发送到服务端,最后服务端合并所有设备更新(即执行AvgFed算法)。该过程中包含了三个主要的Phase: 1) Selection, 即从设备集群中筛选有效设备子集,例如设计设备,通常要确认设备是否是闲置状态,充电状态以及是否为计费网络等等因素;2) Configure,server根据全局模型的聚合机制进行配置配置,向连接的有效设备分发PL计划;3) Reporting,Server端接受设备提交更新,并根据AvgFed进行更新全局模型,当然这里有一个裁定机制,即根据有效设备子集返回情况。

由此可见,对于一个设备端的应用设计主要包括连接,获取模型和参数状态数据,执行计算,最后提交更新,当然,本文没有过多关注隐私数据处理,简而言之客户端应用架构设计如下图所示
Device
对于Server端,Google采用了一个松耦合的编程模型——Actor Modal。并实现一个自顶向下的框架结构
Actors in the FL Server Architecture.png
Coordinators是top-level的actor,负责全局同步以及在锁定步骤中推进训练迭代。在Server端有多个Coordinators,每一个都负责一个FL设备集群,每一个Coordinators都注册一个地址以及一个FL集群(FL population)。随意Coordinator与FL Population形成一一对应的管理结构。而Selector负责接收和转发设备的连接,同时它也会定期搜集FL 集群的设备信息,并决定是否接受设备。在主聚合器(Master Aggregator)和聚合器(Aggregator)生成后,Coordinator会指示Selector将其FL集群子集转接到聚合器。而主聚合器(Master Aggregator)负责管理每个FL任务的迭代周期,它可以根据FL集群和更新提交数量来生成聚合器,以实现弹性聚合计算。

当然,除了实现上述松耦合以及弹性计算的架构之外,系统还要考虑敏感数据的安全性问题,但本文不再赘述。综上所述,是对Federated Learning模型以及系统架构设计的简单整理和理解,根据Google提供的工具,我们还可以对编程过程构建一个工作链
Model Engineer Workflow

0x03 一点想法

对于经典分布式学习来说,往往为了训练一个优秀的模型,而不得不构建一个庞大的计算集群和数据存储集群,这对于中小型企业来说无疑是过于昂贵且复杂的。并且如此体量的数据集构建也是一个极其重要的问题,而对于Federated Learning来说,种种这些问题都迎刃而解,且所使用的数据集更为真实。所以从工程落地角度来讲FL更具备价值。
然而,TFF架构设计是基于现有经典非凸模型以及TensorFlow 实现的,严格意义上讲,FL构建了一个集群模型训练环境,而非一个可拓扑的神经网络结构。这让我们不得不去憧憬一个想法,利用P2P等类似的点对点网络构建一个大的社会级神经网络,那会得到怎么样的结果呢?

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,732评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 87,496评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,264评论 0 338
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,807评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,806评论 5 368
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,675评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,029评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,683评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 41,704评论 1 299
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,666评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,773评论 1 332
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,413评论 4 321
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,016评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,978评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,204评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,083评论 2 350
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,503评论 2 343

推荐阅读更多精彩内容