import inspect import os from types import MethodType from typing import Any, List, Optional from modelscope import get_logger from modelscope.metainfo import Tasks from modelscope.utils.ast_utils import INDEX_KEY from modelscope.utils.import_utils import (LazyImportModule, is_torch_available, is_transformers_available) logger = get_logger() def can_load_by_ms(model_dir: str, task_name: Optional[str], model_type: Optional[str]) -> bool: if model_type is None or task_name is None: return False if ('MODELS', task_name, model_type) in LazyImportModule.get_ast_index()[INDEX_KEY]: return True ms_wrapper_path = os.path.join(model_dir, 'ms_wrapper.py') if os.path.exists(ms_wrapper_path): return True return False def fix_upgrade(module_obj: Any): from transformers import PreTrainedModel if hasattr(module_obj, '_set_gradient_checkpointing') \ and 'value' in inspect.signature( module_obj._set_gradient_checkpointing).parameters.keys() \ and 'modelscope.' in str(module_obj.__class__): module_obj._set_gradient_checkpointing = MethodType( PreTrainedModel._set_gradient_checkpointing, module_obj) def post_init(self, *args, **kwargs): fix_upgrade(self) self.post_init_origin(*args, **kwargs) def fix_transformers_upgrade(): if is_transformers_available() and is_torch_available(): # from 4.35.0, transformers changes its arguments of _set_gradient_checkpointing import transformers from transformers import PreTrainedModel from packaging import version if version.parse(transformers.__version__) >= version.parse('4.35.0') \ and not hasattr(PreTrainedModel, 'post_init_origin'): PreTrainedModel.post_init_origin = PreTrainedModel.post_init PreTrainedModel.post_init = post_init def _can_load_by_hf_automodel(automodel_class: type, config) -> bool: automodel_class_name = automodel_class.__name__ if type(config) in automodel_class._model_mapping.keys(): return True if hasattr(config, 'auto_map') and automodel_class_name in config.auto_map: return True return False def get_default_automodel(config) -> Optional[type]: import modelscope.utils.hf_util as hf_util if not hasattr(config, 'auto_map'): return None auto_map = config.auto_map automodel_list = [k for k in auto_map.keys() if k.startswith('AutoModel')] if len(automodel_list) == 1: return getattr(hf_util, automodel_list[0]) if len(automodel_list) > 1 and len( set([auto_map[k] for k in automodel_list])) == 1: return getattr(hf_util, automodel_list[0]) return None def get_hf_automodel_class(model_dir: str, task_name: Optional[str]) -> Optional[type]: from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForTokenClassification, AutoModelForSequenceClassification) automodel_mapping = { Tasks.backbone: AutoModel, Tasks.chat: AutoModelForCausalLM, Tasks.text_generation: AutoModelForCausalLM, Tasks.text_classification: AutoModelForSequenceClassification, Tasks.token_classification: AutoModelForTokenClassification, } config_path = os.path.join(model_dir, 'config.json') if not os.path.exists(config_path): return None try: config = AutoConfig.from_pretrained(model_dir, trust_remote_code=False) if task_name is None: automodel_class = get_default_automodel(config) else: automodel_class = automodel_mapping.get(task_name, None) if automodel_class is None: return None if _can_load_by_hf_automodel(automodel_class, config): return automodel_class if (automodel_class is AutoModelForCausalLM and _can_load_by_hf_automodel(AutoModelForSeq2SeqLM, config)): return AutoModelForSeq2SeqLM return None except Exception: return None def try_to_load_hf_model(model_dir: str, task_name: str, use_hf: Optional[bool], **kwargs): automodel_class = get_hf_automodel_class(model_dir, task_name) if use_hf and automodel_class is None: raise ValueError(f'Model import failed. You used `use_hf={use_hf}`, ' 'but the model is not a model of hf.') model = None if automodel_class is not None: # use hf model = automodel_class.from_pretrained(model_dir, **kwargs) return model def check_model_from_owner_group(model_dir: str, owner_group: List[str] = None) -> bool: """This checking is for the torch.load, this function may eval malicious code into memory Args: model_dir: The local model_dir owner_group: The owner group to trust Returns: bool: Whether the group can be trusted """ if not model_dir: return False if owner_group is None: owner_group = ['iic', 'damo'] model_dir = model_dir.rstrip('/').rstrip('\\') model_dir = os.path.dirname(model_dir) group = os.path.basename(model_dir) return group in owner_group