base.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from abc import ABC, abstractmethod
  4. from typing import Any, Callable, Dict, Optional, Sequence, Union
  5. from modelscope.metainfo import Models, Preprocessors, TaskModels
  6. from modelscope.utils.config import Config, ConfigDict
  7. from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke,
  8. ModeKeys, Tasks)
  9. from modelscope.utils.hub import read_config, snapshot_download
  10. from modelscope.utils.logger import get_logger
  11. from .builder import build_preprocessor
  12. logger = get_logger()
  13. PREPROCESSOR_MAP = {
  14. # nlp
  15. (Models.canmt, Tasks.competency_aware_translation):
  16. Preprocessors.canmt_translation,
  17. # bart
  18. (Models.bart, Tasks.text_error_correction):
  19. Preprocessors.text_error_correction,
  20. # bert
  21. (Models.bert, Tasks.backbone):
  22. Preprocessors.sen_cls_tokenizer,
  23. (Models.bert, Tasks.document_segmentation):
  24. Preprocessors.document_segmentation,
  25. (Models.bert, Tasks.fill_mask):
  26. Preprocessors.fill_mask,
  27. (Models.bert, Tasks.sentence_embedding):
  28. Preprocessors.sentence_embedding,
  29. (Models.bert, Tasks.text_classification):
  30. Preprocessors.sen_cls_tokenizer,
  31. (Models.bert, Tasks.speaker_diarization_dialogue_detection):
  32. Preprocessors.sen_cls_tokenizer,
  33. (Models.bert, Tasks.nli):
  34. Preprocessors.sen_cls_tokenizer,
  35. (Models.bert, Tasks.sentiment_classification):
  36. Preprocessors.sen_cls_tokenizer,
  37. (Models.bert, Tasks.sentence_similarity):
  38. Preprocessors.sen_cls_tokenizer,
  39. (Models.bert, Tasks.zero_shot_classification):
  40. Preprocessors.sen_cls_tokenizer,
  41. (Models.bert, Tasks.text_ranking):
  42. Preprocessors.text_ranking,
  43. (Models.bert, Tasks.part_of_speech):
  44. Preprocessors.token_cls_tokenizer,
  45. (Models.bert, Tasks.token_classification):
  46. Preprocessors.token_cls_tokenizer,
  47. (Models.bert, Tasks.speaker_diarization_semantic_speaker_turn_detection):
  48. Preprocessors.token_cls_tokenizer,
  49. (Models.bert, Tasks.word_segmentation):
  50. Preprocessors.token_cls_tokenizer,
  51. # bloom
  52. (Models.bloom, Tasks.backbone):
  53. Preprocessors.text_gen_tokenizer,
  54. # gpt_neo
  55. # gpt_neo may have different preprocessors, but now only one
  56. (Models.gpt_neo, Tasks.backbone):
  57. Preprocessors.sentence_piece,
  58. # gpt3 has different preprocessors by different sizes of models, so they are not listed here.
  59. # palm_v2
  60. (Models.palm, Tasks.backbone):
  61. Preprocessors.text_gen_tokenizer,
  62. # T5
  63. (Models.T5, Tasks.backbone):
  64. Preprocessors.text2text_gen_preprocessor,
  65. (Models.T5, Tasks.text2text_generation):
  66. Preprocessors.text2text_gen_preprocessor,
  67. # deberta_v2
  68. (Models.deberta_v2, Tasks.backbone):
  69. Preprocessors.sen_cls_tokenizer,
  70. (Models.deberta_v2, Tasks.fill_mask):
  71. Preprocessors.fill_mask,
  72. # ponet
  73. (Models.ponet, Tasks.fill_mask):
  74. Preprocessors.fill_mask_ponet,
  75. # structbert
  76. (Models.structbert, Tasks.backbone):
  77. Preprocessors.sen_cls_tokenizer,
  78. (Models.structbert, Tasks.fill_mask):
  79. Preprocessors.fill_mask,
  80. (Models.structbert, Tasks.faq_question_answering):
  81. Preprocessors.faq_question_answering_preprocessor,
  82. (Models.structbert, Tasks.text_classification):
  83. Preprocessors.sen_cls_tokenizer,
  84. (Models.structbert, Tasks.nli):
  85. Preprocessors.sen_cls_tokenizer,
  86. (Models.structbert, Tasks.sentiment_classification):
  87. Preprocessors.sen_cls_tokenizer,
  88. (Models.structbert, Tasks.sentence_similarity):
  89. Preprocessors.sen_cls_tokenizer,
  90. (Models.structbert, Tasks.zero_shot_classification):
  91. Preprocessors.sen_cls_tokenizer,
  92. (Models.structbert, Tasks.part_of_speech):
  93. Preprocessors.token_cls_tokenizer,
  94. (Models.token_classification_for_ner, Tasks.named_entity_recognition):
  95. Preprocessors.token_cls_tokenizer,
  96. (Models.structbert, Tasks.token_classification):
  97. Preprocessors.token_cls_tokenizer,
  98. (Models.structbert, Tasks.word_segmentation):
  99. Preprocessors.token_cls_tokenizer,
  100. # doc2bot
  101. (Models.doc2bot, Tasks.document_grounded_dialog_generate):
  102. Preprocessors.document_grounded_dialog_generate,
  103. (Models.doc2bot, Tasks.document_grounded_dialog_rerank):
  104. Preprocessors.document_grounded_dialog_rerank,
  105. (Models.doc2bot, Tasks.document_grounded_dialog_retrieval):
  106. Preprocessors.document_grounded_dialog_retrieval,
  107. # veco
  108. (Models.veco, Tasks.backbone):
  109. Preprocessors.sen_cls_tokenizer,
  110. (Models.veco, Tasks.fill_mask):
  111. Preprocessors.fill_mask,
  112. (Models.veco, Tasks.text_classification):
  113. Preprocessors.sen_cls_tokenizer,
  114. (Models.veco, Tasks.nli):
  115. Preprocessors.sen_cls_tokenizer,
  116. (Models.veco, Tasks.sentiment_classification):
  117. Preprocessors.sen_cls_tokenizer,
  118. (Models.veco, Tasks.sentence_similarity):
  119. Preprocessors.sen_cls_tokenizer,
  120. # ner models
  121. (Models.lcrf, Tasks.named_entity_recognition):
  122. Preprocessors.sequence_labeling_tokenizer,
  123. (Models.lcrf, Tasks.word_segmentation):
  124. Preprocessors.sequence_labeling_tokenizer,
  125. (Models.lcrf, Tasks.part_of_speech):
  126. Preprocessors.sequence_labeling_tokenizer,
  127. (Models.lcrf_wseg, Tasks.word_segmentation):
  128. Preprocessors.sequence_labeling_tokenizer,
  129. (Models.tcrf_wseg, Tasks.word_segmentation):
  130. Preprocessors.sequence_labeling_tokenizer,
  131. (Models.tcrf, Tasks.named_entity_recognition):
  132. Preprocessors.sequence_labeling_tokenizer,
  133. # task models
  134. (TaskModels.token_classification, Tasks.token_classification):
  135. Preprocessors.sequence_labeling_tokenizer,
  136. (TaskModels.token_classification, Tasks.part_of_speech):
  137. Preprocessors.sequence_labeling_tokenizer,
  138. (TaskModels.token_classification, Tasks.named_entity_recognition):
  139. Preprocessors.sequence_labeling_tokenizer,
  140. (TaskModels.text_classification, Tasks.text_classification):
  141. Preprocessors.sen_cls_tokenizer,
  142. (TaskModels.fill_mask, Tasks.fill_mask):
  143. Preprocessors.fill_mask,
  144. (TaskModels.feature_extraction, Tasks.feature_extraction):
  145. Preprocessors.feature_extraction,
  146. (TaskModels.information_extraction, Tasks.information_extraction):
  147. Preprocessors.re_tokenizer,
  148. (TaskModels.text_ranking, Tasks.text_ranking):
  149. Preprocessors.text_ranking,
  150. (TaskModels.text_generation, Tasks.text_generation):
  151. Preprocessors.text_gen_tokenizer,
  152. # cv
  153. (Models.tinynas_detection, Tasks.image_object_detection):
  154. Preprocessors.object_detection_tinynas_preprocessor,
  155. (Models.tinynas_damoyolo, Tasks.image_object_detection):
  156. Preprocessors.object_detection_tinynas_preprocessor,
  157. (Models.tinynas_damoyolo, Tasks.domain_specific_object_detection):
  158. Preprocessors.object_detection_tinynas_preprocessor,
  159. (Models.controllable_image_generation, Tasks.controllable_image_generation):
  160. Preprocessors.controllable_image_generation_preprocessor,
  161. }
  162. class Preprocessor(ABC):
  163. """Base of preprocessors.
  164. """
  165. def __init__(self, mode=ModeKeys.INFERENCE, *args, **kwargs):
  166. self._mode = mode
  167. assert self._mode in (ModeKeys.INFERENCE, ModeKeys.TRAIN,
  168. ModeKeys.EVAL)
  169. self.device = int(
  170. os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else None
  171. pass
  172. @abstractmethod
  173. def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
  174. pass
  175. @property
  176. def mode(self):
  177. return self._mode
  178. @mode.setter
  179. def mode(self, value):
  180. self._mode = value
  181. @classmethod
  182. def from_pretrained(cls,
  183. model_name_or_path: str,
  184. revision: Optional[str] = DEFAULT_MODEL_REVISION,
  185. cfg_dict: Config = None,
  186. preprocessor_mode=ModeKeys.INFERENCE,
  187. **kwargs):
  188. """Instantiate a preprocessor from local directory or remote model repo. Note
  189. that when loading from remote, the model revision can be specified.
  190. Args:
  191. model_name_or_path(str): A model dir or a model id used to load the preprocessor out.
  192. revision(str, `optional`): The revision used when the model_name_or_path is
  193. a model id of the remote hub. default `master`.
  194. cfg_dict(Config, `optional`): An optional config. If provided, it will replace
  195. the config read out of the `model_name_or_path`
  196. preprocessor_mode(str, `optional`): Specify the working mode of the preprocessor, can be `train`, `eval`,
  197. or `inference`. Default value `inference`.
  198. The preprocessor field in the config may contain two sub preprocessors:
  199. >>> {
  200. >>> "train": {
  201. >>> "type": "some-train-preprocessor"
  202. >>> },
  203. >>> "val": {
  204. >>> "type": "some-eval-preprocessor"
  205. >>> }
  206. >>> }
  207. In this scenario, the `train` preprocessor will be loaded in the `train` mode, the `val` preprocessor
  208. will be loaded in the `eval` or `inference` mode. The `mode` field in the preprocessor class
  209. will be assigned in all the modes.
  210. Or just one:
  211. >>> {
  212. >>> "type": "some-train-preprocessor"
  213. >>> }
  214. In this scenario, the sole preprocessor will be loaded in all the modes,
  215. and the `mode` field in the preprocessor class will be assigned.
  216. **kwargs:
  217. task(str, `optional`): The `Tasks` enumeration value to replace the task value
  218. read out of config in the `model_name_or_path`.
  219. This is useful when the preprocessor does not have a `type` field and the task to be used is not
  220. equal to the task of which the model is saved.
  221. Other kwargs will be directly fed into the preprocessor, to replace the default configs.
  222. Returns:
  223. The preprocessor instance.
  224. Examples:
  225. >>> from modelscope.preprocessors import Preprocessor
  226. >>> Preprocessor.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-base')
  227. """
  228. if not os.path.exists(model_name_or_path):
  229. model_dir = snapshot_download(
  230. model_name_or_path,
  231. revision=revision,
  232. user_agent={Invoke.KEY: Invoke.PREPROCESSOR},
  233. ignore_file_pattern=[
  234. '.*.bin',
  235. '.*.ts',
  236. '.*.pt',
  237. '.*.data-00000-of-00001',
  238. '.*.onnx',
  239. '.*.meta',
  240. '.*.pb',
  241. '.*.index',
  242. ])
  243. else:
  244. model_dir = model_name_or_path
  245. if cfg_dict is None:
  246. cfg = read_config(model_dir)
  247. else:
  248. cfg = cfg_dict
  249. task = cfg.task
  250. if 'task' in kwargs:
  251. task = kwargs.pop('task')
  252. field_name = Tasks.find_field_by_task(task)
  253. if 'field' in kwargs:
  254. field_name = kwargs.pop('field')
  255. sub_key = 'train' if preprocessor_mode == ModeKeys.TRAIN else 'val'
  256. if not hasattr(cfg, 'preprocessor') or len(cfg.preprocessor) == 0:
  257. logger.warning('No preprocessor field found in cfg.')
  258. preprocessor_cfg = ConfigDict()
  259. else:
  260. preprocessor_cfg = cfg.preprocessor
  261. if 'type' not in preprocessor_cfg:
  262. if sub_key in preprocessor_cfg:
  263. sub_cfg = getattr(preprocessor_cfg, sub_key)
  264. else:
  265. logger.warning(
  266. f'No {sub_key} key and type key found in '
  267. f'preprocessor domain of configuration.json file.')
  268. sub_cfg = preprocessor_cfg
  269. else:
  270. sub_cfg = preprocessor_cfg
  271. # TODO @wenmeng.zwm refine this logic when preprocessor has no model_dir param
  272. # for cv models.
  273. sub_cfg.update({'model_dir': model_dir})
  274. sub_cfg.update(kwargs)
  275. if 'type' in sub_cfg:
  276. if isinstance(sub_cfg, Sequence):
  277. # TODO: for Sequence, need adapt to `mode` and `mode_dir` args,
  278. # and add mode for Compose or other plans
  279. raise NotImplementedError('Not supported yet!')
  280. preprocessor = build_preprocessor(sub_cfg, field_name)
  281. else:
  282. logger.warning(
  283. f'Cannot find available config to build preprocessor at mode {preprocessor_mode}, '
  284. f'current config: {sub_cfg}. trying to build by task and model information.'
  285. )
  286. model_cfg = getattr(cfg, 'model', ConfigDict())
  287. model_type = model_cfg.type if hasattr(
  288. model_cfg, 'type') else getattr(model_cfg, 'model_type', None)
  289. if task is None or model_type is None:
  290. logger.warning(
  291. f'Find task: {task}, model type: {model_type}. '
  292. f'Insufficient information to build preprocessor, skip building preprocessor'
  293. )
  294. return None
  295. if (model_type, task) not in PREPROCESSOR_MAP:
  296. logger.info(
  297. f'No preprocessor key {(model_type, task)} found in PREPROCESSOR_MAP, '
  298. f'skip building preprocessor. If the pipeline runs normally, please ignore this log.'
  299. )
  300. return None
  301. sub_cfg = ConfigDict({
  302. 'type': PREPROCESSOR_MAP[(model_type, task)],
  303. **sub_cfg
  304. })
  305. preprocessor = build_preprocessor(sub_cfg, field_name)
  306. preprocessor.mode = preprocessor_mode
  307. sub_cfg.pop('model_dir', None)
  308. if not hasattr(preprocessor, 'cfg'):
  309. preprocessor.cfg = cfg
  310. return preprocessor
  311. def save_pretrained(self,
  312. target_folder: Union[str, os.PathLike],
  313. config: Optional[dict] = None,
  314. save_config_function: Callable = None):
  315. """Save the preprocessor, its configuration and other related files to a directory,
  316. so that it can be re-loaded
  317. By default, this method will save the preprocessor's config with mode `inference`.
  318. Args:
  319. target_folder (Union[str, os.PathLike]):
  320. Directory to which to save. Will be created if it doesn't exist.
  321. config (Optional[dict], optional):
  322. The config for the configuration.json
  323. save_config_function (Callable): The function used to save the configuration, call this function
  324. after the config is updated.
  325. """
  326. if config is None and hasattr(self, 'cfg'):
  327. config = self.cfg
  328. if config is not None:
  329. # Update the mode to `inference` in the preprocessor field.
  330. if 'preprocessor' in config and config['preprocessor'] is not None:
  331. if 'mode' in config['preprocessor']:
  332. config['preprocessor']['mode'] = 'inference'
  333. elif 'val' in config['preprocessor'] and 'mode' in config[
  334. 'preprocessor']['val']:
  335. config['preprocessor']['val']['mode'] = 'inference'
  336. if save_config_function is None:
  337. from modelscope.utils.checkpoint import save_configuration
  338. save_config_function = save_configuration
  339. save_config_function(target_folder, config)