test_to_xarray.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import numpy as np
  2. import pytest
  3. from pandas import (
  4. Categorical,
  5. DataFrame,
  6. MultiIndex,
  7. Series,
  8. StringDtype,
  9. date_range,
  10. )
  11. import pandas._testing as tm
  12. from pandas.util.version import Version
  13. xarray = pytest.importorskip("xarray")
  14. class TestDataFrameToXArray:
  15. @pytest.fixture
  16. def df(self):
  17. return DataFrame(
  18. {
  19. "a": list("abcd"),
  20. "b": list(range(1, 5)),
  21. "c": np.arange(3, 7).astype("u1"),
  22. "d": np.arange(4.0, 8.0, dtype="float64"),
  23. "e": [True, False, True, False],
  24. "f": Categorical(list("abcd")),
  25. "g": date_range("20130101", periods=4),
  26. "h": date_range("20130101", periods=4, tz="US/Eastern"),
  27. }
  28. )
  29. def test_to_xarray_index_types(self, index_flat, df, using_infer_string):
  30. index = index_flat
  31. # MultiIndex is tested in test_to_xarray_with_multiindex
  32. if len(index) == 0:
  33. pytest.skip("Test doesn't make sense for empty index")
  34. from xarray import Dataset
  35. df.index = index[:4]
  36. df.index.name = "foo"
  37. df.columns.name = "bar"
  38. result = df.to_xarray()
  39. assert result.sizes["foo"] == 4
  40. assert len(result.coords) == 1
  41. assert len(result.data_vars) == 8
  42. tm.assert_almost_equal(list(result.coords.keys()), ["foo"])
  43. assert isinstance(result, Dataset)
  44. # idempotency
  45. # datetimes w/tz are preserved
  46. # column names are lost
  47. expected = df.copy()
  48. expected["f"] = expected["f"].astype(
  49. object if not using_infer_string else "str"
  50. )
  51. expected.columns.name = None
  52. tm.assert_frame_equal(result.to_dataframe(), expected)
  53. def test_to_xarray_empty(self, df):
  54. from xarray import Dataset
  55. df.index.name = "foo"
  56. result = df[0:0].to_xarray()
  57. assert result.sizes["foo"] == 0
  58. assert isinstance(result, Dataset)
  59. def test_to_xarray_with_multiindex(self, df, using_infer_string):
  60. from xarray import Dataset
  61. # MultiIndex
  62. df.index = MultiIndex.from_product([["a"], range(4)], names=["one", "two"])
  63. result = df.to_xarray()
  64. assert result.sizes["one"] == 1
  65. assert result.sizes["two"] == 4
  66. assert len(result.coords) == 2
  67. assert len(result.data_vars) == 8
  68. tm.assert_almost_equal(list(result.coords.keys()), ["one", "two"])
  69. assert isinstance(result, Dataset)
  70. result = result.to_dataframe()
  71. expected = df.copy()
  72. expected["f"] = expected["f"].astype(
  73. object if not using_infer_string else "str"
  74. )
  75. expected.columns.name = None
  76. tm.assert_frame_equal(result, expected)
  77. class TestSeriesToXArray:
  78. def test_to_xarray_index_types(self, index_flat, request):
  79. index = index_flat
  80. if (
  81. isinstance(index.dtype, StringDtype)
  82. and index.dtype.storage == "pyarrow"
  83. and Version(xarray.__version__) > Version("2024.9.0")
  84. and Version(xarray.__version__) < Version("2025.6.0")
  85. ):
  86. request.applymarker(
  87. pytest.mark.xfail(
  88. reason="xarray calling reshape of ArrowExtensionArray",
  89. raises=NotImplementedError,
  90. )
  91. )
  92. # MultiIndex is tested in test_to_xarray_with_multiindex
  93. from xarray import DataArray
  94. ser = Series(range(len(index)), index=index, dtype="int64")
  95. ser.index.name = "foo"
  96. result = ser.to_xarray()
  97. repr(result)
  98. assert len(result) == len(index)
  99. assert len(result.coords) == 1
  100. tm.assert_almost_equal(list(result.coords.keys()), ["foo"])
  101. assert isinstance(result, DataArray)
  102. # idempotency
  103. tm.assert_series_equal(result.to_series(), ser)
  104. def test_to_xarray_empty(self):
  105. from xarray import DataArray
  106. ser = Series([], dtype=object)
  107. ser.index.name = "foo"
  108. result = ser.to_xarray()
  109. assert len(result) == 0
  110. assert len(result.coords) == 1
  111. tm.assert_almost_equal(list(result.coords.keys()), ["foo"])
  112. assert isinstance(result, DataArray)
  113. def test_to_xarray_with_multiindex(self):
  114. from xarray import DataArray
  115. mi = MultiIndex.from_product([["a", "b"], range(3)], names=["one", "two"])
  116. ser = Series(range(6), dtype="int64", index=mi)
  117. result = ser.to_xarray()
  118. assert len(result) == 2
  119. tm.assert_almost_equal(list(result.coords.keys()), ["one", "two"])
  120. assert isinstance(result, DataArray)
  121. res = result.to_series()
  122. tm.assert_series_equal(res, ser)