上一篇介绍了深度学习中卷积的优化算法,这一篇来介绍一下矩阵乘法的优化算法——Strassen算法。我一直觉得Strassen和winograd有异曲同工之妙。矩阵乘法是深度学习中无处不在的运算,同样它是非常耗时的运算。
- 朴素方法
C = A • B,A是n x p 的矩阵,B是 p x m 的矩阵。
C[i][j] = sum(A[i][k] * B[k][j]) for k = 0,1,2,...,p,i的区间为[0,n],j的区间为[0,m]。
通过上面的式子我们可以知道,要想计算C,通过3级for循环即可,时间复杂度其实就是O(nmp),这里为了好说明,先让n=m=p,即所有的矩阵都是方阵。此时时间复杂度就是O(n^3)。
for(int i = 0; i < n; ++i){
for(int i = 0; i < n; ++i){
C[i][j] = 0;
for(int i = 0; i < n; ++i){
C[i][j] += a[i][k] * b[k][j];
}
}
}
- 分治
我们都应该知道一种排序算法叫做归并排序,这应该是分治算法的典型。分治算法的思想其实有三部:
1、将问题分解为规模更小的子问题
2、将子问题解出
3、将子问题的解合并
Strassen算法基于分治的思想:把nn的矩阵分割成4个n/2n/2的矩阵。如下:
矩阵分法
可以得出如下式子:
C11 = A11 • B11 + A12 • B21
C12 = A11 • B12 + A12 • B22
C21= A21 • B11 + A22 • B21
C22 = A21 • B12 + A22 • B22
以上的式子没个都需要两个矩阵乘法和一个矩阵加法。也就是要想计算出C矩阵,总共需要8次矩阵乘法和4个矩阵加法。 - Strassen算法
我认为Strassen算法跟winograd算法相似之处在于Strassen算法的思想也是希望通过减少乘法,当然这里是矩阵乘法。根据分治法把矩阵四分,创建中间矩阵:
S1 = B12 - B22
S2 = A11 + A12
S3 = A21 + A22
S4 = B21 - B11
S5 = A11 + A22
S6 = B11 + B22
S7 = A12 - A22
S8 = B21 + B22
S9 = A11 - A21
S10 = B11 + B12
计算:
P1 = A11 • S1
P2 = S2 • B22
P3 = S3 • B11
P4 = A22 • S4
P5 = S5 • S6
P6 = S7 • S8
P7 = S9 • S10
可根据以上7个结果计算C矩阵:
C11 = P5 + P4 - P2 + P6
C12 = P1 + P2
C21 = P3 + P4
C22 = P5 + P1 - P3 - P7
我们来数一数现在需要多少矩阵乘法和矩阵加法:S需要10次矩阵加法,P需要7次矩阵乘法,最后计算出C矩阵需要8次矩阵加法。因此计算C矩阵总共需要7次矩阵乘法和18次矩阵加法。相比于朴素算法,减少了一次矩阵乘法,增加了14次矩阵加法。而根据计算机原理,当达到一定量级,14个矩阵加法是要快于1一个矩阵乘法的。就这样实现了矩阵乘法的加速。
当然,Strassen算法也有适用场景,具体到实现,Strassen算法是做递归运算,此时需要创建大量动态数组,而分配这些数组的内存空间也要占用计算时间。故Strassen算法实现时需要设置一些界限,当达到一定量级时才使用Strassen算法进行加速,一般情况使用朴素方法进行计算。还有,当矩阵稀疏时也不适合使用Strassen算法,稀疏时矩阵中为0的元素就不需要做乘法,就无法体现出Strassen算法减少乘法的优势。
网友评论