pipeline.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import logging
  3. import os
  4. from argparse import ArgumentParser
  5. from string import Template
  6. from modelscope.cli.base import CLICommand
  7. from modelscope.utils.logger import get_logger
  8. logger = get_logger(log_level=logging.WARNING)
  9. current_path = os.path.dirname(os.path.abspath(__file__))
  10. template_path = os.path.join(current_path, 'template')
  11. def subparser_func(args):
  12. """ Function which will be called for a specific sub parser.
  13. """
  14. return PipelineCMD(args)
  15. class PipelineCMD(CLICommand):
  16. name = 'pipeline'
  17. def __init__(self, args):
  18. self.args = args
  19. @staticmethod
  20. def define_args(parsers: ArgumentParser):
  21. """ define args for create pipeline template command.
  22. """
  23. parser = parsers.add_parser(PipelineCMD.name)
  24. parser.add_argument(
  25. '-act',
  26. '--action',
  27. type=str,
  28. required=True,
  29. choices=['create'],
  30. help='the action of command pipeline[create]')
  31. parser.add_argument(
  32. '-tpl',
  33. '--tpl_file_path',
  34. type=str,
  35. default='template.tpl',
  36. help='the template be selected for ModelScope[template.tpl]')
  37. parser.add_argument(
  38. '-s',
  39. '--save_file_path',
  40. type=str,
  41. default='./',
  42. help='the name of custom template be saved for ModelScope')
  43. parser.add_argument(
  44. '-f',
  45. '--filename',
  46. type=str,
  47. default='ms_wrapper.py',
  48. help='the init name of custom template be saved for ModelScope')
  49. parser.add_argument(
  50. '-t',
  51. '--task_name',
  52. type=str,
  53. required=True,
  54. help='the unique task_name for ModelScope')
  55. parser.add_argument(
  56. '-m',
  57. '--model_name',
  58. type=str,
  59. default='MyCustomModel',
  60. help='the class of model name for ModelScope')
  61. parser.add_argument(
  62. '-p',
  63. '--preprocessor_name',
  64. type=str,
  65. default='MyCustomPreprocessor',
  66. help='the class of preprocessor name for ModelScope')
  67. parser.add_argument(
  68. '-pp',
  69. '--pipeline_name',
  70. type=str,
  71. default='MyCustomPipeline',
  72. help='the class of pipeline name for ModelScope')
  73. parser.add_argument(
  74. '-config',
  75. '--configuration_path',
  76. type=str,
  77. default='./',
  78. help='the path of configuration.json for ModelScope')
  79. parser.set_defaults(func=subparser_func)
  80. def create_template(self):
  81. if self.args.tpl_file_path not in os.listdir(template_path):
  82. tpl_file_path = self.args.tpl_file_path
  83. else:
  84. tpl_file_path = os.path.join(template_path,
  85. self.args.tpl_file_path)
  86. if not os.path.exists(tpl_file_path):
  87. raise ValueError('%s not exists!' % tpl_file_path)
  88. save_file_path = self.args.save_file_path if self.args.save_file_path != './' else os.getcwd(
  89. )
  90. os.makedirs(save_file_path, exist_ok=True)
  91. if not self.args.filename.endswith('.py'):
  92. raise ValueError('the FILENAME must end with .py ')
  93. save_file_name = self.args.filename
  94. save_pkl_path = os.path.join(save_file_path, save_file_name)
  95. if not self.args.configuration_path.endswith('/'):
  96. self.args.configuration_path = self.args.configuration_path + '/'
  97. lines = []
  98. with open(tpl_file_path) as tpl_file:
  99. tpl = Template(tpl_file.read())
  100. lines.append(tpl.substitute(**vars(self.args)))
  101. with open(save_pkl_path, 'w') as save_file:
  102. save_file.writelines(lines)
  103. logger.info('>>> Configuration be saved in %s/%s' %
  104. (self.args.configuration_path, 'configuration.json'))
  105. logger.info('>>> Task_name: %s, Created in %s' %
  106. (self.args.task_name, save_pkl_path))
  107. logger.info('Open the file < %s >, update and run it.' % save_pkl_path)
  108. def execute(self):
  109. if self.args.action == 'create':
  110. self.create_template()
  111. else:
  112. raise ValueError('The parameter of action must be in [create]')