test_interval.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. """
  2. This file contains a minimal set of tests for compliance with the extension
  3. array interface test suite, and should contain no other tests.
  4. The test suite for the full functionality of the array is located in
  5. `pandas/tests/arrays/`.
  6. The tests in this file are inherited from the BaseExtensionTests, and only
  7. minimal tweaks should be applied to get the tests passing (by overwriting a
  8. parent method).
  9. Additional tests should either be added to one of the BaseExtensionTests
  10. classes (if they are relevant for the extension interface for all dtypes), or
  11. be added to the array-specific tests in `pandas/tests/arrays/`.
  12. """
  13. from __future__ import annotations
  14. from typing import TYPE_CHECKING
  15. import numpy as np
  16. import pytest
  17. from pandas.core.dtypes.dtypes import IntervalDtype
  18. from pandas import Interval
  19. from pandas.core.arrays import IntervalArray
  20. from pandas.tests.extension import base
  21. if TYPE_CHECKING:
  22. import pandas as pd
  23. def make_data():
  24. N = 100
  25. left_array = np.random.default_rng(2).uniform(size=N).cumsum()
  26. right_array = left_array + np.random.default_rng(2).uniform(size=N)
  27. return [Interval(left, right) for left, right in zip(left_array, right_array)]
  28. @pytest.fixture
  29. def dtype():
  30. return IntervalDtype()
  31. @pytest.fixture
  32. def data():
  33. """Length-100 PeriodArray for semantics test."""
  34. return IntervalArray(make_data())
  35. @pytest.fixture
  36. def data_missing():
  37. """Length 2 array with [NA, Valid]"""
  38. return IntervalArray.from_tuples([None, (0, 1)])
  39. @pytest.fixture
  40. def data_for_twos():
  41. pytest.skip("Interval is not a numeric dtype")
  42. @pytest.fixture
  43. def data_for_sorting():
  44. return IntervalArray.from_tuples([(1, 2), (2, 3), (0, 1)])
  45. @pytest.fixture
  46. def data_missing_for_sorting():
  47. return IntervalArray.from_tuples([(1, 2), None, (0, 1)])
  48. @pytest.fixture
  49. def data_for_grouping():
  50. a = (0, 1)
  51. b = (1, 2)
  52. c = (2, 3)
  53. return IntervalArray.from_tuples([b, b, None, None, a, a, b, c])
  54. class TestIntervalArray(base.ExtensionTests):
  55. divmod_exc = TypeError
  56. def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
  57. return op_name in ["min", "max"]
  58. @pytest.mark.xfail(
  59. reason="Raises with incorrect message bc it disallows *all* listlikes "
  60. "instead of just wrong-length listlikes"
  61. )
  62. def test_fillna_length_mismatch(self, data_missing):
  63. super().test_fillna_length_mismatch(data_missing)
  64. @pytest.mark.filterwarnings(
  65. "ignore:invalid value encountered in cast:RuntimeWarning"
  66. )
  67. def test_hash_pandas_object(self, data):
  68. super().test_hash_pandas_object(data)
  69. @pytest.mark.filterwarnings(
  70. "ignore:invalid value encountered in cast:RuntimeWarning"
  71. )
  72. def test_hash_pandas_object_works(self, data, as_frame):
  73. super().test_hash_pandas_object_works(data, as_frame)
  74. @pytest.mark.filterwarnings(
  75. "ignore:invalid value encountered in cast:RuntimeWarning"
  76. )
  77. @pytest.mark.parametrize("engine", ["c", "python"])
  78. def test_EA_types(self, engine, data, request):
  79. super().test_EA_types(engine, data, request)
  80. @pytest.mark.filterwarnings(
  81. "ignore:invalid value encountered in cast:RuntimeWarning"
  82. )
  83. def test_astype_str(self, data):
  84. super().test_astype_str(data)
  85. # TODO: either belongs in tests.arrays.interval or move into base tests.
  86. def test_fillna_non_scalar_raises(data_missing):
  87. msg = "can only insert Interval objects and NA into an IntervalArray"
  88. with pytest.raises(TypeError, match=msg):
  89. data_missing.fillna([1, 1])