美文网首页
1.mxnet代码学习1——Trainer

1.mxnet代码学习1——Trainer

作者: 吃个小烧饼 | 来源:发表于2017-12-08 03:58 被阅读216次

    Trainer类是gluon下的一个类。顾名思义,就是主导“训练”的一个驱动方法。可以说不管写出多炫酷的网络结构,都需要从这个对象开始训练。可以从它的初始化方法__init__()看到一些基本信息:
    def __init__(self, params, optimizer, optimizer_params=None, kvstore='device'):
    忽略最后一个参数kvstore,我们可以知道其需要的几个参数是:params, optimizer, optimizer_params

    先说params,其注释部分有说:params : ParameterDict The set of parameters to optimize.,就是需要优化的参数啦。举个例子,一个cnn里,就是每个卷积核的那些参数。这个params如果获取到的是一个ParameterDict,我们就需要将其转化为一个listparams = list(params.values()),随后就把它写为一个内部 list 对象_params。所以说,我感觉在写工业的代码时,要充分将收集到的对象规范化,一般我自己在写算法代码时从来不考虑这些。

    所以,要是你的参数很容易得到,层数也不多,完全可以手动添加到一个list里,在mxnet的官方中文教程里就有这样的示范:

    params = [W1, b1, W2, b2, W3, b3, W4, b4]
    for param in params:
        param.attach_grad()
    

    就是收集到我们要更新的参数,然后记录它的梯度用以更新。既然说到这里了就可以看一下在这里如何更新:由于attach_grad对梯度创建一个“占位”,当执行完计算输出->输出和已有的做交叉熵计算loss之后,需要执行mxnet自带的backword来自动计算梯度。当然,具体mxnet(以后简写为mx)是如何计算的,它是拿cpp写的,有时间的话我也学习一下。当然,做完后向传播之后还要更新一下你的参数,比如SGD就是

    def SGD(params, lr):
        for param in params:
            param[:] = param - lr * param.grad
    

    说了这么多,好像越扯越远。确实,从想要说明参数表可以是 list rather than dict,花费了我大量不需要的时间······

    回到正题。当我们有了_params之后,我们就需要来解析优化器optimizer了。在我们一般使用时,往往(或者绝对)输入的是一个字符串'sgd'这样的,然而实际上这里要配合前后2个参数将其转化为一个Optimizer类对象。这个Optimizer产生的方式、也就是直接写的话,是这样的:

    sgd = mx.optimizer.Optimizer.create_optimizer('sgd',param_dict=param_dict, learning_rate=.1)
    

    然后把这个对象返回个内部私有对象_optimizer。而在传递param_dict之前,首先要使用_params建立一个词典,方法也很简单:

    param_dict = {i: param for i, param in enumerate(self._params)}
    

    这里就要说到最重要的地方了。上边讲到的create_optimizer是这样声明的:

    def create_optimizer(name, **kwargs):
    

    name是一个str,**kwargs是字典形式的参数表,返回一个optimizer类对象,最骚的来了,这里的对象在创建的时候给一个叫opt_registry的词典以name
    key ,以返回的类为 value 添加一项。这一部分在函数def register(klass):中。

    而包含他们的文件optimizer.py 里有以register为装饰器执行register,所以在执行create之前必然先创建了以输入的name为名的对象。
    这个装饰器的位置在类Optimizer之后,在所有优化方法之前。所以我们的操作就是:

    所以当我们具体调用一个Trainer时,背后的操作如下:

    trainer=gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':.1})
    

    1.进入optimizer.py包,执行register = Optimizer.register对象化register()方法。

    2.register函数对象为装饰器,遍历下边的具体优化方法类SGD Adam等,也就相当于register(SGD)等等方法。
    执行此函数时创建Optimezer类,其中有一个字典opt_registry,我们把具体优化方法类的名字的小写作为opt_registry的一项的 key ,具体的优化方法类作为 value ,构建一个字典。

    3.初始化函数对象createcreate = Optimizer.create_optimizer。噼里啪啦一顿操作后完成,模块读取完毕。

    4.回到有trainer的模块,收集参数net.collect_params(),将其构建为字典,然后Trainer执行主方法,调用刚刚生产的create,以参数字典为一个参数,字符串sgdname,具体优化参数字典为**kwargs
    具体的create过程就是,检查字符串name是否在刚刚构建的词典opt_registry内,在的话匹配,即初始化与其匹配的优化方法类:

    Optimizer.opt_registry[name.lower()](**kwargs)
    

    5.结束了,返回 >_<

    相关文章

      网友评论

          本文标题:1.mxnet代码学习1——Trainer

          本文链接:https://www.haomeiwen.com/subject/lifsixtx.html