visual_entailment.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict
  3. import torch
  4. from PIL import Image
  5. from torchvision import transforms
  6. from modelscope.preprocessors.image import load_image
  7. from modelscope.utils.constant import ModeKeys
  8. from .base import OfaBasePreprocessor
  9. class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor):
  10. r"""
  11. OFA preprocessor for visual entailment tasks.
  12. """
  13. def __init__(self,
  14. cfg,
  15. model_dir,
  16. mode=ModeKeys.INFERENCE,
  17. *args,
  18. **kwargs):
  19. """preprocess the data
  20. Args:
  21. cfg(modelscope.utils.config.ConfigDict) : model config
  22. model_dir (str): model path,
  23. mode: preprocessor mode (model mode)
  24. """
  25. super(OfaVisualEntailmentPreprocessor,
  26. self).__init__(cfg, model_dir, mode, *args, **kwargs)
  27. # Initialize transform
  28. self.patch_resize_transform = transforms.Compose([
  29. lambda image: image.convert('RGB'),
  30. transforms.Resize(
  31. (self.patch_image_size, self.patch_image_size),
  32. interpolation=transforms.InterpolationMode.BICUBIC),
  33. transforms.ToTensor(),
  34. transforms.Normalize(mean=self.mean, std=self.std),
  35. ])
  36. def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
  37. if self.mode == ModeKeys.TRAIN:
  38. return self._build_train_sample(data)
  39. else:
  40. return self._build_infer_sample(data)
  41. def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
  42. r"""
  43. Building training samples.
  44. step 1. Preprocess the data using the logic of `_build_infer_sample`
  45. and make sure the label data in the result.
  46. step 2. Preprocess the label data to generate the `target` and
  47. `prev_output_tokens`.
  48. - tokenize the label data.
  49. - calculate the target item.
  50. 1) if `promp_type` is `None`, using tokenized label data.
  51. 2) if `promp_type` is `src`, concatenating the `source` data
  52. and tokenized label data.
  53. 3) if `promp_type` is `prev_output`, concatenating the `source`
  54. data without eos token and tokenized label data
  55. step 3. Add constraint mask
  56. Args:
  57. data (`Dict[str, Any]`): Input data, should contains the key of `text`
  58. `text2` and `label` are optional.
  59. Return:
  60. A dict object, contains source text input, patch images, patch masks
  61. with `Tensor([True])` value, decoder prompt, label, target, previous
  62. output tokens and constraint mask.
  63. """
  64. sample = self._build_infer_sample(data)
  65. target = ' {}'.format(sample['label'])
  66. sample['ref_dict'] = {sample['label']: 1.0}
  67. tgt_item = self.tokenize_text(target, add_bos=False, add_eos=False)
  68. if self.prompt_type == 'none':
  69. prev_output_item = torch.cat([self.bos_item, tgt_item])
  70. target_item = torch.cat([prev_output_item[1:], self.eos_item])
  71. elif self.prompt_type == 'src':
  72. prev_output_item = torch.cat([sample['source'], tgt_item])
  73. target_item = torch.cat([prev_output_item[1:], self.eos_item])
  74. elif self.prompt_type == 'prev_output':
  75. prev_output_item = torch.cat([sample['source'][:-1], tgt_item])
  76. target_item = torch.cat([prev_output_item[1:], self.eos_item])
  77. else:
  78. raise NotImplementedError
  79. target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id
  80. sample['target'] = target_item
  81. sample['prev_output_tokens'] = prev_output_item
  82. if self.constraint_trie is not None:
  83. constraint_mask = torch.zeros(
  84. (len(target_item), len(self.tgt_dict))).bool()
  85. start_idx = len(target_item) - len(tgt_item) - 1
  86. for i in range(
  87. len(target_item) - len(tgt_item) - 1, len(target_item)):
  88. constraint_prefix_token = [
  89. self.tgt_dict.bos()
  90. ] + target_item[start_idx:i].tolist()
  91. constraint_nodes = self.constraint_trie.get_next_layer(
  92. constraint_prefix_token)
  93. constraint_mask[i][constraint_nodes] = True
  94. sample['constraint_mask'] = constraint_mask
  95. return sample
  96. def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
  97. r"""
  98. Building inference samples.
  99. step 1. Preprocessing the image as model's image input.
  100. - get the pillow image input from `data`
  101. - do some transforms to the pillow image, such as resize, normalize etc.
  102. step 2. Building the instruction as model's source text input.
  103. - use text input to build instruction. so far, we support two kind of
  104. input form, we will take different examples to both of them to explain
  105. how to use them.
  106. 1) only `text` input in data. this setting can solve the tasks which
  107. judge whether or not the input `text` describe the input image.
  108. 2) both `text` and `text2` input in data. this setting can solve the
  109. tasks which judge whether or not the `text` together with input image
  110. can imply the `text2`
  111. - tokenize the instruction above.
  112. step 3. Calculate the decoder prompt input.
  113. step 4. Whether or not to add label data.
  114. Args:
  115. data (`Dict[str, Any]`): Input data, should contains the key of `text`
  116. `text2` and `label` are optional.
  117. Return:
  118. A dict object, contains source text input, patch images, patch masks
  119. with `Tensor([True])` value, decoder prompt and label.
  120. """
  121. image = self.get_img_pil(data[self.column_map['image']])
  122. patch_image = self.patch_resize_transform(image)
  123. if 'text2' not in data:
  124. hypothesis = self.pre_caption(data[self.column_map['text']],
  125. self.max_src_length)
  126. prompt = self.cfg.model.get('prompt',
  127. ' does the image describe " {} "?')
  128. text = prompt.format(hypothesis)
  129. else:
  130. assert 'text' in data, f'text must be in the input {data.keys()}'
  131. caption = self.pre_caption(data[self.column_map['text2']],
  132. self.max_src_length)
  133. hypothesis = self.pre_caption(data[self.column_map['text']],
  134. self.max_src_length)
  135. prompt = self.cfg.model.get(
  136. 'prompt', ' can image and text1 " {} " imply text2 " {} "?')
  137. text = prompt.format(caption, hypothesis)
  138. inputs = self.tokenize_text(text)
  139. if self.prompt_type == 'none':
  140. prefix_token = []
  141. decoder_prompt = self.bos_item
  142. elif self.prompt_type == 'prev_output':
  143. prefix_token = inputs[:-1] # remove eos
  144. decoder_prompt = inputs[:-1]
  145. else:
  146. raise NotImplementedError
  147. sample = {
  148. 'source': inputs,
  149. 'patch_image': patch_image,
  150. 'patch_mask': torch.tensor([True]),
  151. 'prefix_token': prefix_token,
  152. 'decoder_prompt': decoder_prompt,
  153. }
  154. if 'relation' in self.column_map and self.column_map[
  155. 'relation'] in data:
  156. sample['label'] = data[self.column_map['relation']]
  157. return sample