美文网首页
使用K-means算法寻找yolo的锚框

使用K-means算法寻找yolo的锚框

作者: 小黄不头秃 | 来源:发表于2023-06-19 12:43 被阅读0次

    在使用yolov3算法时需要9个锚框,根据不同的数据锚框的大小是不一样的,于是yolov3使用K-means聚类算法计算出数据集中的9个框的期望值作为9个锚框,现在我们一起来讨论一下这些锚框是怎么生成的。

    • 第一步:提取出数据集中所有框的坐标值
    • 第二步:将框的坐标值转换成宽和高
    • 第三步:使用K-means聚类算法随机生成k个框。
    • 第四步:计算每一个框和初始框的相似度。原k-means算法使用的是欧氏距离计算点的相似度,这里改为使用1-IOU作为相似度。
    • 第五步:将所有的框分为k个类
    • 第六步:使用分类的均值更新k个框
    • 第七步:重复第四步-第七步
    • 第八步:计算框的改变值,当新框和原框改变值较小的时候,停止算法。

    可以根据自己的实际情况修改锚框的读取方式,代码中给出两种:

    • 从VOC 标签XML文件读取
    • 从txt标签文件读取

    K-means算法+生成锚框代码实现:

    import glob
    import random
    import xml.etree.ElementTree as ET
    import numpy as np
    import PIL.Image as Image
    import PIL.ImageDraw as D
    
    # 计算IOU
    def cas_iou(box, cluster):
        x = np.minimum(cluster[:, 0], box[0])
        y = np.minimum(cluster[:, 1], box[1])
    
        intersection = x * y
        area1 = box[0] * box[1]
        area2 = cluster[:, 0] * cluster[:, 1]
        iou = intersection / (area1 + area2 - intersection)
        return iou
    
    # 计算平均IOU
    def avg_iou(box, cluster):
        return np.mean([np.max(cas_iou(box[i], cluster)) for i in range(box.shape[0])])
    
    def kmeans(box, k):
        # 取出一共有多少框
        row = box.shape[0]
        # 每个框各个点的位置
        distance = np.empty((row, k)) # [699, 9]
    
        # 最后的聚类位置
        last_clu = np.zeros((row,)) # [699,]
    
        np.random.seed()
    
        # 随机选9个当聚类中心
        cluster = box[np.random.choice(row, k, replace=False)] # [9,2]
        # cluster = random.sample(row, k)
        while True:
            # 计算每一行距离五个点的iou情况。
            for i in range(row):
                distance[i] = 1 - cas_iou(box[i], cluster)
    
            # 取出最小点的索引值
            near = np.argmin(distance, axis=1) # [699, 1]
    
            # 算法结束条件
            if (last_clu == near).all():
                break
    
            # 求每一个类的中位点,
            for j in range(k):
                cluster[j] = np.median(box[near == j], axis=0) # 计算中位数 [9,2]
    
            last_clu = near
        return cluster
    
    # 从xml文件中读取
    def load_data(path):
        data = []
        # 对于每一个xml都寻找box
        for xml_file in glob.glob('{}/*xml'.format(path)):
            tree = ET.parse(xml_file)
            height = int(tree.findtext('./size/height'))
            width = int(tree.findtext('./size/width'))
            if height <= 0 or width <= 0:
                continue
    
            # 对于每一个目标都获得它的宽高
            for obj in tree.iter('object'):
                xmin = int(float(obj.findtext('bndbox/xmin'))) / width
                ymin = int(float(obj.findtext('bndbox/ymin'))) / height
                xmax = int(float(obj.findtext('bndbox/xmax'))) / width
                ymax = int(float(obj.findtext('bndbox/ymax'))) / height
    
                xmin = np.float64(xmin)
                ymin = np.float64(ymin)
                xmax = np.float64(xmax)
                ymax = np.float64(ymax)
                # 得到宽高
                data.append([xmax - xmin, ymax - ymin])
        return np.array(data)
    
    # 从txt中读取
    def load_data2(path, img_path):
        res = []
        with open(path) as f:
            for line in f.readlines():
                arr = line.strip().split(" ")
                box = np.array(arr[2:], dtype=np.float64)
                img_file_path = img_path + arr[0]
                w,h = Image.open(img_file_path).size
    
                x1 = (box[0] - box[2]/2) / w
                y1 = (box[1] - box[3]/2) / h
                x2 = (box[0] + box[2]/2) / w
                y2 = (box[1] + box[3]/2) / h
                res.append([x2-x1, y2-y1])
        return np.array(res)
    
    if __name__ == '__main__':
        SIZE = 416
        anchors_num = 9
        # 载入数据集,可以使用VOC的xml
        img_path = r"./dataset/car-identify/car-main/dataset/dataset/anno_img/"
        path = r"./dataset/car-identify/car-main/dataset/dataset/Imagesets/img_label.txt"
        path2 = r"./dataset/car-identify/car-main/dataset/dataset/annotation"
    
        # 载入所有的xml
        # 存储格式为转化为比例后的width,height
        data = load_data2(path, img_path)
    
        # 使用k聚类算法
        out = kmeans(data, anchors_num)
        out = out[np.argsort(out[:, 0])]
        print('acc:{:.2f}%'.format(avg_iou(data, out) * 100))
        print(out * SIZE)
        data = out * SIZE # [9, 2]
    
        # 这里是按照面积大小进行排序
        area = data[:,0] * data[:, 1]
        sort_index = np.argsort(area) # 获取面积由小到大的排序
        data = data[sort_index]
        new_data = []
        # 每三个为同一个尺寸:小框、中框、大框
        # 再按照宽高比进行排序,排完之后:竖框、方框、横框
        for i in range(0,9,3):
            boxes = data[i:i+3]
            ratio = boxes[:, 0] / boxes[:, 1]
            ratio_index = np.argsort(ratio)
            boxes = boxes[ratio_index]
            [new_data.append(box) for box in boxes]
        data = new_data
    
        f = open("./param/neuron_anchors.txt", 'w')
        row = np.shape(data)[0]
        for i in range(row):
            if i == 0:
                x_y = "%d,%d" % (data[i][0], data[i][1])
            else:
                x_y = ", %d,%d" % (data[i][0], data[i][1])
            f.write(x_y)
        f.close()
    

    相关文章

      网友评论

          本文标题:使用K-means算法寻找yolo的锚框

          本文链接:https://www.haomeiwen.com/subject/olztydtx.html