ocr_recognition.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict
  3. import torch
  4. import unicodedata2
  5. from torchvision import transforms
  6. from torchvision.transforms import InterpolationMode
  7. from torchvision.transforms import functional as F
  8. from zhconv import convert
  9. from modelscope.utils.constant import ModeKeys
  10. from .base import OfaBasePreprocessor
  11. def ocr_resize(img, patch_image_size, is_document=False):
  12. r"""
  13. Image resize function for OCR tasks.
  14. """
  15. img = img.convert('RGB')
  16. width, height = img.size
  17. if is_document:
  18. new_height, new_width = 64, 1920
  19. else:
  20. if width >= height:
  21. new_width = max(64, patch_image_size)
  22. new_height = max(64, int(patch_image_size * (height / width)))
  23. top = (patch_image_size - new_height) // 2
  24. bottom = patch_image_size - new_height - top
  25. left, right = 0, 0
  26. else:
  27. new_height = max(64, patch_image_size)
  28. new_width = max(64, int(patch_image_size * (width / height)))
  29. left = (patch_image_size - new_width) // 2
  30. right = patch_image_size - new_width - left
  31. top, bottom = 0, 0
  32. img_new = F.resize(
  33. img,
  34. (new_height, new_width),
  35. interpolation=InterpolationMode.BICUBIC,
  36. )
  37. if is_document:
  38. img_split = transforms.ToTensor()(img_new).chunk(4, dim=-1)
  39. img_new = transforms.ToPILImage()(torch.cat(img_split, dim=-2))
  40. new_width, new_height = img_new.size
  41. top = (patch_image_size - new_height) // 2
  42. bottom = patch_image_size - new_height - top
  43. left, right = 0, 0
  44. img_new = F.pad(
  45. img_new, padding=[left, top, right, bottom], padding_mode='edge')
  46. assert img_new.size == (patch_image_size, patch_image_size)
  47. return img_new
  48. class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
  49. r"""
  50. OFA preprocessor for OCR recognition tasks.
  51. """
  52. def __init__(self,
  53. cfg,
  54. model_dir,
  55. mode=ModeKeys.INFERENCE,
  56. *args,
  57. **kwargs):
  58. """preprocess the data
  59. Args:
  60. cfg(modelscope.utils.config.ConfigDict) : model config
  61. model_dir (str): model path,
  62. mode: preprocessor mode (model mode)
  63. """
  64. super(OfaOcrRecognitionPreprocessor,
  65. self).__init__(cfg, model_dir, mode, *args, **kwargs)
  66. self.patch_resize_transform = transforms.Compose([
  67. lambda image: ocr_resize(
  68. image,
  69. self.patch_image_size,
  70. is_document=self.cfg.model.get('is_document', False)),
  71. transforms.ToTensor(),
  72. transforms.Normalize(mean=self.mean, std=self.std),
  73. ])
  74. def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
  75. if self.mode == ModeKeys.TRAIN:
  76. return self._build_train_sample(data)
  77. else:
  78. return self._build_infer_sample(data)
  79. def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
  80. r"""
  81. Building training samples.
  82. step 1. Preprocess the data using the logic of `_build_infer_sample`
  83. and make sure the label data in the result.
  84. step 2. Preprocess the label data. Contains:
  85. - do tripe to the label value.
  86. - tokenize the label as `target` value without `bos` token.
  87. - add `bos` token and remove `eos` token of `target` as `prev_output_tokens`.
  88. Args:
  89. data (`Dict[str, Any]`): Input data, should contains the key of `image`, `prompt` and `label`,
  90. the former refers the image input data, and the later refers the text input data
  91. the `label` is the supervised data for training.
  92. Return:
  93. A dict object, contains source, image, mask, label, target tokens,
  94. and previous output tokens data.
  95. """
  96. sample = self._build_infer_sample(data)
  97. target = sample['label']
  98. target_token_list = target.strip().split()
  99. target = ' '.join(target_token_list[:self.max_tgt_length])
  100. sample['target'] = self.tokenize_text(target, add_bos=False)
  101. sample['prev_output_tokens'] = torch.cat(
  102. [self.bos_item, sample['target'][:-1]])
  103. return sample
  104. def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
  105. r"""
  106. Building inference samples.
  107. step 1. Get the pillow image.
  108. step 2. Do some transforms for the pillow image as the image input,
  109. such as resize, normalize, to tensor etc.
  110. step 3. Tokenize the prompt as text input.
  111. step 4. Determine Whether or not to add labels to the sample.
  112. Args:
  113. data (`Dict[str, Any]`): Input data, should contains the key of `image` and `prompt`,
  114. the former refers the image input data, and the later refers the text input data.
  115. Return:
  116. A dict object, contains source, image, image patch mask and label data.
  117. """
  118. image = self.get_img_pil(data[self.column_map['image']])
  119. patch_image = self.patch_resize_transform(image)
  120. prompt = self.cfg.model.get('prompt', '图片上的文字是什么?')
  121. inputs = self.tokenize_text(prompt)
  122. sample = {
  123. 'source': inputs,
  124. 'patch_image': patch_image,
  125. 'patch_mask': torch.tensor([True])
  126. }
  127. if 'text' in self.column_map and self.column_map['text'] in data:
  128. target = data[self.column_map['text']]
  129. sample['label'] = unicodedata2.normalize(
  130. 'NFKC', convert(target, 'zh-hans'))
  131. return sample