test_take.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. from datetime import datetime
  2. import numpy as np
  3. import pytest
  4. from pandas._libs import iNaT
  5. import pandas._testing as tm
  6. import pandas.core.algorithms as algos
  7. @pytest.fixture(
  8. params=[
  9. (np.int8, np.int16(127), np.int8),
  10. (np.int8, np.int16(128), np.int16),
  11. (np.int32, 1, np.int32),
  12. (np.int32, 2.0, np.float64),
  13. (np.int32, 3.0 + 4.0j, np.complex128),
  14. (np.int32, True, np.object_),
  15. (np.int32, "", np.object_),
  16. (np.float64, 1, np.float64),
  17. (np.float64, 2.0, np.float64),
  18. (np.float64, 3.0 + 4.0j, np.complex128),
  19. (np.float64, True, np.object_),
  20. (np.float64, "", np.object_),
  21. (np.complex128, 1, np.complex128),
  22. (np.complex128, 2.0, np.complex128),
  23. (np.complex128, 3.0 + 4.0j, np.complex128),
  24. (np.complex128, True, np.object_),
  25. (np.complex128, "", np.object_),
  26. (np.bool_, 1, np.object_),
  27. (np.bool_, 2.0, np.object_),
  28. (np.bool_, 3.0 + 4.0j, np.object_),
  29. (np.bool_, True, np.bool_),
  30. (np.bool_, "", np.object_),
  31. ]
  32. )
  33. def dtype_fill_out_dtype(request):
  34. return request.param
  35. class TestTake:
  36. def test_1d_fill_nonna(self, dtype_fill_out_dtype):
  37. dtype, fill_value, out_dtype = dtype_fill_out_dtype
  38. data = np.random.default_rng(2).integers(0, 2, 4).astype(dtype)
  39. indexer = [2, 1, 0, -1]
  40. result = algos.take_nd(data, indexer, fill_value=fill_value)
  41. assert (result[[0, 1, 2]] == data[[2, 1, 0]]).all()
  42. assert result[3] == fill_value
  43. assert result.dtype == out_dtype
  44. indexer = [2, 1, 0, 1]
  45. result = algos.take_nd(data, indexer, fill_value=fill_value)
  46. assert (result[[0, 1, 2, 3]] == data[indexer]).all()
  47. assert result.dtype == dtype
  48. def test_2d_fill_nonna(self, dtype_fill_out_dtype):
  49. dtype, fill_value, out_dtype = dtype_fill_out_dtype
  50. data = np.random.default_rng(2).integers(0, 2, (5, 3)).astype(dtype)
  51. indexer = [2, 1, 0, -1]
  52. result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value)
  53. assert (result[[0, 1, 2], :] == data[[2, 1, 0], :]).all()
  54. assert (result[3, :] == fill_value).all()
  55. assert result.dtype == out_dtype
  56. result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value)
  57. assert (result[:, [0, 1, 2]] == data[:, [2, 1, 0]]).all()
  58. assert (result[:, 3] == fill_value).all()
  59. assert result.dtype == out_dtype
  60. indexer = [2, 1, 0, 1]
  61. result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value)
  62. assert (result[[0, 1, 2, 3], :] == data[indexer, :]).all()
  63. assert result.dtype == dtype
  64. result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value)
  65. assert (result[:, [0, 1, 2, 3]] == data[:, indexer]).all()
  66. assert result.dtype == dtype
  67. def test_3d_fill_nonna(self, dtype_fill_out_dtype):
  68. dtype, fill_value, out_dtype = dtype_fill_out_dtype
  69. data = np.random.default_rng(2).integers(0, 2, (5, 4, 3)).astype(dtype)
  70. indexer = [2, 1, 0, -1]
  71. result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value)
  72. assert (result[[0, 1, 2], :, :] == data[[2, 1, 0], :, :]).all()
  73. assert (result[3, :, :] == fill_value).all()
  74. assert result.dtype == out_dtype
  75. result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value)
  76. assert (result[:, [0, 1, 2], :] == data[:, [2, 1, 0], :]).all()
  77. assert (result[:, 3, :] == fill_value).all()
  78. assert result.dtype == out_dtype
  79. result = algos.take_nd(data, indexer, axis=2, fill_value=fill_value)
  80. assert (result[:, :, [0, 1, 2]] == data[:, :, [2, 1, 0]]).all()
  81. assert (result[:, :, 3] == fill_value).all()
  82. assert result.dtype == out_dtype
  83. indexer = [2, 1, 0, 1]
  84. result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value)
  85. assert (result[[0, 1, 2, 3], :, :] == data[indexer, :, :]).all()
  86. assert result.dtype == dtype
  87. result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value)
  88. assert (result[:, [0, 1, 2, 3], :] == data[:, indexer, :]).all()
  89. assert result.dtype == dtype
  90. result = algos.take_nd(data, indexer, axis=2, fill_value=fill_value)
  91. assert (result[:, :, [0, 1, 2, 3]] == data[:, :, indexer]).all()
  92. assert result.dtype == dtype
  93. def test_1d_other_dtypes(self):
  94. arr = np.random.default_rng(2).standard_normal(10).astype(np.float32)
  95. indexer = [1, 2, 3, -1]
  96. result = algos.take_nd(arr, indexer)
  97. expected = arr.take(indexer)
  98. expected[-1] = np.nan
  99. tm.assert_almost_equal(result, expected)
  100. def test_2d_other_dtypes(self):
  101. arr = np.random.default_rng(2).standard_normal((10, 5)).astype(np.float32)
  102. indexer = [1, 2, 3, -1]
  103. # axis=0
  104. result = algos.take_nd(arr, indexer, axis=0)
  105. expected = arr.take(indexer, axis=0)
  106. expected[-1] = np.nan
  107. tm.assert_almost_equal(result, expected)
  108. # axis=1
  109. result = algos.take_nd(arr, indexer, axis=1)
  110. expected = arr.take(indexer, axis=1)
  111. expected[:, -1] = np.nan
  112. tm.assert_almost_equal(result, expected)
  113. def test_1d_bool(self):
  114. arr = np.array([0, 1, 0], dtype=bool)
  115. result = algos.take_nd(arr, [0, 2, 2, 1])
  116. expected = arr.take([0, 2, 2, 1])
  117. tm.assert_numpy_array_equal(result, expected)
  118. result = algos.take_nd(arr, [0, 2, -1])
  119. assert result.dtype == np.object_
  120. def test_2d_bool(self):
  121. arr = np.array([[0, 1, 0], [1, 0, 1], [0, 1, 1]], dtype=bool)
  122. result = algos.take_nd(arr, [0, 2, 2, 1])
  123. expected = arr.take([0, 2, 2, 1], axis=0)
  124. tm.assert_numpy_array_equal(result, expected)
  125. result = algos.take_nd(arr, [0, 2, 2, 1], axis=1)
  126. expected = arr.take([0, 2, 2, 1], axis=1)
  127. tm.assert_numpy_array_equal(result, expected)
  128. result = algos.take_nd(arr, [0, 2, -1])
  129. assert result.dtype == np.object_
  130. def test_2d_float32(self):
  131. arr = np.random.default_rng(2).standard_normal((4, 3)).astype(np.float32)
  132. indexer = [0, 2, -1, 1, -1]
  133. # axis=0
  134. result = algos.take_nd(arr, indexer, axis=0)
  135. expected = arr.take(indexer, axis=0)
  136. expected[[2, 4], :] = np.nan
  137. tm.assert_almost_equal(result, expected)
  138. # axis=1
  139. result = algos.take_nd(arr, indexer, axis=1)
  140. expected = arr.take(indexer, axis=1)
  141. expected[:, [2, 4]] = np.nan
  142. tm.assert_almost_equal(result, expected)
  143. def test_2d_datetime64(self):
  144. # 2005/01/01 - 2006/01/01
  145. arr = (
  146. np.random.default_rng(2).integers(11_045_376, 11_360_736, (5, 3))
  147. * 100_000_000_000
  148. )
  149. arr = arr.view(dtype="datetime64[ns]")
  150. indexer = [0, 2, -1, 1, -1]
  151. # axis=0
  152. result = algos.take_nd(arr, indexer, axis=0)
  153. expected = arr.take(indexer, axis=0)
  154. expected.view(np.int64)[[2, 4], :] = iNaT
  155. tm.assert_almost_equal(result, expected)
  156. result = algos.take_nd(arr, indexer, axis=0, fill_value=datetime(2007, 1, 1))
  157. expected = arr.take(indexer, axis=0)
  158. expected[[2, 4], :] = datetime(2007, 1, 1)
  159. tm.assert_almost_equal(result, expected)
  160. # axis=1
  161. result = algos.take_nd(arr, indexer, axis=1)
  162. expected = arr.take(indexer, axis=1)
  163. expected.view(np.int64)[:, [2, 4]] = iNaT
  164. tm.assert_almost_equal(result, expected)
  165. result = algos.take_nd(arr, indexer, axis=1, fill_value=datetime(2007, 1, 1))
  166. expected = arr.take(indexer, axis=1)
  167. expected[:, [2, 4]] = datetime(2007, 1, 1)
  168. tm.assert_almost_equal(result, expected)
  169. def test_take_axis_0(self):
  170. arr = np.arange(12).reshape(4, 3)
  171. result = algos.take(arr, [0, -1])
  172. expected = np.array([[0, 1, 2], [9, 10, 11]])
  173. tm.assert_numpy_array_equal(result, expected)
  174. # allow_fill=True
  175. result = algos.take(arr, [0, -1], allow_fill=True, fill_value=0)
  176. expected = np.array([[0, 1, 2], [0, 0, 0]])
  177. tm.assert_numpy_array_equal(result, expected)
  178. def test_take_axis_1(self):
  179. arr = np.arange(12).reshape(4, 3)
  180. result = algos.take(arr, [0, -1], axis=1)
  181. expected = np.array([[0, 2], [3, 5], [6, 8], [9, 11]])
  182. tm.assert_numpy_array_equal(result, expected)
  183. # allow_fill=True
  184. result = algos.take(arr, [0, -1], axis=1, allow_fill=True, fill_value=0)
  185. expected = np.array([[0, 0], [3, 0], [6, 0], [9, 0]])
  186. tm.assert_numpy_array_equal(result, expected)
  187. # GH#26976 make sure we validate along the correct axis
  188. with pytest.raises(IndexError, match="indices are out-of-bounds"):
  189. algos.take(arr, [0, 3], axis=1, allow_fill=True, fill_value=0)
  190. def test_take_non_hashable_fill_value(self):
  191. arr = np.array([1, 2, 3])
  192. indexer = np.array([1, -1])
  193. with pytest.raises(ValueError, match="fill_value must be a scalar"):
  194. algos.take(arr, indexer, allow_fill=True, fill_value=[1])
  195. # with object dtype it is allowed
  196. arr = np.array([1, 2, 3], dtype=object)
  197. result = algos.take(arr, indexer, allow_fill=True, fill_value=[1])
  198. expected = np.array([2, [1]], dtype=object)
  199. tm.assert_numpy_array_equal(result, expected)
  200. class TestExtensionTake:
  201. # The take method found in pd.api.extensions
  202. def test_bounds_check_large(self):
  203. arr = np.array([1, 2])
  204. msg = "indices are out-of-bounds"
  205. with pytest.raises(IndexError, match=msg):
  206. algos.take(arr, [2, 3], allow_fill=True)
  207. msg = "index 2 is out of bounds for( axis 0 with)? size 2"
  208. with pytest.raises(IndexError, match=msg):
  209. algos.take(arr, [2, 3], allow_fill=False)
  210. def test_bounds_check_small(self):
  211. arr = np.array([1, 2, 3], dtype=np.int64)
  212. indexer = [0, -1, -2]
  213. msg = r"'indices' contains values less than allowed \(-2 < -1\)"
  214. with pytest.raises(ValueError, match=msg):
  215. algos.take(arr, indexer, allow_fill=True)
  216. result = algos.take(arr, indexer)
  217. expected = np.array([1, 3, 2], dtype=np.int64)
  218. tm.assert_numpy_array_equal(result, expected)
  219. @pytest.mark.parametrize("allow_fill", [True, False])
  220. def test_take_empty(self, allow_fill):
  221. arr = np.array([], dtype=np.int64)
  222. # empty take is ok
  223. result = algos.take(arr, [], allow_fill=allow_fill)
  224. tm.assert_numpy_array_equal(arr, result)
  225. msg = "|".join(
  226. [
  227. "cannot do a non-empty take from an empty axes.",
  228. "indices are out-of-bounds",
  229. ]
  230. )
  231. with pytest.raises(IndexError, match=msg):
  232. algos.take(arr, [0], allow_fill=allow_fill)
  233. def test_take_na_empty(self):
  234. result = algos.take(np.array([]), [-1, -1], allow_fill=True, fill_value=0.0)
  235. expected = np.array([0.0, 0.0])
  236. tm.assert_numpy_array_equal(result, expected)
  237. def test_take_coerces_list(self):
  238. arr = [1, 2, 3]
  239. msg = "take accepting non-standard inputs is deprecated"
  240. with tm.assert_produces_warning(FutureWarning, match=msg):
  241. result = algos.take(arr, [0, 0])
  242. expected = np.array([1, 1])
  243. tm.assert_numpy_array_equal(result, expected)