前言
PubTabNet是IBM公司公布的基于图像的表格识别数据集。
其包含了568k+表格图片,其标注数据是HTML的表格结构,下载压缩包磁盘存储大小10G+。
GitHub相关地址
IBM的下载地址
相关论文:
Image-based table recognition: data, model, and evaluation
此篇论文的核心在于通过Encoder - double Decoder实现表格结构与单元格的识别
目前暂时没有GitHub复现项目。
PubTabNet数据
表格图片
该数据集的表格都是PDF截图,清晰度不是很高
示例如下:
PMC2838834_005_00.png
训练结构数据
训练结构数据位于.jsonl文件中,该文件每一行都是一条json数据,可以通过jsonlines库逐条读取。
数据结构:
- imgid: 图像id
- html: html单元格详细描述
cells:单元格列表
cells列表成员:
tokens:单元格文本字符级别信息,
bbox:单元格文本范围的bounding box,这个并不是单元格范围,而是单元格内文本范围!数字坐标说明:x_min, y_min, x_max, y_max,即文本范围对角线两个点的坐标。
该坐标应该是相应PDF页面中的坐标,并不是从0开始,具体使用时,需要考虑图片真实大小,并进行换算才能使用。
structure字典
structure字典成员:
tokens:对应cells的HTML格式 - split:表示训练数据或验证数据,分别为train与val
- filename:图像名称,如PMC2838834_005_00.png
示例如下(只贴出两个单元格以及部分HTML结构数据):
{
"imgid": 4,
"html": {
"cells": [
{
"tokens": [
"<b>",
"M",
"a",
"i",
"n",
" ",
"c",
"e",
"l",
"l",
"u",
"l",
"a",
"r",
" ",
"p",
"r",
"o",
"c",
"e",
"s",
"s",
"</b>"
],
"bbox": [
1,
4,
76,
13
]
},
{
"tokens": [
"<b>",
"M",
"o",
"d",
"u",
"l",
"a",
"t",
"e",
"d",
" ",
"p",
"a",
"t",
"h",
"w",
"a",
"y",
"s",
"</b>"
],
"bbox": [
92,
4,
167,
13
]
}
],
"structure": {
"tokens": [
"<thead>",
"<tr>",
"<td>",
"</td>",
"<td>",
"</td>",
"<td",
" colspan=\"2\"",
">",
"</td>",
"<td",
" colspan=\"3\"",
">",
"</td>",
"</tr>",
"...",
"..."
]
}
},
"split": "train",
"filename": "PMC2838834_005_00.png"
}
位于GitHun的PubTabNet相关代码,只有读取数据,将数据转换为HTML的功能,并没有表格识别相关的代码。
PubTabNet转换为SciTSR训练数据格式
因为之前在研究SciTSR数据,所以需要将PubTabNet的训练集数据转换为SciTSR的格式,从而扩大训练集,即我们需要根据.jsonl中的数据,提炼出chunk, structure数据来。
SciTSR的介绍链接在此
这一段是干货,因为这意味着不同数据集可以通用了。
import tqdm
import os
import jsonlines
import json
import logging
import sys
import re
sys.path.insert(0, os.path.abspath('../'))
class Transform:
def transform(self, pubtabnet_data_file: str):
if pubtabnet_data_file is None or not os.path.exists(pubtabnet_data_file):
logging.error('No PubTabNet data file!')
return
root_path = os.path.dirname(pubtabnet_data_file)
with open(pubtabnet_data_file, encoding='utf-8') as reader:
for img in jsonlines.Reader(reader):
# if img['filename'] == 'PMC5577841_001_00.png':
# print(img)
img_filename = img['filename']
img_type = img['split']
gfte_structure_folder = '{0}/{1}/structure/'.format(root_path, img_type)
if not os.path.exists(gfte_structure_folder):
os.makedirs(gfte_structure_folder)
gfte_chunk_folder = '{0}/{1}/chunk/'.format(root_path, img_type)
if not os.path.exists(gfte_chunk_folder):
os.makedirs(gfte_chunk_folder)
print('Handle image: {0}'.format(img_filename))
cells = img['html']['cells']
structure_list = img['html']['structure']['tokens']
gfte_structure = self.get_row_col_position(structure_list)
gfte_structure_file = '{0}.json'.format(img_filename.replace('.png', ''))
gfte_chunk_file = '{0}.chunk'.format(img_filename.replace('.png', ''))
structure_data_dict, chunk_data_dict = self.transfer_GFTE_data(cells, gfte_structure)
if isinstance(structure_data_dict, dict) and \
isinstance(chunk_data_dict, dict):
with open(os.path.join(gfte_structure_folder,
gfte_structure_file), "w", encoding='utf-8') as write_file:
json.dump(structure_data_dict, write_file)
with open(os.path.join(gfte_chunk_folder,
gfte_chunk_file), "w", encoding='utf-8') as write_file:
json.dump(chunk_data_dict, write_file)
def get_row_col_position(self, structure_list: list):
row_index = 0
col_index = 0
gfte_structure = []
for index, table_element in enumerate(structure_list):
if table_element == '<tr>':
col_index = 0
if table_element == '</tr>':
row_index += 1
if table_element == '<td>':
structure_dict = {'start_row': row_index,
'end_row': row_index,
'start_col': col_index,
'end_col': col_index}
col_index += 1
gfte_structure.append(structure_dict)
continue
if table_element == '<td' and index < len(structure_list) - 1:
next_element = structure_list[index + 1]
next_element_split = next_element.strip().replace('"', '').split('=')
if len(next_element_split) == 2:
span_number = int(next_element_split[1])
if next_element_split[0] == 'colspan':
structure_dict = {'start_row': row_index,
'end_row': row_index,
'start_col': col_index,
'end_col': col_index + span_number - 1}
col_index += span_number
gfte_structure.append(structure_dict)
continue
if next_element_split[0] == 'rowspan':
structure_dict = {'start_row': row_index,
'end_row': row_index + span_number - 1,
'start_col': col_index,
'end_col': col_index}
col_index += 1
gfte_structure.append(structure_dict)
continue
return gfte_structure
def transfer_GFTE_data(self, cell_list: list, structure_list: list):
logging.info('Construct GFTE structure data')
if cell_list is None or \
structure_list is None or \
not isinstance(cell_list, list) or \
not isinstance(structure_list, list) or \
len(cell_list) != len(structure_list):
logging.error('Cell is not pair of structure!')
return -1, -1
structure_data_dict = {'cells': []}
chunk_data_dict = {'chunks': []}
for index, (cell, structure) in enumerate(zip(cell_list, structure_list)):
structure_data = {}
structure_data['id'] = index
raw_text = ''.join(cell['tokens']).strip()
structure_data['tex'] = raw_text
pure_text = re.sub(r'\<.*?\>', '', raw_text).strip()
structure_data['content'] = pure_text.split()
structure_data.update(structure)
structure_data_dict['cells'].append(structure_data)
chunk_data = {}
if cell.get('bbox', None) is not None:
chunk_data['pos'] = [cell['bbox'][0],
cell['bbox'][2],
cell['bbox'][1],
cell['bbox'][3]]
else:
chunk_data['pos'] = [0, 0, 0, 0]
chunk_data['text'] = raw_text
chunk_data_dict['chunks'].append(chunk_data)
for index, (structure_data, chunk_data) in enumerate(zip(structure_data_dict['cells'],
chunk_data_dict['chunks'])):
# x0, x1, y0, y1
if chunk_data['pos'] == [0, 0, 0, 0]:
id = structure_data['id']
start_row = structure_data['start_row']
end_row = structure_data['end_row']
start_col = structure_data['start_col']
end_col = structure_data['end_col']
x0 = None
x1 = None
y0 = None
y1 = None
for structure, chunk in zip(structure_data_dict['cells'],
chunk_data_dict['chunks']):
if structure['id'] != id and chunk['pos'] != [0, 0, 0, 0]:
if y0 is None and structure['start_row'] == start_row:
y0 = chunk['pos'][2]
if y1 is None and structure['end_row'] == end_row:
y1 = chunk['pos'][3]
if x0 is None and structure['start_col'] == start_col:
x0 = chunk['pos'][0]
if x1 is None and structure['end_col'] == end_col:
x1 = chunk['pos'][1]
if x0 is not None and x1 is not None and y0 is not None and y1 is not None:
chunk_data['pos'] = [x0, x1, y0, y1]
break
return structure_data_dict, chunk_data_dict
if __name__ == "__main__":
transform = Transform()
transform.transform('/data/pubtabnet/PubTabNet_2.0.0.jsonl')
# transform.transform('/data/scitsr/examples/PubTabNet_Examples.jsonl')
UNET所需数据准备
UNET非常适合图像分割,对于表格来说,如果能够通过UNET将表格中行与列进行标注,则能够方便根据表格结构提取各个单元格信息。
事实上,UNET预测的是“线”,训练数据类似于如下格式,即两点一线,标签为0代表横线,标签为1代表竖线
{
"label": "0",
"line_color": [
0,
0,
128
],
"fill_color": [
0,
0,
128
],
"points": [
[
0,
0
],
[
503,
0
]
]
},
{
"label": "0",
"line_color": [
0,
0,
128
],
"fill_color": [
0,
0,
128
],
"points": [
[
0,
275
],
[
503,
275
]
]
},
{
"label": "1",
"line_color": [
0,
0,
128
],
"fill_color": [
0,
0,
128
],
"points": [
[
0,
0
],
[
0,
276
]
]
}
训练数据转换代码,则是根据SciTSR的训练数据:chunk与structure得到。
如下代码实现了:
- 根据每行每列的最大与最小坐标,以及邻接单元格信息,尝试绘制表格中的横线与竖线。而这仅仅是根据原数据中的单元格中文本框的坐标计算得到的
- 将得到的线绘制到单独一张图片,构成训练数据示意图片
import os
import numpy as np
from PIL import Image
import cv2
import json
# Returns if columns belong to same table or not
def sameTable(ymin_1, ymin_2, ymax_1, ymax_2):
min_diff = abs(ymin_1 - ymin_2)
max_diff = abs(ymax_1 - ymax_2)
if min_diff <= 5 and max_diff <= 5:
return True
elif min_diff <= 4 and max_diff <= 7:
return True
elif min_diff <= 7 and max_diff <= 4:
return True
return False
def draw_cell_line(root_path: str = r'/data/scitsr/train', is_scitsr: bool = True):
"""
依赖PubTabNet的图像数据,以及SciTSR的数据格式,转换为table_net需要的数据结构
:param root_path:
:param is_scitsr:
:return:
"""
directory = os.path.join(root_path, 'img')
chunk_directory = os.path.join(root_path, 'chunk')
structure_directory = os.path.join(root_path, 'structure')
final_cell_directory = os.path.join(root_path, 'cell_mask_img')
if not os.path.exists(final_cell_directory):
os.makedirs(final_cell_directory)
final_table_directory = os.path.join(root_path, 'table_mask_img')
if not os.path.exists(final_table_directory):
os.makedirs(final_table_directory)
table_net_data_directory = os.path.join(root_path, 'table_net_data')
if not os.path.exists(table_net_data_directory):
os.makedirs(table_net_data_directory)
files = os.listdir(directory)
for index, file in enumerate(files):
if index == 200:
break
print('Handle the {0}/ {1} file: {2}'.format(index + 1, len(files), file))
file_path = os.path.join(directory, file)
img = cv2.imread(file_path)
height, width = img.shape[:2]
# Create grayscale image array
col_mask = np.ones((height, width), dtype=np.int32) * 255
chunk_file = os.path.join(chunk_directory, file.replace('.png', '.chunk'))
structure_file = os.path.join(structure_directory, file.replace('.png', '.json'))
table_net_data_file = os.path.join(table_net_data_directory, file.replace('.png', '.json'))
with open(chunk_file, encoding='utf-8') as f:
chunk_data_list = json.load(f)['chunks']
with open(structure_file, encoding='utf-8') as f:
structure_data_list = json.load(f)['cells']
table_xmin = int(min([bndbox[0] for bndbox in [cell['pos'] for cell in chunk_data_list]]))
table_xmax = int(max([bndbox[1] for bndbox in [cell['pos'] for cell in chunk_data_list]]))
table_ymin = int(min([bndbox[2] for bndbox in [cell['pos'] for cell in chunk_data_list]]))
table_ymax = int(max([bndbox[3] for bndbox in [cell['pos'] for cell in chunk_data_list]]))
width_ratio = (table_xmax - table_xmin) / width
height_ration = (table_ymax - table_ymin) / height
row_col_bound_dict = {'row': {}, 'col': {}}
# 先得到每行的ymin与ymax,每列的xmin与xmax
for structure in structure_data_list:
index = structure['id']
bndbox = chunk_data_list[index]['pos']
xmin = int((int(bndbox[0]) - table_xmin) / width_ratio)
xmax = int((int(bndbox[1]) - table_xmin) / width_ratio)
# SciTSR的坐标,x0, x1是顺序的,但是y0, y1是倒序的,即坐标系y轴是从下往上的。
# 所以相对坐标应该是,cell_ymin = -(y2 - ymax) cell_ymax = -(y1 - ymax)
if is_scitsr:
ymin = int(abs(int(bndbox[3]) - table_ymax) / height_ration)
ymax = int(abs(int(bndbox[2]) - table_ymax) / height_ration)
else:
ymin = int((int(bndbox[2]) - table_ymin) / height_ration)
ymax = int((int(bndbox[3]) - table_ymin) / height_ration)
start_row = structure['start_row']
if row_col_bound_dict['row'].get(start_row, None) is None:
row_col_bound_dict['row'][start_row] = {'ymin': ymin}
else:
if row_col_bound_dict['row'][start_row].get('ymin', None) is None:
row_col_bound_dict['row'][start_row]['ymin'] = ymin
else:
if row_col_bound_dict['row'][start_row]['ymin'] > ymin:
row_col_bound_dict['row'][start_row]['ymin'] = ymin
end_row = structure['end_row']
if row_col_bound_dict['row'].get(end_row, None) is None:
row_col_bound_dict['row'][end_row] = {'ymax': ymax}
else:
if row_col_bound_dict['row'][end_row].get('ymax', None) is None:
row_col_bound_dict['row'][end_row]['ymax'] = ymax
else:
if row_col_bound_dict['row'][end_row]['ymax'] < ymax:
row_col_bound_dict['row'][end_row]['ymax'] = ymax
start_col = structure['start_col']
if row_col_bound_dict['col'].get(start_col, None) is None:
row_col_bound_dict['col'][start_col] = {'xmin': xmin}
else:
if row_col_bound_dict['col'][start_col].get('xmin', None) is None:
row_col_bound_dict['col'][start_col]['xmin'] = xmin
else:
if row_col_bound_dict['col'][start_col]['xmin'] > xmin:
row_col_bound_dict['col'][start_col]['xmin'] = xmin
end_col = structure['end_col']
if row_col_bound_dict['col'].get(end_col, None) is None:
row_col_bound_dict['col'][end_col] = {'xmax': xmax}
else:
if row_col_bound_dict['col'][end_col].get('xmax', None) is None:
row_col_bound_dict['col'][end_col]['xmax'] = xmax
else:
if row_col_bound_dict['col'][end_col]['xmax'] < xmax:
row_col_bound_dict['col'][end_col]['xmax'] = xmax
table_net_data = {'version': '3.16.7',
'flags': {},
'lineColor': [0, 255, 0, 128],
'fillColor': [255, 0, 0, 128],
'imagePath': file_path,
'shapes': []}
# 首先添加表边框四条线
table_net_data['shapes'].append({'label': '0',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': [[0, 0], [width, 0]]})
table_net_data['shapes'].append({'label': '0',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': [[0, height - 1],
[width, height - 1]]})
table_net_data['shapes'].append({'label': '1',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': [[0, 0],
[0, height]]})
table_net_data['shapes'].append({'label': '1',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': [[width - 1, 0],
[width - 1, height]]})
row_list = list(row_col_bound_dict.get('row', {}).keys())
row_list.sort()
col_list = list(row_col_bound_dict.get('col', {}).keys())
col_list.sort()
# 只添加结束行横线
for row in row_list:
if row == row_list[-1]:
continue
same_row_list = [structure for structure
in structure_data_list
if structure['start_row'] == row]
line_points = None
for cell in same_row_list:
end_row = cell['end_row']
start_col = cell['start_col']
end_col = cell['end_col']
xmin = row_col_bound_dict['col'][start_col]['xmin']
xmax = row_col_bound_dict['col'][end_col]['xmax']
y = row_col_bound_dict['row'][end_row]['ymax']
if end_row == row:
if line_points is None:
line_points = [[xmin, y], [xmax, y]]
else:
line_points[1][0] = xmax
else:
if line_points is not None:
table_net_data['shapes'].append({'label': '0',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': line_points})
line_points = [[xmin, y], [xmax, y]]
table_net_data['shapes'].append({'label': '0',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': line_points})
line_points = None
if line_points is not None:
table_net_data['shapes'].append({'label': '0',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': line_points})
# 只添加结束列竖线
for col in col_list:
if col == col_list[-1]:
continue
same_col_list = [structure for structure
in structure_data_list
if structure['start_col'] == col]
line_points = None
for cell in same_col_list:
end_col = cell['end_col']
start_row = cell['start_row']
end_row = cell['end_row']
if start_row - 1 >= 0:
ymin = row_col_bound_dict['row'][start_row - 1]['ymax']
else:
ymin = row_col_bound_dict['row'][start_row]['ymin']
ymax = row_col_bound_dict['row'][end_row]['ymax']
x = row_col_bound_dict['col'][end_col]['xmax']
if end_col == col:
if line_points is None:
line_points = [[x, ymin], [x, ymax]]
else:
line_points[1][1] = ymax
else:
if line_points is not None:
table_net_data['shapes'].append({'label': '1',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': line_points})
line_points = [[x, ymin], [x, ymax]]
table_net_data['shapes'].append({'label': '1',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': line_points})
line_points = None
if line_points is not None:
table_net_data['shapes'].append({'label': '1',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': line_points})
draw_line(col_mask, table_net_data['shapes'])
cv2.imwrite(os.path.join(final_cell_directory, file), col_mask)
draw_line(img, table_net_data['shapes'])
updated = file.replace('.png', '') + '_line.png'
cv2.imwrite(os.path.join(final_cell_directory, updated), img)
with open(table_net_data_file, "w", encoding='utf-8') as f:
json.dump(table_net_data, f, indent=True, ensure_ascii=False)
def draw_line(col_mask, line_list: list):
for line in line_list:
points = line['points']
thickness = 1
cv2.line(col_mask,
(points[0][0], points[0][1]),
(points[1][0], points[1][1]),
(0, 0, 128),
lineType=thickness)
def draw_cell(col_mask,
structure_data_list: list,
row_col_bound_dict: dict,
is_scitsr: bool,
table_xmin,
table_xmax,
table_ymin,
table_ymax):
for structure in structure_data_list:
start_row = structure['start_row']
end_row = structure['end_row']
start_col = structure['start_col']
end_col = structure['end_col']
xmin = row_col_bound_dict['col'].get(start_col, {}).get('xmin', None)
xmax = row_col_bound_dict['col'].get(end_col, {}).get('xmax', None)
ymin = row_col_bound_dict['row'].get(start_row, {}).get('ymin', None)
ymax = row_col_bound_dict['row'].get(end_row, {}).get('ymax', None)
if xmin is not None and xmax is not None and ymin is not None and ymax is not None:
draw_rectangle(col_mask, xmin, xmax, ymin, ymax)
if is_scitsr:
draw_rectangle(col_mask, 0, table_xmax - table_xmin, 0, abs(table_ymin - table_ymax))
else:
draw_rectangle(col_mask, 0, table_xmax - table_xmin, 0, table_ymax - table_ymin)
def draw_rectangle(img, xmin, xmax, ymin, ymax):
thickness = 1
# 上横线
cv2.line(img, (xmin, ymin), (xmax, ymin), (0, 0, 128), lineType=thickness)
# 左竖线
cv2.line(img, (xmin, ymin), (xmin, ymax), (0, 0, 128), lineType=thickness)
# 右竖线
cv2.line(img, (xmax, ymin), (xmax, ymax), (0, 0, 128), lineType=thickness)
# 下横线
cv2.line(img, (xmin, ymax), (xmax, ymax), (0, 0, 128), lineType=thickness)
if __name__ == '__main__':
draw_cell_line(root_path=r'/data/pubtabnet/val/', is_scitsr=False)
示例
原图如下:
PMC1181812_008_00.png
纯线段图如下:
PMC1181812_008_00.png
将纯线段与原图叠加如下:
PMC1181812_008_00_line.png
效果不是完美,是因为数据源提供的坐标仅仅是各个单元格的文本区域bounding box,之后会进一步完善
网友评论