hashutil.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from __future__ import annotations
  2. import base64
  3. import hashlib
  4. import logging
  5. import mmap
  6. import sys
  7. import time
  8. from typing import TYPE_CHECKING
  9. from typing_extensions import TypeAlias
  10. from wandb.sdk.lib.paths import StrPath
  11. if TYPE_CHECKING:
  12. import _hashlib # type: ignore[import-not-found]
  13. logger = logging.getLogger(__name__)
  14. # In the future, consider relying on pydantic to validate these types via e.g.
  15. # - Base64Str: https://docs.pydantic.dev/latest/api/types/#pydantic.types.Base64Str
  16. # - a custom EncodedStr + Encoder impl: https://docs.pydantic.dev/latest/api/types/#pydantic.types.EncodedStr
  17. #
  18. # Note that so long as we continue to support Pydantic v1, the options above will require a compatible shim/backport
  19. # implementation, since those types are not in Pydantic v1.
  20. ETag: TypeAlias = str
  21. HexMD5: TypeAlias = str
  22. B64MD5: TypeAlias = str
  23. def _md5(data: bytes = b"") -> _hashlib.HASH:
  24. """Allow FIPS-compliant md5 hash when supported."""
  25. if sys.version_info >= (3, 9):
  26. return hashlib.md5(data, usedforsecurity=False)
  27. else:
  28. return hashlib.md5(data)
  29. def md5_string(string: str) -> B64MD5:
  30. return _b64_from_hasher(_md5(string.encode("utf-8")))
  31. def _b64_from_hasher(hasher: _hashlib.HASH) -> B64MD5:
  32. return B64MD5(base64.b64encode(hasher.digest()).decode("ascii"))
  33. def b64_to_hex_id(string: B64MD5) -> HexMD5:
  34. return HexMD5(base64.standard_b64decode(string).hex())
  35. def hex_to_b64_id(encoded_string: str | bytes) -> B64MD5:
  36. if isinstance(encoded_string, bytes):
  37. encoded_string = encoded_string.decode("utf-8")
  38. as_str = bytes.fromhex(encoded_string)
  39. return B64MD5(base64.standard_b64encode(as_str).decode("utf-8"))
  40. def md5_file_b64(*paths: StrPath) -> B64MD5:
  41. start_time = time.monotonic()
  42. digest = _b64_from_hasher(_md5_file_hasher(*paths))
  43. hash_time_seconds = time.monotonic() - start_time
  44. if hash_time_seconds > 1.0:
  45. logger.debug(
  46. "Computed MD5 hash for file. paths=%s, hashTimeMs=%d",
  47. paths,
  48. int(hash_time_seconds * 1000),
  49. )
  50. return digest
  51. def md5_file_hex(*paths: StrPath) -> HexMD5:
  52. return HexMD5(_md5_file_hasher(*paths).hexdigest())
  53. _KB: int = 1_024
  54. _CHUNKSIZE: int = 128 * _KB
  55. """Chunk size (in bytes) for iteratively reading from file, if needed."""
  56. def _md5_file_hasher(*paths: StrPath) -> _hashlib.HASH:
  57. md5_hash = _md5()
  58. # Note: We use str paths (instead of pathlib.Path objs) for minor perf improvements.
  59. for path in sorted(map(str, paths)):
  60. with open(path, "rb") as f:
  61. try:
  62. with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mview:
  63. md5_hash.update(mview)
  64. except OSError:
  65. # This occurs if the mmap-ed file is on a different/mounted filesystem,
  66. # so we'll fall back on a less performant implementation.
  67. # Note: At the time of implementation, the walrus operator `:=`
  68. # is avoided to maintain support for users on python 3.7.
  69. # Consider revisiting once 3.7 support is no longer needed.
  70. chunk = f.read(_CHUNKSIZE)
  71. while chunk:
  72. md5_hash.update(chunk)
  73. chunk = f.read(_CHUNKSIZE)
  74. except ValueError:
  75. # This occurs when mmap-ing an empty file, which can be skipped.
  76. # See: https://github.com/python/cpython/blob/986a4e1b6fcae7fe7a1d0a26aea446107dd58dd2/Modules/mmapmodule.c#L1589
  77. pass
  78. return md5_hash