image_classification.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import functools
  3. from typing import Any, Dict
  4. import torch
  5. from PIL import Image, ImageFile
  6. from timm.data import create_transform
  7. from torchvision import transforms
  8. from modelscope.preprocessors.image import load_image
  9. from modelscope.utils.constant import ModeKeys
  10. from .base import OfaBasePreprocessor
  11. from .utils.vision_helper import RandomAugment
  12. ImageFile.LOAD_TRUNCATED_IMAGES = True
  13. ImageFile.MAX_IMAGE_PIXELS = None
  14. Image.MAX_IMAGE_PIXELS = None
  15. class OfaImageClassificationPreprocessor(OfaBasePreprocessor):
  16. r"""
  17. OFA preprocessor for image classification task.
  18. """
  19. def __init__(self,
  20. cfg,
  21. model_dir,
  22. mode=ModeKeys.INFERENCE,
  23. *args,
  24. **kwargs):
  25. """preprocess the data
  26. Args:
  27. cfg(modelscope.utils.config.ConfigDict) : model config
  28. model_dir (str): model path,
  29. mode: preprocessor mode (model mode)
  30. """
  31. super(OfaImageClassificationPreprocessor,
  32. self).__init__(cfg, model_dir, mode, *args, **kwargs)
  33. # Initialize transform
  34. if self.mode != ModeKeys.TRAIN:
  35. self.patch_resize_transform = transforms.Compose([
  36. lambda image: image.convert('RGB'),
  37. transforms.Resize(
  38. (self.patch_image_size, self.patch_image_size),
  39. interpolation=transforms.InterpolationMode.BICUBIC),
  40. transforms.ToTensor(),
  41. transforms.Normalize(mean=self.mean, std=self.std),
  42. ])
  43. else:
  44. self.patch_resize_transform = create_transform(
  45. input_size=self.patch_image_size,
  46. is_training=True,
  47. color_jitter=0.4,
  48. auto_augment='rand-m9-mstd0.5-inc1',
  49. interpolation='bicubic',
  50. re_prob=0.25,
  51. re_mode='pixel',
  52. re_count=1,
  53. mean=self.mean,
  54. std=self.std)
  55. self.patch_resize_transform = transforms.Compose(
  56. functools.reduce(lambda x, y: x + y, [
  57. [
  58. lambda image: image.convert('RGB'),
  59. ],
  60. self.patch_resize_transform.transforms[:2],
  61. [self.patch_resize_transform.transforms[2]],
  62. [
  63. RandomAugment(
  64. 2,
  65. 7,
  66. isPIL=True,
  67. augs=[
  68. 'Identity', 'AutoContrast', 'Equalize',
  69. 'Brightness', 'Sharpness', 'ShearX', 'ShearY',
  70. 'TranslateX', 'TranslateY', 'Rotate'
  71. ]),
  72. ],
  73. self.patch_resize_transform.transforms[3:],
  74. ]))
  75. def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
  76. if self.mode == ModeKeys.TRAIN:
  77. return self._build_train_sample(data)
  78. else:
  79. return self._build_infer_sample(data)
  80. def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
  81. r"""
  82. Building training samples.
  83. step 1. Preprocess the data using the logic of `_build_infer_sample`
  84. and make sure the label data in the result.
  85. step 2. Preprocess the label data. Contains:
  86. - add ` ` before the label value and add `ref_dict` value
  87. - tokenize the label as `target` value without `bos` token.
  88. - add `bos` token and remove `eos` token of `target` as `prev_output_tokens`.
  89. - add constraints mask.
  90. Args:
  91. data (`Dict[str, Any]`): Input data, should contains the key of `image`,
  92. `prompt` and `label`, `image` refers the image input data, `prompt`
  93. refers the text input data the `label` is the supervised data for training.
  94. Return:
  95. A dict object, contains source, image, mask, label, target tokens,
  96. and previous output tokens data.
  97. """
  98. sample = self._build_infer_sample(data)
  99. target = ' {}'.format(sample['label'])
  100. sample['ref_dict'] = {sample['label']: 1.0}
  101. sample['target'] = self.tokenize_text(target, add_bos=False)
  102. sample['prev_output_tokens'] = torch.cat(
  103. [self.bos_item, sample['target'][:-1]])
  104. if self.constraint_trie is not None:
  105. constraint_mask = torch.zeros((len(sample['prev_output_tokens']),
  106. len(self.tgt_dict))).bool()
  107. for i in range(len(sample['prev_output_tokens'])):
  108. constraint_prefix_token = sample[
  109. 'prev_output_tokens'][:i + 1].tolist()
  110. constraint_nodes = self.constraint_trie.get_next_layer(
  111. constraint_prefix_token)
  112. constraint_mask[i][constraint_nodes] = True
  113. sample['constraint_mask'] = constraint_mask
  114. return sample
  115. def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
  116. r"""
  117. Building inference samples.
  118. step 1. Get the pillow image.
  119. step 2. Do some transforms for the pillow image as the image input,
  120. such as resize, normalize, to tensor etc.
  121. step 3. Tokenize the prompt as text input.
  122. step 4. Determine Whether or not to add labels to the sample.
  123. Args:
  124. data (`Dict[str, Any]`): Input data, should contains the key of `image` and `prompt`,
  125. the former refers the image input data, and the later refers the text input data.
  126. Return:
  127. A dict object, contains source, image, mask and label data.
  128. """
  129. image = self.get_img_pil(data[self.column_map['image']])
  130. patch_image = self.patch_resize_transform(image)
  131. prompt = self.cfg.model.get('prompt', ' what does the image describe?')
  132. inputs = self.tokenize_text(prompt)
  133. sample = {
  134. 'source': inputs,
  135. 'patch_image': patch_image,
  136. 'patch_mask': torch.tensor([True]),
  137. 'decoder_prompt': self.bos_item,
  138. }
  139. if 'text' in self.column_map and self.column_map['text'] in data:
  140. sample['label'] = data[self.column_map['text']]
  141. return sample