dim2.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. """
  2. Tests for 2D compatibility.
  3. """
  4. import numpy as np
  5. import pytest
  6. from pandas._libs.missing import is_matching_na
  7. from pandas.core.dtypes.common import (
  8. is_bool_dtype,
  9. is_integer_dtype,
  10. )
  11. import pandas as pd
  12. import pandas._testing as tm
  13. from pandas.core.arrays.integer import NUMPY_INT_TO_DTYPE
  14. class Dim2CompatTests:
  15. # Note: these are ONLY for ExtensionArray subclasses that support 2D arrays.
  16. # i.e. not for pyarrow-backed EAs.
  17. @pytest.fixture(autouse=True)
  18. def skip_if_doesnt_support_2d(self, dtype, request):
  19. if not dtype._supports_2d:
  20. node = request.node
  21. # In cases where we are mixed in to ExtensionTests, we only want to
  22. # skip tests that are defined in Dim2CompatTests
  23. test_func = node._obj
  24. if test_func.__qualname__.startswith("Dim2CompatTests"):
  25. # TODO: is there a less hacky way of checking this?
  26. pytest.skip(f"{dtype} does not support 2D.")
  27. def test_transpose(self, data):
  28. arr2d = data.repeat(2).reshape(-1, 2)
  29. shape = arr2d.shape
  30. assert shape[0] != shape[-1] # otherwise the rest of the test is useless
  31. assert arr2d.T.shape == shape[::-1]
  32. def test_frame_from_2d_array(self, data):
  33. arr2d = data.repeat(2).reshape(-1, 2)
  34. df = pd.DataFrame(arr2d)
  35. expected = pd.DataFrame({0: arr2d[:, 0], 1: arr2d[:, 1]})
  36. tm.assert_frame_equal(df, expected)
  37. def test_swapaxes(self, data):
  38. arr2d = data.repeat(2).reshape(-1, 2)
  39. result = arr2d.swapaxes(0, 1)
  40. expected = arr2d.T
  41. tm.assert_extension_array_equal(result, expected)
  42. def test_delete_2d(self, data):
  43. arr2d = data.repeat(3).reshape(-1, 3)
  44. # axis = 0
  45. result = arr2d.delete(1, axis=0)
  46. expected = data.delete(1).repeat(3).reshape(-1, 3)
  47. tm.assert_extension_array_equal(result, expected)
  48. # axis = 1
  49. result = arr2d.delete(1, axis=1)
  50. expected = data.repeat(2).reshape(-1, 2)
  51. tm.assert_extension_array_equal(result, expected)
  52. def test_take_2d(self, data):
  53. arr2d = data.reshape(-1, 1)
  54. result = arr2d.take([0, 0, -1], axis=0)
  55. expected = data.take([0, 0, -1]).reshape(-1, 1)
  56. tm.assert_extension_array_equal(result, expected)
  57. def test_repr_2d(self, data):
  58. # this could fail in a corner case where an element contained the name
  59. res = repr(data.reshape(1, -1))
  60. assert res.count(f"<{type(data).__name__}") == 1
  61. res = repr(data.reshape(-1, 1))
  62. assert res.count(f"<{type(data).__name__}") == 1
  63. def test_reshape(self, data):
  64. arr2d = data.reshape(-1, 1)
  65. assert arr2d.shape == (data.size, 1)
  66. assert len(arr2d) == len(data)
  67. arr2d = data.reshape((-1, 1))
  68. assert arr2d.shape == (data.size, 1)
  69. assert len(arr2d) == len(data)
  70. with pytest.raises(ValueError):
  71. data.reshape((data.size, 2))
  72. with pytest.raises(ValueError):
  73. data.reshape(data.size, 2)
  74. def test_getitem_2d(self, data):
  75. arr2d = data.reshape(1, -1)
  76. result = arr2d[0]
  77. tm.assert_extension_array_equal(result, data)
  78. with pytest.raises(IndexError):
  79. arr2d[1]
  80. with pytest.raises(IndexError):
  81. arr2d[-2]
  82. result = arr2d[:]
  83. tm.assert_extension_array_equal(result, arr2d)
  84. result = arr2d[:, :]
  85. tm.assert_extension_array_equal(result, arr2d)
  86. result = arr2d[:, 0]
  87. expected = data[[0]]
  88. tm.assert_extension_array_equal(result, expected)
  89. # dimension-expanding getitem on 1D
  90. result = data[:, np.newaxis]
  91. tm.assert_extension_array_equal(result, arr2d.T)
  92. def test_iter_2d(self, data):
  93. arr2d = data.reshape(1, -1)
  94. objs = list(iter(arr2d))
  95. assert len(objs) == arr2d.shape[0]
  96. for obj in objs:
  97. assert isinstance(obj, type(data))
  98. assert obj.dtype == data.dtype
  99. assert obj.ndim == 1
  100. assert len(obj) == arr2d.shape[1]
  101. def test_tolist_2d(self, data):
  102. arr2d = data.reshape(1, -1)
  103. result = arr2d.tolist()
  104. expected = [data.tolist()]
  105. assert isinstance(result, list)
  106. assert all(isinstance(x, list) for x in result)
  107. assert result == expected
  108. def test_concat_2d(self, data):
  109. left = type(data)._concat_same_type([data, data]).reshape(-1, 2)
  110. right = left.copy()
  111. # axis=0
  112. result = left._concat_same_type([left, right], axis=0)
  113. expected = data._concat_same_type([data] * 4).reshape(-1, 2)
  114. tm.assert_extension_array_equal(result, expected)
  115. # axis=1
  116. result = left._concat_same_type([left, right], axis=1)
  117. assert result.shape == (len(data), 4)
  118. tm.assert_extension_array_equal(result[:, :2], left)
  119. tm.assert_extension_array_equal(result[:, 2:], right)
  120. # axis > 1 -> invalid
  121. msg = "axis 2 is out of bounds for array of dimension 2"
  122. with pytest.raises(ValueError, match=msg):
  123. left._concat_same_type([left, right], axis=2)
  124. @pytest.mark.parametrize("method", ["backfill", "pad"])
  125. def test_fillna_2d_method(self, data_missing, method):
  126. # pad_or_backfill is always along axis=0
  127. arr = data_missing.repeat(2).reshape(2, 2)
  128. assert arr[0].isna().all()
  129. assert not arr[1].isna().any()
  130. result = arr._pad_or_backfill(method=method, limit=None)
  131. expected = data_missing._pad_or_backfill(method=method).repeat(2).reshape(2, 2)
  132. tm.assert_extension_array_equal(result, expected)
  133. # Reverse so that backfill is not a no-op.
  134. arr2 = arr[::-1]
  135. assert not arr2[0].isna().any()
  136. assert arr2[1].isna().all()
  137. result2 = arr2._pad_or_backfill(method=method, limit=None)
  138. expected2 = (
  139. data_missing[::-1]._pad_or_backfill(method=method).repeat(2).reshape(2, 2)
  140. )
  141. tm.assert_extension_array_equal(result2, expected2)
  142. @pytest.mark.parametrize("method", ["mean", "median", "var", "std", "sum", "prod"])
  143. def test_reductions_2d_axis_none(self, data, method):
  144. arr2d = data.reshape(1, -1)
  145. err_expected = None
  146. err_result = None
  147. try:
  148. expected = getattr(data, method)()
  149. except Exception as err:
  150. # if the 1D reduction is invalid, the 2D reduction should be as well
  151. err_expected = err
  152. try:
  153. result = getattr(arr2d, method)(axis=None)
  154. except Exception as err2:
  155. err_result = err2
  156. else:
  157. result = getattr(arr2d, method)(axis=None)
  158. if err_result is not None or err_expected is not None:
  159. assert type(err_result) == type(err_expected)
  160. return
  161. assert is_matching_na(result, expected) or result == expected
  162. @pytest.mark.parametrize("method", ["mean", "median", "var", "std", "sum", "prod"])
  163. @pytest.mark.parametrize("min_count", [0, 1])
  164. def test_reductions_2d_axis0(self, data, method, min_count):
  165. if min_count == 1 and method not in ["sum", "prod"]:
  166. pytest.skip(f"min_count not relevant for {method}")
  167. arr2d = data.reshape(1, -1)
  168. kwargs = {}
  169. if method in ["std", "var"]:
  170. # pass ddof=0 so we get all-zero std instead of all-NA std
  171. kwargs["ddof"] = 0
  172. elif method in ["prod", "sum"]:
  173. kwargs["min_count"] = min_count
  174. try:
  175. result = getattr(arr2d, method)(axis=0, **kwargs)
  176. except Exception as err:
  177. try:
  178. getattr(data, method)()
  179. except Exception as err2:
  180. assert type(err) == type(err2)
  181. return
  182. else:
  183. raise AssertionError("Both reductions should raise or neither")
  184. def get_reduction_result_dtype(dtype):
  185. # windows and 32bit builds will in some cases have int32/uint32
  186. # where other builds will have int64/uint64.
  187. if dtype.itemsize == 8:
  188. return dtype
  189. elif dtype.kind in "ib":
  190. return NUMPY_INT_TO_DTYPE[np.dtype(int)]
  191. else:
  192. # i.e. dtype.kind == "u"
  193. return NUMPY_INT_TO_DTYPE[np.dtype("uint")]
  194. if method in ["sum", "prod"]:
  195. # std and var are not dtype-preserving
  196. expected = data
  197. if data.dtype.kind in "iub":
  198. dtype = get_reduction_result_dtype(data.dtype)
  199. expected = data.astype(dtype)
  200. assert dtype == expected.dtype
  201. if min_count == 0:
  202. fill_value = 1 if method == "prod" else 0
  203. expected = expected.fillna(fill_value)
  204. tm.assert_extension_array_equal(result, expected)
  205. elif method == "median":
  206. # std and var are not dtype-preserving
  207. expected = data
  208. tm.assert_extension_array_equal(result, expected)
  209. elif method in ["mean", "std", "var"]:
  210. if is_integer_dtype(data) or is_bool_dtype(data):
  211. data = data.astype("Float64")
  212. if method == "mean":
  213. tm.assert_extension_array_equal(result, data)
  214. else:
  215. tm.assert_extension_array_equal(result, data - data)
  216. @pytest.mark.parametrize("method", ["mean", "median", "var", "std", "sum", "prod"])
  217. def test_reductions_2d_axis1(self, data, method):
  218. arr2d = data.reshape(1, -1)
  219. try:
  220. result = getattr(arr2d, method)(axis=1)
  221. except Exception as err:
  222. try:
  223. getattr(data, method)()
  224. except Exception as err2:
  225. assert type(err) == type(err2)
  226. return
  227. else:
  228. raise AssertionError("Both reductions should raise or neither")
  229. # not necessarily type/dtype-preserving, so weaker assertions
  230. assert result.shape == (1,)
  231. expected_scalar = getattr(data, method)()
  232. res = result[0]
  233. assert is_matching_na(res, expected_scalar) or res == expected_scalar
  234. class NDArrayBacked2DTests(Dim2CompatTests):
  235. # More specific tests for NDArrayBackedExtensionArray subclasses
  236. def test_copy_order(self, data):
  237. # We should be matching numpy semantics for the "order" keyword in 'copy'
  238. arr2d = data.repeat(2).reshape(-1, 2)
  239. assert arr2d._ndarray.flags["C_CONTIGUOUS"]
  240. res = arr2d.copy()
  241. assert res._ndarray.flags["C_CONTIGUOUS"]
  242. res = arr2d[::2, ::2].copy()
  243. assert res._ndarray.flags["C_CONTIGUOUS"]
  244. res = arr2d.copy("F")
  245. assert not res._ndarray.flags["C_CONTIGUOUS"]
  246. assert res._ndarray.flags["F_CONTIGUOUS"]
  247. res = arr2d.copy("K")
  248. assert res._ndarray.flags["C_CONTIGUOUS"]
  249. res = arr2d.T.copy("K")
  250. assert not res._ndarray.flags["C_CONTIGUOUS"]
  251. assert res._ndarray.flags["F_CONTIGUOUS"]
  252. # order not accepted by numpy
  253. msg = r"order must be one of 'C', 'F', 'A', or 'K' \(got 'Q'\)"
  254. with pytest.raises(ValueError, match=msg):
  255. arr2d.copy("Q")
  256. # neither contiguity
  257. arr_nc = arr2d[::2]
  258. assert not arr_nc._ndarray.flags["C_CONTIGUOUS"]
  259. assert not arr_nc._ndarray.flags["F_CONTIGUOUS"]
  260. assert arr_nc.copy()._ndarray.flags["C_CONTIGUOUS"]
  261. assert not arr_nc.copy()._ndarray.flags["F_CONTIGUOUS"]
  262. assert arr_nc.copy("C")._ndarray.flags["C_CONTIGUOUS"]
  263. assert not arr_nc.copy("C")._ndarray.flags["F_CONTIGUOUS"]
  264. assert not arr_nc.copy("F")._ndarray.flags["C_CONTIGUOUS"]
  265. assert arr_nc.copy("F")._ndarray.flags["F_CONTIGUOUS"]
  266. assert arr_nc.copy("K")._ndarray.flags["C_CONTIGUOUS"]
  267. assert not arr_nc.copy("K")._ndarray.flags["F_CONTIGUOUS"]