美文网首页个人专题机器学习转行篇Python建模与NLP
倾斜四边形非极大值抑制(NMS)的计算思路

倾斜四边形非极大值抑制(NMS)的计算思路

作者: dalalaa | 来源:发表于2019-11-30 11:31 被阅读0次

    在做物体检测的时候常常会用到倾斜四边形(一般是矩形)的NMS问题,在允许使用OpenCV的环境下,可以直接调用cv2.dnn.NMSBoxesRotated函数。
    但是在有些无法使用OpenCV的场合,只能靠自己实现这个功能了。
    本文将会提供一个PyTorch版的NMSBoxesRotated函数,为了方便使用jit或onnx部署,函数中除PyTorch之外没有其他依赖(注意,这份nms代码在Python环境下速度很慢)。

    文章分为两个部分,求倾斜四边形的重叠区域面积和NMS。

    求重叠区域面积


    求重叠区域面积的思路如下:

    项目思路

    求两条线段的交点

    首先利用叉乘判断两条线段是否相交,然后对相交的线段计算交点。

    def cross(a,b):
        '''平面向量的叉乘'''
        x1,y1 = a
        x2,y2 = b
        return x1 * y2 - x2 * y1
    def line_cross(line1,line2):
        '''判断两条线段是否相交,并求交点'''
        a,b = line1
        c,d = line2
        # 两个三角形的面积同号或者其中一个为0(其中一条线段端点落在另一条线段上) ---> 不相交
        if cross(c - a,b - a) * cross(d - a,b - a) >= 0:
            return False
        if cross(b - c,d - c) * cross(a - c,d - c) >= 0:
            return False
        x1,y1 = a
        x2,y2 = b
        x3,y3 = c
        x4,y4 = d
        
        k = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4) 
        if  k != 0:
            xp = ((x1*y2 - y1*x2) * (x3 - x4) - (x1 - x2) * (x3*y4 - y3*x4)) / k
            yp = ((x1*y2 - y1*x2) * (y3 - y4) - (y1 - y2) * (x3*y4 - y3*x4)) / k
        else:
            # 共线
            return False
        return xp,yp
    

    为了验证上面的函数的正确性,可以使用下面的代码测试一下:

    from itertools import combinations
    lines = torch.randn((100,4)).view((-1,2,2))
    comb = combinations(lines,r =2 )
    plt.figure(figsize=(10,10))
    for line in lines:
        plt.plot(line[:,0],line[:,1],color = 'r')
    for line1,line2 in comb:
        r = line_cross(line1,line2)
        if r:
            plt.scatter(r[0],r[1],color = 'g')
    
    线段交点示意图

    整理点集顺序

    整理顺序的思路是先找到所有顶点到中心点的连线;然后定义一个判断线段相对位置(顺时针位还是逆时针位)的函数,这里同样用到了叉乘法;最后根据这个函数实现一个快速排序,代码如下:

    def compare(a,b,center):
        '''
        对比a-center线段是在b-center线段的顺时针方向(True)还是逆时针方向(False)
        1. 通过叉乘积判断,积为负则a-center在b-center的逆时针方向,否则a-center在b-center的顺时针方向;
        2. 如果a,b,center三点共线,则按距离排列,距离center较远的作为顺时针位。
    
        原理:
        det = a x b = a * b * sin(<a,b>)
        其中<a,b>为a和b之间的夹角,意义为a逆时针旋转到b的位置所需转过的角度
        所以如果det为正,说明a可以逆时针转到b的位置,说明a在b的顺时针方向
        如果det为负,说明a可以顺时针转到b的位置,说明a在b的逆时针方向
    
        '''
        det = cross(a - center, b - center)
        if det > 0:
            return True
        elif det < 0:
            return False
        else:
            d_a = torch.sum((a - center) ** 2)
            d_b = torch.sum((b - center) ** 2)
            if d_a > d_b:
                return True
            else:
                return False
    
    def quick_sort(box,left,right,center = None):
        '''快速排序'''
        if center is None:
            center = torch.mean(box,dim = 0)
        if left < right:
            q = partition(box,left,right,center)
            quick_sort(box,left,q - 1,center)
            quick_sort(box,q + 1,right,center)
    
    def partition(box,left,right,center = None):
        '''辅助快排,使用最后一个元素将'''
        x = box[right]
        i = left - 1
        for j in range(left,right):
            if compare(x,box[j],center):
                i += 1
                temp = box[i].clone()
                box[i] = box[j]
                box[j] = temp
                # torch.Tensor不能使用下面的方式进行元素交换
                # box[i],box[j] = box[j],box[i]
        temp = box[i + 1].clone()
        box[i + 1] = box[right]
        box[right] = temp
        return i + 1
    

    同样的,我们可以再写一段代码验证一下效果:

    empty = (np.ones((800,800,3)) * 255).astype(np.uint8)
    box = torch.rand((16,2)) * 800
    cv2.polylines(empty,[box.data.numpy().astype(np.int32)],True,(0,255,0),2)
    quick_sort(box,0,len(box) - 1)
    cv2.polylines(empty, [box.data.numpy().astype(np.int32)], True, (255, 0, 0), 8)
    plt.imshow(empty)
    plt.show()
    

    得到如下图像,红色的是整理之后的多边形框

    整理多边形顺序

    判断点是否在多边形内

    这个函数是用来求凸四边形交集的,因为凸四边形的交集图形的顶点由三部分构成:

    1. box1内部的box2的顶点;
    2. box2内部的box1的顶点;
    3. box1和box2的交点。

    判断代码如下:

    def inside(point,polygon):
        '''
        判断点是否在多边形内部
        原理:
        射线法
        从point作一条水平线,如果与polygon的焦点数量为奇数,则在polygon内,否则在polygon外
    
        为了排除特殊情况
        只有在线段的一个端点在射线下方,另一个端点在射线上方或者射线上的时候,才认为线段与射线相交
        '''
        x0,y0  = point
        # 做一条从point到多边形最左端位置的水平(y保持不变)射线
        left_line = torch.Tensor([[x0,y0],[torch.min(polygon,dim = 0)[0][0].item() - 1,y0]])
        lines = [[polygon[i],polygon[i+1]] for i in range(len(polygon) - 1)] + [[polygon[-1],polygon[0]]]
        ins = False
        for line in lines:
            (x1,y1),(x2,y2) = line
            if min(y1,y2) < y0 and max(y1,y2) >= y0:
                c = line_cross(left_line,line)
                if c and c[0] <= x0:
                    ins = not ins
        return ins
    

    然后使用下面的代码再验证一下:

    points = torch.rand(800,2) * 800
    for p_ in points:
        p = p_.clone().long()
        r = inside(p,box)
        if r:
            cv2.circle(empty,(p[0].item(),p[1].item()),5,color = (0,0,0),thickness=5)
        else:
            cv2.circle(empty,(p[0].item(),p[1].item()),5,color = (255,0,255),thickness=5)
    plt.imshow(empty)
    

    就可以获得下面这个很花哨的图形了:

    点在多边形内部

    求两个四边形的重叠区域

    !!!只适用于四边形的重叠区域只有一个的情况,例如两者都是凸四边形的情况

    def intersection(box1,box2):
        '''
        判断两个框是否相交,如果相交,返回重叠区域的顶点
        1. 求box1在box2内部的点;
        2. 求box2在box1内部的点;
        3. 求box1和box2的交点;
        4. 所有点构成重叠区域的多边形点集;
        5. 顺时针排序
        '''
        quick_sort(box1,0,len(box1) - 1)
        quick_sort(box2,0,len(box2) - 1)
        # 求重叠区域
        # 整理成线段
        lines1 = [[box1[i],box1[i + 1]] for i in range(len(box1) - 1)] + [[box1[-1],box1[0]]]
        lines2 = [[box2[i],box2[i + 1]] for i in range(len(box2) - 1)] + [[box2[-1],box2[0]]]
        cross_points = []
        # 交点
        for l1 in lines1:
            for l2 in lines2:
                c = line_cross(l1,l2)
                if c:
                    cross_points.append(torch.Tensor(c).view(1,-1))
        # 求box1在box2内部的点
        for b in box1:
            if inside(b,box2):
                cross_points.append(b.view(1,-1))
        for b in box2:
            if inside(b,box1):
                cross_points.append(b.view(1,-1))
        if len(cross_points) > 0:
            cross_points = torch.cat(cross_points,dim = 0)
            quick_sort(cross_points,0,len(cross_points) - 1)
            return cross_points
        else:
            return None
    

    验证代码如下:

    
    plt.figure(figsize=(18,10))
    for i in range(4):
        box1 = torch.rand((4,2)) * 800
        box2 = torch.rand((4,2)) * 800
        empty = (np.ones((800,800,3)) * 255).astype(np.uint8)
        quick_sort(box1,0,len(box1) - 1)
        quick_sort(box2,0,len(box2) - 1)
        cv2.polylines(empty, [box1.data.numpy().astype(np.int32)], True, (255, 0, 0), 4)
        cv2.polylines(empty, [box2.data.numpy().astype(np.int32)], True, (0, 255, 0), 4)
        cross_points = intersection(box1,box2)
        if cross_points is not None:
            cv2.polylines(empty, [cross_points.data.numpy().astype(np.int32)], True, (0, 0, 255), 4)
        plt.subplot(140 + i + 1)
        plt.imshow(empty)
    
    四边形的重叠区域

    计算多边形的面积

    多边形面积也是利用叉乘来求的,这里利用了叉乘的集合意义以及叉乘的正负性。

    def polygon_area(polygon):
        '''
        求多边形面积
        https://blog.csdn.net/m0_37914500/article/details/78615284 使用向量叉乘计算多边形面积,前提是多边形所有点按顺序排列
        '''
        lines = [[polygon[i],polygon[i+1]] for i in range(len(polygon) - 1)] + [[polygon[-1],polygon[0]]]
        s_polygon = 0.0
        for line in lines:
            a,b = line
            s_tri = cross(a,b)
            s_polygon += s_tri
        return s_polygon / 2
    

    计算IOU

    IOU即交并比,也就是两个多边形的交集面积除以并集面积。

    def intersection_of_union(box1,box2):
        '''
        iou = intersection(s_1,s_2) / (s_1 + s_2 - intersection(s_1,s_2))
        '''
        quick_sort(box1,0,len(box1) - 1)
        quick_sort(box2,0,len(box2) - 1)
        s_box1 = torch.abs(polygon_area(box1))
        s_box2 = torch.abs(polygon_area(box2))
        cross_points = intersection(box1,box2)
        if cross_points is not None:
            cv2.polylines(empty, [cross_points.data.numpy().astype(np.int32)], True, (0, 0, 255), 4)
            s_cross = torch.abs(polygon_area(cross_points))
        else:
            s_cross = torch.Tensor([[0]])
        iou = s_cross / (s_box1 + s_box2 - s_cross)
        return iou
    

    计算结果如下:

    plt.figure(figsize=(18,10))
    for i in range(4):
        box1 = torch.rand((4,2)) * 800
        box2 = torch.rand((4,2)) * 800
        empty = (np.ones((800,800,3)) * 255).astype(np.uint8)
        quick_sort(box1,0,len(box1) - 1)
        quick_sort(box2,0,len(box2) - 1)
    #     s_box1 = torch.abs(polygon_area(box1))
    #     s_box2 = torch.abs(polygon_area(box2))
        cv2.polylines(empty, [box1.data.numpy().astype(np.int32)], True, (255, 0, 0), 4)
        cv2.polylines(empty, [box2.data.numpy().astype(np.int32)], True, (0, 255, 0), 4)
        cross_points = intersection(box1,box2)
        if cross_points is not None:
            cv2.polylines(empty, [cross_points.data.numpy().astype(np.int32)], True, (0, 0, 255), 4)
    #         s_cross = torch.abs(polygon_area(cross_points))
    #     else:
    #         s_cross = torch.Tensor([[0]])
        iou = intersection_of_union(box1,box2)
        print(iou.item())
        plt.subplot(140 + i + 1)
        plt.title("IOU : {}".format(iou.item()))
        plt.imshow(empty)
    
    
    iou值展示

    NMS

    nms原理相信大家都比较了解了,分为如下几个步骤:

    1. 选择score最大的box;
    2. 删除与该box的iou超过nms_thresh的box;
    3. 从剩余的box中选择score最大的box,重复第二步。
    
    def nms(boxes,scores,score_thresh = 0.95,nms_thresh = 0.1):
        indices = torch.where(scores > score_thresh)[0]
        if len(indices) <= 1:
            return boxes[indices]
        boxes = boxes[indices]
        scores = scores[indices]
        keep_indices = []
        # 从大到小
        order = torch.argsort(scores).flip(dims = [0])
        while order.shape[0] > 0:
            i = order[0]
            keep_indices.append(i)
            not_overlaps = []
            for j in range(len(order)):
                if order[j] != i:
                    iou = intersection_of_union(boxes[i],boxes[order[j]])
                    if iou < nms_thresh:
                        not_overlaps.append(j)
            order = order[not_overlaps]
        keep_boxes = boxes[[i.item() for i in keep_indices]]
        return keep_boxes
    

    验证代码:

    boxes = torch.rand((10,4,2)) * 800
    empty = (np.ones((800,800,3)) * 255).astype(np.uint8)
    for i in range(len(boxes)):
        quick_sort(boxes[i],0,len(boxes[i]) - 1)
    cv2.polylines(empty,boxes.data.numpy().astype(np.int32),True,(0,255,0),4)
    plt.subplot(121)
    plt.imshow(empty)
    scores = torch.arange(10) + 1
    keep_boxes = nms(boxes,scores)
    # print("keep indices",keep_indices,boxes.shape)
    # keep_boxes = boxes[[i.item() for i in keep_indices]]
    empty = (np.ones((800,800,3)) * 255).astype(np.uint8)
    cv2.polylines(empty,keep_boxes.data.numpy().astype(np.int32),True,(0,255,0),4)
    plt.subplot(122)
    plt.imshow(empty)
    

    最终结果:

    nms

    相关文章

      网友评论

        本文标题:倾斜四边形非极大值抑制(NMS)的计算思路

        本文链接:https://www.haomeiwen.com/subject/wutswctx.html