automodel_utils.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import inspect
  2. import os
  3. from types import MethodType
  4. from typing import Any, List, Optional
  5. from modelscope import get_logger
  6. from modelscope.metainfo import Tasks
  7. from modelscope.utils.ast_utils import INDEX_KEY
  8. from modelscope.utils.import_utils import (LazyImportModule,
  9. is_torch_available,
  10. is_transformers_available)
  11. logger = get_logger()
  12. def can_load_by_ms(model_dir: str, task_name: Optional[str],
  13. model_type: Optional[str]) -> bool:
  14. if model_type is None or task_name is None:
  15. return False
  16. if ('MODELS', task_name,
  17. model_type) in LazyImportModule.get_ast_index()[INDEX_KEY]:
  18. return True
  19. ms_wrapper_path = os.path.join(model_dir, 'ms_wrapper.py')
  20. if os.path.exists(ms_wrapper_path):
  21. return True
  22. return False
  23. def fix_upgrade(module_obj: Any):
  24. from transformers import PreTrainedModel
  25. if hasattr(module_obj, '_set_gradient_checkpointing') \
  26. and 'value' in inspect.signature(
  27. module_obj._set_gradient_checkpointing).parameters.keys() \
  28. and 'modelscope.' in str(module_obj.__class__):
  29. module_obj._set_gradient_checkpointing = MethodType(
  30. PreTrainedModel._set_gradient_checkpointing, module_obj)
  31. def post_init(self, *args, **kwargs):
  32. fix_upgrade(self)
  33. self.post_init_origin(*args, **kwargs)
  34. def fix_transformers_upgrade():
  35. if is_transformers_available() and is_torch_available():
  36. # from 4.35.0, transformers changes its arguments of _set_gradient_checkpointing
  37. import transformers
  38. from transformers import PreTrainedModel
  39. from packaging import version
  40. if version.parse(transformers.__version__) >= version.parse('4.35.0') \
  41. and not hasattr(PreTrainedModel, 'post_init_origin'):
  42. PreTrainedModel.post_init_origin = PreTrainedModel.post_init
  43. PreTrainedModel.post_init = post_init
  44. def _can_load_by_hf_automodel(automodel_class: type, config) -> bool:
  45. automodel_class_name = automodel_class.__name__
  46. if type(config) in automodel_class._model_mapping.keys():
  47. return True
  48. if hasattr(config, 'auto_map') and automodel_class_name in config.auto_map:
  49. return True
  50. return False
  51. def get_default_automodel(config) -> Optional[type]:
  52. import modelscope.utils.hf_util as hf_util
  53. if not hasattr(config, 'auto_map'):
  54. return None
  55. auto_map = config.auto_map
  56. automodel_list = [k for k in auto_map.keys() if k.startswith('AutoModel')]
  57. if len(automodel_list) == 1:
  58. return getattr(hf_util, automodel_list[0])
  59. if len(automodel_list) > 1 and len(
  60. set([auto_map[k] for k in automodel_list])) == 1:
  61. return getattr(hf_util, automodel_list[0])
  62. return None
  63. def get_hf_automodel_class(model_dir: str,
  64. task_name: Optional[str]) -> Optional[type]:
  65. from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM,
  66. AutoModelForSeq2SeqLM,
  67. AutoModelForTokenClassification,
  68. AutoModelForSequenceClassification)
  69. automodel_mapping = {
  70. Tasks.backbone: AutoModel,
  71. Tasks.chat: AutoModelForCausalLM,
  72. Tasks.text_generation: AutoModelForCausalLM,
  73. Tasks.text_classification: AutoModelForSequenceClassification,
  74. Tasks.token_classification: AutoModelForTokenClassification,
  75. }
  76. config_path = os.path.join(model_dir, 'config.json')
  77. if not os.path.exists(config_path):
  78. return None
  79. try:
  80. config = AutoConfig.from_pretrained(model_dir, trust_remote_code=False)
  81. if task_name is None:
  82. automodel_class = get_default_automodel(config)
  83. else:
  84. automodel_class = automodel_mapping.get(task_name, None)
  85. if automodel_class is None:
  86. return None
  87. if _can_load_by_hf_automodel(automodel_class, config):
  88. return automodel_class
  89. if (automodel_class is AutoModelForCausalLM
  90. and _can_load_by_hf_automodel(AutoModelForSeq2SeqLM, config)):
  91. return AutoModelForSeq2SeqLM
  92. return None
  93. except Exception:
  94. return None
  95. def try_to_load_hf_model(model_dir: str, task_name: str,
  96. use_hf: Optional[bool], **kwargs):
  97. automodel_class = get_hf_automodel_class(model_dir, task_name)
  98. if use_hf and automodel_class is None:
  99. raise ValueError(f'Model import failed. You used `use_hf={use_hf}`, '
  100. 'but the model is not a model of hf.')
  101. model = None
  102. if automodel_class is not None:
  103. # use hf
  104. model = automodel_class.from_pretrained(model_dir, **kwargs)
  105. return model
  106. def check_model_from_owner_group(model_dir: str,
  107. owner_group: List[str] = None) -> bool:
  108. """This checking is for the torch.load, this function may eval malicious code into memory
  109. Args:
  110. model_dir: The local model_dir
  111. owner_group: The owner group to trust
  112. Returns:
  113. bool: Whether the group can be trusted
  114. """
  115. if not model_dir:
  116. return False
  117. if owner_group is None:
  118. owner_group = ['iic', 'damo']
  119. model_dir = model_dir.rstrip('/').rstrip('\\')
  120. model_dir = os.path.dirname(model_dir)
  121. group = os.path.basename(model_dir)
  122. return group in owner_group