test_dialect.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. """
  2. Tests that dialects are properly handled during parsing
  3. for all of the parsers defined in parsers.py
  4. """
  5. import csv
  6. from io import StringIO
  7. import pytest
  8. from pandas.errors import ParserWarning
  9. from pandas import DataFrame
  10. import pandas._testing as tm
  11. pytestmark = pytest.mark.filterwarnings(
  12. "ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
  13. )
  14. @pytest.fixture
  15. def custom_dialect():
  16. dialect_name = "weird"
  17. dialect_kwargs = {
  18. "doublequote": False,
  19. "escapechar": "~",
  20. "delimiter": ":",
  21. "skipinitialspace": False,
  22. "quotechar": "`",
  23. "quoting": 3,
  24. }
  25. return dialect_name, dialect_kwargs
  26. def test_dialect(all_parsers):
  27. parser = all_parsers
  28. data = """\
  29. label1,label2,label3
  30. index1,"a,c,e
  31. index2,b,d,f
  32. """
  33. dia = csv.excel()
  34. dia.quoting = csv.QUOTE_NONE
  35. if parser.engine == "pyarrow":
  36. msg = "The 'dialect' option is not supported with the 'pyarrow' engine"
  37. with pytest.raises(ValueError, match=msg):
  38. parser.read_csv(StringIO(data), dialect=dia)
  39. return
  40. df = parser.read_csv(StringIO(data), dialect=dia)
  41. data = """\
  42. label1,label2,label3
  43. index1,a,c,e
  44. index2,b,d,f
  45. """
  46. exp = parser.read_csv(StringIO(data))
  47. exp.replace("a", '"a', inplace=True)
  48. tm.assert_frame_equal(df, exp)
  49. def test_dialect_str(all_parsers):
  50. dialect_name = "mydialect"
  51. parser = all_parsers
  52. data = """\
  53. fruit:vegetable
  54. apple:broccoli
  55. pear:tomato
  56. """
  57. exp = DataFrame({"fruit": ["apple", "pear"], "vegetable": ["broccoli", "tomato"]})
  58. with tm.with_csv_dialect(dialect_name, delimiter=":"):
  59. if parser.engine == "pyarrow":
  60. msg = "The 'dialect' option is not supported with the 'pyarrow' engine"
  61. with pytest.raises(ValueError, match=msg):
  62. parser.read_csv(StringIO(data), dialect=dialect_name)
  63. return
  64. df = parser.read_csv(StringIO(data), dialect=dialect_name)
  65. tm.assert_frame_equal(df, exp)
  66. def test_invalid_dialect(all_parsers):
  67. class InvalidDialect:
  68. pass
  69. data = "a\n1"
  70. parser = all_parsers
  71. msg = "Invalid dialect"
  72. with pytest.raises(ValueError, match=msg):
  73. parser.read_csv(StringIO(data), dialect=InvalidDialect)
  74. @pytest.mark.parametrize(
  75. "arg",
  76. [None, "doublequote", "escapechar", "skipinitialspace", "quotechar", "quoting"],
  77. )
  78. @pytest.mark.parametrize("value", ["dialect", "default", "other"])
  79. def test_dialect_conflict_except_delimiter(all_parsers, custom_dialect, arg, value):
  80. # see gh-23761.
  81. dialect_name, dialect_kwargs = custom_dialect
  82. parser = all_parsers
  83. expected = DataFrame({"a": [1], "b": [2]})
  84. data = "a:b\n1:2"
  85. warning_klass = None
  86. kwds = {}
  87. # arg=None tests when we pass in the dialect without any other arguments.
  88. if arg is not None:
  89. if value == "dialect": # No conflict --> no warning.
  90. kwds[arg] = dialect_kwargs[arg]
  91. elif value == "default": # Default --> no warning.
  92. from pandas.io.parsers.base_parser import parser_defaults
  93. kwds[arg] = parser_defaults[arg]
  94. else: # Non-default + conflict with dialect --> warning.
  95. warning_klass = ParserWarning
  96. kwds[arg] = "blah"
  97. with tm.with_csv_dialect(dialect_name, **dialect_kwargs):
  98. if parser.engine == "pyarrow":
  99. msg = "The 'dialect' option is not supported with the 'pyarrow' engine"
  100. with pytest.raises(ValueError, match=msg):
  101. parser.read_csv_check_warnings(
  102. # No warning bc we raise
  103. None,
  104. "Conflicting values for",
  105. StringIO(data),
  106. dialect=dialect_name,
  107. **kwds,
  108. )
  109. return
  110. result = parser.read_csv_check_warnings(
  111. warning_klass,
  112. "Conflicting values for",
  113. StringIO(data),
  114. dialect=dialect_name,
  115. **kwds,
  116. )
  117. tm.assert_frame_equal(result, expected)
  118. @pytest.mark.parametrize(
  119. "kwargs,warning_klass",
  120. [
  121. ({"sep": ","}, None), # sep is default --> sep_override=True
  122. ({"sep": "."}, ParserWarning), # sep isn't default --> sep_override=False
  123. ({"delimiter": ":"}, None), # No conflict
  124. ({"delimiter": None}, None), # Default arguments --> sep_override=True
  125. ({"delimiter": ","}, ParserWarning), # Conflict
  126. ({"delimiter": "."}, ParserWarning), # Conflict
  127. ],
  128. ids=[
  129. "sep-override-true",
  130. "sep-override-false",
  131. "delimiter-no-conflict",
  132. "delimiter-default-arg",
  133. "delimiter-conflict",
  134. "delimiter-conflict2",
  135. ],
  136. )
  137. def test_dialect_conflict_delimiter(all_parsers, custom_dialect, kwargs, warning_klass):
  138. # see gh-23761.
  139. dialect_name, dialect_kwargs = custom_dialect
  140. parser = all_parsers
  141. expected = DataFrame({"a": [1], "b": [2]})
  142. data = "a:b\n1:2"
  143. with tm.with_csv_dialect(dialect_name, **dialect_kwargs):
  144. if parser.engine == "pyarrow":
  145. msg = "The 'dialect' option is not supported with the 'pyarrow' engine"
  146. with pytest.raises(ValueError, match=msg):
  147. parser.read_csv_check_warnings(
  148. # no warning bc we raise
  149. None,
  150. "Conflicting values for 'delimiter'",
  151. StringIO(data),
  152. dialect=dialect_name,
  153. **kwargs,
  154. )
  155. return
  156. result = parser.read_csv_check_warnings(
  157. warning_klass,
  158. "Conflicting values for 'delimiter'",
  159. StringIO(data),
  160. dialect=dialect_name,
  161. **kwargs,
  162. )
  163. tm.assert_frame_equal(result, expected)