model.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import torch
  4. import torch.nn.functional as F
  5. from modelscope.metainfo import Models
  6. from modelscope.models.base.base_torch_model import TorchModel
  7. from modelscope.models.builder import MODELS
  8. from modelscope.utils.config import Config
  9. from modelscope.utils.constant import ModelFile, Tasks
  10. from modelscope.utils.logger import get_logger
  11. from .modules.ConvNextViT.main_model import ConvNextViT
  12. from .modules.CRNN.main_model import CRNN
  13. from .modules.LightweightEdge.main_model import LightweightEdge
  14. LOGGER = get_logger()
  15. def flatten_label(target):
  16. label_flatten = []
  17. label_length = []
  18. label_dict = []
  19. for i in range(0, target.size()[0]):
  20. cur_label = target[i].tolist()
  21. temp_label = cur_label[:cur_label.index(0)]
  22. label_flatten += temp_label
  23. label_dict.append(temp_label)
  24. label_length.append(len(temp_label))
  25. label_flatten = torch.LongTensor(label_flatten)
  26. label_length = torch.IntTensor(label_length)
  27. return (label_dict, label_length, label_flatten)
  28. class cha_encdec():
  29. def __init__(self, charMapping, case_sensitive=True):
  30. self.case_sensitive = case_sensitive
  31. self.text_seq_len = 160
  32. self.charMapping = charMapping
  33. def encode(self, label_batch):
  34. max_len = max([len(s) for s in label_batch])
  35. out = torch.zeros(len(label_batch), max_len + 1).long()
  36. for i in range(0, len(label_batch)):
  37. if not self.case_sensitive:
  38. cur_encoded = torch.tensor([
  39. self.charMapping[char.lower()] - 1 if char.lower()
  40. in self.charMapping else len(self.charMapping)
  41. for char in label_batch[i]
  42. ]) + 1
  43. else:
  44. cur_encoded = torch.tensor([
  45. self.charMapping[char]
  46. - 1 if char in self.charMapping else len(self.charMapping)
  47. for char in label_batch[i]
  48. ]) + 1
  49. out[i][0:len(cur_encoded)] = cur_encoded
  50. out = torch.cat(
  51. (out, torch.zeros(
  52. (out.size(0), self.text_seq_len - out.size(1))).type_as(out)),
  53. dim=1)
  54. label_dict, label_length, label_flatten = flatten_label(out)
  55. return label_dict, label_length, label_flatten
  56. @MODELS.register_module(
  57. Tasks.ocr_recognition, module_name=Models.ocr_recognition)
  58. class OCRRecognition(TorchModel):
  59. def __init__(self, model_dir: str, **kwargs):
  60. """initialize the ocr recognition model from the `model_dir` path.
  61. Args:
  62. model_dir (str): the model path.
  63. """
  64. super().__init__(model_dir, **kwargs)
  65. model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
  66. cfgs = Config.from_file(
  67. os.path.join(model_dir, ModelFile.CONFIGURATION))
  68. self.do_chunking = cfgs.model.inference_kwargs.do_chunking
  69. self.target_height = cfgs.model.inference_kwargs.img_height
  70. self.target_width = cfgs.model.inference_kwargs.img_width
  71. self.recognizer = None
  72. if cfgs.model.recognizer == 'ConvNextViT':
  73. self.recognizer = ConvNextViT()
  74. elif cfgs.model.recognizer == 'CRNN':
  75. self.recognizer = CRNN()
  76. elif cfgs.model.recognizer == 'LightweightEdge':
  77. self.recognizer = LightweightEdge()
  78. else:
  79. raise TypeError(
  80. f'recognizer should be either ConvNextViT, CRNN, but got {cfgs.model.recognizer}'
  81. )
  82. if model_path != '':
  83. params_pretrained = torch.load(model_path, map_location='cpu')
  84. model_dict = self.recognizer.state_dict()
  85. # remove prefix for finetuned models
  86. check_point = {
  87. k.replace('recognizer.', '').replace('module.', ''): v
  88. for k, v in params_pretrained.items()
  89. }
  90. model_dict.update(check_point)
  91. self.recognizer.load_state_dict(model_dict)
  92. dict_path = os.path.join(model_dir, ModelFile.VOCAB_FILE)
  93. self.labelMapping = dict()
  94. self.charMapping = dict()
  95. with open(dict_path, 'r', encoding='utf-8') as f:
  96. lines = f.readlines()
  97. cnt = 1
  98. # ConvNextViT and LightweightEdge model start from index=2
  99. if cfgs.model.recognizer == 'ConvNextViT' or cfgs.model.recognizer == 'LightweightEdge':
  100. cnt += 1
  101. for line in lines:
  102. line = line.strip('\n')
  103. self.labelMapping[cnt] = line
  104. self.charMapping[line] = cnt
  105. cnt += 1
  106. self.encdec = cha_encdec(self.charMapping)
  107. self.criterion_CTC = torch.nn.CTCLoss(zero_infinity=True)
  108. def forward(self, inputs):
  109. """
  110. Args:
  111. img (`torch.Tensor`): batched image tensor,
  112. shape of each tensor is [N, 1, H, W].
  113. Return:
  114. `probs [T, N, Classes] of the sequence feature`
  115. """
  116. return self.recognizer(inputs)
  117. def do_step(self, batch):
  118. inputs = batch['images']
  119. labels = batch['labels']
  120. bs = inputs.shape[0]
  121. if self.do_chunking:
  122. inputs = inputs.view(bs * 3, 3, self.target_height, 300)
  123. else:
  124. inputs = inputs.view(bs, 3, self.target_height, self.target_width)
  125. output = self(inputs)
  126. probs = output['probs'].permute(1, 0, 2)
  127. _, label_length, label_flatten = self.encdec.encode(labels)
  128. probs_sizes = torch.IntTensor([probs.size(0)] * probs.size(1))
  129. loss = self.criterion_CTC(
  130. probs.log_softmax(2), label_flatten, probs_sizes, label_length)
  131. output = dict(loss=loss, preds=output['preds'])
  132. return output
  133. def postprocess(self, inputs):
  134. outprobs = inputs
  135. outprobs = F.softmax(outprobs, dim=-1)
  136. preds = torch.argmax(outprobs, -1)
  137. batchSize, length = preds.shape
  138. final_str_list = []
  139. for i in range(batchSize):
  140. pred_idx = preds[i].cpu().data.tolist()
  141. last_p = 0
  142. str_pred = []
  143. for p in pred_idx:
  144. if p != last_p and p != 0:
  145. str_pred.append(self.labelMapping[p])
  146. last_p = p
  147. final_str = ''.join(str_pred)
  148. final_str_list.append(final_str)
  149. return {'preds': final_str_list, 'probs': inputs}