progress.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. """progress."""
  2. import os
  3. from typing import IO, TYPE_CHECKING, Optional
  4. from wandb.errors import CommError
  5. if TYPE_CHECKING:
  6. from typing import Protocol
  7. class ProgressFn(Protocol):
  8. def __call__(self, new_bytes: int, total_bytes: int) -> None:
  9. pass
  10. class Progress:
  11. """A helper class for displaying progress."""
  12. ITER_BYTES = 1024 * 1024
  13. def __init__(
  14. self, file: IO[bytes], callback: Optional["ProgressFn"] = None
  15. ) -> None:
  16. self.file = file
  17. if callback is None:
  18. def callback_(new_bytes: int, total_bytes: int) -> None:
  19. pass
  20. callback = callback_
  21. self.callback: ProgressFn = callback
  22. self.bytes_read = 0
  23. self.len = os.fstat(file.fileno()).st_size
  24. def read(self, size=-1):
  25. """Read bytes and call the callback."""
  26. bites = self.file.read(size)
  27. self.bytes_read += len(bites)
  28. if not bites and self.bytes_read < self.len:
  29. # Files shrinking during uploads causes request timeouts. Maybe
  30. # we could avoid those by updating the self.len in real-time, but
  31. # files getting truncated while uploading seems like something
  32. # that shouldn't really be happening anyway.
  33. raise CommError(
  34. f"File {self.file.name} size shrank from {self.len} to {self.bytes_read} while it was being uploaded."
  35. )
  36. # Growing files are also likely to be bad, but our code didn't break
  37. # on those in the past, so it's riskier to make that an error now.
  38. self.callback(len(bites), self.bytes_read)
  39. return bites
  40. def rewind(self) -> None:
  41. self.callback(-self.bytes_read, 0)
  42. self.bytes_read = 0
  43. self.file.seek(0)
  44. def __getattr__(self, name):
  45. """Fallback to the file object for attrs not defined here."""
  46. if hasattr(self.file, name):
  47. return getattr(self.file, name)
  48. else:
  49. raise AttributeError
  50. def __iter__(self):
  51. return self
  52. def __next__(self):
  53. bites = self.read(self.ITER_BYTES)
  54. if len(bites) == 0:
  55. raise StopIteration
  56. return bites
  57. def __len__(self):
  58. return self.len
  59. next = __next__