前言
有一批图片需要进行标注,需要标注的目标没有移动,仅仅在每次拍摄时变换了一点角度
。
每一张图片都去标注十分耗费时间,采取了一种方法,只标注一张,其余的图片根据已经标注的图片采用模板匹配的方法自动生成标注文件
步骤
- 读取已标注的xml文件和图片
- 模板匹配已标注的图片和剩余未标注的图片
- 根据模板匹配的结果生成透视矩阵
- 求透视矩阵的逆矩阵
- 通过逆透视矩阵变换已标注的图片上的坐标变换到未标注的图片上并保存到xml中
备注:python版本3.6、 opencv版本使用3.4.2.16
代码
import os
import cv2
import tqdm
import click
import dict2xml
import numpy as np
from lxml import etree
import collections
### 解析xml
def parse_xml_to_dict(xml):
if len(xml) == 0:
return{xml.tag: xml.text}
result= {}
for child in xml:
child_result= parse_xml_to_dict(child)
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result:
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
### 从字典生成xml
def generate_dict_to_xml(image_name, im_dict, folder, path, width, height, xml_save_path):
base_dict = collections.OrderedDict(
{
"annotation": collections.OrderedDict({
"folder": folder,
"filename": image_name,
"path": path,
"source": collections.OrderedDict({
"database": "Unknown"
}),
"size": collections.OrderedDict({
"width": str(width),
"height": str(height),
"depth": "3"
}),
"segmented": "0",
"object": []
})
})
for info in im_dict:
obj_dict = collections.OrderedDict(
{
"name": info["label"],
"pose": "Unspecified",
"truncated": "0",
"difficult": "0",
"bndbox":collections.OrderedDict(
{
"xmin": str(int(info["left"])),
"ymin": str(int(info["top"])),
"xmax": str(int(info["right"])),
"ymax": str(int(info["bottom"]))
})
})
base_dict["annotation"]["object"].append(obj_dict)
# print(base_dict)
xml_str_data = dict2xml.dict2xml(collections.OrderedDict(base_dict))
with open(xml_save_path, "w") as f:
f.write(xml_str_data)
### 通过逆透视矩阵进行坐标映射
def cvt_pos(pos,cvt_mat_t):
u = float(pos[0])
v = float(pos[1])
x = (cvt_mat_t[0][0]*u+cvt_mat_t[0][1]*v+cvt_mat_t[0][2])/(cvt_mat_t[2][0]*u+cvt_mat_t[2][1]*v+cvt_mat_t[2][2])
y = (cvt_mat_t[1][0]*u+cvt_mat_t[1][1]*v+cvt_mat_t[1][2])/(cvt_mat_t[2][0]*u+cvt_mat_t[2][1]*v+cvt_mat_t[2][2])
return (x,y)
## 让opencv可以读取带中文路径的图片
def cv_imread(file_path, read_type):
cv_img = cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), read_type)
return cv_img
### 通过opencv surf方法做模板匹配,通过匹配到的特征点计算逆透视矩阵
def image_match_surf(trans_image_path, ref_image_path, ref_xml_path, save_path):
surf = cv2.xfeatures2d.SURF_create()
with open(ref_xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
xml_data = parse_xml_to_dict(xml)
img1 = cv_imread(ref_image_path, cv2.IMREAD_GRAYSCALE)
img2 = cv_imread(trans_image_path, cv2.IMREAD_GRAYSCALE)
kp1, des1 = surf.detectAndCompute(img1, None)
kp2, des2 = surf.detectAndCompute(img2, None)
# BFmatcher with default parms
flann = cv2.FlannBasedMatcher()
matches = flann.knnMatch(des1, des2, k=2)
good_matches = []
for m, n in matches:
if m.distance < 0.7 * n.distance:
good_matches.append(m)
points1 = np.float32([kp1[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
points2 = np.float32([kp2[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)
h, mask = cv2.findHomography(points2, points1, cv2.RANSAC, 5.0)
height, width= img2.shape
im_dict = []
for obj in xml_data["annotation"]["object"]:
label = obj["name"]
xmin = obj["bndbox"]["xmin"]
ymin = obj["bndbox"]["ymin"]
xmax = obj["bndbox"]["xmax"]
ymax = obj["bndbox"]["ymax"]
left, top = cvt_pos((xmin, ymin), np.linalg.inv(h))
right, bottom = cvt_pos((xmax, ymax), np.linalg.inv(h))
if left < 0 or top < 0 or right > width or bottom > height:continue
im_dict.append({"label" : label, "left" : left, "top" : top, "right" : right, "bottom" : bottom})
trans_image_name = os.path.basename(trans_image_path)
xml_save_path = os.path.join(save_path, os.path.splitext(trans_image_name)[0] + ".xml")
generate_dict_to_xml(trans_image_name, im_dict, "folder", "path", width, height, xml_save_path)
@click.command()
@click.option("-t", "--trans_image_path", required=True, help="需要标注的图片路径")
@click.option("-i", "--ref_image", required=True, help="参考图片")
@click.option("-x", "--ref_xml", required=True, help="参考图片的xml")
@click.option("-s", "--save_path", required=True, help="生成的xml文件保存路径")
def main(trans_image_path, ref_image, ref_xml, save_path):
image_list = [os.path.join(trans_image_path, img) for img in os.listdir(trans_image_path) if img.endswith(".jpg")]
for img in tqdm.tqdm(image_list):
image_match_surf(img, ref_image, ref_xml, save_path)
if __name__ == "__main__":
main()
网友评论