conftest.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. from __future__ import annotations
  2. import os
  3. import pytest
  4. from pandas.compat import HAS_PYARROW
  5. from pandas.compat._optional import VERSIONS
  6. from pandas import (
  7. read_csv,
  8. read_table,
  9. )
  10. import pandas._testing as tm
  11. class BaseParser:
  12. engine: str | None = None
  13. low_memory = True
  14. float_precision_choices: list[str | None] = []
  15. def update_kwargs(self, kwargs):
  16. kwargs = kwargs.copy()
  17. kwargs.update({"engine": self.engine, "low_memory": self.low_memory})
  18. return kwargs
  19. def read_csv(self, *args, **kwargs):
  20. kwargs = self.update_kwargs(kwargs)
  21. return read_csv(*args, **kwargs)
  22. def read_csv_check_warnings(
  23. self,
  24. warn_type: type[Warning],
  25. warn_msg: str,
  26. *args,
  27. raise_on_extra_warnings=True,
  28. check_stacklevel: bool = True,
  29. **kwargs,
  30. ):
  31. # We need to check the stacklevel here instead of in the tests
  32. # since this is where read_csv is called and where the warning
  33. # should point to.
  34. kwargs = self.update_kwargs(kwargs)
  35. with tm.assert_produces_warning(
  36. warn_type,
  37. match=warn_msg,
  38. raise_on_extra_warnings=raise_on_extra_warnings,
  39. check_stacklevel=check_stacklevel,
  40. ):
  41. return read_csv(*args, **kwargs)
  42. def read_table(self, *args, **kwargs):
  43. kwargs = self.update_kwargs(kwargs)
  44. return read_table(*args, **kwargs)
  45. def read_table_check_warnings(
  46. self,
  47. warn_type: type[Warning],
  48. warn_msg: str,
  49. *args,
  50. raise_on_extra_warnings=True,
  51. **kwargs,
  52. ):
  53. # We need to check the stacklevel here instead of in the tests
  54. # since this is where read_table is called and where the warning
  55. # should point to.
  56. kwargs = self.update_kwargs(kwargs)
  57. with tm.assert_produces_warning(
  58. warn_type, match=warn_msg, raise_on_extra_warnings=raise_on_extra_warnings
  59. ):
  60. return read_table(*args, **kwargs)
  61. class CParser(BaseParser):
  62. engine = "c"
  63. float_precision_choices = [None, "high", "round_trip"]
  64. class CParserHighMemory(CParser):
  65. low_memory = False
  66. class CParserLowMemory(CParser):
  67. low_memory = True
  68. class PythonParser(BaseParser):
  69. engine = "python"
  70. float_precision_choices = [None]
  71. class PyArrowParser(BaseParser):
  72. engine = "pyarrow"
  73. float_precision_choices = [None]
  74. @pytest.fixture
  75. def csv_dir_path(datapath):
  76. """
  77. The directory path to the data files needed for parser tests.
  78. """
  79. return datapath("io", "parser", "data")
  80. @pytest.fixture
  81. def csv1(datapath):
  82. """
  83. The path to the data file "test1.csv" needed for parser tests.
  84. """
  85. return os.path.join(datapath("io", "data", "csv"), "test1.csv")
  86. _cParserHighMemory = CParserHighMemory
  87. _cParserLowMemory = CParserLowMemory
  88. _pythonParser = PythonParser
  89. _pyarrowParser = PyArrowParser
  90. _py_parsers_only = [_pythonParser]
  91. _c_parsers_only = [_cParserHighMemory, _cParserLowMemory]
  92. _pyarrow_parsers_only = [
  93. pytest.param(
  94. _pyarrowParser,
  95. marks=[
  96. pytest.mark.single_cpu,
  97. pytest.mark.skipif(not HAS_PYARROW, reason="pyarrow is not installed"),
  98. ],
  99. )
  100. ]
  101. _all_parsers = [*_c_parsers_only, *_py_parsers_only, *_pyarrow_parsers_only]
  102. _py_parser_ids = ["python"]
  103. _c_parser_ids = ["c_high", "c_low"]
  104. _pyarrow_parsers_ids = ["pyarrow"]
  105. _all_parser_ids = [*_c_parser_ids, *_py_parser_ids, *_pyarrow_parsers_ids]
  106. @pytest.fixture(params=_all_parsers, ids=_all_parser_ids)
  107. def all_parsers(request):
  108. """
  109. Fixture all of the CSV parsers.
  110. """
  111. parser = request.param()
  112. if parser.engine == "pyarrow":
  113. pytest.importorskip("pyarrow", VERSIONS["pyarrow"])
  114. # Try finding a way to disable threads all together
  115. # for more stable CI runs
  116. import pyarrow
  117. pyarrow.set_cpu_count(1)
  118. return parser
  119. @pytest.fixture(params=_c_parsers_only, ids=_c_parser_ids)
  120. def c_parser_only(request):
  121. """
  122. Fixture all of the CSV parsers using the C engine.
  123. """
  124. return request.param()
  125. @pytest.fixture(params=_py_parsers_only, ids=_py_parser_ids)
  126. def python_parser_only(request):
  127. """
  128. Fixture all of the CSV parsers using the Python engine.
  129. """
  130. return request.param()
  131. @pytest.fixture(params=_pyarrow_parsers_only, ids=_pyarrow_parsers_ids)
  132. def pyarrow_parser_only(request):
  133. """
  134. Fixture all of the CSV parsers using the Pyarrow engine.
  135. """
  136. return request.param()
  137. def _get_all_parser_float_precision_combinations():
  138. """
  139. Return all allowable parser and float precision
  140. combinations and corresponding ids.
  141. """
  142. params = []
  143. ids = []
  144. for parser, parser_id in zip(_all_parsers, _all_parser_ids):
  145. if hasattr(parser, "values"):
  146. # Wrapped in pytest.param, get the actual parser back
  147. parser = parser.values[0]
  148. for precision in parser.float_precision_choices:
  149. # Re-wrap in pytest.param for pyarrow
  150. mark = (
  151. [
  152. pytest.mark.single_cpu,
  153. pytest.mark.skipif(
  154. not HAS_PYARROW, reason="pyarrow is not installed"
  155. ),
  156. ]
  157. if parser.engine == "pyarrow"
  158. else ()
  159. )
  160. param = pytest.param((parser(), precision), marks=mark)
  161. params.append(param)
  162. ids.append(f"{parser_id}-{precision}")
  163. return {"params": params, "ids": ids}
  164. @pytest.fixture(
  165. params=_get_all_parser_float_precision_combinations()["params"],
  166. ids=_get_all_parser_float_precision_combinations()["ids"],
  167. )
  168. def all_parsers_all_precisions(request):
  169. """
  170. Fixture for all allowable combinations of parser
  171. and float precision
  172. """
  173. return request.param
  174. _utf_values = [8, 16, 32]
  175. _encoding_seps = ["", "-", "_"]
  176. _encoding_prefixes = ["utf", "UTF"]
  177. _encoding_fmts = [
  178. f"{prefix}{sep}{{0}}" for sep in _encoding_seps for prefix in _encoding_prefixes
  179. ]
  180. @pytest.fixture(params=_utf_values)
  181. def utf_value(request):
  182. """
  183. Fixture for all possible integer values for a UTF encoding.
  184. """
  185. return request.param
  186. @pytest.fixture(params=_encoding_fmts)
  187. def encoding_fmt(request):
  188. """
  189. Fixture for all possible string formats of a UTF encoding.
  190. """
  191. return request.param
  192. @pytest.fixture(
  193. params=[
  194. ("-1,0", -1.0),
  195. ("-1,2e0", -1.2),
  196. ("-1e0", -1.0),
  197. ("+1e0", 1.0),
  198. ("+1e+0", 1.0),
  199. ("+1e-1", 0.1),
  200. ("+,1e1", 1.0),
  201. ("+1,e0", 1.0),
  202. ("-,1e1", -1.0),
  203. ("-1,e0", -1.0),
  204. ("0,1", 0.1),
  205. ("1,", 1.0),
  206. (",1", 0.1),
  207. ("-,1", -0.1),
  208. ("1_,", 1.0),
  209. ("1_234,56", 1234.56),
  210. ("1_234,56e0", 1234.56),
  211. # negative cases; must not parse as float
  212. ("_", "_"),
  213. ("-_", "-_"),
  214. ("-_1", "-_1"),
  215. ("-_1e0", "-_1e0"),
  216. ("_1", "_1"),
  217. ("_1,", "_1,"),
  218. ("_1,_", "_1,_"),
  219. ("_1e0", "_1e0"),
  220. ("1,2e_1", "1,2e_1"),
  221. ("1,2e1_0", "1,2e1_0"),
  222. ("1,_2", "1,_2"),
  223. (",1__2", ",1__2"),
  224. (",1e", ",1e"),
  225. ("-,1e", "-,1e"),
  226. ("1_000,000_000", "1_000,000_000"),
  227. ("1,e1_2", "1,e1_2"),
  228. ("e11,2", "e11,2"),
  229. ("1e11,2", "1e11,2"),
  230. ("1,2,2", "1,2,2"),
  231. ("1,2_1", "1,2_1"),
  232. ("1,2e-10e1", "1,2e-10e1"),
  233. ("--1,2", "--1,2"),
  234. ("1a_2,1", "1a_2,1"),
  235. ("1,2E-1", 0.12),
  236. ("1,2E1", 12.0),
  237. ]
  238. )
  239. def numeric_decimal(request):
  240. """
  241. Fixture for all numeric formats which should get recognized. The first entry
  242. represents the value to read while the second represents the expected result.
  243. """
  244. return request.param
  245. @pytest.fixture
  246. def pyarrow_xfail(request):
  247. """
  248. Fixture that xfails a test if the engine is pyarrow.
  249. Use if failure is do to unsupported keywords or inconsistent results.
  250. """
  251. if "all_parsers" in request.fixturenames:
  252. parser = request.getfixturevalue("all_parsers")
  253. elif "all_parsers_all_precisions" in request.fixturenames:
  254. # Return value is tuple of (engine, precision)
  255. parser = request.getfixturevalue("all_parsers_all_precisions")[0]
  256. else:
  257. return
  258. if parser.engine == "pyarrow":
  259. mark = pytest.mark.xfail(reason="pyarrow doesn't support this.")
  260. request.applymarker(mark)
  261. @pytest.fixture
  262. def pyarrow_skip(request):
  263. """
  264. Fixture that skips a test if the engine is pyarrow.
  265. Use if failure is do a parsing failure from pyarrow.csv.read_csv
  266. """
  267. if "all_parsers" in request.fixturenames:
  268. parser = request.getfixturevalue("all_parsers")
  269. elif "all_parsers_all_precisions" in request.fixturenames:
  270. # Return value is tuple of (engine, precision)
  271. parser = request.getfixturevalue("all_parsers_all_precisions")[0]
  272. else:
  273. return
  274. if parser.engine == "pyarrow":
  275. pytest.skip(reason="https://github.com/apache/arrow/issues/38676")