test_common.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import numpy as np
  2. import pytest
  3. from pandas.core.dtypes import dtypes
  4. from pandas.core.dtypes.common import is_extension_array_dtype
  5. import pandas as pd
  6. import pandas._testing as tm
  7. from pandas.core.arrays import ExtensionArray
  8. class DummyDtype(dtypes.ExtensionDtype):
  9. pass
  10. class DummyArray(ExtensionArray):
  11. def __init__(self, data) -> None:
  12. self.data = data
  13. def __array__(self, dtype=None, copy=None):
  14. return self.data
  15. @property
  16. def dtype(self):
  17. return DummyDtype()
  18. def astype(self, dtype, copy=True):
  19. # we don't support anything but a single dtype
  20. if isinstance(dtype, DummyDtype):
  21. if copy:
  22. return type(self)(self.data)
  23. return self
  24. elif not copy:
  25. return np.asarray(self, dtype=dtype)
  26. else:
  27. return np.array(self, dtype=dtype, copy=copy)
  28. class TestExtensionArrayDtype:
  29. @pytest.mark.parametrize(
  30. "values",
  31. [
  32. pd.Categorical([]),
  33. pd.Categorical([]).dtype,
  34. pd.Series(pd.Categorical([])),
  35. DummyDtype(),
  36. DummyArray(np.array([1, 2])),
  37. ],
  38. )
  39. def test_is_extension_array_dtype(self, values):
  40. assert is_extension_array_dtype(values)
  41. @pytest.mark.parametrize("values", [np.array([]), pd.Series(np.array([]))])
  42. def test_is_not_extension_array_dtype(self, values):
  43. assert not is_extension_array_dtype(values)
  44. def test_astype():
  45. arr = DummyArray(np.array([1, 2, 3]))
  46. expected = np.array([1, 2, 3], dtype=object)
  47. result = arr.astype(object)
  48. tm.assert_numpy_array_equal(result, expected)
  49. result = arr.astype("object")
  50. tm.assert_numpy_array_equal(result, expected)
  51. def test_astype_no_copy():
  52. arr = DummyArray(np.array([1, 2, 3], dtype=np.int64))
  53. result = arr.astype(arr.dtype, copy=False)
  54. assert arr is result
  55. result = arr.astype(arr.dtype)
  56. assert arr is not result
  57. @pytest.mark.parametrize("dtype", [dtypes.CategoricalDtype(), dtypes.IntervalDtype()])
  58. def test_is_extension_array_dtype(dtype):
  59. assert isinstance(dtype, dtypes.ExtensionDtype)
  60. assert is_extension_array_dtype(dtype)
  61. class CapturingStringArray(pd.arrays.StringArray):
  62. """Extend StringArray to capture arguments to __getitem__"""
  63. def __getitem__(self, item):
  64. self.last_item_arg = item
  65. return super().__getitem__(item)
  66. def test_ellipsis_index():
  67. # GH#42430 1D slices over extension types turn into N-dimensional slices
  68. # over ExtensionArrays
  69. df = pd.DataFrame(
  70. {"col1": CapturingStringArray(np.array(["hello", "world"], dtype=object))}
  71. )
  72. _ = df.iloc[:1]
  73. # String comparison because there's no native way to compare slices.
  74. # Before the fix for GH#42430, last_item_arg would get set to the 2D slice
  75. # (Ellipsis, slice(None, 1, None))
  76. out = df["col1"].array.last_item_arg
  77. assert str(out) == "slice(None, 1, None)"