common.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING
  3. from pandas import (
  4. DataFrame,
  5. concat,
  6. )
  7. if TYPE_CHECKING:
  8. from pandas._typing import AxisInt
  9. def _check_mixed_float(df, dtype=None):
  10. # float16 are most likely to be upcasted to float32
  11. dtypes = {"A": "float32", "B": "float32", "C": "float16", "D": "float64"}
  12. if isinstance(dtype, str):
  13. dtypes = {k: dtype for k, v in dtypes.items()}
  14. elif isinstance(dtype, dict):
  15. dtypes.update(dtype)
  16. if dtypes.get("A"):
  17. assert df.dtypes["A"] == dtypes["A"]
  18. if dtypes.get("B"):
  19. assert df.dtypes["B"] == dtypes["B"]
  20. if dtypes.get("C"):
  21. assert df.dtypes["C"] == dtypes["C"]
  22. if dtypes.get("D"):
  23. assert df.dtypes["D"] == dtypes["D"]
  24. def _check_mixed_int(df, dtype=None):
  25. dtypes = {"A": "int32", "B": "uint64", "C": "uint8", "D": "int64"}
  26. if isinstance(dtype, str):
  27. dtypes = {k: dtype for k, v in dtypes.items()}
  28. elif isinstance(dtype, dict):
  29. dtypes.update(dtype)
  30. if dtypes.get("A"):
  31. assert df.dtypes["A"] == dtypes["A"]
  32. if dtypes.get("B"):
  33. assert df.dtypes["B"] == dtypes["B"]
  34. if dtypes.get("C"):
  35. assert df.dtypes["C"] == dtypes["C"]
  36. if dtypes.get("D"):
  37. assert df.dtypes["D"] == dtypes["D"]
  38. def zip_frames(frames: list[DataFrame], axis: AxisInt = 1) -> DataFrame:
  39. """
  40. take a list of frames, zip them together under the
  41. assumption that these all have the first frames' index/columns.
  42. Returns
  43. -------
  44. new_frame : DataFrame
  45. """
  46. if axis == 1:
  47. columns = frames[0].columns
  48. zipped = [f.loc[:, c] for c in columns for f in frames]
  49. return concat(zipped, axis=1)
  50. else:
  51. index = frames[0].index
  52. zipped = [f.loc[i, :] for i in index for f in frames]
  53. return DataFrame(zipped)