wandb_logger.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os
  2. from .base_logger import BaseLogger
  3. from ppocr.utils.logging import get_logger
  4. class WandbLogger(BaseLogger):
  5. def __init__(
  6. self,
  7. project=None,
  8. name=None,
  9. id=None,
  10. entity=None,
  11. save_dir=None,
  12. config=None,
  13. **kwargs,
  14. ):
  15. try:
  16. import wandb
  17. self.wandb = wandb
  18. except ModuleNotFoundError:
  19. raise ModuleNotFoundError("Please install wandb using `pip install wandb`")
  20. self.project = project
  21. self.name = name
  22. self.id = id
  23. self.save_dir = save_dir
  24. self.config = config
  25. self.kwargs = kwargs
  26. self.entity = entity
  27. self._run = None
  28. self._wandb_init = dict(
  29. project=self.project,
  30. name=self.name,
  31. id=self.id,
  32. entity=self.entity,
  33. dir=self.save_dir,
  34. resume="allow",
  35. )
  36. self._wandb_init.update(**kwargs)
  37. self.logger = get_logger()
  38. _ = self.run
  39. if self.config:
  40. self.run.config.update(self.config)
  41. @property
  42. def run(self):
  43. if self._run is None:
  44. if self.wandb.run is not None:
  45. self.logger.info(
  46. "There is a wandb run already in progress "
  47. "and newly created instances of `WandbLogger` will reuse"
  48. " this run. If this is not desired, call `wandb.finish()`"
  49. "before instantiating `WandbLogger`."
  50. )
  51. self._run = self.wandb.run
  52. else:
  53. self._run = self.wandb.init(**self._wandb_init)
  54. return self._run
  55. def log_metrics(self, metrics, prefix=None, step=None):
  56. if not prefix:
  57. prefix = ""
  58. updated_metrics = {prefix.lower() + "/" + k: v for k, v in metrics.items()}
  59. self.run.log(updated_metrics, step=step)
  60. def log_model(self, is_best, prefix, metadata=None):
  61. model_path = os.path.join(self.save_dir, prefix + ".pdparams")
  62. artifact = self.wandb.Artifact(
  63. "model-{}".format(self.run.id), type="model", metadata=metadata
  64. )
  65. artifact.add_file(model_path, name="model_ckpt.pdparams")
  66. aliases = [prefix]
  67. if is_best:
  68. aliases.append("best")
  69. self.run.log_artifact(artifact, aliases=aliases)
  70. def close(self):
  71. self.run.finish()