美文网首页MXNet
MxNet源码解析(1) KVStore,pslite源码解析

MxNet源码解析(1) KVStore,pslite源码解析

作者: Junr_0926 | 来源:发表于2018-08-26 01:52 被阅读0次

    1. 前言

    从毕业开始工作已经两个多月,这期间相当一部分的时间都用在了对MxNet的学习上,而在MxNet的众多部分中,又是pslite这一部分接触最多。因此,今天将我一直以来的学习过程中的心得和收获总结在这里,也为以后对MxNet的继续学习做一个铺垫

    2. MxNet构成

    MxNet作为一个深度学习框架,它最大的特点应该是分布式训练的支持了。从初次接触MxNet到现在的两个多月里,我认为MxNet主要有以下几个大的部分:

    • symbol和graph,负责构建计算的图和反向传播的图
    • Engine,负责根据图的节点的依赖关系,并行运算
    • Parameter server和KVStore,负责参数的同步和传递
    • operator,定义图的节点的op
    • NDArray,数据的存储和计算
    • Executor,对图进行处理用于计算

    3. Parameter server

    参数服务器的概念并不复杂,主要思想就是,将模型的参数保存在server中,另外通过worker来完成具体的计算任务,当每完成一个计算任务时就会得到对应参数的梯度,这时将梯度传送给server,由server来完成参数的更新,worker再从server那里取回更新后的参数。
    在MxNet中,当我们需要进行分布式的训练时,就需要使用到它了。在MxNet中,为了完成参数在不同机器前的同步和更新,主要实现了两大部分。一是pslite,另一个是KVStore

    3.1 KVStore

    为了更方便理解,我从KVStore开始讲起。在MXNet中,可能很多人并不会直接操作KVStore,在官方文档中,甚至提到,不建议直接操作KVStore,但是,每个人在使用MXNet的过程中,都肯定用到了KVStore。其实,在我们建立module.Module的时候,就会调用KVStorepushpull操作。
    kvstore主要分为两种,一种是单机下,一种是多机下。单机下又分为将参数存储在GPU显存和CPU内存上两种情况。

    3.1.1 comm.h

    comm.h文件中定义了Comm类,该类用于设备间的信息传递,也就是communication。从Comm类中派生出了两个子类CommCPU用于CPU内存通信,CommDevice用于GPU通信。

    Comm类中定义了几个纯虚函数:

    • Init:根据存储类型和shape初始化
    • Reduce:输入NDArray的一个vector,返回它们的和
    • Broadcast:将一个NDArray复制到vector中的每一个元素
    • BroadcastRowSparse

    CommCPU

    将数据复制到CPU内存中,在那里做操作。

    • Init:初始化key对应的KVStore,创建key对应的NDArray,保存在merge_buf_[key].merged中。(不分配内存)
    • Reduce:将输入的vector<NDArray>& src的每个元素求和并返回。当src只有一个元素时,若不是sparse的就直接返回src[0],若是则将src[0]拷贝至merged_buf返回。如果src元素多于一个,那么:
    if (stype == kDefaultStorage) {
          std::vector<Engine::VarHandle> const_vars(src.size() - 1); // 定义engine pushasync的输入,用于engine根据该操作的输入来规划操作的执行
          std::vector<NDArray> reduce(src.size());
          CopyFromTo(src[0], &buf_merged, priority);
          reduce[0] = buf_merged;
    
          if (buf.copy_buf.empty()) { // copy_buf用于GPU数据拷贝至CPU,由于第0个元素存储在buf_merged,这里只需要src.size()-1个
            buf.copy_buf.resize(src.size()-1);
            for (size_t j = 0; j < src.size() - 1; ++j) {
              // allocate copy buffer
              buf.copy_buf[j] = NDArray(
                src[0].shape(), pinned_ctx_, false, src[0].dtype());
            }
          }
          CHECK(stype == buf.copy_buf[0].storage_type())
               << "Storage type mismatch detected. " << stype << "(src) vs. "
               << buf.copy_buf[0].storage_type() << "(buf.copy_buf)";
          for (size_t i = 1; i < src.size(); ++i) {
            CopyFromTo(src[i], &(buf.copy_buf[i-1]), priority); // 定义拷贝操作
            reduce[i] = buf.copy_buf[i-1];
            const_vars[i-1] = reduce[i].var(); // 定义拷贝操作的输入
          }
    
          Engine::Get()->PushAsync( // push该操作至engine,engine会根据输入来规划什么时候执行
            [reduce, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
              ReduceSumCPU(reduce);
              on_complete();
            }, Context::CPU(), const_vars, {reduce[0].var()},
            FnProperty::kCPUPrioritized, priority, "KVStoreReduce");
    
        }
    
    • Broadcast

    CommDevice

    • Init:将存储类型和shape信息存储在sorted_key_attrs_中。
    • InitBuffersAndComm:将vector<NDArray>& src的context信息存储在devs中,通过InitMergeBuffersorted_key_attrs_信息,将所有的KVPairs分别存储在GPU上。
    • Reduce:和CommCPU的reduce一样,同样也是为了累积求和。

    3.1.2 kvstore.h

    kvstore.h中定义了几个纯虚函数

    • Init:根据参数定义的一组KVPairs初始化
    • Push:将一组KVPairs执行Push操作
    • Pull: Pull操作
    • Updater:用于参数的更新

    3.1.3 local 和 device

    在初始化KVStore我们需要提供KVStore的类型,在MXNet中提供了localdevice两种用于单机训练时的类型。不论哪种,都在文件kvstore_local.h中定义。两者最主要的区别就是对Comm的选择,local会使用CommCPU来进行comm_的初始化,device使用CommDevice来初始化。
    KVStoreLocal有以下几个重要的方法:

    • Init:设置key的类型(str或者int),进行初始化。初始化的方法是使用comm_的初始化方法,同时还会在local_保存一个pinned_ctx_类型的拷贝。pinned_ctx_指的是不会被移出cache的内存。
    • Push:根据输入的KVPairs,使用comm_Reduce方法,进行相同keyvalue的求和。并且如果注册了updater_的话,会调用updater_进行更新。在进行更新之前,如果是在GPU端更新,会先将保存在local_的参数拷贝至GPU。
    • Pull:Pull方法主要的工作是将存储在local_的参数复制到对应的输出中。
      经过对这几个主要方法的理解,我们就清楚了KVStore的主要工作方式,也就对它对内存和显存的占用有了一个清晰的了解。具体的实现细节还是要参考源码去了解。

    3.1.4 KVStoreDist

    这篇博客的重点还是去试图了解分布式下的KVStore,当我们使用dist-*去create KVStore的时候,就会使用到类KVStoreDistKVStoreDist分两个主要部分,一个是worker,一个是server
    如果该节点是worker,首先会创建一个ps_worker_ = new ps::KVWorker<char>(0, new_customer_id);这个ps::KVWorker将在pslite部分具体解析,它是主要的完成pushpull操作的部分。
    server的启动:在我们通过import mxnet的时候,会导入kvstore_server,而导入该文件会允许语句_init_kvstore_server_module(),阅读该函数源码不难发现,它会判断当前节点是否是server节点,如果是就会调用server.run(),然后调用c++代码的MXKVStoreRunServer,也就是类KVStoreDistRunServer方法,该方法会创建server_ = new KVStoreDistServer();

    • set_updater:updater的设置是通过python端的函数定义来完成的,它通过ctype转换成为了c端的函数,并且通过pickle序列化为字符串传递给server。
      当然,我们的主要注意力还是放在pushpull的实现上。
    • Push_:push操作首先会通过comm_进行Reduce操作,并将结果存储在comm_buf_[key]中,完成了本地的Reduce后,调用EncodeDefaultKey函数将存储为key : intval : NDArray形式的KVPair,转化为PSKV形式,该形式用于Push操作。之后会通过PushDefault方法完成操作,该方法定义了函数push_to_servers,将comm_buf_[key]作为输入,通过Engine::Get()->PushAsync方法完成push操作的异步执行(只是将任务发给Engine,由Engine完成调度)。Engine会在适当的时机执行push_to_servers,该函数调用ps_worker_ZPush方法来实现分布式的push
    • PullImpl:pull操作由该函数来完成,该函数会根据keysserver端的结果获取到对应的NDArray中。中间结果会保存在comm_buf_[key]中,这里由于之前push将该变量作为了输入,Engine在调度执行时会考虑到这点,保证所有对comm_buf_[key]的操作都在对它的读入完成之后,也就是push完成之后(push将它作为了输入)。类似于Push_操作,Pull操作定义了函数pull_from_servers作为异步执行的函数,调用PushAsync发送给Engine。pull_from_servers函数调用了ps_worker_ZPull方法来完成分布式的pull操作。

    这里的分析只是简单的流程的总结,更多实现的细节可以通过阅读源码来了解。

    3.1.5 KVStoreDistServer

    如果当前节点是server,那么就会建立一个KVStoreDistServer对象,由该对象完成对workerpush,pull请求的处理。其中最重要的方法是DataHandleEx,它根据RequestType来调用相应的函数完成对数据的处理。
    KVStoreDistServer的构造函数中,会执行ps_server_ = new ps::KVServer<char>(0);它建立了一个ps::KVServer对象,该对象调用ps_server_->set_request_handle(std::bind(&KVStoreDistServer::DataHandleEx, this, _1, _2, _3));DataHandleEx绑定在自己的request_handle_上。

    • Run:前面提到过,如果该节点是server会调用RunServer方法,该方法就会调用if (server_) server_->Run();阅读KVStoreDistServer的源码发现,Run仅仅只有一行exec_.Start();。这一行会调用Executor exec_;Start方法,源码如下
    void Start() {
        std::unique_lock<std::mutex> lk(mu_);
        while (true) {
          cond_.wait(lk, [this]{return !queue_.empty();}); // queue_为空,则等待被唤醒
          Block blk = std::move(queue_.front()); // 取出queue头元素
          queue_.pop();
          lk.unlock(); // 释放锁,给其他线程操作queue
    
          if (blk.f) { // 如果blk定义了一个function,则允许他
            blk.f();
            blk.p->set_value(); // 返回function的结果
          } else {
            blk.p->set_value(); break;
          }
          lk.lock(); // 获取锁,执行下一个循环
        }
    

    调用ExecutorExec方法,会在queue中添加一个执行函数的block,代码如下

    void Exec(const Func& func) {
        Block blk(func); // 建立block
        auto fut = blk.p->get_future();
        {
          std::lock_guard<std::mutex> lk(mu_);
          queue_.push(std::move(blk));
          cond_.notify_one(); // 通知别的线程运行
        }
        fut.wait();
      }
    

    有了上面的知识,我们来看一下具体怎么处理数据。

    • DataHandleDefault:该方法是默认的数据处理的方法,由于DataHandleEx被绑定为了数据的处理函数,当RequestTypekDefaultPushPull,就会调用该函数。它会根据传入的信息,提取对应的key,将对应的数据存储在store_[key]。如果从worker来的request类型是push,就会分两种情况运行。一种是初始化的时候,由于初始化同样通过调用push来完成,因此初始化的push只会将store_[key]设置为对应的值。另一种是初始化后,每一次的push都会进行相应的操作。这里每一次从任何一个worker来的某一个keypush操作,都会存储在updates.merged中,并且除了第一次的push,之后的push会进行updates.merged += updates.temp_array;也就是和之前的push相加。并且ApplyUpdates只会在push数达到worker的个数的时候,才会真正地进行。也只有在ApplyUpdates真正执行的时候才会将回复返回给worker。这样,就实现了同步。

    对于server的讲解,这里也只是简单地描述它的同步和执行的简单机制,具体更多的实现细节,可以参考源码来了解。

    3.2 pslite

    通过前面的了解,我们知道了worker会使用ps_worker_ZPush方法来完成push操作,使用ZPull方法来完成pull操作。类似地,server会使用ps_server_request_handle_来进行数据处理的传递,使用SimpleApprequest_handle_来完成Command处理的传递。这一部分,我们就来了解它们的实现。
    KVWorkerKVServer都定义在文件kv_app.h中,它们都继承自SimpleApp

    3.2.1 kv_app

    kv_app是MxNet主要应用的部分。ps-lite实现了两个app,一个是simple_app,一个是kv_app

    KVWorker

    当数据从MXNet端的push函数传递到parameter server端时,调用了如下方法:

    int ZPush(const SArray<Key>& keys,
                const SArray<Val>& vals,
                const SArray<int>& lens = {},
                int cmd = 0,
                const Callback& cb = nullptr) {
        int ts = obj_->NewRequest(kServerGroup);
        AddCallback(ts, cb);
        KVPairs<Val> kvs;
        kvs.keys = keys;
        kvs.vals = vals;
        kvs.lens = lens;
        Send(ts, true, cmd, kvs);
        return ts;
      }
    

    该方法将kServerGroup作为数据传输对象,建立了KVPairs,通过Send方法,将数据发送给server。Send方法完成了数据从KVParisMessage的转换,然后调用Postoffice::Get()->van()->Send(msg);来执行数据的发送。

    相应地,在执行pull操作的时候,调用了Pull_方法,该方法首先定义了一个回调函数,该函数在完成pull操作后执行,具体来说就是当发出的请求都得到了回应后,会在Process方法中执行下列函数:

    // finished, run callbacks
      if (obj_->NumResponse(ts) == Postoffice::Get()->num_servers() - 1)  {
        RunCallback(ts);
      }
    

    KVServer

    前面说到过,KVServer会使用request_handle_来调用KVStore的数据处理函数。KVServer会在方法KVServer<Val>::Process中调用request_handle_,在这之前它会将得到的Message转换为KVMetaKVPairs。这样就完成了数据从接收到,再到传递给MXNet端的数据处理函数的过程。

    由于时间有限,内容较多,就不一一介绍函数。

    3.2.2 postoffice.cc

    Postoffice是一个类似于全局管理者的角色,它完成了环境初始化等必要工作

    3.2.3 van

    从前面的介绍我们看到,所有的数据在发送的最后,调用的都是van的send方法。van的具体实现类是ZMQVan。由于本人对于zmq也只是个初学者,这里有兴趣的同学可以去详细了解它的实现以及性能。

    3.2.4 meta.proto

    zmq在进行数据传输的时候,会建立socket,并且将字符串传递给对应的对象。在代码中,使用了protobuf来进行数据到字符串的转换工作。

    3.2.4 SArray.h

    SArray全名Shared array,它完成了在进行数据赋值过程中的零拷贝,及时是不同类型间数据的赋值,仅仅是将数据指向的指针进行赋值,同时将类型进行保存而已。

    3.2.5 message.h

    后记

    今天已经很晚了,只能在pslite部分草草收尾,希望下次进行总结的时候能够做的更好。
    总体来说,MXNet对于我这样一个初学者来说有很多可以学习的地方,并且它异步的实现和parameter server的设计,都是非常值得学习的内容。

    相关文章

      网友评论

        本文标题:MxNet源码解析(1) KVStore,pslite源码解析

        本文链接:https://www.haomeiwen.com/subject/awtwiftx.html