一、概述
首先感谢代码大佬的分享,我也是在大佬的基础上进行学习。此文既是笔记,也是分享,请结合代码一起食用。代码下载在第三节中给出链接。项目中使用多核CPU开启多进程模拟多机之间互相进行联邦学习的。
这是博主写的一些准备:https://blog.csdn.net/zzxxxaa1/article/details/121421075
里面介绍了python分布式的一些基础知识,如果对分布式和多线程不了解的话,这里十分重要。
二、系统介绍
(1)总体实现方式
使用多进程的方式来模拟server和worker之间的关系。其中主进程既要有server的功能,又要有client的功能。代码中的world_size就是进程数,即包括server在内共有多少个节点。
服务端 rank = 0
横向联邦学习的服务端的主要功能是将被选择的客户端上传的本地模型进行模型聚合。但这里需要特别注意的是,事实上,对于一个功能完善的联邦学习框架,比如我们FATE平台,服务端的功能要复杂得多,比如服务端需要对各个客户端节点进行网络监控、对失败节点发出重连信号等。此文由于是在本地模拟的,不涉及网络通信细节和失败故障等处理,因此不讨论这些功能细节,仅涉及模型聚合功能。
客户端 rank != 0
横向联邦学习的客户端主要功能是接收服务端的下发指令和全局模型,利用本地数据进行局部模型训练。与前一节一样,对于一个功能完善的联邦学习框架,客户端的功能同样相当复杂,比如需要考虑本地的资源(CPU、内存等)是否满足训练需要、当前的网络中断、当前的训练由于受到外界因素影响而中断等。读者如果对这些设计细节感兴趣,可以查看当前流行的联邦学习框架源代码和文档,比如FATE,获取更多的实现细节。此文我们仅考虑客户端本地的模型训练细节。我们首先定义客户端类Client,类中的主要函数包括以下两种。
- 定义构造函数。在客户端构造函数中,客户端的主要工作包括:首先,将配置信息拷贝到客户端中;然后,按照配置中的模型信息获取模型,通常由服务端将模型参数传递给客户端,客户端将该全局模型覆盖掉本地模型;最后,配置本地训练数据,在本案例中,我们通过torchvision 的datasets 模块获取cifar10 数据集后按客户端ID切分,不同的客户端拥有不同的子数据集,相互之间没有交集。
- 定义模型本地训练函数。本例是一个图像分类的例子,因此,我们使用交叉熵作为本地模型的损失函数,利用梯度下降来求解并更新参数值,实现细节如下面代码块所示。
(2)系统流程
- 获取配置参数(里面包括了模型选取,学习率,batch size,epoch等)
- 构建神经网络,获取权重参数(引入先前训练好的参数,帮助神经网络快速拟合。)
- 创建多线程,分别模拟服务端(server)和工作节点(worker)
- 进行全局迭代,每一轮迭代都要更新参数。
- 迭代结束,关闭进程组
关于系统整体流程,我做成了流程图看得更清晰。
(3)文件目录介绍
.extra_util:
- distributed_utils.py:用于构建进程和进程间的通信。
- model.py:关于模型构建的类定义。
- my_dataset.py:用于获取数据集长度、初始化、获取图片等。
- utils.py:用于数据集划分和图片显示,将图片和分类保存成json格式。
Federal_defense:
- class_mean.py:网络构建,制作数据集,做了一个小的demo来查看分布式的计算结果
- client_mean.py:这也是一个小demo,主要区别在local train里面。
根目录下的文件:
- client.py: 构建客户端函数,并实现初始化函数、训练函数和测试函数。
- distributed_minist.py: 这是一个minist数据集的一个测试demo。
- get_weight_matrix.py: 不同进程之间的类平均值的通信。
- main.py: 主函数,完整的流程。
- my_data_loader.py: 划分数据集和数据增强。
! 这里面的文件结构非常杂乱有些引用都没有用到,但是可以看得出作者的一点一点的研究过程,每个文件都可以去看看,或许会给你很多启发。
三、环境搭建和代码运行
(1)环境搭建
Python、pytorch
(2)代码运行
代码下载地址:https://gitee.com/wang-qifan-hust/pytorch_federated_learning?_from=gitee_search
在根目录中执行下面命令:python main.py
四、部分代码分析
(1)分布式实现代码( 伪代码,不可运行 )
import os
import copy
import socket
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.multiprocessing import Process
from torch.nn.parallel import DistributedDataParallel as DDP
from client import *
from Federal_defense.class_mean import Net
def run(rank, size):
""" 这里是伪代码,简单描述系统流程 """
model = Net() #构建神经网络
model.load_state_dict(weights, strict=False) # 载入权重
client = Client(args, rank)
dist.barrier() # 确保所有进程所对应的客户端client类实例化完成。
for i in epochs:
# 复制模型本地训练
local_model = copy.deepcopy(model)
diff = client.local_train(local_model)
# 更新所有进程参数
dist.all_reduce(value, op=dist.ReduceOp.SUM)
# 等待所有进程完成此操作
dist.barrier()
# 更新model的参数
model
#阻塞进程,等待所有进程都运行到这里
dist.barrier()
dist.destroy_process_group() # 释放进程组
def init_processes(rank, size, fn, backend='gloo'):
""" 初始化分布式环境 """
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # 无线局域网适配器WLAN所对应的IPV4地址
s.connect(("8.8.8.8", 80))
os.environ['MASTER_ADDR'] = s.getsockname()[0]
os.environ['MASTER_PORT'] = '29500'
# 等待所有进程准备好
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size)
if __name__ == "__main__":
size = 3 # 这个就是world_size,也就是进程数,目前最大可调整到5
mp.spawn(init_processes, # 主运行函数,初始化进程
args=(size, run), # 主函数的参数,绑定run函数
nprocs=size, # 当前节点的进程数
join=True) # 加入同一进程池
(1)客户端类
# client.py
import os
import sys
import math
import copy
import torch
import numpy as np
import torch.nn.functional
import torch.utils.data
import torch.optim.lr_scheduler as lr_scheduler
sys.path.append('extra_utils')
from extra_utils.model import resnet34, resnet101, mnist_Net
from extra_utils.distributed_utils import init_distributed_mode, cleanup, is_main_process
import torch.distributed as dist
from tqdm import tqdm
# 构建客户端对象
class Client(object):
def __init__(self, args, conf, train_loader, train_sampler, eval_loader, rank):
self.conf = conf
self.rank = rank
self.train_loader = train_loader
self.eval_loader = eval_loader
self.train_sampler = train_sampler
def local_train(self, model, args, global_epoch, cost_list, train_length, accuracy_list,
accuracy_1, accuracy_2, rank, cost_1, cost_2):
# train_length是方便我们每隔一段时间画出损失函数点
# for name, param in model.state_dict().items(): # 遍历模型参数
# self.local_model.state_dict()[name].copy_(param.clone()) # 将全局模型复制一份到本地训练的模型
local_model = copy.deepcopy(model) # 拷贝拷贝可变类型就是完完全全拷贝了一份,是完完全全的两个内存,互不干扰
# 是否冻结权重 我们默认不冻结
if args.freeze_layers:
for name, para in local_model.named_parameters():
# 除最后的全连接层外,其他权重全部冻结
if "fc" not in name:
para.requires_grad_(False) # 即只训练全连接层
pg = [p for p in local_model.parameters() if p.requires_grad] # 将我们需要训练的各层参数(全连接层)以 列表生成式的方式 生成列表。
optimizer = torch.optim.Adam(pg, lr=self.conf['lr']) # 注意此时更新的不是model模型,而是pg列表所对应的模型参数
else:
optimizer = torch.optim.Adam(local_model.parameters(), lr=self.conf['lr'])
# 若冻结权重,则带有BN结构的网络不会被训练,而训练带有BN结构的网络时使用SyncBatchNorm才有意义
if args.syncBN:
# 使用SyncBatchNorm后训练会更耗时
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
# 将模型中所有的BN层替换成具有同步功能的BN层
lf = lambda x: ((1 + math.cos(x * math.pi / self.conf["local_epochs"])) / 2) * (1 - args.lrf) + args.lrf # cosine
# 利用lambda函数定义一个输入参数x和函数lf的关系。
# 使用随机梯度下降算法作为优化算法
# 需要训练的全连接层参数、初始学习率、动量、正则项
criterion = torch.nn.CrossEntropyLoss()
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
local_model.train() # 进入模型训练模式
for epoch in range(self.conf["local_epochs"]):
self.train_sampler.set_epoch(epoch) # 这种方法使得我们的各个设备在每一轮所获得的数据都是不一样的。
if is_main_process():
self.train_loader = tqdm(self.train_loader, file=sys.stdout)
total_loss = 0
running_loss = 0
local_loss = 0
for batch_id, batch in enumerate(self.train_loader):
inputs, target = batch
if rank == 2:
for j in range(len(target)): # 标签翻转攻击
if target[j] == 1:
target[j] = 4
optimizer.zero_grad() # 梯度清零
output = local_model(inputs) # 前向传播算法
loss = criterion(output, target) # 使用交叉熵计算损失值
loss.backward() # 反向传播计算得到了梯度
total_loss += loss.item()
running_loss += loss.item()
optimizer.step() # 利用反向传播得到的梯度,利用优化算法更新网络参数(权重)
scheduler.step() # 更新学习率
if (batch_id + 1) % (train_length / (self.conf["batch_size"]*5*self.conf["world_size"])) == 0:
# 每次本地训练共取出5个点来绘制。
cost_list.append(running_loss)
# print(running_loss)
# 使用命令行参数控制是否打印其他进程图线
if not args.only_0:
# 用于绘制除rank0之外的其他进程的图线的Loss代码。
loss_1 = torch.FloatTensor([running_loss]) # 方便发送接收rank1的准确率
loss_2 = torch.FloatTensor([running_loss]) # 方便发送接收rank2的准确率
if is_main_process():
dist.recv(loss_1, src=1) # 接收来源于rank1的数据,并将其覆盖于loss_1
dist.recv(loss_2, src=2) # 接收来源于rank2的数据,并将其覆盖于loss_2
elif rank == 1:
dist.send(loss_1, dst=0) # 如果是rank1,发送数据acc1
else:
dist.send(loss_2, dst=0) # 如果是rank2,发送数据acc2
dist.barrier()
loss_1 = loss_1.item() # 取出tensor的data数据。避免最后plt.plot出错如下。
loss_2 = loss_2.item() # VisibleDeprecationWarning: Creating an ndarray from ragged nested...
if is_main_process():
cost_1.append(loss_1)
cost_2.append(loss_2)
running_loss = 0
# 我们要对所有的进程执行此操作,否则会引起每个进程微小的误差。
# 此部分会极大地拉慢程序运行速度。
with torch.no_grad():
local_model.eval() # 进入模型评估模式
correct = 0
dataset_size = 0
for batch_id, batch in enumerate(self.eval_loader):
inputs, target = batch
dataset_size += inputs.size()[0]
output = local_model(inputs)
local_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
pred = output.data.max(1)[1]
correct += pred.eq(target.data.view_as(pred)).sum().item()
acc = 100.0 * (float(correct) / float(dataset_size)) # 准确率
accuracy_list.append(acc)
# print(acc)
# 使用命令行参数控制是否打印其他进程图线
if not args.only_0:
# 用于绘制除rank0之外的其他进程的图线的准确率代码。
acc1 = torch.FloatTensor([acc]) # 方便发送接收rank1的准确率
acc2 = torch.FloatTensor([acc]) # 方便发送接收rank2的准确率
if is_main_process():
dist.recv(acc1, src=1) # 接收来源于rank1的数据,并将其覆盖于acc1
dist.recv(acc2, src=2) # 接收来源于rank2的数据,并将其覆盖于acc2
elif rank == 1:
dist.send(acc1, dst=0) # 如果是rank1,发送数据acc1
else:
dist.send(acc2, dst=0) # 如果是rank2,发送数据acc2
dist.barrier()
acc1 = acc1.item() # 取出tensor的data数据。避免最后plt.plot出错如下。
acc2 = acc2.item() # VisibleDeprecationWarning: Creating an ndarray from ragged nested...
if is_main_process():
accuracy_1.append(acc1)
accuracy_2.append(acc2)
# print(acc)
# print(acc1)
# print(acc2)
local_loss = 0
dist.barrier() # 防止打印进度条的时候,会由于其他进程先训练完成而导致打印出来的字符串混乱掉进度条。
if is_main_process():
print("Rank %d, Global_epoch [%d/%d], Local Epoch [%d/%d] loss : %f."
% (self.rank, global_epoch + 1, self.conf["global_epochs"], epoch + 1,
self.conf["local_epochs"], total_loss))
# 各进程loss大小差距不大,个人认为应该是因为迁移学习所造成的
# 取出权重矩阵
weight_matrix = []
bias_matrix = []
for name, parm in local_model.named_parameters():
if (name == 'module.fc3.weight') | (name == 'fc3.weight'): # 可以将最后一层全连接层的权重矩阵取出
weight_matrix = parm.detach()
elif (name == 'module.fc3.bias') | (name == 'fc3.bias'):
bias_matrix = parm.detach()
bias_matrix = bias_matrix.unsqueeze(1)
if is_main_process():
print(bias_matrix.size())
# tensor.detach()返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,
# 不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。
mean = torch.ones([weight_matrix.size()[1], 1])
mean = mean / weight_matrix.size()[1]
class_mean = weight_matrix.mm(mean) # 利用矩阵乘法求得 此客户端下 该轮次的 类平均值torch.Size([10, 1])
class_mean = class_mean + bias_matrix / weight_matrix.size()[1]
diff = dict() # 生成一个空的字典
for name, data in local_model.state_dict().items(): # 遍历更新之后的各层模型参数。并返回每层对应的名字(name)和数据。
# print(data != model.state_dict()[name]) # 用于打印出来是否参数相等
diff[name] = (data - model.state_dict()[name]) # 将当前name和全局模型所对应name的数据进行相减,得到权重大小的变化量即权重差
# print(diff[name])
return diff, class_mean # 返回网络参数的变化.value为tensor类型
# 模型评估
@torch.no_grad() # 装饰器的方法实现with.no_grad()
# 我们模型评估的时候使用全部测试集
def model_eval(self, model):
model.eval() # 进入模型评估模式
total_loss = 0.0
correct = 0
dataset_size = 0
if is_main_process():
self.eval_loader = tqdm(self.eval_loader, file=sys.stdout)
for batch_id, batch in enumerate(self.eval_loader): # batch_id就为enumerate()遍历集合所返回的批量序号
inputs, target = batch # 得到数据集和标签
dataset_size += inputs.size()[0] # data.size()=[batch,通道数,32,32]、target.size()=[batch]
output = model(inputs)
if self.conf["type"] == "mnist":
total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
elif self.conf["type"] == "flower":
total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
elif self.conf["type"] == "cifar":
total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
else:
raise TypeError("Not find Appropriate mode.")
# sum up batch loss
# .data意即将变量的tensor取出来
# 因为tensor包含data和grad,分别放置数据和计算的梯度
pred = output.data.max(1)[1] # get the index of the max log-probability
# 按照从左往右的 第一维 取出最大值的索引 torch.max()
correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
# torch.view_as(tensor)即将调用函数的变量,转变为同参数tensor同样的形状
# torch.eq()对两个张量tensor进行逐元素比较,如果相等则返回True,否则返回False。True和False作运算时可以作1、0使用
# .cpu()这一步将预测结果放到cpu上,利用电脑内存存储列表值。从而避免测试过程中爆显存。
# .sum()是将我们一个批量的预测值求和,便于累加到correct变量中。
# .item()取出 单元素张量的元素值 并返回该值,保持原元素类型不变。
acc = 100.0 * (float(correct) / float(dataset_size)) # 准确率
aver_loss = total_loss / dataset_size # 平均损失
return acc, aver_loss
五、总结
总体来说,这位大佬分享的代码非常好,代码注释也十分多,给我十分大的帮助。只是说这个文件目录确实有些头疼。看了半天才看明白。不过也能够看得出大佬的研究历程。在这个项目里真正的体会到了分布式计算,模拟多个节点一起协作的场景。由理论学习转到了实战演示,帮助非常大。
这里面比较头疼的是各种数据的变换和多进程之间的协作,需要一步一步的去理解,有些头疼。