ast_utils.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import ast
  3. import hashlib
  4. import logging
  5. import os
  6. import os.path as osp
  7. import time
  8. import traceback
  9. from datetime import datetime
  10. from functools import reduce
  11. from pathlib import Path
  12. from typing import Union
  13. import json
  14. from modelscope import version
  15. # do not delete
  16. from modelscope.metainfo import (CustomDatasets, Heads, Hooks, LR_Schedulers,
  17. Metrics, Models, Optimizers, Pipelines,
  18. Preprocessors, TaskModels, Trainers)
  19. from modelscope.utils.constant import Fields, Tasks
  20. from modelscope.utils.file_utils import get_modelscope_cache_dir
  21. from modelscope.utils.registry import default_group
  22. p = Path(__file__)
  23. # get the path of package 'modelscope'
  24. SKIP_FUNCTION_SCANNING = True
  25. MODELSCOPE_PATH = p.resolve().parents[1]
  26. INDEXER_FILE_DIR = get_modelscope_cache_dir()
  27. REGISTER_MODULE = 'register_module'
  28. IGNORED_PACKAGES = ['modelscope', '.']
  29. SCAN_SUB_FOLDERS = [
  30. 'models', 'metrics', 'pipelines', 'preprocessors', 'trainers',
  31. 'msdatasets', 'exporters'
  32. ]
  33. INDEXER_FILE = 'ast_indexer'
  34. DECORATOR_KEY = 'decorators'
  35. EXPRESS_KEY = 'express'
  36. FROM_IMPORT_KEY = 'from_imports'
  37. IMPORT_KEY = 'imports'
  38. FILE_NAME_KEY = 'filepath'
  39. MODELSCOPE_PATH_KEY = 'modelscope_path'
  40. VERSION_KEY = 'version'
  41. MD5_KEY = 'md5'
  42. INDEX_KEY = 'index'
  43. FILES_MTIME_KEY = 'files_mtime'
  44. REQUIREMENT_KEY = 'requirements'
  45. MODULE_KEY = 'module'
  46. CLASS_NAME = 'class_name'
  47. GROUP_KEY = 'group_key'
  48. MODULE_NAME = 'module_name'
  49. MODULE_CLS = 'module_cls'
  50. TEMPLATE_PATH = 'TEMPLATE_PATH'
  51. TEMPLATE_FILE = 'ast_index_file.py'
  52. def get_ast_logger():
  53. ast_logger = logging.getLogger('modelscope.ast')
  54. ast_logger.setLevel(logging.INFO)
  55. return ast_logger
  56. logger = get_ast_logger()
  57. class AstScanning(object):
  58. def __init__(self) -> None:
  59. self.result_import = dict()
  60. self.result_from_import = dict()
  61. self.result_decorator = []
  62. self.express = []
  63. def _is_sub_node(self, node: object) -> bool:
  64. return isinstance(node,
  65. ast.AST) and not isinstance(node, ast.expr_context)
  66. def _is_leaf(self, node: ast.AST) -> bool:
  67. for field in node._fields:
  68. attr = getattr(node, field)
  69. if self._is_sub_node(attr):
  70. return False
  71. elif isinstance(attr, (list, tuple)):
  72. for val in attr:
  73. if self._is_sub_node(val):
  74. return False
  75. else:
  76. return True
  77. def _skip_function(self, node: Union[ast.AST, 'str']) -> bool:
  78. if SKIP_FUNCTION_SCANNING:
  79. if type(node).__name__ == 'FunctionDef' or node == 'FunctionDef':
  80. return True
  81. return False
  82. def _fields(self, n: ast.AST, show_offsets: bool = True) -> tuple:
  83. if show_offsets:
  84. return n._attributes + n._fields
  85. else:
  86. return n._fields
  87. def _leaf(self, node: ast.AST, show_offsets: bool = True) -> str:
  88. output = dict()
  89. if isinstance(node, ast.AST):
  90. local_dict = dict()
  91. for field in self._fields(node, show_offsets=show_offsets):
  92. field_output = self._leaf(
  93. getattr(node, field), show_offsets=show_offsets)
  94. local_dict[field] = field_output
  95. output[type(node).__name__] = local_dict
  96. return output
  97. else:
  98. return node
  99. def _refresh(self):
  100. self.result_import = dict()
  101. self.result_from_import = dict()
  102. self.result_decorator = []
  103. self.result_express = []
  104. def scan_ast(self, node: Union[ast.AST, None, str]):
  105. self._setup_global()
  106. self.scan_import(node, indent=' ', show_offsets=False)
  107. def scan_import(
  108. self,
  109. node: Union[ast.AST, None, str],
  110. show_offsets: bool = True,
  111. parent_node_name: str = '',
  112. ) -> tuple:
  113. if node is None:
  114. return node
  115. elif self._is_leaf(node):
  116. return self._leaf(node, show_offsets=show_offsets)
  117. else:
  118. def _scan_import(el: Union[ast.AST, None, str],
  119. parent_node_name: str = '') -> str:
  120. return self.scan_import(
  121. el,
  122. show_offsets=show_offsets,
  123. parent_node_name=parent_node_name)
  124. outputs = dict()
  125. # add relative path expression
  126. if type(node).__name__ == 'ImportFrom':
  127. level = getattr(node, 'level')
  128. if level >= 1:
  129. path_level = ''.join(['.'] * level)
  130. setattr(node, 'level', 0)
  131. module_name = getattr(node, 'module')
  132. if module_name is None:
  133. setattr(node, 'module', path_level)
  134. else:
  135. setattr(node, 'module', path_level + module_name)
  136. for field in self._fields(node, show_offsets=show_offsets):
  137. attr = getattr(node, field)
  138. if attr == []:
  139. outputs[field] = []
  140. elif self._skip_function(parent_node_name):
  141. continue
  142. elif (isinstance(attr, list) and len(attr) == 1
  143. and isinstance(attr[0], ast.AST)
  144. and self._is_leaf(attr[0])):
  145. local_out = _scan_import(attr[0])
  146. outputs[field] = local_out
  147. elif isinstance(attr, list):
  148. el_dict = dict()
  149. for el in attr:
  150. local_out = _scan_import(el, type(el).__name__)
  151. name = type(el).__name__
  152. if (name == 'Import' or name == 'ImportFrom'
  153. or parent_node_name == 'ImportFrom'
  154. or parent_node_name == 'Import'):
  155. if name not in el_dict:
  156. el_dict[name] = []
  157. el_dict[name].append(local_out)
  158. outputs[field] = el_dict
  159. elif isinstance(attr, ast.AST):
  160. output = _scan_import(attr)
  161. outputs[field] = output
  162. else:
  163. outputs[field] = attr
  164. if (type(node).__name__ == 'Import'
  165. or type(node).__name__ == 'ImportFrom'):
  166. if type(node).__name__ == 'ImportFrom':
  167. if field == 'module':
  168. self.result_from_import[outputs[field]] = dict()
  169. if field == 'names':
  170. if isinstance(outputs[field]['alias'], list):
  171. item_name = []
  172. for item in outputs[field]['alias']:
  173. local_name = item['alias']['name']
  174. item_name.append(local_name)
  175. self.result_from_import[
  176. outputs['module']] = item_name
  177. else:
  178. local_name = outputs[field]['alias']['name']
  179. self.result_from_import[outputs['module']] = [
  180. local_name
  181. ]
  182. if type(node).__name__ == 'Import':
  183. final_dict = outputs[field]['alias']
  184. if isinstance(final_dict, list):
  185. for item in final_dict:
  186. self.result_import[item['alias']
  187. ['name']] = item['alias']
  188. else:
  189. self.result_import[outputs[field]['alias']
  190. ['name']] = final_dict
  191. if 'decorator_list' == field and attr != []:
  192. for item in attr:
  193. setattr(item, CLASS_NAME, node.name)
  194. self.result_decorator.extend(attr)
  195. if attr != [] and type(
  196. attr
  197. ).__name__ == 'Call' and parent_node_name == 'Expr':
  198. self.result_express.append(attr)
  199. return {
  200. IMPORT_KEY: self.result_import,
  201. FROM_IMPORT_KEY: self.result_from_import,
  202. DECORATOR_KEY: self.result_decorator,
  203. EXPRESS_KEY: self.result_express
  204. }
  205. def _parse_decorator(self, node: ast.AST) -> tuple:
  206. def _get_attribute_item(node: ast.AST) -> tuple:
  207. value, id, attr = None, None, None
  208. if type(node).__name__ == 'Attribute':
  209. value = getattr(node, 'value')
  210. id = getattr(value, 'id')
  211. attr = getattr(node, 'attr')
  212. if type(node).__name__ == 'Name':
  213. id = getattr(node, 'id')
  214. return id, attr
  215. def _get_args_name(nodes: list) -> list:
  216. result = []
  217. for node in nodes:
  218. if type(node).__name__ == 'Str':
  219. result.append((node.s, None))
  220. elif type(node).__name__ == 'Constant':
  221. result.append((node.value, None))
  222. else:
  223. result.append(_get_attribute_item(node))
  224. return result
  225. def _get_keyword_name(nodes: ast.AST) -> list:
  226. result = []
  227. for node in nodes:
  228. if type(node).__name__ == 'keyword':
  229. attribute_node = getattr(node, 'value')
  230. if type(attribute_node).__name__ == 'Str':
  231. result.append((getattr(node,
  232. 'arg'), attribute_node.s, None))
  233. elif type(attribute_node).__name__ == 'Constant':
  234. result.append(
  235. (getattr(node, 'arg'), attribute_node.value, None))
  236. else:
  237. result.append((getattr(node, 'arg'), )
  238. + _get_attribute_item(attribute_node))
  239. return result
  240. functions = _get_attribute_item(node.func)
  241. args_list = _get_args_name(node.args)
  242. keyword_list = _get_keyword_name(node.keywords)
  243. return functions, args_list, keyword_list
  244. def _get_registry_value(self, key_item):
  245. if key_item is None:
  246. return None
  247. if key_item == 'default_group':
  248. return default_group
  249. split_list = key_item.split('.')
  250. # in the case, the key_item is raw data, not registered
  251. if len(split_list) == 1:
  252. return key_item
  253. else:
  254. return getattr(eval(split_list[0]), split_list[1])
  255. def _registry_indexer(self, parsed_input: tuple, class_name: str) -> tuple:
  256. """format registry information to a tuple indexer
  257. Return:
  258. tuple: (MODELS, Tasks.text-classification, Models.structbert)
  259. """
  260. functions, args_list, keyword_list = parsed_input
  261. # ignore decorators other than register_module
  262. if REGISTER_MODULE != functions[1]:
  263. return None
  264. output = [functions[0]]
  265. if len(args_list) == 0 and len(keyword_list) == 0:
  266. args_list.append(default_group)
  267. if len(keyword_list) == 0 and len(args_list) == 1:
  268. args_list.append(class_name)
  269. if len(keyword_list) > 0 and len(args_list) == 0:
  270. remove_group_item = None
  271. for item in keyword_list:
  272. key, name, attr = item
  273. if key == GROUP_KEY:
  274. args_list.append((name, attr))
  275. remove_group_item = item
  276. if remove_group_item is not None:
  277. keyword_list.remove(remove_group_item)
  278. if len(args_list) == 0:
  279. args_list.append(default_group)
  280. for item in keyword_list:
  281. key, name, attr = item
  282. if key == MODULE_CLS:
  283. class_name = name
  284. else:
  285. args_list.append((name, attr))
  286. for item in args_list:
  287. # the case empty input
  288. if item is None:
  289. output.append(None)
  290. # the case (default_group)
  291. elif item[1] is None:
  292. output.append(item[0])
  293. elif isinstance(item, str):
  294. output.append(item)
  295. else:
  296. output.append('.'.join(item))
  297. return (output[0], self._get_registry_value(output[1]),
  298. self._get_registry_value(output[2]))
  299. def parse_decorators(self, nodes: list) -> list:
  300. """parse the AST nodes of decorators object to registry indexer
  301. Args:
  302. nodes (list): list of AST decorator nodes
  303. Returns:
  304. list: list of registry indexer
  305. """
  306. results = []
  307. for node in nodes:
  308. if type(node).__name__ != 'Call':
  309. continue
  310. class_name = getattr(node, CLASS_NAME, None)
  311. func = getattr(node, 'func')
  312. if getattr(func, 'attr', None) != REGISTER_MODULE:
  313. continue
  314. parse_output = self._parse_decorator(node)
  315. index = self._registry_indexer(parse_output, class_name)
  316. if None is not index:
  317. results.append(index)
  318. return results
  319. def generate_ast(self, file):
  320. self._refresh()
  321. with open(file, 'r', encoding='utf8') as code:
  322. data = code.readlines()
  323. data = ''.join(data)
  324. node = ast.parse(data)
  325. output = self.scan_import(node, show_offsets=False)
  326. output[DECORATOR_KEY] = self.parse_decorators(output[DECORATOR_KEY])
  327. output[EXPRESS_KEY] = self.parse_decorators(output[EXPRESS_KEY])
  328. output[DECORATOR_KEY].extend(output[EXPRESS_KEY])
  329. return output
  330. class FilesAstScanning(object):
  331. def __init__(self) -> None:
  332. self.astScaner = AstScanning()
  333. self.file_dirs = []
  334. self.requirement_dirs = []
  335. def _parse_import_path(self,
  336. import_package: str,
  337. current_path: str = None) -> str:
  338. """
  339. Args:
  340. import_package (str): relative import or abs import
  341. current_path (str): path/to/current/file
  342. """
  343. if import_package.startswith(IGNORED_PACKAGES[0]):
  344. return MODELSCOPE_PATH + '/' + '/'.join(
  345. import_package.split('.')[1:]) + '.py'
  346. elif import_package.startswith(IGNORED_PACKAGES[1]):
  347. current_path_list = current_path.split('/')
  348. import_package_list = import_package.split('.')
  349. level = 0
  350. for index, item in enumerate(import_package_list):
  351. if item != '':
  352. level = index
  353. break
  354. abs_path_list = current_path_list[0:-level]
  355. abs_path_list.extend(import_package_list[index:])
  356. return '/' + '/'.join(abs_path_list) + '.py'
  357. else:
  358. return current_path
  359. def _traversal_import(
  360. self,
  361. import_abs_path,
  362. ):
  363. pass
  364. def parse_import(self, scan_result: dict) -> list:
  365. """parse import and from import dicts to a third party package list
  366. Args:
  367. scan_result (dict): including the import and from import result
  368. Returns:
  369. list: a list of package ignored 'modelscope' and relative path import
  370. """
  371. output = []
  372. output.extend(list(scan_result[IMPORT_KEY].keys()))
  373. output.extend(list(scan_result[FROM_IMPORT_KEY].keys()))
  374. # get the package name
  375. for index, item in enumerate(output):
  376. if '' == item.split('.')[0]:
  377. output[index] = '.'
  378. else:
  379. output[index] = item.split('.')[0]
  380. ignored = set()
  381. for item in output:
  382. for ignored_package in IGNORED_PACKAGES:
  383. if item.startswith(ignored_package):
  384. ignored.add(item)
  385. return list(set(output) - set(ignored))
  386. def traversal_files(self, path, check_sub_dir=None, include_init=False):
  387. self.file_dirs = []
  388. if check_sub_dir is None or len(check_sub_dir) == 0:
  389. self._traversal_files(path, include_init=include_init)
  390. else:
  391. for item in check_sub_dir:
  392. sub_dir = os.path.join(path, item)
  393. if os.path.isdir(sub_dir):
  394. self._traversal_files(sub_dir, include_init=include_init)
  395. def _traversal_files(self, path, include_init=False):
  396. dir_list = os.scandir(path)
  397. for item in dir_list:
  398. if item.name == '__init__.py' and not include_init:
  399. continue
  400. elif (item.name.startswith('__')
  401. and item.name != '__init__.py') or item.name.endswith(
  402. '.json') or item.name.endswith('.md'):
  403. continue
  404. if item.is_dir():
  405. self._traversal_files(item.path, include_init=include_init)
  406. elif item.is_file() and item.name.endswith('.py'):
  407. self.file_dirs.append(item.path)
  408. elif item.is_file() and 'requirement' in item.name:
  409. self.requirement_dirs.append(item.path)
  410. def _get_single_file_scan_result(self, file):
  411. try:
  412. output = self.astScaner.generate_ast(file)
  413. except Exception as e:
  414. detail = traceback.extract_tb(e.__traceback__)
  415. raise Exception(
  416. f'During ast indexing the file {file}, a related error excepted '
  417. f'in the file {detail[-1].filename} at line: '
  418. f'{detail[-1].lineno}: "{detail[-1].line}" with error msg: '
  419. f'"{type(e).__name__}: {e}", please double check the origin file {file} '
  420. f'to see whether the file is correctly edited.')
  421. import_list = self.parse_import(output)
  422. return output[DECORATOR_KEY], import_list
  423. def _inverted_index(self, forward_index):
  424. inverted_index = dict()
  425. for index in forward_index:
  426. for item in forward_index[index][DECORATOR_KEY]:
  427. inverted_index[item] = {
  428. FILE_NAME_KEY: index,
  429. IMPORT_KEY: forward_index[index][IMPORT_KEY],
  430. MODULE_KEY: forward_index[index][MODULE_KEY],
  431. }
  432. return inverted_index
  433. def _module_import(self, forward_index):
  434. module_import = dict()
  435. for index, value_dict in forward_index.items():
  436. module_import[value_dict[MODULE_KEY]] = value_dict[IMPORT_KEY]
  437. return module_import
  438. def _ignore_useless_keys(self, inverted_index):
  439. if ('OPTIMIZERS', 'default', 'name') in inverted_index:
  440. del inverted_index[('OPTIMIZERS', 'default', 'name')]
  441. if ('LR_SCHEDULER', 'default', 'name') in inverted_index:
  442. del inverted_index[('LR_SCHEDULER', 'default', 'name')]
  443. return inverted_index
  444. def get_files_scan_results(self,
  445. target_file_list=None,
  446. target_dir=MODELSCOPE_PATH,
  447. target_folders=SCAN_SUB_FOLDERS):
  448. """the entry method of the ast scan method
  449. Args:
  450. target_file_list can override the dir and folders combine
  451. target_dir (str, optional): the absolute path of the target directory to be scanned. Defaults to None.
  452. target_folder (list, optional): the list of
  453. sub-folders to be scanned in the target folder.
  454. Defaults to SCAN_SUB_FOLDERS.
  455. Returns:
  456. dict: indexer of registry
  457. """
  458. start = time.time()
  459. if target_file_list is not None:
  460. self.file_dirs = target_file_list
  461. else:
  462. self.traversal_files(target_dir, target_folders)
  463. logger.info(
  464. f'AST-Scanning the path "{target_dir}" with the following sub folders {target_folders}'
  465. )
  466. result = dict()
  467. for file in self.file_dirs:
  468. filepath = file[file.rfind('modelscope'):]
  469. module_name = filepath.replace(osp.sep, '.').replace('.py', '')
  470. decorator_list, import_list = self._get_single_file_scan_result(
  471. file)
  472. result[file] = {
  473. DECORATOR_KEY: decorator_list,
  474. IMPORT_KEY: import_list,
  475. MODULE_KEY: module_name
  476. }
  477. inverted_index_with_results = self._inverted_index(result)
  478. inverted_index_with_results = self._ignore_useless_keys(
  479. inverted_index_with_results)
  480. module_import = self._module_import(result)
  481. index = {
  482. INDEX_KEY: inverted_index_with_results,
  483. REQUIREMENT_KEY: module_import
  484. }
  485. logger.info(
  486. f'Scanning done! A number of {len(inverted_index_with_results)} '
  487. f'components indexed or updated! Time consumed {time.time()-start}s'
  488. )
  489. return index
  490. def files_mtime_md5(self,
  491. target_path=MODELSCOPE_PATH,
  492. target_subfolder=SCAN_SUB_FOLDERS,
  493. file_list=None):
  494. self.file_dirs = []
  495. if file_list and isinstance(file_list, list):
  496. self.file_dirs = file_list
  497. else:
  498. self.traversal_files(target_path, target_subfolder)
  499. files_mtime = []
  500. files_mtime_dict = dict()
  501. for item in self.file_dirs:
  502. mtime = os.path.getmtime(item)
  503. files_mtime.append(mtime)
  504. files_mtime_dict[item] = mtime
  505. result_str = reduce(lambda x, y: str(x) + str(y), files_mtime, '')
  506. md5 = hashlib.md5(result_str.encode())
  507. return md5.hexdigest(), files_mtime_dict
  508. file_scanner = FilesAstScanning()
  509. def ensure_write(obj: bytes, filepath: Union[str, Path]) -> None:
  510. """Write data to a given ``filepath`` with 'wb' mode.
  511. Note:
  512. ``write`` will create a directory if the directory of ``filepath``
  513. does not exist.
  514. Args:
  515. obj (bytes): Data to be written.
  516. filepath (str or Path): Path to write data.
  517. """
  518. dirname = os.path.dirname(filepath)
  519. if dirname and not os.path.exists(dirname):
  520. os.makedirs(dirname, exist_ok=True)
  521. with open(filepath, 'wb') as f:
  522. f.write(obj)
  523. def _save_index(index, file_path, file_list=None, with_template=False):
  524. # convert tuple key to str key
  525. index[INDEX_KEY] = {str(k): v for k, v in index[INDEX_KEY].items()}
  526. from modelscope.version import __version__
  527. index[VERSION_KEY] = __version__
  528. index[MD5_KEY], index[FILES_MTIME_KEY] = file_scanner.files_mtime_md5(
  529. file_list=file_list)
  530. index[MODELSCOPE_PATH_KEY] = MODELSCOPE_PATH.as_posix()
  531. json_index = json.dumps(index)
  532. if with_template:
  533. json_index = json_index.replace(MODELSCOPE_PATH.as_posix(),
  534. TEMPLATE_PATH)
  535. ensure_write(json_index.encode(), file_path)
  536. index[INDEX_KEY] = {
  537. ast.literal_eval(k): v
  538. for k, v in index[INDEX_KEY].items()
  539. }
  540. def _load_index(file_path, with_template=False):
  541. with open(file_path, 'rb') as f:
  542. bytes_index = f.read()
  543. if with_template:
  544. bytes_index = bytes_index.decode().replace(TEMPLATE_PATH,
  545. MODELSCOPE_PATH.as_posix())
  546. wrapped_index = json.loads(bytes_index)
  547. # convert str key to tuple key
  548. wrapped_index[INDEX_KEY] = {
  549. ast.literal_eval(k): v
  550. for k, v in wrapped_index[INDEX_KEY].items()
  551. }
  552. return wrapped_index
  553. def _update_index(index, files_mtime):
  554. # inplace update index
  555. origin_files_mtime = index[FILES_MTIME_KEY]
  556. new_files = list(set(files_mtime) - set(origin_files_mtime))
  557. removed_files = list(set(origin_files_mtime) - set(files_mtime))
  558. updated_files = []
  559. for file in origin_files_mtime:
  560. if file not in removed_files and \
  561. (origin_files_mtime[file] != files_mtime[file]):
  562. updated_files.append(file)
  563. removed_files.extend(updated_files)
  564. updated_files.extend(new_files)
  565. # remove deleted index
  566. if len(removed_files) > 0:
  567. remove_index_keys = []
  568. remove_requirement_keys = []
  569. for key in index[INDEX_KEY]:
  570. if index[INDEX_KEY][key][FILE_NAME_KEY] in removed_files:
  571. remove_index_keys.append(key)
  572. remove_requirement_keys.append(
  573. index[INDEX_KEY][key][MODULE_KEY])
  574. for key in remove_index_keys:
  575. del index[INDEX_KEY][key]
  576. for key in remove_requirement_keys:
  577. if key in index[REQUIREMENT_KEY]:
  578. del index[REQUIREMENT_KEY][key]
  579. # add new index
  580. updated_index = file_scanner.get_files_scan_results(updated_files)
  581. index[INDEX_KEY].update(updated_index[INDEX_KEY])
  582. index[REQUIREMENT_KEY].update(updated_index[REQUIREMENT_KEY])
  583. def __is_develop_model():
  584. # use the trick of release time check is in development
  585. release_timestamp = int(
  586. round(
  587. datetime.strptime(version.__release_datetime__,
  588. '%Y-%m-%d %H:%M:%S').timestamp()))
  589. SECONDS_PER_YEAR = 24 * 365 * 60 * 60
  590. current_timestamp = int(round(datetime.now().timestamp()))
  591. if release_timestamp > current_timestamp + SECONDS_PER_YEAR:
  592. return True
  593. return False
  594. def load_index(
  595. file_list=None,
  596. force_rebuild=False,
  597. indexer_file_dir=INDEXER_FILE_DIR,
  598. indexer_file=INDEXER_FILE,
  599. ):
  600. """get the index from scan results or cache
  601. Args:
  602. file_list: load indexer only from the file lists if provided, default as None
  603. force_rebuild: If set true, rebuild and load index, default as False,
  604. indexer_file_dir: The dir where the indexer file saved, default as INDEXER_FILE_DIR
  605. indexer_file: The indexer file name, default as INDEXER_FILE
  606. Returns:
  607. dict: the index information for all registered modules, including key:
  608. index, requirements, files last modified time, modelscope home path,
  609. version and md5, the detail is shown below example: {
  610. 'index': {
  611. ('MODELS', 'nlp', 'bert'):{
  612. 'filepath' : 'path/to/the/registered/model', 'imports':
  613. ['os', 'torch', 'typing'] 'module':
  614. 'modelscope.models.nlp.bert'
  615. },
  616. ...
  617. }, 'requirements': {
  618. 'modelscope.models.nlp.bert': ['os', 'torch', 'typing'],
  619. 'modelscope.models.nlp.structbert': ['os', 'torch', 'typing'],
  620. ...
  621. }, 'files_mtime' : {
  622. '/User/Path/To/Your/Modelscope/modelscope/preprocessors/nlp/text_generation_preprocessor.py':
  623. 16554565445, ...
  624. },'version': '0.2.3', 'md5': '8616924970fe6bc119d1562832625612',
  625. 'modelscope_path': '/User/Path/To/Your/Modelscope'
  626. }
  627. """
  628. # env variable override
  629. cache_dir = os.getenv('MODELSCOPE_CACHE', indexer_file_dir)
  630. index_file = os.getenv('MODELSCOPE_INDEX_FILE', indexer_file)
  631. file_path = os.path.join(cache_dir, index_file)
  632. index = None
  633. if force_rebuild:
  634. logger.info('Force rebuilding ast index from scanning every file!')
  635. index = file_scanner.get_files_scan_results(file_list)
  636. return index
  637. # when developing, we need to generator as need.
  638. if __is_develop_model():
  639. logger.info(f'Loading ast index from {file_path}')
  640. if os.path.exists(file_path): # already exist, check it's latest
  641. wrapped_index = _load_index(file_path)
  642. md5, files_mtime = file_scanner.files_mtime_md5(
  643. file_list=file_list)
  644. index = wrapped_index
  645. from modelscope.version import __version__
  646. if (wrapped_index[VERSION_KEY] == __version__
  647. and wrapped_index[MD5_KEY] != md5) or \
  648. wrapped_index[VERSION_KEY] != __version__:
  649. logger.info(
  650. 'Updating the files for the changes of local files, '
  651. 'first time updating will take longer time! Please wait till updating done!'
  652. )
  653. _update_index(index, files_mtime)
  654. _save_index(index, file_path, file_list)
  655. else:
  656. logger.info(
  657. f'No valid ast index found from {file_path}, generating ast index from scratch!'
  658. )
  659. index = file_scanner.get_files_scan_results(
  660. file_list) # generate new
  661. _save_index(index, file_path, file_list) # save to generate path.
  662. logger.info(
  663. f'Loading done! Current index file version is {index[VERSION_KEY]}, '
  664. f'with md5 {index[MD5_KEY]} and a total number of '
  665. f'{len(index[INDEX_KEY])} components indexed')
  666. else: # just load the prebuild index file.
  667. index = load_from_prebuilt()
  668. return index
  669. def load_from_prebuilt(file_path=None):
  670. if file_path is None:
  671. local_path = p.resolve().parents[0]
  672. file_path = os.path.join(local_path, TEMPLATE_FILE)
  673. if os.path.exists(file_path):
  674. index = _load_index(file_path, with_template=True)
  675. else:
  676. index = generate_ast_template()
  677. return index
  678. def generate_ast_template(file_path=None, force_rebuild=True):
  679. index = load_index(force_rebuild=force_rebuild)
  680. if file_path is None:
  681. local_path = p.resolve().parents[0]
  682. file_path = os.path.join(local_path, TEMPLATE_FILE)
  683. _save_index(index, file_path, with_template=True)
  684. if not os.path.exists(file_path):
  685. raise Exception(
  686. 'The index file is not create correctly, please double check')
  687. return index
  688. if __name__ == '__main__':
  689. index = load_index(force_rebuild=True)
  690. print(index)