| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- import os
- import sys
- import pathlib
- __dir__ = pathlib.Path(os.path.abspath(__file__))
- sys.path.append(str(__dir__))
- sys.path.append(str(__dir__.parent.parent))
- import paddle
- import paddle.distributed as dist
- from utils import Config, ArgsParser
- def init_args():
- parser = ArgsParser()
- args = parser.parse_args()
- return args
- def main(config, profiler_options):
- from models import build_model, build_loss
- from data_loader import get_dataloader
- from trainer import Trainer
- from post_processing import get_post_processing
- from utils import get_metric
- if paddle.device.cuda.device_count() > 1:
- dist.init_parallel_env()
- config["distributed"] = True
- else:
- config["distributed"] = False
- train_loader = get_dataloader(config["dataset"]["train"], config["distributed"])
- assert train_loader is not None
- if "validate" in config["dataset"]:
- validate_loader = get_dataloader(config["dataset"]["validate"], False)
- else:
- validate_loader = None
- criterion = build_loss(config["loss"])
- config["arch"]["backbone"]["in_channels"] = (
- 3 if config["dataset"]["train"]["dataset"]["args"]["img_mode"] != "GRAY" else 1
- )
- model = build_model(config["arch"])
- # set @to_static for benchmark, skip this by default.
- post_p = get_post_processing(config["post_processing"])
- metric = get_metric(config["metric"])
- trainer = Trainer(
- config=config,
- model=model,
- criterion=criterion,
- train_loader=train_loader,
- post_process=post_p,
- metric_cls=metric,
- validate_loader=validate_loader,
- profiler_options=profiler_options,
- )
- trainer.train()
- if __name__ == "__main__":
- args = init_args()
- assert os.path.exists(args.config_file)
- config = Config(args.config_file)
- config.merge_dict(args.opt)
- main(config.cfg, args.profiler_options)
|