test_register_accessor.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from collections.abc import Generator
  2. import contextlib
  3. import pytest
  4. import pandas as pd
  5. import pandas._testing as tm
  6. from pandas.core import accessor
  7. def test_dirname_mixin() -> None:
  8. # GH37173
  9. class X(accessor.DirNamesMixin):
  10. x = 1
  11. y: int
  12. def __init__(self) -> None:
  13. self.z = 3
  14. result = [attr_name for attr_name in dir(X()) if not attr_name.startswith("_")]
  15. assert result == ["x", "z"]
  16. @contextlib.contextmanager
  17. def ensure_removed(obj, attr) -> Generator[None, None, None]:
  18. """Ensure that an attribute added to 'obj' during the test is
  19. removed when we're done
  20. """
  21. try:
  22. yield
  23. finally:
  24. try:
  25. delattr(obj, attr)
  26. except AttributeError:
  27. pass
  28. obj._accessors.discard(attr)
  29. class MyAccessor:
  30. def __init__(self, obj) -> None:
  31. self.obj = obj
  32. self.item = "item"
  33. @property
  34. def prop(self):
  35. return self.item
  36. def method(self):
  37. return self.item
  38. @pytest.mark.parametrize(
  39. "obj, registrar",
  40. [
  41. (pd.Series, pd.api.extensions.register_series_accessor),
  42. (pd.DataFrame, pd.api.extensions.register_dataframe_accessor),
  43. (pd.Index, pd.api.extensions.register_index_accessor),
  44. ],
  45. )
  46. def test_register(obj, registrar):
  47. with ensure_removed(obj, "mine"):
  48. before = set(dir(obj))
  49. registrar("mine")(MyAccessor)
  50. o = obj([]) if obj is not pd.Series else obj([], dtype=object)
  51. assert o.mine.prop == "item"
  52. after = set(dir(obj))
  53. assert (before ^ after) == {"mine"}
  54. assert "mine" in obj._accessors
  55. def test_accessor_works():
  56. with ensure_removed(pd.Series, "mine"):
  57. pd.api.extensions.register_series_accessor("mine")(MyAccessor)
  58. s = pd.Series([1, 2])
  59. assert s.mine.obj is s
  60. assert s.mine.prop == "item"
  61. assert s.mine.method() == "item"
  62. def test_overwrite_warns():
  63. match = r".*MyAccessor.*fake.*Series.*"
  64. with tm.assert_produces_warning(UserWarning, match=match):
  65. with ensure_removed(pd.Series, "fake"):
  66. setattr(pd.Series, "fake", 123)
  67. pd.api.extensions.register_series_accessor("fake")(MyAccessor)
  68. s = pd.Series([1, 2])
  69. assert s.fake.prop == "item"
  70. def test_raises_attribute_error():
  71. with ensure_removed(pd.Series, "bad"):
  72. @pd.api.extensions.register_series_accessor("bad")
  73. class Bad:
  74. def __init__(self, data) -> None:
  75. raise AttributeError("whoops")
  76. with pytest.raises(AttributeError, match="whoops"):
  77. pd.Series([], dtype=object).bad