test_take.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. from datetime import datetime
  2. import numpy as np
  3. import pytest
  4. from pandas._libs import iNaT
  5. from pandas import array
  6. import pandas._testing as tm
  7. import pandas.core.algorithms as algos
  8. @pytest.fixture(
  9. params=[
  10. (np.int8, np.int16(127), np.int8),
  11. (np.int8, np.int16(128), np.int16),
  12. (np.int32, 1, np.int32),
  13. (np.int32, 2.0, np.float64),
  14. (np.int32, 3.0 + 4.0j, np.complex128),
  15. (np.int32, True, np.object_),
  16. (np.int32, "", np.object_),
  17. (np.float64, 1, np.float64),
  18. (np.float64, 2.0, np.float64),
  19. (np.float64, 3.0 + 4.0j, np.complex128),
  20. (np.float64, True, np.object_),
  21. (np.float64, "", np.object_),
  22. (np.complex128, 1, np.complex128),
  23. (np.complex128, 2.0, np.complex128),
  24. (np.complex128, 3.0 + 4.0j, np.complex128),
  25. (np.complex128, True, np.object_),
  26. (np.complex128, "", np.object_),
  27. (np.bool_, 1, np.object_),
  28. (np.bool_, 2.0, np.object_),
  29. (np.bool_, 3.0 + 4.0j, np.object_),
  30. (np.bool_, True, np.bool_),
  31. (np.bool_, "", np.object_),
  32. ]
  33. )
  34. def dtype_fill_out_dtype(request):
  35. return request.param
  36. class TestTake:
  37. def test_1d_fill_nonna(self, dtype_fill_out_dtype):
  38. dtype, fill_value, out_dtype = dtype_fill_out_dtype
  39. data = np.random.default_rng(2).integers(0, 2, 4).astype(dtype)
  40. indexer = [2, 1, 0, -1]
  41. result = algos.take_nd(data, indexer, fill_value=fill_value)
  42. assert (result[[0, 1, 2]] == data[[2, 1, 0]]).all()
  43. assert result[3] == fill_value
  44. assert result.dtype == out_dtype
  45. indexer = [2, 1, 0, 1]
  46. result = algos.take_nd(data, indexer, fill_value=fill_value)
  47. assert (result[[0, 1, 2, 3]] == data[indexer]).all()
  48. assert result.dtype == dtype
  49. def test_2d_fill_nonna(self, dtype_fill_out_dtype):
  50. dtype, fill_value, out_dtype = dtype_fill_out_dtype
  51. data = np.random.default_rng(2).integers(0, 2, (5, 3)).astype(dtype)
  52. indexer = [2, 1, 0, -1]
  53. result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value)
  54. assert (result[[0, 1, 2], :] == data[[2, 1, 0], :]).all()
  55. assert (result[3, :] == fill_value).all()
  56. assert result.dtype == out_dtype
  57. result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value)
  58. assert (result[:, [0, 1, 2]] == data[:, [2, 1, 0]]).all()
  59. assert (result[:, 3] == fill_value).all()
  60. assert result.dtype == out_dtype
  61. indexer = [2, 1, 0, 1]
  62. result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value)
  63. assert (result[[0, 1, 2, 3], :] == data[indexer, :]).all()
  64. assert result.dtype == dtype
  65. result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value)
  66. assert (result[:, [0, 1, 2, 3]] == data[:, indexer]).all()
  67. assert result.dtype == dtype
  68. def test_3d_fill_nonna(self, dtype_fill_out_dtype):
  69. dtype, fill_value, out_dtype = dtype_fill_out_dtype
  70. data = np.random.default_rng(2).integers(0, 2, (5, 4, 3)).astype(dtype)
  71. indexer = [2, 1, 0, -1]
  72. result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value)
  73. assert (result[[0, 1, 2], :, :] == data[[2, 1, 0], :, :]).all()
  74. assert (result[3, :, :] == fill_value).all()
  75. assert result.dtype == out_dtype
  76. result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value)
  77. assert (result[:, [0, 1, 2], :] == data[:, [2, 1, 0], :]).all()
  78. assert (result[:, 3, :] == fill_value).all()
  79. assert result.dtype == out_dtype
  80. result = algos.take_nd(data, indexer, axis=2, fill_value=fill_value)
  81. assert (result[:, :, [0, 1, 2]] == data[:, :, [2, 1, 0]]).all()
  82. assert (result[:, :, 3] == fill_value).all()
  83. assert result.dtype == out_dtype
  84. indexer = [2, 1, 0, 1]
  85. result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value)
  86. assert (result[[0, 1, 2, 3], :, :] == data[indexer, :, :]).all()
  87. assert result.dtype == dtype
  88. result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value)
  89. assert (result[:, [0, 1, 2, 3], :] == data[:, indexer, :]).all()
  90. assert result.dtype == dtype
  91. result = algos.take_nd(data, indexer, axis=2, fill_value=fill_value)
  92. assert (result[:, :, [0, 1, 2, 3]] == data[:, :, indexer]).all()
  93. assert result.dtype == dtype
  94. def test_1d_other_dtypes(self):
  95. arr = np.random.default_rng(2).standard_normal(10).astype(np.float32)
  96. indexer = [1, 2, 3, -1]
  97. result = algos.take_nd(arr, indexer)
  98. expected = arr.take(indexer)
  99. expected[-1] = np.nan
  100. tm.assert_almost_equal(result, expected)
  101. def test_2d_other_dtypes(self):
  102. arr = np.random.default_rng(2).standard_normal((10, 5)).astype(np.float32)
  103. indexer = [1, 2, 3, -1]
  104. # axis=0
  105. result = algos.take_nd(arr, indexer, axis=0)
  106. expected = arr.take(indexer, axis=0)
  107. expected[-1] = np.nan
  108. tm.assert_almost_equal(result, expected)
  109. # axis=1
  110. result = algos.take_nd(arr, indexer, axis=1)
  111. expected = arr.take(indexer, axis=1)
  112. expected[:, -1] = np.nan
  113. tm.assert_almost_equal(result, expected)
  114. def test_1d_bool(self):
  115. arr = np.array([0, 1, 0], dtype=bool)
  116. result = algos.take_nd(arr, [0, 2, 2, 1])
  117. expected = arr.take([0, 2, 2, 1])
  118. tm.assert_numpy_array_equal(result, expected)
  119. result = algos.take_nd(arr, [0, 2, -1])
  120. assert result.dtype == np.object_
  121. def test_2d_bool(self):
  122. arr = np.array([[0, 1, 0], [1, 0, 1], [0, 1, 1]], dtype=bool)
  123. result = algos.take_nd(arr, [0, 2, 2, 1])
  124. expected = arr.take([0, 2, 2, 1], axis=0)
  125. tm.assert_numpy_array_equal(result, expected)
  126. result = algos.take_nd(arr, [0, 2, 2, 1], axis=1)
  127. expected = arr.take([0, 2, 2, 1], axis=1)
  128. tm.assert_numpy_array_equal(result, expected)
  129. result = algos.take_nd(arr, [0, 2, -1])
  130. assert result.dtype == np.object_
  131. def test_2d_float32(self):
  132. arr = np.random.default_rng(2).standard_normal((4, 3)).astype(np.float32)
  133. indexer = [0, 2, -1, 1, -1]
  134. # axis=0
  135. result = algos.take_nd(arr, indexer, axis=0)
  136. expected = arr.take(indexer, axis=0)
  137. expected[[2, 4], :] = np.nan
  138. tm.assert_almost_equal(result, expected)
  139. # axis=1
  140. result = algos.take_nd(arr, indexer, axis=1)
  141. expected = arr.take(indexer, axis=1)
  142. expected[:, [2, 4]] = np.nan
  143. tm.assert_almost_equal(result, expected)
  144. def test_2d_datetime64(self):
  145. # 2005/01/01 - 2006/01/01
  146. arr = (
  147. np.random.default_rng(2).integers(11_045_376, 11_360_736, (5, 3))
  148. * 100_000_000_000
  149. )
  150. arr = arr.view(dtype="datetime64[ns]")
  151. indexer = [0, 2, -1, 1, -1]
  152. # axis=0
  153. result = algos.take_nd(arr, indexer, axis=0)
  154. expected = arr.take(indexer, axis=0)
  155. expected.view(np.int64)[[2, 4], :] = iNaT
  156. tm.assert_almost_equal(result, expected)
  157. result = algos.take_nd(arr, indexer, axis=0, fill_value=datetime(2007, 1, 1))
  158. expected = arr.take(indexer, axis=0)
  159. expected[[2, 4], :] = datetime(2007, 1, 1)
  160. tm.assert_almost_equal(result, expected)
  161. # axis=1
  162. result = algos.take_nd(arr, indexer, axis=1)
  163. expected = arr.take(indexer, axis=1)
  164. expected.view(np.int64)[:, [2, 4]] = iNaT
  165. tm.assert_almost_equal(result, expected)
  166. result = algos.take_nd(arr, indexer, axis=1, fill_value=datetime(2007, 1, 1))
  167. expected = arr.take(indexer, axis=1)
  168. expected[:, [2, 4]] = datetime(2007, 1, 1)
  169. tm.assert_almost_equal(result, expected)
  170. def test_take_axis_0(self):
  171. arr = np.arange(12).reshape(4, 3)
  172. result = algos.take(arr, [0, -1])
  173. expected = np.array([[0, 1, 2], [9, 10, 11]])
  174. tm.assert_numpy_array_equal(result, expected)
  175. # allow_fill=True
  176. result = algos.take(arr, [0, -1], allow_fill=True, fill_value=0)
  177. expected = np.array([[0, 1, 2], [0, 0, 0]])
  178. tm.assert_numpy_array_equal(result, expected)
  179. def test_take_axis_1(self):
  180. arr = np.arange(12).reshape(4, 3)
  181. result = algos.take(arr, [0, -1], axis=1)
  182. expected = np.array([[0, 2], [3, 5], [6, 8], [9, 11]])
  183. tm.assert_numpy_array_equal(result, expected)
  184. # allow_fill=True
  185. result = algos.take(arr, [0, -1], axis=1, allow_fill=True, fill_value=0)
  186. expected = np.array([[0, 0], [3, 0], [6, 0], [9, 0]])
  187. tm.assert_numpy_array_equal(result, expected)
  188. # GH#26976 make sure we validate along the correct axis
  189. with pytest.raises(IndexError, match="indices are out-of-bounds"):
  190. algos.take(arr, [0, 3], axis=1, allow_fill=True, fill_value=0)
  191. def test_take_non_hashable_fill_value(self):
  192. arr = np.array([1, 2, 3])
  193. indexer = np.array([1, -1])
  194. with pytest.raises(ValueError, match="fill_value must be a scalar"):
  195. algos.take(arr, indexer, allow_fill=True, fill_value=[1])
  196. # with object dtype it is allowed
  197. arr = np.array([1, 2, 3], dtype=object)
  198. result = algos.take(arr, indexer, allow_fill=True, fill_value=[1])
  199. expected = np.array([2, [1]], dtype=object)
  200. tm.assert_numpy_array_equal(result, expected)
  201. class TestExtensionTake:
  202. # The take method found in pd.api.extensions
  203. def test_bounds_check_large(self):
  204. arr = np.array([1, 2])
  205. msg = "indices are out-of-bounds"
  206. with pytest.raises(IndexError, match=msg):
  207. algos.take(arr, [2, 3], allow_fill=True)
  208. msg = "index 2 is out of bounds for( axis 0 with)? size 2"
  209. with pytest.raises(IndexError, match=msg):
  210. algos.take(arr, [2, 3], allow_fill=False)
  211. def test_bounds_check_small(self):
  212. arr = np.array([1, 2, 3], dtype=np.int64)
  213. indexer = [0, -1, -2]
  214. msg = r"'indices' contains values less than allowed \(-2 < -1\)"
  215. with pytest.raises(ValueError, match=msg):
  216. algos.take(arr, indexer, allow_fill=True)
  217. result = algos.take(arr, indexer)
  218. expected = np.array([1, 3, 2], dtype=np.int64)
  219. tm.assert_numpy_array_equal(result, expected)
  220. @pytest.mark.parametrize("allow_fill", [True, False])
  221. def test_take_empty(self, allow_fill):
  222. arr = np.array([], dtype=np.int64)
  223. # empty take is ok
  224. result = algos.take(arr, [], allow_fill=allow_fill)
  225. tm.assert_numpy_array_equal(arr, result)
  226. msg = "|".join(
  227. [
  228. "cannot do a non-empty take from an empty axes.",
  229. "indices are out-of-bounds",
  230. ]
  231. )
  232. with pytest.raises(IndexError, match=msg):
  233. algos.take(arr, [0], allow_fill=allow_fill)
  234. def test_take_na_empty(self):
  235. result = algos.take(np.array([]), [-1, -1], allow_fill=True, fill_value=0.0)
  236. expected = np.array([0.0, 0.0])
  237. tm.assert_numpy_array_equal(result, expected)
  238. def test_take_coerces_list(self):
  239. # GH#52981 coercing is deprecated, disabled in 3.0
  240. arr = [1, 2, 3]
  241. msg = (
  242. "pd.api.extensions.take requires a numpy.ndarray, ExtensionArray, "
  243. "Index, Series, or NumpyExtensionArray got list"
  244. )
  245. with pytest.raises(TypeError, match=msg):
  246. algos.take(arr, [0, 0])
  247. def test_take_NumpyExtensionArray(self):
  248. # GH#59177
  249. arr = array([1 + 1j, 2, 3]) # NumpyEADtype('complex128') (NumpyExtensionArray)
  250. assert algos.take(arr, [2]) == 2
  251. arr = array([1, 2, 3]) # Int64Dtype() (ExtensionArray)
  252. assert algos.take(arr, [2]) == 2