美文网首页
笔记:Checkmate: breaking the memor

笔记:Checkmate: breaking the memor

作者: 裳湑_ | 来源:发表于2021-01-26 13:11 被阅读0次

    原文链接:https://arxiv.org/abs/1910.02653
    我个人认为非常优美的一篇论文,居然没人发过笔记,太奇怪了,那就我自己稍微写写。分段和原文不同,但公式标号基本是一样的。

    1. background

    这篇论文要解决的问题是训练过程中显存不够,如何节约开销。一个非常自然的想法就是把不用的显存释放掉,从而节约开销。这在推理过程中是很有用的,很自然能发现在线性网络(VGG等)中,这可以让显存开销达到O(1)的水平。然而在训练过程中,由于反向计算梯度的过程中这些中间结果还会被用到,当backward开始时,显存开销为O(n)(n为算子个数,因此也近似是中间张量的个数)。

    Checkpoint这篇论文中,给出了一个自然的思路:只保留一部分中间结果,对于其它中间结果,当需要时再用保留到的来计算。如果计算的开销很小,这个trade-off是值得的(BTW,这些计算通常可以被fuse到其它算子上)。典型代表就是CNN中常见的Conv-Batch Norm-Noise-Activate,除卷积层外的算子都是开销低的。

    Checkpoint只讨论了形如VGG的线形网络(即input依次经过算子1、2、3...没有其它路径):一言蔽之,将网络分段(segment),每个segment保存最后一个算子的输出,其输入为上一个segment的输出,在backward计算梯度的时候,segment内部的算子会重新跑一遍,用重新跑的结果来算梯度。
    这样,我们至多将网络再跑了一遍,但可以将显存开销降到O(\sqrt n):每\sqrt n个算子分成一个segment。实际应用中不会这么好。

    结合实际需求,作者又提出了两个算法,一个是用户提供指示函数m,用来表示每个算子可否被重用(虽然原文写的m的值域是自然数,但实际只能是\{0,1\},否则原文中的mapa还需要修正),另一个是提供显存上界B和重计算高开销的算子集合C(在m的值域修改成\{0,1\}后,Cm是一样的)。后者会根据显存上界贪心地确认最终的m,然后调用前者。(这个贪心也有点问题,实际上会可能超过显存上界)

    这个算法可以进一步优化:一个segment可以视为一个子图,从而可以递归调用该算法。(这样最终的m终于大于1了,尽管在原文alg.2里它还是只在\{0,1\}上)
    虽然上面一段说到有一些奇怪的问题,但它已经的确在torch.util.checkpoint和MXNet中用到了。最终性能不会有明显下降,但显存开销有明显优化。

    2. intro

    事情没有到此结束。上文的算法在VGG形状的网络里没什么大问题,m值域假了都无所谓。但在ResNet甚至ResNeXt广泛使用后,网络不再只是线形,而是有更多样的形状,这时segment的观点就不适用了。

    Checkmate解决了这个问题(适用于任意图),同时修复了m的值在非递归时不大于1的问题。文章中还提到它可以兼顾跨层保存的张量和层内计算中的张量,不过Checkpoint也可以做到,因此不是一个nontrivial的特性。总的来说,Checkmate将segment方法拓广为rematerialization,即有需求时再重新计算。
    Checkmate的思想简单且优美,即利用线性规划。概括来说,它的主要思路是:

    1. 每个时间点的显存使用量是关于哪些张量目前被保存的线性函数;
    2. 张量在每个时间点是否重新计算、保存或释放之间是布尔表达式的关系。对应的布尔不等式可以转换为线性不等式,这是论文中最优美的一点。
    3. 重新计算的代价与每个时间点是否重新计算是线性关系,且需要尽可能小,因此将其作为线性规划的目标。

    3. method

    然后我们就可以开始列线性规划的方程了:

    3.1 basic

    首先引入符号:G=(V,E)为计算图,包含了backward的部分。图上的点为V=\{v_1,v_2\dots v_n\},其下标符合拓扑序。对于v_i的输出,其显存占用为M_i,计算的代价为C_i。假设算子都是单一输出,从而可以用v_i来表示v_i算子所输出的张量。
    接着引入两组布尔变量R,S\in \{0,1\},其中R_{t,i}表示在阶段tv_i是否重计算,S_{t,i}表示v_i是否会从上一阶段保留到当前阶段。

    由此,我们可以写出最基本的约束条件:张量的计算要满足依赖关系、张量的保存需要首先经过计算、边界条件。这会被写成:
    \arg\min_{R,S}\sum_{t}\sum_{i} C_{t,i}R_{t,i} \tag{1a}
    subject to:
    R_{t,j}\leq R_{t,i}+S_{t,i},\forall t\forall (v_i,v_j)\in E \tag{1b}
    S_{t,i}\leq R_{t-1,i}+S_{t-1,i},\forall t>1\forall i \tag{1c}
    \sum_i S_{1,i}=0, \tag{1d}
    \sum_t R_{t,n}\geq 1\tag{1e}
    R,S\in \{0,1\}\tag{1f}
    最优化目标为计算的总代价最少,其约束的含义依次为:

    1. 计算v_j前,考虑其依赖的v_i:若在之前的阶段计算,则需要保存到该阶段(S_{t,i})。在本阶段计算(R_{t,i}),则需要在本阶段保存到v_j的计算,这点会体现在后续的显存约束;
    2. 张量为了保存要么在本阶段计算,要么在上一阶段保存至本阶段;
    3. 初始状态下所有张量都没有保存;
    4. 最终整个图都被计算过,保证有输出;
    5. R,S的值域限制
    3.2 memory: intro of U

    然后写出显存占用的表达式。显存占用要满足的是每个时刻(注意不是阶段)的占用都不大于一个给定常数M_{budget}

    首先还是引入符号:使用U_{t,k}表示阶段t中,算子k的重计算被考虑后的显存使用量。由此,我们可以写出显存上限对应的约束是
    U_{t,k}\leq M_{budget},\forall t\forall k\tag{2a}
    且有
    U_{t,0}=M_{input}+2M_{param}+\sum_i M_iS_{t,i}\tag{2b}
    其中前两项是假设输入、参数和参数的梯度都留有保存空间,均为常数。

    3.3 memory: intro of FREE

    为了约束每个阶段中间过程的显存使用,继续引入符号:

    使用FREE_{t,i,k}表示阶段t中,计算v_k后,v_i的显存是否被释放;
    使用DEPS[k]表示v_k在图上的前驱,即DEPS[k]=\{i:(v_i,v_k)\in E\},使用USERS[k]表示图上的后继。

    因此有:
    U_{t,k+1}=U_{t,k}-mem\_freed_t(v_k)+R_{t,k+1}M_{k+1}\tag{3}
    mem\_freed_t(v_k)=\sum_{i\in DEPS[k]\cup\{k\}}M_i\times FREE_{t,i,k}\tag{4}
    即算子计算后考虑其前驱是否可以释放。
    其中判断是否释放的函数为
    FREE_{t,i,k}=R_{t,k}\times (1-S_{t+1,i})\Pi_{j\in USERS[i],j>k}(1-R_{t,j})\tag{5}
    这个式子可以用布尔表达式的观点来看:v_k计算后释放v_i的条件是v_k被计算(R_{t,k})、且不再需要保存v_i(1-S_{t,k})、且后续依赖v_i的张量(v_j)都不会计算(1-R_{t,j})。
    由于i在一个stage至多计算一次、因此至多释放一次,有约束
    \sum_{k\in USERS[i]}FREE_{t,i,k}\leq 1,\forall t\forall i\tag{5'}

    3.4 analysis of FREE

    到上一步,显存的约束已经写完了。但是需要注意到,(5)不是线性表达式,而是由布尔与运算构成的乘法。
    解决这个问题是这篇文章最优美的地方。它使用了下述两个引理:

    Lemma 4.1 If x_1\dots x_n\in\{0,1\}, then \Pi_i x_i=1_{\sum_i (1-x_i)=0}(x_1,\dots x_n)
    Lemma 4.2 If 0\leq y\leq \kappa, then x=1_{y=0}(y) if and only if x\in\{0,1\} and (1-x)\leq y\leq \kappa(1-x)

    这两个引理都很容易证明。将4.1应用在(5)上,有
    FREE_{t,i,k}=1_{num\_hazards(t,i,k)=0}(t,i,k)\tag{6}
    其中num\_hazards(t,i,k)=1-R_{t,k}+S_{t+1,i}+\sum_{j\in USERS[i],j>k}R_{t,j}\tag{6'}
    至此,乘法已经转换为加法。然而它还在示性函数的条件上,需要使用4.2进一步变换:
    FREE_{t,i,k}\in\{0,1\}\tag{7a}
    1-FREE_{t,i,k}\leq num\_hazards(t,i,k)\tag{7b}
    \kappa(1-FREE_{t,i,k})\geq num\_hazards(t,i,k)\tag{7c}
    相当精彩。

    3.5 conclusion and prune

    因此,我们可以将原问题转化为:
    优化1a,约束条件为1b, 1c, 1d, 1e, 1f, 2a, 2b, 3, 4, 6', 7a, 7b, 7c。
    注意这是整数线性规划(ILP),属于NP问题,同时FREE的下标表明参数量至多为O(n^3)(当然现实场景的计算图通常会稀疏很多),因此需要一些优化:

    首先是根据拓扑序,指定每个阶段计算的算子,从而将条件下三角化:
    R_{i,i}=1,\forall i\tag{8a}
    \sum_{i\geq t}S_{t,i}=0,\forall t\tag{8b}
    \sum_{i>t}R_{t,i}=0, \forall t\tag{8c}
    这三个条件用来替换原条件中的两个边界条件(1d), (1e)
    论文中实验表明在八层的网络(加上backward后共有17个算子),原问题需要9.4小时,新问题只需要0.23秒。

    然后观察到FREE_{t,k,k}=1时,可以直接令R_{t,k}=0。因此将(4)中求和范围里的\{k\}删去,假设FREE_{t,k,k}=0。这可以减少n^2个变量。

    3.6 approximation

    ILP是NP问题,所以还是要考虑近似算法。作者采用的是将原问题的整数域\{0,1\}放到实数域[0,1]上来做,然后获得一个近似解S^*,R^*。将S^*近似为S^{int}S^{int}_{t,i}:=1_{S^*_{t,i}>0.5}
    然后通过S^{int}来计算对应的R^{int},即在S_{t-1,i}^{int}\neq S_{t,i}^{int}处填补对应的R_{t,i}=1FREE_{t,i,\_},然后再考虑计算i所依赖的张量是否保存,若未保存则在该阶段添加其计算。

    该近似过程可能让显存约束条件被违背,解决策略是在求实数解的时候把原本\leq M_{budget}的约束改为\leq (1-\epsilon) M_{budget}\epsilon=0.1时实验效果较好。

    相关文章

      网友评论

          本文标题:笔记:Checkmate: breaking the memor

          本文链接:https://www.haomeiwen.com/subject/pndwzktx.html