美文网首页mxnetMXNet
MxNet源码解析(2) symbol

MxNet源码解析(2) symbol

作者: Junr_0926 | 来源:发表于2018-09-14 17:30 被阅读0次

    1. 前言

    我们在训练之前,先建立好一个图,然后我们可以在这个图上做我们想做的优化,这种形式称为Symbolic Programs。相对应的是Imperative Programs,也就是每一句代码都对应着程序的执行,在这种情况下,我们可以写类似于下面的代码:

    a = 2
    b= a + 1
    d = np.zeros(10)
    for i in range(d):
        d += np.zeros(10)
    

    这在symbolic的方式下是做不到的,因为在for循环开始时,程序并不知道d的值,也就无法判断循环的次数。
    因此我们可以说,symbolic更高效,imperative更灵活。

    MxNet是一个异步式的训练框架,它支持上面的两种形式。我们可以使用NDArray来进行imperative形式的程序编写,也可以使用symbol来建立图。

    2. op

    先来了解operator,不了解operator可能就很难理解源码中占据了很大一部分的operator的定义。就是通过这些operator来将symbol连接成为了一个图。

    • OpManager:单例结构体,通过OpManager::Global()总会返回同一个结构体。Op的构造函数会将OpManagerop_counter加一,并且将自己的index_注册为当前的op_counter
    • add_alias:将别名注册到`dmlc::Registry<Op>中
    • Get:根据name返回Op
    • GetAttrMap

    2.1 op

    • name:名字
    • description:该op的描述
    • num_inputs:输入的个数
    • num_outputs:输出的个数
    • get_num_outputs, get_num_inputs:函数,返回输出,输入的个数
    • attr_parser:函数,用于方便返回该op的参数
    • Op& Op::describe(const std::string& descr):方法用于将输入注册到description变量中,并返回这个op,方便接着调用其他方法。

    2.2 几个宏

    • #define NNVM_REGISTER_VAR_DEF(OpName):定义OpName
    • #define NNVM_REGISTER_VAR_DEF(TagName):定义TagName
    #define NNVM_REGISTER_OP(OpName) \
      DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName, __COUNTER__) = \
        ::dmlc::Register<::nnvm::op>::Get()->__REGISTER_OR_GET(#OpName)
    

    注册op,并返回该op

    3. Node

    Node是组成symbol的基本组件。
    结构体NodeEntry包含了:

    • node:指向node的指针
    • index:输出的索引值
    • version:输入的version

    结构体NodeAttrs包含了:

    • op: 指向operator的指针
    • name: node的名字
    • dict:attributes的字典

    Node包含:

    • attrs:结构体NodeAttrs成员,存储了op, name, attributes等信息。
    • inputs:输入,是一个元素为NodeEntry的向量
    • control_deps:保存了应该在该node执行之前执行的node
    • op():返回该Node的operator,就是返回attrs中保存的op
    • Create():类方法,静态方法,用于新建一个Node,返回指向它的指针
    • num_outputs:如果是变量,输出为1,否则返回op的输出

    几个函数

    定义在文件op_attr_types.h

    • FListinputNames:返回输入的名字,默认return {'data'}
    • FNumVisibleOutputs:用于隐藏一些输出
    • FListOutputNames:返回输出的名字
    • FMutateInputs:返回该node会改变的node的索引值
    • FInferNodeEntryAttr:推理出AttrType
    • FInferShape:推理shape,也就是上面的AttrTypeTshape
    • FInferType:推理类型
    • TIsBackward是否是反向传播
    • FInplaceOption
    • FGradient:返回node的梯度节点
    • FSetInputVarAttrOnCompose:为输入设置attribute
    • FCorrectLayout:推理layout
    • FInputGraph:返回输入,解释为图而不是数据

    这些函数是在定义具体的op时,可以选择注册对应的函数。

    4. Symbol

    Symbol是为了使用Node建立Graph。Symbol是我们能够直接接触的类,它定义了一系列方法用于更方便地构建图。在symbol的成员outputs中,定义了一组由NodeEntry组成的向量。

    • outputs:该symbol包含的输出,是一个元素是NodeEntry的向量
    • Copy:返回一个深拷贝,方式是通过遍历Node,每次访问到的Node保存起来,再建立起node之间的连接,最后将head加入到outputs中。
    • Symbol operator[] (size_t index) const:返回第n个输出。
    • ListInputs:返回输入
    • ListInputNames:返回输入的名字
    • Compose:组合symbol
    • operator ():调用compose,来组合symbol
    • AddControlDeps:加入控制,用于有向图的构建
    • GetInternals:返回一个symbol,它的输出是原来symbol的输出加上所有中间输出和输入
    • GetChildren
    • SetAttrs:设置attribution
    • GetAttrs
    • CreateFunctor:给定op和attrs,返回一个symbol
      我认为symbol中比较重要的函数是compose,在调用的时候我们是通过调用symbol的操作符()函数,也就是operator (),该函数将参数传递给Compose

    5. Graph

    Graph就是计算的时候使用的图

    • outputs:和symboloutputs一样,类型为std::vector<NodeEntry>
    • attrs:定义了图的一些属性
    • PostOrderDFSVisit:后序遍历图,给定参数head,进行拓扑排序。算法,貌似,就是拓扑排序算法。
    • DFSVisit:调用PostOrderDFSVisit,对图的head进行拓扑排序。参数为:const std::vector<NodeEntry>& heads, FVisit fvisit,其中head是反向传播时的头节点,fvisit是访问时调用的函数,该方法将fvisit(*n)作为访问节点时的函数,[](GNode n)->Node*{return->get();}作为hash函数,这个函数看签名返回的是一个指向节点的指针。图的节点入度计算如下:
    [](GNode n)->uint32_t {
      if (!(*n)) return 0;
      return (*n)->input.size() + (*n)->control_deps.size();
    }
    

    节点输入计算如下:

    [](GNode n, uint32_t index)->GNode {
      if (index < (*n)->input.size()) {
        return &(*n)->input.at(index).node;
      } else {
      return &(*n)->contorl_deps.at(index - (*n)->inputs.size());
    }
    

    6. IndexedGraph

    IndexedGraphGraph返回,

    • nodes_:成员变量,一个指向Node结构体的向量,Node定义如下:
    struct Node {
      const nnvm::Node* source;
      array_view<NodeEntry> inputs;
      array_view<uint32_t> control_deps;
      std::weak_ptr<nnvm::Node> weak_ref;
    };
    

    其中NodeEntry如下:

    struct NoodeEntry {
      uint32_t node_id;
      uint32_t index;
      uint32_t version;
    };
    

    成员变量:

    • input_nodes_:输入node的索引
    • mutable_input_nodes_
    • outputs:输出节点
    • node2index:node到索引的映射
    • entry_rptr_:
    • input_entries_
    • control_deps_
      方法:
    • DFSVisit
    • PostOrderDFSVisti

    7. pass

    7.1 gradient.cc

    • Gradientgradient会根据属于的graph,返回一个带反向传播图的新图。它主要由executor建立图的时候调用,调用方式如下:
    nnvm::Graph g_grad = nnvm::pass::Gradient(g, 
                symbol.outputs, xs, head_grad_entry_, ArggregateGradient,
                need_mirror, nullptr, zero_ops, "_copy");
    

    调用该方法会调用文件pass_function.h下的Gradient函数。该函数将传入的参数保存在graph下的attrs中。再通过applypass调用Gradient方法。也就是在该文件下定义的方法,签名:Graph Gradient(Graph src)

    1. 根据DFSVisit进行拓扑排序,将序列存储到topo_order
    2. 将输出的梯度保存在output_grads
    3. 根据mirror_fun在适当的地方插入新的节点,来实现内存的复用
    • DefaultAggregateGradient

    7.2 plan_memory.cc

    7.3 place_device.cc

    7.4 correct_layout.cc

    相关文章

      网友评论

        本文标题:MxNet源码解析(2) symbol

        本文链接:https://www.haomeiwen.com/subject/eudziftx.html