general.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import os
  2. import logging
  3. import wandb
  4. import torch
  5. def set_logging(name=None, verbose=True):
  6. for handler in logging.root.handlers[:]:
  7. logging.root.removeHandler(handler)
  8. # Sets level and returns logger
  9. rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
  10. logging.basicConfig(format="%(message)s", level=logging.INFO if (verbose and rank in (-1, 0)) else logging.WARNING)
  11. return logging.getLogger(name)
  12. LOGGER = set_logging(__name__) # define globally (used in train.py, val.py, detect.py, etc.)
  13. LOGGERS = ('csv', 'tb', 'wandb')
  14. CUDA = True if torch.cuda.is_available() else False
  15. DEVICE = 'cuda' if CUDA else 'cpu'
  16. LOGGER_WANDB = 'wandb'
  17. LOGGER_TENSORBOARD = 'tb'
  18. class Loggers():
  19. def __init__(self, hyp):
  20. self.type = hyp['logger']['type']
  21. self.epochs = hyp['train']['epochs']
  22. self.wandb = None
  23. self.writer = None
  24. if self.type == LOGGER_WANDB:
  25. if hyp['logger']['project'] == '':
  26. project = 'ComicTextDetector'
  27. else:
  28. project = hyp['logger']['project']
  29. if hyp['logger']['run_id'] == '':
  30. self.wandb = wandb.init(project=project, config=hyp, resume='allow')
  31. else:
  32. self.wandb = wandb.init(project=project, config=hyp, resume='must', id=hyp['logger']['run_id'])
  33. elif self.type == LOGGER_TENSORBOARD:
  34. from torch.utils.tensorboard import SummaryWriter
  35. self.writer = SummaryWriter(hyp['data']['save_dir'])
  36. def on_train_batch_end(self, metrics):
  37. # Callback runs on train batch end
  38. if self.wandb:
  39. self.wandb.log(metrics)
  40. pass
  41. def on_train_epoch_end(self, epoch, metrics):
  42. LOGGER.info(f'fin epoch {epoch}/{self.epochs}, metrics: {metrics}')
  43. if self.type == LOGGER_WANDB:
  44. self.wandb.log(metrics)
  45. elif self.type == LOGGER_TENSORBOARD:
  46. for key in metrics.keys():
  47. self.writer.add_scalar(key, metrics[key], epoch)
  48. def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
  49. # Callback runs on model save event
  50. if self.wandb:
  51. if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
  52. self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)