__init__.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. from __future__ import annotations
  2. from decimal import Decimal
  3. import operator
  4. import os
  5. from sys import byteorder
  6. from typing import (
  7. TYPE_CHECKING,
  8. Callable,
  9. ContextManager,
  10. )
  11. import warnings
  12. import numpy as np
  13. from pandas._config import using_string_dtype
  14. from pandas._config.localization import (
  15. can_set_locale,
  16. get_locales,
  17. set_locale,
  18. )
  19. from pandas.compat import pa_version_under10p1
  20. import pandas as pd
  21. from pandas import (
  22. ArrowDtype,
  23. DataFrame,
  24. Index,
  25. MultiIndex,
  26. RangeIndex,
  27. Series,
  28. )
  29. from pandas._testing._io import (
  30. round_trip_localpath,
  31. round_trip_pathlib,
  32. round_trip_pickle,
  33. write_to_compressed,
  34. )
  35. from pandas._testing._warnings import (
  36. assert_produces_warning,
  37. maybe_produces_warning,
  38. )
  39. from pandas._testing.asserters import (
  40. assert_almost_equal,
  41. assert_attr_equal,
  42. assert_categorical_equal,
  43. assert_class_equal,
  44. assert_contains_all,
  45. assert_copy,
  46. assert_datetime_array_equal,
  47. assert_dict_equal,
  48. assert_equal,
  49. assert_extension_array_equal,
  50. assert_frame_equal,
  51. assert_index_equal,
  52. assert_indexing_slices_equivalent,
  53. assert_interval_array_equal,
  54. assert_is_sorted,
  55. assert_is_valid_plot_return_object,
  56. assert_metadata_equivalent,
  57. assert_numpy_array_equal,
  58. assert_period_array_equal,
  59. assert_series_equal,
  60. assert_sp_array_equal,
  61. assert_timedelta_array_equal,
  62. raise_assert_detail,
  63. )
  64. from pandas._testing.compat import (
  65. get_dtype,
  66. get_obj,
  67. )
  68. from pandas._testing.contexts import (
  69. assert_cow_warning,
  70. decompress_file,
  71. ensure_clean,
  72. raises_chained_assignment_error,
  73. set_timezone,
  74. use_numexpr,
  75. with_csv_dialect,
  76. )
  77. from pandas.core.arrays import (
  78. ArrowExtensionArray,
  79. BaseMaskedArray,
  80. NumpyExtensionArray,
  81. )
  82. from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
  83. from pandas.core.construction import extract_array
  84. if TYPE_CHECKING:
  85. from pandas._typing import (
  86. Dtype,
  87. NpDtype,
  88. )
  89. UNSIGNED_INT_NUMPY_DTYPES: list[NpDtype] = ["uint8", "uint16", "uint32", "uint64"]
  90. UNSIGNED_INT_EA_DTYPES: list[Dtype] = ["UInt8", "UInt16", "UInt32", "UInt64"]
  91. SIGNED_INT_NUMPY_DTYPES: list[NpDtype] = [int, "int8", "int16", "int32", "int64"]
  92. SIGNED_INT_EA_DTYPES: list[Dtype] = ["Int8", "Int16", "Int32", "Int64"]
  93. ALL_INT_NUMPY_DTYPES = UNSIGNED_INT_NUMPY_DTYPES + SIGNED_INT_NUMPY_DTYPES
  94. ALL_INT_EA_DTYPES = UNSIGNED_INT_EA_DTYPES + SIGNED_INT_EA_DTYPES
  95. ALL_INT_DTYPES: list[Dtype] = [*ALL_INT_NUMPY_DTYPES, *ALL_INT_EA_DTYPES]
  96. FLOAT_NUMPY_DTYPES: list[NpDtype] = [float, "float32", "float64"]
  97. FLOAT_EA_DTYPES: list[Dtype] = ["Float32", "Float64"]
  98. ALL_FLOAT_DTYPES: list[Dtype] = [*FLOAT_NUMPY_DTYPES, *FLOAT_EA_DTYPES]
  99. COMPLEX_DTYPES: list[Dtype] = [complex, "complex64", "complex128"]
  100. if using_string_dtype():
  101. STRING_DTYPES: list[Dtype] = ["U"]
  102. else:
  103. STRING_DTYPES: list[Dtype] = [str, "str", "U"] # type: ignore[no-redef]
  104. COMPLEX_FLOAT_DTYPES: list[Dtype] = [*COMPLEX_DTYPES, *FLOAT_NUMPY_DTYPES]
  105. DATETIME64_DTYPES: list[Dtype] = ["datetime64[ns]", "M8[ns]"]
  106. TIMEDELTA64_DTYPES: list[Dtype] = ["timedelta64[ns]", "m8[ns]"]
  107. BOOL_DTYPES: list[Dtype] = [bool, "bool"]
  108. BYTES_DTYPES: list[Dtype] = [bytes, "bytes"]
  109. OBJECT_DTYPES: list[Dtype] = [object, "object"]
  110. ALL_REAL_NUMPY_DTYPES = FLOAT_NUMPY_DTYPES + ALL_INT_NUMPY_DTYPES
  111. ALL_REAL_EXTENSION_DTYPES = FLOAT_EA_DTYPES + ALL_INT_EA_DTYPES
  112. ALL_REAL_DTYPES: list[Dtype] = [*ALL_REAL_NUMPY_DTYPES, *ALL_REAL_EXTENSION_DTYPES]
  113. ALL_NUMERIC_DTYPES: list[Dtype] = [*ALL_REAL_DTYPES, *COMPLEX_DTYPES]
  114. ALL_NUMPY_DTYPES = (
  115. ALL_REAL_NUMPY_DTYPES
  116. + COMPLEX_DTYPES
  117. + STRING_DTYPES
  118. + DATETIME64_DTYPES
  119. + TIMEDELTA64_DTYPES
  120. + BOOL_DTYPES
  121. + OBJECT_DTYPES
  122. + BYTES_DTYPES
  123. )
  124. NARROW_NP_DTYPES = [
  125. np.float16,
  126. np.float32,
  127. np.int8,
  128. np.int16,
  129. np.int32,
  130. np.uint8,
  131. np.uint16,
  132. np.uint32,
  133. ]
  134. PYTHON_DATA_TYPES = [
  135. str,
  136. int,
  137. float,
  138. complex,
  139. list,
  140. tuple,
  141. range,
  142. dict,
  143. set,
  144. frozenset,
  145. bool,
  146. bytes,
  147. bytearray,
  148. memoryview,
  149. ]
  150. ENDIAN = {"little": "<", "big": ">"}[byteorder]
  151. NULL_OBJECTS = [None, np.nan, pd.NaT, float("nan"), pd.NA, Decimal("NaN")]
  152. NP_NAT_OBJECTS = [
  153. cls("NaT", unit)
  154. for cls in [np.datetime64, np.timedelta64]
  155. for unit in [
  156. "Y",
  157. "M",
  158. "W",
  159. "D",
  160. "h",
  161. "m",
  162. "s",
  163. "ms",
  164. "us",
  165. "ns",
  166. "ps",
  167. "fs",
  168. "as",
  169. ]
  170. ]
  171. if not pa_version_under10p1:
  172. import pyarrow as pa
  173. UNSIGNED_INT_PYARROW_DTYPES = [pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()]
  174. SIGNED_INT_PYARROW_DTYPES = [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
  175. ALL_INT_PYARROW_DTYPES = UNSIGNED_INT_PYARROW_DTYPES + SIGNED_INT_PYARROW_DTYPES
  176. ALL_INT_PYARROW_DTYPES_STR_REPR = [
  177. str(ArrowDtype(typ)) for typ in ALL_INT_PYARROW_DTYPES
  178. ]
  179. # pa.float16 doesn't seem supported
  180. # https://github.com/apache/arrow/blob/master/python/pyarrow/src/arrow/python/helpers.cc#L86
  181. FLOAT_PYARROW_DTYPES = [pa.float32(), pa.float64()]
  182. FLOAT_PYARROW_DTYPES_STR_REPR = [
  183. str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES
  184. ]
  185. DECIMAL_PYARROW_DTYPES = [pa.decimal128(7, 3)]
  186. STRING_PYARROW_DTYPES = [pa.string()]
  187. BINARY_PYARROW_DTYPES = [pa.binary()]
  188. TIME_PYARROW_DTYPES = [
  189. pa.time32("s"),
  190. pa.time32("ms"),
  191. pa.time64("us"),
  192. pa.time64("ns"),
  193. ]
  194. DATE_PYARROW_DTYPES = [pa.date32(), pa.date64()]
  195. DATETIME_PYARROW_DTYPES = [
  196. pa.timestamp(unit=unit, tz=tz)
  197. for unit in ["s", "ms", "us", "ns"]
  198. for tz in [None, "UTC", "US/Pacific", "US/Eastern"]
  199. ]
  200. TIMEDELTA_PYARROW_DTYPES = [pa.duration(unit) for unit in ["s", "ms", "us", "ns"]]
  201. BOOL_PYARROW_DTYPES = [pa.bool_()]
  202. # TODO: Add container like pyarrow types:
  203. # https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions
  204. ALL_PYARROW_DTYPES = (
  205. ALL_INT_PYARROW_DTYPES
  206. + FLOAT_PYARROW_DTYPES
  207. + DECIMAL_PYARROW_DTYPES
  208. + STRING_PYARROW_DTYPES
  209. + BINARY_PYARROW_DTYPES
  210. + TIME_PYARROW_DTYPES
  211. + DATE_PYARROW_DTYPES
  212. + DATETIME_PYARROW_DTYPES
  213. + TIMEDELTA_PYARROW_DTYPES
  214. + BOOL_PYARROW_DTYPES
  215. )
  216. ALL_REAL_PYARROW_DTYPES_STR_REPR = (
  217. ALL_INT_PYARROW_DTYPES_STR_REPR + FLOAT_PYARROW_DTYPES_STR_REPR
  218. )
  219. else:
  220. FLOAT_PYARROW_DTYPES_STR_REPR = []
  221. ALL_INT_PYARROW_DTYPES_STR_REPR = []
  222. ALL_PYARROW_DTYPES = []
  223. ALL_REAL_PYARROW_DTYPES_STR_REPR = []
  224. ALL_REAL_NULLABLE_DTYPES = (
  225. FLOAT_NUMPY_DTYPES + ALL_REAL_EXTENSION_DTYPES + ALL_REAL_PYARROW_DTYPES_STR_REPR
  226. )
  227. arithmetic_dunder_methods = [
  228. "__add__",
  229. "__radd__",
  230. "__sub__",
  231. "__rsub__",
  232. "__mul__",
  233. "__rmul__",
  234. "__floordiv__",
  235. "__rfloordiv__",
  236. "__truediv__",
  237. "__rtruediv__",
  238. "__pow__",
  239. "__rpow__",
  240. "__mod__",
  241. "__rmod__",
  242. ]
  243. comparison_dunder_methods = ["__eq__", "__ne__", "__le__", "__lt__", "__ge__", "__gt__"]
  244. # -----------------------------------------------------------------------------
  245. # Comparators
  246. def box_expected(expected, box_cls, transpose: bool = True):
  247. """
  248. Helper function to wrap the expected output of a test in a given box_class.
  249. Parameters
  250. ----------
  251. expected : np.ndarray, Index, Series
  252. box_cls : {Index, Series, DataFrame}
  253. Returns
  254. -------
  255. subclass of box_cls
  256. """
  257. if box_cls is pd.array:
  258. if isinstance(expected, RangeIndex):
  259. # pd.array would return an IntegerArray
  260. expected = NumpyExtensionArray(np.asarray(expected._values))
  261. else:
  262. expected = pd.array(expected, copy=False)
  263. elif box_cls is Index:
  264. with warnings.catch_warnings():
  265. warnings.filterwarnings("ignore", "Dtype inference", category=FutureWarning)
  266. expected = Index(expected)
  267. elif box_cls is Series:
  268. with warnings.catch_warnings():
  269. warnings.filterwarnings("ignore", "Dtype inference", category=FutureWarning)
  270. expected = Series(expected)
  271. elif box_cls is DataFrame:
  272. with warnings.catch_warnings():
  273. warnings.filterwarnings("ignore", "Dtype inference", category=FutureWarning)
  274. expected = Series(expected).to_frame()
  275. if transpose:
  276. # for vector operations, we need a DataFrame to be a single-row,
  277. # not a single-column, in order to operate against non-DataFrame
  278. # vectors of the same length. But convert to two rows to avoid
  279. # single-row special cases in datetime arithmetic
  280. expected = expected.T
  281. expected = pd.concat([expected] * 2, ignore_index=True)
  282. elif box_cls is np.ndarray or box_cls is np.array:
  283. expected = np.array(expected)
  284. elif box_cls is to_array:
  285. expected = to_array(expected)
  286. else:
  287. raise NotImplementedError(box_cls)
  288. return expected
  289. def to_array(obj):
  290. """
  291. Similar to pd.array, but does not cast numpy dtypes to nullable dtypes.
  292. """
  293. # temporary implementation until we get pd.array in place
  294. dtype = getattr(obj, "dtype", None)
  295. if dtype is None:
  296. return np.asarray(obj)
  297. return extract_array(obj, extract_numpy=True)
  298. class SubclassedSeries(Series):
  299. _metadata = ["testattr", "name"]
  300. @property
  301. def _constructor(self):
  302. # For testing, those properties return a generic callable, and not
  303. # the actual class. In this case that is equivalent, but it is to
  304. # ensure we don't rely on the property returning a class
  305. # See https://github.com/pandas-dev/pandas/pull/46018 and
  306. # https://github.com/pandas-dev/pandas/issues/32638 and linked issues
  307. return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs)
  308. @property
  309. def _constructor_expanddim(self):
  310. return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs)
  311. class SubclassedDataFrame(DataFrame):
  312. _metadata = ["testattr"]
  313. @property
  314. def _constructor(self):
  315. return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs)
  316. @property
  317. def _constructor_sliced(self):
  318. return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs)
  319. def convert_rows_list_to_csv_str(rows_list: list[str]) -> str:
  320. """
  321. Convert list of CSV rows to single CSV-formatted string for current OS.
  322. This method is used for creating expected value of to_csv() method.
  323. Parameters
  324. ----------
  325. rows_list : List[str]
  326. Each element represents the row of csv.
  327. Returns
  328. -------
  329. str
  330. Expected output of to_csv() in current OS.
  331. """
  332. sep = os.linesep
  333. return sep.join(rows_list) + sep
  334. def external_error_raised(expected_exception: type[Exception]) -> ContextManager:
  335. """
  336. Helper function to mark pytest.raises that have an external error message.
  337. Parameters
  338. ----------
  339. expected_exception : Exception
  340. Expected error to raise.
  341. Returns
  342. -------
  343. Callable
  344. Regular `pytest.raises` function with `match` equal to `None`.
  345. """
  346. import pytest
  347. return pytest.raises(expected_exception, match=None)
  348. cython_table = pd.core.common._cython_table.items()
  349. def get_cython_table_params(ndframe, func_names_and_expected):
  350. """
  351. Combine frame, functions from com._cython_table
  352. keys and expected result.
  353. Parameters
  354. ----------
  355. ndframe : DataFrame or Series
  356. func_names_and_expected : Sequence of two items
  357. The first item is a name of a NDFrame method ('sum', 'prod') etc.
  358. The second item is the expected return value.
  359. Returns
  360. -------
  361. list
  362. List of three items (DataFrame, function, expected result)
  363. """
  364. results = []
  365. for func_name, expected in func_names_and_expected:
  366. results.append((ndframe, func_name, expected))
  367. results += [
  368. (ndframe, func, expected)
  369. for func, name in cython_table
  370. if name == func_name
  371. ]
  372. return results
  373. def get_op_from_name(op_name: str) -> Callable:
  374. """
  375. The operator function for a given op name.
  376. Parameters
  377. ----------
  378. op_name : str
  379. The op name, in form of "add" or "__add__".
  380. Returns
  381. -------
  382. function
  383. A function performing the operation.
  384. """
  385. short_opname = op_name.strip("_")
  386. try:
  387. op = getattr(operator, short_opname)
  388. except AttributeError:
  389. # Assume it is the reverse operator
  390. rop = getattr(operator, short_opname[1:])
  391. op = lambda x, y: rop(y, x)
  392. return op
  393. # -----------------------------------------------------------------------------
  394. # Indexing test helpers
  395. def getitem(x):
  396. return x
  397. def setitem(x):
  398. return x
  399. def loc(x):
  400. return x.loc
  401. def iloc(x):
  402. return x.iloc
  403. def at(x):
  404. return x.at
  405. def iat(x):
  406. return x.iat
  407. # -----------------------------------------------------------------------------
  408. _UNITS = ["s", "ms", "us", "ns"]
  409. def get_finest_unit(left: str, right: str):
  410. """
  411. Find the higher of two datetime64 units.
  412. """
  413. if _UNITS.index(left) >= _UNITS.index(right):
  414. return left
  415. return right
  416. def shares_memory(left, right) -> bool:
  417. """
  418. Pandas-compat for np.shares_memory.
  419. """
  420. if isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
  421. return np.shares_memory(left, right)
  422. elif isinstance(left, np.ndarray):
  423. # Call with reversed args to get to unpacking logic below.
  424. return shares_memory(right, left)
  425. if isinstance(left, RangeIndex):
  426. return False
  427. if isinstance(left, MultiIndex):
  428. return shares_memory(left._codes, right)
  429. if isinstance(left, (Index, Series)):
  430. if isinstance(right, (Index, Series)):
  431. return shares_memory(left._values, right._values)
  432. return shares_memory(left._values, right)
  433. if isinstance(left, NDArrayBackedExtensionArray):
  434. return shares_memory(left._ndarray, right)
  435. if isinstance(left, pd.core.arrays.SparseArray):
  436. return shares_memory(left.sp_values, right)
  437. if isinstance(left, pd.core.arrays.IntervalArray):
  438. return shares_memory(left._left, right) or shares_memory(left._right, right)
  439. if isinstance(left, ArrowExtensionArray):
  440. if isinstance(right, ArrowExtensionArray):
  441. # https://github.com/pandas-dev/pandas/pull/43930#discussion_r736862669
  442. left_pa_data = left._pa_array
  443. right_pa_data = right._pa_array
  444. left_buf1 = left_pa_data.chunk(0).buffers()[1]
  445. right_buf1 = right_pa_data.chunk(0).buffers()[1]
  446. return left_buf1.address == right_buf1.address
  447. else:
  448. # if we have one one ArrowExtensionArray and one other array, assume
  449. # they can only share memory if they share the same numpy buffer
  450. return np.shares_memory(left, right)
  451. if isinstance(left, BaseMaskedArray) and isinstance(right, BaseMaskedArray):
  452. # By convention, we'll say these share memory if they share *either*
  453. # the _data or the _mask
  454. return np.shares_memory(left._data, right._data) or np.shares_memory(
  455. left._mask, right._mask
  456. )
  457. if isinstance(left, DataFrame) and len(left._mgr.arrays) == 1:
  458. arr = left._mgr.arrays[0]
  459. return shares_memory(arr, right)
  460. raise NotImplementedError(type(left), type(right))
  461. __all__ = [
  462. "ALL_INT_EA_DTYPES",
  463. "ALL_INT_NUMPY_DTYPES",
  464. "ALL_NUMPY_DTYPES",
  465. "ALL_REAL_NUMPY_DTYPES",
  466. "assert_almost_equal",
  467. "assert_attr_equal",
  468. "assert_categorical_equal",
  469. "assert_class_equal",
  470. "assert_contains_all",
  471. "assert_copy",
  472. "assert_datetime_array_equal",
  473. "assert_dict_equal",
  474. "assert_equal",
  475. "assert_extension_array_equal",
  476. "assert_frame_equal",
  477. "assert_index_equal",
  478. "assert_indexing_slices_equivalent",
  479. "assert_interval_array_equal",
  480. "assert_is_sorted",
  481. "assert_is_valid_plot_return_object",
  482. "assert_metadata_equivalent",
  483. "assert_numpy_array_equal",
  484. "assert_period_array_equal",
  485. "assert_produces_warning",
  486. "assert_series_equal",
  487. "assert_sp_array_equal",
  488. "assert_timedelta_array_equal",
  489. "assert_cow_warning",
  490. "at",
  491. "BOOL_DTYPES",
  492. "box_expected",
  493. "BYTES_DTYPES",
  494. "can_set_locale",
  495. "COMPLEX_DTYPES",
  496. "convert_rows_list_to_csv_str",
  497. "DATETIME64_DTYPES",
  498. "decompress_file",
  499. "ENDIAN",
  500. "ensure_clean",
  501. "external_error_raised",
  502. "FLOAT_EA_DTYPES",
  503. "FLOAT_NUMPY_DTYPES",
  504. "get_cython_table_params",
  505. "get_dtype",
  506. "getitem",
  507. "get_locales",
  508. "get_finest_unit",
  509. "get_obj",
  510. "get_op_from_name",
  511. "iat",
  512. "iloc",
  513. "loc",
  514. "maybe_produces_warning",
  515. "NARROW_NP_DTYPES",
  516. "NP_NAT_OBJECTS",
  517. "NULL_OBJECTS",
  518. "OBJECT_DTYPES",
  519. "raise_assert_detail",
  520. "raises_chained_assignment_error",
  521. "round_trip_localpath",
  522. "round_trip_pathlib",
  523. "round_trip_pickle",
  524. "setitem",
  525. "set_locale",
  526. "set_timezone",
  527. "shares_memory",
  528. "SIGNED_INT_EA_DTYPES",
  529. "SIGNED_INT_NUMPY_DTYPES",
  530. "STRING_DTYPES",
  531. "SubclassedDataFrame",
  532. "SubclassedSeries",
  533. "TIMEDELTA64_DTYPES",
  534. "to_array",
  535. "UNSIGNED_INT_EA_DTYPES",
  536. "UNSIGNED_INT_NUMPY_DTYPES",
  537. "use_numexpr",
  538. "with_csv_dialect",
  539. "write_to_compressed",
  540. ]