test_truncate.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import numpy as np
  2. import pytest
  3. import pandas as pd
  4. from pandas import (
  5. DataFrame,
  6. DatetimeIndex,
  7. Index,
  8. Series,
  9. date_range,
  10. )
  11. import pandas._testing as tm
  12. class TestDataFrameTruncate:
  13. def test_truncate(self, datetime_frame, frame_or_series):
  14. ts = datetime_frame[::3]
  15. ts = tm.get_obj(ts, frame_or_series)
  16. start, end = datetime_frame.index[3], datetime_frame.index[6]
  17. start_missing = datetime_frame.index[2]
  18. end_missing = datetime_frame.index[7]
  19. # neither specified
  20. truncated = ts.truncate()
  21. tm.assert_equal(truncated, ts)
  22. # both specified
  23. expected = ts[1:3]
  24. truncated = ts.truncate(start, end)
  25. tm.assert_equal(truncated, expected)
  26. truncated = ts.truncate(start_missing, end_missing)
  27. tm.assert_equal(truncated, expected)
  28. # start specified
  29. expected = ts[1:]
  30. truncated = ts.truncate(before=start)
  31. tm.assert_equal(truncated, expected)
  32. truncated = ts.truncate(before=start_missing)
  33. tm.assert_equal(truncated, expected)
  34. # end specified
  35. expected = ts[:3]
  36. truncated = ts.truncate(after=end)
  37. tm.assert_equal(truncated, expected)
  38. truncated = ts.truncate(after=end_missing)
  39. tm.assert_equal(truncated, expected)
  40. # corner case, empty series/frame returned
  41. truncated = ts.truncate(after=ts.index[0] - ts.index.freq)
  42. assert len(truncated) == 0
  43. truncated = ts.truncate(before=ts.index[-1] + ts.index.freq)
  44. assert len(truncated) == 0
  45. msg = "Truncate: 2000-01-06 00:00:00 must be after 2000-05-16 00:00:00"
  46. with pytest.raises(ValueError, match=msg):
  47. ts.truncate(
  48. before=ts.index[-1] - ts.index.freq, after=ts.index[0] + ts.index.freq
  49. )
  50. def test_truncate_nonsortedindex(self, frame_or_series):
  51. # GH#17935
  52. obj = DataFrame({"A": ["a", "b", "c", "d", "e"]}, index=[5, 3, 2, 9, 0])
  53. obj = tm.get_obj(obj, frame_or_series)
  54. msg = "truncate requires a sorted index"
  55. with pytest.raises(ValueError, match=msg):
  56. obj.truncate(before=3, after=9)
  57. def test_sort_values_nonsortedindex(self):
  58. rng = date_range("2011-01-01", "2012-01-01", freq="W")
  59. ts = DataFrame(
  60. {
  61. "A": np.random.default_rng(2).standard_normal(len(rng)),
  62. "B": np.random.default_rng(2).standard_normal(len(rng)),
  63. },
  64. index=rng,
  65. )
  66. decreasing = ts.sort_values("A", ascending=False)
  67. msg = "truncate requires a sorted index"
  68. with pytest.raises(ValueError, match=msg):
  69. decreasing.truncate(before="2011-11", after="2011-12")
  70. def test_truncate_nonsortedindex_axis1(self):
  71. # GH#17935
  72. df = DataFrame(
  73. {
  74. 3: np.random.default_rng(2).standard_normal(5),
  75. 20: np.random.default_rng(2).standard_normal(5),
  76. 2: np.random.default_rng(2).standard_normal(5),
  77. 0: np.random.default_rng(2).standard_normal(5),
  78. },
  79. columns=[3, 20, 2, 0],
  80. )
  81. msg = "truncate requires a sorted index"
  82. with pytest.raises(ValueError, match=msg):
  83. df.truncate(before=2, after=20, axis=1)
  84. @pytest.mark.parametrize(
  85. "before, after, indices",
  86. [(1, 2, [2, 1]), (None, 2, [2, 1, 0]), (1, None, [3, 2, 1])],
  87. )
  88. @pytest.mark.parametrize("dtyp", [*tm.ALL_REAL_NUMPY_DTYPES, "datetime64[ns]"])
  89. def test_truncate_decreasing_index(
  90. self, before, after, indices, dtyp, frame_or_series
  91. ):
  92. # https://github.com/pandas-dev/pandas/issues/33756
  93. idx = Index([3, 2, 1, 0], dtype=dtyp)
  94. if isinstance(idx, DatetimeIndex):
  95. before = pd.Timestamp(before) if before is not None else None
  96. after = pd.Timestamp(after) if after is not None else None
  97. indices = [pd.Timestamp(i) for i in indices]
  98. values = frame_or_series(range(len(idx)), index=idx)
  99. result = values.truncate(before=before, after=after)
  100. expected = values.loc[indices]
  101. tm.assert_equal(result, expected)
  102. def test_truncate_multiindex(self, frame_or_series):
  103. # GH 34564
  104. mi = pd.MultiIndex.from_product([[1, 2, 3, 4], ["A", "B"]], names=["L1", "L2"])
  105. s1 = DataFrame(range(mi.shape[0]), index=mi, columns=["col"])
  106. s1 = tm.get_obj(s1, frame_or_series)
  107. result = s1.truncate(before=2, after=3)
  108. df = DataFrame.from_dict(
  109. {"L1": [2, 2, 3, 3], "L2": ["A", "B", "A", "B"], "col": [2, 3, 4, 5]}
  110. )
  111. expected = df.set_index(["L1", "L2"])
  112. expected = tm.get_obj(expected, frame_or_series)
  113. tm.assert_equal(result, expected)
  114. def test_truncate_index_only_one_unique_value(self, frame_or_series):
  115. # GH 42365
  116. obj = Series(0, index=date_range("2021-06-30", "2021-06-30")).repeat(5)
  117. if frame_or_series is DataFrame:
  118. obj = obj.to_frame(name="a")
  119. truncated = obj.truncate("2021-06-28", "2021-07-01")
  120. tm.assert_equal(truncated, obj)