visual_grounding.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict
  3. import numpy as np
  4. import torch
  5. from PIL import Image
  6. from torchvision import transforms
  7. from modelscope.preprocessors.image import load_image
  8. from modelscope.utils.constant import ModeKeys
  9. from .base import OfaBasePreprocessor
  10. from .utils import transforms as T
  11. class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):
  12. r"""
  13. OFA preprocessor for visual grounding tasks.
  14. """
  15. def __init__(self,
  16. cfg,
  17. model_dir,
  18. mode=ModeKeys.INFERENCE,
  19. *args,
  20. **kwargs):
  21. """preprocess the data
  22. Args:
  23. cfg(modelscope.utils.config.ConfigDict) : model config
  24. model_dir (str): model path,
  25. mode: preprocessor mode (model mode)
  26. """
  27. super(OfaVisualGroundingPreprocessor,
  28. self).__init__(cfg, model_dir, mode, *args, **kwargs)
  29. self.num_bins = self.cfg.model.get('num_bins', 1000)
  30. if self.mode == ModeKeys.TRAIN:
  31. # for positioning
  32. self.positioning_transform = T.Compose([
  33. T.RandomResize([self.patch_image_size],
  34. max_size=self.patch_image_size),
  35. T.ToTensor(),
  36. T.Normalize(
  37. mean=self.mean,
  38. std=self.std,
  39. max_image_size=self.max_image_size)
  40. ])
  41. else:
  42. # Initialize transform
  43. self.patch_resize_transform = transforms.Compose([
  44. lambda image: image.convert('RGB'),
  45. transforms.Resize(
  46. (self.patch_image_size, self.patch_image_size),
  47. interpolation=transforms.InterpolationMode.BICUBIC),
  48. transforms.ToTensor(),
  49. transforms.Normalize(mean=self.mean, std=self.std),
  50. ])
  51. def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
  52. if self.mode == ModeKeys.TRAIN:
  53. return self._build_train_sample(data)
  54. else:
  55. return self._build_infer_sample(data)
  56. def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
  57. r"""
  58. Building training samples.
  59. step 1. Preprocessing the image input for model's image input.
  60. - get the pillow image.
  61. - calculate the target boxes using for getting the exact area
  62. in the pillow image for input text by input `region_coord`. in
  63. training setting, `region_coord` will be a label data.
  64. - getting the target image as patch images and do some transforms
  65. such as resize, normalize etc.
  66. step 2. Preprocessing the text input for model's source text input.
  67. - do the str preprocessing to text input by function `pre_caption`.
  68. - build the instruction. the default instruction is
  69. ` which region does the text " {} " describe?`, `{}` refer to the
  70. text input.
  71. - tokenize the instruction as source text input.
  72. step 3. Preprocessing the patch image boxes for model's target text input.
  73. - quantize the coordinate of selected patch images
  74. - concatenate the quantization results by blank
  75. - tokenize the result above as target text input.
  76. step 4. Get the previous output tokens using target item without eos token.
  77. Args:
  78. data (`Dict[str, Any]`): Input data, should contains the key of `image`
  79. `text` and `region_coord`.
  80. Return:
  81. A dict object, contains source text input, patch images, patch masks
  82. with `Tensor([True])` value, target, previous output tokens,
  83. width scale ratio, height scale ratio and region coordinate.
  84. """
  85. image = self.get_img_pil(data[self.column_map['image']])
  86. w, h = image.size
  87. boxes_target = {
  88. 'boxes': [],
  89. 'labels': [],
  90. 'area': [],
  91. 'size': torch.tensor([h, w])
  92. }
  93. x0, y0, x1, y1 = data[self.column_map['region_coord']].strip().split(
  94. ',')
  95. region = torch.tensor([float(x0), float(y0), float(x1), float(y1)])
  96. boxes_target['boxes'] = torch.tensor(
  97. [[float(x0), float(y0), float(x1),
  98. float(y1)]])
  99. boxes_target['labels'] = np.array([0])
  100. area = [(float(x1) - float(x0)) * (float(y1) - float(y0))]
  101. boxes_target['area'] = torch.tensor(area)
  102. patch_image, patch_boxes = self.positioning_transform(
  103. image, boxes_target)
  104. resize_h, resize_w = patch_boxes['size'][0], patch_boxes['size'][1]
  105. quant_x0 = '<bin_{}>'.format(
  106. int((patch_boxes['boxes'][0][0] * (self.num_bins - 1)).round()))
  107. quant_y0 = '<bin_{}>'.format(
  108. int((patch_boxes['boxes'][0][1] * (self.num_bins - 1)).round()))
  109. quant_x1 = '<bin_{}>'.format(
  110. int((patch_boxes['boxes'][0][2] * (self.num_bins - 1)).round()))
  111. quant_y1 = '<bin_{}>'.format(
  112. int((patch_boxes['boxes'][0][3] * (self.num_bins - 1)).round()))
  113. region_coord = '{} {} {} {}'.format(quant_x0, quant_y0, quant_x1,
  114. quant_y1)
  115. src_caption = self.pre_caption(data[self.column_map['text']],
  116. self.max_src_length)
  117. prompt = self.cfg.model.get(
  118. 'prompt', ' which region does the text " {} " describe?')
  119. text = prompt.format(src_caption)
  120. src_item = self.tokenize_text(text)
  121. target_item = self.tokenize_text(
  122. region_coord, add_bos=False) # !!! use_bpe=False
  123. prev_output_item = torch.cat([self.bos_item, target_item[:-1]])
  124. sample = {
  125. 'source': src_item,
  126. 'patch_image': patch_image,
  127. 'patch_mask': torch.tensor([True]),
  128. 'target': target_item,
  129. 'prev_output_tokens': prev_output_item,
  130. 'w_resize_ratio': resize_w / w,
  131. 'h_resize_ratio': resize_h / h,
  132. 'region_coord': region
  133. }
  134. return sample
  135. def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
  136. r"""
  137. Building inference samples.
  138. step 1. Preprocessing image input for model's image input.
  139. - get pillow image from data.
  140. - do some transforms to the pillow image, such as resize, normalize etc.
  141. step 2. Preprocessing the text input for model's text input.
  142. - do the str preprocessing to text input by function `pre_caption`.
  143. - build the instruction. the default instruction is
  144. ` which region does the text " {} " describe?`, `{}` refer to the
  145. text input.
  146. - tokenize the instruction as source text input.
  147. step 3. Whether or not to add label data which refer to a region coordinate
  148. in this task.
  149. Args:
  150. data (`Dict[str, Any]`): Input data, should contains the key of `image`
  151. `text`.
  152. Return:
  153. A dict object, contains source text input, patch images, patch masks
  154. with `Tensor([True])` value, width scale ratio, height scale ratio
  155. and label.
  156. """
  157. image = self.get_img_pil(data[self.column_map['image']])
  158. w, h = image.size
  159. patch_image = self.patch_resize_transform(image)
  160. w_resize_ratio = torch.tensor(self.patch_image_size / w)
  161. h_resize_ratio = torch.tensor(self.patch_image_size / h)
  162. src_caption = self.pre_caption(data[self.column_map['text']],
  163. self.max_src_length)
  164. prompt = self.cfg.model.get(
  165. 'prompt', ' which region does the text " {} " describe?')
  166. text = prompt.format(src_caption)
  167. src_item = self.tokenize_text(text)
  168. sample = {
  169. 'source': src_item,
  170. 'patch_image': patch_image,
  171. 'patch_mask': torch.tensor([True]),
  172. 'w_resize_ratio': w_resize_ratio,
  173. 'h_resize_ratio': h_resize_ratio,
  174. }
  175. if 'region_coord' in self.column_map and self.column_map[
  176. 'region_coord'] in data:
  177. x0, y0, x1, y1 = data[
  178. self.column_map['region_coord']].strip().split(',')
  179. sample['label'] = [float(x0), float(y0), float(x1), float(y1)]
  180. return sample