train.py 607 B

12345678910111213141516171819202122232425
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import argparse
  3. from modelscope.trainers import build_trainer
  4. def parse_args():
  5. parser = argparse.ArgumentParser(description='Train a model')
  6. parser.add_argument('config', help='config file path', type=str)
  7. parser.add_argument(
  8. 'trainer_name', help='name for trainer', type=str, default=None)
  9. args = parser.parse_args()
  10. return args
  11. def main():
  12. args = parse_args()
  13. kwargs = dict(cfg_file=args.config)
  14. trainer = build_trainer(args.trainer_name, kwargs)
  15. trainer.train()
  16. if __name__ == '__main__':
  17. main()