processing.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. """
  3. Processor class for GeoLayoutLM.
  4. """
  5. from collections import defaultdict
  6. from typing import Dict, Iterable, List, Union
  7. import cv2
  8. import numpy as np
  9. import PIL
  10. import torch
  11. from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  12. from torchvision import transforms
  13. from modelscope.preprocessors.image import LoadImage
  14. def custom_tokenize(tokenizer, text):
  15. toks = tokenizer.tokenize('pad ' + text)[1:]
  16. toks2 = toks[1:] if len(toks) > 0 and toks[0] == '▁' else toks
  17. return toks2
  18. class ImageProcessor(object):
  19. r"""
  20. Construct a GeoLayoutLM image processor
  21. Args:
  22. do_preprocess (`bool`): whether to do preprocess to unify the image format,
  23. resize and convert to tensor.
  24. do_rescale: only works when we disable do_preprocess.
  25. """
  26. def __init__(self,
  27. do_preprocess: bool = True,
  28. do_resize: bool = False,
  29. image_size: Dict[str, int] = None,
  30. do_rescale: bool = False,
  31. rescale_factor: float = 1. / 255,
  32. do_normalize: bool = True,
  33. image_mean: Union[float, Iterable[float]] = None,
  34. image_std: Union[float, Iterable[float]] = None,
  35. apply_ocr: bool = True,
  36. **kwargs) -> None:
  37. self.do_preprocess = do_preprocess
  38. self.do_resize = do_resize
  39. self.size = image_size if image_size is not None else {
  40. 'height': 768,
  41. 'width': 768
  42. }
  43. self.do_rescale = do_rescale and (not do_preprocess)
  44. self.rescale_factor = rescale_factor
  45. self.do_normalize = do_normalize
  46. image_mean = IMAGENET_DEFAULT_MEAN if image_mean is None else image_mean
  47. image_std = IMAGENET_DEFAULT_STD if image_std is None else image_std
  48. self.image_mean = (image_mean, image_mean, image_mean) if isinstance(
  49. image_mean, float) else image_mean
  50. self.image_std = (image_std, image_std, image_std) if isinstance(
  51. image_std, float) else image_std
  52. self.apply_ocr = apply_ocr
  53. self.kwargs = kwargs
  54. self.totensor = transforms.ToTensor()
  55. def preprocess(self, image: Union[np.ndarray, PIL.Image.Image]):
  56. """ unify the image format, resize and convert to tensor.
  57. """
  58. image = LoadImage.convert_to_ndarray(image)[:, :, ::-1]
  59. size_raw = image.shape[:2]
  60. if self.do_resize:
  61. image = cv2.resize(image,
  62. (self.size['width'], self.size['height']))
  63. # convert to pytorch tensor
  64. image_pt = self.totensor(image)
  65. return image_pt, size_raw
  66. def __call__(self, images: Union[list, np.ndarray, PIL.Image.Image, str]):
  67. """
  68. Args:
  69. images: list of np.ndarrays, PIL images or image tensors.
  70. """
  71. if not isinstance(images, list):
  72. images = [images]
  73. sizes_raw = []
  74. if self.do_preprocess:
  75. for i in range(len(images)):
  76. images[i], size_raw = self.preprocess(images[i])
  77. sizes_raw.append(size_raw)
  78. images_pt = torch.stack(images, dim=0) # [b, c, h, w]
  79. if self.do_rescale:
  80. images_pt = images_pt * self.rescale_factor
  81. if self.do_normalize:
  82. mu = torch.tensor(self.image_mean).view(1, 3, 1, 1)
  83. std = torch.tensor(self.image_std).view(1, 3, 1, 1)
  84. images_pt = (images_pt - mu) / (std + 1e-8)
  85. # TODO: apply OCR
  86. ocr_infos = None
  87. if self.apply_ocr:
  88. raise NotImplementedError('OCR service is not available yet!')
  89. if len(sizes_raw) == 0:
  90. sizes_raw = None
  91. data = {
  92. 'images': images_pt,
  93. 'ocr_infos': ocr_infos,
  94. 'sizes_raw': sizes_raw
  95. }
  96. return data
  97. class OCRUtils(object):
  98. def __init__(self):
  99. self.version = 'v0'
  100. def __call__(self, ocr_infos):
  101. """
  102. sort boxes, filtering or other preprocesses
  103. should return sorted ocr_infos
  104. """
  105. raise NotImplementedError
  106. def bound_box(box, height, width):
  107. # box: [x_tl, y_tl, x_br, y_br] or ...
  108. assert len(box) == 4 or len(box) == 8
  109. for i in range(len(box)):
  110. if i & 1:
  111. box[i] = max(0, min(box[i], height))
  112. else:
  113. box[i] = max(0, min(box[i], width))
  114. return box
  115. def bbox2pto4p(box2p):
  116. box4p = [
  117. box2p[0], box2p[1], box2p[2], box2p[1], box2p[2], box2p[3], box2p[0],
  118. box2p[3]
  119. ]
  120. return box4p
  121. def bbox4pto2p(box4p):
  122. box2p = [
  123. min(box4p[0], box4p[2], box4p[4], box4p[6]),
  124. min(box4p[1], box4p[3], box4p[5], box4p[7]),
  125. max(box4p[0], box4p[2], box4p[4], box4p[6]),
  126. max(box4p[1], box4p[3], box4p[5], box4p[7]),
  127. ]
  128. return box2p
  129. def stack_tensor_dict(tensor_dicts: List[Dict[str, torch.Tensor]]):
  130. one_dict = defaultdict(list)
  131. for td in tensor_dicts:
  132. for k, v in td.items():
  133. one_dict[k].append(v)
  134. res_dict = {}
  135. for k, v in one_dict.items():
  136. res_dict[k] = torch.stack(v, dim=0)
  137. return res_dict
  138. class TextLayoutSerializer(object):
  139. def __init__(self,
  140. max_seq_length: int,
  141. max_block_num: int,
  142. tokenizer,
  143. width=768,
  144. height=768,
  145. use_roberta_tokenizer: bool = True,
  146. ocr_utils: OCRUtils = None):
  147. self.version = 'v0'
  148. self.max_seq_length = max_seq_length
  149. self.max_block_num = max_block_num
  150. self.tokenizer = tokenizer
  151. self.width = width
  152. self.height = height
  153. self.use_roberta_tokenizer = use_roberta_tokenizer
  154. self.ocr_utils = ocr_utils
  155. self.pad_token_id = tokenizer.pad_token_id
  156. self.cls_token_id = tokenizer.bos_token_id
  157. self.sep_token_id = tokenizer.eos_token_id
  158. self.unk_token_id = tokenizer.unk_token_id
  159. self.cls_bbs_word = [0.0] * 8
  160. self.cls_bbs_line = [0] * 4
  161. def label2seq(self, ocr_info: list, label_info: list):
  162. raise NotImplementedError
  163. def serialize_single(
  164. self,
  165. ocr_info: list = None,
  166. input_ids: list = None,
  167. bbox_line: List[List] = None,
  168. bbox_word: List[List] = None,
  169. width: int = 768,
  170. height: int = 768,
  171. ):
  172. r"""
  173. Either ocr_info or (input_ids, bbox_line, bbox_word)
  174. should be provided.
  175. If (input_ids, bbox_line, bbox_word) is provided,
  176. convenient plug into the serialization (customization)
  177. is offered. The tokens must be organised by blocks and words.
  178. Else, ocr_info must be provided, to be parsed
  179. to sequences directly (the simplest way).
  180. Args:
  181. ocr_info: [
  182. {"text": "xx", "box": [a,b,c,d],
  183. "words": [{"text": "x", "box": [e,f,g,h]}, ...]},
  184. ...
  185. ]
  186. bbox_line: the coordinate value should match the original image
  187. (i.e., not be normalized).
  188. """
  189. if input_ids is not None:
  190. assert len(input_ids) == len(bbox_line)
  191. assert len(input_ids) == len(bbox_word)
  192. input_ids, bbs_word, bbs_line, first_token_idxes, \
  193. line_rank_ids, line_rank_inner_ids, word_rank_ids = \
  194. self.halfseq2seq(input_ids, bbox_line, bbox_word, width, height)
  195. else:
  196. assert ocr_info is not None
  197. input_ids, bbs_word, bbs_line, first_token_idxes, \
  198. line_rank_ids, line_rank_inner_ids, word_rank_ids = \
  199. self.ocr_info2seq(ocr_info, width, height)
  200. token_seq = {}
  201. token_seq['input_ids'] = torch.ones(
  202. self.max_seq_length, dtype=torch.int64) * self.pad_token_id
  203. token_seq['attention_mask'] = torch.zeros(
  204. self.max_seq_length, dtype=torch.int64)
  205. token_seq['first_token_idxes'] = torch.zeros(
  206. self.max_block_num, dtype=torch.int64)
  207. token_seq['first_token_idxes_mask'] = torch.zeros(
  208. self.max_block_num, dtype=torch.int64)
  209. token_seq['bbox_4p_normalized'] = torch.zeros(
  210. self.max_seq_length, 8, dtype=torch.float32)
  211. token_seq['bbox'] = torch.zeros(
  212. self.max_seq_length, 4, dtype=torch.float32)
  213. token_seq['line_rank_id'] = torch.zeros(
  214. self.max_seq_length, dtype=torch.int64) # start from 1
  215. token_seq['line_rank_inner_id'] = torch.ones(
  216. self.max_seq_length, dtype=torch.int64) # 1 2 2 3
  217. token_seq['word_rank_id'] = torch.zeros(
  218. self.max_seq_length, dtype=torch.int64) # start from 1
  219. # expand using cls and sep tokens
  220. sep_bbs_word = [width, height] * 4
  221. sep_bbs_line = [width, height] * 2
  222. input_ids = [self.cls_token_id] + input_ids + [self.sep_token_id]
  223. bbs_line = [self.cls_bbs_line] + bbs_line + [sep_bbs_line]
  224. bbs_word = [self.cls_bbs_word] + bbs_word + [sep_bbs_word]
  225. # assign
  226. len_tokens = len(input_ids)
  227. len_lines = len(first_token_idxes)
  228. token_seq['input_ids'][:len_tokens] = torch.tensor(input_ids)
  229. token_seq['attention_mask'][:len_tokens] = 1
  230. token_seq['first_token_idxes'][:len_lines] = torch.tensor(
  231. first_token_idxes)
  232. token_seq['first_token_idxes_mask'][:len_lines] = 1
  233. token_seq['line_rank_id'][1:len_tokens
  234. - 1] = torch.tensor(line_rank_ids)
  235. token_seq['line_rank_inner_id'][1:len_tokens - 1] = torch.tensor(
  236. line_rank_inner_ids)
  237. token_seq['line_rank_inner_id'] = token_seq[
  238. 'line_rank_inner_id'] * token_seq['attention_mask']
  239. token_seq['word_rank_id'][1:len_tokens
  240. - 1] = torch.tensor(word_rank_ids)
  241. token_seq['bbox_4p_normalized'][:len_tokens, :] = torch.tensor(
  242. bbs_word)
  243. # word bbox normalization -> [0, 1]
  244. token_seq['bbox_4p_normalized'][:, [0, 2, 4, 6]] = \
  245. token_seq['bbox_4p_normalized'][:, [0, 2, 4, 6]] / width
  246. token_seq['bbox_4p_normalized'][:, [1, 3, 5, 7]] = \
  247. token_seq['bbox_4p_normalized'][:, [1, 3, 5, 7]] / height
  248. token_seq['bbox'][:len_tokens, :] = torch.tensor(bbs_line)
  249. # line bbox -> [0, 1000)
  250. token_seq['bbox'][:,
  251. [0, 2]] = token_seq['bbox'][:, [0, 2]] / width * 1000
  252. token_seq['bbox'][:,
  253. [1, 3]] = token_seq['bbox'][:,
  254. [1, 3]] / height * 1000
  255. token_seq['bbox'] = token_seq['bbox'].long()
  256. return token_seq
  257. def ocr_info2seq(self, ocr_info: list, width: int, height: int):
  258. input_ids = []
  259. bbs_word = []
  260. bbs_line = []
  261. first_token_idxes = []
  262. line_rank_ids = []
  263. line_rank_inner_ids = []
  264. word_rank_ids = []
  265. early_stop = False
  266. for line_idx, line in enumerate(ocr_info):
  267. if line_idx == self.max_block_num:
  268. early_stop = True
  269. if early_stop:
  270. break
  271. lbox = line['box']
  272. lbox = bound_box(lbox, height, width)
  273. is_first_word = True
  274. for word_id, word_info in enumerate(line['words']):
  275. wtext = word_info['text']
  276. wbox = word_info['box']
  277. wbox = bound_box(wbox, height, width)
  278. wbox4p = bbox2pto4p(wbox)
  279. if self.use_roberta_tokenizer:
  280. wtokens = custom_tokenize(self.tokenizer, wtext)
  281. else:
  282. wtokens = self.tokenizer.tokenize(wtext)
  283. wtoken_ids = self.tokenizer.convert_tokens_to_ids(wtokens)
  284. if len(wtoken_ids) == 0:
  285. wtoken_ids.append(self.unk_token_id)
  286. n_tokens = len(wtoken_ids)
  287. # reserve for cls and sep
  288. if len(input_ids) + n_tokens > self.max_seq_length - 2:
  289. early_stop = True
  290. break # chunking early for long documents
  291. if is_first_word:
  292. first_token_idxes.append(len(input_ids) + 1)
  293. input_ids.extend(wtoken_ids)
  294. bbs_word.extend([wbox4p] * n_tokens)
  295. bbs_line.extend([lbox] * n_tokens)
  296. word_rank_ids.extend([word_id + 1] * n_tokens)
  297. line_rank_ids.extend([line_idx + 1] * n_tokens)
  298. if is_first_word:
  299. if len(line_rank_inner_ids
  300. ) > 0 and line_rank_inner_ids[-1] == 2:
  301. line_rank_inner_ids[-1] = 3
  302. line_rank_inner_ids.extend([1] + (n_tokens - 1) * [2])
  303. is_first_word = False
  304. else:
  305. line_rank_inner_ids.extend(n_tokens * [2])
  306. if len(line_rank_inner_ids) > 0 and line_rank_inner_ids[-1] == 2:
  307. line_rank_inner_ids[-1] = 3
  308. return input_ids, bbs_word, bbs_line, first_token_idxes, line_rank_ids, \
  309. line_rank_inner_ids, word_rank_ids
  310. def halfseq2seq(self, input_ids: list, bbox_line: List[List],
  311. bbox_word: List[List], width: int, height: int):
  312. """
  313. for convenient plug into the serialization, given the 3 customized sequences.
  314. They should not contain special tokens like [CLS] or [SEP].
  315. """
  316. bbs_word = []
  317. bbs_line = []
  318. first_token_idxes = []
  319. line_rank_ids = []
  320. line_rank_inner_ids = []
  321. word_rank_ids = []
  322. n_real_tokens = len(input_ids)
  323. lb_prev, wb_prev = None, None
  324. line_id = 0
  325. word_id = 1
  326. for i in range(n_real_tokens):
  327. lb_now = bbox_line[i]
  328. wb_now = bbox_word[i]
  329. line_start = lb_prev is None or lb_now != lb_prev
  330. word_start = wb_prev is None or wb_now != wb_prev
  331. lb_prev, wb_prev = lb_now, wb_now
  332. if len(lb_now) == 8:
  333. lb_now = bbox4pto2p(lb_now)
  334. assert len(lb_now) == 4
  335. lb_now = bound_box(lb_now, height, width)
  336. if len(wb_now) == 4:
  337. wb_now = bbox2pto4p(wb_now)
  338. assert len(wb_now) == 8
  339. wb_now = bound_box(wb_now, height, width)
  340. bbs_word.append(wb_now)
  341. bbs_line.append(lb_now)
  342. if word_start:
  343. word_id += 1
  344. if line_start:
  345. line_id += 1
  346. first_token_idxes.append(i + 1)
  347. if len(line_rank_inner_ids
  348. ) > 0 and line_rank_inner_ids[-1] == 2:
  349. line_rank_inner_ids[-1] = 3
  350. line_rank_inner_ids.append(1)
  351. word_id = 1
  352. else:
  353. line_rank_inner_ids.append(2)
  354. line_rank_ids.append(line_id)
  355. word_rank_ids.append(word_id)
  356. if len(line_rank_inner_ids) > 0 and line_rank_inner_ids[-1] == 2:
  357. line_rank_inner_ids[-1] = 3
  358. return input_ids, bbs_word, bbs_line, first_token_idxes, \
  359. line_rank_ids, line_rank_inner_ids, word_rank_ids
  360. def __call__(
  361. self,
  362. ocr_infos: List[List] = None,
  363. input_ids: list = None,
  364. bboxes_line: List[List] = None,
  365. bboxes_word: List[List] = None,
  366. sizes_raw: list = None,
  367. **kwargs,
  368. ):
  369. n_samples = len(ocr_infos) if ocr_infos is not None else len(input_ids)
  370. if sizes_raw is None:
  371. sizes_raw = [(self.height, self.width)] * n_samples
  372. seqs = []
  373. if input_ids is not None:
  374. assert len(input_ids) == len(bboxes_line)
  375. assert len(input_ids) == len(bboxes_word)
  376. for input_id, bbox_line, bbox_word, size_raw in zip(
  377. input_ids, bboxes_line, bboxes_word, sizes_raw):
  378. height, width = size_raw
  379. token_seq = self.serialize_single(None, input_id, bbox_line,
  380. bbox_word, width, height)
  381. seqs.append(token_seq)
  382. else:
  383. assert ocr_infos is not None, 'For serialization, ocr_infos must not be NoneType!'
  384. if self.ocr_utils is not None:
  385. ocr_infos = self.ocr_utils(ocr_infos)
  386. for ocr_info, size_raw in zip(ocr_infos, sizes_raw):
  387. height, width = size_raw
  388. token_seq = self.serialize_single(
  389. ocr_info, width=width, height=height)
  390. seqs.append(token_seq)
  391. pt_seqs = stack_tensor_dict(seqs)
  392. return pt_seqs
  393. class Processor(object):
  394. r"""Construct a GeoLayoutLM processor.
  395. Args:
  396. max_seq_length: max length for token
  397. max_block_num: max number of text lines (blocks or segments)
  398. img_processor: type of ImageProcessor.
  399. tokenizer: to tokenize strings.
  400. use_roberta_tokenizer: Whether the tokenizer is originated from RoBerta tokenizer
  401. (True by default).
  402. ocr_utils: a tool to preprocess ocr_infos.
  403. width: default width. It can be used only when all the images are of the same shape.
  404. height: default height. It can be used only when all the images are of the same shape.
  405. In `serialize_from_tokens`, the 3 sequences (i.e., `input_ids`, `bboxes_line`, `bboxes_word`)
  406. must not contain special tokens like [CLS] or [SEP].
  407. The boxes in `bboxes_line` and `bboxes_word` can be presented by either 2 points or 4 points.
  408. The value in boxes should keep original.
  409. Here is an example of the 3 arguments:
  410. ```
  411. input_ids ->
  412. [[6, 2391, 6, 31833, 6, 10132, 6, 2283, 6, 17730, 6, 2698, 152]]
  413. bboxes_line ->
  414. [[[230, 1, 353, 38], [230, 1, 353, 38], [230, 1, 353, 38], [230, 1, 353, 38],
  415. [230, 1, 353, 38], [230, 1, 353, 38], [230, 1, 353, 38], [230, 1, 353, 38],
  416. [257, 155, 338, 191], [257, 155, 338, 191], [257, 155, 338, 191], [257, 155, 338, 191],
  417. [257, 155, 338, 191]]]
  418. bboxes_word ->
  419. [[[231, 2, 267, 2, 267, 38, 231, 38], [231, 2, 267, 2, 267, 38, 231, 38],
  420. [264, 7, 298, 7, 298, 36, 264, 36], [264, 7, 298, 7, 298, 36, 264, 36],
  421. [293, 3, 329, 3, 329, 41, 293, 41], [293, 3, 329, 3, 329, 41, 293, 41],
  422. [330, 4, 354, 4, 354, 39, 330, 39], [330, 4, 354, 4, 354, 39, 330, 39],
  423. [258, 156, 289, 156, 289, 193, 258, 193], [258, 156, 289, 156, 289, 193, 258, 193],
  424. [288, 158, 321, 158, 321, 192, 288, 192], [288, 158, 321, 158, 321, 192, 288, 192],
  425. [321, 156, 336, 156, 336, 190, 321, 190]]]
  426. ```
  427. """
  428. def __init__(self,
  429. max_seq_length,
  430. max_block_num,
  431. img_processor: ImageProcessor,
  432. tokenizer=None,
  433. use_roberta_tokenizer: bool = True,
  434. ocr_utils: OCRUtils = None,
  435. width=768,
  436. height=768,
  437. **kwargs):
  438. self.img_processor = img_processor
  439. self.tokenizer = tokenizer
  440. self.kwargs = kwargs
  441. self.serializer = TextLayoutSerializer(
  442. max_seq_length,
  443. max_block_num,
  444. tokenizer,
  445. width,
  446. height,
  447. use_roberta_tokenizer=use_roberta_tokenizer,
  448. ocr_utils=ocr_utils)
  449. def __call__(
  450. self,
  451. images: Union[list, np.ndarray, PIL.Image.Image, str],
  452. ocr_infos: List[List] = None,
  453. token_seqs: dict = None,
  454. sizes_raw: list = None,
  455. ):
  456. img_data = self.img_processor(images)
  457. images = img_data['images']
  458. ocr_infos = img_data['ocr_infos'] if ocr_infos is None else ocr_infos
  459. sizes_raw = img_data['sizes_raw'] if sizes_raw is None else sizes_raw
  460. if token_seqs is None:
  461. token_seqs = self.serializer(ocr_infos, sizes_raw=sizes_raw)
  462. else:
  463. token_seqs = self.serializer(
  464. None, sizes_raw=sizes_raw, **token_seqs)
  465. assert token_seqs is not None, 'token_seqs must not be NoneType!'
  466. batch = {}
  467. batch['image'] = images
  468. for k, v in token_seqs.items():
  469. batch[k] = token_seqs[k]
  470. return batch
  471. def serialize_from_tokens(self,
  472. images,
  473. input_ids,
  474. bboxes_line,
  475. bboxes_word,
  476. sizes_raw=None):
  477. half_batch = {}
  478. half_batch['input_ids'] = input_ids
  479. half_batch['bboxes_line'] = bboxes_line
  480. half_batch['bboxes_word'] = bboxes_word
  481. return self(images, None, half_batch, sizes_raw)