test_coercion.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941
  1. from __future__ import annotations
  2. from datetime import (
  3. datetime,
  4. timedelta,
  5. )
  6. import itertools
  7. import numpy as np
  8. import pytest
  9. from pandas.compat import (
  10. IS64,
  11. is_platform_windows,
  12. )
  13. from pandas.compat.numpy import np_version_gt2
  14. import pandas as pd
  15. import pandas._testing as tm
  16. ###############################################################
  17. # Index / Series common tests which may trigger dtype coercions
  18. ###############################################################
  19. @pytest.fixture(autouse=True, scope="class")
  20. def check_comprehensiveness(request):
  21. # Iterate over combination of dtype, method and klass
  22. # and ensure that each are contained within a collected test
  23. cls = request.cls
  24. combos = itertools.product(cls.klasses, cls.dtypes, [cls.method])
  25. def has_test(combo):
  26. klass, dtype, method = combo
  27. cls_funcs = request.node.session.items
  28. return any(
  29. klass in x.name and dtype in x.name and method in x.name for x in cls_funcs
  30. )
  31. opts = request.config.option
  32. if opts.lf or opts.keyword:
  33. # If we are running with "last-failed" or -k foo, we expect to only
  34. # run a subset of tests.
  35. yield
  36. else:
  37. for combo in combos:
  38. if not has_test(combo):
  39. raise AssertionError(
  40. f"test method is not defined: {cls.__name__}, {combo}"
  41. )
  42. yield
  43. class CoercionBase:
  44. klasses = ["index", "series"]
  45. dtypes = [
  46. "object",
  47. "int64",
  48. "float64",
  49. "complex128",
  50. "bool",
  51. "datetime64",
  52. "datetime64tz",
  53. "timedelta64",
  54. "period",
  55. ]
  56. @property
  57. def method(self):
  58. raise NotImplementedError(self)
  59. class TestSetitemCoercion(CoercionBase):
  60. method = "setitem"
  61. # disable comprehensiveness tests, as most of these have been moved to
  62. # tests.series.indexing.test_setitem in SetitemCastingEquivalents subclasses.
  63. klasses: list[str] = []
  64. def test_setitem_series_no_coercion_from_values_list(self):
  65. # GH35865 - int casted to str when internally calling np.array(ser.values)
  66. ser = pd.Series(["a", 1])
  67. ser[:] = list(ser.values)
  68. expected = pd.Series(["a", 1])
  69. tm.assert_series_equal(ser, expected)
  70. def _assert_setitem_index_conversion(
  71. self, original_series, loc_key, expected_index, expected_dtype
  72. ):
  73. """test index's coercion triggered by assign key"""
  74. temp = original_series.copy()
  75. # GH#33469 pre-2.0 with int loc_key and temp.index.dtype == np.float64
  76. # `temp[loc_key] = 5` treated loc_key as positional
  77. temp[loc_key] = 5
  78. exp = pd.Series([1, 2, 3, 4, 5], index=expected_index)
  79. tm.assert_series_equal(temp, exp)
  80. # check dtype explicitly for sure
  81. assert temp.index.dtype == expected_dtype
  82. temp = original_series.copy()
  83. temp.loc[loc_key] = 5
  84. exp = pd.Series([1, 2, 3, 4, 5], index=expected_index)
  85. tm.assert_series_equal(temp, exp)
  86. # check dtype explicitly for sure
  87. assert temp.index.dtype == expected_dtype
  88. @pytest.mark.parametrize(
  89. "val,exp_dtype", [("x", object), (5, IndexError), (1.1, object)]
  90. )
  91. def test_setitem_index_object(self, val, exp_dtype):
  92. obj = pd.Series([1, 2, 3, 4], index=pd.Index(list("abcd"), dtype=object))
  93. assert obj.index.dtype == object
  94. if exp_dtype is IndexError:
  95. temp = obj.copy()
  96. warn_msg = "Series.__setitem__ treating keys as positions is deprecated"
  97. msg = "index 5 is out of bounds for axis 0 with size 4"
  98. with pytest.raises(exp_dtype, match=msg):
  99. with tm.assert_produces_warning(FutureWarning, match=warn_msg):
  100. temp[5] = 5
  101. else:
  102. exp_index = pd.Index(list("abcd") + [val], dtype=object)
  103. self._assert_setitem_index_conversion(obj, val, exp_index, exp_dtype)
  104. @pytest.mark.parametrize(
  105. "val,exp_dtype", [(5, np.int64), (1.1, np.float64), ("x", object)]
  106. )
  107. def test_setitem_index_int64(self, val, exp_dtype):
  108. obj = pd.Series([1, 2, 3, 4])
  109. assert obj.index.dtype == np.int64
  110. exp_index = pd.Index([0, 1, 2, 3, val])
  111. self._assert_setitem_index_conversion(obj, val, exp_index, exp_dtype)
  112. @pytest.mark.parametrize(
  113. "val,exp_dtype", [(5, np.float64), (5.1, np.float64), ("x", object)]
  114. )
  115. def test_setitem_index_float64(self, val, exp_dtype, request):
  116. obj = pd.Series([1, 2, 3, 4], index=[1.1, 2.1, 3.1, 4.1])
  117. assert obj.index.dtype == np.float64
  118. exp_index = pd.Index([1.1, 2.1, 3.1, 4.1, val])
  119. self._assert_setitem_index_conversion(obj, val, exp_index, exp_dtype)
  120. @pytest.mark.xfail(reason="Test not implemented")
  121. def test_setitem_series_period(self):
  122. raise NotImplementedError
  123. @pytest.mark.xfail(reason="Test not implemented")
  124. def test_setitem_index_complex128(self):
  125. raise NotImplementedError
  126. @pytest.mark.xfail(reason="Test not implemented")
  127. def test_setitem_index_bool(self):
  128. raise NotImplementedError
  129. @pytest.mark.xfail(reason="Test not implemented")
  130. def test_setitem_index_datetime64(self):
  131. raise NotImplementedError
  132. @pytest.mark.xfail(reason="Test not implemented")
  133. def test_setitem_index_datetime64tz(self):
  134. raise NotImplementedError
  135. @pytest.mark.xfail(reason="Test not implemented")
  136. def test_setitem_index_timedelta64(self):
  137. raise NotImplementedError
  138. @pytest.mark.xfail(reason="Test not implemented")
  139. def test_setitem_index_period(self):
  140. raise NotImplementedError
  141. class TestInsertIndexCoercion(CoercionBase):
  142. klasses = ["index"]
  143. method = "insert"
  144. def _assert_insert_conversion(self, original, value, expected, expected_dtype):
  145. """test coercion triggered by insert"""
  146. target = original.copy()
  147. res = target.insert(1, value)
  148. tm.assert_index_equal(res, expected)
  149. assert res.dtype == expected_dtype
  150. @pytest.mark.parametrize(
  151. "insert, coerced_val, coerced_dtype",
  152. [
  153. (1, 1, object),
  154. (1.1, 1.1, object),
  155. (False, False, object),
  156. ("x", "x", object),
  157. ],
  158. )
  159. def test_insert_index_object(self, insert, coerced_val, coerced_dtype):
  160. obj = pd.Index(list("abcd"), dtype=object)
  161. assert obj.dtype == object
  162. exp = pd.Index(["a", coerced_val, "b", "c", "d"], dtype=object)
  163. self._assert_insert_conversion(obj, insert, exp, coerced_dtype)
  164. @pytest.mark.parametrize(
  165. "insert, coerced_val, coerced_dtype",
  166. [
  167. (1, 1, None),
  168. (1.1, 1.1, np.float64),
  169. (False, False, object), # GH#36319
  170. ("x", "x", object),
  171. ],
  172. )
  173. def test_insert_int_index(
  174. self, any_int_numpy_dtype, insert, coerced_val, coerced_dtype
  175. ):
  176. dtype = any_int_numpy_dtype
  177. obj = pd.Index([1, 2, 3, 4], dtype=dtype)
  178. coerced_dtype = coerced_dtype if coerced_dtype is not None else dtype
  179. exp = pd.Index([1, coerced_val, 2, 3, 4], dtype=coerced_dtype)
  180. self._assert_insert_conversion(obj, insert, exp, coerced_dtype)
  181. @pytest.mark.parametrize(
  182. "insert, coerced_val, coerced_dtype",
  183. [
  184. (1, 1.0, None),
  185. # When float_numpy_dtype=float32, this is not the case
  186. # see the correction below
  187. (1.1, 1.1, np.float64),
  188. (False, False, object), # GH#36319
  189. ("x", "x", object),
  190. ],
  191. )
  192. def test_insert_float_index(
  193. self, float_numpy_dtype, insert, coerced_val, coerced_dtype
  194. ):
  195. dtype = float_numpy_dtype
  196. obj = pd.Index([1.0, 2.0, 3.0, 4.0], dtype=dtype)
  197. coerced_dtype = coerced_dtype if coerced_dtype is not None else dtype
  198. if np_version_gt2 and dtype == "float32" and coerced_val == 1.1:
  199. # Hack, in the 2nd test case, since 1.1 can be losslessly cast to float32
  200. # the expected dtype will be float32 if the original dtype was float32
  201. coerced_dtype = np.float32
  202. exp = pd.Index([1.0, coerced_val, 2.0, 3.0, 4.0], dtype=coerced_dtype)
  203. self._assert_insert_conversion(obj, insert, exp, coerced_dtype)
  204. @pytest.mark.parametrize(
  205. "fill_val,exp_dtype",
  206. [
  207. (pd.Timestamp("2012-01-01"), "datetime64[ns]"),
  208. (pd.Timestamp("2012-01-01", tz="US/Eastern"), "datetime64[ns, US/Eastern]"),
  209. ],
  210. ids=["datetime64", "datetime64tz"],
  211. )
  212. @pytest.mark.parametrize(
  213. "insert_value",
  214. [pd.Timestamp("2012-01-01"), pd.Timestamp("2012-01-01", tz="Asia/Tokyo"), 1],
  215. )
  216. def test_insert_index_datetimes(self, fill_val, exp_dtype, insert_value):
  217. obj = pd.DatetimeIndex(
  218. ["2011-01-01", "2011-01-02", "2011-01-03", "2011-01-04"], tz=fill_val.tz
  219. ).as_unit("ns")
  220. assert obj.dtype == exp_dtype
  221. exp = pd.DatetimeIndex(
  222. ["2011-01-01", fill_val.date(), "2011-01-02", "2011-01-03", "2011-01-04"],
  223. tz=fill_val.tz,
  224. ).as_unit("ns")
  225. self._assert_insert_conversion(obj, fill_val, exp, exp_dtype)
  226. if fill_val.tz:
  227. # mismatched tzawareness
  228. ts = pd.Timestamp("2012-01-01")
  229. result = obj.insert(1, ts)
  230. expected = obj.astype(object).insert(1, ts)
  231. assert expected.dtype == object
  232. tm.assert_index_equal(result, expected)
  233. ts = pd.Timestamp("2012-01-01", tz="Asia/Tokyo")
  234. result = obj.insert(1, ts)
  235. # once deprecation is enforced:
  236. expected = obj.insert(1, ts.tz_convert(obj.dtype.tz))
  237. assert expected.dtype == obj.dtype
  238. tm.assert_index_equal(result, expected)
  239. else:
  240. # mismatched tzawareness
  241. ts = pd.Timestamp("2012-01-01", tz="Asia/Tokyo")
  242. result = obj.insert(1, ts)
  243. expected = obj.astype(object).insert(1, ts)
  244. assert expected.dtype == object
  245. tm.assert_index_equal(result, expected)
  246. item = 1
  247. result = obj.insert(1, item)
  248. expected = obj.astype(object).insert(1, item)
  249. assert expected[1] == item
  250. assert expected.dtype == object
  251. tm.assert_index_equal(result, expected)
  252. def test_insert_index_timedelta64(self):
  253. obj = pd.TimedeltaIndex(["1 day", "2 day", "3 day", "4 day"])
  254. assert obj.dtype == "timedelta64[ns]"
  255. # timedelta64 + timedelta64 => timedelta64
  256. exp = pd.TimedeltaIndex(["1 day", "10 day", "2 day", "3 day", "4 day"])
  257. self._assert_insert_conversion(
  258. obj, pd.Timedelta("10 day"), exp, "timedelta64[ns]"
  259. )
  260. for item in [pd.Timestamp("2012-01-01"), 1]:
  261. result = obj.insert(1, item)
  262. expected = obj.astype(object).insert(1, item)
  263. assert expected.dtype == object
  264. tm.assert_index_equal(result, expected)
  265. @pytest.mark.parametrize(
  266. "insert, coerced_val, coerced_dtype",
  267. [
  268. (pd.Period("2012-01", freq="M"), "2012-01", "period[M]"),
  269. (pd.Timestamp("2012-01-01"), pd.Timestamp("2012-01-01"), object),
  270. (1, 1, object),
  271. ("x", "x", object),
  272. ],
  273. )
  274. def test_insert_index_period(self, insert, coerced_val, coerced_dtype):
  275. obj = pd.PeriodIndex(["2011-01", "2011-02", "2011-03", "2011-04"], freq="M")
  276. assert obj.dtype == "period[M]"
  277. data = [
  278. pd.Period("2011-01", freq="M"),
  279. coerced_val,
  280. pd.Period("2011-02", freq="M"),
  281. pd.Period("2011-03", freq="M"),
  282. pd.Period("2011-04", freq="M"),
  283. ]
  284. if isinstance(insert, pd.Period):
  285. exp = pd.PeriodIndex(data, freq="M")
  286. self._assert_insert_conversion(obj, insert, exp, coerced_dtype)
  287. # string that can be parsed to appropriate PeriodDtype
  288. self._assert_insert_conversion(obj, str(insert), exp, coerced_dtype)
  289. else:
  290. result = obj.insert(0, insert)
  291. expected = obj.astype(object).insert(0, insert)
  292. tm.assert_index_equal(result, expected)
  293. # TODO: ATM inserting '2012-01-01 00:00:00' when we have obj.freq=="M"
  294. # casts that string to Period[M], not clear that is desirable
  295. if not isinstance(insert, pd.Timestamp):
  296. # non-castable string
  297. result = obj.insert(0, str(insert))
  298. expected = obj.astype(object).insert(0, str(insert))
  299. tm.assert_index_equal(result, expected)
  300. @pytest.mark.xfail(reason="Test not implemented")
  301. def test_insert_index_complex128(self):
  302. raise NotImplementedError
  303. @pytest.mark.xfail(reason="Test not implemented")
  304. def test_insert_index_bool(self):
  305. raise NotImplementedError
  306. class TestWhereCoercion(CoercionBase):
  307. method = "where"
  308. _cond = np.array([True, False, True, False])
  309. def _assert_where_conversion(
  310. self, original, cond, values, expected, expected_dtype
  311. ):
  312. """test coercion triggered by where"""
  313. target = original.copy()
  314. res = target.where(cond, values)
  315. tm.assert_equal(res, expected)
  316. assert res.dtype == expected_dtype
  317. def _construct_exp(self, obj, klass, fill_val, exp_dtype):
  318. if fill_val is True:
  319. values = klass([True, False, True, True])
  320. elif isinstance(fill_val, (datetime, np.datetime64)):
  321. values = pd.date_range(fill_val, periods=4)
  322. else:
  323. values = klass(x * fill_val for x in [5, 6, 7, 8])
  324. exp = klass([obj[0], values[1], obj[2], values[3]], dtype=exp_dtype)
  325. return values, exp
  326. def _run_test(self, obj, fill_val, klass, exp_dtype):
  327. cond = klass(self._cond)
  328. exp = klass([obj[0], fill_val, obj[2], fill_val], dtype=exp_dtype)
  329. self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)
  330. values, exp = self._construct_exp(obj, klass, fill_val, exp_dtype)
  331. self._assert_where_conversion(obj, cond, values, exp, exp_dtype)
  332. @pytest.mark.parametrize(
  333. "fill_val,exp_dtype",
  334. [(1, object), (1.1, object), (1 + 1j, object), (True, object)],
  335. )
  336. def test_where_object(self, index_or_series, fill_val, exp_dtype):
  337. klass = index_or_series
  338. obj = klass(list("abcd"), dtype=object)
  339. assert obj.dtype == object
  340. self._run_test(obj, fill_val, klass, exp_dtype)
  341. @pytest.mark.parametrize(
  342. "fill_val,exp_dtype",
  343. [(1, np.int64), (1.1, np.float64), (1 + 1j, np.complex128), (True, object)],
  344. )
  345. def test_where_int64(self, index_or_series, fill_val, exp_dtype, request):
  346. klass = index_or_series
  347. obj = klass([1, 2, 3, 4])
  348. assert obj.dtype == np.int64
  349. self._run_test(obj, fill_val, klass, exp_dtype)
  350. @pytest.mark.parametrize(
  351. "fill_val, exp_dtype",
  352. [(1, np.float64), (1.1, np.float64), (1 + 1j, np.complex128), (True, object)],
  353. )
  354. def test_where_float64(self, index_or_series, fill_val, exp_dtype, request):
  355. klass = index_or_series
  356. obj = klass([1.1, 2.2, 3.3, 4.4])
  357. assert obj.dtype == np.float64
  358. self._run_test(obj, fill_val, klass, exp_dtype)
  359. @pytest.mark.parametrize(
  360. "fill_val,exp_dtype",
  361. [
  362. (1, np.complex128),
  363. (1.1, np.complex128),
  364. (1 + 1j, np.complex128),
  365. (True, object),
  366. ],
  367. )
  368. def test_where_complex128(self, index_or_series, fill_val, exp_dtype):
  369. klass = index_or_series
  370. obj = klass([1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=np.complex128)
  371. assert obj.dtype == np.complex128
  372. self._run_test(obj, fill_val, klass, exp_dtype)
  373. @pytest.mark.parametrize(
  374. "fill_val,exp_dtype",
  375. [(1, object), (1.1, object), (1 + 1j, object), (True, np.bool_)],
  376. )
  377. def test_where_series_bool(self, index_or_series, fill_val, exp_dtype):
  378. klass = index_or_series
  379. obj = klass([True, False, True, False])
  380. assert obj.dtype == np.bool_
  381. self._run_test(obj, fill_val, klass, exp_dtype)
  382. @pytest.mark.parametrize(
  383. "fill_val,exp_dtype",
  384. [
  385. (pd.Timestamp("2012-01-01"), "datetime64[ns]"),
  386. (pd.Timestamp("2012-01-01", tz="US/Eastern"), object),
  387. ],
  388. ids=["datetime64", "datetime64tz"],
  389. )
  390. def test_where_datetime64(self, index_or_series, fill_val, exp_dtype):
  391. klass = index_or_series
  392. obj = klass(pd.date_range("2011-01-01", periods=4, freq="D")._with_freq(None))
  393. assert obj.dtype == "datetime64[ns]"
  394. fv = fill_val
  395. # do the check with each of the available datetime scalars
  396. if exp_dtype == "datetime64[ns]":
  397. for scalar in [fv, fv.to_pydatetime(), fv.to_datetime64()]:
  398. self._run_test(obj, scalar, klass, exp_dtype)
  399. else:
  400. for scalar in [fv, fv.to_pydatetime()]:
  401. self._run_test(obj, fill_val, klass, exp_dtype)
  402. @pytest.mark.xfail(reason="Test not implemented")
  403. def test_where_index_complex128(self):
  404. raise NotImplementedError
  405. @pytest.mark.xfail(reason="Test not implemented")
  406. def test_where_index_bool(self):
  407. raise NotImplementedError
  408. @pytest.mark.xfail(reason="Test not implemented")
  409. def test_where_series_timedelta64(self):
  410. raise NotImplementedError
  411. @pytest.mark.xfail(reason="Test not implemented")
  412. def test_where_series_period(self):
  413. raise NotImplementedError
  414. @pytest.mark.parametrize(
  415. "value", [pd.Timedelta(days=9), timedelta(days=9), np.timedelta64(9, "D")]
  416. )
  417. def test_where_index_timedelta64(self, value):
  418. tdi = pd.timedelta_range("1 Day", periods=4)
  419. cond = np.array([True, False, False, True])
  420. expected = pd.TimedeltaIndex(["1 Day", value, value, "4 Days"])
  421. result = tdi.where(cond, value)
  422. tm.assert_index_equal(result, expected)
  423. # wrong-dtyped NaT
  424. dtnat = np.datetime64("NaT", "ns")
  425. expected = pd.Index([tdi[0], dtnat, dtnat, tdi[3]], dtype=object)
  426. assert expected[1] is dtnat
  427. result = tdi.where(cond, dtnat)
  428. tm.assert_index_equal(result, expected)
  429. def test_where_index_period(self):
  430. dti = pd.date_range("2016-01-01", periods=3, freq="QS")
  431. pi = dti.to_period("Q")
  432. cond = np.array([False, True, False])
  433. # Passing a valid scalar
  434. value = pi[-1] + pi.freq * 10
  435. expected = pd.PeriodIndex([value, pi[1], value])
  436. result = pi.where(cond, value)
  437. tm.assert_index_equal(result, expected)
  438. # Case passing ndarray[object] of Periods
  439. other = np.asarray(pi + pi.freq * 10, dtype=object)
  440. result = pi.where(cond, other)
  441. expected = pd.PeriodIndex([other[0], pi[1], other[2]])
  442. tm.assert_index_equal(result, expected)
  443. # Passing a mismatched scalar -> casts to object
  444. td = pd.Timedelta(days=4)
  445. expected = pd.Index([td, pi[1], td], dtype=object)
  446. result = pi.where(cond, td)
  447. tm.assert_index_equal(result, expected)
  448. per = pd.Period("2020-04-21", "D")
  449. expected = pd.Index([per, pi[1], per], dtype=object)
  450. result = pi.where(cond, per)
  451. tm.assert_index_equal(result, expected)
  452. class TestFillnaSeriesCoercion(CoercionBase):
  453. # not indexing, but place here for consistency
  454. method = "fillna"
  455. @pytest.mark.xfail(reason="Test not implemented")
  456. def test_has_comprehensive_tests(self):
  457. raise NotImplementedError
  458. def _assert_fillna_conversion(self, original, value, expected, expected_dtype):
  459. """test coercion triggered by fillna"""
  460. target = original.copy()
  461. res = target.fillna(value)
  462. tm.assert_equal(res, expected)
  463. assert res.dtype == expected_dtype
  464. @pytest.mark.parametrize(
  465. "fill_val, fill_dtype",
  466. [(1, object), (1.1, object), (1 + 1j, object), (True, object)],
  467. )
  468. def test_fillna_object(self, index_or_series, fill_val, fill_dtype):
  469. klass = index_or_series
  470. obj = klass(["a", np.nan, "c", "d"], dtype=object)
  471. assert obj.dtype == object
  472. exp = klass(["a", fill_val, "c", "d"], dtype=object)
  473. self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)
  474. @pytest.mark.parametrize(
  475. "fill_val,fill_dtype",
  476. [(1, np.float64), (1.1, np.float64), (1 + 1j, np.complex128), (True, object)],
  477. )
  478. def test_fillna_float64(self, index_or_series, fill_val, fill_dtype):
  479. klass = index_or_series
  480. obj = klass([1.1, np.nan, 3.3, 4.4])
  481. assert obj.dtype == np.float64
  482. exp = klass([1.1, fill_val, 3.3, 4.4])
  483. self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)
  484. @pytest.mark.parametrize(
  485. "fill_val,fill_dtype",
  486. [
  487. (1, np.complex128),
  488. (1.1, np.complex128),
  489. (1 + 1j, np.complex128),
  490. (True, object),
  491. ],
  492. )
  493. def test_fillna_complex128(self, index_or_series, fill_val, fill_dtype):
  494. klass = index_or_series
  495. obj = klass([1 + 1j, np.nan, 3 + 3j, 4 + 4j], dtype=np.complex128)
  496. assert obj.dtype == np.complex128
  497. exp = klass([1 + 1j, fill_val, 3 + 3j, 4 + 4j])
  498. self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)
  499. @pytest.mark.parametrize(
  500. "fill_val,fill_dtype",
  501. [
  502. (pd.Timestamp("2012-01-01"), "datetime64[ns]"),
  503. (pd.Timestamp("2012-01-01", tz="US/Eastern"), object),
  504. (1, object),
  505. ("x", object),
  506. ],
  507. ids=["datetime64", "datetime64tz", "object", "object"],
  508. )
  509. def test_fillna_datetime(self, index_or_series, fill_val, fill_dtype):
  510. klass = index_or_series
  511. obj = klass(
  512. [
  513. pd.Timestamp("2011-01-01"),
  514. pd.NaT,
  515. pd.Timestamp("2011-01-03"),
  516. pd.Timestamp("2011-01-04"),
  517. ]
  518. )
  519. assert obj.dtype == "datetime64[ns]"
  520. exp = klass(
  521. [
  522. pd.Timestamp("2011-01-01"),
  523. fill_val,
  524. pd.Timestamp("2011-01-03"),
  525. pd.Timestamp("2011-01-04"),
  526. ]
  527. )
  528. self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)
  529. @pytest.mark.parametrize(
  530. "fill_val,fill_dtype",
  531. [
  532. (pd.Timestamp("2012-01-01", tz="US/Eastern"), "datetime64[ns, US/Eastern]"),
  533. (pd.Timestamp("2012-01-01"), object),
  534. # pre-2.0 with a mismatched tz we would get object result
  535. (pd.Timestamp("2012-01-01", tz="Asia/Tokyo"), "datetime64[ns, US/Eastern]"),
  536. (1, object),
  537. ("x", object),
  538. ],
  539. )
  540. def test_fillna_datetime64tz(self, index_or_series, fill_val, fill_dtype):
  541. klass = index_or_series
  542. tz = "US/Eastern"
  543. obj = klass(
  544. [
  545. pd.Timestamp("2011-01-01", tz=tz),
  546. pd.NaT,
  547. pd.Timestamp("2011-01-03", tz=tz),
  548. pd.Timestamp("2011-01-04", tz=tz),
  549. ]
  550. )
  551. assert obj.dtype == "datetime64[ns, US/Eastern]"
  552. if getattr(fill_val, "tz", None) is None:
  553. fv = fill_val
  554. else:
  555. fv = fill_val.tz_convert(tz)
  556. exp = klass(
  557. [
  558. pd.Timestamp("2011-01-01", tz=tz),
  559. fv,
  560. pd.Timestamp("2011-01-03", tz=tz),
  561. pd.Timestamp("2011-01-04", tz=tz),
  562. ]
  563. )
  564. self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)
  565. @pytest.mark.parametrize(
  566. "fill_val",
  567. [
  568. 1,
  569. 1.1,
  570. 1 + 1j,
  571. True,
  572. pd.Interval(1, 2, closed="left"),
  573. pd.Timestamp("2012-01-01", tz="US/Eastern"),
  574. pd.Timestamp("2012-01-01"),
  575. pd.Timedelta(days=1),
  576. pd.Period("2016-01-01", "D"),
  577. ],
  578. )
  579. def test_fillna_interval(self, index_or_series, fill_val):
  580. ii = pd.interval_range(1.0, 5.0, closed="right").insert(1, np.nan)
  581. assert isinstance(ii.dtype, pd.IntervalDtype)
  582. obj = index_or_series(ii)
  583. exp = index_or_series([ii[0], fill_val, ii[2], ii[3], ii[4]], dtype=object)
  584. fill_dtype = object
  585. self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)
  586. @pytest.mark.xfail(reason="Test not implemented")
  587. def test_fillna_series_int64(self):
  588. raise NotImplementedError
  589. @pytest.mark.xfail(reason="Test not implemented")
  590. def test_fillna_index_int64(self):
  591. raise NotImplementedError
  592. @pytest.mark.xfail(reason="Test not implemented")
  593. def test_fillna_series_bool(self):
  594. raise NotImplementedError
  595. @pytest.mark.xfail(reason="Test not implemented")
  596. def test_fillna_index_bool(self):
  597. raise NotImplementedError
  598. @pytest.mark.xfail(reason="Test not implemented")
  599. def test_fillna_series_timedelta64(self):
  600. raise NotImplementedError
  601. @pytest.mark.parametrize(
  602. "fill_val",
  603. [
  604. 1,
  605. 1.1,
  606. 1 + 1j,
  607. True,
  608. pd.Interval(1, 2, closed="left"),
  609. pd.Timestamp("2012-01-01", tz="US/Eastern"),
  610. pd.Timestamp("2012-01-01"),
  611. pd.Timedelta(days=1),
  612. pd.Period("2016-01-01", "W"),
  613. ],
  614. )
  615. def test_fillna_series_period(self, index_or_series, fill_val):
  616. pi = pd.period_range("2016-01-01", periods=4, freq="D").insert(1, pd.NaT)
  617. assert isinstance(pi.dtype, pd.PeriodDtype)
  618. obj = index_or_series(pi)
  619. exp = index_or_series([pi[0], fill_val, pi[2], pi[3], pi[4]], dtype=object)
  620. fill_dtype = object
  621. self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)
  622. @pytest.mark.xfail(reason="Test not implemented")
  623. def test_fillna_index_timedelta64(self):
  624. raise NotImplementedError
  625. @pytest.mark.xfail(reason="Test not implemented")
  626. def test_fillna_index_period(self):
  627. raise NotImplementedError
  628. class TestReplaceSeriesCoercion(CoercionBase):
  629. klasses = ["series"]
  630. method = "replace"
  631. rep: dict[str, list] = {}
  632. rep["object"] = ["a", "b"]
  633. rep["int64"] = [4, 5]
  634. rep["float64"] = [1.1, 2.2]
  635. rep["complex128"] = [1 + 1j, 2 + 2j]
  636. rep["bool"] = [True, False]
  637. rep["datetime64[ns]"] = [pd.Timestamp("2011-01-01"), pd.Timestamp("2011-01-03")]
  638. for tz in ["UTC", "US/Eastern"]:
  639. # to test tz => different tz replacement
  640. key = f"datetime64[ns, {tz}]"
  641. rep[key] = [
  642. pd.Timestamp("2011-01-01", tz=tz),
  643. pd.Timestamp("2011-01-03", tz=tz),
  644. ]
  645. rep["timedelta64[ns]"] = [pd.Timedelta("1 day"), pd.Timedelta("2 day")]
  646. @pytest.fixture(params=["dict", "series"])
  647. def how(self, request):
  648. return request.param
  649. @pytest.fixture(
  650. params=[
  651. "object",
  652. "int64",
  653. "float64",
  654. "complex128",
  655. "bool",
  656. "datetime64[ns]",
  657. "datetime64[ns, UTC]",
  658. "datetime64[ns, US/Eastern]",
  659. "timedelta64[ns]",
  660. ]
  661. )
  662. def from_key(self, request):
  663. return request.param
  664. @pytest.fixture(
  665. params=[
  666. "object",
  667. "int64",
  668. "float64",
  669. "complex128",
  670. "bool",
  671. "datetime64[ns]",
  672. "datetime64[ns, UTC]",
  673. "datetime64[ns, US/Eastern]",
  674. "timedelta64[ns]",
  675. ],
  676. ids=[
  677. "object",
  678. "int64",
  679. "float64",
  680. "complex128",
  681. "bool",
  682. "datetime64",
  683. "datetime64tz",
  684. "datetime64tz",
  685. "timedelta64",
  686. ],
  687. )
  688. def to_key(self, request):
  689. return request.param
  690. @pytest.fixture
  691. def replacer(self, how, from_key, to_key):
  692. """
  693. Object we will pass to `Series.replace`
  694. """
  695. if how == "dict":
  696. replacer = dict(zip(self.rep[from_key], self.rep[to_key]))
  697. elif how == "series":
  698. replacer = pd.Series(self.rep[to_key], index=self.rep[from_key])
  699. else:
  700. raise ValueError
  701. return replacer
  702. def test_replace_series(self, how, to_key, from_key, replacer, using_infer_string):
  703. index = pd.Index([3, 4], name="xxx")
  704. obj = pd.Series(self.rep[from_key], index=index, name="yyy")
  705. obj = obj.astype(from_key)
  706. assert obj.dtype == from_key
  707. if from_key.startswith("datetime") and to_key.startswith("datetime"):
  708. # tested below
  709. return
  710. elif from_key in ["datetime64[ns, US/Eastern]", "datetime64[ns, UTC]"]:
  711. # tested below
  712. return
  713. if (from_key == "float64" and to_key in ("int64")) or (
  714. from_key == "complex128" and to_key in ("int64", "float64")
  715. ):
  716. if not IS64 or is_platform_windows():
  717. pytest.skip(f"32-bit platform buggy: {from_key} -> {to_key}")
  718. # Expected: do not downcast by replacement
  719. exp = pd.Series(self.rep[to_key], index=index, name="yyy", dtype=from_key)
  720. else:
  721. exp = pd.Series(self.rep[to_key], index=index, name="yyy")
  722. if using_infer_string and exp.dtype == "string":
  723. # with infer_string, we disable the deprecated downcasting behavior
  724. exp = exp.astype(object)
  725. msg = "Downcasting behavior in `replace`"
  726. warn = FutureWarning
  727. if (
  728. exp.dtype == obj.dtype
  729. or exp.dtype == object
  730. or (exp.dtype.kind in "iufc" and obj.dtype.kind in "iufc")
  731. ):
  732. warn = None
  733. with tm.assert_produces_warning(warn, match=msg):
  734. result = obj.replace(replacer)
  735. tm.assert_series_equal(result, exp)
  736. @pytest.mark.parametrize(
  737. "to_key",
  738. ["timedelta64[ns]", "bool", "object", "complex128", "float64", "int64"],
  739. indirect=True,
  740. )
  741. @pytest.mark.parametrize(
  742. "from_key", ["datetime64[ns, UTC]", "datetime64[ns, US/Eastern]"], indirect=True
  743. )
  744. def test_replace_series_datetime_tz(
  745. self, how, to_key, from_key, replacer, using_infer_string
  746. ):
  747. index = pd.Index([3, 4], name="xyz")
  748. obj = pd.Series(self.rep[from_key], index=index, name="yyy")
  749. assert obj.dtype == from_key
  750. exp = pd.Series(self.rep[to_key], index=index, name="yyy")
  751. if using_infer_string and exp.dtype == "string":
  752. # with infer_string, we disable the deprecated downcasting behavior
  753. exp = exp.astype(object)
  754. else:
  755. assert exp.dtype == to_key
  756. msg = "Downcasting behavior in `replace`"
  757. warn = FutureWarning if exp.dtype != object else None
  758. with tm.assert_produces_warning(warn, match=msg):
  759. result = obj.replace(replacer)
  760. tm.assert_series_equal(result, exp)
  761. @pytest.mark.parametrize(
  762. "to_key",
  763. ["datetime64[ns]", "datetime64[ns, UTC]", "datetime64[ns, US/Eastern]"],
  764. indirect=True,
  765. )
  766. @pytest.mark.parametrize(
  767. "from_key",
  768. ["datetime64[ns]", "datetime64[ns, UTC]", "datetime64[ns, US/Eastern]"],
  769. indirect=True,
  770. )
  771. def test_replace_series_datetime_datetime(self, how, to_key, from_key, replacer):
  772. index = pd.Index([3, 4], name="xyz")
  773. obj = pd.Series(self.rep[from_key], index=index, name="yyy")
  774. assert obj.dtype == from_key
  775. exp = pd.Series(self.rep[to_key], index=index, name="yyy")
  776. warn = FutureWarning
  777. if isinstance(obj.dtype, pd.DatetimeTZDtype) and isinstance(
  778. exp.dtype, pd.DatetimeTZDtype
  779. ):
  780. # with mismatched tzs, we retain the original dtype as of 2.0
  781. exp = exp.astype(obj.dtype)
  782. warn = None
  783. else:
  784. assert exp.dtype == to_key
  785. if to_key == from_key:
  786. warn = None
  787. msg = "Downcasting behavior in `replace`"
  788. with tm.assert_produces_warning(warn, match=msg):
  789. result = obj.replace(replacer)
  790. tm.assert_series_equal(result, exp)
  791. @pytest.mark.xfail(reason="Test not implemented")
  792. def test_replace_series_period(self):
  793. raise NotImplementedError