美文网首页Python机器学习
PubTabNet数据集介绍(有干货)

PubTabNet数据集介绍(有干货)

作者: blade_he | 来源:发表于2020-12-24 16:33 被阅读0次

前言

PubTabNet是IBM公司公布的基于图像的表格识别数据集。
其包含了568k+表格图片,其标注数据是HTML的表格结构,下载压缩包磁盘存储大小10G+。
GitHub相关地址
IBM的下载地址
相关论文:
Image-based table recognition: data, model, and evaluation
此篇论文的核心在于通过Encoder - double Decoder实现表格结构与单元格的识别

2020-12-24 15_28_04-Start.png
目前暂时没有GitHub复现项目。

PubTabNet数据

表格图片

该数据集的表格都是PDF截图,清晰度不是很高
示例如下:


PMC2838834_005_00.png

训练结构数据

训练结构数据位于.jsonl文件中,该文件每一行都是一条json数据,可以通过jsonlines库逐条读取。
数据结构:

  • imgid: 图像id
  • html: html单元格详细描述
    cells:单元格列表
    cells列表成员:
    tokens:单元格文本字符级别信息,
    bbox:单元格文本范围的bounding box,这个并不是单元格范围,而是单元格内文本范围!数字坐标说明:x_min, y_min, x_max, y_max,即文本范围对角线两个点的坐标。
    该坐标应该是相应PDF页面中的坐标,并不是从0开始,具体使用时,需要考虑图片真实大小,并进行换算才能使用。
    structure字典
    structure字典成员:
    tokens:对应cells的HTML格式
  • split:表示训练数据或验证数据,分别为train与val
  • filename:图像名称,如PMC2838834_005_00.png

示例如下(只贴出两个单元格以及部分HTML结构数据):

{
  "imgid": 4,
  "html": {
    "cells": [
      {
        "tokens": [
          "<b>",
          "M",
          "a",
          "i",
          "n",
          " ",
          "c",
          "e",
          "l",
          "l",
          "u",
          "l",
          "a",
          "r",
          " ",
          "p",
          "r",
          "o",
          "c",
          "e",
          "s",
          "s",
          "</b>"
        ],
        "bbox": [
          1,
          4,
          76,
          13
        ]
      },
      {
        "tokens": [
          "<b>",
          "M",
          "o",
          "d",
          "u",
          "l",
          "a",
          "t",
          "e",
          "d",
          " ",
          "p",
          "a",
          "t",
          "h",
          "w",
          "a",
          "y",
          "s",
          "</b>"
        ],
        "bbox": [
          92,
          4,
          167,
          13
        ]
      }
    ],
    "structure": {
      "tokens": [
        "<thead>",
        "<tr>",
        "<td>",
        "</td>",
        "<td>",
        "</td>",
        "<td",
        " colspan=\"2\"",
        ">",
        "</td>",
        "<td",
        " colspan=\"3\"",
        ">",
        "</td>",
        "</tr>",
        "...",
        "..."
      ]
    }
  },
  "split": "train",
  "filename": "PMC2838834_005_00.png"
}

位于GitHun的PubTabNet相关代码,只有读取数据,将数据转换为HTML的功能,并没有表格识别相关的代码。

PubTabNet转换为SciTSR训练数据格式

因为之前在研究SciTSR数据,所以需要将PubTabNet的训练集数据转换为SciTSR的格式,从而扩大训练集,即我们需要根据.jsonl中的数据,提炼出chunk, structure数据来。
SciTSR的介绍链接在此
这一段是干货,因为这意味着不同数据集可以通用了。

import tqdm
import os
import jsonlines
import json
import logging
import sys
import re

sys.path.insert(0, os.path.abspath('../'))


class Transform:
    def transform(self, pubtabnet_data_file: str):
        if pubtabnet_data_file is None or not os.path.exists(pubtabnet_data_file):
            logging.error('No PubTabNet data file!')
            return
        root_path = os.path.dirname(pubtabnet_data_file)

        with open(pubtabnet_data_file, encoding='utf-8') as reader:
            for img in jsonlines.Reader(reader):
                # if img['filename'] == 'PMC5577841_001_00.png':
                # print(img)
                img_filename = img['filename']

                img_type = img['split']
                gfte_structure_folder = '{0}/{1}/structure/'.format(root_path, img_type)
                if not os.path.exists(gfte_structure_folder):
                    os.makedirs(gfte_structure_folder)
                gfte_chunk_folder = '{0}/{1}/chunk/'.format(root_path, img_type)
                if not os.path.exists(gfte_chunk_folder):
                    os.makedirs(gfte_chunk_folder)

                print('Handle image: {0}'.format(img_filename))
                cells = img['html']['cells']
                structure_list = img['html']['structure']['tokens']
                gfte_structure = self.get_row_col_position(structure_list)

                gfte_structure_file = '{0}.json'.format(img_filename.replace('.png', ''))
                gfte_chunk_file = '{0}.chunk'.format(img_filename.replace('.png', ''))
                structure_data_dict, chunk_data_dict = self.transfer_GFTE_data(cells, gfte_structure)
                if isinstance(structure_data_dict, dict) and \
                        isinstance(chunk_data_dict, dict):
                    with open(os.path.join(gfte_structure_folder,
                                           gfte_structure_file), "w", encoding='utf-8') as write_file:
                        json.dump(structure_data_dict, write_file)

                    with open(os.path.join(gfte_chunk_folder,
                                           gfte_chunk_file), "w", encoding='utf-8') as write_file:
                        json.dump(chunk_data_dict, write_file)

    def get_row_col_position(self, structure_list: list):
        row_index = 0
        col_index = 0
        gfte_structure = []
        for index, table_element in enumerate(structure_list):
            if table_element == '<tr>':
                col_index = 0
            if table_element == '</tr>':
                row_index += 1
            if table_element == '<td>':
                structure_dict = {'start_row': row_index,
                                  'end_row': row_index,
                                  'start_col': col_index,
                                  'end_col': col_index}
                col_index += 1
                gfte_structure.append(structure_dict)
                continue
            if table_element == '<td' and index < len(structure_list) - 1:
                next_element = structure_list[index + 1]
                next_element_split = next_element.strip().replace('"', '').split('=')
                if len(next_element_split) == 2:
                    span_number = int(next_element_split[1])
                    if next_element_split[0] == 'colspan':
                        structure_dict = {'start_row': row_index,
                                          'end_row': row_index,
                                          'start_col': col_index,
                                          'end_col': col_index + span_number - 1}
                        col_index += span_number
                        gfte_structure.append(structure_dict)
                        continue
                    if next_element_split[0] == 'rowspan':
                        structure_dict = {'start_row': row_index,
                                          'end_row': row_index + span_number - 1,
                                          'start_col': col_index,
                                          'end_col': col_index}
                        col_index += 1
                        gfte_structure.append(structure_dict)
                        continue
        return gfte_structure

    def transfer_GFTE_data(self, cell_list: list, structure_list: list):
        logging.info('Construct GFTE structure data')
        if cell_list is None or \
                structure_list is None or \
                not isinstance(cell_list, list) or \
                not isinstance(structure_list, list) or \
                len(cell_list) != len(structure_list):
            logging.error('Cell is not pair of structure!')
            return -1, -1
        structure_data_dict = {'cells': []}
        chunk_data_dict = {'chunks': []}
        for index, (cell, structure) in enumerate(zip(cell_list, structure_list)):
            structure_data = {}
            structure_data['id'] = index
            raw_text = ''.join(cell['tokens']).strip()
            structure_data['tex'] = raw_text
            pure_text = re.sub(r'\<.*?\>', '', raw_text).strip()
            structure_data['content'] = pure_text.split()
            structure_data.update(structure)
            structure_data_dict['cells'].append(structure_data)

            chunk_data = {}
            if cell.get('bbox', None) is not None:
                chunk_data['pos'] = [cell['bbox'][0],
                                     cell['bbox'][2],
                                     cell['bbox'][1],
                                     cell['bbox'][3]]
            else:
                chunk_data['pos'] = [0, 0, 0, 0]
            chunk_data['text'] = raw_text
            chunk_data_dict['chunks'].append(chunk_data)
        for index, (structure_data, chunk_data) in enumerate(zip(structure_data_dict['cells'],
                                                                 chunk_data_dict['chunks'])):
            # x0, x1, y0, y1
            if chunk_data['pos'] == [0, 0, 0, 0]:
                id = structure_data['id']
                start_row = structure_data['start_row']
                end_row = structure_data['end_row']
                start_col = structure_data['start_col']
                end_col = structure_data['end_col']
                x0 = None
                x1 = None
                y0 = None
                y1 = None
                for structure, chunk in zip(structure_data_dict['cells'],
                                            chunk_data_dict['chunks']):
                    if structure['id'] != id and chunk['pos'] != [0, 0, 0, 0]:
                        if y0 is None and structure['start_row'] == start_row:
                            y0 = chunk['pos'][2]
                        if y1 is None and structure['end_row'] == end_row:
                            y1 = chunk['pos'][3]
                        if x0 is None and structure['start_col'] == start_col:
                            x0 = chunk['pos'][0]
                        if x1 is None and structure['end_col'] == end_col:
                            x1 = chunk['pos'][1]
                    if x0 is not None and x1 is not None and y0 is not None and y1 is not None:
                        chunk_data['pos'] = [x0, x1, y0, y1]
                        break
        return structure_data_dict, chunk_data_dict


if __name__ == "__main__":
    transform = Transform()
    transform.transform('/data/pubtabnet/PubTabNet_2.0.0.jsonl')
    # transform.transform('/data/scitsr/examples/PubTabNet_Examples.jsonl')

UNET所需数据准备

UNET非常适合图像分割,对于表格来说,如果能够通过UNET将表格中行与列进行标注,则能够方便根据表格结构提取各个单元格信息。
事实上,UNET预测的是“线”,训练数据类似于如下格式,即两点一线,标签为0代表横线,标签为1代表竖线

{
   "label": "0",
   "line_color": [
    0,
    0,
    128
   ],
   "fill_color": [
    0,
    0,
    128
   ],
   "points": [
    [
     0,
     0
    ],
    [
     503,
     0
    ]
   ]
  },
  {
   "label": "0",
   "line_color": [
    0,
    0,
    128
   ],
   "fill_color": [
    0,
    0,
    128
   ],
   "points": [
    [
     0,
     275
    ],
    [
     503,
     275
    ]
   ]
  },
  {
   "label": "1",
   "line_color": [
    0,
    0,
    128
   ],
   "fill_color": [
    0,
    0,
    128
   ],
   "points": [
    [
     0,
     0
    ],
    [
     0,
     276
    ]
   ]
  }

训练数据转换代码,则是根据SciTSR的训练数据:chunk与structure得到。
如下代码实现了:

  1. 根据每行每列的最大与最小坐标,以及邻接单元格信息,尝试绘制表格中的横线与竖线。而这仅仅是根据原数据中的单元格中文本框的坐标计算得到的
  2. 将得到的线绘制到单独一张图片,构成训练数据示意图片
import os
import numpy as np
from PIL import Image
import cv2
import json


# Returns if columns belong to same table or not
def sameTable(ymin_1, ymin_2, ymax_1, ymax_2):
    min_diff = abs(ymin_1 - ymin_2)
    max_diff = abs(ymax_1 - ymax_2)

    if min_diff <= 5 and max_diff <= 5:
        return True
    elif min_diff <= 4 and max_diff <= 7:
        return True
    elif min_diff <= 7 and max_diff <= 4:
        return True
    return False


def draw_cell_line(root_path: str = r'/data/scitsr/train', is_scitsr: bool = True):
    """
    依赖PubTabNet的图像数据,以及SciTSR的数据格式,转换为table_net需要的数据结构
    :param root_path:
    :param is_scitsr:
    :return:
    """
    directory = os.path.join(root_path, 'img')
    chunk_directory = os.path.join(root_path, 'chunk')
    structure_directory = os.path.join(root_path, 'structure')
    final_cell_directory = os.path.join(root_path, 'cell_mask_img')
    if not os.path.exists(final_cell_directory):
        os.makedirs(final_cell_directory)
    final_table_directory = os.path.join(root_path, 'table_mask_img')
    if not os.path.exists(final_table_directory):
        os.makedirs(final_table_directory)
    table_net_data_directory = os.path.join(root_path, 'table_net_data')
    if not os.path.exists(table_net_data_directory):
        os.makedirs(table_net_data_directory)

    files = os.listdir(directory)
    for index, file in enumerate(files):
        if index == 200:
            break
        print('Handle the {0}/ {1} file: {2}'.format(index + 1, len(files), file))
        file_path = os.path.join(directory, file)
        img = cv2.imread(file_path)
        height, width = img.shape[:2]
        # Create grayscale image array
        col_mask = np.ones((height, width), dtype=np.int32) * 255

        chunk_file = os.path.join(chunk_directory, file.replace('.png', '.chunk'))
        structure_file = os.path.join(structure_directory, file.replace('.png', '.json'))
        table_net_data_file = os.path.join(table_net_data_directory, file.replace('.png', '.json'))
        with open(chunk_file, encoding='utf-8') as f:
            chunk_data_list = json.load(f)['chunks']
        with open(structure_file, encoding='utf-8') as f:
            structure_data_list = json.load(f)['cells']

        table_xmin = int(min([bndbox[0] for bndbox in [cell['pos'] for cell in chunk_data_list]]))
        table_xmax = int(max([bndbox[1] for bndbox in [cell['pos'] for cell in chunk_data_list]]))
        table_ymin = int(min([bndbox[2] for bndbox in [cell['pos'] for cell in chunk_data_list]]))
        table_ymax = int(max([bndbox[3] for bndbox in [cell['pos'] for cell in chunk_data_list]]))

        width_ratio = (table_xmax - table_xmin) / width
        height_ration = (table_ymax - table_ymin) / height

        row_col_bound_dict = {'row': {}, 'col': {}}
        # 先得到每行的ymin与ymax,每列的xmin与xmax
        for structure in structure_data_list:
            index = structure['id']
            bndbox = chunk_data_list[index]['pos']
            xmin = int((int(bndbox[0]) - table_xmin) / width_ratio)
            xmax = int((int(bndbox[1]) - table_xmin) / width_ratio)
            # SciTSR的坐标,x0, x1是顺序的,但是y0, y1是倒序的,即坐标系y轴是从下往上的。
            # 所以相对坐标应该是,cell_ymin = -(y2 - ymax) cell_ymax = -(y1 - ymax)
            if is_scitsr:
                ymin = int(abs(int(bndbox[3]) - table_ymax) / height_ration)
                ymax = int(abs(int(bndbox[2]) - table_ymax) / height_ration)
            else:
                ymin = int((int(bndbox[2]) - table_ymin) / height_ration)
                ymax = int((int(bndbox[3]) - table_ymin) / height_ration)

            start_row = structure['start_row']
            if row_col_bound_dict['row'].get(start_row, None) is None:
                row_col_bound_dict['row'][start_row] = {'ymin': ymin}
            else:
                if row_col_bound_dict['row'][start_row].get('ymin', None) is None:
                    row_col_bound_dict['row'][start_row]['ymin'] = ymin
                else:
                    if row_col_bound_dict['row'][start_row]['ymin'] > ymin:
                        row_col_bound_dict['row'][start_row]['ymin'] = ymin

            end_row = structure['end_row']
            if row_col_bound_dict['row'].get(end_row, None) is None:
                row_col_bound_dict['row'][end_row] = {'ymax': ymax}
            else:
                if row_col_bound_dict['row'][end_row].get('ymax', None) is None:
                    row_col_bound_dict['row'][end_row]['ymax'] = ymax
                else:
                    if row_col_bound_dict['row'][end_row]['ymax'] < ymax:
                        row_col_bound_dict['row'][end_row]['ymax'] = ymax

            start_col = structure['start_col']
            if row_col_bound_dict['col'].get(start_col, None) is None:
                row_col_bound_dict['col'][start_col] = {'xmin': xmin}
            else:
                if row_col_bound_dict['col'][start_col].get('xmin', None) is None:
                    row_col_bound_dict['col'][start_col]['xmin'] = xmin
                else:
                    if row_col_bound_dict['col'][start_col]['xmin'] > xmin:
                        row_col_bound_dict['col'][start_col]['xmin'] = xmin

            end_col = structure['end_col']
            if row_col_bound_dict['col'].get(end_col, None) is None:
                row_col_bound_dict['col'][end_col] = {'xmax': xmax}
            else:
                if row_col_bound_dict['col'][end_col].get('xmax', None) is None:
                    row_col_bound_dict['col'][end_col]['xmax'] = xmax
                else:
                    if row_col_bound_dict['col'][end_col]['xmax'] < xmax:
                        row_col_bound_dict['col'][end_col]['xmax'] = xmax

        table_net_data = {'version': '3.16.7',
                          'flags': {},
                          'lineColor': [0, 255, 0, 128],
                          'fillColor': [255, 0, 0, 128],
                          'imagePath': file_path,
                          'shapes': []}
        # 首先添加表边框四条线
        table_net_data['shapes'].append({'label': '0',
                                         'line_color': [0, 0, 128],
                                         'fill_color': [0, 0, 128],
                                         'points': [[0, 0], [width, 0]]})
        table_net_data['shapes'].append({'label': '0',
                                         'line_color': [0, 0, 128],
                                         'fill_color': [0, 0, 128],
                                         'points': [[0, height - 1],
                                                    [width, height - 1]]})
        table_net_data['shapes'].append({'label': '1',
                                         'line_color': [0, 0, 128],
                                         'fill_color': [0, 0, 128],
                                         'points': [[0, 0],
                                                    [0, height]]})
        table_net_data['shapes'].append({'label': '1',
                                         'line_color': [0, 0, 128],
                                         'fill_color': [0, 0, 128],
                                         'points': [[width - 1, 0],
                                                    [width - 1, height]]})

        row_list = list(row_col_bound_dict.get('row', {}).keys())
        row_list.sort()
        col_list = list(row_col_bound_dict.get('col', {}).keys())
        col_list.sort()
        # 只添加结束行横线
        for row in row_list:
            if row == row_list[-1]:
                continue
            same_row_list = [structure for structure
                             in structure_data_list
                             if structure['start_row'] == row]
            line_points = None
            for cell in same_row_list:
                end_row = cell['end_row']
                start_col = cell['start_col']
                end_col = cell['end_col']
                xmin = row_col_bound_dict['col'][start_col]['xmin']
                xmax = row_col_bound_dict['col'][end_col]['xmax']
                y = row_col_bound_dict['row'][end_row]['ymax']
                if end_row == row:
                    if line_points is None:
                        line_points = [[xmin, y], [xmax, y]]
                    else:
                        line_points[1][0] = xmax
                else:
                    if line_points is not None:
                        table_net_data['shapes'].append({'label': '0',
                                                         'line_color': [0, 0, 128],
                                                         'fill_color': [0, 0, 128],
                                                         'points': line_points})
                    line_points = [[xmin, y], [xmax, y]]
                    table_net_data['shapes'].append({'label': '0',
                                                     'line_color': [0, 0, 128],
                                                     'fill_color': [0, 0, 128],
                                                     'points': line_points})
                    line_points = None
            if line_points is not None:
                table_net_data['shapes'].append({'label': '0',
                                                 'line_color': [0, 0, 128],
                                                 'fill_color': [0, 0, 128],
                                                 'points': line_points})

        # 只添加结束列竖线
        for col in col_list:
            if col == col_list[-1]:
                continue
            same_col_list = [structure for structure
                             in structure_data_list
                             if structure['start_col'] == col]
            line_points = None
            for cell in same_col_list:
                end_col = cell['end_col']
                start_row = cell['start_row']
                end_row = cell['end_row']
                if start_row - 1 >= 0:
                    ymin = row_col_bound_dict['row'][start_row - 1]['ymax']
                else:
                    ymin = row_col_bound_dict['row'][start_row]['ymin']
                ymax = row_col_bound_dict['row'][end_row]['ymax']
                x = row_col_bound_dict['col'][end_col]['xmax']
                if end_col == col:
                    if line_points is None:
                        line_points = [[x, ymin], [x, ymax]]
                    else:
                        line_points[1][1] = ymax
                else:
                    if line_points is not None:
                        table_net_data['shapes'].append({'label': '1',
                                                         'line_color': [0, 0, 128],
                                                         'fill_color': [0, 0, 128],
                                                         'points': line_points})
                    line_points = [[x, ymin], [x, ymax]]
                    table_net_data['shapes'].append({'label': '1',
                                                     'line_color': [0, 0, 128],
                                                     'fill_color': [0, 0, 128],
                                                     'points': line_points})
                    line_points = None
            if line_points is not None:
                table_net_data['shapes'].append({'label': '1',
                                                 'line_color': [0, 0, 128],
                                                 'fill_color': [0, 0, 128],
                                                 'points': line_points})

        draw_line(col_mask, table_net_data['shapes'])
        cv2.imwrite(os.path.join(final_cell_directory, file), col_mask)

        draw_line(img, table_net_data['shapes'])
        updated = file.replace('.png', '') + '_line.png'
        cv2.imwrite(os.path.join(final_cell_directory, updated), img)
        with open(table_net_data_file, "w", encoding='utf-8') as f:
            json.dump(table_net_data, f, indent=True, ensure_ascii=False)


def draw_line(col_mask, line_list: list):
    for line in line_list:
        points = line['points']
        thickness = 1
        cv2.line(col_mask,
                 (points[0][0], points[0][1]),
                 (points[1][0], points[1][1]),
                 (0, 0, 128),
                 lineType=thickness)


def draw_cell(col_mask,
              structure_data_list: list,
              row_col_bound_dict: dict,
              is_scitsr: bool,
              table_xmin,
              table_xmax,
              table_ymin,
              table_ymax):
    for structure in structure_data_list:
        start_row = structure['start_row']
        end_row = structure['end_row']
        start_col = structure['start_col']
        end_col = structure['end_col']

        xmin = row_col_bound_dict['col'].get(start_col, {}).get('xmin', None)
        xmax = row_col_bound_dict['col'].get(end_col, {}).get('xmax', None)
        ymin = row_col_bound_dict['row'].get(start_row, {}).get('ymin', None)
        ymax = row_col_bound_dict['row'].get(end_row, {}).get('ymax', None)

        if xmin is not None and xmax is not None and ymin is not None and ymax is not None:
            draw_rectangle(col_mask, xmin, xmax, ymin, ymax)
    if is_scitsr:
        draw_rectangle(col_mask, 0, table_xmax - table_xmin, 0, abs(table_ymin - table_ymax))
    else:
        draw_rectangle(col_mask, 0, table_xmax - table_xmin, 0, table_ymax - table_ymin)


def draw_rectangle(img, xmin, xmax, ymin, ymax):
    thickness = 1
    # 上横线
    cv2.line(img, (xmin, ymin), (xmax, ymin), (0, 0, 128), lineType=thickness)
    # 左竖线
    cv2.line(img, (xmin, ymin), (xmin, ymax), (0, 0, 128), lineType=thickness)
    # 右竖线
    cv2.line(img, (xmax, ymin), (xmax, ymax), (0, 0, 128), lineType=thickness)
    # 下横线
    cv2.line(img, (xmin, ymax), (xmax, ymax), (0, 0, 128), lineType=thickness)


if __name__ == '__main__':
    draw_cell_line(root_path=r'/data/pubtabnet/val/', is_scitsr=False)

示例

原图如下:


PMC1181812_008_00.png

纯线段图如下:


PMC1181812_008_00.png
将纯线段与原图叠加如下:
PMC1181812_008_00_line.png

效果不是完美,是因为数据源提供的坐标仅仅是各个单元格的文本区域bounding box,之后会进一步完善

相关文章

网友评论

    本文标题:PubTabNet数据集介绍(有干货)

    本文链接:https://www.haomeiwen.com/subject/wrumnktx.html