美文网首页
算法设计技巧: 分治法 (Divide & Conquer)

算法设计技巧: 分治法 (Divide & Conquer)

作者: 胡拉哥 | 来源:发表于2020-04-14 19:55 被阅读0次

    分治法是一种非常通用的算法设计技巧. 在很多实际问题中, 相比直接求解, 分治法往往能显著降低算法的计算复杂度. 常见的可以用分治法求解的问题有: 排序, 矩阵乘法, 整数乘法, 离散傅里叶变换等. 分治法的一般思路如下:

    1. [Divide] 把原问题拆分成同类型的子问题.
    2. [Conquer] 用递归的方式求解子问题.
    3. [Combine] 根据子问题的解构造原问题的解.

    分治法最关键的步骤是如何 低成本地 利用子问题的解构来造原问题的解. 它包含两个方面: 1. 可行性, 即, 可以用子问题的解来构造原问题的解; 2. 高效性, 即构造原问题解的时间复杂度较低. 换句话说, 分治法需要比直接求解效率高. 分治法一般是通过递归求解子问题, 其时间复杂度的分析需要用到如下定理.

    Master Theorem[1]. 考虑递归式T(n) = a T(n/b) + f(n). 当f(n) \in \Theta(n^d), d\geq 0时, 我们有
    \begin{aligned} T(n) = \begin{cases} \Theta(n^d) & \text{ if } a < b^{d} \\ \Theta(n^d \log n) & \text{ if } a = b^d \\ \Theta(n^{\log_b a}) & \text{ if } a > b^d \end{cases} \end{aligned}

    Counting Inversions[2]

    在推荐场景中, 一个常用的方法是协同过滤(Collaborative Filtering), 即为相似的用户推荐它们共同喜欢的事物. 以推荐歌曲为例, 我们把用户AB对歌曲的偏好分别进行排序, 然后计算有多少首歌在AB的排序是"不同的". 最后根据这种不同来定义用户AB的相似性, 从而进行歌曲推荐. 具体来说, 对用户A对歌曲的偏好按1, 2, \ldots, n编号. 用户B对歌曲的偏好可以表示为
    a_1 < a_2 < \ldots < a_n.

    对任意的i<j, 如果a_i > a_j, 我们称a_ia_j称为 反序(inversion). 换句话说, 用户A把歌曲i排在歌曲j的前面, 但是用户B把歌曲j排在了歌曲i的前面.

    问题描述

    给定无重复数字的整数序列a_1, a_2, \ldots, a_n, 计算其反序的数量.

    算法设计

    直接求解的思路是考虑所有的二元组(a_i, a_j)并判断它们是否反序. 这个算法的时间复杂度是O(n^2). 下面我们用分治法来降低计算复杂度.

    注意到这个问题实际上与排序非常类似. 通过对序列进行排序的同时记录不满足"顺序"的二元组数量, 即反序的数量. 令k=\lfloor n/2 \rfloor. 把序列分成两部分:

    \begin{aligned} & \text{left} = \{a_1, a_2, \ldots, a_k\} \\ & \text{right} = \{a_{k+1}, a_{k+2}, \ldots, a_n\}. \end{aligned}

    用递归的方式对left和right进行排序, 同时计算left和right中的反序数量. 当序列的长度为1时, 返回0. 下一步是合并子问题的解. 注意到left和right已经是按照从小到大的顺序进行排列. 比较left和right的第一个元素并把较小的元素添加到结果中直到left或right为空, 最后再把剩余的序列添加到结果集. 在比较过程中我们需要记录反序的数量. 当right中的元素小于left中的元素时, 反序的增量为"left中剩余元素的数量". 最终的结果包含三部分之和: left中反序的数量, rihgt中反序的数量和合并时反序的数量.

    Python实现

    整体的计算过程.

    def sort_and_count(x):
        if len(x) == 1:
            return x, 0
        k = len(x) // 2
        left, count_left = sort_and_count(x[0: k])
        right, count_right = sort_and_count(x[k:])
        # 把子问题的解拼接成原问题的解
        combined, count = merge_and_count(left, right)
        return combined, count + count_left + count_right
    

    归并过程.

    def merge_and_count(left, right):
        """ 把left和right合并且计算inversion的数量
        注意: left和right已经排好序
        """
        combined = []
        count = 0
        while len(left) and len(right):
            if left[0] > right[0]:  # 反序(左边的编号小于右边的编号是正序)
                combined.append(right.pop(0))
                count += len(left)
            else:  # 正序
                combined.append(left.pop(0))
        return combined + left + right, count
    

    完整代码

    计算复杂度

    容易分析归并过程的时间复杂度是O(n). 令T(n)代表算法的时间复杂度, 我们有

    T(n) \leq 2T(n/2) + cn, \quad c \text{ is constant}.

    根据Master Theorem, 我们得到T(n) = \Theta(n\log n).

    Closest Pair[2]

    Closest Pair是计算几何里的一个基本问题: 给定二维平面上的n个点, 找到距离最近的两个点. 通过计算任意两点的距离可以在O(n^2)找到距离最近的两点. 下面利用分治法可以把时间复杂度降低到O(n\log n).

    算法设计

    如果所有点是一维的, 我们可以把它们排序, 然后计算所有相邻两点的最小距离. 排序耗时O(n\log n), 计算相邻点的最小距离耗时O(n), 因此算法的时间复杂度为O(n\log n). 在二维情形, 我们的思路是类似的:

    1. 沿着x轴方向对点集P进行排序得到P_x = \{(x_1, y_1), (x_2, y_2), \ldots, (x_n, y_n) \}.

    2. P_x按与x轴垂的方向均分成两部分QR:
      \begin{aligned} & Q = \{(x_1, y_1), (x_2, y_2), \ldots, (x_{k}, y_{k})\}, \quad k=\lfloor n/2\rfloor \\ & R = \{(x_{k+1}, y_{k+1}), (x_{k+2}, y_{k+2}), \ldots, (x_n, y_n)\}. \end{aligned}

    3. 递归地求解QR中的closest pair(如下图所示).

      image
    4. 根据QR的计算结果构造原问题的解(见下文).

    合并(Combine)

    (q_0, q_1), (r_0, r_1)分别是QR中的closest pair. 如果P的closest pair在PQ中, 我们只需要从(q_0,q_1)(r_0, r_1)选择距离小的pair作为结果输出. 否则P的closest pair其中一点在Q中, 另一点在R中, 这时我们需要比较QR中的点. 这样一来, 合并的时间复杂度为O(n^2)! 接下来我们要把合并的时间复杂度降低为O(n).

    \delta = \min \{d(q_0, q_1), d(r_0, r_1)\}, 其中d(x,y)代表x, y两点之间的距离. 设L = \{x = x_0 \} 代表QR的分割线. 如果存在q\in Q, r\in R使得d(p,r) < \delta, 那么qrx轴方向距离L一定不超过\delta. 令S = \{(x,y)\in P, \text{s.t. } |x-x_0| < \delta \}, 因此q, r \in S. 如下图所示, S中的点在蓝色虚线之间.
    [图片上传失败...(image-facf8c-1586865336187)]
    S中的点按y轴从小到大排序, 得到集合S_y = \{s_1, s_2 \ldots \}, 其中s_i是一个二元组(代表它在平面中的位置). 我们有如下定理(稍后给出证明):

    定理 如果存在s_i, s_j \in S_y满足d(s_i,s_j) < \delta, 那么|i-j| \leq 15.

    这样一来, 我们可以在O(n)的时间内找到所有距离不超过\delta的点对, 并记录距离最小的点对作为结果输出(如果存在). 思路思路如下:

    pairs_within_delta = []  # S中距离不超过delta的点的集合
    for s in Sy:
        for t in 15 points after s:
            if d(s, t) < delta:
                add (s,t) to pairs_within_delta
    output the minimum distance pair in pairs_within_delta
    

    求解子问题QR之前, 首先把P根据y轴从小到大排序得到P_y, 这样一来可以在O(n)时间内构造S_y, 即依次过滤掉P_y中距离L超过\delta的点. 在上述算法中, 外层循环次数是O(n), 内层循环是常数, 因此在合并步骤中构造closest pair的时间复杂度最终降低为O(n).

    Python实现

    先把输入点集P分别按x轴和y轴排序, 得到P_xP_y. 递归求解的过程参考函数closest_pair_xy.

    
    def closest_pair(points):
        """ 计算二维点集中的closest pair.
        :param points: P = [(x1,y1), (x2,y2), ..., (xn, yn)]
        :return: 两个距离最近的点
        """
        
        # 把P按x轴和y轴分别进行排序, 得到Px和Py
        # 注意: P, Px, Py 三个集合是相同的(仅仅排序不同)
        Px = sorted(points, key=lambda item: item[0])
        Py = sorted(points, key=lambda item: item[1])
        return closest_pair_xy(Px, Py)
        
        
    def closest_pair_xy(Px, Py):
        """ 计算closest pair
        :param Px: 把points按x轴升序排列
        :param Py: 把points按y轴升序排列
        :return: point1, point2
        """
        if len(Px) <= 3:
            return search_closest_pair(Px)
        # 构造子问题的输入: Qx, Rx, Qy, Ry
        k = len(Px) // 2
        Q, R = Px[0: k], Px[k:]
        Qx, Qy = sorted(Q, key=lambda x: x[0]), sorted(Q, key=lambda x: x[1])
        Rx, Ry = sorted(R, key=lambda x: x[0]), sorted(R, key=lambda x: x[1])
        # 求解子问题
        q0, q1 = closest_pair_xy(Qx, Qy)
        r0, r1 = closest_pair_xy(Rx, Ry)
        # 合并子问题的解
        return combine_results_of_sub_problems(Py, Qx, q0, q1, r0, r1)
    
    
    def search_closest_pair(points):
        """ 用枚举的方式寻找closest pair
        :param points: [(x1,y1), (x2,y2), ...]
        :return: closest pair
        """
        n = len(points)
        dist_min, i_min, j_min = math.inf, 0, 0
        for i in range(n-1):
            for j in range(i+1, n):
                d = get_dist(points[i], points[j])
                if d < dist_min:
                    dist_min, i_min, j_min = d, i, j
        return points[i_min], points[j_min]
    
    
    def get_dist(a, b):
        """ 计算两点之间的欧式距离
        """
        return math.sqrt(math.pow(a[0]-b[0], 2) + math.pow(a[1]-b[1], 2))
    

    T(n)代表closest_pair_xy的计算时间. 根据前文分析, 合并子问题的解并输出原问题的解的时间复杂度为O(n), 因此我们有
    T(n) \leq 2T(n/2) + cn, \quad c \text{ is constant}.
    根据Master Theorem, 我们有T(n) = \Theta(n\log n). 此外, 把P分别按x,y轴排序的时间复杂度为O(n\log n), 因此算法整体的时间复杂度是O(n\log n).

    下面是合并过程的实现.

    def combine_results_of_sub_problems(Py, Qx, q0, q1, r0, r1):
        """
        :param Py: P按y轴排序的结果
        :param Qx: P在x=x0处被切分成Q和R. Qx是Q按x轴排序的结果
        :param q0: (q0, q1)是Q中的closest pair
        :param q1: 参考q0
        :param r0: (r0, r1)是R中的closest pair
        :param r1: 参考r0
        :return: closest pair in P
        """
        # 计算Sy
        d = min(get_dist(q0, q1), get_dist(r0, r1))
        Sy = get_sy(Py, Qx, d)
        # 检查是否存在距离更小的pair
        s1, s2 = closest_pair_of_sy(Sy)
        if s1 and s2 and get_dist(s1, s2) < d:
            return s1, s2
        elif get_dist(q0, q1) < get_dist(r0, r1):
            return q0, q1
        else:
            return r0, r1
    
    
    def get_sy(Py, Qx, d):
        """ 根据Py计算Sy.
        :param Py: P按y轴排序的结果
        :param Qx: P在x=x0处被切分成Q和R. Qx是Q按x轴排序的结果
        :param d: delta
        :return: S
        """
        x0 = Qx[-1][0]  # Q最右边点的x坐标值
        return [p for p in Py if p[0] - x0 < d]
    
    
    def closest_pair_of_sy(Sy):
        """ 计算集合Sy的closest pair
        """
        n = len(Sy)
        if n <= 1:
            return None, None
        dist_min, i_min, j_min = math.inf, 0, 0
        for i in range(n-1):
            for j in range(i + 1, i + 16):
                if j == len(Sy):
                    break
                d = get_dist(Sy[i], Sy[j])
                if d < dist_min:
                    dist_min, i_min, j_min = d, i, j
        return Sy[i_min], Sy[j_min]
    

    完整代码

    定理证明

    定理 如果存在s_i, s_j \in S_y满足d(s_i,s_j) < \delta, 那么|i-j| \leq 15.

    根据前文的描述, 已知S中的点在下图蓝色虚线之间. 把S中的点按y轴从小到大排序得到S_y = \{s_1, s_2, \ldots\}, 其中s_i代表平面中的一个点. 为了方便描述, 我们把下图中蓝色虚线之间用单位长度为\delta/2的网格划分.
    [图片上传失败...(image-dad441-1586865336187)]
    假设存在s_i,s_j使得d(s_i, s_j) < \delta. 我们要证明|i-j| \leq 15. 证明分为两步:

    1. s_is_j必须在不同的网格中. 反证法. 假设s_i, s_j在同一个网格中(意味着s_i,s_j\in P or Q), 它们的距离d(s_i, s_j) \leq \frac{\delta}{2}\sqrt{2} < \delta. 注意\deltaPQ中的最短距离, 因此矛盾.
    2. |i-j|\leq 15. 反证法. 假设|i-j| \geq 16. 如上图所示s_is_j之间至少相差3行(网格). 因此d(s_i, s_j) \geq \frac{\delta}{2} \cdot 3 > \delta, 矛盾.

    参考文献


    1. T.H. Cormen, C. E. Leiserson, R.L. Rivest and C. Stein. Introduction to Algorithms. Third edition. The MIT Press, 2009.

    2. J. Kleinberg and E. Tardos. Algorithm Design. Pearson, 2006.

    相关文章

      网友评论

          本文标题:算法设计技巧: 分治法 (Divide & Conquer)

          本文链接:https://www.haomeiwen.com/subject/bgkxvhtx.html