gridvlp_pipeline.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. import time
  4. import traceback
  5. from typing import Any, Dict, Optional
  6. import json
  7. import numpy as np
  8. import torch
  9. from PIL import Image
  10. from transformers import BertTokenizer
  11. from modelscope.hub.snapshot_download import snapshot_download
  12. from modelscope.metainfo import Pipelines
  13. from modelscope.pipelines import Pipeline
  14. from modelscope.pipelines.builder import PIPELINES
  15. from modelscope.preprocessors.image import load_image
  16. from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Frameworks,
  17. Invoke, Tasks)
  18. from modelscope.utils.logger import get_logger
  19. logger = get_logger()
  20. def cost(end, begin):
  21. return '{:.2f}ms'.format((end - begin) * 1000)
  22. class Config:
  23. SCALE = 1 / 255.0
  24. MEAN = np.require([0.485, 0.456, 0.406], dtype=np.float32)[:, np.newaxis,
  25. np.newaxis]
  26. STD = np.require([0.229, 0.224, 0.225], dtype=np.float32)[:, np.newaxis,
  27. np.newaxis]
  28. # RESIZE_HEIGHT = int(224*1.14)
  29. RESIZE_HEIGHT = int(256)
  30. # RESIZE_WIDTH = int(224*1.14)
  31. RESIZE_WIDTH = int(256)
  32. CROP_SIZE = 224
  33. def pre_processor(img):
  34. img = img.convert('RGB')
  35. w, h = img.size
  36. if (w <= h and w == Config.RESIZE_WIDTH) \
  37. or (h <= w and h == Config.RESIZE_WIDTH):
  38. img = img
  39. if w < h:
  40. ow = Config.RESIZE_WIDTH
  41. oh = int(Config.RESIZE_WIDTH * h / w)
  42. img = img.resize((ow, oh), Image.BILINEAR)
  43. else:
  44. oh = Config.RESIZE_WIDTH
  45. ow = int(Config.RESIZE_WIDTH * w / h)
  46. img = img.resize((ow, oh), Image.BILINEAR)
  47. w, h = img.size
  48. crop_top = int(round((h - Config.CROP_SIZE) / 2.))
  49. crop_left = int(round((w - Config.CROP_SIZE) / 2.))
  50. img = img.crop((crop_left, crop_top, crop_left + Config.CROP_SIZE,
  51. crop_top + Config.CROP_SIZE))
  52. _img = np.array(img, dtype=np.float32)
  53. _img = np.require(_img.transpose((2, 0, 1)), dtype=np.float32)
  54. _img *= Config.SCALE
  55. _img -= Config.MEAN
  56. _img /= Config.STD
  57. return _img
  58. class GridVlpPipeline(Pipeline):
  59. """ Pipeline for gridvlp, including classification and embedding."""
  60. def __init__(self, model_name_or_path: str, **kwargs):
  61. """ Pipeline for gridvlp, including classification and embedding.
  62. Args:
  63. model: path to local model directory.
  64. """
  65. # download model from modelscope to local model dir
  66. logger.info(f'load checkpoint from modelscope {model_name_or_path}')
  67. if osp.exists(model_name_or_path):
  68. local_model_dir = model_name_or_path
  69. else:
  70. invoked_by = '%s/%s' % (Invoke.KEY, Invoke.PIPELINE)
  71. local_model_dir = snapshot_download(
  72. model_name_or_path,
  73. DEFAULT_MODEL_REVISION,
  74. user_agent=invoked_by)
  75. self.local_model_dir = local_model_dir
  76. # load model from cpu and torch jit model
  77. logger.info(f'load model from {local_model_dir}')
  78. self.model = torch.jit.load(
  79. osp.join(local_model_dir, 'pytorch_model.pt'))
  80. self.framework = Frameworks.torch
  81. self.device_name = 'cpu'
  82. self._model_prepare = True
  83. self._auto_collate = False
  84. # load tokenizer
  85. logger.info(f'load tokenizer from {local_model_dir}')
  86. self.tokenizer = BertTokenizer.from_pretrained(local_model_dir)
  87. def preprocess(self, inputs: Dict[str, Any], max_seq_length=49):
  88. # fetch input params
  89. image = inputs.get('image', '')
  90. text = inputs.get('text', '')
  91. s1 = time.time()
  92. # download image and preprocess
  93. try:
  94. # load PIL image
  95. img = load_image(image)
  96. s2 = time.time()
  97. # image preprocess
  98. image_data = pre_processor(img)
  99. s3 = time.time()
  100. except Exception:
  101. image_data = np.zeros((3, 224, 224), dtype=np.float32)
  102. s2 = time.time()
  103. s3 = time.time()
  104. logger.info(traceback.print_exc())
  105. # text process
  106. if text is None or text.isspace() or not text.strip():
  107. logger.info('text is empty!')
  108. text = ''
  109. inputs = self.tokenizer(
  110. text,
  111. padding='max_length',
  112. truncation=True,
  113. max_length=max_seq_length)
  114. s4 = time.time()
  115. logger.info(f'example. text: {text} image: {image}')
  116. logger.info(
  117. f'preprocess. Img_Download:{cost(s2, s1)}, Img_Pre:{cost(s3, s2)}, Txt_Pre:{cost(s4, s3)}'
  118. )
  119. input_dict = {
  120. 'image': image_data,
  121. 'input_ids': inputs['input_ids'],
  122. 'input_mask': inputs['attention_mask'],
  123. 'segment_ids': inputs['token_type_ids']
  124. }
  125. return input_dict
  126. @PIPELINES.register_module(
  127. Tasks.visual_question_answering,
  128. module_name=Pipelines.gridvlp_multi_modal_classification)
  129. class GridVlpClassificationPipeline(GridVlpPipeline):
  130. """ Pipeline for gridvlp classification, including cate classification and
  131. brand classification.
  132. Example:
  133. ```python
  134. >>> from modelscope.pipelines.multi_modal.gridvlp_pipeline import \
  135. GridVlpClassificationPipeline
  136. >>> pipeline = GridVlpClassificationPipeline('rgtjf1/multi-modal_gridvlp_classification_chinese-base-ecom-cate')
  137. >>> output = pipeline({'text': '女装快干弹力轻型短裤448575',\
  138. 'image':'https://yejiabo-public.oss-cn-zhangjiakou.aliyuncs.com/alinlp/clothes.png'})
  139. >>> output['text'][0]
  140. {'label': {'cate_name': '休闲裤', 'cate_path': '女装>>裤子>>休闲裤>>休闲裤'}, 'score': 0.4146, 'rank': 0}
  141. ```
  142. """
  143. def __init__(self, model_name_or_path: str, **kwargs):
  144. """ Pipeline for gridvlp classification, including cate classification and
  145. brand classification.
  146. Args:
  147. model: path to local model directory.
  148. """
  149. super().__init__(model_name_or_path, **kwargs)
  150. # load label mapping
  151. logger.info(f'load label mapping from {self.local_model_dir}')
  152. self.label_mapping = json.load(
  153. open(osp.join(self.local_model_dir, 'label_mapping.json')))
  154. def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  155. s4 = time.time()
  156. box_tensor = torch.zeros(1, dtype=torch.float32)
  157. output = self.model(
  158. torch.tensor(inputs['image']).unsqueeze(0),
  159. box_tensor.unsqueeze(0),
  160. torch.tensor(inputs['input_ids'], dtype=torch.long).unsqueeze(0),
  161. torch.tensor(inputs['input_mask'], dtype=torch.long).unsqueeze(0),
  162. torch.tensor(inputs['segment_ids'], dtype=torch.long).unsqueeze(0))
  163. output = output[0].detach().numpy()
  164. s5 = time.time()
  165. logger.info(f'forward. Infer:{cost(s5, s4)}')
  166. # 返回结果
  167. return output
  168. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  169. s5 = time.time()
  170. output = inputs
  171. index = np.argsort(-output)
  172. out_sort = output[index]
  173. top_k = []
  174. for i in range(min(10, len(self.label_mapping))):
  175. label = self.label_mapping[str(index[i])]
  176. top_k.append({
  177. 'label': label,
  178. 'score': round(float(out_sort[i]), 4),
  179. 'rank': i
  180. })
  181. s6 = time.time()
  182. logger.info(f'postprocess. Post: {cost(s6, s5)}')
  183. return {'text': top_k}
  184. @PIPELINES.register_module(
  185. Tasks.multi_modal_embedding,
  186. module_name=Pipelines.gridvlp_multi_modal_embedding)
  187. class GridVlpEmbeddingPipeline(GridVlpPipeline):
  188. """ Pipeline for gridvlp embedding. These only generate unified multi-modal
  189. embeddings and output it in `text_embedding` or `img_embedding`.
  190. Example:
  191. ```python
  192. >>> from modelscope.pipelines.multi_modal.gridvlp_pipeline import \
  193. GridVlpEmbeddingPipeline
  194. >>> pipeline = GridVlpEmbeddingPipeline('rgtjf1/multi-modal_gridvlp_classification_chinese-base-ecom-embedding')
  195. >>> outputs = pipeline({'text': '女装快干弹力轻型短裤448575',\
  196. 'image':'https://yejiabo-public.oss-cn-zhangjiakou.aliyuncs.com/alinlp/clothes.png'})
  197. >>> outputs["text_embedding"].shape
  198. (768,)
  199. ```
  200. """
  201. def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  202. s4 = time.time()
  203. box_tensor = torch.zeros(1, dtype=torch.float32)
  204. output = self.model(
  205. torch.tensor(inputs['image']).unsqueeze(0),
  206. box_tensor.unsqueeze(0),
  207. torch.tensor(inputs['input_ids'], dtype=torch.long).unsqueeze(0),
  208. torch.tensor(inputs['input_mask'], dtype=torch.long).unsqueeze(0),
  209. torch.tensor(inputs['segment_ids'], dtype=torch.long).unsqueeze(0))
  210. s5 = time.time()
  211. output = output[0].detach().numpy()
  212. s6 = time.time()
  213. logger.info(f'forward. Infer:{cost(s5, s4)}, Post: {cost(s6, s5)}')
  214. # 返回结果
  215. return output
  216. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  217. outputs = {
  218. 'img_embedding': inputs,
  219. 'text_embedding': inputs,
  220. }
  221. return outputs