Google 近日 (大概是3月7日) 开源了TensorFlow Federated (TFF)框架。个人觉得这个框架在工程应用领域很有价值,所以对其核心原理做一个整理和解读,以供参考。内容主要参考Google的两篇论文:
[1] McMahan, et. Communication-Efficient Learning of Deep Networks from Decentralized Data. 28 Feb 2017
[2] Keith Bonawitz, et. Towards Federated Learning at Scale: System Design. 4 Feb 2019
0x01 模型算法
Federated Learning 的特点主要有三:1) 训练数据更为真实;2) 训练过程中,无需将敏感数据集中到数据中心; 3) 对于监督学习,可以很自然的通过用户交互获取数据的labels。然而个人认为,在实际工程中应用Federeated Learning模型,可以将一些特殊功能从硬编码中解放出来,我们可以再应用上线初期构建一个基本的机器学习模型,并在用户的使用过程中不断的训练该模型,以优化用户体验。
广义上来讲,Federated Learning可以看作是一种分布式机器学习的模型,然而其优化特性又与经典分布式机器学习模型截然不同,根据文献[1] 可以总结出几个特性:
- Non-IID( Non-independent and identically distributed ):分散各个设备端的数据集因为用户的特异性而存在较大的差异;
- Unbalanced:因为用户对于服务端的访问频率差异,使得各个设备端训练数据量存在不同;
- Massively distributed:设备端数量远远大于设备端训练模型的数量;
- Limited communication:因为网络等原因,使得设备端访问连接受到限制。
当然,FL模型包含用户数据隐私问题,但本文并不作过多关注。现在假设又个用户,选取其中用户比例用于一轮机器学习训练,那么FL所关注的非凸机器学习模型可以表示为
其中,即样本数据以及参数。如果假设又个用户数据分区,那么用表示用户的数据点索引的集合,令,因此将上述模型改写为
这里对于独立同分布(IID)数据,那么对于显然存在关系
然而对于FL模型的数据集则无法得到该关系。当然我们更为关心Federated Learning 模型的SGD算法,以一个典型实现为例,令(实际上,决定了每轮迭代的batch size)且合适的学习率,对于每个用户计算,表示当前用户设备的模型,而对于服务端则收集此些模型,并利用的平均值进行模型更新,又有,因此等效的有
于是
如此便是,每个用户使用本地数据集对模型进行训练,然后server端搜集训练结果并根据权重计算梯度平均值。当然也可以计算梯度平均值之前在用户设备端进行多次迭代
此即FederatedAveraging (or FedAvg)算法。用表示每轮每个用户设备端训练次数,而用表示设备端本地的 minibatch size,当且时即为FedSGD,而对于一个客户端个样本,每轮更新数量可由得出。
由此,文献中给出一段算法伪代码
:
initialize
for each round do
for each client in parallel do
: // Run on client
for each local epoch from to do
for batch do
return to server
0x02 架构设计
根据上述Federated Learning模型以及FedSGD算法,Google团队基于TensorFlow设计并开源了FL框架。如下图是其架构的网络协议,在FL任务开始时,服务器筛首先选出所有设备端的一个有效子集作为本轮FL任务的执行者,继而服务器向所有子集中的设备发送数据,数据主要包括TF的计算图以及执行该计算图的方法。而在每轮训练任务中,服务端于本轮开始时需要向设备端发送当前模型的超参数以及从checkpoint得到的必要状态数据。之后每个接收到任务的设备根据全局参数和状态数据以及本地数据集执行计算任务,并将更新发送到服务端,最后服务端合并所有设备更新(即执行AvgFed算法)。该过程中包含了三个主要的Phase: 1) Selection, 即从设备集群中筛选有效设备子集,例如设计设备,通常要确认设备是否是闲置状态,充电状态以及是否为计费网络等等因素;2) Configure,server根据全局模型的聚合机制进行配置配置,向连接的有效设备分发PL计划;3) Reporting,Server端接受设备提交更新,并根据AvgFed进行更新全局模型,当然这里有一个裁定机制,即根据有效设备子集返回情况。
当然,除了实现上述松耦合以及弹性计算的架构之外,系统还要考虑敏感数据的安全性问题,但本文不再赘述。综上所述,是对Federated Learning模型以及系统架构设计的简单整理和理解,根据Google提供的工具,我们还可以对编程过程构建一个工作链
0x03 一点想法
对于经典分布式学习来说,往往为了训练一个优秀的模型,而不得不构建一个庞大的计算集群和数据存储集群,这对于中小型企业来说无疑是过于昂贵且复杂的。并且如此体量的数据集构建也是一个极其重要的问题,而对于Federated Learning来说,种种这些问题都迎刃而解,且所使用的数据集更为真实。所以从工程落地角度来讲FL更具备价值。
然而,TFF架构设计是基于现有经典非凸模型以及TensorFlow 实现的,严格意义上讲,FL构建了一个集群模型训练环境,而非一个可拓扑的神经网络结构。这让我们不得不去憧憬一个想法,利用P2P等类似的点对点网络构建一个大的社会级神经网络,那会得到怎么样的结果呢?