# Copyright (c) Alibaba, Inc. and its affiliates. import ast import hashlib import logging import os import os.path as osp import time import traceback from datetime import datetime from functools import reduce from pathlib import Path from typing import Union import json from modelscope import version # do not delete from modelscope.metainfo import (CustomDatasets, Heads, Hooks, LR_Schedulers, Metrics, Models, Optimizers, Pipelines, Preprocessors, TaskModels, Trainers) from modelscope.utils.constant import Fields, Tasks from modelscope.utils.file_utils import get_modelscope_cache_dir from modelscope.utils.registry import default_group p = Path(__file__) # get the path of package 'modelscope' SKIP_FUNCTION_SCANNING = True MODELSCOPE_PATH = p.resolve().parents[1] INDEXER_FILE_DIR = get_modelscope_cache_dir() REGISTER_MODULE = 'register_module' IGNORED_PACKAGES = ['modelscope', '.'] SCAN_SUB_FOLDERS = [ 'models', 'metrics', 'pipelines', 'preprocessors', 'trainers', 'msdatasets', 'exporters' ] INDEXER_FILE = 'ast_indexer' DECORATOR_KEY = 'decorators' EXPRESS_KEY = 'express' FROM_IMPORT_KEY = 'from_imports' IMPORT_KEY = 'imports' FILE_NAME_KEY = 'filepath' MODELSCOPE_PATH_KEY = 'modelscope_path' VERSION_KEY = 'version' MD5_KEY = 'md5' INDEX_KEY = 'index' FILES_MTIME_KEY = 'files_mtime' REQUIREMENT_KEY = 'requirements' MODULE_KEY = 'module' CLASS_NAME = 'class_name' GROUP_KEY = 'group_key' MODULE_NAME = 'module_name' MODULE_CLS = 'module_cls' TEMPLATE_PATH = 'TEMPLATE_PATH' TEMPLATE_FILE = 'ast_index_file.py' def get_ast_logger(): ast_logger = logging.getLogger('modelscope.ast') ast_logger.setLevel(logging.INFO) return ast_logger logger = get_ast_logger() class AstScanning(object): def __init__(self) -> None: self.result_import = dict() self.result_from_import = dict() self.result_decorator = [] self.express = [] def _is_sub_node(self, node: object) -> bool: return isinstance(node, ast.AST) and not isinstance(node, ast.expr_context) def _is_leaf(self, node: ast.AST) -> bool: for field in node._fields: attr = getattr(node, field) if self._is_sub_node(attr): return False elif isinstance(attr, (list, tuple)): for val in attr: if self._is_sub_node(val): return False else: return True def _skip_function(self, node: Union[ast.AST, 'str']) -> bool: if SKIP_FUNCTION_SCANNING: if type(node).__name__ == 'FunctionDef' or node == 'FunctionDef': return True return False def _fields(self, n: ast.AST, show_offsets: bool = True) -> tuple: if show_offsets: return n._attributes + n._fields else: return n._fields def _leaf(self, node: ast.AST, show_offsets: bool = True) -> str: output = dict() if isinstance(node, ast.AST): local_dict = dict() for field in self._fields(node, show_offsets=show_offsets): field_output = self._leaf( getattr(node, field), show_offsets=show_offsets) local_dict[field] = field_output output[type(node).__name__] = local_dict return output else: return node def _refresh(self): self.result_import = dict() self.result_from_import = dict() self.result_decorator = [] self.result_express = [] def scan_ast(self, node: Union[ast.AST, None, str]): self._setup_global() self.scan_import(node, indent=' ', show_offsets=False) def scan_import( self, node: Union[ast.AST, None, str], show_offsets: bool = True, parent_node_name: str = '', ) -> tuple: if node is None: return node elif self._is_leaf(node): return self._leaf(node, show_offsets=show_offsets) else: def _scan_import(el: Union[ast.AST, None, str], parent_node_name: str = '') -> str: return self.scan_import( el, show_offsets=show_offsets, parent_node_name=parent_node_name) outputs = dict() # add relative path expression if type(node).__name__ == 'ImportFrom': level = getattr(node, 'level') if level >= 1: path_level = ''.join(['.'] * level) setattr(node, 'level', 0) module_name = getattr(node, 'module') if module_name is None: setattr(node, 'module', path_level) else: setattr(node, 'module', path_level + module_name) for field in self._fields(node, show_offsets=show_offsets): attr = getattr(node, field) if attr == []: outputs[field] = [] elif self._skip_function(parent_node_name): continue elif (isinstance(attr, list) and len(attr) == 1 and isinstance(attr[0], ast.AST) and self._is_leaf(attr[0])): local_out = _scan_import(attr[0]) outputs[field] = local_out elif isinstance(attr, list): el_dict = dict() for el in attr: local_out = _scan_import(el, type(el).__name__) name = type(el).__name__ if (name == 'Import' or name == 'ImportFrom' or parent_node_name == 'ImportFrom' or parent_node_name == 'Import'): if name not in el_dict: el_dict[name] = [] el_dict[name].append(local_out) outputs[field] = el_dict elif isinstance(attr, ast.AST): output = _scan_import(attr) outputs[field] = output else: outputs[field] = attr if (type(node).__name__ == 'Import' or type(node).__name__ == 'ImportFrom'): if type(node).__name__ == 'ImportFrom': if field == 'module': self.result_from_import[outputs[field]] = dict() if field == 'names': if isinstance(outputs[field]['alias'], list): item_name = [] for item in outputs[field]['alias']: local_name = item['alias']['name'] item_name.append(local_name) self.result_from_import[ outputs['module']] = item_name else: local_name = outputs[field]['alias']['name'] self.result_from_import[outputs['module']] = [ local_name ] if type(node).__name__ == 'Import': final_dict = outputs[field]['alias'] if isinstance(final_dict, list): for item in final_dict: self.result_import[item['alias'] ['name']] = item['alias'] else: self.result_import[outputs[field]['alias'] ['name']] = final_dict if 'decorator_list' == field and attr != []: for item in attr: setattr(item, CLASS_NAME, node.name) self.result_decorator.extend(attr) if attr != [] and type( attr ).__name__ == 'Call' and parent_node_name == 'Expr': self.result_express.append(attr) return { IMPORT_KEY: self.result_import, FROM_IMPORT_KEY: self.result_from_import, DECORATOR_KEY: self.result_decorator, EXPRESS_KEY: self.result_express } def _parse_decorator(self, node: ast.AST) -> tuple: def _get_attribute_item(node: ast.AST) -> tuple: value, id, attr = None, None, None if type(node).__name__ == 'Attribute': value = getattr(node, 'value') id = getattr(value, 'id') attr = getattr(node, 'attr') if type(node).__name__ == 'Name': id = getattr(node, 'id') return id, attr def _get_args_name(nodes: list) -> list: result = [] for node in nodes: if type(node).__name__ == 'Str': result.append((node.s, None)) elif type(node).__name__ == 'Constant': result.append((node.value, None)) else: result.append(_get_attribute_item(node)) return result def _get_keyword_name(nodes: ast.AST) -> list: result = [] for node in nodes: if type(node).__name__ == 'keyword': attribute_node = getattr(node, 'value') if type(attribute_node).__name__ == 'Str': result.append((getattr(node, 'arg'), attribute_node.s, None)) elif type(attribute_node).__name__ == 'Constant': result.append( (getattr(node, 'arg'), attribute_node.value, None)) else: result.append((getattr(node, 'arg'), ) + _get_attribute_item(attribute_node)) return result functions = _get_attribute_item(node.func) args_list = _get_args_name(node.args) keyword_list = _get_keyword_name(node.keywords) return functions, args_list, keyword_list def _get_registry_value(self, key_item): if key_item is None: return None if key_item == 'default_group': return default_group split_list = key_item.split('.') # in the case, the key_item is raw data, not registered if len(split_list) == 1: return key_item else: return getattr(eval(split_list[0]), split_list[1]) def _registry_indexer(self, parsed_input: tuple, class_name: str) -> tuple: """format registry information to a tuple indexer Return: tuple: (MODELS, Tasks.text-classification, Models.structbert) """ functions, args_list, keyword_list = parsed_input # ignore decorators other than register_module if REGISTER_MODULE != functions[1]: return None output = [functions[0]] if len(args_list) == 0 and len(keyword_list) == 0: args_list.append(default_group) if len(keyword_list) == 0 and len(args_list) == 1: args_list.append(class_name) if len(keyword_list) > 0 and len(args_list) == 0: remove_group_item = None for item in keyword_list: key, name, attr = item if key == GROUP_KEY: args_list.append((name, attr)) remove_group_item = item if remove_group_item is not None: keyword_list.remove(remove_group_item) if len(args_list) == 0: args_list.append(default_group) for item in keyword_list: key, name, attr = item if key == MODULE_CLS: class_name = name else: args_list.append((name, attr)) for item in args_list: # the case empty input if item is None: output.append(None) # the case (default_group) elif item[1] is None: output.append(item[0]) elif isinstance(item, str): output.append(item) else: output.append('.'.join(item)) return (output[0], self._get_registry_value(output[1]), self._get_registry_value(output[2])) def parse_decorators(self, nodes: list) -> list: """parse the AST nodes of decorators object to registry indexer Args: nodes (list): list of AST decorator nodes Returns: list: list of registry indexer """ results = [] for node in nodes: if type(node).__name__ != 'Call': continue class_name = getattr(node, CLASS_NAME, None) func = getattr(node, 'func') if getattr(func, 'attr', None) != REGISTER_MODULE: continue parse_output = self._parse_decorator(node) index = self._registry_indexer(parse_output, class_name) if None is not index: results.append(index) return results def generate_ast(self, file): self._refresh() with open(file, 'r', encoding='utf8') as code: data = code.readlines() data = ''.join(data) node = ast.parse(data) output = self.scan_import(node, show_offsets=False) output[DECORATOR_KEY] = self.parse_decorators(output[DECORATOR_KEY]) output[EXPRESS_KEY] = self.parse_decorators(output[EXPRESS_KEY]) output[DECORATOR_KEY].extend(output[EXPRESS_KEY]) return output class FilesAstScanning(object): def __init__(self) -> None: self.astScaner = AstScanning() self.file_dirs = [] self.requirement_dirs = [] def _parse_import_path(self, import_package: str, current_path: str = None) -> str: """ Args: import_package (str): relative import or abs import current_path (str): path/to/current/file """ if import_package.startswith(IGNORED_PACKAGES[0]): return MODELSCOPE_PATH + '/' + '/'.join( import_package.split('.')[1:]) + '.py' elif import_package.startswith(IGNORED_PACKAGES[1]): current_path_list = current_path.split('/') import_package_list = import_package.split('.') level = 0 for index, item in enumerate(import_package_list): if item != '': level = index break abs_path_list = current_path_list[0:-level] abs_path_list.extend(import_package_list[index:]) return '/' + '/'.join(abs_path_list) + '.py' else: return current_path def _traversal_import( self, import_abs_path, ): pass def parse_import(self, scan_result: dict) -> list: """parse import and from import dicts to a third party package list Args: scan_result (dict): including the import and from import result Returns: list: a list of package ignored 'modelscope' and relative path import """ output = [] output.extend(list(scan_result[IMPORT_KEY].keys())) output.extend(list(scan_result[FROM_IMPORT_KEY].keys())) # get the package name for index, item in enumerate(output): if '' == item.split('.')[0]: output[index] = '.' else: output[index] = item.split('.')[0] ignored = set() for item in output: for ignored_package in IGNORED_PACKAGES: if item.startswith(ignored_package): ignored.add(item) return list(set(output) - set(ignored)) def traversal_files(self, path, check_sub_dir=None, include_init=False): self.file_dirs = [] if check_sub_dir is None or len(check_sub_dir) == 0: self._traversal_files(path, include_init=include_init) else: for item in check_sub_dir: sub_dir = os.path.join(path, item) if os.path.isdir(sub_dir): self._traversal_files(sub_dir, include_init=include_init) def _traversal_files(self, path, include_init=False): dir_list = os.scandir(path) for item in dir_list: if item.name == '__init__.py' and not include_init: continue elif (item.name.startswith('__') and item.name != '__init__.py') or item.name.endswith( '.json') or item.name.endswith('.md'): continue if item.is_dir(): self._traversal_files(item.path, include_init=include_init) elif item.is_file() and item.name.endswith('.py'): self.file_dirs.append(item.path) elif item.is_file() and 'requirement' in item.name: self.requirement_dirs.append(item.path) def _get_single_file_scan_result(self, file): try: output = self.astScaner.generate_ast(file) except Exception as e: detail = traceback.extract_tb(e.__traceback__) raise Exception( f'During ast indexing the file {file}, a related error excepted ' f'in the file {detail[-1].filename} at line: ' f'{detail[-1].lineno}: "{detail[-1].line}" with error msg: ' f'"{type(e).__name__}: {e}", please double check the origin file {file} ' f'to see whether the file is correctly edited.') import_list = self.parse_import(output) return output[DECORATOR_KEY], import_list def _inverted_index(self, forward_index): inverted_index = dict() for index in forward_index: for item in forward_index[index][DECORATOR_KEY]: inverted_index[item] = { FILE_NAME_KEY: index, IMPORT_KEY: forward_index[index][IMPORT_KEY], MODULE_KEY: forward_index[index][MODULE_KEY], } return inverted_index def _module_import(self, forward_index): module_import = dict() for index, value_dict in forward_index.items(): module_import[value_dict[MODULE_KEY]] = value_dict[IMPORT_KEY] return module_import def _ignore_useless_keys(self, inverted_index): if ('OPTIMIZERS', 'default', 'name') in inverted_index: del inverted_index[('OPTIMIZERS', 'default', 'name')] if ('LR_SCHEDULER', 'default', 'name') in inverted_index: del inverted_index[('LR_SCHEDULER', 'default', 'name')] return inverted_index def get_files_scan_results(self, target_file_list=None, target_dir=MODELSCOPE_PATH, target_folders=SCAN_SUB_FOLDERS): """the entry method of the ast scan method Args: target_file_list can override the dir and folders combine target_dir (str, optional): the absolute path of the target directory to be scanned. Defaults to None. target_folder (list, optional): the list of sub-folders to be scanned in the target folder. Defaults to SCAN_SUB_FOLDERS. Returns: dict: indexer of registry """ start = time.time() if target_file_list is not None: self.file_dirs = target_file_list else: self.traversal_files(target_dir, target_folders) logger.info( f'AST-Scanning the path "{target_dir}" with the following sub folders {target_folders}' ) result = dict() for file in self.file_dirs: filepath = file[file.rfind('modelscope'):] module_name = filepath.replace(osp.sep, '.').replace('.py', '') decorator_list, import_list = self._get_single_file_scan_result( file) result[file] = { DECORATOR_KEY: decorator_list, IMPORT_KEY: import_list, MODULE_KEY: module_name } inverted_index_with_results = self._inverted_index(result) inverted_index_with_results = self._ignore_useless_keys( inverted_index_with_results) module_import = self._module_import(result) index = { INDEX_KEY: inverted_index_with_results, REQUIREMENT_KEY: module_import } logger.info( f'Scanning done! A number of {len(inverted_index_with_results)} ' f'components indexed or updated! Time consumed {time.time()-start}s' ) return index def files_mtime_md5(self, target_path=MODELSCOPE_PATH, target_subfolder=SCAN_SUB_FOLDERS, file_list=None): self.file_dirs = [] if file_list and isinstance(file_list, list): self.file_dirs = file_list else: self.traversal_files(target_path, target_subfolder) files_mtime = [] files_mtime_dict = dict() for item in self.file_dirs: mtime = os.path.getmtime(item) files_mtime.append(mtime) files_mtime_dict[item] = mtime result_str = reduce(lambda x, y: str(x) + str(y), files_mtime, '') md5 = hashlib.md5(result_str.encode()) return md5.hexdigest(), files_mtime_dict file_scanner = FilesAstScanning() def ensure_write(obj: bytes, filepath: Union[str, Path]) -> None: """Write data to a given ``filepath`` with 'wb' mode. Note: ``write`` will create a directory if the directory of ``filepath`` does not exist. Args: obj (bytes): Data to be written. filepath (str or Path): Path to write data. """ dirname = os.path.dirname(filepath) if dirname and not os.path.exists(dirname): os.makedirs(dirname, exist_ok=True) with open(filepath, 'wb') as f: f.write(obj) def _save_index(index, file_path, file_list=None, with_template=False): # convert tuple key to str key index[INDEX_KEY] = {str(k): v for k, v in index[INDEX_KEY].items()} from modelscope.version import __version__ index[VERSION_KEY] = __version__ index[MD5_KEY], index[FILES_MTIME_KEY] = file_scanner.files_mtime_md5( file_list=file_list) index[MODELSCOPE_PATH_KEY] = MODELSCOPE_PATH.as_posix() json_index = json.dumps(index) if with_template: json_index = json_index.replace(MODELSCOPE_PATH.as_posix(), TEMPLATE_PATH) ensure_write(json_index.encode(), file_path) index[INDEX_KEY] = { ast.literal_eval(k): v for k, v in index[INDEX_KEY].items() } def _load_index(file_path, with_template=False): with open(file_path, 'rb') as f: bytes_index = f.read() if with_template: bytes_index = bytes_index.decode().replace(TEMPLATE_PATH, MODELSCOPE_PATH.as_posix()) wrapped_index = json.loads(bytes_index) # convert str key to tuple key wrapped_index[INDEX_KEY] = { ast.literal_eval(k): v for k, v in wrapped_index[INDEX_KEY].items() } return wrapped_index def _update_index(index, files_mtime): # inplace update index origin_files_mtime = index[FILES_MTIME_KEY] new_files = list(set(files_mtime) - set(origin_files_mtime)) removed_files = list(set(origin_files_mtime) - set(files_mtime)) updated_files = [] for file in origin_files_mtime: if file not in removed_files and \ (origin_files_mtime[file] != files_mtime[file]): updated_files.append(file) removed_files.extend(updated_files) updated_files.extend(new_files) # remove deleted index if len(removed_files) > 0: remove_index_keys = [] remove_requirement_keys = [] for key in index[INDEX_KEY]: if index[INDEX_KEY][key][FILE_NAME_KEY] in removed_files: remove_index_keys.append(key) remove_requirement_keys.append( index[INDEX_KEY][key][MODULE_KEY]) for key in remove_index_keys: del index[INDEX_KEY][key] for key in remove_requirement_keys: if key in index[REQUIREMENT_KEY]: del index[REQUIREMENT_KEY][key] # add new index updated_index = file_scanner.get_files_scan_results(updated_files) index[INDEX_KEY].update(updated_index[INDEX_KEY]) index[REQUIREMENT_KEY].update(updated_index[REQUIREMENT_KEY]) def __is_develop_model(): # use the trick of release time check is in development release_timestamp = int( round( datetime.strptime(version.__release_datetime__, '%Y-%m-%d %H:%M:%S').timestamp())) SECONDS_PER_YEAR = 24 * 365 * 60 * 60 current_timestamp = int(round(datetime.now().timestamp())) if release_timestamp > current_timestamp + SECONDS_PER_YEAR: return True return False def load_index( file_list=None, force_rebuild=False, indexer_file_dir=INDEXER_FILE_DIR, indexer_file=INDEXER_FILE, ): """get the index from scan results or cache Args: file_list: load indexer only from the file lists if provided, default as None force_rebuild: If set true, rebuild and load index, default as False, indexer_file_dir: The dir where the indexer file saved, default as INDEXER_FILE_DIR indexer_file: The indexer file name, default as INDEXER_FILE Returns: dict: the index information for all registered modules, including key: index, requirements, files last modified time, modelscope home path, version and md5, the detail is shown below example: { 'index': { ('MODELS', 'nlp', 'bert'):{ 'filepath' : 'path/to/the/registered/model', 'imports': ['os', 'torch', 'typing'] 'module': 'modelscope.models.nlp.bert' }, ... }, 'requirements': { 'modelscope.models.nlp.bert': ['os', 'torch', 'typing'], 'modelscope.models.nlp.structbert': ['os', 'torch', 'typing'], ... }, 'files_mtime' : { '/User/Path/To/Your/Modelscope/modelscope/preprocessors/nlp/text_generation_preprocessor.py': 16554565445, ... },'version': '0.2.3', 'md5': '8616924970fe6bc119d1562832625612', 'modelscope_path': '/User/Path/To/Your/Modelscope' } """ # env variable override cache_dir = os.getenv('MODELSCOPE_CACHE', indexer_file_dir) index_file = os.getenv('MODELSCOPE_INDEX_FILE', indexer_file) file_path = os.path.join(cache_dir, index_file) index = None if force_rebuild: logger.info('Force rebuilding ast index from scanning every file!') index = file_scanner.get_files_scan_results(file_list) return index # when developing, we need to generator as need. if __is_develop_model(): logger.info(f'Loading ast index from {file_path}') if os.path.exists(file_path): # already exist, check it's latest wrapped_index = _load_index(file_path) md5, files_mtime = file_scanner.files_mtime_md5( file_list=file_list) index = wrapped_index from modelscope.version import __version__ if (wrapped_index[VERSION_KEY] == __version__ and wrapped_index[MD5_KEY] != md5) or \ wrapped_index[VERSION_KEY] != __version__: logger.info( 'Updating the files for the changes of local files, ' 'first time updating will take longer time! Please wait till updating done!' ) _update_index(index, files_mtime) _save_index(index, file_path, file_list) else: logger.info( f'No valid ast index found from {file_path}, generating ast index from scratch!' ) index = file_scanner.get_files_scan_results( file_list) # generate new _save_index(index, file_path, file_list) # save to generate path. logger.info( f'Loading done! Current index file version is {index[VERSION_KEY]}, ' f'with md5 {index[MD5_KEY]} and a total number of ' f'{len(index[INDEX_KEY])} components indexed') else: # just load the prebuild index file. index = load_from_prebuilt() return index def load_from_prebuilt(file_path=None): if file_path is None: local_path = p.resolve().parents[0] file_path = os.path.join(local_path, TEMPLATE_FILE) if os.path.exists(file_path): index = _load_index(file_path, with_template=True) else: index = generate_ast_template() return index def generate_ast_template(file_path=None, force_rebuild=True): index = load_index(force_rebuild=force_rebuild) if file_path is None: local_path = p.resolve().parents[0] file_path = os.path.join(local_path, TEMPLATE_FILE) _save_index(index, file_path, with_template=True) if not os.path.exists(file_path): raise Exception( 'The index file is not create correctly, please double check') return index if __name__ == '__main__': index = load_index(force_rebuild=True) print(index)