stats.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import threading
  2. from typing import MutableMapping, NamedTuple
  3. from wandb.sdk.lib import filenames
  4. class FileStats(NamedTuple):
  5. deduped: bool
  6. total: int
  7. uploaded: int
  8. failed: bool
  9. artifact_file: bool
  10. class Summary(NamedTuple):
  11. uploaded_bytes: int
  12. total_bytes: int
  13. deduped_bytes: int
  14. class FileCountsByCategory(NamedTuple):
  15. artifact: int
  16. wandb: int
  17. media: int
  18. other: int
  19. class Stats:
  20. def __init__(self) -> None:
  21. self._stats: MutableMapping[str, FileStats] = {}
  22. self._lock = threading.Lock()
  23. def init_file(
  24. self, save_name: str, size: int, is_artifact_file: bool = False
  25. ) -> None:
  26. with self._lock:
  27. self._stats[save_name] = FileStats(
  28. deduped=False,
  29. total=size,
  30. uploaded=0,
  31. failed=False,
  32. artifact_file=is_artifact_file,
  33. )
  34. def set_file_deduped(self, save_name: str) -> None:
  35. with self._lock:
  36. orig = self._stats[save_name]
  37. self._stats[save_name] = orig._replace(
  38. deduped=True,
  39. uploaded=orig.total,
  40. )
  41. def update_uploaded_file(self, save_name: str, total_uploaded: int) -> None:
  42. with self._lock:
  43. self._stats[save_name] = self._stats[save_name]._replace(
  44. uploaded=total_uploaded,
  45. )
  46. def update_failed_file(self, save_name: str) -> None:
  47. with self._lock:
  48. self._stats[save_name] = self._stats[save_name]._replace(
  49. uploaded=0,
  50. failed=True,
  51. )
  52. def summary(self) -> Summary:
  53. # Need to use list to ensure we get a copy, since other threads may
  54. # modify this while we iterate
  55. with self._lock:
  56. stats = list(self._stats.values())
  57. return Summary(
  58. uploaded_bytes=sum(f.uploaded for f in stats),
  59. total_bytes=sum(f.total for f in stats),
  60. deduped_bytes=sum(f.total for f in stats if f.deduped),
  61. )
  62. def file_counts_by_category(self) -> FileCountsByCategory:
  63. artifact_files = 0
  64. wandb_files = 0
  65. media_files = 0
  66. other_files = 0
  67. # Need to use list to ensure we get a copy, since other threads may
  68. # modify this while we iterate
  69. with self._lock:
  70. file_stats = list(self._stats.items())
  71. for save_name, stats in file_stats:
  72. if stats.artifact_file:
  73. artifact_files += 1
  74. elif filenames.is_wandb_file(save_name):
  75. wandb_files += 1
  76. elif save_name.startswith("media"):
  77. media_files += 1
  78. else:
  79. other_files += 1
  80. return FileCountsByCategory(
  81. artifact=artifact_files,
  82. wandb=wandb_files,
  83. media=media_files,
  84. other=other_files,
  85. )