import_utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. # Part of the implementation is borrowed from huggingface/transformers.
  3. import ast
  4. import functools
  5. import importlib
  6. import inspect
  7. import logging
  8. import os
  9. import os.path as osp
  10. import sys
  11. from collections import OrderedDict
  12. from importlib import import_module
  13. from itertools import chain
  14. from pathlib import Path
  15. from types import ModuleType
  16. from typing import Any
  17. from modelscope.utils.ast_utils import (INDEX_KEY, MODULE_KEY, REQUIREMENT_KEY,
  18. load_index)
  19. from modelscope.utils.error import * # noqa
  20. from modelscope.utils.logger import get_logger
  21. if sys.version_info < (3, 8):
  22. import importlib_metadata
  23. else:
  24. import importlib.metadata as importlib_metadata
  25. logger = get_logger(log_level=logging.WARNING)
  26. def import_modules_from_file(py_file: str):
  27. """ Import module from a certrain file
  28. Args:
  29. py_file: path to a python file to be imported
  30. Return:
  31. """
  32. dirname, basefile = os.path.split(py_file)
  33. if dirname == '':
  34. dirname = Path.cwd()
  35. module_name = osp.splitext(basefile)[0]
  36. sys.path.insert(0, dirname)
  37. validate_py_syntax(py_file)
  38. mod = import_module(module_name)
  39. sys.path.pop(0)
  40. return module_name, mod
  41. def is_method_overridden(method, base_class, derived_class):
  42. """Check if a method of base class is overridden in derived class.
  43. Args:
  44. method (str): the method name to check.
  45. base_class (type): the class of the base class.
  46. derived_class (type | Any): the class or instance of the derived class.
  47. """
  48. assert isinstance(base_class, type), \
  49. "base_class doesn't accept instance, Please pass class instead."
  50. if not isinstance(derived_class, type):
  51. derived_class = derived_class.__class__
  52. base_method = getattr(base_class, method)
  53. derived_method = getattr(derived_class, method)
  54. return derived_method != base_method
  55. def has_method(obj: object, method: str) -> bool:
  56. """Check whether the object has a method.
  57. Args:
  58. method (str): The method name to check.
  59. obj (object): The object to check.
  60. Returns:
  61. bool: True if the object has the method else False.
  62. """
  63. return hasattr(obj, method) and callable(getattr(obj, method))
  64. def import_modules(imports, allow_failed_imports=False):
  65. """Import modules from the given list of strings.
  66. Args:
  67. imports (list | str | None): The given module names to be imported.
  68. allow_failed_imports (bool): If True, the failed imports will return
  69. None. Otherwise, an ImportError is raise. Default: False.
  70. Returns:
  71. list[module] | module | None: The imported modules.
  72. Examples:
  73. >>> osp, sys = import_modules(
  74. ... ['os.path', 'sys'])
  75. >>> import os.path as osp_
  76. >>> import sys as sys_
  77. >>> assert osp == osp_
  78. >>> assert sys == sys_
  79. """
  80. if not imports:
  81. return
  82. single_import = False
  83. if isinstance(imports, str):
  84. single_import = True
  85. imports = [imports]
  86. if not isinstance(imports, list):
  87. raise TypeError(
  88. f'custom_imports must be a list but got type {type(imports)}')
  89. imported = []
  90. for imp in imports:
  91. if not isinstance(imp, str):
  92. raise TypeError(
  93. f'{imp} is of type {type(imp)} and cannot be imported.')
  94. try:
  95. imported_tmp = import_module(imp)
  96. except ImportError:
  97. if allow_failed_imports:
  98. logger.warning(f'{imp} failed to import and is ignored.')
  99. imported_tmp = None
  100. else:
  101. raise ImportError
  102. imported.append(imported_tmp)
  103. if single_import:
  104. imported = imported[0]
  105. return imported
  106. def validate_py_syntax(filename):
  107. with open(filename, 'r', encoding='utf-8') as f:
  108. # Setting encoding explicitly to resolve coding issue on windows
  109. content = f.read()
  110. try:
  111. ast.parse(content)
  112. except SyntaxError as e:
  113. raise SyntaxError('There are syntax errors in config '
  114. f'file {filename}: {e}')
  115. # following code borrows implementation from huggingface/transformers
  116. ENV_VARS_TRUE_VALUES = {'1', 'ON', 'YES', 'TRUE'}
  117. ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({'AUTO'})
  118. USE_TF = os.environ.get('USE_TF', 'AUTO').upper()
  119. USE_TORCH = os.environ.get('USE_TORCH', 'AUTO').upper()
  120. _torch_version = 'N/A'
  121. if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
  122. _torch_available = importlib.util.find_spec('torch') is not None
  123. if _torch_available:
  124. try:
  125. _torch_version = importlib_metadata.version('torch')
  126. logger.info(f'PyTorch version {_torch_version} Found.')
  127. except importlib_metadata.PackageNotFoundError:
  128. _torch_available = False
  129. else:
  130. logger.info('Disabling PyTorch because USE_TF is set')
  131. _torch_available = False
  132. _timm_available = importlib.util.find_spec('timm') is not None
  133. try:
  134. _timm_version = importlib_metadata.version('timm')
  135. logger.debug(f'Successfully imported timm version {_timm_version}')
  136. except importlib_metadata.PackageNotFoundError:
  137. _timm_available = False
  138. _tf_version = 'N/A'
  139. if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
  140. _tf_available = importlib.util.find_spec('tensorflow') is not None
  141. if _tf_available:
  142. candidates = (
  143. 'tensorflow',
  144. 'tensorflow-cpu',
  145. 'tensorflow-gpu',
  146. 'tf-nightly',
  147. 'tf-nightly-cpu',
  148. 'tf-nightly-gpu',
  149. 'intel-tensorflow',
  150. 'intel-tensorflow-avx512',
  151. 'tensorflow-rocm',
  152. 'tensorflow-macos',
  153. )
  154. _tf_version = None
  155. # For the metadata, we have to look for both tensorflow and tensorflow-cpu
  156. for pkg in candidates:
  157. try:
  158. _tf_version = importlib_metadata.version(pkg)
  159. break
  160. except importlib_metadata.PackageNotFoundError:
  161. pass
  162. _tf_available = _tf_version is not None
  163. if _tf_available:
  164. from packaging import version
  165. if version.parse(_tf_version) < version.parse('2'):
  166. pass
  167. else:
  168. logger.info(f'TensorFlow version {_tf_version} Found.')
  169. else:
  170. logger.info('Disabling Tensorflow because USE_TORCH is set')
  171. _tf_available = False
  172. def is_scipy_available():
  173. return importlib.util.find_spec('scipy') is not None
  174. def is_sklearn_available():
  175. if importlib.util.find_spec('sklearn') is None:
  176. return False
  177. return is_scipy_available() and importlib.util.find_spec('sklearn.metrics')
  178. def is_sentencepiece_available():
  179. return importlib.util.find_spec('sentencepiece') is not None
  180. def is_protobuf_available():
  181. if importlib.util.find_spec('google') is None:
  182. return False
  183. return importlib.util.find_spec('google.protobuf') is not None
  184. def is_tokenizers_available():
  185. return importlib.util.find_spec('tokenizers') is not None
  186. def is_timm_available():
  187. return _timm_available
  188. def is_torch_available():
  189. return _torch_available
  190. def is_torch_cuda_available():
  191. if is_torch_available():
  192. import torch
  193. return torch.cuda.is_available()
  194. else:
  195. return False
  196. def is_wenetruntime_available():
  197. return importlib.util.find_spec('wenetruntime') is not None
  198. def is_swift_available():
  199. return importlib.util.find_spec('swift') is not None
  200. def is_tf_available():
  201. return _tf_available
  202. def is_opencv_available():
  203. return importlib.util.find_spec('cv2') is not None
  204. def is_pillow_available():
  205. return importlib.util.find_spec('PIL.Image') is not None
  206. def _is_package_available_fn(pkg_name):
  207. return importlib.util.find_spec(pkg_name) is not None
  208. def is_package_available(pkg_name):
  209. return functools.partial(_is_package_available_fn, pkg_name)
  210. def is_espnet_available(pkg_name):
  211. return importlib.util.find_spec('espnet2') is not None \
  212. and importlib.util.find_spec('espnet')
  213. def is_vllm_available():
  214. return importlib.util.find_spec('vllm') is not None
  215. def is_transformers_available():
  216. return importlib.util.find_spec('transformers') is not None
  217. def is_diffusers_available():
  218. return importlib.util.find_spec('diffusers') is not None
  219. def is_tensorrt_llm_available():
  220. return importlib.util.find_spec('tensorrt_llm') is not None
  221. REQUIREMENTS_MAAPING = OrderedDict([
  222. ('protobuf', (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
  223. ('sentencepiece', (is_sentencepiece_available,
  224. SENTENCEPIECE_IMPORT_ERROR)),
  225. ('sklearn', (is_sklearn_available, SKLEARN_IMPORT_ERROR)),
  226. ('tf', (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
  227. ('tensorflow', (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
  228. ('timm', (is_timm_available, TIMM_IMPORT_ERROR)),
  229. ('tokenizers', (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
  230. ('torch', (is_torch_available, PYTORCH_IMPORT_ERROR)),
  231. ('wenetruntime',
  232. (is_wenetruntime_available,
  233. WENETRUNTIME_IMPORT_ERROR.replace('TORCH_VER', _torch_version))),
  234. ('scipy', (is_scipy_available, SCIPY_IMPORT_ERROR)),
  235. ('cv2', (is_opencv_available, OPENCV_IMPORT_ERROR)),
  236. ('PIL', (is_pillow_available, PILLOW_IMPORT_ERROR)),
  237. ('pai-easynlp', (is_package_available('easynlp'), EASYNLP_IMPORT_ERROR)),
  238. ('espnet2', (is_espnet_available,
  239. GENERAL_IMPORT_ERROR.replace('REQ', 'espnet'))),
  240. ('espnet', (is_espnet_available,
  241. GENERAL_IMPORT_ERROR.replace('REQ', 'espnet'))),
  242. ('funasr', (is_package_available('funasr'), AUDIO_IMPORT_ERROR)),
  243. ('kwsbp', (is_package_available('kwsbp'), AUDIO_IMPORT_ERROR)),
  244. ('decord', (is_package_available('decord'), DECORD_IMPORT_ERROR)),
  245. ('deepspeed', (is_package_available('deepspeed'), DEEPSPEED_IMPORT_ERROR)),
  246. ('fairseq', (is_package_available('fairseq'), FAIRSEQ_IMPORT_ERROR)),
  247. ('fasttext', (is_package_available('fasttext'), FASTTEXT_IMPORT_ERROR)),
  248. ('megatron_util', (is_package_available('megatron_util'),
  249. MEGATRON_UTIL_IMPORT_ERROR)),
  250. ('text2sql_lgesql', (is_package_available('text2sql_lgesql'),
  251. TEXT2SQL_LGESQL_IMPORT_ERROR)),
  252. ('mpi4py', (is_package_available('mpi4py'), MPI4PY_IMPORT_ERROR)),
  253. ('open_clip', (is_package_available('open_clip'), OPENCLIP_IMPORT_ERROR)),
  254. ('taming', (is_package_available('taming'), TAMING_IMPORT_ERROR)),
  255. ('xformers', (is_package_available('xformers'), XFORMERS_IMPORT_ERROR)),
  256. ('swift', (is_package_available('swift'), SWIFT_IMPORT_ERROR)),
  257. ])
  258. SYSTEM_PACKAGE = set(['os', 'sys', 'typing'])
  259. def requires(obj, requirements):
  260. if not isinstance(requirements, (list, tuple)):
  261. requirements = [requirements]
  262. if isinstance(obj, str):
  263. name = obj
  264. else:
  265. name = obj.__name__ if hasattr(obj,
  266. '__name__') else obj.__class__.__name__
  267. checks = []
  268. for req in requirements:
  269. if req == '' or req in SYSTEM_PACKAGE:
  270. continue
  271. if req in REQUIREMENTS_MAAPING:
  272. check = REQUIREMENTS_MAAPING[req]
  273. else:
  274. check_fn = is_package_available(req)
  275. err_msg = GENERAL_IMPORT_ERROR.replace('REQ', req)
  276. check = (check_fn, err_msg)
  277. checks.append(check)
  278. failed = [msg.format(name) for available, msg in checks if not available()]
  279. if failed:
  280. raise ImportError(''.join(failed))
  281. def torch_required(func):
  282. # Chose a different decorator name than in tests so it's clear they are not the same.
  283. @functools.wraps(func)
  284. def wrapper(*args, **kwargs):
  285. if is_torch_available():
  286. return func(*args, **kwargs)
  287. else:
  288. raise ImportError(f'Method `{func.__name__}` requires PyTorch.')
  289. return wrapper
  290. def tf_required(func):
  291. # Chose a different decorator name than in tests so it's clear they are not the same.
  292. @functools.wraps(func)
  293. def wrapper(*args, **kwargs):
  294. if is_tf_available():
  295. return func(*args, **kwargs)
  296. else:
  297. raise ImportError(f'Method `{func.__name__}` requires TF.')
  298. return wrapper
  299. class LazyImportModule(ModuleType):
  300. _AST_INDEX = None
  301. def __init__(self,
  302. name,
  303. module_file,
  304. import_structure,
  305. module_spec=None,
  306. extra_objects=None,
  307. try_to_pre_import=False,
  308. extra_import_func=None):
  309. super().__init__(name)
  310. self._modules = set(import_structure.keys())
  311. self._class_to_module = {}
  312. for key, values in import_structure.items():
  313. for value in values:
  314. self._class_to_module[value] = key
  315. # Needed for autocompletion in an IDE
  316. self.__all__ = list(import_structure.keys()) + list(
  317. chain(*import_structure.values()))
  318. self.__file__ = module_file
  319. self.__spec__ = module_spec
  320. self.__path__ = [os.path.dirname(module_file)]
  321. self._objects = {} if extra_objects is None else extra_objects
  322. self._name = name
  323. self._import_structure = import_structure
  324. self._extra_import_func = extra_import_func
  325. if try_to_pre_import:
  326. self._try_to_import()
  327. def _try_to_import(self):
  328. for sub_module in self._class_to_module.keys():
  329. try:
  330. getattr(self, sub_module)
  331. except Exception as e:
  332. logger.warning(
  333. f'pre load module {sub_module} error, please check {e}')
  334. # Needed for autocompletion in an IDE
  335. def __dir__(self):
  336. result = super().__dir__()
  337. # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
  338. # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
  339. for attr in self.__all__:
  340. if attr not in result:
  341. result.append(attr)
  342. return result
  343. def __getattr__(self, name: str) -> Any:
  344. if name in self._objects:
  345. return self._objects[name]
  346. if name in self._modules:
  347. value = self._get_module(name)
  348. elif name in self._class_to_module.keys():
  349. module = self._get_module(self._class_to_module[name])
  350. value = getattr(module, name)
  351. elif self._extra_import_func is not None:
  352. value = self._extra_import_func(name)
  353. if value is None:
  354. raise AttributeError(
  355. f'module {self.__name__} has no attribute {name}')
  356. else:
  357. raise AttributeError(
  358. f'module {self.__name__} has no attribute {name}')
  359. setattr(self, name, value)
  360. return value
  361. def _get_module(self, module_name: str):
  362. try:
  363. module_name_full = self.__name__ + '.' + module_name
  364. if not any(
  365. module_name_full.startswith(f'modelscope.{prefix}')
  366. for prefix in ['hub', 'utils', 'version', 'fileio']):
  367. # check requirements before module import
  368. ast_index = self.get_ast_index()
  369. if module_name_full in ast_index[REQUIREMENT_KEY]:
  370. requirements = ast_index[REQUIREMENT_KEY][module_name_full]
  371. requires(module_name_full, requirements)
  372. return importlib.import_module('.' + module_name, self.__name__)
  373. except Exception as e:
  374. raise RuntimeError(
  375. f'Failed to import {self.__name__}.{module_name} because of the following error '
  376. f'(look up to see its traceback):\n{e}') from e
  377. def __reduce__(self):
  378. return self.__class__, (self._name, self.__file__,
  379. self._import_structure)
  380. @staticmethod
  381. def get_ast_index():
  382. if LazyImportModule._AST_INDEX is None:
  383. LazyImportModule._AST_INDEX = load_index()
  384. return LazyImportModule._AST_INDEX
  385. @staticmethod
  386. def import_module(signature):
  387. """ import a lazy import module using signature
  388. Args:
  389. signature (tuple): a tuple of str, (registry_name, registry_group_name, module_name)
  390. """
  391. ast_index = LazyImportModule.get_ast_index()
  392. if signature in ast_index[INDEX_KEY]:
  393. mod_index = ast_index[INDEX_KEY][signature]
  394. module_name = mod_index[MODULE_KEY]
  395. if module_name in ast_index[REQUIREMENT_KEY]:
  396. requirements = ast_index[REQUIREMENT_KEY][module_name]
  397. requires(module_name, requirements)
  398. importlib.import_module(module_name)
  399. else:
  400. logger.warning(f'{signature} not found in ast index file')
  401. def has_attr_in_class(cls, attribute_name) -> bool:
  402. """
  403. Determine if attribute in specific class.
  404. Args:
  405. cls: target class.
  406. attribute_name: the attribute name.
  407. Returns:
  408. The attribute in the class or not.
  409. """
  410. init_method = cls.__init__
  411. signature = inspect.signature(init_method)
  412. parameters = signature.parameters
  413. param_names = list(parameters.keys())
  414. return attribute_name in param_names