_util.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from __future__ import annotations
  2. from typing import (
  3. TYPE_CHECKING,
  4. Literal,
  5. )
  6. import numpy as np
  7. from pandas._config import using_string_dtype
  8. from pandas._libs import lib
  9. from pandas.compat import (
  10. pa_version_under18p0,
  11. pa_version_under19p0,
  12. )
  13. from pandas.compat._optional import import_optional_dependency
  14. import pandas as pd
  15. if TYPE_CHECKING:
  16. from collections.abc import Callable
  17. import pyarrow
  18. from pandas._typing import DtypeBackend
  19. def _arrow_dtype_mapping() -> dict:
  20. pa = import_optional_dependency("pyarrow")
  21. return {
  22. pa.int8(): pd.Int8Dtype(),
  23. pa.int16(): pd.Int16Dtype(),
  24. pa.int32(): pd.Int32Dtype(),
  25. pa.int64(): pd.Int64Dtype(),
  26. pa.uint8(): pd.UInt8Dtype(),
  27. pa.uint16(): pd.UInt16Dtype(),
  28. pa.uint32(): pd.UInt32Dtype(),
  29. pa.uint64(): pd.UInt64Dtype(),
  30. pa.bool_(): pd.BooleanDtype(),
  31. pa.string(): pd.StringDtype(),
  32. pa.float32(): pd.Float32Dtype(),
  33. pa.float64(): pd.Float64Dtype(),
  34. pa.string(): pd.StringDtype(),
  35. pa.large_string(): pd.StringDtype(),
  36. }
  37. def _arrow_string_types_mapper() -> Callable:
  38. pa = import_optional_dependency("pyarrow")
  39. mapping = {
  40. pa.string(): pd.StringDtype(na_value=np.nan),
  41. pa.large_string(): pd.StringDtype(na_value=np.nan),
  42. }
  43. if not pa_version_under18p0:
  44. mapping[pa.string_view()] = pd.StringDtype(na_value=np.nan)
  45. return mapping.get
  46. def arrow_table_to_pandas(
  47. table: pyarrow.Table,
  48. dtype_backend: DtypeBackend | Literal["numpy"] | lib.NoDefault = lib.no_default,
  49. null_to_int64: bool = False,
  50. to_pandas_kwargs: dict | None = None,
  51. ) -> pd.DataFrame:
  52. if to_pandas_kwargs is None:
  53. to_pandas_kwargs = {}
  54. pa = import_optional_dependency("pyarrow")
  55. types_mapper: type[pd.ArrowDtype] | None | Callable
  56. if dtype_backend == "numpy_nullable":
  57. mapping = _arrow_dtype_mapping()
  58. if null_to_int64:
  59. # Modify the default mapping to also map null to Int64
  60. # (to match other engines - only for CSV parser)
  61. mapping[pa.null()] = pd.Int64Dtype()
  62. types_mapper = mapping.get
  63. elif dtype_backend == "pyarrow":
  64. types_mapper = pd.ArrowDtype
  65. elif using_string_dtype():
  66. if pa_version_under19p0:
  67. types_mapper = _arrow_string_types_mapper()
  68. else:
  69. types_mapper = None
  70. elif dtype_backend is lib.no_default or dtype_backend == "numpy":
  71. types_mapper = None
  72. else:
  73. raise NotImplementedError
  74. df = table.to_pandas(types_mapper=types_mapper, **to_pandas_kwargs)
  75. return df