util.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. from typing import List, Optional, Union
  4. from modelscope.hub.api import HubApi
  5. from modelscope.hub.file_download import model_file_download
  6. from modelscope.utils.config import Config
  7. from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
  8. from modelscope.utils.logger import get_logger
  9. logger = get_logger()
  10. def is_config_has_model(cfg_file):
  11. try:
  12. cfg = Config.from_file(cfg_file)
  13. return hasattr(cfg, 'model') or hasattr(cfg, 'model_type')
  14. except Exception as e:
  15. logger.error(f'parse config file {cfg_file} failed: {e}')
  16. return False
  17. def is_official_hub_path(path: Union[str, List],
  18. revision: Optional[str] = DEFAULT_MODEL_REVISION):
  19. """ Whether path is an official hub name or a valid local
  20. path to official hub directory.
  21. """
  22. def is_official_hub_impl(path):
  23. if osp.exists(path):
  24. cfg_file = osp.join(path, ModelFile.CONFIGURATION)
  25. return osp.exists(cfg_file)
  26. else:
  27. try:
  28. _ = HubApi().get_model(path, revision=revision)
  29. return True
  30. except Exception as e:
  31. raise ValueError(f'invalid model repo path {e}')
  32. if isinstance(path, str):
  33. return is_official_hub_impl(path)
  34. else:
  35. results = [is_official_hub_impl(m) for m in path]
  36. all_true = all(results)
  37. any_true = any(results)
  38. if any_true and not all_true:
  39. raise ValueError(
  40. f'some model are hub address, some are not, model list: {path}'
  41. )
  42. return all_true
  43. def is_model(path: Union[str, List]):
  44. """ whether path is a valid modelhub path and containing model config
  45. """
  46. def is_modelhub_path_impl(path):
  47. if osp.exists(path):
  48. cfg_file = osp.join(path, ModelFile.CONFIGURATION)
  49. hf_cfg_file = osp.join(path, ModelFile.CONFIG)
  50. if osp.exists(cfg_file):
  51. return is_config_has_model(cfg_file)
  52. elif osp.exists(hf_cfg_file):
  53. return is_config_has_model(hf_cfg_file)
  54. else:
  55. return False
  56. else:
  57. try:
  58. cfg_file = model_file_download(path, ModelFile.CONFIGURATION)
  59. if is_config_has_model(cfg_file):
  60. return True
  61. else:
  62. hf_cfg_file = model_file_download(path, ModelFile.CONFIG)
  63. return is_config_has_model(hf_cfg_file)
  64. except Exception:
  65. return False
  66. if isinstance(path, str):
  67. return is_modelhub_path_impl(path)
  68. else:
  69. results = [is_modelhub_path_impl(m) for m in path]
  70. all_true = all(results)
  71. any_true = any(results)
  72. if any_true and not all_true:
  73. raise ValueError(
  74. f'some models are hub address, some are not, model list: {path}'
  75. )
  76. return all_true
  77. def batch_process(model, data):
  78. import torch
  79. if model.__class__.__name__ == 'OfaForAllTasks':
  80. # collate batch data due to the nested data structure
  81. assert isinstance(data, list)
  82. batch_data = {
  83. 'nsentences': len(data),
  84. 'samples': [d['samples'][0] for d in data],
  85. 'net_input': {}
  86. }
  87. for k in data[0]['net_input'].keys():
  88. batch_data['net_input'][k] = torch.cat(
  89. [d['net_input'][k] for d in data])
  90. if 'w_resize_ratios' in data[0]:
  91. batch_data['w_resize_ratios'] = torch.cat(
  92. [d['w_resize_ratios'] for d in data])
  93. if 'h_resize_ratios' in data[0]:
  94. batch_data['h_resize_ratios'] = torch.cat(
  95. [d['h_resize_ratios'] for d in data])
  96. return batch_data