utils.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. """Shared utilities for the modules in wandb.sklearn."""
  2. from collections.abc import Iterable, Sequence
  3. import numpy as np
  4. import pandas as pd
  5. import scipy
  6. import sklearn
  7. import wandb
  8. chart_limit = 1000
  9. def check_against_limit(count, chart, limit=None):
  10. if limit is None:
  11. limit = chart_limit
  12. if count > limit:
  13. warn_chart_limit(limit, chart)
  14. return True
  15. else:
  16. return False
  17. def warn_chart_limit(limit, chart):
  18. warning = f"using only the first {limit} datapoints to create chart {chart}"
  19. wandb.termwarn(warning)
  20. def encode_labels(df):
  21. le = sklearn.preprocessing.LabelEncoder()
  22. # apply le on categorical feature columns
  23. categorical_cols = df.select_dtypes(
  24. exclude=["int", "float", "float64", "float32", "int32", "int64"]
  25. ).columns
  26. df[categorical_cols] = df[categorical_cols].apply(lambda col: le.fit_transform(col))
  27. def test_types(**kwargs):
  28. test_passed = True
  29. for k, v in kwargs.items():
  30. # check for incorrect types
  31. if (
  32. (k == "X")
  33. or (k == "X_test")
  34. or (k == "y")
  35. or (k == "y_test")
  36. or (k == "y_true")
  37. or (k == "y_probas")
  38. ):
  39. # FIXME: do this individually
  40. if not isinstance(
  41. v,
  42. (
  43. Sequence,
  44. Iterable,
  45. np.ndarray,
  46. np.generic,
  47. pd.DataFrame,
  48. pd.Series,
  49. list,
  50. ),
  51. ):
  52. wandb.termerror(f"{k} is not an array. Please try again.")
  53. test_passed = False
  54. # check for classifier types
  55. if k == "model":
  56. if (not sklearn.base.is_classifier(v)) and (
  57. not sklearn.base.is_regressor(v)
  58. ):
  59. wandb.termerror(
  60. f"{k} is not a classifier or regressor. Please try again."
  61. )
  62. test_passed = False
  63. elif k == "clf" or k == "binary_clf":
  64. if not (sklearn.base.is_classifier(v)):
  65. wandb.termerror(f"{k} is not a classifier. Please try again.")
  66. test_passed = False
  67. elif k == "regressor":
  68. if not sklearn.base.is_regressor(v):
  69. wandb.termerror(f"{k} is not a regressor. Please try again.")
  70. test_passed = False
  71. elif k == "clusterer":
  72. if not (getattr(v, "_estimator_type", None) == "clusterer"):
  73. wandb.termerror(f"{k} is not a clusterer. Please try again.")
  74. test_passed = False
  75. return test_passed
  76. def test_fitted(model):
  77. try:
  78. model.predict(np.zeros((7, 3)))
  79. except sklearn.exceptions.NotFittedError:
  80. wandb.termerror("Please fit the model before passing it in.")
  81. return False
  82. except AttributeError:
  83. # Some clustering models (LDA, PCA, Agglomerative) don't implement ``predict``
  84. try:
  85. sklearn.utils.validation.check_is_fitted(
  86. model,
  87. [
  88. "coef_",
  89. "estimator_",
  90. "labels_",
  91. "n_clusters_",
  92. "children_",
  93. "components_",
  94. "n_components_",
  95. "n_iter_",
  96. "n_batch_iter_",
  97. "explained_variance_",
  98. "singular_values_",
  99. "mean_",
  100. ],
  101. all_or_any=any,
  102. )
  103. except sklearn.exceptions.NotFittedError:
  104. wandb.termerror("Please fit the model before passing it in.")
  105. return False
  106. else:
  107. return True
  108. except Exception:
  109. # Assume it's fitted, since ``NotFittedError`` wasn't raised
  110. return True
  111. # Test Asummptions for plotting parameters and datasets
  112. def test_missing(**kwargs):
  113. test_passed = True
  114. for k, v in kwargs.items():
  115. # Missing/empty params/datapoint arrays
  116. if v is None:
  117. wandb.termerror(f"{k} is None. Please try again.")
  118. test_passed = False
  119. if (k == "X") or (k == "X_test"):
  120. if isinstance(v, scipy.sparse.csr.csr_matrix):
  121. v = v.toarray()
  122. elif isinstance(v, (pd.DataFrame, pd.Series)):
  123. v = v.to_numpy()
  124. elif isinstance(v, list):
  125. v = np.asarray(v)
  126. # Warn the user about missing values
  127. missing = 0
  128. missing = np.count_nonzero(pd.isnull(v))
  129. if missing > 0:
  130. wandb.termwarn(f"{k} contains {missing} missing values. ")
  131. test_passed = False
  132. # Ensure the dataset contains only integers
  133. non_nums = 0
  134. if v.ndim == 1:
  135. non_nums = sum(
  136. 1
  137. for val in v
  138. if (
  139. not isinstance(val, (int, float, complex))
  140. and not isinstance(val, np.number)
  141. )
  142. )
  143. else:
  144. non_nums = sum(
  145. 1
  146. for sl in v
  147. for val in sl
  148. if (
  149. not isinstance(val, (int, float, complex))
  150. and not isinstance(val, np.number)
  151. )
  152. )
  153. if non_nums > 0:
  154. wandb.termerror(
  155. f"{k} contains values that are not numbers. Please vectorize, label encode or one hot encode {k} "
  156. "and call the plotting function again."
  157. )
  158. test_passed = False
  159. return test_passed
  160. def round_3(n):
  161. return round(n, 3)
  162. def round_2(n):
  163. return round(n, 2)