deploy_checker.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import argparse
  2. import traceback
  3. from typing import List, Union
  4. from modelscope.hub.api import HubApi
  5. from modelscope.hub.file_download import model_file_download
  6. from modelscope.pipelines import pipeline
  7. from modelscope.utils.config import Config
  8. from modelscope.utils.constant import ModelFile
  9. from modelscope.utils.input_output import (
  10. call_pipeline_with_json, get_pipeline_information_by_pipeline,
  11. get_task_input_examples, pipeline_output_to_service_base64_output)
  12. from modelscope.utils.logger import get_logger
  13. logger = get_logger()
  14. class DeployChecker:
  15. def __init__(self):
  16. self.api = HubApi()
  17. def check_model(self, model_id: str, model_revision=None):
  18. # get model_revision & task info
  19. if not model_revision:
  20. model_revisions = self.api.list_model_revisions(model_id)
  21. logger.info(
  22. f'All model_revisions of `{model_id}`: {model_revisions}')
  23. if len(model_revisions):
  24. model_revision = model_revisions[0]
  25. else:
  26. logger.error(f'{model_id} has no revision.')
  27. configuration_file = model_file_download(
  28. model_id=model_id,
  29. file_path=ModelFile.CONFIGURATION,
  30. revision=model_revision)
  31. cfg = Config.from_file(configuration_file)
  32. task = cfg.safe_get('task')
  33. # init pipeline
  34. ppl = pipeline(
  35. task=task,
  36. model=model_id,
  37. model_revision=model_revision,
  38. external_engine_for_llm=True)
  39. pipeline_info = get_pipeline_information_by_pipeline(ppl)
  40. # call pipeline
  41. data = get_task_input_examples(task)
  42. infer_result = call_pipeline_with_json(pipeline_info, ppl, data)
  43. result = pipeline_output_to_service_base64_output(task, infer_result)
  44. return result
  45. def check_deploy(models: Union[str, List], revisions: Union[str, List] = None):
  46. if not isinstance(models, list):
  47. models = [models]
  48. if not isinstance(revisions, list):
  49. revisions = [revisions] * (1 if revisions else len(models))
  50. if len(models) != len(revisions):
  51. logger.error(
  52. f'The number of models and revisions need to be equal: The number of models'
  53. f' is {len(model)} while the number of revisions is {len(revision)}.'
  54. )
  55. checker = DeployChecker()
  56. for model, revision in zip(models, revisions):
  57. try:
  58. res = checker.check_model(model, revision)
  59. logger.info(f'{model} {revision}: Deploy pre-check pass. {res}\n')
  60. except BaseException as e:
  61. logger.info(
  62. f'{model} {revision}: Deploy pre-check failed: {e}. {traceback.print_exc()}\n'
  63. )
  64. if __name__ == '__main__':
  65. parser = argparse.ArgumentParser()
  66. parser.add_argument('--model_id', type=str)
  67. parser.add_argument('--revision', type=str, default=None)
  68. args = parser.parse_args()
  69. check_deploy(args.model_id, args.revision)