1. 介绍
这是一篇粗浅并且可能存在错误的个人理解
我们在使用MXNet的时候,都是通过调用python端提供的接口。通过一步步地构建symbol,在调用module
进行训练的时候,其实MXNet帮助我们建立了graph
。那么,这个图到底是如何建立的呢?反向传播又是如何进行的呢?以及我们在c++
, cu
端建立的op
是怎么通过 python
来进行调用的呢?
2. 从op到python
那么我们在c++端定义的operator到底是怎么引入到python的模块空间下面的呢?比如,我在{mxnet_root}/src/operator/
下面定义了一个operator,然后重新编译mxnet,就可以通过mxnet.sym
以及mxnet.module
来进行调用啦。这是怎么做到的呢?这里我们以symbol
为例分析。
首先我们在import mxnet
的时候会运行{mxnet_root}/python/mxnet
下面的__init__.py
,该文件中运行from . import symbol
,该行代码运行了symbol
文件夹下的__init__.py
。
之后依次运行:
- __init__.py:
from . import _internal, contrib, linalg, op, random, sparse, image
- _internal.py:
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
from .._ctypes.symbol import SymbolBase, _set_symbol_class
from .._ctypes.symbol import _symbol_creator
- __init__.py:
from . import register
- register.py:
_init_op_module('mxnet', 'symbol', _make_symbol_function)
这里,register.py文件下进行 op 到symbol模块空间下的注册,也就是通过_init_op_module
来完成。具体大家可以通过该函数进行跟踪。
其中比较重要的函数是同文件下的_generate_symbol_function_code
函数。该函数直接生成op
对应的方法的code,再通过python
的exec
函数进行实际的定义,接着注册到对应的模块下面。
3. symbol
我们已经知道了op如何到symbol下面来让我们调用,那么我们调用symbol下的方法时,是怎么联系起来的呢?
直观上,我们通常通过mx.sym.var
(tensorflow下tf.placeholder
, pytorch下Tensor
)来创建输入变量,输入变量作为某个op
的输入,一步步构建新的symbol。最终的symbol往往是通过loss
函数,输出一个标量。
每当我们调用某个op时,例如mx.sym.FullyConnected
,会将输入参数通过一系列操作之后,传给_symbol_creator
。我们查看该函数源码,会发现它首先调用了MXSymbolCreateAtomicSymbol
来根据参数创建一个symbol
的handle
,之后调用symbol
的compose
进行具体op
的调用。
MXSymbolCreateAtomicSymbol
函数定义在src/c_api/c_api_symbolic.cc
里面。输入的第一个参数creator
其实是通过NNGetOpHandle
返回的一个OpHandle
,它是通过op_name
得到的。得到op
后通过CreateFunctor
返回symbol
,这里会建立symbol
的outputs
。
创建好symbol
后,就通过s._compose(name=name, **kwargs)
进行调用啦。其中kwargs
应该(还未仔细确认)是指类型为Symbol
的输入参数,例如我们指定的weight,data
等。
那么重点就到了compose
方法。
4. compose
symbol
的_compose
接着调用了NNSymbolCompose
,该函数输入参数为:symbol指针,name,参数个数,参数key,参数值的指针。
NNSymbolCompose
主要调用了Symbol::Compose
。该方法签名为:
void Symbol::Compose(const array_view<const Symbol*>& args,
const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name)
compose方法会首先进行一系列的输入的操作,检查。之后根据该symbol的输出是不是atomic
(根据op的输入判断),相应地进行compose。这里仅分析一下普通情况,也就是不是atomic。
首先我们要了解一下DFSVisit
这个方法
4.1 DFSVisit
查看该函数的源代码,我们看到,首先将传入的heads
(这里传入的一般是该symbol的outputs
),从类型NodeEntry
转换为GNode
。
之后调用PostOrderDFSVisit
,也就是post order深度优先遍历。后序遍历,我们知道,就是先子树,最后访问节点,也就是调用访问函数visit
。
PostOrderDFSVisit
的访问参数如下:
-
head_nodes
:根节点 - `[fvisit](GNode n) {fvisit(*n);}:访问函数,直接调用visit函数
-
n->get();
:hash函数,用于表示节点 -
(*n)->inputs.size() + (*n)->control_deps.size()
:入度的计算 - 根据
index
返回输入的函数
继续看PostOrderDFSVisit
源码,我们发现其实它是一个拓扑排序,用于DAG的遍历。它从每个根节点开始,不断将节点的输入加入visited, stack
,当某个节点的所有输入都完成了访问(入度等于访问次数),就从stack
删除,直到stack
为空。
了解了DFSVisit
的作用后(遍历DAG),我们再来看compose函数。
对于非atomic的情况,首先定义了一个访问函数find_replace_map
,它建立了一个replace_map
,这个map的key是类型为variable
的节点,对应的value是对应输入参数的outpus[0]。
接着又根据这个find_replace_map
建立了一个新的访问函数find_replace_plan
,它遍历每个节点的所有输入,如果输入节点存在在之前创建的replace_map
中,就将map对应的那一项加入到replace_plan
中,并且将该节点加入到update_node
中。
之后遍历replace_plan
完成替换,遍历update_node
,完成更新。
这里,我理解是对node直接的denpendecy做一些必要的处理。
5. 图
那么图是哪里构建的呢?
我们使用symbol构建网络的时候,需要调用bind
来进行必要的初始化。该方法返回一个executor
,executor
的创建会调用GraphExecutor::Init
。在这个函数里,进行了symbol到图的转换。
反向图的建立是通过Gradient
的pass
完成的。
网友评论