registry.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import importlib
  3. import inspect
  4. from typing import List, Tuple, Union
  5. from modelscope.utils.logger import get_logger
  6. TYPE_NAME = 'type'
  7. default_group = 'default'
  8. logger = get_logger()
  9. class Registry(object):
  10. """ Registry which support registering modules and group them by a keyname
  11. If group name is not provided, modules will be registered to default group.
  12. """
  13. def __init__(self, name: str):
  14. self._name = name
  15. self._modules = {default_group: {}}
  16. def __repr__(self):
  17. format_str = self.__class__.__name__ + f' ({self._name})\n'
  18. for group_name, group in self._modules.items():
  19. format_str += f'group_name={group_name}, '\
  20. f'modules={list(group.keys())}\n'
  21. return format_str
  22. @property
  23. def name(self):
  24. return self._name
  25. @property
  26. def modules(self):
  27. return self._modules
  28. def list(self):
  29. """ logging the list of module in current registry
  30. """
  31. for group_name, group in self._modules.items():
  32. logger.info(f'group_name={group_name}')
  33. for m in group.keys():
  34. logger.info(f'\t{m}')
  35. logger.info('')
  36. def get(self, module_key, group_key=default_group):
  37. if group_key not in self._modules:
  38. return None
  39. else:
  40. return self._modules[group_key].get(module_key, None)
  41. def _register_module(self,
  42. group_key=default_group,
  43. module_name=None,
  44. module_cls=None,
  45. force=False):
  46. assert isinstance(group_key,
  47. str), 'group_key is required and must be str'
  48. if group_key not in self._modules:
  49. self._modules[group_key] = dict()
  50. # Some registered module_cls can be function type.
  51. # if not inspect.isclass(module_cls):
  52. # raise TypeError(f'module is not a class type: {type(module_cls)}')
  53. if module_name is None:
  54. module_name = module_cls.__name__
  55. if module_name in self._modules[group_key] and not force:
  56. raise KeyError(f'{module_name} is already registered in '
  57. f'{self._name}[{group_key}]')
  58. self._modules[group_key][module_name] = module_cls
  59. module_cls.group_key = group_key
  60. def register_module(self,
  61. group_key: str = default_group,
  62. module_name: str = None,
  63. module_cls: type = None,
  64. force=False):
  65. """ Register module
  66. Example:
  67. >>> models = Registry('models')
  68. >>> @models.register_module('image-classification', 'SwinT')
  69. >>> class SwinTransformer:
  70. >>> pass
  71. >>> @models.register_module('SwinDefault')
  72. >>> class SwinTransformerDefaultGroup:
  73. >>> pass
  74. >>> class SwinTransformer2:
  75. >>> pass
  76. >>> MODELS.register_module('image-classification',
  77. module_name='SwinT2',
  78. module_cls=SwinTransformer2)
  79. Args:
  80. group_key: Group name of which module will be registered,
  81. default group name is 'default'
  82. module_name: Module name
  83. module_cls: Module class object
  84. force (bool, optional): Whether to override an existing class with
  85. the same name. Default: False.
  86. """
  87. if not (module_name is None or isinstance(module_name, str)):
  88. raise TypeError(f'module_name must be either of None, str,'
  89. f'got {type(module_name)}')
  90. if module_cls is not None:
  91. self._register_module(
  92. group_key=group_key,
  93. module_name=module_name,
  94. module_cls=module_cls,
  95. force=force)
  96. return module_cls
  97. # if module_cls is None, should return a decorator function
  98. def _register(module_cls):
  99. self._register_module(
  100. group_key=group_key,
  101. module_name=module_name,
  102. module_cls=module_cls,
  103. force=force)
  104. return module_cls
  105. return _register
  106. def build_from_cfg(cfg,
  107. registry: Registry,
  108. group_key: str = default_group,
  109. default_args: dict = None) -> object:
  110. """Build a module from config dict when it is a class configuration, or
  111. call a function from config dict when it is a function configuration.
  112. Example:
  113. >>> models = Registry('models')
  114. >>> @models.register_module('image-classification', 'SwinT')
  115. >>> class SwinTransformer:
  116. >>> pass
  117. >>> swint = build_from_cfg(dict(type='SwinT'), MODELS,
  118. >>> 'image-classification')
  119. >>> # Returns an instantiated object
  120. >>>
  121. >>> @MODELS.register_module()
  122. >>> def swin_transformer():
  123. >>> pass
  124. >>> = build_from_cfg(dict(type='swin_transformer'), MODELS)
  125. >>> # Return a result of the calling function
  126. Args:
  127. cfg (dict): Config dict. It should at least contain the key "type".
  128. registry (:obj:`Registry`): The registry to search the type from.
  129. group_key (str, optional): The name of registry group from which
  130. module should be searched.
  131. default_args (dict, optional): Default initialization arguments.
  132. type_name (str, optional): The name of the type in the config.
  133. Returns:
  134. object: The constructed object.
  135. """
  136. if not isinstance(cfg, dict):
  137. raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
  138. if TYPE_NAME not in cfg:
  139. if default_args is None or TYPE_NAME not in default_args:
  140. raise KeyError(
  141. f'`cfg` or `default_args` must contain the key "{TYPE_NAME}", '
  142. f'but got {cfg}\n{default_args}')
  143. if not isinstance(registry, Registry):
  144. raise TypeError('registry must be an modelscope.Registry object, '
  145. f'but got {type(registry)}')
  146. if not (isinstance(default_args, dict) or default_args is None):
  147. raise TypeError('default_args must be a dict or None, '
  148. f'but got {type(default_args)}')
  149. # dynamic load installation requirements for this module
  150. from modelscope.utils.import_utils import LazyImportModule
  151. sig = (registry.name.upper(), group_key, cfg['type'])
  152. LazyImportModule.import_module(sig)
  153. args = cfg.copy()
  154. if default_args is not None:
  155. for name, value in default_args.items():
  156. args.setdefault(name, value)
  157. if group_key is None:
  158. group_key = default_group
  159. obj_type = args.pop(TYPE_NAME)
  160. if isinstance(obj_type, str):
  161. obj_cls = registry.get(obj_type, group_key=group_key)
  162. if obj_cls is None:
  163. raise KeyError(
  164. f'{obj_type} is not in the {registry.name}'
  165. f' registry group {group_key}. Please make'
  166. f' sure the correct version of ModelScope library is used.')
  167. obj_cls.group_key = group_key
  168. elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
  169. obj_cls = obj_type
  170. else:
  171. raise TypeError(
  172. f'type must be a str or valid type, but got {type(obj_type)}')
  173. try:
  174. if hasattr(obj_cls, '_instantiate'):
  175. return obj_cls._instantiate(**args)
  176. else:
  177. return obj_cls(**args)
  178. except Exception as e:
  179. # Normal TypeError does not print class name.
  180. raise type(e)(f'{obj_cls.__name__}: {e}') from e