美文网首页
mxnet源码分析1

mxnet源码分析1

作者: 迷途的Go | 来源:发表于2017-12-30 20:38 被阅读0次

如何使用

from mxnet import nd
from mxnet.gluon import nn
from mxnet import gluon
from mxnet import autograd

class Net(nn.Block):
    def __init__(self, **kwargs):
        super(Net, self).__init__(**kwargs)
        self.dense0 = nn.Dense(4, use_bias=False)
        self.dense1 = nn.Dense(2, use_bias=False)

    def forward(self, x):
        return self.dense1((self.dense0(x)))
def train():
    net = Net()
    net.initialize()
    w = net.dense0.weight
    print ('weight shape after initialize', w.shape, 'weight params', w.data())
    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 1})
    data = nd.ones(shape=(1, 1, 28, 28))
    label = nd.ones(shape=(1, 10))
    loss = gluon.loss.L2Loss()
    with autograd.record():
        res = net(data)
        w = net.dense0.weight
        print ('net[0] name', net.dense0.name, 'weight shape', w.shape, '\nparams', w.data(), 'grad', w.grad())
        L = loss(res, label)
    L.backward()
    trainer.step(batch_size=1)
    print ('net[0] name', net.dense0.name, 'weight shape', w.shape, '\nparams', w.data(), 'grad', w.grad())
if __name__ == '__main__':
    train()

执行输出结果:

('weight shape after initialize', (4, 0))
('net[0] name', 'dense0', 'weight shape', (4L, 784L), '\nparams', 
[[ 0.04118239  0.05352169 -0.04762455 ...,  0.03089482 -0.00140258
   0.01266012]
 [-0.00697319 -0.00986735 -0.03128323 ...,  0.02195714 -0.04105704
   0.01050965]
 [ 0.02380178 -0.04182156  0.04908523 ..., -0.05005977 -0.0463761
   0.0436078 ]
 [-0.04813539 -0.03545294 -0.01216894 ...,  0.06526501 -0.00576673
  -0.02751607]]
<NDArray 4x784 @cpu(0)>, 'grad', 
[[ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]]
<NDArray 4x784 @cpu(0)>)
('net[0] name', 'dense0', 'weight shape', (4L, 784L), '\nparams', 
[[ 0.02016377  0.03250307 -0.06864318 ...,  0.00987619 -0.0224212
  -0.00835851]
 [-0.05362909 -0.05652324 -0.07793912 ..., -0.02469876 -0.08771293
  -0.03614624]
 [ 0.0333778  -0.03224555  0.05866124 ..., -0.04048375 -0.03680009
   0.05318382]
 [-0.03410936 -0.02142691  0.00185709 ...,  0.07929104  0.00825929
  -0.01349004]]
<NDArray 4x784 @cpu(0)>, 'grad', 
[[ 0.02101862  0.02101862  0.02101862 ...,  0.02101862  0.02101862
   0.02101862]
 [ 0.04665589  0.04665589  0.04665589 ...,  0.04665589  0.04665589
   0.04665589]
 [-0.00957601 -0.00957601 -0.00957601 ..., -0.00957601 -0.00957601
  -0.00957601]
 [-0.01402603 -0.01402603 -0.01402603 ..., -0.01402603 -0.01402603
  -0.01402603]]
<NDArray 4x784 @cpu(0)>)

以上代码包含了一个神经网络的典型结构:

  • 定义网络,上面是一个mlp
  • 网络初始化
  • 训练网络
    • 前向传播
    • 计算loss
    • 反向传播得到梯度
    • 更新权重参数

上面的代码证明了两样东西:

  • 定义网络后初始化给出的第二维是0, 这个是由于mxnet参数初始化延迟推导, 不知道输入,没办法知道第二维参数,相比于pytorch,优点是不用定义每一层网络的输入大小,但是一次forward之前就不知道参数的形状了
  • weight=weight-lr*grad, 以第一个参数为例,上述打印的结果前向传播的时候大小是0.04118239, 梯度是0, 一次反向传播后,梯度是0.02101862, 新的参数是0.02016377=0.04118239-0.02101862

以上涉及了gluon的关键组件:

  • gluon.nn.Block,Sequential, HybridBlock, HybridSequential的父类
  • loss
  • gluon.Trainer, 用来辅助更新模型参数的辅助类
  • mxnet.optimizer
  • mxnet.nd

相关文章

网友评论

      本文标题:mxnet源码分析1

      本文链接:https://www.haomeiwen.com/subject/jdungxtx.html