train.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import os
  2. import sys
  3. import pathlib
  4. __dir__ = pathlib.Path(os.path.abspath(__file__))
  5. sys.path.append(str(__dir__))
  6. sys.path.append(str(__dir__.parent.parent))
  7. import paddle
  8. import paddle.distributed as dist
  9. from utils import Config, ArgsParser
  10. def init_args():
  11. parser = ArgsParser()
  12. args = parser.parse_args()
  13. return args
  14. def main(config, profiler_options):
  15. from models import build_model, build_loss
  16. from data_loader import get_dataloader
  17. from trainer import Trainer
  18. from post_processing import get_post_processing
  19. from utils import get_metric
  20. if paddle.device.cuda.device_count() > 1:
  21. dist.init_parallel_env()
  22. config["distributed"] = True
  23. else:
  24. config["distributed"] = False
  25. train_loader = get_dataloader(config["dataset"]["train"], config["distributed"])
  26. assert train_loader is not None
  27. if "validate" in config["dataset"]:
  28. validate_loader = get_dataloader(config["dataset"]["validate"], False)
  29. else:
  30. validate_loader = None
  31. criterion = build_loss(config["loss"])
  32. config["arch"]["backbone"]["in_channels"] = (
  33. 3 if config["dataset"]["train"]["dataset"]["args"]["img_mode"] != "GRAY" else 1
  34. )
  35. model = build_model(config["arch"])
  36. # set @to_static for benchmark, skip this by default.
  37. post_p = get_post_processing(config["post_processing"])
  38. metric = get_metric(config["metric"])
  39. trainer = Trainer(
  40. config=config,
  41. model=model,
  42. criterion=criterion,
  43. train_loader=train_loader,
  44. post_process=post_p,
  45. metric_cls=metric,
  46. validate_loader=validate_loader,
  47. profiler_options=profiler_options,
  48. )
  49. trainer.train()
  50. if __name__ == "__main__":
  51. args = init_args()
  52. assert os.path.exists(args.config_file)
  53. config = Config(args.config_file)
  54. config.merge_dict(args.opt)
  55. main(config.cfg, args.profiler_options)