核心思想
首先说这篇文章的核心思想,它认为,在剪枝时,将一个filter剪掉,这个filter引起的loss越小,说明这个filter越不重要,这就叫做oracle。文中提出最好的剪枝算法,就是对每个filter的重要度用oracle做评估,然后选出最小的剪掉;然后再fine-tuning并重新一遍,再把最小的剪掉。
这个方法一听就知道不靠谱,因为计算量太大了,所以作者提出我们可以用一些近似的方法去估计每个filter的重要度,但核心还是想用oracle。
方法
下面就介绍本文算法的三个核心内容:
1. Damage Isolation
首先,最好的评估oracle的方式是把一个filter删掉,看看网络最终的loss变化有多大。文中这一节讲了很多内容最后就是想说明:我们想只用下一层的输出,而非最终的输出来评估删掉一个filter造成的影响。为啥?因为快啊,不然每个filter都要跑整个网络,多慢。举个栗子:你删了第i层的一个filter,然后就看看第i+1层的输出变化有多大。等等,为啥不直接用第i层的?因为你第i层的filter数量都和原来不一样了,咋做l2-loss?没法比,只能比i+1。公式如下:

2. Multi-path Training-time Pruning Framework
这部分内容多,也很绕,挨个来说:
2.1 首先,以往的的剪枝算法都是先剪枝,再fine-tune,太浪费时间了,所以作者想,我们边剪枝,边训练。于是有了下面这张图。
图中有两条路径。左边的称为base path(u),代表剪枝后的网络;右边的称为score path(v),用于一一对应的给每个filter的重要度打分。所以在一次forward的过程中,我们既可以根据输出得出filter的重要度,又可以直接用loss做反向传播更新模型。以下图为例,带花纹的的代表已经剪枝的,假设conv1已经剪枝了两个,现尝试剪枝第3个,此时只把v当中的给它删掉,u当中的留下,所以uv输出的差值就是删掉这个filter而引起的loss,也就是图中所示的t。
但是这里有一个问题,你如果每次只选取一个filter去测试它的重要度,那工作了也太大了,所以作者采用每次随机mask一堆(x个)filters,然后把t值记录到每个选取到的filter中,在经过n个batch后,每个filter中记录的loss虽然不是自己的,但是它的期望应该是等于”其他filter的平均loss+当前filter的loss“。所以各个filter之间依然是有可比性的。

3. Binary Filter Search
这里部分就是想说,对于上一部分所说的x的值应该如何确定,答案是每次都选一半!等计算出重要度后,在从最不重要的一般中再选一半。整体的算法如下:

这里大概记录一下以便以后翻阅回忆吧,这个过程实在是太绕了,懒得写完了。
网友评论