用数组计算NMS(非极大值抑制)
def nms(boxes, scores, iou_threshold):
# 得分index排序,从大到小
sorted_indices = np.argsort(scores)[::-1]
keep_boxes = []
while sorted_indices.size > 0:
# 取得分最高的index
box_id = sorted_indices[0]
keep_boxes.append(box_id)
# 计算这个box和其余box的IOU
ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
# 过滤掉大于iou阈值的box
keep_indices = np.where(ious < iou_threshold)[0]
# 剩下的需要继续计算的box的index,由于每次都保留了第一个index,计算iou的index要小1,所以index要加1.
sorted_indices = sorted_indices[keep_indices + 1]
return keep_boxes
def compute_iou(box, boxes):
# 计算xmin, ymin, xmax, ymax
xmin = np.maximum(box[0], boxes[:, 0])
ymin = np.maximum(box[1], boxes[:, 1])
xmax = np.minimum(box[2], boxes[:, 2])
ymax = np.minimum(box[3], boxes[:, 3])
# 计算交集面积
intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
# 计算并集面积
box_area = (box[2] - box[0]) * (box[3] - box[1])
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
union_area = box_area + boxes_area - intersection_area
# 计算IoU
iou = intersection_area / union_area
return iou
网友评论