pickle_compat.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. """
  2. Pickle compatibility to pandas version 1.0
  3. """
  4. from __future__ import annotations
  5. import contextlib
  6. import io
  7. import pickle
  8. from typing import (
  9. TYPE_CHECKING,
  10. Any,
  11. )
  12. import numpy as np
  13. from pandas._libs.arrays import NDArrayBacked
  14. from pandas._libs.tslibs import BaseOffset
  15. from pandas.core.arrays import (
  16. DatetimeArray,
  17. PeriodArray,
  18. TimedeltaArray,
  19. )
  20. from pandas.core.internals import BlockManager
  21. if TYPE_CHECKING:
  22. from collections.abc import Generator
  23. # If classes are moved, provide compat here.
  24. _class_locations_map = {
  25. # Re-routing unpickle block logic to go through _unpickle_block instead
  26. # for pandas <= 1.3.5
  27. ("pandas.core.internals.blocks", "new_block"): (
  28. "pandas._libs.internals",
  29. "_unpickle_block",
  30. ),
  31. # Avoid Cython's warning "contradiction to Python 'class private name' rules"
  32. ("pandas._libs.tslibs.nattype", "__nat_unpickle"): (
  33. "pandas._libs.tslibs.nattype",
  34. "_nat_unpickle",
  35. ),
  36. # 50775, remove Int64Index, UInt64Index & Float64Index from codebase
  37. ("pandas.core.indexes.numeric", "Int64Index"): (
  38. "pandas.core.indexes.base",
  39. "Index",
  40. ),
  41. ("pandas.core.indexes.numeric", "UInt64Index"): (
  42. "pandas.core.indexes.base",
  43. "Index",
  44. ),
  45. ("pandas.core.indexes.numeric", "Float64Index"): (
  46. "pandas.core.indexes.base",
  47. "Index",
  48. ),
  49. ("pandas.core.arrays.sparse.dtype", "SparseDtype"): (
  50. "pandas.core.dtypes.dtypes",
  51. "SparseDtype",
  52. ),
  53. }
  54. # our Unpickler sub-class to override methods and some dispatcher
  55. # functions for compat and uses a non-public class of the pickle module.
  56. class Unpickler(pickle._Unpickler):
  57. def find_class(self, module: str, name: str) -> Any:
  58. key = (module, name)
  59. module, name = _class_locations_map.get(key, key)
  60. return super().find_class(module, name)
  61. dispatch = pickle._Unpickler.dispatch.copy()
  62. def load_reduce(self) -> None:
  63. stack = self.stack # type: ignore[attr-defined]
  64. args = stack.pop()
  65. func = stack[-1]
  66. try:
  67. stack[-1] = func(*args)
  68. except TypeError:
  69. # If we have a deprecated function,
  70. # try to replace and try again.
  71. if args and isinstance(args[0], type) and issubclass(args[0], BaseOffset):
  72. # TypeError: object.__new__(Day) is not safe, use Day.__new__()
  73. cls = args[0]
  74. stack[-1] = cls.__new__(*args)
  75. return
  76. elif args and issubclass(args[0], PeriodArray):
  77. cls = args[0]
  78. stack[-1] = NDArrayBacked.__new__(*args)
  79. return
  80. raise
  81. dispatch[pickle.REDUCE[0]] = load_reduce # type: ignore[assignment]
  82. def load_newobj(self) -> None:
  83. args = self.stack.pop() # type: ignore[attr-defined]
  84. cls = self.stack.pop() # type: ignore[attr-defined]
  85. # compat
  86. if issubclass(cls, DatetimeArray) and not args:
  87. arr = np.array([], dtype="M8[ns]")
  88. obj = cls.__new__(cls, arr, arr.dtype)
  89. elif issubclass(cls, TimedeltaArray) and not args:
  90. arr = np.array([], dtype="m8[ns]")
  91. obj = cls.__new__(cls, arr, arr.dtype)
  92. elif cls is BlockManager and not args:
  93. obj = cls.__new__(cls, (), [], False)
  94. else:
  95. obj = cls.__new__(cls, *args)
  96. self.append(obj) # type: ignore[attr-defined]
  97. dispatch[pickle.NEWOBJ[0]] = load_newobj # type: ignore[assignment]
  98. def loads(
  99. bytes_object: bytes,
  100. *,
  101. fix_imports: bool = True,
  102. encoding: str = "ASCII",
  103. errors: str = "strict",
  104. ) -> Any:
  105. """
  106. Analogous to pickle._loads.
  107. """
  108. fd = io.BytesIO(bytes_object)
  109. return Unpickler(
  110. fd, fix_imports=fix_imports, encoding=encoding, errors=errors
  111. ).load()
  112. @contextlib.contextmanager
  113. def patch_pickle() -> Generator[None]:
  114. """
  115. Temporarily patch pickle to use our unpickler.
  116. """
  117. orig_loads = pickle.loads
  118. try:
  119. setattr(pickle, "loads", loads)
  120. yield
  121. finally:
  122. setattr(pickle, "loads", orig_loads)