base.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import io
  3. import os
  4. import re
  5. import string
  6. from os import path as osp
  7. import json
  8. import numpy as np
  9. import torch
  10. import torchaudio
  11. from PIL import Image
  12. from modelscope.fileio import File
  13. from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH
  14. from modelscope.preprocessors.image import load_image
  15. from modelscope.utils.trie import Trie
  16. from .utils.audio_helper import (_get_kaldi_fbank, _get_torchaudio_fbank,
  17. convert_waveform)
  18. from .utils.constant import OFA_TASK_KEY_MAPPING
  19. from .utils.random_help import set_torch_seed
  20. class OfaBasePreprocessor:
  21. r"""
  22. OFA base preprocessor for
  23. """
  24. def __init__(self, cfg, model_dir, mode, *args, **kwargs):
  25. """preprocess the data via the vocab.txt from the `model_dir` path
  26. Args:
  27. cfg(modelscope.utils.config.ConfigDict) : model config
  28. model_dir (str): model path
  29. """
  30. self.cfg = cfg
  31. self.mode = mode
  32. self.language = self.cfg.model.get('language', 'en')
  33. if os.path.exists(model_dir):
  34. model_dir = os.path.abspath(model_dir)
  35. if self.language == 'en':
  36. tokenizer = OFATokenizer.from_pretrained(model_dir)
  37. elif self.language in ['zh', 'cn']:
  38. tokenizer = OFATokenizerZH.from_pretrained(model_dir)
  39. else:
  40. raise NotImplementedError
  41. # there is some diff between here and our ofa code,
  42. # there will be no need to use param: use_bpe
  43. tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)])
  44. tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)])
  45. if self.cfg.model.get('multimodal_type', 'default') == 'text2sql':
  46. tokenizer.add_tokens(['>=', '<='])
  47. self.tokenizer = tokenizer
  48. self.bos_item = torch.LongTensor([tokenizer.bos_token_id])
  49. self.pad_item = torch.LongTensor([tokenizer.pad_token_id])
  50. self.eos_item = torch.LongTensor([tokenizer.eos_token_id])
  51. self.tgt_dict = self.src_dict = {
  52. value: key
  53. for key, value in tokenizer.get_vocab().items()
  54. }
  55. self.max_src_length = cfg.model.get('max_src_length', 256)
  56. self.max_tgt_length = cfg.model.get('max_tgt_length', 256)
  57. self.max_image_size = cfg.model.get('max_image_size', 512)
  58. self.language = self.cfg.model.get('language', 'en')
  59. self.prompt_type = self.cfg.model.get('prompt_type', 'none')
  60. seed = self.cfg.model.get('seed', 7)
  61. np.random.seed(seed)
  62. set_torch_seed(seed)
  63. imagenet_default_mean_and_std = self.cfg.model.get(
  64. 'imagenet_default_mean_and_std', False)
  65. if imagenet_default_mean_and_std:
  66. self.mean = [0.485, 0.456, 0.406]
  67. self.std = [0.229, 0.224, 0.225]
  68. else:
  69. self.mean = [0.5, 0.5, 0.5]
  70. self.std = [0.5, 0.5, 0.5]
  71. self.patch_image_size = self.cfg.model.get('patch_image_size', 480)
  72. self.column_map = {
  73. key: key
  74. for key in OFA_TASK_KEY_MAPPING[self.cfg.task]
  75. }
  76. if hasattr(self.cfg,
  77. 'dataset') and self.cfg.dataset.column_map is not None:
  78. for k, v in self.cfg.dataset.column_map.items():
  79. self.column_map[k] = v
  80. self.transtab = str.maketrans(
  81. {key: None
  82. for key in string.punctuation})
  83. self.constraint_trie = None
  84. if self.cfg.model.get('answer2label', None):
  85. ans2label_file = osp.join(model_dir, self.cfg.model.answer2label)
  86. with open(ans2label_file, 'r', encoding='utf-8') as reader:
  87. ans2label_dict = json.load(reader)
  88. self.ans2label = ans2label_dict
  89. self.label2ans = {v: k for k, v in self.ans2label.items()}
  90. self.constraint_trie = Trie(tokenizer.eos_token_id)
  91. for i, answer in enumerate(ans2label_dict.keys()):
  92. answer_item = self.tokenize_text(
  93. ' ' + answer, add_bos=False, add_eos=False)
  94. self.constraint_trie.insert([tokenizer.bos_token_id]
  95. + answer_item.tolist()
  96. + [tokenizer.eos_token_id])
  97. self.train_audio_feature_transforms = None
  98. self.test_audio_feature_transforms = None
  99. def tokenize_text(self, text, add_bos=True, add_eos=True):
  100. r"""
  101. Using `OFATokenizer` to tokenize text input.
  102. Args:
  103. text (`str`): Input text.
  104. add_bos ('bool', **optional**, default to `True`)
  105. Whether or not to add beginning of sentence token in
  106. the front of sentence.
  107. add_eos ('bool', **optional**, default to `True`)
  108. Whether or not to add ending of sentence token in
  109. the end of sentence.
  110. Returns:
  111. A list of tokens with the max length of `max_src_length + 2`
  112. """
  113. if text is None:
  114. return None
  115. inputs = self.tokenizer(
  116. text,
  117. max_length=self.max_src_length,
  118. add_special_tokens=False,
  119. truncation=True,
  120. return_tensors='pt')['input_ids'].squeeze(0)
  121. if add_bos:
  122. inputs = torch.cat([self.bos_item, inputs])
  123. if add_eos:
  124. inputs = torch.cat([inputs, self.eos_item])
  125. return inputs
  126. @staticmethod
  127. def pre_caption(caption, max_words=None):
  128. r"""
  129. Preprocessing for text sentence.
  130. step 1. Get the lower case of input text.
  131. step 2. Remove the words within `,.!?*#:;~ ` in the beginning
  132. of the sentence.
  133. step 3. Replace the words within `-/` or pattern `\s{2,}` with word ` `
  134. and replace tag `<person>` with `person`.
  135. step 4. Remove the `\n` in the end of the sentence.
  136. step 5. Split the sentence with token ` `, If `max_words` is not None,
  137. make a length truncation.
  138. Args:
  139. caption (`str`): Input text.
  140. max_words (`int`, **optional**, default `None`):
  141. The max length of input text. If None, do nothing, else
  142. make a truncation.
  143. Returns:
  144. A sequence of `str`.
  145. """
  146. caption = caption.lower().lstrip(',.!?*#:;~').replace('-', ' ') \
  147. .replace('/', ' ').replace('<person>', 'person')
  148. caption = re.sub(
  149. r'\s{2,}',
  150. ' ',
  151. caption,
  152. )
  153. caption = caption.rstrip('\n')
  154. caption = caption.strip(' ')
  155. # truncate caption
  156. caption_words = caption.split(' ')
  157. if max_words is not None and len(caption_words) > max_words:
  158. caption = ' '.join(caption_words[:max_words])
  159. return caption
  160. @staticmethod
  161. def pre_question(question, max_ques_words):
  162. r"""
  163. Preprocessing for text sentence.
  164. Note that this function is very similar to `pre_caption`, should be merged in the future version.
  165. step 1. Get the lower case of input text.
  166. step 2. Remove the words within `,.!?*#:;~ ` in the beginning
  167. of the sentence.
  168. step 3. Replace the words within `-/` or pattern `\s{2,}` with word ` `.
  169. step 4. Remove the `\n` in the end of the sentence.
  170. step 5. Split the sentence with token ` `, If `max_words` is not None,
  171. make a length truncation.
  172. Args:
  173. question (`str`): Input text.
  174. max_ques_words (`int`, **optional**, default `None`):
  175. The max length of input text. If None, do nothing, else
  176. make a truncation.
  177. Returns:
  178. A sequence of `str`.
  179. """
  180. question = question.lower().lstrip(',.!?*#:;~').replace('-',
  181. ' ').replace(
  182. '/', ' ')
  183. question = re.sub(
  184. r'\s{2,}',
  185. ' ',
  186. question,
  187. )
  188. question = question.rstrip('\n')
  189. question = question.strip(' ')
  190. # truncate question
  191. question_words = question.split(' ')
  192. if len(question_words) > max_ques_words:
  193. question = ' '.join(question_words[:max_ques_words])
  194. return question
  195. def add_constraint_mask(self, sample):
  196. r"""
  197. Add constraint mask.
  198. """
  199. target_itm = sample['target']
  200. len_label_itm = target_itm.ne(self.pad_item).sum(dim=0).item()
  201. if self.constraint_trie:
  202. constraint_mask = torch.zeros(
  203. (len(target_itm), len(self.tgt_dict))).bool()
  204. start_idx = len(target_itm) - len_label_itm
  205. for i in range(start_idx, len(target_itm)):
  206. constraint_prefix_token = self.bos_item.tolist(
  207. ) + target_itm[start_idx:i].tolist()
  208. constraint_nodes = self.constraint_trie.get_next_layer(
  209. constraint_prefix_token)
  210. constraint_mask[i][constraint_nodes] = True
  211. sample['constraint_mask'] = constraint_mask
  212. def get_img_pil(self, path_or_url_or_pil):
  213. r"""
  214. Get the pillow image. If the input is not a pillow image ,it will load
  215. image from a local path or an external url.
  216. Args:
  217. path_or_url_or_pil (`Union[str, Image]`):
  218. Can be:
  219. - A path or url reference to an image
  220. - A pillow image.
  221. Returns:
  222. A pillow image.
  223. """
  224. image = path_or_url_or_pil if isinstance(path_or_url_or_pil, Image.Image) \
  225. else load_image(path_or_url_or_pil)
  226. return image
  227. def get_audio_bytes(self, path_or_url):
  228. if isinstance(path_or_url, bytes):
  229. audio_bytes = io.BytesIO(path_or_url)
  230. elif isinstance(path_or_url, str):
  231. file_bytes = File.read(path_or_url)
  232. audio_bytes = io.BytesIO(file_bytes)
  233. else:
  234. raise TypeError(f'Unsupported input type: {type(path_or_url)}.')
  235. return audio_bytes
  236. def prepare_fbank(self,
  237. waveform,
  238. sample_rate,
  239. speed,
  240. target_sample_rate=16000,
  241. is_train=False):
  242. waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
  243. waveform, sample_rate,
  244. [['speed', str(speed)], ['rate', str(target_sample_rate)]])
  245. _waveform, _ = convert_waveform(
  246. waveform, sample_rate, to_mono=True, normalize_volume=True)
  247. # Kaldi compliance: 16-bit signed integers
  248. _waveform = _waveform * (2**15)
  249. _waveform = _waveform.numpy()
  250. fbank = _get_kaldi_fbank(_waveform, sample_rate, 80)
  251. if fbank is None:
  252. fbank = _get_torchaudio_fbank(_waveform, sample_rate, 80)
  253. if fbank is None:
  254. raise ImportError(
  255. 'Please install pyKaldi or torchaudio to enable fbank feature extraction'
  256. )
  257. if is_train and self.train_audio_feature_transforms is not None:
  258. fbank = self.train_audio_feature_transforms(fbank)
  259. elif ~is_train and self.test_audio_feature_transforms(
  260. fbank) is not None:
  261. fbank = self.test_audio_feature_transforms(fbank)
  262. fbank = torch.from_numpy(fbank).float()
  263. fbank = self.pack_frames(fbank)
  264. return fbank
  265. def pack_frames(self, feature: torch.Tensor):
  266. if self.cfg.n_frames_per_step == 1:
  267. return feature
  268. n_packed_frames = feature.shape[0] // self.cfg.n_frames_per_step
  269. feature = feature[:self.cfg.n_frames_per_step * n_packed_frames]
  270. return feature.reshape(n_packed_frames, -1)