Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 176 additions & 95 deletions wired_table_rec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,16 @@
import logging
import time
import traceback
from dataclasses import dataclass, asdict
from enum import Enum
from pathlib import Path
from typing import List, Optional, Union, Dict, Any
from typing import List, Optional, Tuple, Union, Dict, Any
import numpy as np
import cv2

from wired_table_rec.table_structure_cycle_center_net import TSRCycleCenterNet
from wired_table_rec.table_structure_unet import TSRUnet
from wired_table_rec.utils.download_model import DownloadModel
from wired_table_rec.table_line_rec import TableLineRecognition
from wired_table_rec.table_line_rec_plus import TableLineRecognitionPlus
from .table_recover import TableRecover
from .utils.utils import InputType, LoadImage
from wired_table_rec.utils.utils_table_recover import (
from .utils import InputType, LoadImage
from .utils_table_recover import (
match_ocr_cell,
plot_html_table,
box_4_2_poly_to_box_4_1,
Expand All @@ -27,73 +24,54 @@
gather_ocr_list_by_row,
)


class ModelType(Enum):
CYCLE_CENTER_NET = "cycle_center_net"
UNET = "unet"


ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
KEY_TO_MODEL_URL = {
ModelType.CYCLE_CENTER_NET.value: f"{ROOT_URL}/cycle_center_net.onnx",
ModelType.UNET.value: f"{ROOT_URL}/unet.onnx",
}


@dataclass
class WiredTableInput:
model_type: Optional[str] = ModelType.UNET.value
model_path: Union[str, Path, None, Dict[str, str]] = None
use_cuda: bool = False
device: str = "cpu"


@dataclass
class WiredTableOutput:
pred_html: Optional[str] = None
cell_bboxes: Optional[np.ndarray] = None
logic_points: Optional[np.ndarray] = None
elapse: Optional[float] = None
cur_dir = Path(__file__).resolve().parent
default_model_path = cur_dir / "models" / "cycle_center_net_v1.onnx"
default_model_path_v2 = cur_dir / "models" / "cycle_center_net_v2.onnx"


class WiredTableRecognition:
def __init__(self, config: WiredTableInput):
self.model_type = config.model_type
if self.model_type not in KEY_TO_MODEL_URL:
model_list = ",".join(KEY_TO_MODEL_URL)
raise ValueError(
f"{self.model_type} is not supported. The currently supported models are {model_list}."
)

config.model_path = self.get_model_path(config.model_type, config.model_path)
if self.model_type == ModelType.CYCLE_CENTER_NET.value:
self.table_structure = TSRCycleCenterNet(asdict(config))
else:
self.table_structure = TSRUnet(asdict(config))

def __init__(self, table_model_path: Union[str, Path] = None, version="v2"):
self.load_img = LoadImage()
if version == "v2":
model_path = table_model_path if table_model_path else default_model_path_v2
self.table_line_rec = TableLineRecognitionPlus(str(model_path))
else:
model_path = table_model_path if table_model_path else default_model_path
self.table_line_rec = TableLineRecognition(str(model_path))

self.table_recover = TableRecover()

try:
self.ocr = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
except ModuleNotFoundError:
self.ocr = None

def __call__(
self,
img: InputType,
ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
**kwargs,
) -> WiredTableOutput:
) -> Tuple[str, float, Any, Any, Any]:
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __call__ method return type annotation is Tuple[str, float, Any, Any, Any] (5 elements), but the actual successful return statement on lines 126-134 returns a tuple with 6 elements: table_str, table_elapse, sorted_polygons, sorted_logi_points, sorted_ocr_boxes_res, and adjust_dict. The annotation should be Tuple[str, float, Any, Any, Any, Any] (6 elements) to match the actual return.

Suggested change
) -> Tuple[str, float, Any, Any, Any]:
) -> Tuple[str, float, Any, Any, Any, Any]:

Copilot uses AI. Check for mistakes.
if self.ocr is None and ocr_result is None:
raise ValueError(
"One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
)

s = time.perf_counter()
rec_again = True
need_ocr = True
col_threshold = 15
row_threshold = 10
if kwargs:
rec_again = kwargs.get("rec_again", True)
need_ocr = kwargs.get("need_ocr", True)
col_threshold = kwargs.get("col_threshold", 15)
row_threshold = kwargs.get("row_threshold", 10)
img = self.load_img(img)
polygons, rotated_polygons = self.table_structure(img, **kwargs)
polygons, rotated_polygons = self.table_line_rec(img, **kwargs)
if polygons is None:
logging.warning("polygons is None.")
return WiredTableOutput("", None, None, 0.0)
return "", 0.0, None, None, None

try:
table_res, logi_points = self.table_recover(
Expand All @@ -108,34 +86,52 @@ def __call__(
sorted_polygons, idx_list = sorted_ocr_boxes(
[box_4_2_poly_to_box_4_1(box) for box in polygons]
)
return WiredTableOutput(
return (
"",
time.perf_counter() - s,
sorted_polygons,
logi_points[idx_list],
time.perf_counter() - s,
[],
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The early returns in __call__ are inconsistent in the number of elements they return compared to the successful path:

  • Line 74 returns a 5-element tuple: ("", 0.0, None, None, None)
  • Lines 89-95 (when need_ocr=False) return a 5-element tuple: ("", elapsed, sorted_polygons, logi_points[idx_list], [])
  • Line 125 returns a 5-element tuple: ("", 0.0, None, None, None)
  • Lines 126-134 (success path) return a 6-element tuple including adjust_dict

Callers must handle varying-length tuples, making unpacking error-prone. All return paths should consistently return the same number of elements.

Suggested change
[],
[],
None,

Copilot uses AI. Check for mistakes.
)
if ocr_result is None and need_ocr:
ocr_result, _ = self.ocr(img)
cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
# 如果有识别框没有ocr结果,直接进行rec补充
cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map)
cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map, rec_again)
# 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
t_rec_ocr_list_dict = self.transform_res(cell_box_det_map, polygons, logi_points)
# 第一行或者第一列为空时,调整代码
#adjust_dict = self.adjust_table_cells(t_rec_ocr_list_dict)
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The commented-out line 104 (#adjust_dict = self.adjust_table_cells(t_rec_ocr_list_dict)) should be cleaned up. Leaving commented-out code in production code is a maintainability concern, especially when the active alternative (process_ocr_result) is right below it.

Suggested change
#adjust_dict = self.adjust_table_cells(t_rec_ocr_list_dict)

Copilot uses AI. Check for mistakes.
adjust_dict = self.process_ocr_result(t_rec_ocr_list_dict)
# 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list_dict)
Comment on lines +102 to +107
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The process_ocr_result method mutates the t_logic_box lists in the entries of the ocr_result list passed to it (lines 177–178 and 190–192: entry['t_logic_box'][0] -= 1 etc.). Since t_rec_ocr_list_dict is also referenced as t_rec_ocr_list_dict and then passed to sort_and_gather_ocr_res on line 107, mutating these entries in-place after returning from process_ocr_result (whose return value adjust_dict is stored separately) means the modifications affect the same objects that sort_and_gather_ocr_res will read. This unintended mutation could lead to double-adjustment of the logic boxes.

Copilot uses AI. Check for mistakes.
# cell_box_map =
logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list]
cell_box_det_map = {
i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
for i, t_box_ocr in enumerate(t_rec_ocr_list)
}
pred_html = plot_html_table(logi_points, cell_box_det_map)
polygons = np.array(polygons).reshape(-1, 8)
logi_points = np.array(logi_points)
elapse = time.perf_counter() - s
table_str = plot_html_table(logi_points, cell_box_det_map)
ocr_boxes_res = [
box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_result
]
sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res)
sorted_polygons = [box_4_2_poly_to_box_4_1(box) for box in polygons]
sorted_logi_points = logi_points
table_elapse = time.perf_counter() - s

except Exception:
logging.warning(traceback.format_exc())
return WiredTableOutput("", None, None, 0.0)
return WiredTableOutput(pred_html, polygons, logi_points, elapse)
return "", 0.0, None, None, None
return (
table_str,
table_elapse,
sorted_polygons,
sorted_logi_points,
sorted_ocr_boxes_res,
adjust_dict

)

def transform_res(
self,
Expand Down Expand Up @@ -166,6 +162,102 @@ def transform_res(
res.append(dict_res)
return res

def process_ocr_result(self, ocr_result):
# 删除第一行的字典,并调整其余字典的行数
first_row_empty = [entry for entry in ocr_result if
entry['t_logic_box'][0] == 0 and entry['t_logic_box'][1] == 0 and entry['t_ocr_res'][0][
1] == '']

if len(first_row_empty) == len(
[entry for entry in ocr_result if entry['t_logic_box'][0] == 0 and entry['t_logic_box'][1] == 0]):
# 如果第一行的所有单元格都为空,删除第一行
ocr_result = [entry for entry in ocr_result if entry['t_logic_box'][0] != 0 or entry['t_logic_box'][1] != 0]
# 调整剩余字典的行数
for entry in ocr_result:
entry['t_logic_box'][0] -= 1
entry['t_logic_box'][1] -= 1

# 删除第一列的字典,并调整其余字典的列数
first_col_empty = [entry for entry in ocr_result if
entry['t_logic_box'][2] == 0 and entry['t_logic_box'][3] == 0 and entry['t_ocr_res'][0][
1] == '']

if len(first_col_empty) == len(
[entry for entry in ocr_result if entry['t_logic_box'][2] == 0 and entry['t_logic_box'][3] == 0]):
# 如果第一列的所有单元格都为空,删除第一列
ocr_result = [entry for entry in ocr_result if entry['t_logic_box'][2] != 0 or entry['t_logic_box'][3] != 0]
Comment on lines +167 to +188
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The process_ocr_result method checks for first-row emptiness using entry['t_logic_box'][0] == 0 and entry['t_logic_box'][1] == 0. However, t_logic_box is [row_start, row_end, col_start, col_end], so t_logic_box[1] is row_end. Cells in the first row that span multiple rows (e.g. a merged cell spanning rows 0–1) have row_start=0 but row_end=1, so they would be excluded from first_row_empty. This means the condition len(first_row_empty) == len([...]) would fail for merged-cell first rows, causing the first row to never be removed. The check should likely use only entry['t_logic_box'][0] == 0 to identify all cells whose row starts at row 0. The same issue applies to the first-column check on line 182 (t_logic_box[2] == 0 and t_logic_box[3] == 0).

Suggested change
first_row_empty = [entry for entry in ocr_result if
entry['t_logic_box'][0] == 0 and entry['t_logic_box'][1] == 0 and entry['t_ocr_res'][0][
1] == '']
if len(first_row_empty) == len(
[entry for entry in ocr_result if entry['t_logic_box'][0] == 0 and entry['t_logic_box'][1] == 0]):
# 如果第一行的所有单元格都为空,删除第一行
ocr_result = [entry for entry in ocr_result if entry['t_logic_box'][0] != 0 or entry['t_logic_box'][1] != 0]
# 调整剩余字典的行数
for entry in ocr_result:
entry['t_logic_box'][0] -= 1
entry['t_logic_box'][1] -= 1
# 删除第一列的字典,并调整其余字典的列数
first_col_empty = [entry for entry in ocr_result if
entry['t_logic_box'][2] == 0 and entry['t_logic_box'][3] == 0 and entry['t_ocr_res'][0][
1] == '']
if len(first_col_empty) == len(
[entry for entry in ocr_result if entry['t_logic_box'][2] == 0 and entry['t_logic_box'][3] == 0]):
# 如果第一列的所有单元格都为空,删除第一列
ocr_result = [entry for entry in ocr_result if entry['t_logic_box'][2] != 0 or entry['t_logic_box'][3] != 0]
first_row_empty = [
entry
for entry in ocr_result
if entry['t_logic_box'][0] == 0 and entry['t_ocr_res'][0][1] == ''
]
if len(first_row_empty) == len(
[entry for entry in ocr_result if entry['t_logic_box'][0] == 0]
):
# 如果第一行的所有单元格都为空,删除第一行
ocr_result = [
entry for entry in ocr_result if entry['t_logic_box'][0] != 0
]
# 调整剩余字典的行数
for entry in ocr_result:
entry['t_logic_box'][0] -= 1
entry['t_logic_box'][1] -= 1
# 删除第一列的字典,并调整其余字典的列数
first_col_empty = [
entry
for entry in ocr_result
if entry['t_logic_box'][2] == 0 and entry['t_ocr_res'][0][1] == ''
]
if len(first_col_empty) == len(
[entry for entry in ocr_result if entry['t_logic_box'][2] == 0]
):
# 如果第一列的所有单元格都为空,删除第一列
ocr_result = [
entry for entry in ocr_result if entry['t_logic_box'][2] != 0
]

Copilot uses AI. Check for mistakes.
# 调整剩余字典的列数
for entry in ocr_result:
entry['t_logic_box'][2] -= 1
entry['t_logic_box'][3] -= 1
Comment on lines +174 to +192
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In process_ocr_result, line 174 removes all entries where t_logic_box[0] != 0 or t_logic_box[1] != 0. This condition using or is incorrect: it will remove entries that have t_logic_box[0] == 0 (start of first row) as long as t_logic_box[1] != 0, keeping only the entries where BOTH are 0. The correct logic to filter out first-row entries would be entry['t_logic_box'][0] != 0 (entries whose row start is not 0). Similarly, line 188 uses entry['t_logic_box'][2] != 0 or entry['t_logic_box'][3] != 0 instead of entry['t_logic_box'][2] != 0.

Copilot uses AI. Check for mistakes.

return ocr_result

def adjust_table_cells(self, t_rec_ocr_list_dict):
"""
调整表格单元格,去掉第一行和/或第一列的单元格,
并更新剩余单元格的行列起始和结束位置。

参数:
t_rec_ocr_list_dict (list): 原始表格单元格识别结果,格式为
[
{
"t_box": [xmin, ymin, xmax, ymax],
"t_logic_box": [row_start, row_end, col_start, col_end],
"t_ocr_res": [[box, text], ...]
},
...
]

返回:
list: 调整后的表格单元格识别结果,格式与输入相同。
"""
# 新的结果列表
adjusted_result = []

# 记录是否第一行和第一列的单元格已被删除
remove_first_row = False
remove_first_col = False

# 检查并移除第一行
if all(cell and not cell[1] for cell in t_rec_ocr_list_dict[0].get("t_ocr_res", [])):
remove_first_row = True

# 检查并移除第一列
if all(row.get("t_ocr_res") and not row["t_ocr_res"][0][1] for row in t_rec_ocr_list_dict):
remove_first_col = True

# 遍历原始结果进行调整
for i, row in enumerate(t_rec_ocr_list_dict):
adjusted_row = []

# 如果是第一行并且需要删除,跳过这行
if remove_first_row and i == 0:
continue

for j, cell in enumerate(row.get("t_ocr_res", [])):
# 如果是第一列并且需要删除,跳过这一列
if remove_first_col and j == 0:
continue

# 更新当前单元格的逻辑位置
adjusted_cell = {
"t_box": row.get("t_box"),
"t_logic_box": [
row["t_logic_box"][0] - 1 if i > 0 else row["t_logic_box"][0],
row["t_logic_box"][1] - 1 if i > 0 else row["t_logic_box"][1],
row["t_logic_box"][2] - 1 if j > 0 else row["t_logic_box"][2],
row["t_logic_box"][3] - 1 if j > 0 else row["t_logic_box"][3]
],
"t_ocr_res": cell
}
adjusted_row.append(adjusted_cell)

if adjusted_row:
adjusted_result.append(adjusted_row)

return adjusted_result

Comment on lines +196 to +260
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The adjust_table_cells method is defined but never called (the only call to it, on line 104, is commented out in favor of process_ocr_result). This is dead code that adds unnecessary complexity to the file. It should either be removed or its purpose compared with process_ocr_result to decide which one to keep.

Suggested change
def adjust_table_cells(self, t_rec_ocr_list_dict):
"""
调整表格单元格去掉第一行和/或第一列的单元格
并更新剩余单元格的行列起始和结束位置
参数:
t_rec_ocr_list_dict (list): 原始表格单元格识别结果格式为
[
{
"t_box": [xmin, ymin, xmax, ymax],
"t_logic_box": [row_start, row_end, col_start, col_end],
"t_ocr_res": [[box, text], ...]
},
...
]
返回:
list: 调整后的表格单元格识别结果格式与输入相同
"""
# 新的结果列表
adjusted_result = []
# 记录是否第一行和第一列的单元格已被删除
remove_first_row = False
remove_first_col = False
# 检查并移除第一行
if all(cell and not cell[1] for cell in t_rec_ocr_list_dict[0].get("t_ocr_res", [])):
remove_first_row = True
# 检查并移除第一列
if all(row.get("t_ocr_res") and not row["t_ocr_res"][0][1] for row in t_rec_ocr_list_dict):
remove_first_col = True
# 遍历原始结果进行调整
for i, row in enumerate(t_rec_ocr_list_dict):
adjusted_row = []
# 如果是第一行并且需要删除,跳过这行
if remove_first_row and i == 0:
continue
for j, cell in enumerate(row.get("t_ocr_res", [])):
# 如果是第一列并且需要删除,跳过这一列
if remove_first_col and j == 0:
continue
# 更新当前单元格的逻辑位置
adjusted_cell = {
"t_box": row.get("t_box"),
"t_logic_box": [
row["t_logic_box"][0] - 1 if i > 0 else row["t_logic_box"][0],
row["t_logic_box"][1] - 1 if i > 0 else row["t_logic_box"][1],
row["t_logic_box"][2] - 1 if j > 0 else row["t_logic_box"][2],
row["t_logic_box"][3] - 1 if j > 0 else row["t_logic_box"][3]
],
"t_ocr_res": cell
}
adjusted_row.append(adjusted_cell)
if adjusted_row:
adjusted_result.append(adjusted_row)
return adjusted_result

Copilot uses AI. Check for mistakes.
def sort_and_gather_ocr_res(self, res):
for i, dict_res in enumerate(res):
_, sorted_idx = sorted_ocr_boxes(
Expand All @@ -177,19 +269,30 @@ def sort_and_gather_ocr_res(self, res):
)
return res

def fill_blank_rec(
def re_rec(
self,
img: np.ndarray,
sorted_polygons: np.ndarray,
cell_box_map: Dict[int, List[str]],
rec_again=True,
) -> Dict[int, List[Any]]:
"""找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
for i in range(sorted_polygons.shape[0]):
if cell_box_map.get(i):
continue
if not rec_again:
box = sorted_polygons[i]
cell_box_map[i] = [[box, "", 1]]
continue
crop_img = get_rotate_crop_image(img, sorted_polygons[i])
pad_img = cv2.copyMakeBorder(
crop_img, 5, 5, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255)
)
rec_res, _ = self.ocr(pad_img, use_det=False, use_cls=True, use_rec=True)
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In re_rec, when rec_again=True, the code calls self.ocr(...) on line 291. However, self.ocr can be None if rapidocr_onnxruntime is not installed (set to None in __init__). While the guard at the top of __call__ raises an error if self.ocr is None and ocr_result is None, it does NOT prevent re_rec with rec_again=True from being called when ocr_result is supplied externally but self.ocr is None. This will result in a TypeError: 'NoneType' object is not callable at line 291.

Copilot uses AI. Check for mistakes.
box = sorted_polygons[i]
cell_box_map[i] = [[box, "", 1]]
continue
text = [rec[0] for rec in rec_res]
scores = [rec[1] for rec in rec_res]
cell_box_map[i] = [[box, "".join(text), min(scores)]]
return cell_box_map

def re_rec_high_precise(
Expand Down Expand Up @@ -222,46 +325,24 @@ def re_rec_high_precise(
]
return cell_box_map

@staticmethod
def get_model_path(
model_type: str, model_path: Union[str, Path, None]
) -> Union[str, Dict[str, str]]:
if model_path is not None:
return model_path

model_url = KEY_TO_MODEL_URL.get(model_type, None)
if isinstance(model_url, str):
model_path = DownloadModel.download(model_url)
return model_path

if isinstance(model_url, dict):
model_paths = {}
for k, url in model_url.items():
model_paths[k] = DownloadModel.download(
url, save_model_name=f"{model_type}_{Path(url).name}"
)
return model_paths

raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.")


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-img", "--img_path", type=str, required=True)
args = parser.parse_args()

try:
ocr_engine = importlib.import_module("rapidocr").RapidOCR()
ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"Please install the rapidocr by pip install rapidocr."
"Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime."
) from exc
input_args = WiredTableInput()
table_rec = WiredTableRecognition(input_args)

table_rec = WiredTableRecognition()
ocr_result, _ = ocr_engine(args.img_path)
table_results = table_rec(args.img_path, ocr_result)
print(table_results.pred_html)
print(f"cost: {table_results.elapse:.5f}")
table_str, elapse = table_rec(args.img_path, ocr_result)
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In main(), line 343 unpacks the return value of table_rec(args.img_path, ocr_result) into only 2 variables (table_str, elapse), but __call__ returns a 6-element tuple on success: (table_str, table_elapse, sorted_polygons, sorted_logi_points, sorted_ocr_boxes_res, adjust_dict). This will raise a ValueError: too many values to unpack at runtime whenever the call succeeds.

Suggested change
table_str, elapse = table_rec(args.img_path, ocr_result)
table_str, elapse, _, _, _, _ = table_rec(args.img_path, ocr_result)

Copilot uses AI. Check for mistakes.
print(table_str)
print(f"cost: {elapse:.5f}")


if __name__ == "__main__":
Expand Down