美文网首页
MXNet中的图是怎么构建的?

MXNet中的图是怎么构建的?

作者: Junr_0926 | 来源:发表于2019-02-20 23:47 被阅读0次

1. 介绍

这是一篇粗浅并且可能存在错误的个人理解

我们在使用MXNet的时候,都是通过调用python端提供的接口。通过一步步地构建symbol,在调用module进行训练的时候,其实MXNet帮助我们建立了graph。那么,这个图到底是如何建立的呢?反向传播又是如何进行的呢?以及我们在c++, cu端建立的op是怎么通过 python来进行调用的呢?

2. 从op到python

那么我们在c++端定义的operator到底是怎么引入到python的模块空间下面的呢?比如,我在{mxnet_root}/src/operator/下面定义了一个operator,然后重新编译mxnet,就可以通过mxnet.sym以及mxnet.module来进行调用啦。这是怎么做到的呢?这里我们以symbol为例分析。

首先我们在import mxnet的时候会运行{mxnet_root}/python/mxnet下面的__init__.py,该文件中运行from . import symbol ,该行代码运行了symbol文件夹下的__init__.py

之后依次运行:

  • __init__.py: from . import _internal, contrib, linalg, op, random, sparse, image
  • _internal.py:
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
    from .._ctypes.symbol import SymbolBase, _set_symbol_class
    from .._ctypes.symbol import _symbol_creator
  • __init__.py: from . import register
  • register.py: _init_op_module('mxnet', 'symbol', _make_symbol_function)

这里,register.py文件下进行 op 到symbol模块空间下的注册,也就是通过_init_op_module来完成。具体大家可以通过该函数进行跟踪。
其中比较重要的函数是同文件下的_generate_symbol_function_code函数。该函数直接生成op对应的方法的code,再通过pythonexec函数进行实际的定义,接着注册到对应的模块下面。

3. symbol

我们已经知道了op如何到symbol下面来让我们调用,那么我们调用symbol下的方法时,是怎么联系起来的呢?

直观上,我们通常通过mx.sym.var(tensorflow下tf.placeholder, pytorch下Tensor)来创建输入变量,输入变量作为某个op的输入,一步步构建新的symbol。最终的symbol往往是通过loss函数,输出一个标量。

每当我们调用某个op时,例如mx.sym.FullyConnected,会将输入参数通过一系列操作之后,传给_symbol_creator。我们查看该函数源码,会发现它首先调用了MXSymbolCreateAtomicSymbol来根据参数创建一个symbolhandle,之后调用symbolcompose进行具体op的调用。

MXSymbolCreateAtomicSymbol函数定义在src/c_api/c_api_symbolic.cc里面。输入的第一个参数creator其实是通过NNGetOpHandle返回的一个OpHandle,它是通过op_name得到的。得到op后通过CreateFunctor返回symbol,这里会建立symboloutputs

创建好symbol后,就通过s._compose(name=name, **kwargs)进行调用啦。其中kwargs应该(还未仔细确认)是指类型为Symbol的输入参数,例如我们指定的weight,data等。

那么重点就到了compose方法。

4. compose

symbol_compose接着调用了NNSymbolCompose,该函数输入参数为:symbol指针,name,参数个数,参数key,参数值的指针。
NNSymbolCompose主要调用了Symbol::Compose。该方法签名为:

void Symbol::Compose(const array_view<const Symbol*>& args,
                     const std::unordered_map<std::string, const Symbol*>& kwargs,
                     const std::string& name)

compose方法会首先进行一系列的输入的操作,检查。之后根据该symbol的输出是不是atomic(根据op的输入判断),相应地进行compose。这里仅分析一下普通情况,也就是不是atomic。

首先我们要了解一下DFSVisit这个方法

4.1 DFSVisit

查看该函数的源代码,我们看到,首先将传入的heads(这里传入的一般是该symbol的outputs),从类型NodeEntry转换为GNode

之后调用PostOrderDFSVisit,也就是post order深度优先遍历。后序遍历,我们知道,就是先子树,最后访问节点,也就是调用访问函数visit

PostOrderDFSVisit的访问参数如下:

  • head_nodes:根节点
  • `[fvisit](GNode n) {fvisit(*n);}:访问函数,直接调用visit函数
  • n->get();:hash函数,用于表示节点
  • (*n)->inputs.size() + (*n)->control_deps.size():入度的计算
  • 根据index返回输入的函数

继续看PostOrderDFSVisit源码,我们发现其实它是一个拓扑排序,用于DAG的遍历。它从每个根节点开始,不断将节点的输入加入visited, stack,当某个节点的所有输入都完成了访问(入度等于访问次数),就从stack删除,直到stack为空。

了解了DFSVisit的作用后(遍历DAG),我们再来看compose函数。

对于非atomic的情况,首先定义了一个访问函数find_replace_map,它建立了一个replace_map,这个map的key是类型为variable的节点,对应的value是对应输入参数的outpus[0]

接着又根据这个find_replace_map建立了一个新的访问函数find_replace_plan,它遍历每个节点的所有输入,如果输入节点存在在之前创建的replace_map中,就将map对应的那一项加入到replace_plan中,并且将该节点加入到update_node中。

之后遍历replace_plan完成替换,遍历update_node,完成更新。

这里,我理解是对node直接的denpendecy做一些必要的处理。

5. 图

那么图是哪里构建的呢?

我们使用symbol构建网络的时候,需要调用bind来进行必要的初始化。该方法返回一个executorexecutor的创建会调用GraphExecutor::Init。在这个函数里,进行了symbol到图的转换。

反向图的建立是通过Gradientpass完成的。

相关文章

网友评论

      本文标题:MXNet中的图是怎么构建的?

      本文链接:https://www.haomeiwen.com/subject/qtskyqtx.html