| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- # Part of the implementation is borrowed from huggingface/transformers.
- import ast
- import functools
- import importlib
- import inspect
- import logging
- import os
- import os.path as osp
- import sys
- from collections import OrderedDict
- from importlib import import_module
- from itertools import chain
- from pathlib import Path
- from types import ModuleType
- from typing import Any
- from modelscope.utils.ast_utils import (INDEX_KEY, MODULE_KEY, REQUIREMENT_KEY,
- load_index)
- from modelscope.utils.error import * # noqa
- from modelscope.utils.logger import get_logger
- if sys.version_info < (3, 8):
- import importlib_metadata
- else:
- import importlib.metadata as importlib_metadata
- logger = get_logger(log_level=logging.WARNING)
- def import_modules_from_file(py_file: str):
- """ Import module from a certrain file
- Args:
- py_file: path to a python file to be imported
- Return:
- """
- dirname, basefile = os.path.split(py_file)
- if dirname == '':
- dirname = Path.cwd()
- module_name = osp.splitext(basefile)[0]
- sys.path.insert(0, dirname)
- validate_py_syntax(py_file)
- mod = import_module(module_name)
- sys.path.pop(0)
- return module_name, mod
- def is_method_overridden(method, base_class, derived_class):
- """Check if a method of base class is overridden in derived class.
- Args:
- method (str): the method name to check.
- base_class (type): the class of the base class.
- derived_class (type | Any): the class or instance of the derived class.
- """
- assert isinstance(base_class, type), \
- "base_class doesn't accept instance, Please pass class instead."
- if not isinstance(derived_class, type):
- derived_class = derived_class.__class__
- base_method = getattr(base_class, method)
- derived_method = getattr(derived_class, method)
- return derived_method != base_method
- def has_method(obj: object, method: str) -> bool:
- """Check whether the object has a method.
- Args:
- method (str): The method name to check.
- obj (object): The object to check.
- Returns:
- bool: True if the object has the method else False.
- """
- return hasattr(obj, method) and callable(getattr(obj, method))
- def import_modules(imports, allow_failed_imports=False):
- """Import modules from the given list of strings.
- Args:
- imports (list | str | None): The given module names to be imported.
- allow_failed_imports (bool): If True, the failed imports will return
- None. Otherwise, an ImportError is raise. Default: False.
- Returns:
- list[module] | module | None: The imported modules.
- Examples:
- >>> osp, sys = import_modules(
- ... ['os.path', 'sys'])
- >>> import os.path as osp_
- >>> import sys as sys_
- >>> assert osp == osp_
- >>> assert sys == sys_
- """
- if not imports:
- return
- single_import = False
- if isinstance(imports, str):
- single_import = True
- imports = [imports]
- if not isinstance(imports, list):
- raise TypeError(
- f'custom_imports must be a list but got type {type(imports)}')
- imported = []
- for imp in imports:
- if not isinstance(imp, str):
- raise TypeError(
- f'{imp} is of type {type(imp)} and cannot be imported.')
- try:
- imported_tmp = import_module(imp)
- except ImportError:
- if allow_failed_imports:
- logger.warning(f'{imp} failed to import and is ignored.')
- imported_tmp = None
- else:
- raise ImportError
- imported.append(imported_tmp)
- if single_import:
- imported = imported[0]
- return imported
- def validate_py_syntax(filename):
- with open(filename, 'r', encoding='utf-8') as f:
- # Setting encoding explicitly to resolve coding issue on windows
- content = f.read()
- try:
- ast.parse(content)
- except SyntaxError as e:
- raise SyntaxError('There are syntax errors in config '
- f'file {filename}: {e}')
- # following code borrows implementation from huggingface/transformers
- ENV_VARS_TRUE_VALUES = {'1', 'ON', 'YES', 'TRUE'}
- ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({'AUTO'})
- USE_TF = os.environ.get('USE_TF', 'AUTO').upper()
- USE_TORCH = os.environ.get('USE_TORCH', 'AUTO').upper()
- _torch_version = 'N/A'
- if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
- _torch_available = importlib.util.find_spec('torch') is not None
- if _torch_available:
- try:
- _torch_version = importlib_metadata.version('torch')
- logger.info(f'PyTorch version {_torch_version} Found.')
- except importlib_metadata.PackageNotFoundError:
- _torch_available = False
- else:
- logger.info('Disabling PyTorch because USE_TF is set')
- _torch_available = False
- _timm_available = importlib.util.find_spec('timm') is not None
- try:
- _timm_version = importlib_metadata.version('timm')
- logger.debug(f'Successfully imported timm version {_timm_version}')
- except importlib_metadata.PackageNotFoundError:
- _timm_available = False
- _tf_version = 'N/A'
- if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
- _tf_available = importlib.util.find_spec('tensorflow') is not None
- if _tf_available:
- candidates = (
- 'tensorflow',
- 'tensorflow-cpu',
- 'tensorflow-gpu',
- 'tf-nightly',
- 'tf-nightly-cpu',
- 'tf-nightly-gpu',
- 'intel-tensorflow',
- 'intel-tensorflow-avx512',
- 'tensorflow-rocm',
- 'tensorflow-macos',
- )
- _tf_version = None
- # For the metadata, we have to look for both tensorflow and tensorflow-cpu
- for pkg in candidates:
- try:
- _tf_version = importlib_metadata.version(pkg)
- break
- except importlib_metadata.PackageNotFoundError:
- pass
- _tf_available = _tf_version is not None
- if _tf_available:
- from packaging import version
- if version.parse(_tf_version) < version.parse('2'):
- pass
- else:
- logger.info(f'TensorFlow version {_tf_version} Found.')
- else:
- logger.info('Disabling Tensorflow because USE_TORCH is set')
- _tf_available = False
- def is_scipy_available():
- return importlib.util.find_spec('scipy') is not None
- def is_sklearn_available():
- if importlib.util.find_spec('sklearn') is None:
- return False
- return is_scipy_available() and importlib.util.find_spec('sklearn.metrics')
- def is_sentencepiece_available():
- return importlib.util.find_spec('sentencepiece') is not None
- def is_protobuf_available():
- if importlib.util.find_spec('google') is None:
- return False
- return importlib.util.find_spec('google.protobuf') is not None
- def is_tokenizers_available():
- return importlib.util.find_spec('tokenizers') is not None
- def is_timm_available():
- return _timm_available
- def is_torch_available():
- return _torch_available
- def is_torch_cuda_available():
- if is_torch_available():
- import torch
- return torch.cuda.is_available()
- else:
- return False
- def is_wenetruntime_available():
- return importlib.util.find_spec('wenetruntime') is not None
- def is_swift_available():
- return importlib.util.find_spec('swift') is not None
- def is_tf_available():
- return _tf_available
- def is_opencv_available():
- return importlib.util.find_spec('cv2') is not None
- def is_pillow_available():
- return importlib.util.find_spec('PIL.Image') is not None
- def _is_package_available_fn(pkg_name):
- return importlib.util.find_spec(pkg_name) is not None
- def is_package_available(pkg_name):
- return functools.partial(_is_package_available_fn, pkg_name)
- def is_espnet_available(pkg_name):
- return importlib.util.find_spec('espnet2') is not None \
- and importlib.util.find_spec('espnet')
- def is_vllm_available():
- return importlib.util.find_spec('vllm') is not None
- def is_transformers_available():
- return importlib.util.find_spec('transformers') is not None
- def is_diffusers_available():
- return importlib.util.find_spec('diffusers') is not None
- def is_tensorrt_llm_available():
- return importlib.util.find_spec('tensorrt_llm') is not None
- REQUIREMENTS_MAAPING = OrderedDict([
- ('protobuf', (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
- ('sentencepiece', (is_sentencepiece_available,
- SENTENCEPIECE_IMPORT_ERROR)),
- ('sklearn', (is_sklearn_available, SKLEARN_IMPORT_ERROR)),
- ('tf', (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
- ('tensorflow', (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
- ('timm', (is_timm_available, TIMM_IMPORT_ERROR)),
- ('tokenizers', (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
- ('torch', (is_torch_available, PYTORCH_IMPORT_ERROR)),
- ('wenetruntime',
- (is_wenetruntime_available,
- WENETRUNTIME_IMPORT_ERROR.replace('TORCH_VER', _torch_version))),
- ('scipy', (is_scipy_available, SCIPY_IMPORT_ERROR)),
- ('cv2', (is_opencv_available, OPENCV_IMPORT_ERROR)),
- ('PIL', (is_pillow_available, PILLOW_IMPORT_ERROR)),
- ('pai-easynlp', (is_package_available('easynlp'), EASYNLP_IMPORT_ERROR)),
- ('espnet2', (is_espnet_available,
- GENERAL_IMPORT_ERROR.replace('REQ', 'espnet'))),
- ('espnet', (is_espnet_available,
- GENERAL_IMPORT_ERROR.replace('REQ', 'espnet'))),
- ('funasr', (is_package_available('funasr'), AUDIO_IMPORT_ERROR)),
- ('kwsbp', (is_package_available('kwsbp'), AUDIO_IMPORT_ERROR)),
- ('decord', (is_package_available('decord'), DECORD_IMPORT_ERROR)),
- ('deepspeed', (is_package_available('deepspeed'), DEEPSPEED_IMPORT_ERROR)),
- ('fairseq', (is_package_available('fairseq'), FAIRSEQ_IMPORT_ERROR)),
- ('fasttext', (is_package_available('fasttext'), FASTTEXT_IMPORT_ERROR)),
- ('megatron_util', (is_package_available('megatron_util'),
- MEGATRON_UTIL_IMPORT_ERROR)),
- ('text2sql_lgesql', (is_package_available('text2sql_lgesql'),
- TEXT2SQL_LGESQL_IMPORT_ERROR)),
- ('mpi4py', (is_package_available('mpi4py'), MPI4PY_IMPORT_ERROR)),
- ('open_clip', (is_package_available('open_clip'), OPENCLIP_IMPORT_ERROR)),
- ('taming', (is_package_available('taming'), TAMING_IMPORT_ERROR)),
- ('xformers', (is_package_available('xformers'), XFORMERS_IMPORT_ERROR)),
- ('swift', (is_package_available('swift'), SWIFT_IMPORT_ERROR)),
- ])
- SYSTEM_PACKAGE = set(['os', 'sys', 'typing'])
- def requires(obj, requirements):
- if not isinstance(requirements, (list, tuple)):
- requirements = [requirements]
- if isinstance(obj, str):
- name = obj
- else:
- name = obj.__name__ if hasattr(obj,
- '__name__') else obj.__class__.__name__
- checks = []
- for req in requirements:
- if req == '' or req in SYSTEM_PACKAGE:
- continue
- if req in REQUIREMENTS_MAAPING:
- check = REQUIREMENTS_MAAPING[req]
- else:
- check_fn = is_package_available(req)
- err_msg = GENERAL_IMPORT_ERROR.replace('REQ', req)
- check = (check_fn, err_msg)
- checks.append(check)
- failed = [msg.format(name) for available, msg in checks if not available()]
- if failed:
- raise ImportError(''.join(failed))
- def torch_required(func):
- # Chose a different decorator name than in tests so it's clear they are not the same.
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- if is_torch_available():
- return func(*args, **kwargs)
- else:
- raise ImportError(f'Method `{func.__name__}` requires PyTorch.')
- return wrapper
- def tf_required(func):
- # Chose a different decorator name than in tests so it's clear they are not the same.
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- if is_tf_available():
- return func(*args, **kwargs)
- else:
- raise ImportError(f'Method `{func.__name__}` requires TF.')
- return wrapper
- class LazyImportModule(ModuleType):
- _AST_INDEX = None
- def __init__(self,
- name,
- module_file,
- import_structure,
- module_spec=None,
- extra_objects=None,
- try_to_pre_import=False,
- extra_import_func=None):
- super().__init__(name)
- self._modules = set(import_structure.keys())
- self._class_to_module = {}
- for key, values in import_structure.items():
- for value in values:
- self._class_to_module[value] = key
- # Needed for autocompletion in an IDE
- self.__all__ = list(import_structure.keys()) + list(
- chain(*import_structure.values()))
- self.__file__ = module_file
- self.__spec__ = module_spec
- self.__path__ = [os.path.dirname(module_file)]
- self._objects = {} if extra_objects is None else extra_objects
- self._name = name
- self._import_structure = import_structure
- self._extra_import_func = extra_import_func
- if try_to_pre_import:
- self._try_to_import()
- def _try_to_import(self):
- for sub_module in self._class_to_module.keys():
- try:
- getattr(self, sub_module)
- except Exception as e:
- logger.warning(
- f'pre load module {sub_module} error, please check {e}')
- # Needed for autocompletion in an IDE
- def __dir__(self):
- result = super().__dir__()
- # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
- # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
- for attr in self.__all__:
- if attr not in result:
- result.append(attr)
- return result
- def __getattr__(self, name: str) -> Any:
- if name in self._objects:
- return self._objects[name]
- if name in self._modules:
- value = self._get_module(name)
- elif name in self._class_to_module.keys():
- module = self._get_module(self._class_to_module[name])
- value = getattr(module, name)
- elif self._extra_import_func is not None:
- value = self._extra_import_func(name)
- if value is None:
- raise AttributeError(
- f'module {self.__name__} has no attribute {name}')
- else:
- raise AttributeError(
- f'module {self.__name__} has no attribute {name}')
- setattr(self, name, value)
- return value
- def _get_module(self, module_name: str):
- try:
- module_name_full = self.__name__ + '.' + module_name
- if not any(
- module_name_full.startswith(f'modelscope.{prefix}')
- for prefix in ['hub', 'utils', 'version', 'fileio']):
- # check requirements before module import
- ast_index = self.get_ast_index()
- if module_name_full in ast_index[REQUIREMENT_KEY]:
- requirements = ast_index[REQUIREMENT_KEY][module_name_full]
- requires(module_name_full, requirements)
- return importlib.import_module('.' + module_name, self.__name__)
- except Exception as e:
- raise RuntimeError(
- f'Failed to import {self.__name__}.{module_name} because of the following error '
- f'(look up to see its traceback):\n{e}') from e
- def __reduce__(self):
- return self.__class__, (self._name, self.__file__,
- self._import_structure)
- @staticmethod
- def get_ast_index():
- if LazyImportModule._AST_INDEX is None:
- LazyImportModule._AST_INDEX = load_index()
- return LazyImportModule._AST_INDEX
- @staticmethod
- def import_module(signature):
- """ import a lazy import module using signature
- Args:
- signature (tuple): a tuple of str, (registry_name, registry_group_name, module_name)
- """
- ast_index = LazyImportModule.get_ast_index()
- if signature in ast_index[INDEX_KEY]:
- mod_index = ast_index[INDEX_KEY][signature]
- module_name = mod_index[MODULE_KEY]
- if module_name in ast_index[REQUIREMENT_KEY]:
- requirements = ast_index[REQUIREMENT_KEY][module_name]
- requires(module_name, requirements)
- importlib.import_module(module_name)
- else:
- logger.warning(f'{signature} not found in ast index file')
- def has_attr_in_class(cls, attribute_name) -> bool:
- """
- Determine if attribute in specific class.
- Args:
- cls: target class.
- attribute_name: the attribute name.
- Returns:
- The attribute in the class or not.
- """
- init_method = cls.__init__
- signature = inspect.signature(init_method)
- parameters = signature.parameters
- param_names = list(parameters.keys())
- return attribute_name in param_names
|