multi_modal.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. import re
  4. from io import BytesIO
  5. from typing import Any, Dict, List, Tuple, Union
  6. import decord
  7. import json
  8. import numpy as np
  9. import torch
  10. from PIL import Image
  11. from timm.data import create_transform
  12. from torchvision import transforms
  13. from torchvision.datasets import ImageFolder
  14. from torchvision.transforms import Compose, Normalize, Resize, ToTensor
  15. from modelscope.hub.snapshot_download import snapshot_download
  16. from modelscope.metainfo import Preprocessors
  17. from modelscope.pipelines.base import Input
  18. from modelscope.pipelines.cv.cmdssl_video_embedding_pipeline import (
  19. VCenterCrop, VCompose, VNormalize, VRescale, VToTensor)
  20. from modelscope.preprocessors import load_image
  21. from modelscope.utils.config import Config
  22. from modelscope.utils.constant import (Fields, Invoke, ModeKeys, ModelFile,
  23. Tasks)
  24. from .base import Preprocessor
  25. from .builder import PREPROCESSORS
  26. from .ofa import * # noqa
  27. from .ofa.utils.collate import collate_fn
  28. from .ofa.utils.constant import OFA_TASK_KEY_MAPPING
  29. __all__ = [
  30. 'DiffusionImageGenerationPreprocessor', 'OfaPreprocessor',
  31. 'MPlugPreprocessor', 'HiTeAPreprocessor', 'MplugOwlPreprocessor'
  32. ]
  33. @PREPROCESSORS.register_module(
  34. Fields.multi_modal,
  35. module_name=Preprocessors.diffusion_image_generation_preprocessor)
  36. class DiffusionImageGenerationPreprocessor(Preprocessor):
  37. """ Preprocessor the data with the combination of image and text.
  38. Args:
  39. data: process the value as an image for keys ending with 'FILE'
  40. or existing in preprocessor_image_keys and pass-through the values of other keys.
  41. """
  42. def __init__(self, *args, **kwargs):
  43. super().__init__(*args, **kwargs)
  44. self.preprocessor_resolution = kwargs.pop('resolution', 512)
  45. self.preprocessor_mean = kwargs.pop('mean', [0.5])
  46. self.preprocessor_std = kwargs.pop('std', [0.5])
  47. self.preprocessor_image_keys = set(kwargs.pop('image_keys', []))
  48. self.center_crop = kwargs.pop('center_crop', True)
  49. self.transform_input = transforms.Compose([
  50. transforms.Resize(
  51. self.preprocessor_resolution,
  52. interpolation=transforms.InterpolationMode.BILINEAR),
  53. transforms.CenterCrop(self.preprocessor_resolution)
  54. if self.center_crop else transforms.RandomCrop(
  55. self.preprocessor_resolution),
  56. transforms.ToTensor(),
  57. transforms.Normalize(self.preprocessor_mean,
  58. self.preprocessor_std),
  59. ])
  60. def __call__(self, data) -> Dict[str, Any]:
  61. results = {}
  62. for key, value in data.items():
  63. if key.endswith(':FILE') or key in self.preprocessor_image_keys:
  64. image = load_image(value)
  65. img = self.transform_input(image)
  66. results[key.replace(':FILE', '').lower()] = img
  67. else:
  68. results[key.lower()] = value if value else ''
  69. return results
  70. @PREPROCESSORS.register_module(
  71. Fields.multi_modal, module_name=Preprocessors.ofa_tasks_preprocessor)
  72. class OfaPreprocessor(Preprocessor):
  73. def __init__(self,
  74. model_dir: str,
  75. mode=ModeKeys.INFERENCE,
  76. *args,
  77. **kwargs):
  78. """preprocess the data
  79. Args:
  80. model_dir (str): model path
  81. mode: preprocessor mode (model mode)
  82. """
  83. super().__init__(*args, **kwargs)
  84. preprocess_mapping = {
  85. Tasks.ocr_recognition: OfaOcrRecognitionPreprocessor,
  86. Tasks.image_captioning: OfaImageCaptioningPreprocessor,
  87. Tasks.visual_grounding: OfaVisualGroundingPreprocessor,
  88. Tasks.visual_question_answering:
  89. OfaVisualQuestionAnsweringPreprocessor,
  90. Tasks.visual_entailment: OfaVisualEntailmentPreprocessor,
  91. Tasks.image_classification: OfaImageClassificationPreprocessor,
  92. Tasks.text_classification: OfaTextClassificationPreprocessor,
  93. Tasks.text_summarization: OfaSummarizationPreprocessor,
  94. Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor,
  95. Tasks.auto_speech_recognition: OfaASRPreprocessor,
  96. Tasks.sudoku: OfaSudokuPreprocessor,
  97. Tasks.text2sql: OfaTextToSqlPreprocessor
  98. }
  99. model_dir = model_dir if osp.exists(model_dir) else snapshot_download(
  100. model_dir, user_agent={Invoke.KEY: Invoke.PREPROCESSOR})
  101. self.cfg = Config.from_file(
  102. osp.join(model_dir, ModelFile.CONFIGURATION))
  103. self.preprocess = preprocess_mapping[self.cfg.task](
  104. cfg=self.cfg, model_dir=model_dir, mode=mode)
  105. self.keys = OFA_TASK_KEY_MAPPING[self.cfg.task]
  106. self.tokenizer = self.preprocess.tokenizer
  107. if kwargs.get('no_collate', None):
  108. self.no_collate = True
  109. else:
  110. self.no_collate = False
  111. # just for modelscope demo
  112. def _build_dict(self, input: Union[Input, List[Input]]) -> Dict[str, Any]:
  113. data = dict()
  114. if not isinstance(input, tuple) and not isinstance(input, list):
  115. input = (input, )
  116. for key, item in zip(self.keys, input):
  117. data[key] = item
  118. return data
  119. def _ofa_input_compatibility_conversion(self, data): # fake
  120. if 'image' in data and self.cfg.model.get('type', None) == 'ofa':
  121. if isinstance(data['image'], str):
  122. image = load_image(data['image'])
  123. else:
  124. image = data['image']
  125. if image.mode != 'RGB':
  126. image = image.convert('RGB')
  127. img_buffer = BytesIO()
  128. image.save(img_buffer, format='JPEG')
  129. data['image'] = Image.open(img_buffer)
  130. return data
  131. def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args,
  132. **kwargs) -> Dict[str, Any]:
  133. if isinstance(input, dict):
  134. data = input
  135. else:
  136. data = self._build_dict(input)
  137. sample = self.preprocess(data)
  138. str_data = dict()
  139. for k, v in data.items():
  140. str_data[k] = str(v)
  141. sample['sample'] = str_data
  142. if self.no_collate:
  143. return sample
  144. else:
  145. return collate_fn([sample],
  146. pad_idx=self.tokenizer.pad_token_id,
  147. eos_idx=self.tokenizer.eos_token_id)
  148. def _convert_to_rgb(image):
  149. return image.convert('RGB')
  150. @PREPROCESSORS.register_module(
  151. Fields.multi_modal, module_name=Preprocessors.clip_preprocessor)
  152. class CLIPPreprocessor(Preprocessor):
  153. def __init__(self,
  154. model_dir: str,
  155. mode=ModeKeys.INFERENCE,
  156. *args,
  157. **kwargs):
  158. """preprocess the data
  159. Args:
  160. model_dir (str): model path
  161. mode: preprocessor mode (model mode)
  162. """
  163. super().__init__(*args, **kwargs)
  164. model_dir = model_dir if osp.exists(model_dir) else snapshot_download(
  165. model_dir, user_agent={Invoke.KEY: Invoke.PREPROCESSOR})
  166. self.mode = mode
  167. # text tokenizer
  168. from modelscope.models.multi_modal.clip.bert_tokenizer import FullTokenizer
  169. if 'tokenizer' in kwargs and isinstance(kwargs['tokenizer'],
  170. FullTokenizer):
  171. self.tokenizer = kwargs['tokenizer']
  172. else:
  173. vocab_file = f'{model_dir}/{ModelFile.VOCAB_FILE}'
  174. self.tokenizer = FullTokenizer(vocab_file=vocab_file)
  175. # image preprocessor
  176. if 'resolution' in kwargs and isinstance(kwargs['resolution'], int):
  177. self.image_resolution = kwargs['resolution']
  178. else:
  179. self.image_resolution = json.load(
  180. open(
  181. '{}/vision_model_config.json'.format(model_dir),
  182. encoding='utf-8'))['image_resolution']
  183. self.img_preprocess = self._build_image_transform()
  184. # key mapping
  185. # specify the input keys, compatible with training and inference whose key names may be different
  186. self.input_keys = {'img': 'img', 'text': 'text'}
  187. def _build_image_transform(self):
  188. if self.mode == ModeKeys.TRAIN:
  189. transform = create_transform(
  190. input_size=self.image_resolution,
  191. scale=(0.9, 1.0),
  192. is_training=True,
  193. color_jitter=None,
  194. auto_augment='original',
  195. interpolation='bicubic',
  196. mean=(0.48145466, 0.4578275, 0.40821073),
  197. std=(0.26862954, 0.26130258, 0.27577711),
  198. )
  199. transform = Compose(transform.transforms[:-3] + [_convert_to_rgb]
  200. + transform.transforms[-3:])
  201. else:
  202. transform = Compose([
  203. Resize((self.image_resolution, self.image_resolution),
  204. interpolation=Image.BICUBIC),
  205. _convert_to_rgb,
  206. ToTensor(),
  207. Normalize((0.48145466, 0.4578275, 0.40821073),
  208. (0.26862954, 0.26130258, 0.27577711)),
  209. ])
  210. return transform
  211. def tokenize(self,
  212. texts: Union[str, List[str]],
  213. context_length: int = 52) -> torch.LongTensor:
  214. """
  215. Returns the tokenized representation of given input string(s)
  216. Parameters
  217. ----------
  218. texts : Union[str, List[str]]
  219. An input string or a list of input strings to tokenize
  220. context_length : int
  221. The context length to use; all baseline models use 24 as the context length
  222. Returns
  223. -------
  224. A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
  225. """
  226. if isinstance(texts, str):
  227. texts = [texts]
  228. all_tokens = []
  229. for text in texts:
  230. all_tokens.append(
  231. [self.tokenizer.vocab['[CLS]']]
  232. + self.tokenizer.convert_tokens_to_ids(
  233. self.tokenizer.tokenize(text))[:context_length - 2]
  234. + [self.tokenizer.vocab['[SEP]']])
  235. result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
  236. for i, tokens in enumerate(all_tokens):
  237. assert len(tokens) <= context_length
  238. result[i, :len(tokens)] = torch.tensor(tokens)
  239. return result
  240. def set_input_img_key(self, new_key: str):
  241. self.input_keys['img'] = new_key
  242. def set_input_text_key(self, new_key: str):
  243. self.input_keys['text'] = new_key
  244. def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args,
  245. **kwargs) -> Dict[str, Any]:
  246. output = {}
  247. # preprocess the image input
  248. input_img_key = self.input_keys['img']
  249. if input_img_key in input and input[input_img_key] is not None:
  250. image_input = input[input_img_key]
  251. # single image input
  252. if isinstance(image_input, Image.Image):
  253. image_tensor = self.img_preprocess(image_input).unsqueeze(0)
  254. # multi images input
  255. elif isinstance(image_input, list):
  256. if all([isinstance(elem, Image.Image)
  257. for elem in image_input]):
  258. image_tensor = torch.stack(
  259. [self.img_preprocess(elem)
  260. for elem in image_input], # noqa
  261. dim=0) # noqa
  262. else:
  263. unsupported_elem_type = [
  264. type(elem) for elem in image_input
  265. if not isinstance(elem, Image.Image)
  266. ][0]
  267. raise TypeError(
  268. f'img should be PIL.Image or List[PIL.Image], \
  269. but got a List containing one {unsupported_elem_type}'
  270. )
  271. # others
  272. else:
  273. raise TypeError(
  274. f'img should be PIL.Image or List[PIL.Image], but got {type(image_input)}'
  275. )
  276. output['img'] = image_tensor
  277. # preprocess the text input
  278. input_text_key = self.input_keys['text']
  279. if input_text_key in input and input[input_text_key] is not None:
  280. text_input = input[input_text_key]
  281. # single text input
  282. if isinstance(text_input, str):
  283. text_tensor = self.tokenize(text_input)
  284. # multi texts input
  285. elif isinstance(text_input, list):
  286. if all([isinstance(elem, str) for elem in text_input]):
  287. text_tensor = self.tokenize(text_input)
  288. else:
  289. unsupported_elem_type = [
  290. type(elem) for elem in text_input
  291. if not isinstance(elem, str)
  292. ][0]
  293. raise TypeError(
  294. f'text should be str or List[str], but got a List containing one {unsupported_elem_type}'
  295. )
  296. # others
  297. else:
  298. raise TypeError(
  299. f'text should be str or List[str], but got {type(text_input)}'
  300. )
  301. output['text'] = text_tensor
  302. return output
  303. @PREPROCESSORS.register_module(
  304. Fields.multi_modal, module_name=Preprocessors.mplug_tasks_preprocessor)
  305. class MPlugPreprocessor(Preprocessor):
  306. def __init__(self,
  307. model_dir: str,
  308. mode: str = ModeKeys.INFERENCE,
  309. tokenizer_max_length: int = 25,
  310. *args,
  311. **kwargs):
  312. super().__init__(*args, **kwargs)
  313. self.model_dir = model_dir
  314. self.mode = mode
  315. self.tokenizer_max_length = tokenizer_max_length
  316. self._tokenizer = None
  317. self._patch_resize_transform = None
  318. self._image_map = {}
  319. @property
  320. def tokenizer(self):
  321. from transformers import BertTokenizer
  322. if self._tokenizer is None:
  323. self._tokenizer = BertTokenizer.from_pretrained(self.model_dir)
  324. return self._tokenizer
  325. @property
  326. def patch_resize_transform(self):
  327. if self._patch_resize_transform is None:
  328. from torchvision import transforms
  329. from modelscope.models.multi_modal.mplug import CONFIG_NAME, MPlugConfig
  330. config = MPlugConfig.from_yaml_file(
  331. osp.join(self.model_dir, CONFIG_NAME))
  332. mean = (0.48145466, 0.4578275, 0.40821073)
  333. std = (0.26862954, 0.26130258, 0.27577711)
  334. self._patch_resize_transform = transforms.Compose([
  335. transforms.Resize((config.image_res, config.image_res),
  336. interpolation=Image.BICUBIC),
  337. transforms.ToTensor(),
  338. transforms.Normalize(mean=mean, std=std),
  339. ])
  340. return self._patch_resize_transform
  341. def image_open(self, path: str) -> Tuple[Image.Image, int]:
  342. if path not in self._image_map:
  343. index = len(self._image_map)
  344. self._image_map[path] = (load_image(path), index)
  345. return self._image_map[path]
  346. def __call__(
  347. self, data: Union[Image.Image, tuple,
  348. Dict[str, Any]]) -> Dict[str, Any]:
  349. self.cfg = Config.from_file(
  350. osp.join(self.model_dir, ModelFile.CONFIGURATION))
  351. if isinstance(data, (Image.Image, str)):
  352. image = data
  353. elif isinstance(data, tuple):
  354. image = data[0]
  355. else:
  356. image = data['image']
  357. index = 0
  358. if isinstance(image, str):
  359. image, index = self.image_open(image)
  360. image = image.convert('RGB')
  361. image = self.patch_resize_transform(image)
  362. question = '' if self.cfg.task == Tasks.image_captioning \
  363. else data[1 if isinstance(data, tuple)
  364. else ('text' if 'text' in data else 'question')]
  365. question = self.tokenizer(
  366. question.lower(),
  367. padding='max_length',
  368. truncation=True,
  369. max_length=self.tokenizer_max_length,
  370. return_tensors='pt')
  371. if self.mode == ModeKeys.INFERENCE:
  372. image = torch.stack([image], dim=0)
  373. return {'image': image, 'question': question}
  374. else:
  375. answer = data['answer']
  376. answer = self.tokenizer(
  377. answer,
  378. padding='max_length',
  379. truncation=True,
  380. max_length=self.tokenizer_max_length,
  381. return_tensors='pt')
  382. output = {
  383. 'image': image,
  384. 'question_input_ids': question.input_ids.squeeze(),
  385. 'question_attention_mask': question.attention_mask.squeeze(),
  386. 'answer_input_ids': answer.input_ids.squeeze(),
  387. 'answer_attention_mask': answer.attention_mask.squeeze(),
  388. }
  389. if self.cfg.task == Tasks.image_text_retrieval:
  390. output['index'] = index
  391. return output
  392. @PREPROCESSORS.register_module(
  393. Fields.multi_modal, module_name=Preprocessors.vldoc_preprocessor)
  394. class VLDocPreprocessor(Preprocessor):
  395. def __init__(self,
  396. model_dir: str,
  397. mode: str = ModeKeys.INFERENCE,
  398. *args,
  399. **kwargs):
  400. """Preprocess data for the model `VLDocForDocVLEmbedding`.
  401. Args:
  402. model_dir (str): model path in model hub.
  403. mode (str): model mode, in ('train', 'eval', 'inference').
  404. """
  405. super().__init__(*args, **kwargs)
  406. self.model_dir = model_dir
  407. self.mode = mode
  408. model_cfg_path = osp.join(model_dir, 'config.json')
  409. with open(model_cfg_path, 'r', encoding='utf-8') as f:
  410. model_cfg = json.load(f)
  411. from modelscope.models.multi_modal.vldoc.tokenization import VLDocXLMTokenizer
  412. tokenizer_path = osp.join(model_dir, ModelFile.TOKENIZER_FOLDER)
  413. self.tokenizer = VLDocXLMTokenizer.from_pretrained(tokenizer_path)
  414. from modelscope.models.multi_modal.vldoc.processing import Processor, ImageProcessor
  415. self.img_proc = ImageProcessor(
  416. do_preprocess=True,
  417. do_resize=True,
  418. image_size={
  419. 'height': model_cfg['image_size'][0],
  420. 'width': model_cfg['image_size'][1],
  421. },
  422. do_normalize=True,
  423. apply_ocr=False)
  424. self.proc = Processor(
  425. max_seq_length=model_cfg['max_seq_length'],
  426. max_block_num=model_cfg['max_block_num'],
  427. img_processor=self.img_proc,
  428. tokenizer=self.tokenizer,
  429. width=model_cfg['image_size'][1],
  430. height=model_cfg['image_size'][0],
  431. )
  432. def __call__(self, input: Dict[str, Any], *args,
  433. **kwargs) -> Dict[str, Any]:
  434. """
  435. Args:
  436. input: {
  437. 'images': ['img_path1', 'img_path2', ...],
  438. 'ocr_info_paths': ['json_path1', 'json_path2', ...]
  439. }
  440. Return:
  441. encodings: Dict[str, Tensor]
  442. """
  443. ocr_infos = []
  444. for one_ocr_info_path in input['ocr_info_paths']:
  445. with open(one_ocr_info_path, 'r') as f:
  446. ocr_info = json.load(f)
  447. ocr_info = ocr_info['form']
  448. ocr_infos.append(ocr_info)
  449. proc_input = {'images': input['images'], 'ocr_infos': ocr_infos}
  450. encodings = self.proc(**proc_input)
  451. return encodings
  452. @PREPROCESSORS.register_module(
  453. Fields.multi_modal, module_name=Preprocessors.hitea_tasks_preprocessor)
  454. class HiTeAPreprocessor(Preprocessor):
  455. def __init__(self,
  456. model_dir: str,
  457. mode: str = ModeKeys.INFERENCE,
  458. tokenizer_max_length: int = 25,
  459. *args,
  460. **kwargs):
  461. super().__init__(*args, **kwargs)
  462. self.model_dir = model_dir
  463. self.mode = mode
  464. self.tokenizer_max_length = tokenizer_max_length
  465. self._tokenizer = None
  466. self._patch_resize_transform = None
  467. self._num_frames = None
  468. self._video_map = {}
  469. @property
  470. def tokenizer(self):
  471. from transformers import BertTokenizer
  472. if self._tokenizer is None:
  473. self._tokenizer = BertTokenizer.from_pretrained(self.model_dir)
  474. return self._tokenizer
  475. @property
  476. def patch_resize_transform(self):
  477. if self._patch_resize_transform is None:
  478. from torchvision import transforms
  479. from modelscope.models.multi_modal.mplug import CONFIG_NAME, HiTeAConfig
  480. config = HiTeAConfig.from_yaml_file(
  481. osp.join(self.model_dir, CONFIG_NAME))
  482. mean = (0.48145466, 0.4578275, 0.40821073)
  483. std = (0.26862954, 0.26130258, 0.27577711)
  484. self._patch_resize_transform = transforms.Compose([
  485. transforms.Resize((config.image_res, config.image_res),
  486. interpolation=Image.BICUBIC),
  487. transforms.ToTensor(),
  488. transforms.Normalize(mean=mean, std=std),
  489. ])
  490. return self._patch_resize_transform
  491. @property
  492. def num_frames(self):
  493. if self._num_frames is None:
  494. from torchvision import transforms
  495. from modelscope.models.multi_modal.mplug import CONFIG_NAME, HiTeAConfig
  496. config = HiTeAConfig.from_yaml_file(
  497. osp.join(self.model_dir, CONFIG_NAME))
  498. self._num_frames = config.num_frames
  499. return self._num_frames
  500. def video_open(self, path: str) -> Tuple[decord.VideoReader, int]:
  501. if path not in self._video_map:
  502. index = len(self._video_map)
  503. vr = decord.VideoReader(path, ctx=decord.cpu(0))
  504. self._video_map[path] = (vr, index)
  505. return self._video_map[path]
  506. def sample_frames(self, num_frames: int, vlen: int) -> List[int]:
  507. acc_samples = min(num_frames, vlen)
  508. # split the video into `acc_samples` intervals, and sample from each interval.
  509. intervals = np.linspace(
  510. start=0, stop=vlen, num=acc_samples + 1).astype(int)
  511. ranges = []
  512. for idx, interv in enumerate(intervals[:-1]):
  513. ranges.append((interv, intervals[idx + 1] - 1))
  514. frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
  515. if len(frame_indices) < num_frames: # padded with last frame
  516. padded_frame_indices = [frame_indices[-1]] * num_frames
  517. padded_frame_indices[:len(frame_indices)] = frame_indices
  518. frame_indices = padded_frame_indices
  519. return frame_indices
  520. def __call__(
  521. self, data: Union[decord.VideoReader, tuple,
  522. Dict[str, Any]]) -> Dict[str, Any]:
  523. self.cfg = Config.from_file(
  524. osp.join(self.model_dir, ModelFile.CONFIGURATION))
  525. if isinstance(data, (decord.VideoReader, str)):
  526. video = data
  527. elif isinstance(data, tuple):
  528. video = data[0]
  529. else:
  530. video = data['video']
  531. index = 0
  532. if isinstance(video, str):
  533. video, index = self.video_open(video)
  534. frame_indices = self.sample_frames(self.num_frames, len(video))
  535. video.seek(0)
  536. video = torch.from_numpy(video.get_batch(frame_indices).asnumpy())
  537. video = [
  538. self.patch_resize_transform(Image.fromarray(f))
  539. for f in video.numpy()
  540. ]
  541. video = torch.stack(video, dim=0)
  542. question = '' if self.cfg.task == Tasks.video_captioning \
  543. else data[1 if isinstance(data, tuple)
  544. else ('text' if 'text' in data else 'question')]
  545. question = self.tokenizer(
  546. question.lower(),
  547. padding='max_length',
  548. truncation=True,
  549. max_length=self.tokenizer_max_length,
  550. return_tensors='pt')
  551. if self.mode == ModeKeys.INFERENCE:
  552. video = torch.stack([video], dim=0)
  553. return {'video': video, 'question': question}
  554. else:
  555. answer = data['answer']
  556. answer = self.tokenizer(
  557. answer,
  558. padding='max_length',
  559. truncation=True,
  560. max_length=self.tokenizer_max_length,
  561. return_tensors='pt')
  562. output = {
  563. 'video': video,
  564. 'question_input_ids': question.input_ids.squeeze(),
  565. 'question_attention_mask': question.attention_mask.squeeze(),
  566. 'answer_input_ids': answer.input_ids.squeeze(),
  567. 'answer_attention_mask': answer.attention_mask.squeeze(),
  568. }
  569. return output
  570. @PREPROCESSORS.register_module(
  571. Fields.multi_modal, module_name=Preprocessors.mplug_owl_preprocessor)
  572. class MplugOwlPreprocessor(Preprocessor):
  573. def __init__(self,
  574. model_dir: str,
  575. mode: str = ModeKeys.INFERENCE,
  576. *args,
  577. **kwargs):
  578. super().__init__(*args, **kwargs)
  579. self.model_dir = model_dir
  580. self.mode = mode
  581. self._tokenizer = None
  582. self._patch_resize_transform = None
  583. self.media_token = {'<|image|>': 65}
  584. self._image_map = {}
  585. @property
  586. def tokenizer(self):
  587. from modelscope.models.nlp.llama import LlamaTokenizer
  588. if self._tokenizer is None:
  589. self._tokenizer = LlamaTokenizer.from_pretrained(self.model_dir)
  590. return self._tokenizer
  591. @property
  592. def patch_resize_transform(self):
  593. if self._patch_resize_transform is None:
  594. from torchvision import transforms
  595. mean = (0.48145466, 0.4578275, 0.40821073)
  596. std = (0.26862954, 0.26130258, 0.27577711)
  597. self._patch_resize_transform = transforms.Compose([
  598. transforms.Resize((224, 224), interpolation=Image.BICUBIC),
  599. transforms.ToTensor(),
  600. transforms.Normalize(mean=mean, std=std),
  601. ])
  602. return self._patch_resize_transform
  603. def image_open(self, path: str) -> Tuple[Image.Image, int]:
  604. if path not in self._image_map:
  605. index = len(self._image_map)
  606. self._image_map[path] = (load_image(path), index)
  607. return self._image_map[path]
  608. def tokenize_text(self, text: str) -> List[int]:
  609. media_tokens = {
  610. k: -int(i + 1)
  611. for i, k in enumerate(self.media_token.keys())
  612. }
  613. media_lengths = self.media_token.copy()
  614. prompt_chunk = [self.tokenizer.bos_token_id]
  615. # Pure Text
  616. condition = [
  617. media_token not in text for media_token in media_tokens.keys()
  618. ]
  619. if all(condition):
  620. enc_chunk = prompt_chunk + \
  621. self.tokenizer(text, add_special_tokens=False)['input_ids']
  622. # Multi-Modal Text
  623. else:
  624. enc_chunk = prompt_chunk
  625. pattern = '|'.join(map(re.escape, list(media_tokens.keys())))
  626. chunk_strs = re.split(f'({pattern})', text)
  627. chunk_strs = [x for x in chunk_strs if len(x) > 0]
  628. for idx, chunk_str in enumerate(chunk_strs):
  629. if chunk_str in media_tokens:
  630. enc_chunk += [media_tokens[chunk_str]] * \
  631. media_lengths[chunk_str]
  632. else:
  633. tmp_chunk = self.tokenizer(
  634. chunk_str, add_special_tokens=False)['input_ids']
  635. enc_chunk += tmp_chunk
  636. return enc_chunk
  637. def convert(self, messages: Dict[str, List[Dict]]) -> str:
  638. texts = []
  639. image = []
  640. messages = messages['messages']
  641. for turn in messages:
  642. if turn['role'] == 'system':
  643. role = ''
  644. elif turn['role'] == 'user':
  645. role = 'Human: '
  646. else:
  647. role = 'AI: '
  648. if isinstance(turn['content'], str):
  649. text = f"{role}{turn['content']}"
  650. texts.append(text)
  651. else:
  652. for t in turn['content']:
  653. if isinstance(t, str):
  654. text = f'{role}{t}'
  655. else:
  656. text = f'{role}<|image|>'
  657. image.append(t['image'])
  658. texts.append(text)
  659. texts = '\n'.join(texts)
  660. texts += '\nAI: '
  661. return image, texts
  662. def __call__(self, messages: Dict[str, Any],
  663. **forward_params) -> Dict[str, Any]:
  664. """
  665. Args:
  666. messages: {[
  667. {'role': 'system', 'content': 'message1'},
  668. {'role': 'user', 'content': 'message2'},
  669. {'role': 'user', 'content': ['message2', {"image": 'image_path'}, 'message3', ...]},
  670. ]}
  671. The 'role' should be choose from ['system', 'user', 'assistant'].
  672. The 'content' can be either str or List[Union[str, Dict]]
  673. Return:
  674. output: Dict[str, Tensor]
  675. """
  676. output = {}
  677. images, text = self.convert(messages)
  678. if len(images) > 0:
  679. pixel_values = []
  680. for image in images:
  681. pixel_values.append(
  682. self.patch_resize_transform(self.image_open(image)[0]))
  683. pixel_values = torch.stack(pixel_values, dim=0)
  684. else:
  685. pixel_values = None
  686. input_ids = self.tokenize_text(text)
  687. input_ids = torch.LongTensor([input_ids])
  688. output = {
  689. 'pixel_values': pixel_values,
  690. 'input_ids': input_ids,
  691. **forward_params
  692. }
  693. return output
  694. @PREPROCESSORS.register_module(
  695. Fields.multi_modal,
  696. module_name=Preprocessors.image_captioning_clip_interrogator_preprocessor)
  697. class ImageCaptioningClipInterrogatorPreprocessor(Preprocessor):
  698. def __init__(self, **kwargs):
  699. super().__init__(**kwargs)
  700. def __call__(self, data) -> Dict[str, Any]:
  701. image = load_image(data)
  702. data = np.array(image).transpose(2, 0, 1)
  703. return data