csvs.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. """
  2. Module for formatting output data into CSV files.
  3. """
  4. from __future__ import annotations
  5. from collections.abc import (
  6. Hashable,
  7. Iterable,
  8. Iterator,
  9. Sequence,
  10. )
  11. import csv as csvlib
  12. import os
  13. from typing import (
  14. TYPE_CHECKING,
  15. Any,
  16. cast,
  17. )
  18. import numpy as np
  19. from pandas._libs import writers as libwriters
  20. from pandas._typing import SequenceNotStr
  21. from pandas.util._decorators import cache_readonly
  22. from pandas.core.dtypes.generic import (
  23. ABCDatetimeIndex,
  24. ABCIndex,
  25. ABCMultiIndex,
  26. ABCPeriodIndex,
  27. )
  28. from pandas.core.dtypes.missing import notna
  29. from pandas.core.indexes.api import Index
  30. from pandas.io.common import get_handle
  31. if TYPE_CHECKING:
  32. from pandas._typing import (
  33. CompressionOptions,
  34. FilePath,
  35. FloatFormatType,
  36. IndexLabel,
  37. StorageOptions,
  38. WriteBuffer,
  39. npt,
  40. )
  41. from pandas.io.formats.format import DataFrameFormatter
  42. _DEFAULT_CHUNKSIZE_CELLS = 100_000
  43. class CSVFormatter:
  44. cols: npt.NDArray[np.object_]
  45. def __init__(
  46. self,
  47. formatter: DataFrameFormatter,
  48. path_or_buf: FilePath | WriteBuffer[str] | WriteBuffer[bytes] = "",
  49. sep: str = ",",
  50. cols: Sequence[Hashable] | None = None,
  51. index_label: IndexLabel | None = None,
  52. mode: str = "w",
  53. encoding: str | None = None,
  54. errors: str = "strict",
  55. compression: CompressionOptions = "infer",
  56. quoting: int | None = None,
  57. lineterminator: str | None = "\n",
  58. chunksize: int | None = None,
  59. quotechar: str | None = '"',
  60. date_format: str | None = None,
  61. doublequote: bool = True,
  62. escapechar: str | None = None,
  63. storage_options: StorageOptions | None = None,
  64. ) -> None:
  65. self.fmt = formatter
  66. self.obj = self.fmt.frame
  67. self.filepath_or_buffer = path_or_buf
  68. self.encoding = encoding
  69. self.compression: CompressionOptions = compression
  70. self.mode = mode
  71. self.storage_options = storage_options
  72. self.sep = sep
  73. self.index_label = self._initialize_index_label(index_label)
  74. self.errors = errors
  75. self.quoting = quoting or csvlib.QUOTE_MINIMAL
  76. self.quotechar = self._initialize_quotechar(quotechar)
  77. self.doublequote = doublequote
  78. self.escapechar = escapechar
  79. self.lineterminator = lineterminator or os.linesep
  80. self.date_format = date_format
  81. self.cols = self._initialize_columns(cols)
  82. self.chunksize = self._initialize_chunksize(chunksize)
  83. @property
  84. def na_rep(self) -> str:
  85. return self.fmt.na_rep
  86. @property
  87. def float_format(self) -> FloatFormatType | None:
  88. return self.fmt.float_format
  89. @property
  90. def decimal(self) -> str:
  91. return self.fmt.decimal
  92. @property
  93. def header(self) -> bool | SequenceNotStr[str]:
  94. return self.fmt.header
  95. @property
  96. def index(self) -> bool:
  97. return self.fmt.index
  98. def _initialize_index_label(self, index_label: IndexLabel | None) -> IndexLabel:
  99. if index_label is not False:
  100. if index_label is None:
  101. return self._get_index_label_from_obj()
  102. elif not isinstance(index_label, (list, tuple, np.ndarray, ABCIndex)):
  103. # given a string for a DF with Index
  104. return [index_label]
  105. return index_label
  106. def _get_index_label_from_obj(self) -> Sequence[Hashable]:
  107. if isinstance(self.obj.index, ABCMultiIndex):
  108. return self._get_index_label_multiindex()
  109. else:
  110. return self._get_index_label_flat()
  111. def _get_index_label_multiindex(self) -> Sequence[Hashable]:
  112. return [name or "" for name in self.obj.index.names]
  113. def _get_index_label_flat(self) -> Sequence[Hashable]:
  114. index_label = self.obj.index.name
  115. return [""] if index_label is None else [index_label]
  116. def _initialize_quotechar(self, quotechar: str | None) -> str | None:
  117. if self.quoting != csvlib.QUOTE_NONE:
  118. # prevents crash in _csv
  119. return quotechar
  120. return None
  121. @property
  122. def has_mi_columns(self) -> bool:
  123. return bool(isinstance(self.obj.columns, ABCMultiIndex))
  124. def _initialize_columns(
  125. self, cols: Iterable[Hashable] | None
  126. ) -> npt.NDArray[np.object_]:
  127. # validate mi options
  128. if self.has_mi_columns:
  129. if cols is not None:
  130. msg = "cannot specify cols with a MultiIndex on the columns"
  131. raise TypeError(msg)
  132. if cols is not None:
  133. if isinstance(cols, ABCIndex):
  134. cols = cols._get_values_for_csv(**self._number_format)
  135. else:
  136. cols = list(cols)
  137. self.obj = self.obj.loc[:, cols]
  138. # update columns to include possible multiplicity of dupes
  139. # and make sure cols is just a list of labels
  140. new_cols = self.obj.columns
  141. return new_cols._get_values_for_csv(**self._number_format)
  142. def _initialize_chunksize(self, chunksize: int | None) -> int:
  143. if chunksize is None:
  144. return (_DEFAULT_CHUNKSIZE_CELLS // (len(self.cols) or 1)) or 1
  145. return int(chunksize)
  146. @property
  147. def _number_format(self) -> dict[str, Any]:
  148. """Dictionary used for storing number formatting settings."""
  149. return {
  150. "na_rep": self.na_rep,
  151. "float_format": self.float_format,
  152. "date_format": self.date_format,
  153. "quoting": self.quoting,
  154. "decimal": self.decimal,
  155. }
  156. @cache_readonly
  157. def data_index(self) -> Index:
  158. data_index = self.obj.index
  159. if (
  160. isinstance(data_index, (ABCDatetimeIndex, ABCPeriodIndex))
  161. and self.date_format is not None
  162. ):
  163. data_index = Index(
  164. [x.strftime(self.date_format) if notna(x) else "" for x in data_index]
  165. )
  166. elif isinstance(data_index, ABCMultiIndex):
  167. data_index = data_index.remove_unused_levels()
  168. return data_index
  169. @property
  170. def nlevels(self) -> int:
  171. if self.index:
  172. return getattr(self.data_index, "nlevels", 1)
  173. else:
  174. return 0
  175. @property
  176. def _has_aliases(self) -> bool:
  177. return isinstance(self.header, (tuple, list, np.ndarray, ABCIndex))
  178. @property
  179. def _need_to_save_header(self) -> bool:
  180. return bool(self._has_aliases or self.header)
  181. @property
  182. def write_cols(self) -> SequenceNotStr[Hashable]:
  183. if self._has_aliases:
  184. assert not isinstance(self.header, bool)
  185. if len(self.header) != len(self.cols):
  186. raise ValueError(
  187. f"Writing {len(self.cols)} cols but got {len(self.header)} aliases"
  188. )
  189. return self.header
  190. else:
  191. # self.cols is an ndarray derived from Index._get_values_for_csv,
  192. # so its entries are strings, i.e. hashable
  193. return cast(SequenceNotStr[Hashable], self.cols)
  194. @property
  195. def encoded_labels(self) -> list[Hashable]:
  196. encoded_labels: list[Hashable] = []
  197. if self.index and self.index_label:
  198. assert isinstance(self.index_label, Sequence)
  199. encoded_labels = list(self.index_label)
  200. if not self.has_mi_columns or self._has_aliases:
  201. encoded_labels += list(self.write_cols)
  202. return encoded_labels
  203. def save(self) -> None:
  204. """
  205. Create the writer & save.
  206. """
  207. # apply compression and byte/text conversion
  208. with get_handle(
  209. self.filepath_or_buffer,
  210. self.mode,
  211. encoding=self.encoding,
  212. errors=self.errors,
  213. compression=self.compression,
  214. storage_options=self.storage_options,
  215. ) as handles:
  216. # Note: self.encoding is irrelevant here
  217. self.writer = csvlib.writer(
  218. handles.handle,
  219. lineterminator=self.lineterminator,
  220. delimiter=self.sep,
  221. quoting=self.quoting,
  222. doublequote=self.doublequote,
  223. escapechar=self.escapechar,
  224. quotechar=self.quotechar,
  225. )
  226. self._save()
  227. def _save(self) -> None:
  228. if self._need_to_save_header:
  229. self._save_header()
  230. self._save_body()
  231. def _save_header(self) -> None:
  232. if not self.has_mi_columns or self._has_aliases:
  233. self.writer.writerow(self.encoded_labels)
  234. else:
  235. for row in self._generate_multiindex_header_rows():
  236. self.writer.writerow(row)
  237. def _generate_multiindex_header_rows(self) -> Iterator[list[Hashable]]:
  238. columns = self.obj.columns
  239. for i in range(columns.nlevels):
  240. # we need at least 1 index column to write our col names
  241. col_line = []
  242. if self.index:
  243. # name is the first column
  244. col_line.append(columns.names[i])
  245. if isinstance(self.index_label, list) and len(self.index_label) > 1:
  246. col_line.extend([""] * (len(self.index_label) - 1))
  247. col_line.extend(columns._get_level_values(i))
  248. yield col_line
  249. # Write out the index line if it's not empty.
  250. # Otherwise, we will print out an extraneous
  251. # blank line between the mi and the data rows.
  252. if self.encoded_labels and set(self.encoded_labels) != {""}:
  253. yield self.encoded_labels + [""] * len(columns)
  254. def _save_body(self) -> None:
  255. nrows = len(self.data_index)
  256. chunks = (nrows // self.chunksize) + 1
  257. for i in range(chunks):
  258. start_i = i * self.chunksize
  259. end_i = min(start_i + self.chunksize, nrows)
  260. if start_i >= end_i:
  261. break
  262. self._save_chunk(start_i, end_i)
  263. def _save_chunk(self, start_i: int, end_i: int) -> None:
  264. # create the data for a chunk
  265. slicer = slice(start_i, end_i)
  266. df = self.obj.iloc[slicer]
  267. res = df._get_values_for_csv(**self._number_format)
  268. data = list(res._iter_column_arrays())
  269. ix = self.data_index[slicer]._get_values_for_csv(**self._number_format)
  270. libwriters.write_csv_rows(
  271. data,
  272. ix,
  273. self.nlevels,
  274. self.cols,
  275. self.writer,
  276. )