Trainer
类是gluon
下的一个类。顾名思义,就是主导“训练”的一个驱动方法。可以说不管写出多炫酷的网络结构,都需要从这个对象开始训练。可以从它的初始化方法__init__()
看到一些基本信息:
def __init__(self, params, optimizer, optimizer_params=None, kvstore='device'):
忽略最后一个参数kvstore
,我们可以知道其需要的几个参数是:params, optimizer, optimizer_params
。
先说params
,其注释部分有说:params : ParameterDict The set of parameters to optimize.
,就是需要优化的参数啦。举个例子,一个cnn里,就是每个卷积核的那些参数。这个params如果获取到的是一个ParameterDict
,我们就需要将其转化为一个list
:params = list(params.values())
,随后就把它写为一个内部 list 对象_params
。所以说,我感觉在写工业的代码时,要充分将收集到的对象规范化,一般我自己在写算法代码时从来不考虑这些。
所以,要是你的参数很容易得到,层数也不多,完全可以手动添加到一个list里,在mxnet的官方中文教程里就有这样的示范:
params = [W1, b1, W2, b2, W3, b3, W4, b4]
for param in params:
param.attach_grad()
就是收集到我们要更新的参数,然后记录它的梯度用以更新。既然说到这里了就可以看一下在这里如何更新:由于attach_grad
对梯度创建一个“占位”,当执行完计算输出->输出和已有的做交叉熵计算loss之后,需要执行mxnet自带的backword
来自动计算梯度。当然,具体mxnet(以后简写为mx)是如何计算的,它是拿cpp写的,有时间的话我也学习一下。当然,做完后向传播之后还要更新一下你的参数,比如SGD就是
def SGD(params, lr):
for param in params:
param[:] = param - lr * param.grad
说了这么多,好像越扯越远。确实,从想要说明参数表可以是 list rather than dict,花费了我大量不需要的时间······
回到正题。当我们有了_params
之后,我们就需要来解析优化器optimizer
了。在我们一般使用时,往往(或者绝对)输入的是一个字符串'sgd'
这样的,然而实际上这里要配合前后2个参数将其转化为一个Optimizer
类对象。这个Optimizer
产生的方式、也就是直接写的话,是这样的:
sgd = mx.optimizer.Optimizer.create_optimizer('sgd',param_dict=param_dict, learning_rate=.1)
然后把这个对象返回个内部私有对象_optimizer
。而在传递param_dict
之前,首先要使用_params
建立一个词典,方法也很简单:
param_dict = {i: param for i, param in enumerate(self._params)}
这里就要说到最重要的地方了。上边讲到的create_optimizer
是这样声明的:
def create_optimizer(name, **kwargs):
name是一个str,**kwargs是字典形式的参数表,返回一个optimizer
类对象,最骚的来了,这里的对象在创建的时候给一个叫opt_registry
的词典以name
为
key ,以返回的类为 value 添加一项。这一部分在函数def register(klass):
中。
而包含他们的文件optimizer.py 里有以register
为装饰器执行register,所以在执行create之前必然先创建了以输入的name
为名的对象。
这个装饰器的位置在类Optimizer之后,在所有优化方法之前。所以我们的操作就是:
所以当我们具体调用一个Trainer
时,背后的操作如下:
trainer=gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':.1})
1.进入optimizer.py
包,执行register = Optimizer.register
对象化register()
方法。
2.以register
函数对象为装饰器,遍历下边的具体优化方法类SGD
Adam
等,也就相当于register(SGD)
等等方法。
执行此函数时创建Optimezer
类,其中有一个字典opt_registry
,我们把具体优化方法类的名字的小写作为opt_registry
的一项的 key ,具体的优化方法类作为 value ,构建一个字典。
3.初始化函数对象create
:create = Optimizer.create_optimizer
。噼里啪啦一顿操作后完成,模块读取完毕。
4.回到有trainer
的模块,收集参数net.collect_params()
,将其构建为字典,然后Trainer
执行主方法,调用刚刚生产的create
,以参数字典为一个参数,字符串sgd
为name
,具体优化参数字典为**kwargs
。
具体的create
过程就是,检查字符串name
是否在刚刚构建的词典opt_registry
内,在的话匹配,即初始化与其匹配的优化方法类:
Optimizer.opt_registry[name.lower()](**kwargs)
5.结束了,返回 >_<
网友评论