_io.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. from __future__ import annotations
  2. import gzip
  3. import io
  4. import pathlib
  5. import tarfile
  6. from typing import (
  7. TYPE_CHECKING,
  8. Any,
  9. Callable,
  10. )
  11. import uuid
  12. import zipfile
  13. from pandas.compat import (
  14. get_bz2_file,
  15. get_lzma_file,
  16. )
  17. from pandas.compat._optional import import_optional_dependency
  18. import pandas as pd
  19. from pandas._testing.contexts import ensure_clean
  20. if TYPE_CHECKING:
  21. from pandas._typing import (
  22. FilePath,
  23. ReadPickleBuffer,
  24. )
  25. from pandas import (
  26. DataFrame,
  27. Series,
  28. )
  29. # ------------------------------------------------------------------
  30. # File-IO
  31. def round_trip_pickle(
  32. obj: Any, path: FilePath | ReadPickleBuffer | None = None
  33. ) -> DataFrame | Series:
  34. """
  35. Pickle an object and then read it again.
  36. Parameters
  37. ----------
  38. obj : any object
  39. The object to pickle and then re-read.
  40. path : str, path object or file-like object, default None
  41. The path where the pickled object is written and then read.
  42. Returns
  43. -------
  44. pandas object
  45. The original object that was pickled and then re-read.
  46. """
  47. _path = path
  48. if _path is None:
  49. _path = f"__{uuid.uuid4()}__.pickle"
  50. with ensure_clean(_path) as temp_path:
  51. pd.to_pickle(obj, temp_path)
  52. return pd.read_pickle(temp_path)
  53. def round_trip_pathlib(writer, reader, path: str | None = None):
  54. """
  55. Write an object to file specified by a pathlib.Path and read it back
  56. Parameters
  57. ----------
  58. writer : callable bound to pandas object
  59. IO writing function (e.g. DataFrame.to_csv )
  60. reader : callable
  61. IO reading function (e.g. pd.read_csv )
  62. path : str, default None
  63. The path where the object is written and then read.
  64. Returns
  65. -------
  66. pandas object
  67. The original object that was serialized and then re-read.
  68. """
  69. Path = pathlib.Path
  70. if path is None:
  71. path = "___pathlib___"
  72. with ensure_clean(path) as path:
  73. writer(Path(path)) # type: ignore[arg-type]
  74. obj = reader(Path(path)) # type: ignore[arg-type]
  75. return obj
  76. def round_trip_localpath(writer, reader, path: str | None = None):
  77. """
  78. Write an object to file specified by a py.path LocalPath and read it back.
  79. Parameters
  80. ----------
  81. writer : callable bound to pandas object
  82. IO writing function (e.g. DataFrame.to_csv )
  83. reader : callable
  84. IO reading function (e.g. pd.read_csv )
  85. path : str, default None
  86. The path where the object is written and then read.
  87. Returns
  88. -------
  89. pandas object
  90. The original object that was serialized and then re-read.
  91. """
  92. import pytest
  93. LocalPath = pytest.importorskip("py.path").local
  94. if path is None:
  95. path = "___localpath___"
  96. with ensure_clean(path) as path:
  97. writer(LocalPath(path))
  98. obj = reader(LocalPath(path))
  99. return obj
  100. def write_to_compressed(compression, path, data, dest: str = "test") -> None:
  101. """
  102. Write data to a compressed file.
  103. Parameters
  104. ----------
  105. compression : {'gzip', 'bz2', 'zip', 'xz', 'zstd'}
  106. The compression type to use.
  107. path : str
  108. The file path to write the data.
  109. data : str
  110. The data to write.
  111. dest : str, default "test"
  112. The destination file (for ZIP only)
  113. Raises
  114. ------
  115. ValueError : An invalid compression value was passed in.
  116. """
  117. args: tuple[Any, ...] = (data,)
  118. mode = "wb"
  119. method = "write"
  120. compress_method: Callable
  121. if compression == "zip":
  122. compress_method = zipfile.ZipFile
  123. mode = "w"
  124. args = (dest, data)
  125. method = "writestr"
  126. elif compression == "tar":
  127. compress_method = tarfile.TarFile
  128. mode = "w"
  129. file = tarfile.TarInfo(name=dest)
  130. bytes = io.BytesIO(data)
  131. file.size = len(data)
  132. args = (file, bytes)
  133. method = "addfile"
  134. elif compression == "gzip":
  135. compress_method = gzip.GzipFile
  136. elif compression == "bz2":
  137. compress_method = get_bz2_file()
  138. elif compression == "zstd":
  139. compress_method = import_optional_dependency("zstandard").open
  140. elif compression == "xz":
  141. compress_method = get_lzma_file()
  142. else:
  143. raise ValueError(f"Unrecognized compression type: {compression}")
  144. with compress_method(path, mode=mode) as f:
  145. getattr(f, method)(*args)