史上最全面的ps-lite理解

概述

ps-lite是一个分布式参数服务器,具体什么是分布式,什么是参数服务器就不在此详述,talk is cheap, show me code。

代码

既然是分布式,那么我们就来看看整个框架有哪几部分。
在这里插入图片描述
可以看到有worker, server, scheduler.在这里我们假设有一个scheduler, 2个server,2个worker,既然是分布式,那么就假设分布在5台电脑上,在每台电脑上肯定要起一个进程,好了,我们首先先来启动一个scheduler。

scheduler node
全局观
  • 先启动一个节点
  • 等待各个work和sever发来的message
  • 一旦收集到4个message,就说明work和server都到齐了,这时给每一个work和server发一个message告诉他们对应的身份id,同时也让work去链接server,server去链接work
  • scheduler的初始化完成
具体代码

环境配置

export DMLC_NUM_SERVER=2
export DMLC_NUM_WORKER=2
export DMLC_PS_ROOT_URI='127.0.0.1'
export DMLC_PS_ROOT_PORT=8000 
export DMLC_ROLE='scheduler'

具体代码

一层代码
Start(0);//Start(int customer_id, const char* argv0 = nullptr) {Postoffice::Get()->Start(customer_id, argv0, true);
Finalize(0, true);//Finalize(int customer_id, const bool do_barrier = true) {Postoffice::Get()->Finalize(customer_id, do_barrier);
二层代码

首先出来了一个Postoffice的类型,每一个节点都有且只有一个Postoffice对象,具体如下:
在这里插入图片描述
用到什么成员变量和函数再说,不做一一介绍,首先看第一部分, 这里我把线程锁和一些无关紧要的逻辑到删除了,可以看到这个函数主要做了这几件事,首先是初始化node_ids_这个成员变量,接下的事情就是初始化变量van_, 最后执行了一个Barrier函数。

void Postoffice::Start(int customer_id, const char* argv0, const bool do_barrier) 
{
//读取环境变量InitEnvironment();//这一行核心内容就是这个:Van::Create(van_type="zmp"),此外还初始化一点成员变量
// init node_ids_直接看下面的图
// start vanvan_->Start(customer_id);  
// do a barrier hereif (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler);
}
三层代码

下面首先来看看node_ids_的初始化流程,代码逻辑非常简单,可以自己去查看,这里我们假设有两个work ,两个server, 具体的结果如下:
在这里插入图片描述
接下来看看van_这个成员变量的初始化, 我们知道van_其实是一个zmp对象,zmp继承于Van这个类,在这个类的基础上加了两个成员变量,分别是:unordered_map senders_ 和变量void *receiver_ = nullptr, senders_是一个集合,就是发送的消息的结合,比如8号节点要给9号节点发消息,那么只要找到(9,socket_9)这个组合就行了,然后调用socket_9.send(message), receiver_就只有一个,因为你节点对外肯定只有一个门户。 这个等用到的时候再说,由于大体上改变的不多,所以也可以对照父类的机构看看,具体如下:
在这里插入图片描述
具体开看看代码:

void Van::Start(int customer_id) 
{//初始化scheduler_这个成员变量scheduler_.hostname = "DMLC_PS_ROOT_URI";scheduler_.port ="DMLC_PS_ROOT_PORT";scheduler_.role = Node::SCHEDULER;scheduler_.id = kScheduler;//确认本节点是scheduler节点is_scheduler_ = true;//初始化本节点,因为是scheduler,所以直接就是等于赋值就行my_node_ = scheduler_;//绑定接口,把本节点绑定到ip:port这个socket上,理论来说这个函数就是初始化了receiver_Bind(my_node_,  0)//连接上scheduler_,由于本节点就是scheduler_,其实就是初始化senders_,由于发送的节点很多,所以这里是一个map// 在这里就是senders_[1] = socket_1, socket_1中的body设置一点字符“ps1***”, 注意链接不是sendMsg,这一点一定要闹清楚Connect(scheduler_);//开启一个接收消息的线程,其实这里就是一直待阻塞了,等到所有的work和server都发发来了消息receiver_thread_ =new thread(&Van::Receiving, this);//然后就是等着ready_啥时候从false变成true,当是scheduler的时候,必须要有等worker和server节点过来,不然一直都是阻塞在这。while (!ready_.load()) {this_thread::sleep_for(std::chrono::milliseconds(100));// 如果设置了超时重传,就初始化resender_这个变量resender_ = new Resender(timeout, 10, this);}
四层代码

接下来再往里面深入看看代码,上可以看到主要就是Bind函数,Connect函数,以及Receiving函数

//这个函数对schedule节点的话,你不需要指定port ,但是对于work和server需要自己查找一个本地可用端口。
int Bind(const Node& node, int max_retry) override{//在这里可以看到receiver_这个变量被初始化了,//是一个socket,下面绑定了具体的ip:port,每次RecvMsg(Message* msg)时候里面都要从这个socket读取。receiver_ = zmq_socket(context_, ZMQ_ROUTER);string hostname = node.hostname;string addr =  "tcp://" + hostname + ":";int port = node.port;unsigned seed = static_cast<unsigned>(time(NULL) + port);for (int i = 0; i < max_retry + 1; ++i) {auto address = addr + std::to_string(port);if (zmq_bind(receiver_, address.c_str()) == 0) break;if (i == max_retry) {port = -1;} else {port = 10000 + rand_r(&seed) % 40000;}}return port;}

connect开始初始化senders_,或者在后面的时候就补充

void Connect(const Node& node) override 
{int id = node.id;auto it = senders_.find(id);if (it != senders_.end()) {zmq_close(it->second);}//如果找到了对应socket就关闭socket// worker doesn't need to connect to the other workers. same for serverif ((node.role == my_node_.role) && (node.id != my_node_.id)) {return;}void *sender = zmq_socket(context_, ZMQ_DEALER);//建立一个socket//我们知道对于scheduler而言,一开始就是知道自己的id,为1,下面这一个if条件就是说把自己的id捆绑到当下socket上if (my_node_.id != Node::kEmpty) {std::string my_id = "ps" + std::to_string(my_node_.id);zmq_setsockopt(sender, ZMQ_IDENTITY, my_id.data(), my_id.size());const char* watermark = Environment::Get()->find("DMLC_PS_WATER_MARK");if (watermark) {const int hwm = atoi(watermark);zmq_setsockopt(sender, ZMQ_SNDHWM, &hwm, sizeof(hwm));}}// connectstring addr = "tcp://" + node.hostname + ":" + to_string(node.port);zmq_connect(sender, addr.c_str());//将sender这个socket和目标地址连接senders_[id] = sender;//将目标id的socket存放起来后面使用}

最后看看receiving这个函数,其实在这里就开始等待work和server节点的接入了,假如现在开始有一个work发来消息,消息是控制信息,具体指令是ADD_NODE.

void Van::Receiving() {Meta nodes;Meta recovery_nodes;  // store recovery nodes 储存康复的节点recovery_nodes.control.cmd = Control::ADD_NODE;// 康复节点的control都设置为add_nodewhile (true) {Message msg;int recv_bytes = RecvMsg(&msg);//利用receiver_这个变量拿到消息recv_bytes_ += recv_bytes;//收到的中字节数累加// duplicated messageif (resender_ && resender_->AddIncomming(msg)) continue;//重传机制先不看if (!msg.meta.control.empty()) //如果是控制类型的消息{// control msgauto& ctrl = msg.meta.control;if (ctrl.cmd == Control::TERMINATE) {ProcessTerminateCommand();break;} else if (ctrl.cmd == Control::ADD_NODE) {ProcessAddNodeCommand(&msg, &nodes, &recovery_nodes);//当执行到这个位置的时候继续跳转} else if (ctrl.cmd == Control::BARRIER) {ProcessBarrierCommand(&msg);} else if (ctrl.cmd == Control::HEARTBEAT) {ProcessHearbeat(&msg);} } else //非控制类型的消息处理方式{ProcessDataMsg(&msg);}}
}

接下来看看scheduler对于控制类型消息的处理:

void Van::ProcessAddNodeCommandAtScheduler(Message* msg, Meta* nodes, Meta* recovery_nodes)
{auto dead_nodes = Postoffice::Get()->GetDeadNodes(heartbeat_timeout_);//查出心跳包超时的idunordered_set<int> dead_set(dead_nodes.begin(), dead_nodes.end());//又给转存到dead_set里面auto& ctrl = msg->meta.control;//拿到收到消息里面的control信息//下面这个函数就比较骚,名字叫做更新节点ID,记住当下是在schedule节点,我们先下去看看这个函数。UpdateLocalID(msg, &dead_set, nodes, recovery_nodes);//上面的函数代码看完后继续往下走recovery_nodes->control.cmd = Control::ADD_NODE;//不知道为啥又写一边time_t t = time(NULL);size_t num_nodes = Postoffice::Get()->num_servers() + Postoffice::Get()->num_workers();//根据上面updatelocalId的函数,我们知道当下nodes还是没有收集齐全,一旦收齐后进入if条件中if (nodes->control.node.size() == num_nodes) {// sort the nodes according their ip and port,这个排序就是不说了,就是根据IP和port给work,server排个序std::sort(nodes->control.node.begin(), nodes->control.node.end(), [](const Node& a, const Node& b) {return (a.hostname.compare(b.hostname) | (a.port < b.port)) > 0; });// assign node rankfor (auto& node : nodes->control.node) {string node_host_ip = node.hostname + ":" + to_string(node.port);if (connected_nodes_.find(node_host_ip) == connected_nodes_.end()) //如果ip:port不存在van_中的话{CHECK_EQ(node.id, Node::kEmpty);//判断是不是初始化节点int id = node.role == Node::SERVER? Postoffice::ServerRankToID(num_servers_)//如果是sever的话,就id产生一个id号,num_servers_初始化为0: Postoffice::WorkerRankToID(num_workers_);PS_VLOG(1) << "assign rank=" << id << " to node " << node.DebugString();node.id = id;//将这个节点的id赋值为idConnect(node);//链接这个节点, 其实就是建立一个socket, 然后senders_[id] = sender;//将目标id的socket存放起来后面使用Postoffice::Get()->UpdateHeartbeat(node.id, t);//更新心跳包connected_nodes_[node_host_ip] = id;//你work发message来了,我这里要把这个节点作为已经链接的节点} else {int id = node.role == Node::SERVER? Postoffice::ServerRankToID(num_servers_): Postoffice::WorkerRankToID(num_workers_);shared_node_mapping_[id] = connected_nodes_[node_host_ip];node.id = connected_nodes_[node_host_ip];}if (node.role == Node::SERVER) num_servers_++;//更新rankif (node.role == Node::WORKER) num_workers_++;}nodes->control.node.push_back(my_node_);//要把本节点放到里面nodes->control.cmd = Control::ADD_NODE;Message back;back.meta = *nodes;//消息包装nodes,广播到每一个work,serverfor (int r : Postoffice::Get()->GetNodeIDs(kWorkerGroup + kServerGroup)) {int recver_id = r;if (shared_node_mapping_.find(r) == shared_node_mapping_.end()) {back.meta.recver = recver_id;back.meta.timestamp = timestamp_++;Send(back);}}PS_VLOG(1) << "the scheduler is connected to " << num_workers_<< " workers and " << num_servers_ << " servers";ready_ = true;//到这里可以看到scheduler已经显示准备好了,至于其他work和server收没收到啥的,我不管了。} else if (!recovery_nodes->control.node.empty()) {auto dead_nodes = Postoffice::Get()->GetDeadNodes(heartbeat_timeout_);std::unordered_set<int> dead_set(dead_nodes.begin(), dead_nodes.end());// send back the recovery nodeCHECK_EQ(recovery_nodes->control.node.size(), 1);Connect(recovery_nodes->control.node[0]);Postoffice::Get()->UpdateHeartbeat(recovery_nodes->control.node[0].id, t);Message back;for (int r : Postoffice::Get()->GetNodeIDs(kWorkerGroup + kServerGroup)) {if (r != recovery_nodes->control.node[0].id && dead_set.find(r) != dead_set.end()) {// do not try to send anything to dead nodecontinue;}// only send recovery_node to nodes already exist// but send all nodes to the recovery_nodeback.meta = (r == recovery_nodes->control.node[0].id) ? *nodes : *recovery_nodes;back.meta.recver = r;back.meta.timestamp = timestamp_++;Send(back);}}
}

这里面的msg就是一个work发来的消息,deadnodes_set先不用管,nodes是一个meta类型(一个message的数据头,具体数据结构可以看之前的类图或者源码),recovery_nodes也是。

void Van::UpdateLocalID(Message* msg, std::unordered_set<int>* deadnodes_set, Meta* nodes, Meta* recovery_nodes) 
{auto& ctrl = msg->meta.control;size_t num_nodes = Postoffice::Get()->num_servers() + Postoffice::Get()->num_workers();//num_nodes=4;// assign an idif (msg->meta.sender == Meta::kEmpty) //因为是work节点发过来的,而work节点初始化时候的id就是KEmpty.{CHECK(is_scheduler_);CHECK_EQ(ctrl.node.size(), 1);//msg中的control命令中的节点集合就是work自己,所以就是1个节点。if (nodes->control.node.size() < num_nodes) {nodes->control.node.push_back(ctrl.node[0]);} //因为sizes小于4else //如果四个work和server到齐了,就进入else{// some node dies and restartsCHECK(ready_.load());for (size_t i = 0; i < nodes->control.node.size() - 1; ++i) {const auto& node = nodes->control.node[i];if (deadnodes_set->find(node.id) != deadnodes_set->end() &&node.role == ctrl.node[0].role) {auto& recovery_node = ctrl.node[0];// assign previous node idrecovery_node.id = node.id;recovery_node.is_recovery = true;nodes->control.node[i] = recovery_node;recovery_nodes->control.node.push_back(recovery_node);break;}}}}// update my id, 其实对于scheduler的话这个函数没用,因为是work节点刚push进来,但是如果是schedule发给这个work这个几点的消息,如果发现本地的ip和port和消息中的某个一点重合,那么就把本地节点的ID(初始化时候没有ID,只是等于Empty)改为schedule发过来的身份证ID。for (size_t i = 0; i < ctrl.node.size(); ++i) {const auto& node = ctrl.node[i];if (my_node_.hostname == node.hostname && my_node_.port == node.port) {if (getenv("DMLC_RANK") == nullptr || my_node_.id == Meta::kEmpty) {my_node_ = node;string rank = to_string(Postoffice::IDtoRank(node.id));//max((id - 8) / 2, 0)setenv("DMLC_RANK", rank.c_str(), true);}}}
}
worker/server

由于是框架,所以上面代码基本都覆盖了,具体看代码的时候就是看一下不同的分支判断就行了。


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部