美文网首页
深度学习矩阵乘法优化算法Strassen

深度学习矩阵乘法优化算法Strassen

作者: 半笔闪 | 来源:发表于2019-12-02 15:56 被阅读0次

上一篇介绍了深度学习中卷积的优化算法,这一篇来介绍一下矩阵乘法的优化算法——Strassen算法。我一直觉得Strassen和winograd有异曲同工之妙。矩阵乘法是深度学习中无处不在的运算,同样它是非常耗时的运算。

  • 朴素方法
    C = A • B,A是n x p 的矩阵,B是 p x m 的矩阵。
    C[i][j] = sum(A[i][k] * B[k][j]) for k = 0,1,2,...,p,i的区间为[0,n],j的区间为[0,m]。
    通过上面的式子我们可以知道,要想计算C,通过3级for循环即可,时间复杂度其实就是O(nmp),这里为了好说明,先让n=m=p,即所有的矩阵都是方阵。此时时间复杂度就是O(n^3)。
for(int i = 0; i < n; ++i){
    for(int i = 0; i < n; ++i){
        C[i][j] = 0;
        for(int i = 0; i < n; ++i){
            C[i][j] += a[i][k] * b[k][j];
        }
    }
}
  • 分治
    我们都应该知道一种排序算法叫做归并排序,这应该是分治算法的典型。分治算法的思想其实有三部:
    1、将问题分解为规模更小的子问题
    2、将子问题解出
    3、将子问题的解合并
    Strassen算法基于分治的思想:把nn的矩阵分割成4个n/2n/2的矩阵。如下:
    矩阵分法
    可以得出如下式子:
    C11 = A11 • B11 + A12 • B21
    C12 = A11 • B12 + A12 • B22
    C21= A21 • B11 + A22 • B21
    C22 = A21 • B12 + A22 • B22
    以上的式子没个都需要两个矩阵乘法和一个矩阵加法。也就是要想计算出C矩阵,总共需要8次矩阵乘法和4个矩阵加法。
  • Strassen算法
    我认为Strassen算法跟winograd算法相似之处在于Strassen算法的思想也是希望通过减少乘法,当然这里是矩阵乘法。根据分治法把矩阵四分,创建中间矩阵:
    S1 = B12 - B22
    S2 = A11 + A12
    S3 = A21 + A22
    S4 = B21 - B11
    S5 = A11 + A22
    S6 = B11 + B22
    S7 = A12 - A22
    S8 = B21 + B22
    S9 = A11 - A21
    S10 = B11 + B12
    计算:
    P1 = A11 • S1
    P2 = S2 • B22
    P3 = S3 • B11
    P4 = A22 • S4
    P5 = S5 • S6
    P6 = S7 • S8
    P7 = S9 • S10
    可根据以上7个结果计算C矩阵:
    C11 = P5 + P4 - P2 + P6
    C12 = P1 + P2
    C21 = P3 + P4
    C22 = P5 + P1 - P3 - P7
    我们来数一数现在需要多少矩阵乘法和矩阵加法:S需要10次矩阵加法,P需要7次矩阵乘法,最后计算出C矩阵需要8次矩阵加法。因此计算C矩阵总共需要7次矩阵乘法和18次矩阵加法。相比于朴素算法,减少了一次矩阵乘法,增加了14次矩阵加法。而根据计算机原理,当达到一定量级,14个矩阵加法是要快于1一个矩阵乘法的。就这样实现了矩阵乘法的加速。
    当然,Strassen算法也有适用场景,具体到实现,Strassen算法是做递归运算,此时需要创建大量动态数组,而分配这些数组的内存空间也要占用计算时间。故Strassen算法实现时需要设置一些界限,当达到一定量级时才使用Strassen算法进行加速,一般情况使用朴素方法进行计算。还有,当矩阵稀疏时也不适合使用Strassen算法,稀疏时矩阵中为0的元素就不需要做乘法,就无法体现出Strassen算法减少乘法的优势。

相关文章

网友评论

      本文标题:深度学习矩阵乘法优化算法Strassen

      本文链接:https://www.haomeiwen.com/subject/svkfgctx.html