这是一篇使用增强学习来进行模型搜索的论文。
结构如下图:
overview
由于不知道网络的长度和结构,作者使用了一个RNN作为控制器,使用该控制器来产生一串信息,用于构建网络。之后训练该网络,并用网络的accuracy作为reward返回给控制器来更新控制器的参数,达到更优的策略。
其中控制器(RNN)的设计借鉴了sequence to sequence的思想,不同的是它优化的是一个不可微的目标,也就是 网络的accuracy。
方法
CNN上图展示了如何使用RNN控制器产生一个简单的CNN网络,对于CNN网络的每一层,控制器都会产生一组超参数,当层数达到一个阈值,就会停止。RNN的参数会通过增强学习算法更新,以得到更好的模型结构。
使用REINFORCE来训练
控制器可以看作agent,控制器产生一组token,也就是超参数,看作agent的action,使用产生的模型在验证集的准确率作为reward。因此,控制器需要优化下面公式:
但是 REINFORCE
它的一个经验近似公式如下:
empirical approximation
是一个batch中的模型个数
是超参数的个数
由于以上公式会遇到variance过大的问题,可以使用如下带baseline的公式
with baseline
分布式训练加速
分布式框架如下图
思路:其中 parameter server共同保存了控制器的所有参数,这些server将参数分发给controller,每一个controller使用得到的参数进行模型的构建,这里由于得到的参数可能不同,构建模型的策略是随机的,导致每次构建的网络结构也会不同。每个controller会构建一个batch,也就是 anchor
相对应的,agent的action选择如下图:
agent action
它会根据前面算出的概率和自己的策略,判断是否加入connection。
最终,所有没有后续连接的层都会被连接到输出层,如果连接的两个层大小不一致,就将小的层用0来填充。(pad with zeros)
产生RNN网络cell
为了产生RNN网络cell,类似于LSTM,作者使用了一种树的结构,每一个树的节点都会拥有一个操作(addition, elementwise multiplication, etc.)和一个激活函数(tanh, sigmoid等)。每一个节点的输入,都连接了两个其他节点的输出。为了使用上面描述的方法,作者将每个节点编号,按照顺序预测。如下图:
RNN
根据预测的结果,将会按照如下方式构建网络:
Computation steps
总结
这篇文章将增强学习的算法应用在了模型预测上,并且巧妙地使用RNN来预测参数。总体思路依旧是通过在一个有限的搜索空间进行高效的搜索,来不断提高agent预测的模型的准确率。
note:REINFORCE算法真神奇,能够直接使用一个简单的标量reward来知道agent更新参数的方式。
网友评论