test_col.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. from datetime import datetime
  2. import numpy as np
  3. import pytest
  4. from pandas._libs.properties import cache_readonly
  5. import pandas as pd
  6. import pandas._testing as tm
  7. from pandas.api.typing import Expression
  8. from pandas.tests.test_register_accessor import ensure_removed
  9. @pytest.mark.parametrize(
  10. ("expr", "expected_values", "expected_str"),
  11. [
  12. (pd.col("a"), [1, 2], "col('a')"),
  13. (pd.col("a") * 2, [2, 4], "col('a') * 2"),
  14. (pd.col("a").sum(), [3, 3], "col('a').sum()"),
  15. (pd.col("a") + 1, [2, 3], "col('a') + 1"),
  16. (1 + pd.col("a"), [2, 3], "1 + col('a')"),
  17. (pd.col("a") - 1, [0, 1], "col('a') - 1"),
  18. (1 - pd.col("a"), [0, -1], "1 - col('a')"),
  19. (pd.col("a") * 1, [1, 2], "col('a') * 1"),
  20. (1 * pd.col("a"), [1, 2], "1 * col('a')"),
  21. (pd.col("a") / 1, [1.0, 2.0], "col('a') / 1"),
  22. (1 / pd.col("a"), [1.0, 0.5], "1 / col('a')"),
  23. (pd.col("a") // 1, [1, 2], "col('a') // 1"),
  24. (1 // pd.col("a"), [1, 0], "1 // col('a')"),
  25. (pd.col("a") % 1, [0, 0], "col('a') % 1"),
  26. (1 % pd.col("a"), [0, 1], "1 % col('a')"),
  27. (pd.col("a") > 1, [False, True], "col('a') > 1"),
  28. (pd.col("a") >= 1, [True, True], "col('a') >= 1"),
  29. (pd.col("a") < 1, [False, False], "col('a') < 1"),
  30. (pd.col("a") <= 1, [True, False], "col('a') <= 1"),
  31. (pd.col("a") == 1, [True, False], "col('a') == 1"),
  32. (np.power(pd.col("a"), 2), [1, 4], "power(col('a'), 2)"),
  33. (np.divide(pd.col("a"), pd.col("a")), [1.0, 1.0], "divide(col('a'), col('a'))"),
  34. (
  35. (pd.col("a") + 1) * (pd.col("b") + 2),
  36. [10, 18],
  37. "(col('a') + 1) * (col('b') + 2)",
  38. ),
  39. (
  40. (pd.col("a") - 1).astype("bool"),
  41. [False, True],
  42. "(col('a') - 1).astype('bool')",
  43. ),
  44. ],
  45. )
  46. def test_col_simple(
  47. expr: Expression, expected_values: list[object], expected_str: str
  48. ) -> None:
  49. df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
  50. result = df.assign(c=expr)
  51. expected = pd.DataFrame({"a": [1, 2], "b": [3, 4], "c": expected_values})
  52. tm.assert_frame_equal(result, expected)
  53. assert str(expr) == expected_str
  54. def test_frame_getitem() -> None:
  55. # https://github.com/pandas-dev/pandas/pull/63439
  56. df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
  57. expr = pd.col("a") == 2
  58. result = df[expr]
  59. expected = df.iloc[[1]]
  60. tm.assert_frame_equal(result, expected)
  61. def test_frame_setitem() -> None:
  62. # https://github.com/pandas-dev/pandas/pull/63439
  63. df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
  64. expr = pd.col("a") == 2
  65. result = df.copy()
  66. result[expr] = 100
  67. expected = pd.DataFrame({"a": [1, 100], "b": [3, 100]})
  68. tm.assert_frame_equal(result, expected)
  69. def test_frame_loc() -> None:
  70. # https://github.com/pandas-dev/pandas/pull/63439
  71. df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
  72. expr = pd.col("a") == 2
  73. result = df.copy()
  74. result.loc[expr, "b"] = 100
  75. expected = pd.DataFrame({"a": [1, 2], "b": [3, 100]})
  76. tm.assert_frame_equal(result, expected)
  77. def test_frame_iloc() -> None:
  78. # https://github.com/pandas-dev/pandas/pull/63439
  79. df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
  80. expr = pd.col("a") == 2
  81. result = df.copy()
  82. result.iloc[expr, 1] = 100
  83. expected = pd.DataFrame({"a": [1, 2], "b": [3, 100]})
  84. tm.assert_frame_equal(result, expected)
  85. @pytest.mark.parametrize(
  86. ("expr", "expected_values", "expected_str"),
  87. [
  88. (pd.col("a").dt.year, [2020], "col('a').dt.year"),
  89. (pd.col("a").dt.strftime("%B"), ["January"], "col('a').dt.strftime('%B')"),
  90. (pd.col("b").str.upper(), ["FOO"], "col('b').str.upper()"),
  91. ],
  92. )
  93. def test_namespaces(
  94. expr: Expression, expected_values: list[object], expected_str: str
  95. ) -> None:
  96. df = pd.DataFrame({"a": [datetime(2020, 1, 1)], "b": ["foo"]})
  97. result = df.assign(c=expr)
  98. expected = pd.DataFrame(
  99. {"a": [datetime(2020, 1, 1)], "b": ["foo"], "c": expected_values}
  100. )
  101. tm.assert_frame_equal(result, expected, check_dtype=False)
  102. assert str(expr) == expected_str
  103. def test_invalid() -> None:
  104. df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
  105. with pytest.raises(ValueError, match=r"did you mean one of \['a', 'b'\] instead"):
  106. df.assign(c=pd.col("c").mean())
  107. df = pd.DataFrame({f"col_{i}": [0] for i in range(11)})
  108. msg = (
  109. "did you mean one of "
  110. r"\['col_0', 'col_1', 'col_2', 'col_3', "
  111. "'col_4', 'col_5', 'col_6', 'col_7', "
  112. r"'col_8', 'col_9',\.\.\.\] instead"
  113. )
  114. ""
  115. with pytest.raises(ValueError, match=msg):
  116. df.assign(c=pd.col("c").mean())
  117. def test_custom_accessor() -> None:
  118. df = pd.DataFrame({"a": [1, 2, 3]})
  119. class XYZAccessor:
  120. def __init__(self, pandas_obj):
  121. self._obj = pandas_obj
  122. def mean(self):
  123. return self._obj.mean()
  124. with ensure_removed(pd.Series, "xyz"):
  125. pd.api.extensions.register_series_accessor("xyz")(XYZAccessor)
  126. result = df.assign(b=pd.col("a").xyz.mean())
  127. expected = pd.DataFrame({"a": [1, 2, 3], "b": [2.0, 2.0, 2.0]})
  128. tm.assert_frame_equal(result, expected)
  129. @pytest.mark.parametrize(
  130. ("expr", "expected_values", "expected_str"),
  131. [
  132. (
  133. pd.col("a") & pd.col("b"),
  134. [False, False, True, False],
  135. "col('a') & col('b')",
  136. ),
  137. (
  138. pd.col("a") & True,
  139. [True, False, True, False],
  140. "col('a') & True",
  141. ),
  142. (
  143. pd.col("a") | pd.col("b"),
  144. [True, True, True, True],
  145. "col('a') | col('b')",
  146. ),
  147. (
  148. pd.col("a") | False,
  149. [True, False, True, False],
  150. "col('a') | False",
  151. ),
  152. (
  153. pd.col("a") ^ pd.col("b"),
  154. [True, True, False, True],
  155. "col('a') ^ col('b')",
  156. ),
  157. (
  158. pd.col("a") ^ True,
  159. [False, True, False, True],
  160. "col('a') ^ True",
  161. ),
  162. (
  163. ~pd.col("a"),
  164. [False, True, False, True],
  165. "~col('a')",
  166. ),
  167. ],
  168. )
  169. def test_col_logical_ops(
  170. expr: Expression, expected_values: list[bool], expected_str: str
  171. ) -> None:
  172. # https://github.com/pandas-dev/pandas/issues/63322
  173. df = pd.DataFrame({"a": [True, False, True, False], "b": [False, True, True, True]})
  174. result = df.assign(c=expr)
  175. expected = pd.DataFrame(
  176. {
  177. "a": [True, False, True, False],
  178. "b": [False, True, True, True],
  179. "c": expected_values,
  180. }
  181. )
  182. tm.assert_frame_equal(result, expected)
  183. assert str(expr) == expected_str
  184. # Test that the expression works with .loc
  185. result = df.loc[expr]
  186. expected = df[expected_values]
  187. tm.assert_frame_equal(result, expected)
  188. def test_expression_getitem() -> None:
  189. # https://github.com/pandas-dev/pandas/pull/63439
  190. df = pd.DataFrame({"a": [1, 2, 3]})
  191. expr = pd.col("a")[1]
  192. expected_str = "col('a')[1]"
  193. assert str(expr) == expected_str
  194. result = df.assign(b=expr)
  195. expected = pd.DataFrame({"a": [1, 2, 3], "b": [2, 2, 2]})
  196. tm.assert_frame_equal(result, expected)
  197. def test_property() -> None:
  198. # https://github.com/pandas-dev/pandas/pull/63439
  199. df = pd.DataFrame({"a": [1, 2, 3]})
  200. expr = pd.col("a").index
  201. expected_str = "col('a').index"
  202. assert str(expr) == expected_str
  203. result = df.assign(b=expr)
  204. expected = pd.DataFrame({"a": [1, 2, 3], "b": [0, 1, 2]})
  205. tm.assert_frame_equal(result, expected)
  206. def test_cached_property() -> None:
  207. # https://github.com/pandas-dev/pandas/pull/63439
  208. # Ensure test is valid
  209. assert isinstance(pd.Index.dtype, cache_readonly)
  210. df = pd.DataFrame({"a": [1, 2, 3]})
  211. expr = pd.col("a").index.dtype
  212. expected_str = "col('a').index.dtype"
  213. assert str(expr) == expected_str
  214. result = df.assign(b=expr)
  215. expected = pd.DataFrame({"a": [1, 2, 3], "b": np.int64})
  216. tm.assert_frame_equal(result, expected)
  217. def test_qcut() -> None:
  218. # https://github.com/pandas-dev/pandas/pull/63439
  219. df = pd.DataFrame({"a": [1, 2, 3]})
  220. expr = pd.qcut(pd.col("a"), 3)
  221. expected_str = "qcut(x=col('a'), q=3, labels=None, retbins=False, precision=3)"
  222. assert str(expr) == expected_str, str(expr)
  223. result = df.assign(b=expr)
  224. expected = pd.DataFrame({"a": [1, 2, 3], "b": pd.qcut(df["a"], 3)})
  225. tm.assert_frame_equal(result, expected)
  226. def test_where() -> None:
  227. # https://github.com/pandas-dev/pandas/pull/63439
  228. df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
  229. expr = pd.col("a").where(pd.col("b") == 5, 100)
  230. expected_str = "col('a').where(col('b') == 5, 100)"
  231. assert str(expr) == expected_str, str(expr)
  232. result = df.assign(c=expr)
  233. expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [100, 2, 100]})
  234. tm.assert_frame_equal(result, expected)
  235. expr = pd.col("a").where(pd.col("b") == 5, pd.col("a") + 1)
  236. expected_str = "col('a').where(col('b') == 5, col('a') + 1)"
  237. assert str(expr) == expected_str, str(expr)
  238. result = df.assign(c=expr)
  239. expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [2, 2, 4]})
  240. tm.assert_frame_equal(result, expected)