_util.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. from __future__ import annotations
  2. from typing import (
  3. TYPE_CHECKING,
  4. Literal,
  5. )
  6. import numpy as np
  7. from pandas._config import using_string_dtype
  8. from pandas._libs import lib
  9. from pandas.compat import (
  10. pa_version_under18p0,
  11. pa_version_under19p0,
  12. )
  13. from pandas.compat._optional import import_optional_dependency
  14. from pandas.core.dtypes.common import pandas_dtype
  15. import pandas as pd
  16. if TYPE_CHECKING:
  17. from collections.abc import (
  18. Callable,
  19. Hashable,
  20. Sequence,
  21. )
  22. import pyarrow
  23. from pandas._typing import (
  24. DtypeArg,
  25. DtypeBackend,
  26. )
  27. def _arrow_dtype_mapping() -> dict:
  28. pa = import_optional_dependency("pyarrow")
  29. return {
  30. pa.int8(): pd.Int8Dtype(),
  31. pa.int16(): pd.Int16Dtype(),
  32. pa.int32(): pd.Int32Dtype(),
  33. pa.int64(): pd.Int64Dtype(),
  34. pa.uint8(): pd.UInt8Dtype(),
  35. pa.uint16(): pd.UInt16Dtype(),
  36. pa.uint32(): pd.UInt32Dtype(),
  37. pa.uint64(): pd.UInt64Dtype(),
  38. pa.bool_(): pd.BooleanDtype(),
  39. pa.string(): pd.StringDtype(),
  40. pa.float32(): pd.Float32Dtype(),
  41. pa.float64(): pd.Float64Dtype(),
  42. pa.string(): pd.StringDtype(),
  43. pa.large_string(): pd.StringDtype(),
  44. }
  45. def _arrow_string_types_mapper() -> Callable:
  46. pa = import_optional_dependency("pyarrow")
  47. mapping = {
  48. pa.string(): pd.StringDtype(na_value=np.nan),
  49. pa.large_string(): pd.StringDtype(na_value=np.nan),
  50. }
  51. if not pa_version_under18p0:
  52. mapping[pa.string_view()] = pd.StringDtype(na_value=np.nan)
  53. return mapping.get
  54. def arrow_table_to_pandas(
  55. table: pyarrow.Table,
  56. dtype_backend: DtypeBackend | Literal["numpy"] | lib.NoDefault = lib.no_default,
  57. null_to_int64: bool = False,
  58. to_pandas_kwargs: dict | None = None,
  59. dtype: DtypeArg | None = None,
  60. names: Sequence[Hashable] | None = None,
  61. ) -> pd.DataFrame:
  62. pa = import_optional_dependency("pyarrow")
  63. to_pandas_kwargs = {} if to_pandas_kwargs is None else to_pandas_kwargs
  64. types_mapper: type[pd.ArrowDtype] | None | Callable
  65. if dtype_backend == "numpy_nullable":
  66. mapping = _arrow_dtype_mapping()
  67. if null_to_int64:
  68. # Modify the default mapping to also map null to Int64
  69. # (to match other engines - only for CSV parser)
  70. mapping[pa.null()] = pd.Int64Dtype()
  71. types_mapper = mapping.get
  72. elif dtype_backend == "pyarrow":
  73. types_mapper = pd.ArrowDtype
  74. elif using_string_dtype():
  75. if pa_version_under19p0:
  76. types_mapper = _arrow_string_types_mapper()
  77. elif dtype is not None:
  78. # GH#56136 Avoid lossy conversion to float64
  79. # We'll convert to numpy below if
  80. types_mapper = {
  81. pa.int8(): pd.Int8Dtype(),
  82. pa.int16(): pd.Int16Dtype(),
  83. pa.int32(): pd.Int32Dtype(),
  84. pa.int64(): pd.Int64Dtype(),
  85. }.get
  86. else:
  87. types_mapper = None
  88. elif dtype_backend is lib.no_default or dtype_backend == "numpy":
  89. if dtype is not None:
  90. # GH#56136 Avoid lossy conversion to float64
  91. # We'll convert to numpy below if
  92. types_mapper = {
  93. pa.int8(): pd.Int8Dtype(),
  94. pa.int16(): pd.Int16Dtype(),
  95. pa.int32(): pd.Int32Dtype(),
  96. pa.int64(): pd.Int64Dtype(),
  97. }.get
  98. else:
  99. types_mapper = None
  100. else:
  101. raise NotImplementedError
  102. df = table.to_pandas(types_mapper=types_mapper, **to_pandas_kwargs)
  103. return _post_convert_dtypes(df, dtype_backend, dtype, names)
  104. def _post_convert_dtypes(
  105. df: pd.DataFrame,
  106. dtype_backend: DtypeBackend | Literal["numpy"] | lib.NoDefault,
  107. dtype: DtypeArg | None,
  108. names: Sequence[Hashable] | None,
  109. ) -> pd.DataFrame:
  110. if dtype is not None and (
  111. dtype_backend is lib.no_default or dtype_backend == "numpy"
  112. ):
  113. # GH#56136 apply any user-provided dtype, and convert any IntegerDtype
  114. # columns the user didn't explicitly ask for.
  115. if isinstance(dtype, dict):
  116. if names is not None:
  117. df.columns = names
  118. cmp_dtypes = {
  119. pd.Int8Dtype(),
  120. pd.Int16Dtype(),
  121. pd.Int32Dtype(),
  122. pd.Int64Dtype(),
  123. }
  124. for col in df.columns:
  125. if col not in dtype and df[col].dtype in cmp_dtypes:
  126. # Any key that the user didn't explicitly specify
  127. # that got converted to IntegerDtype now gets converted
  128. # to numpy dtype.
  129. dtype[col] = df[col].dtype.numpy_dtype
  130. # Ignore non-existent columns from dtype mapping
  131. # like other parsers do
  132. dtype = {
  133. key: pandas_dtype(dtype[key]) for key in dtype if key in df.columns
  134. }
  135. else:
  136. dtype = pandas_dtype(dtype)
  137. try:
  138. df = df.astype(dtype)
  139. except TypeError as err:
  140. # GH#44901 reraise to keep api consistent
  141. raise ValueError(str(err)) from err
  142. return df