美文网首页
#算法学习录#Strassen矩阵乘法

#算法学习录#Strassen矩阵乘法

作者: LRC_cheng | 来源:发表于2016-04-30 01:25 被阅读0次

    今天我们谈谈一个“土豪”算法——Strasen矩阵算法
    之说以说它“土豪”就是因为其带来了巨大的空间开销。
    先来考察一个问题:请用三次实数乘法计算复数a+bi和c+di相乘。
    由于:
    a×(c+d)=ac+ad=s1 ;
    b×(c-d)=bc-bd=s2 ;
    d×(a+b)=ad+bd=s3 ;
    故有实部:s1 -s3 =ac-bd,
    虚部:s2+ s3 =ad+bc。
    这样,四次的乘法就变成三次乘法。

    Strassen矩阵乘法也是如此,把A,B,C矩阵分解为n/2×n/2子矩阵,进行7次递归计算n/2×n/2矩阵的乘法,其运行时间的递归式:

    T(n)= Θ(1)             if n=1;

          7T(n/2)+Θ(n^2 )    if n>1;

    令:
    S1=B12-B22;
    S2=A11+A12;
    S3=A21+A22;
    S4=B21-B22;
    S5=A11+A22;
    S6=B11+B22;
    S7=A12-A22;
    S8=B21+B22;
    S9=A11-A21;
    S10=B11+B12;
    那么: P1= A11·S1 = A11·(B12-B22)
    P2= B22·S2 = B22·(A11+A12)
    P3= B11·S3 = B11·(A21+A22)
    P4= A22·S4 = A22·(B21-B22)
    P5= S5·S6 = (A11+A22)·(B11+B22)
    P6= S7·S8 = (A12-A22)·(B21+B22)
    P7= S9·S10 = (A11-A21)·(B11+B12)

    C11= P5 + P4 - P2 + P6=A11×B11+A12×B21
    C12= P1 + P2=A11×B12+A12×B22
    C21= P3 + P4=A21×B11+A22×B21
    C22= P5 + P1 – P3 – P7=A21×B21+A22×B22

    Strassen算法的具体实现(C语言):
    int Strassen(int **A, int **B, int **Result, int Size){
     if (Size == 1){
      //直接计算C11
      Result[0][0] = A[0][0] * B[0][0];
      return 0;
     }
     int NewSize = Size / 2;
     /*分块矩阵*/
     int **A11, **A12, **A21, **A22;
     int **B11, **B12, **B21, **B22;
     int **C11, **C12, **C21, **C22;

     int **P1, **P2, **P3, **P4, **P5, **P6, **P7;
     /*存放数组A、B(i、j)的临时变量*/
     int **AResult, **BResult;

     A11 = new int*[NewSize];
     A12 = new int*[NewSize];
     A21 = new int*[NewSize];
     A22 = new int*[NewSize];

     B11 = new int*[NewSize];
     B12 = new int*[NewSize];
     B21 = new int*[NewSize];
     B22 = new int*[NewSize];

     C11 = new int*[NewSize];
     C12 = new int*[NewSize];
     C21 = new int*[NewSize];
     C22 = new int*[NewSize];

     P1 = new int*[NewSize];
     P2 = new int*[NewSize];
     P3 = new int*[NewSize];
     P4 = new int*[NewSize];
     P5 = new int*[NewSize];
     P6 = new int*[NewSize];
     P7 = new int*[NewSize];

     AResult = new int*[NewSize];
     BResult = new int*[NewSize];

     for (int i = 0; i < NewSize; i++)
     {
      A11[i] = new int[NewSize];
      A12[i] = new int[NewSize];
      A21[i] = new int[NewSize];
      A22[i] = new int[NewSize];

      B11[i] = new int[NewSize];
      B12[i] = new int[NewSize];
      B21[i] = new int[NewSize];
      B22[i] = new int[NewSize];

      C11[i] = new int[NewSize];
      C12[i] = new int[NewSize];
      C21[i] = new int[NewSize];
      C22[i] = new int[NewSize];

      P1[i] = new int[NewSize];
      P2[i] = new int[NewSize];
      P3[i] = new int[NewSize];
      P4[i] = new int[NewSize];
      P5[i] = new int[NewSize];
      P6[i] = new int[NewSize];
      P7[i] = new int[NewSize];

      AResult[i] = new int[NewSize];
      BResult[i] = new int[NewSize];


     }

     //对分块矩阵赋值
     for (int i = 0; i < NewSize; i++)
     {
      for (int j = 0; j < NewSize; j++)
      {
       A11[i][j] = A[i][j];
       A12[i][j] = A[i][j + NewSize];
       A21[i][j] = A[i + NewSize][j];
       A22[i][j] = A[i + NewSize][j + NewSize];

       B11[i][j] = B[i][j];
       B12[i][j] = B[i][j + NewSize];
       B21[i][j] = B[i + NewSize][j];
       B22[i][j] = B[i + NewSize][j + NewSize];

      }
     }

     //计算P1 = A11*(B12-B22)
     Sub(B12, B22, BResult, NewSize);
     Strassen(A11, BResult, P1, NewSize);

     //计算P2 = (A11+A12)*B22
     Add(A11, A12, AResult, NewSize);
     Strassen(AResult, B22, P2, NewSize);

     //计算P3 = (A21+A22)*B11
     Add(A21, A22, AResult, NewSize);
     Strassen(AResult, B11, P3, NewSize);

     //计算P4 = A22*(B21-B11)
     Sub(B21, B11, BResult, NewSize);
     Strassen(A22, BResult, P4, NewSize);

     //计算P5 = (A11+A22)*(B11+B22)
     Add(A11, A22, AResult, NewSize);
     Add(B11, B22, BResult, NewSize);
     Strassen(AResult, BResult, P5, NewSize);

     //计算P6 = (A12-A22)*(B21+B22)
     Sub(A12, A22, AResult, NewSize);
     Add(B21, B22, BResult, NewSize);
     Strassen(AResult, BResult, P6, NewSize);

     //计算P7 = (A11-A21)*(B11+B12)
     Sub(A11, A21, AResult, NewSize);
     Add(B11, B12, BResult, NewSize);
     Strassen(AResult, BResult, P7, NewSize);

     //计算C11,C12,C21,C22
     //C11 = P5 + P4 - P2 + P6;
     Add(P5, P4, AResult, NewSize);
     Sub(AResult, P2, BResult, NewSize);
     Add(BResult, P6, C11, NewSize);

     //C12=P1+P2
     Add(P1, P2, C12, NewSize);

     //C21=P3+P4
     Add(P3, P4, C21, NewSize);

     //C22=P5+P1-P3-P7
     Add(P5, P1, C22, NewSize);
     Sub(C22, P3, C22, NewSize);
     Sub(C22, P7, C22, NewSize);

     //合并C11,C12,C21,C22
     for (int i = 0; i < NewSize; i++)
     {
      for (int j = 0; j < NewSize; j++)
      {
       Result[i][j] = C11[i][j];
       Result[i][j + NewSize] = C12[i][j];
       Result[i + NewSize][j] = C21[i][j];
       Result[i + NewSize][j + NewSize] = C22[i][j];
      }
     }

     //删除数组,回收资源
     for (int i = 0; i < NewSize; i++){
      delete[] A11[i]; delete[] A12[i]; delete[] A21[i]; delete[] A22[i];
      delete[] B11[i]; delete[] B12[i]; delete[] B21[i]; delete[] B22[i];
      delete[] C11[i]; delete[] C12[i]; delete[] C21[i]; delete[] C22[i];
      delete[] P1[i]; delete[] P2[i]; delete[] P3[i]; delete[] P4[i]; delete[] P5[i]; delete[] P6[i]; delete[] P7[i];
      delete[] AResult[i]; delete[] BResult[i];
     }
     delete[] A11; delete[] A12; delete[] A21; delete[] A22;
     delete[] B11; delete[] B12; delete[] B21; delete[] B22;
     delete[] C11; delete[] C12; delete[] C21; delete[] C22;
     delete[] P1; delete[] P2; delete[] P3; delete[] P4; delete[] P5; delete[] P6; delete[] P7;
     delete[] AResult; delete[] BResult;
     return 0;
    }

    //矩阵相加
    void Add(int **A, int **B, int **Q, int Size){
     for (int i = 0; i < Size; i++){
      for (int j = 0; j < Size; j++){
       Q[i][j] = A[i][j] + B[i][j];
      }
     }
    }

    //矩阵相减
    void Sub(int**A, int**B, int **Q, int Size){
     for (int i = 0; i < Size; i++){
      for (int j = 0; j < Size; j++){
       Q[i][j] = A[i][j] - B[i][j];
      }
     }
    }
    演示结果:
     


    与暴力求解相比:
       for(i=0;i<m;i++)
         for(j=0;j<m;j++){
             C[i][j]=0;
          for(k=0;k<n;k++)
                C[i][j]+=A[i][k]*B[k][j];            
    }   
    其运行时间(n^lg7,2.80<lg7<2.81)比暴力求解(n3)稍快,但其精度较低(在处理小数计算时),且消耗了大量存储空间。

    最后附上源代码:https://github.com/LRC-cheng/Algorithms_Practise.git

    相关文章

      网友评论

          本文标题:#算法学习录#Strassen矩阵乘法

          本文链接:https://www.haomeiwen.com/subject/rqrkrttx.html