美文网首页
LeetCode:字符串相乘

LeetCode:字符串相乘

作者: 阿臻同学 | 来源:发表于2020-11-21 13:40 被阅读0次

    字符串相乘 - LeetCode

    导入依赖

    主要依赖的库有:

    • math:用来进行幂运算。
    • random:用来生成随机测试用例
    import math
    import random
    

    拆分、填充

    • 默认输入的格式为不固定长度的字符串,如 "123456"
    • 需要对输入的字符串拆分成长度为 N 的数字类型列表,如 [1,2,3,4,5,6]
    • 并对其进行填充,找到 2 的指数 i,满足 2^{i-1}\lt 2N \le 2^{i},如:

    N=6 时,有 i=4,2^4=16

    满足 2^{3} \lt 2 \times 6 \le 2^{4}

    • 使用 0 对列表进行填充后的长度L满足: L = 2^{i},如 [1,2,3,4,5,6,0,0,0,0,0,0,0,0,0,0]
    def to_list(num1:str, num2:str) -> tuple:
        # 拆分为 list
        a = [int(i) for i in num1]
        b = [int(i) for i in num2]
        
        # 反转列表,将低阶项系数放在列表前面
        a.reverse()
        b.reverse()
        max_len = max(len(a),len(b))
        
        # 对齐使长度相等
        l = len(a)-len(b)
        zeros = [0] * abs(l) 
        if l < 0:
            a = a + zeros
        elif l > 0:
            b = b + zeros
            
        # 补充前导 0,使得长度为 2^n
        fill_count = int(2**math.ceil(math.log2(max_len*2)) - max_len)
        fill = [0] * fill_count
        return a+fill,b+fill
    

    多项式表示

    对于输入的两个数 AB ,将其处理成两个多项式:

    A(x) = a_0 + a_1x^1 + a_2x^2 + \cdots + a_{N-1}x^{N-1} = \sum_{j=0}^{N-1}a_jx^j

    B(x) = b_0 + b_1x^1 + b_2x^2 + \cdots + b_{N-1}x^{N-1} = \sum_{j=0}^{N-1}b_jx^j

    最终的目标是对多项式 C(x) = A(x) \times B(x) 进行求解。

    C(x) = c_0 + c_1x^1 + c_2x^2 + \cdots + c_{2N-1}x^{2N-1} = \sum_{j=0}^{2N-1}c_jx^j

    傅里叶变换求解

    • 将处理后的两个列表进行快速傅里叶变换(fft),得到 2N 个点值对的取值

      [x_0,x_1,x_2,\cdots,x_{2N-1}]

    • 得到两个新的列表并将其按元素相乘,得到待求解的多项式 C(x) 的值

      [A(x_0)B(x_0),A(x_1)B(x_1),A(x_2)B(x_2),\cdots,A(x_{2N-1})B(x_{2N-1})]

    • 再进行逆离散傅里叶变换(idft),将点值表示转换为 C(x) 的系数;

      [c_0,c_1,c_2,\cdots,c_{2N-1}]

    • 傅里叶变换后的结果是虚数,其实部四舍五入后取整,便是结果多项式对应项的系数,将以 10 为底的多项式计算求和,得到乘法的结果。

      \sum_{j=0}^{2N-1} c_i \times 10^j

    def multiply(num1: str, num2: str) -> str:
        l = len(num1)
        a,b = to_list(num1, num2)
        # 傅里叶变换
        a_fft, b_fft = fft(a), fft(b)
        t = []
        # 对应项相乘
        for i in range(len(a_fft)):
            t.append(a_fft[i] * b_fft[i]) 
        # 逆傅里叶变换
        ans = idft(t)
        sum = 0
        # 计算多项式
        for i,r in enumerate(ans):
            # 实部四舍五入取整
            sum += int(r.real+0.5) * (10 ** i)
        return str(sum)
    

    傅里叶变换实现

    傅里叶变换与逆傅里叶变换的主要区别在于:逆傅里叶变换需要对计算的结果除以 N (并不是在递归中进行),并且在计算的过程中 \omega = \omega^{-1}

    def _ft(l:list, idft = False):
        """
        基础的变换方法,通过变量控制进行dft还是idft
        
        :param bool idft: 控制进行傅里叶变换还是逆傅里叶变换
        """
        
        n = len(l)
        if n == 1:
            return l
        
        # dft 与 idft 分别处理 $\omega$
        o_n_e = -2j if idft else 2j
        o = 1
        o_n = math.e ** (o_n_e * math.pi / n)
        
        # 拆分奇偶项
        even_index = l[::2]
        odd_index = l[1::2]
        
        y_even = _ft(even_index, idft)
        y_odd = _ft(odd_index, idft)
    
        y = [0]*n
        for i in range(n//2):
            y[i] = y_even[i] + o * y_odd[i]
            y[i+n//2] = y_even[i] - o * y_odd[i]
            o *= o_n
        return y
    
    def fft(l:list):
        """
        傅里叶变换
        """
        output = _ft(l)
        return output
    
    def idft(l:list):
        """
        逆傅里叶变换
        """
        n = len(l)
        output = _ft(l,True)
        # 将计算的结果除以 $N$
        output = [i/n for i in output]
        return output
    

    测试

    multiply()方法输出的结果与自带的乘法计算结果进行比较,并输出测试结果。

    def test(num1:str, num2:str):
        r = int(multiply(num1,num2))
        s = int(num1)*int(num2)
        t = 30
        print(f"{'-'*t} Test {'-'*t}")
        print(f"Test case: \n\t{num1} \n\t{num2}")
        print(f"Program output: \n\t{r}")
        print(f"Expected output: \n\t{s}")
        print(f"❌ FAILED" if r != s else "✔ OK")
        return r == s
    

    编写测试用例

    # 测试用例数
    test_cases = 10
    # 数据长度
    INT_MAX = 1e100
    for i in range(test_caese):
        num1 = str(random.randint(0, INT_MAX))
        num2 = str(random.randint(0, INT_MAX))
        test(num1, num2)
    

    LeetCode AC 代码

    class Solution:
        def to_list(self, num1:str, num2:str) -> tuple:
            a = [int(i) for i in num1]
            b = [int(i) for i in num2]
            a.reverse()
            b.reverse()
            l = len(a)-len(b)
            max_len = max(len(a),len(b))
            # 对齐使长度相等
            zeros = [0] * abs(l) 
            if l < 0:
                a = a + zeros
            elif l > 0:
                b = b + zeros
            # 补充前导 0,使得长度为 2^n
            fill_count = int(2**math.ceil(math.log2(max_len*2)) - max_len)
            fill = [0] * fill_count
            return a+fill,b+fill
        
        def multiply(self, num1: str, num2: str) -> str:
            l = len(num1)
            a,b = self.to_list(num1, num2)
            a_fft, b_fft = self.fft(a), self.fft(b)
    
            
            t = []
            _3 = []
            for i in range(len(a_fft)):
                t.append(a_fft[i] * b_fft[i]) 
            ans = self.idft(t)
            
            sum = 0
            for i,r in enumerate(ans):
                sum += int(r.real+0.5) * (10 ** i)
            return str(sum)
            
            
    
        def _ft(self, l:list, idft = False):
            n = len(l)
            if n == 1:
                return l
            o_n_e = -2j if idft else 2j
            even_index = l[::2]
            odd_index = l[1::2]
            o = 1
            o_n = math.e ** (o_n_e * math.pi / n)
            
            y_even = self._ft(even_index, idft)
            y_odd = self._ft(odd_index, idft)
            
            y = [0]*n
            for i in range(n//2):
                y[i] = y_even[i] + o * y_odd[i]
                y[i+n//2] = y_even[i] - o * y_odd[i]
                o *= o_n
            return y
        
        def fft(self, l:list):
            output = self._ft(l)
            return output
        
        def idft(self, l:list):
            n = len(l)
            output = self._ft(l,True)
            output = [i/n for i in output]
            return output
    

    相关文章

      网友评论

          本文标题:LeetCode:字符串相乘

          本文链接:https://www.haomeiwen.com/subject/nvzsiktx.html