1. KVStore里的Barrier
在mxnet的分布式训练里,主要模式就是参数服务器。每个worker或者agent就是一台machine,server用于参数的更新。那么,当我们期望在不同的worker之间进行同步的时候,就会需要到barrier
这个方法。
当代码运行在worker的时候,我们可以通过调用kv._barrier()
来进行同步。它的作用就是,会阻塞代码运行,直到每个worker都运行了kv._barrier()
。然后接着运行。这样就实现了同步。
那么它是怎么做到的呢?
通过源码,我们不难发现,python端的接口调用了c++端的方法:
void Barrier() override {
ps::Postoffice::Get()->Barrier(ps_worker_->get_customer()->customer_id(), ps::kWorkerGroup);
}
这个全局的Postoffice
的Barrier
方法的部分源码如下:
void Postoffice::Barrier(int customer_id, int node_group) {
// 省略部分代码
// 省略部分代码
std::unique_lock<std::mutex> ulk(barrier_mu_);
barrier_done_[0][customer_id] = false;
Message req;
req.meta.recver = kScheduler;
req.meta.request = true;
req.meta.control.cmd = Control::BARRIER;
req.meta.app_id = 0;
req.meta.customer_id = customer_id;
req.meta.control.barrier_group = node_group;
req.meta.timestamp = van_->GetTimestamp();
CHECK_GT(van_->Send(req), 0);
barrier_cond_.wait(ulk, [this, customer_id] {
return barrier_done_[0][customer_id];
});
}
可以看到该方法会首先对barrier_mu_
上锁,之后将对应的barrier_done_
设置为false
。然后将这次的barrier信息发送给scheduler。告诉scheduler需要进行一次barrier。然后就阻塞等待barrier_done_
被设置为true
,代表完成了barrier,也就是其他的worker也都进行了barrier。
那么问题就变成了,每个worker都是怎么直到其他worker也进行了barrier的?
首先我们要知道,在参数服务器也就是PS中,每个进程都会建立kvstore。如果是worker,会在构造函数中运行如下代码:
if (IsWorkerNode()) {
int new_customer_id = GetNewCustomerId();
ps_worker_ = new ps::KVWorker<char>(0, new_customer_id);
ps::StartAsync(new_customer_id, "mxnet\0");
if (!ps::Postoffice::Get()->is_recovery()) {
ps::Postoffice::Get()->Barrier(
new_customer_id,
ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
}
}
其中ps::StartAsync
如下:
inline void StartAsync(int customer_id, const char* argv0 = nullptr) {
Postoffice::Get()->Start(customer_id, argv0, false);
}
也就是说,worker在建立起ps_worker_
后,开始运行postoffice,而postoffice的Start
会进行一系列的操作,并调用van_->Start
,接着van
的Start
会进行一系列的初始化后,开启接受消息的线程,也就是
receiver_thread_ = std::unique_ptr<std::thread>(
new std::thread(&Van::Receiving, this));
而receiving
函数会使用ProcessBarrierCommand
处理barrier信号,该函数会++barrier_count_[group]
,也就是将对应group的barrier次数进行统计。当barrier_count_[group]
等于这个group的个数的时候。它会发送类似于ACK的返回信息。
然后worker会调用Manage
方法来处理该message。Manage
发现是barrier的返回信息,将barrier_done_设置为true
,然后将等待的线程唤醒。也就是python端调用barrier后被阻塞的地方。
至此,就完成了一次worker之间的barrier。