——百度
将搜索空间视为有向无环图,为该有向无环图设计可微采样器,该采样器可学习,可以由搜索得到的结构在验证集上的损失来优化,因此称之为:Gradient-based search using Differentiable Architecture Sampler,在CIFAR-10数据集上4 GPU hours可以完成一次搜索过程,达到2.82%的测试错误率和2.5M的参数量。
介绍
搜索一个鲁棒的神经单元(cell)而非整个网络,该单元包含许多变换特征的结构,一个神经网络包含许多这样的单元。下图表示了搜索过程,将一个单元的搜索空间表示为一个DAG(有向无环图),每个灰色节点表示为特征张量,由操作顺序命名。不颜色的边代表不同类型的操作,将某一节点转换为中间特征。同时,每个节点是所有前层节点中间转换特征的累加。在训练时,GDAS从整个DAG中采样一个子图,在子图中每个节点只接受所有前层节点的一个中间特征,具体地,在两个相邻节点的所有中间特征中,GDAS以可微的方式采样一种特征。由此,GDAS能端到端地以梯度下降的方式进行训练,来发现一个鲁棒的cell。
GDASGDAS的快主要来源于采样操作,一个DAG包含上百种参数化操作,有着上百万的参数量,直接优化整个DAG(DARTS)将带来两个缺点:1、在一个迭代步中更新大量的参数将耗费很长时间,导致搜索时间超过一天。2、同时优化不同的操作会使得它们相互竞争,例如,不同的操作可能会产生相反的结果。这些相反的操作结果会相互抵消而带来弥散,破坏两个相邻节点之间的信息流动和优化过程。为了解决这两个问题,GDAS在一次迭代中只采样一个子图,因此一次迭代只需要优化DAG的一个部分,加速了训练过程。
GDAS相较于先前的基于强化学习的方法(RL-based)和遗传算法的方法(EA-based)使得搜索过程可微,可以使用梯度下降法。对于强化学习和遗传算法,他们反馈的信息是通过长时间训练的轨迹来进行reward的,而GDAS则是通过损失来反馈的,而且在梯度下降法中,损失是一个连续的可以在每次迭代中给出的量。且GDAS中的采样过程是可以学习的。
方法
对于CNN,一个单元是全卷积的,将所有之前单元的输出作为输入,产生输出特征张量。将CNN中的单元表示为DAG,包含一系列有序计算节点,每个节点代表一个特征张量,由前面两个特征张量变换而来:
特征变换其中,分别代表第个节点,分别表示来自候选操作集中的两个操作函数。当计算节点数量时,整个单元的节点有7个,代表前面两个单元的输出,代表计算节点。代表该单元的输出张量,表示为。在GDAS中,候选操作集合包含8种操作:恒等映射,零操作,3*3 depth-wise卷积,3*3 depth-wise空洞卷积,5*5 depth-wise空洞卷积,3*3 平均池化,3*3 最大池化(一如DARTS)。
同样搜索两种单元:正常单元和降采样单元,每个正常单元的操作步长为1 ,降采样单元的步长为2,一旦搜搜到所有正常单元和降采样单元,就将其堆叠为完整网络。对于CIFAR-10,堆叠N个正常单元作为一个Block。如下图:
网络结构可微模型采样
定义神经结构为,参数为,NAS的目标是为了找到一个结构,实现当以最小化训练损失训练参数后,使得网络结构在验证集上的准确率最小化。数学表示:
优化问题表示网络结构的最佳权重,能实现训练损失最小化。将负的对数似然最为训练对象,分别表示训练集合验证集。
一个网络结构包含许多同样的神经单元,该单元由搜索空间中搜索而来,具体地,节点之间,从候选操作集合中采样一个变换函数,实际上是从一个离散概率分布中采样而来,在搜索过程中,计算单元中每个节点:
节点计算离散概率分布是被一个可学习的概率质量函数表示的:
是由维可学习向量中的第个元素,表示候选操作集合中第个操作。因此,实际上编码了相邻节点之间的操作采样概率,因此,一个单元的采样分布表示为的集合。
给定上两式,可以得到,即可计算训练集上的损失,但因为采样于离散概率分布,因此梯度不能反传至,为了令方向传播能进行,使用Gumbel-Max的思想重新表达上式:
Gumbel-Max其中,独立同分布于Gumbel(0,1),,其中服从0到1之间的均匀分布。是向量的第个分量,是节点之间的操作的参数权重。然后,以SoftMax函数来放松argmax,实际上就是Gumbel Softmax:
Gumbel Softmax为温度系数,当其趋于零时,。本文在前向传递时用argmax函数,在后向传播中用Gumbel softmax函数,这样就可以用梯度后向传播了。
训练:
上述损失函数的主要挑战是学习一个结构,为了避免计算高阶导,我们应用替代优化策略以迭代方式更新采样分布和所有函数的权重。
Eq.(8):Loss的一般形式该采样分布
由的集合编码而得到,参数
是的集合,表示所有单元所有操作的参数。
对于一个采样数据,首先采样结构,计算网络输出(仅与有关)。
算法1:(alternative optimization strategy (AOS))
算法结构
训练完成之后,需要从分布中得到最后的网络结构。每个节点都与前个节点有关,对于CNN,设置,假设是候选索引集,定义节点之间的连接重要性:,对于每个节点,保留先前节点中有最大重要性的2个连接,对于已经保留的节点之间的连接,使用函数来确定节点之间的操作。
本文固定降采样单元,仅仅搜索正常单元。设计的降采样单元如下:
实验
识别率基本与DARTS持平的情况下,搜索时间比它快5倍以上。
实验 实验
网友评论