hub.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import os.path as osp
  4. from typing import List, Optional, Union
  5. from requests import HTTPError
  6. from modelscope.hub.constants import Licenses, ModelVisibility
  7. from modelscope.hub.file_download import model_file_download
  8. from modelscope.hub.snapshot_download import snapshot_download
  9. from modelscope.utils.config import Config
  10. from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields,
  11. ModelFile)
  12. from .logger import get_logger
  13. logger = get_logger()
  14. def create_model_if_not_exist(
  15. api,
  16. model_id: str,
  17. chinese_name: str,
  18. visibility: Optional[int] = ModelVisibility.PUBLIC,
  19. license: Optional[str] = Licenses.APACHE_V2):
  20. if api.repo_exists(model_id):
  21. logger.info(f'model {model_id} already exists, skip creation.')
  22. return False
  23. else:
  24. api.create_model(
  25. model_id=model_id,
  26. visibility=visibility,
  27. license=license,
  28. chinese_name=chinese_name,
  29. )
  30. logger.info(f'model {model_id} successfully created.')
  31. return True
  32. def read_config(model_id_or_path: str,
  33. revision: Optional[str] = DEFAULT_MODEL_REVISION):
  34. """ Read config from hub or local path
  35. Args:
  36. model_id_or_path (str): Model repo name or local directory path.
  37. revision: revision of the model when getting from the hub
  38. Return:
  39. config (:obj:`Config`): config object
  40. """
  41. if not os.path.exists(model_id_or_path):
  42. local_path = model_file_download(
  43. model_id_or_path, ModelFile.CONFIGURATION, revision=revision)
  44. elif os.path.isdir(model_id_or_path):
  45. local_path = os.path.join(model_id_or_path, ModelFile.CONFIGURATION)
  46. elif os.path.isfile(model_id_or_path):
  47. local_path = model_id_or_path
  48. else:
  49. return None
  50. return Config.from_file(local_path)
  51. def auto_load(model: Union[str, List[str]]):
  52. if isinstance(model, str):
  53. if not osp.exists(model):
  54. model = snapshot_download(model)
  55. else:
  56. model = [
  57. snapshot_download(m) if not osp.exists(m) else m for m in model
  58. ]
  59. return model
  60. def get_model_type(model_dir):
  61. """Get the model type from the configuration.
  62. This method will try to get the model type from 'model.backbone.type',
  63. 'model.type' or 'model.model_type' field in the configuration.json file. If
  64. this file does not exist, the method will try to get the 'model_type' field
  65. from the config.json.
  66. Args:
  67. model_dir: The local model dir to use. @return: The model type
  68. string, returns None if nothing is found.
  69. """
  70. try:
  71. configuration_file = osp.join(model_dir, ModelFile.CONFIGURATION)
  72. config_file = osp.join(model_dir, 'config.json')
  73. if osp.isfile(configuration_file):
  74. cfg = Config.from_file(configuration_file)
  75. if hasattr(cfg.model, 'backbone'):
  76. return cfg.model.backbone.type
  77. elif hasattr(cfg.model,
  78. 'model_type') and not hasattr(cfg.model, 'type'):
  79. return cfg.model.model_type
  80. else:
  81. return cfg.model.type
  82. elif osp.isfile(config_file):
  83. cfg = Config.from_file(config_file)
  84. return cfg.model_type if hasattr(cfg, 'model_type') else None
  85. except Exception as e:
  86. logger.error(f'parse config file failed with error: {e}')
  87. def parse_label_mapping(model_dir):
  88. """Get the label mapping from the model dir.
  89. This method will do:
  90. 1. Try to read label-id mapping from the label_mapping.json
  91. 2. Try to read label-id mapping from the configuration.json
  92. 3. Try to read label-id mapping from the config.json
  93. Args:
  94. model_dir: The local model dir to use.
  95. Returns:
  96. The label2id mapping if found.
  97. """
  98. import json
  99. import os
  100. label2id = None
  101. label_path = os.path.join(model_dir, ModelFile.LABEL_MAPPING)
  102. if os.path.exists(label_path):
  103. with open(label_path, encoding='utf-8') as f:
  104. label_mapping = json.load(f)
  105. label2id = {name: idx for name, idx in label_mapping.items()}
  106. if label2id is None:
  107. config_path = os.path.join(model_dir, ModelFile.CONFIGURATION)
  108. config = Config.from_file(config_path)
  109. if hasattr(config, ConfigFields.model) and hasattr(
  110. config[ConfigFields.model], 'label2id'):
  111. label2id = config[ConfigFields.model].label2id
  112. elif hasattr(config, ConfigFields.model) and hasattr(
  113. config[ConfigFields.model], 'id2label'):
  114. id2label = config[ConfigFields.model].id2label
  115. label2id = {label: id for id, label in id2label.items()}
  116. elif hasattr(config, ConfigFields.preprocessor) and hasattr(
  117. config[ConfigFields.preprocessor], 'label2id'):
  118. label2id = config[ConfigFields.preprocessor].label2id
  119. elif hasattr(config, ConfigFields.preprocessor) and hasattr(
  120. config[ConfigFields.preprocessor], 'id2label'):
  121. id2label = config[ConfigFields.preprocessor].id2label
  122. label2id = {label: id for id, label in id2label.items()}
  123. config_path = os.path.join(model_dir, 'config.json')
  124. if label2id is None and os.path.exists(config_path):
  125. config = Config.from_file(config_path)
  126. if hasattr(config, 'label2id'):
  127. label2id = config.label2id
  128. elif hasattr(config, 'id2label'):
  129. id2label = config.id2label
  130. label2id = {label: id for id, label in id2label.items()}
  131. if label2id is not None:
  132. label2id = {label: int(id) for label, id in label2id.items()}
  133. return label2id