_package_pickler.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # mypy: allow-untyped-defs
  2. from pickle import ( # type: ignore[attr-defined]
  3. _compat_pickle,
  4. _extension_registry,
  5. _getattribute,
  6. _Pickler,
  7. EXT1,
  8. EXT2,
  9. EXT4,
  10. GLOBAL,
  11. PicklingError,
  12. STACK_GLOBAL,
  13. )
  14. from struct import pack
  15. from types import FunctionType
  16. from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer
  17. class _PyTorchLegacyPickler(_Pickler):
  18. def __init__(self, *args, **kwargs):
  19. super().__init__(*args, **kwargs)
  20. self._persistent_id = None
  21. def persistent_id(self, obj):
  22. if self._persistent_id is None:
  23. return super().persistent_id(obj)
  24. return self._persistent_id(obj)
  25. class PackagePickler(_PyTorchLegacyPickler):
  26. """Package-aware pickler.
  27. This behaves the same as a normal pickler, except it uses an `Importer`
  28. to find objects and modules to save.
  29. """
  30. def __init__(self, importer: Importer, *args, **kwargs):
  31. self.importer = importer
  32. super().__init__(*args, **kwargs)
  33. # Make sure the dispatch table copied from _Pickler is up-to-date.
  34. # Previous issues have been encountered where a library (e.g. dill)
  35. # mutate _Pickler.dispatch, PackagePickler makes a copy when this lib
  36. # is imported, then the offending library removes its dispatch entries,
  37. # leaving PackagePickler with a stale dispatch table that may cause
  38. # unwanted behavior.
  39. self.dispatch = _Pickler.dispatch.copy() # type: ignore[misc]
  40. self.dispatch[FunctionType] = PackagePickler.save_global # type: ignore[assignment]
  41. def save_global(self, obj, name=None):
  42. # ruff: noqa: F841
  43. # unfortunately the pickler code is factored in a way that
  44. # forces us to copy/paste this function. The only change is marked
  45. # CHANGED below.
  46. write = self.write # type: ignore[attr-defined]
  47. memo = self.memo # type: ignore[attr-defined]
  48. # CHANGED: import module from module environment instead of __import__
  49. try:
  50. module_name, name = self.importer.get_name(obj, name)
  51. except (ObjNotFoundError, ObjMismatchError) as err:
  52. raise PicklingError(f"Can't pickle {obj}: {str(err)}") from err
  53. module = self.importer.import_module(module_name)
  54. _, parent = _getattribute(module, name)
  55. # END CHANGED
  56. if self.proto >= 2: # type: ignore[attr-defined]
  57. code = _extension_registry.get((module_name, name))
  58. if code:
  59. assert code > 0
  60. if code <= 0xFF:
  61. write(EXT1 + pack("<B", code))
  62. elif code <= 0xFFFF:
  63. write(EXT2 + pack("<H", code))
  64. else:
  65. write(EXT4 + pack("<i", code))
  66. return
  67. lastname = name.rpartition(".")[2]
  68. if parent is module:
  69. name = lastname
  70. # Non-ASCII identifiers are supported only with protocols >= 3.
  71. if self.proto >= 4: # type: ignore[attr-defined]
  72. self.save(module_name) # type: ignore[attr-defined]
  73. self.save(name) # type: ignore[attr-defined]
  74. write(STACK_GLOBAL)
  75. elif parent is not module:
  76. self.save_reduce(getattr, (parent, lastname)) # type: ignore[attr-defined]
  77. elif self.proto >= 3: # type: ignore[attr-defined]
  78. write(
  79. GLOBAL
  80. + bytes(module_name, "utf-8")
  81. + b"\n"
  82. + bytes(name, "utf-8")
  83. + b"\n"
  84. )
  85. else:
  86. if self.fix_imports: # type: ignore[attr-defined]
  87. r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
  88. r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
  89. if (module_name, name) in r_name_mapping:
  90. module_name, name = r_name_mapping[(module_name, name)]
  91. elif module_name in r_import_mapping:
  92. module_name = r_import_mapping[module_name]
  93. try:
  94. write(
  95. GLOBAL
  96. + bytes(module_name, "ascii")
  97. + b"\n"
  98. + bytes(name, "ascii")
  99. + b"\n"
  100. )
  101. except UnicodeEncodeError as exc:
  102. raise PicklingError(
  103. f"can't pickle global identifier '{module}.{name}' using "
  104. f"pickle protocol {self.proto:d}" # type: ignore[attr-defined]
  105. ) from exc
  106. self.memoize(obj) # type: ignore[attr-defined]
  107. def create_pickler(data_buf, importer, protocol=4):
  108. if importer is sys_importer:
  109. # if we are using the normal import library system, then
  110. # we can use the C implementation of pickle which is faster
  111. return _PyTorchLegacyPickler(data_buf, protocol=protocol)
  112. else:
  113. return PackagePickler(importer, data_buf, protocol=protocol)