test_overlaps.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import pytest
  2. from pandas import (
  3. Interval,
  4. Timedelta,
  5. Timestamp,
  6. )
  7. @pytest.fixture(
  8. params=[
  9. (Timedelta("0 days"), Timedelta("1 day")),
  10. (Timestamp("2018-01-01"), Timedelta("1 day")),
  11. (0, 1),
  12. ],
  13. ids=lambda x: type(x[0]).__name__,
  14. )
  15. def start_shift(request):
  16. """
  17. Fixture for generating intervals of types from a start value and a shift
  18. value that can be added to start to generate an endpoint
  19. """
  20. return request.param
  21. class TestOverlaps:
  22. def test_overlaps_self(self, start_shift, closed):
  23. start, shift = start_shift
  24. interval = Interval(start, start + shift, closed)
  25. assert interval.overlaps(interval)
  26. def test_overlaps_nested(self, start_shift, closed, other_closed):
  27. start, shift = start_shift
  28. interval1 = Interval(start, start + 3 * shift, other_closed)
  29. interval2 = Interval(start + shift, start + 2 * shift, closed)
  30. # nested intervals should always overlap
  31. assert interval1.overlaps(interval2)
  32. def test_overlaps_disjoint(self, start_shift, closed, other_closed):
  33. start, shift = start_shift
  34. interval1 = Interval(start, start + shift, other_closed)
  35. interval2 = Interval(start + 2 * shift, start + 3 * shift, closed)
  36. # disjoint intervals should never overlap
  37. assert not interval1.overlaps(interval2)
  38. def test_overlaps_endpoint(self, start_shift, closed, other_closed):
  39. start, shift = start_shift
  40. interval1 = Interval(start, start + shift, other_closed)
  41. interval2 = Interval(start + shift, start + 2 * shift, closed)
  42. # overlap if shared endpoint is closed for both (overlap at a point)
  43. result = interval1.overlaps(interval2)
  44. expected = interval1.closed_right and interval2.closed_left
  45. assert result == expected
  46. @pytest.mark.parametrize(
  47. "other",
  48. [10, True, "foo", Timedelta("1 day"), Timestamp("2018-01-01")],
  49. ids=lambda x: type(x).__name__,
  50. )
  51. def test_overlaps_invalid_type(self, other):
  52. interval = Interval(0, 1)
  53. msg = f"`other` must be an Interval, got {type(other).__name__}"
  54. with pytest.raises(TypeError, match=msg):
  55. interval.overlaps(other)