一、 半监督学习样本GT生成原理
在处理这样的打标数据时,首先获得文本行对应的局部图片,然后通过模型预测得到region score;根据region score使用分水岭算法得到字符的位置信息;然后将这个坐标信息恢复到原图的坐标轴上,作为GT;这种半监督的方式体现在,根据每个单词得到的字符box都是带有置信度的;置信度的计算方式为
image.jpeg在训练的初期,置信度会比较低,对置信度低于0.5的样本,会根据单词包含的字符数量平均划分区域,进行标注。同时设置标签置信度为0.5。
整体的过程可以用下面图的上面部分来展示:
image.jpeg
二、CRAFT-Reimplementation代码处理流程:
CRAFT官方提供的代码中并没有训练代码,github上的高分训练代码是
这份代码确实复现了半监督学习的实现细节,很赞下面的步骤建议结合原GitHub代码一起看,更加清晰
实际在训练的时候,定义Dataset的时候会传入网络Net;在数据处理的时候,已经有预测网络Net,原图image,单词Word以及单词框word_box;
- 获得单词局部图片
- 将单词图片透视变换到矩形图片,归一化高度为64
- 通过Net网络预测得到region_score的结果
img_torch = torch.from_numpy(imgproc.normalizeMeanVariance(input, mean=(0.485, 0.456, 0.406),
variance=(0.229, 0.224, 0.225)))
img_torch = img_torch.permute(2, 0, 1).unsqueeze(0)
img_torch = img_torch.type(torch.FloatTensor).cuda()
scores, _ = net(img_torch)
region_scores = scores[0, :, :, 0].cpu().data.numpy()
region_scores = np.uint8(np.clip(region_scores, 0, 1) * 255)
bgr_region_scores = cv2.resize(region_scores, (input.shape[1], input.shape[0]))
bgr_region_scores = cv2.cvtColor(bgr_region_scores, cv2.COLOR_GRAY2BGR)
- 使用分水岭算法得到仿真的字符框
pursedo_bboxes = watershed(input, bgr_region_scores, False)
- 根据仿真字符框和Word的字符数量计算confidence
confidence = self.get_confidence(real_char_nums, len(pursedo_bboxes))
- 如果计算的confidence小于0.5,则按照单词长度对每个单元格设置仿真字符框,然后设置
confidence=0.5。
if confidence <= 0.5:
width = input.shape[1]
height = input.shape[0]
width_per_char = width / len(word)
for i, char in enumerate(word):
if char == ' ':
continue
left = i * width_per_char
right = (i + 1) * width_per_char
bbox = np.array([[left, 0], [right, 0], [right, height],
[left, height]])
bboxes.append(bbox)
bboxes = np.array(bboxes, np.float32)
confidence = 0.5
- 将文本框坐标恢复到透视变换之前的坐标系空间
for j in range(len(bboxes)):
ones = np.ones((4, 1))
tmp = np.concatenate([bboxes[j], ones], axis=-1)
I = np.matrix(MM).I
ori = np.matmul(I, tmp.transpose(1, 0)).transpose(1, 0)
bboxes[j] = ori[:, :2]
- 根据图片大小和字符框位置生成region-score高斯热力图
- 根据region-score热力图,图片大小、单词、字符框生成affinity-score map
- 对图片以及热力图进行一些随机化操作,比如随机crop,flip,旋转等操作。
三、 存在的问题
在上面提到的实际代码的生成过程中,前面8个步骤都没有问题,到了第9步生成affinity-score的时候产生问题。
先放一张正确的生成GT的图给了解一下:
最左边是MLT数据的一张图片,原图上叠加了红色边框的是半监督学习生成的字符框,蓝色变宽是半监督学习生成的affinity-box;
第二张图是region-score的热力图,越蓝表示分数越低,最低为0,红色表示分数最高为1。
第三张图是affinity-score的热力图
第四张为mask图,用来忽略置信度低的区域以及非文字区域,减少计入loss损失的非必要部分内容。
image.jpeg下面是有问题的GT图了
可以明显看到第三张子图上,红的灿灿烂烂,烂七八糟,全魔乱舞,仿若极光的affinity-score图。
image.jpeg以及这样的,悠悠的蓝光从中下部散射,漏斗状散射而出:
image.jpeg这些都不是icdar15的数据集,所以作者在写这个代码出现这个问题很正常,但是我们在实际训练的时候,如果没有注意到这样的问题直接套用代码,训练的时候肯定会震荡,得不到最终的结果。因此修改这里的affinity-score GT的生成代码。
网友评论