_validators.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809
  1. # pyright: basic
  2. from __future__ import annotations
  3. import os
  4. import sys
  5. from typing import Any, TypeVar, Callable, Optional, NamedTuple
  6. from typing_extensions import TypeAlias
  7. from .._extras import pandas as pd
  8. class Remediation(NamedTuple):
  9. name: str
  10. immediate_msg: Optional[str] = None
  11. necessary_msg: Optional[str] = None
  12. necessary_fn: Optional[Callable[[Any], Any]] = None
  13. optional_msg: Optional[str] = None
  14. optional_fn: Optional[Callable[[Any], Any]] = None
  15. error_msg: Optional[str] = None
  16. OptionalDataFrameT = TypeVar("OptionalDataFrameT", bound="Optional[pd.DataFrame]")
  17. def num_examples_validator(df: pd.DataFrame) -> Remediation:
  18. """
  19. This validator will only print out the number of examples and recommend to the user to increase the number of examples if less than 100.
  20. """
  21. MIN_EXAMPLES = 100
  22. optional_suggestion = (
  23. ""
  24. if len(df) >= MIN_EXAMPLES
  25. else ". In general, we recommend having at least a few hundred examples. We've found that performance tends to linearly increase for every doubling of the number of examples"
  26. )
  27. immediate_msg = f"\n- Your file contains {len(df)} prompt-completion pairs{optional_suggestion}"
  28. return Remediation(name="num_examples", immediate_msg=immediate_msg)
  29. def necessary_column_validator(df: pd.DataFrame, necessary_column: str) -> Remediation:
  30. """
  31. This validator will ensure that the necessary column is present in the dataframe.
  32. """
  33. def lower_case_column(df: pd.DataFrame, column: Any) -> pd.DataFrame:
  34. cols = [c for c in df.columns if str(c).lower() == column]
  35. df.rename(columns={cols[0]: column.lower()}, inplace=True)
  36. return df
  37. immediate_msg = None
  38. necessary_fn = None
  39. necessary_msg = None
  40. error_msg = None
  41. if necessary_column not in df.columns:
  42. if necessary_column in [str(c).lower() for c in df.columns]:
  43. def lower_case_column_creator(df: pd.DataFrame) -> pd.DataFrame:
  44. return lower_case_column(df, necessary_column)
  45. necessary_fn = lower_case_column_creator
  46. immediate_msg = f"\n- The `{necessary_column}` column/key should be lowercase"
  47. necessary_msg = f"Lower case column name to `{necessary_column}`"
  48. else:
  49. error_msg = f"`{necessary_column}` column/key is missing. Please make sure you name your columns/keys appropriately, then retry"
  50. return Remediation(
  51. name="necessary_column",
  52. immediate_msg=immediate_msg,
  53. necessary_msg=necessary_msg,
  54. necessary_fn=necessary_fn,
  55. error_msg=error_msg,
  56. )
  57. def additional_column_validator(df: pd.DataFrame, fields: list[str] = ["prompt", "completion"]) -> Remediation:
  58. """
  59. This validator will remove additional columns from the dataframe.
  60. """
  61. additional_columns = []
  62. necessary_msg = None
  63. immediate_msg = None
  64. necessary_fn = None # type: ignore
  65. if len(df.columns) > 2:
  66. additional_columns = [c for c in df.columns if c not in fields]
  67. warn_message = ""
  68. for ac in additional_columns:
  69. dups = [c for c in additional_columns if ac in c]
  70. if len(dups) > 0:
  71. warn_message += f"\n WARNING: Some of the additional columns/keys contain `{ac}` in their name. These will be ignored, and the column/key `{ac}` will be used instead. This could also result from a duplicate column/key in the provided file."
  72. immediate_msg = f"\n- The input file should contain exactly two columns/keys per row. Additional columns/keys present are: {additional_columns}{warn_message}"
  73. necessary_msg = f"Remove additional columns/keys: {additional_columns}"
  74. def necessary_fn(x: Any) -> Any:
  75. return x[fields]
  76. return Remediation(
  77. name="additional_column",
  78. immediate_msg=immediate_msg,
  79. necessary_msg=necessary_msg,
  80. necessary_fn=necessary_fn,
  81. )
  82. def non_empty_field_validator(df: pd.DataFrame, field: str = "completion") -> Remediation:
  83. """
  84. This validator will ensure that no completion is empty.
  85. """
  86. necessary_msg = None
  87. necessary_fn = None # type: ignore
  88. immediate_msg = None
  89. if df[field].apply(lambda x: x == "").any() or df[field].isnull().any():
  90. empty_rows = (df[field] == "") | (df[field].isnull())
  91. empty_indexes = df.reset_index().index[empty_rows].tolist()
  92. immediate_msg = f"\n- `{field}` column/key should not contain empty strings. These are rows: {empty_indexes}"
  93. def necessary_fn(x: Any) -> Any:
  94. return x[x[field] != ""].dropna(subset=[field])
  95. necessary_msg = f"Remove {len(empty_indexes)} rows with empty {field}s"
  96. return Remediation(
  97. name=f"empty_{field}",
  98. immediate_msg=immediate_msg,
  99. necessary_msg=necessary_msg,
  100. necessary_fn=necessary_fn,
  101. )
  102. def duplicated_rows_validator(df: pd.DataFrame, fields: list[str] = ["prompt", "completion"]) -> Remediation:
  103. """
  104. This validator will suggest to the user to remove duplicate rows if they exist.
  105. """
  106. duplicated_rows = df.duplicated(subset=fields)
  107. duplicated_indexes = df.reset_index().index[duplicated_rows].tolist()
  108. immediate_msg = None
  109. optional_msg = None
  110. optional_fn = None # type: ignore
  111. if len(duplicated_indexes) > 0:
  112. immediate_msg = f"\n- There are {len(duplicated_indexes)} duplicated {'-'.join(fields)} sets. These are rows: {duplicated_indexes}"
  113. optional_msg = f"Remove {len(duplicated_indexes)} duplicate rows"
  114. def optional_fn(x: Any) -> Any:
  115. return x.drop_duplicates(subset=fields)
  116. return Remediation(
  117. name="duplicated_rows",
  118. immediate_msg=immediate_msg,
  119. optional_msg=optional_msg,
  120. optional_fn=optional_fn,
  121. )
  122. def long_examples_validator(df: pd.DataFrame) -> Remediation:
  123. """
  124. This validator will suggest to the user to remove examples that are too long.
  125. """
  126. immediate_msg = None
  127. optional_msg = None
  128. optional_fn = None # type: ignore
  129. ft_type = infer_task_type(df)
  130. if ft_type != "open-ended generation":
  131. def get_long_indexes(d: pd.DataFrame) -> Any:
  132. long_examples = d.apply(lambda x: len(x.prompt) + len(x.completion) > 10000, axis=1)
  133. return d.reset_index().index[long_examples].tolist()
  134. long_indexes = get_long_indexes(df)
  135. if len(long_indexes) > 0:
  136. immediate_msg = f"\n- There are {len(long_indexes)} examples that are very long. These are rows: {long_indexes}\nFor conditional generation, and for classification the examples shouldn't be longer than 2048 tokens."
  137. optional_msg = f"Remove {len(long_indexes)} long examples"
  138. def optional_fn(x: Any) -> Any:
  139. long_indexes_to_drop = get_long_indexes(x)
  140. if long_indexes != long_indexes_to_drop:
  141. sys.stdout.write(
  142. f"The indices of the long examples has changed as a result of a previously applied recommendation.\nThe {len(long_indexes_to_drop)} long examples to be dropped are now at the following indices: {long_indexes_to_drop}\n"
  143. )
  144. return x.drop(long_indexes_to_drop)
  145. return Remediation(
  146. name="long_examples",
  147. immediate_msg=immediate_msg,
  148. optional_msg=optional_msg,
  149. optional_fn=optional_fn,
  150. )
  151. def common_prompt_suffix_validator(df: pd.DataFrame) -> Remediation:
  152. """
  153. This validator will suggest to add a common suffix to the prompt if one doesn't already exist in case of classification or conditional generation.
  154. """
  155. error_msg = None
  156. immediate_msg = None
  157. optional_msg = None
  158. optional_fn = None # type: ignore
  159. # Find a suffix which is not contained within the prompt otherwise
  160. suggested_suffix = "\n\n### =>\n\n"
  161. suffix_options = [
  162. " ->",
  163. "\n\n###\n\n",
  164. "\n\n===\n\n",
  165. "\n\n---\n\n",
  166. "\n\n===>\n\n",
  167. "\n\n--->\n\n",
  168. ]
  169. for suffix_option in suffix_options:
  170. if suffix_option == " ->":
  171. if df.prompt.str.contains("\n").any():
  172. continue
  173. if df.prompt.str.contains(suffix_option, regex=False).any():
  174. continue
  175. suggested_suffix = suffix_option
  176. break
  177. display_suggested_suffix = suggested_suffix.replace("\n", "\\n")
  178. ft_type = infer_task_type(df)
  179. if ft_type == "open-ended generation":
  180. return Remediation(name="common_suffix")
  181. def add_suffix(x: Any, suffix: Any) -> Any:
  182. x["prompt"] += suffix
  183. return x
  184. common_suffix = get_common_xfix(df.prompt, xfix="suffix")
  185. if (df.prompt == common_suffix).all():
  186. error_msg = f"All prompts are identical: `{common_suffix}`\nConsider leaving the prompts blank if you want to do open-ended generation, otherwise ensure prompts are different"
  187. return Remediation(name="common_suffix", error_msg=error_msg)
  188. if common_suffix != "":
  189. common_suffix_new_line_handled = common_suffix.replace("\n", "\\n")
  190. immediate_msg = f"\n- All prompts end with suffix `{common_suffix_new_line_handled}`"
  191. if len(common_suffix) > 10:
  192. immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`"
  193. if df.prompt.str[: -len(common_suffix)].str.contains(common_suffix, regex=False).any():
  194. immediate_msg += f"\n WARNING: Some of your prompts contain the suffix `{common_suffix}` more than once. We strongly suggest that you review your prompts and add a unique suffix"
  195. else:
  196. immediate_msg = "\n- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty"
  197. if common_suffix == "":
  198. optional_msg = f"Add a suffix separator `{display_suggested_suffix}` to all prompts"
  199. def optional_fn(x: Any) -> Any:
  200. return add_suffix(x, suggested_suffix)
  201. return Remediation(
  202. name="common_completion_suffix",
  203. immediate_msg=immediate_msg,
  204. optional_msg=optional_msg,
  205. optional_fn=optional_fn,
  206. error_msg=error_msg,
  207. )
  208. def common_prompt_prefix_validator(df: pd.DataFrame) -> Remediation:
  209. """
  210. This validator will suggest to remove a common prefix from the prompt if a long one exist.
  211. """
  212. MAX_PREFIX_LEN = 12
  213. immediate_msg = None
  214. optional_msg = None
  215. optional_fn = None # type: ignore
  216. common_prefix = get_common_xfix(df.prompt, xfix="prefix")
  217. if common_prefix == "":
  218. return Remediation(name="common_prefix")
  219. def remove_common_prefix(x: Any, prefix: Any) -> Any:
  220. x["prompt"] = x["prompt"].str[len(prefix) :]
  221. return x
  222. if (df.prompt == common_prefix).all():
  223. # already handled by common_suffix_validator
  224. return Remediation(name="common_prefix")
  225. if common_prefix != "":
  226. immediate_msg = f"\n- All prompts start with prefix `{common_prefix}`"
  227. if MAX_PREFIX_LEN < len(common_prefix):
  228. immediate_msg += ". Fine-tuning doesn't require the instruction specifying the task, or a few-shot example scenario. Most of the time you should only add the input data into the prompt, and the desired output into the completion"
  229. optional_msg = f"Remove prefix `{common_prefix}` from all prompts"
  230. def optional_fn(x: Any) -> Any:
  231. return remove_common_prefix(x, common_prefix)
  232. return Remediation(
  233. name="common_prompt_prefix",
  234. immediate_msg=immediate_msg,
  235. optional_msg=optional_msg,
  236. optional_fn=optional_fn,
  237. )
  238. def common_completion_prefix_validator(df: pd.DataFrame) -> Remediation:
  239. """
  240. This validator will suggest to remove a common prefix from the completion if a long one exist.
  241. """
  242. MAX_PREFIX_LEN = 5
  243. common_prefix = get_common_xfix(df.completion, xfix="prefix")
  244. ws_prefix = len(common_prefix) > 0 and common_prefix[0] == " "
  245. if len(common_prefix) < MAX_PREFIX_LEN:
  246. return Remediation(name="common_prefix")
  247. def remove_common_prefix(x: Any, prefix: Any, ws_prefix: Any) -> Any:
  248. x["completion"] = x["completion"].str[len(prefix) :]
  249. if ws_prefix:
  250. # keep the single whitespace as prefix
  251. x["completion"] = f" {x['completion']}"
  252. return x
  253. if (df.completion == common_prefix).all():
  254. # already handled by common_suffix_validator
  255. return Remediation(name="common_prefix")
  256. immediate_msg = f"\n- All completions start with prefix `{common_prefix}`. Most of the time you should only add the output data into the completion, without any prefix"
  257. optional_msg = f"Remove prefix `{common_prefix}` from all completions"
  258. def optional_fn(x: Any) -> Any:
  259. return remove_common_prefix(x, common_prefix, ws_prefix)
  260. return Remediation(
  261. name="common_completion_prefix",
  262. immediate_msg=immediate_msg,
  263. optional_msg=optional_msg,
  264. optional_fn=optional_fn,
  265. )
  266. def common_completion_suffix_validator(df: pd.DataFrame) -> Remediation:
  267. """
  268. This validator will suggest to add a common suffix to the completion if one doesn't already exist in case of classification or conditional generation.
  269. """
  270. error_msg = None
  271. immediate_msg = None
  272. optional_msg = None
  273. optional_fn = None # type: ignore
  274. ft_type = infer_task_type(df)
  275. if ft_type == "open-ended generation" or ft_type == "classification":
  276. return Remediation(name="common_suffix")
  277. common_suffix = get_common_xfix(df.completion, xfix="suffix")
  278. if (df.completion == common_suffix).all():
  279. error_msg = f"All completions are identical: `{common_suffix}`\nEnsure completions are different, otherwise the model will just repeat `{common_suffix}`"
  280. return Remediation(name="common_suffix", error_msg=error_msg)
  281. # Find a suffix which is not contained within the completion otherwise
  282. suggested_suffix = " [END]"
  283. suffix_options = [
  284. "\n",
  285. ".",
  286. " END",
  287. "***",
  288. "+++",
  289. "&&&",
  290. "$$$",
  291. "@@@",
  292. "%%%",
  293. ]
  294. for suffix_option in suffix_options:
  295. if df.completion.str.contains(suffix_option, regex=False).any():
  296. continue
  297. suggested_suffix = suffix_option
  298. break
  299. display_suggested_suffix = suggested_suffix.replace("\n", "\\n")
  300. def add_suffix(x: Any, suffix: Any) -> Any:
  301. x["completion"] += suffix
  302. return x
  303. if common_suffix != "":
  304. common_suffix_new_line_handled = common_suffix.replace("\n", "\\n")
  305. immediate_msg = f"\n- All completions end with suffix `{common_suffix_new_line_handled}`"
  306. if len(common_suffix) > 10:
  307. immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`"
  308. if df.completion.str[: -len(common_suffix)].str.contains(common_suffix, regex=False).any():
  309. immediate_msg += f"\n WARNING: Some of your completions contain the suffix `{common_suffix}` more than once. We suggest that you review your completions and add a unique ending"
  310. else:
  311. immediate_msg = "\n- Your data does not contain a common ending at the end of your completions. Having a common ending string appended to the end of the completion makes it clearer to the fine-tuned model where the completion should end. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples."
  312. if common_suffix == "":
  313. optional_msg = f"Add a suffix ending `{display_suggested_suffix}` to all completions"
  314. def optional_fn(x: Any) -> Any:
  315. return add_suffix(x, suggested_suffix)
  316. return Remediation(
  317. name="common_completion_suffix",
  318. immediate_msg=immediate_msg,
  319. optional_msg=optional_msg,
  320. optional_fn=optional_fn,
  321. error_msg=error_msg,
  322. )
  323. def completions_space_start_validator(df: pd.DataFrame) -> Remediation:
  324. """
  325. This validator will suggest to add a space at the start of the completion if it doesn't already exist. This helps with tokenization.
  326. """
  327. def add_space_start(x: Any) -> Any:
  328. x["completion"] = x["completion"].apply(lambda s: ("" if s.startswith(" ") else " ") + s)
  329. return x
  330. optional_msg = None
  331. optional_fn = None
  332. immediate_msg = None
  333. if df.completion.str[:1].nunique() != 1 or df.completion.values[0][0] != " ":
  334. immediate_msg = "\n- The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details"
  335. optional_msg = "Add a whitespace character to the beginning of the completion"
  336. optional_fn = add_space_start
  337. return Remediation(
  338. name="completion_space_start",
  339. immediate_msg=immediate_msg,
  340. optional_msg=optional_msg,
  341. optional_fn=optional_fn,
  342. )
  343. def lower_case_validator(df: pd.DataFrame, column: Any) -> Remediation | None:
  344. """
  345. This validator will suggest to lowercase the column values, if more than a third of letters are uppercase.
  346. """
  347. def lower_case(x: Any) -> Any:
  348. x[column] = x[column].str.lower()
  349. return x
  350. count_upper = df[column].apply(lambda x: sum(1 for c in x if c.isalpha() and c.isupper())).sum()
  351. count_lower = df[column].apply(lambda x: sum(1 for c in x if c.isalpha() and c.islower())).sum()
  352. if count_upper * 2 > count_lower:
  353. return Remediation(
  354. name="lower_case",
  355. immediate_msg=f"\n- More than a third of your `{column}` column/key is uppercase. Uppercase {column}s tends to perform worse than a mixture of case encountered in normal language. We recommend to lower case the data if that makes sense in your domain. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details",
  356. optional_msg=f"Lowercase all your data in column/key `{column}`",
  357. optional_fn=lower_case,
  358. )
  359. return None
  360. def read_any_format(
  361. fname: str, fields: list[str] = ["prompt", "completion"]
  362. ) -> tuple[pd.DataFrame | None, Remediation]:
  363. """
  364. This function will read a file saved in .csv, .json, .txt, .xlsx or .tsv format using pandas.
  365. - for .xlsx it will read the first sheet
  366. - for .txt it will assume completions and split on newline
  367. """
  368. remediation = None
  369. necessary_msg = None
  370. immediate_msg = None
  371. error_msg = None
  372. df = None
  373. if os.path.isfile(fname):
  374. try:
  375. if fname.lower().endswith(".csv") or fname.lower().endswith(".tsv"):
  376. file_extension_str, separator = ("CSV", ",") if fname.lower().endswith(".csv") else ("TSV", "\t")
  377. immediate_msg = (
  378. f"\n- Based on your file extension, your file is formatted as a {file_extension_str} file"
  379. )
  380. necessary_msg = f"Your format `{file_extension_str}` will be converted to `JSONL`"
  381. df = pd.read_csv(fname, sep=separator, dtype=str).fillna("")
  382. elif fname.lower().endswith(".xlsx"):
  383. immediate_msg = "\n- Based on your file extension, your file is formatted as an Excel file"
  384. necessary_msg = "Your format `XLSX` will be converted to `JSONL`"
  385. xls = pd.ExcelFile(fname)
  386. sheets = xls.sheet_names
  387. if len(sheets) > 1:
  388. immediate_msg += "\n- Your Excel file contains more than one sheet. Please either save as csv or ensure all data is present in the first sheet. WARNING: Reading only the first sheet..."
  389. df = pd.read_excel(fname, dtype=str).fillna("")
  390. elif fname.lower().endswith(".txt"):
  391. immediate_msg = "\n- Based on your file extension, you provided a text file"
  392. necessary_msg = "Your format `TXT` will be converted to `JSONL`"
  393. with open(fname, "r") as f:
  394. content = f.read()
  395. df = pd.DataFrame(
  396. [["", line] for line in content.split("\n")],
  397. columns=fields,
  398. dtype=str,
  399. ).fillna("")
  400. elif fname.lower().endswith(".jsonl"):
  401. df = pd.read_json(fname, lines=True, dtype=str).fillna("") # type: ignore
  402. if len(df) == 1: # type: ignore
  403. # this is NOT what we expect for a .jsonl file
  404. immediate_msg = "\n- Your JSONL file appears to be in a JSON format. Your file will be converted to JSONL format"
  405. necessary_msg = "Your format `JSON` will be converted to `JSONL`"
  406. df = pd.read_json(fname, dtype=str).fillna("") # type: ignore
  407. else:
  408. pass # this is what we expect for a .jsonl file
  409. elif fname.lower().endswith(".json"):
  410. try:
  411. # to handle case where .json file is actually a .jsonl file
  412. df = pd.read_json(fname, lines=True, dtype=str).fillna("") # type: ignore
  413. if len(df) == 1: # type: ignore
  414. # this code path corresponds to a .json file that has one line
  415. df = pd.read_json(fname, dtype=str).fillna("") # type: ignore
  416. else:
  417. # this is NOT what we expect for a .json file
  418. immediate_msg = "\n- Your JSON file appears to be in a JSONL format. Your file will be converted to JSONL format"
  419. necessary_msg = "Your format `JSON` will be converted to `JSONL`"
  420. except ValueError:
  421. # this code path corresponds to a .json file that has multiple lines (i.e. it is indented)
  422. df = pd.read_json(fname, dtype=str).fillna("") # type: ignore
  423. else:
  424. error_msg = (
  425. "Your file must have one of the following extensions: .CSV, .TSV, .XLSX, .TXT, .JSON or .JSONL"
  426. )
  427. if "." in fname:
  428. error_msg += f" Your file `{fname}` ends with the extension `.{fname.split('.')[-1]}` which is not supported."
  429. else:
  430. error_msg += f" Your file `{fname}` is missing a file extension."
  431. except (ValueError, TypeError):
  432. file_extension_str = fname.split(".")[-1].upper()
  433. error_msg = f"Your file `{fname}` does not appear to be in valid {file_extension_str} format. Please ensure your file is formatted as a valid {file_extension_str} file."
  434. else:
  435. error_msg = f"File {fname} does not exist."
  436. remediation = Remediation(
  437. name="read_any_format",
  438. necessary_msg=necessary_msg,
  439. immediate_msg=immediate_msg,
  440. error_msg=error_msg,
  441. )
  442. return df, remediation
  443. def format_inferrer_validator(df: pd.DataFrame) -> Remediation:
  444. """
  445. This validator will infer the likely fine-tuning format of the data, and display it to the user if it is classification.
  446. It will also suggest to use ada and explain train/validation split benefits.
  447. """
  448. ft_type = infer_task_type(df)
  449. immediate_msg = None
  450. if ft_type == "classification":
  451. immediate_msg = f"\n- Based on your data it seems like you're trying to fine-tune a model for {ft_type}\n- For classification, we recommend you try one of the faster and cheaper models, such as `ada`\n- For classification, you can estimate the expected model performance by keeping a held out dataset, which is not used for training"
  452. return Remediation(name="num_examples", immediate_msg=immediate_msg)
  453. def apply_necessary_remediation(df: OptionalDataFrameT, remediation: Remediation) -> OptionalDataFrameT:
  454. """
  455. This function will apply a necessary remediation to a dataframe, or print an error message if one exists.
  456. """
  457. if remediation.error_msg is not None:
  458. sys.stderr.write(f"\n\nERROR in {remediation.name} validator: {remediation.error_msg}\n\nAborting...")
  459. sys.exit(1)
  460. if remediation.immediate_msg is not None:
  461. sys.stdout.write(remediation.immediate_msg)
  462. if remediation.necessary_fn is not None:
  463. df = remediation.necessary_fn(df)
  464. return df
  465. def accept_suggestion(input_text: str, auto_accept: bool) -> bool:
  466. sys.stdout.write(input_text)
  467. if auto_accept:
  468. sys.stdout.write("Y\n")
  469. return True
  470. return input().lower() != "n"
  471. def apply_optional_remediation(
  472. df: pd.DataFrame, remediation: Remediation, auto_accept: bool
  473. ) -> tuple[pd.DataFrame, bool]:
  474. """
  475. This function will apply an optional remediation to a dataframe, based on the user input.
  476. """
  477. optional_applied = False
  478. input_text = f"- [Recommended] {remediation.optional_msg} [Y/n]: "
  479. if remediation.optional_msg is not None:
  480. if accept_suggestion(input_text, auto_accept):
  481. assert remediation.optional_fn is not None
  482. df = remediation.optional_fn(df)
  483. optional_applied = True
  484. if remediation.necessary_msg is not None:
  485. sys.stdout.write(f"- [Necessary] {remediation.necessary_msg}\n")
  486. return df, optional_applied
  487. def estimate_fine_tuning_time(df: pd.DataFrame) -> None:
  488. """
  489. Estimate the time it'll take to fine-tune the dataset
  490. """
  491. ft_format = infer_task_type(df)
  492. expected_time = 1.0
  493. if ft_format == "classification":
  494. num_examples = len(df)
  495. expected_time = num_examples * 1.44
  496. else:
  497. size = df.memory_usage(index=True).sum()
  498. expected_time = size * 0.0515
  499. def format_time(time: float) -> str:
  500. if time < 60:
  501. return f"{round(time, 2)} seconds"
  502. elif time < 3600:
  503. return f"{round(time / 60, 2)} minutes"
  504. elif time < 86400:
  505. return f"{round(time / 3600, 2)} hours"
  506. else:
  507. return f"{round(time / 86400, 2)} days"
  508. time_string = format_time(expected_time + 140)
  509. sys.stdout.write(
  510. f"Once your model starts training, it'll approximately take {time_string} to train a `curie` model, and less for `ada` and `babbage`. Queue will approximately take half an hour per job ahead of you.\n"
  511. )
  512. def get_outfnames(fname: str, split: bool) -> list[str]:
  513. suffixes = ["_train", "_valid"] if split else [""]
  514. i = 0
  515. while True:
  516. index_suffix = f" ({i})" if i > 0 else ""
  517. candidate_fnames = [f"{os.path.splitext(fname)[0]}_prepared{suffix}{index_suffix}.jsonl" for suffix in suffixes]
  518. if not any(os.path.isfile(f) for f in candidate_fnames):
  519. return candidate_fnames
  520. i += 1
  521. def get_classification_hyperparams(df: pd.DataFrame) -> tuple[int, object]:
  522. n_classes = df.completion.nunique()
  523. pos_class = None
  524. if n_classes == 2:
  525. pos_class = df.completion.value_counts().index[0]
  526. return n_classes, pos_class
  527. def write_out_file(df: pd.DataFrame, fname: str, any_remediations: bool, auto_accept: bool) -> None:
  528. """
  529. This function will write out a dataframe to a file, if the user would like to proceed, and also offer a fine-tuning command with the newly created file.
  530. For classification it will optionally ask the user if they would like to split the data into train/valid files, and modify the suggested command to include the valid set.
  531. """
  532. ft_format = infer_task_type(df)
  533. common_prompt_suffix = get_common_xfix(df.prompt, xfix="suffix")
  534. common_completion_suffix = get_common_xfix(df.completion, xfix="suffix")
  535. split = False
  536. input_text = "- [Recommended] Would you like to split into training and validation set? [Y/n]: "
  537. if ft_format == "classification":
  538. if accept_suggestion(input_text, auto_accept):
  539. split = True
  540. additional_params = ""
  541. common_prompt_suffix_new_line_handled = common_prompt_suffix.replace("\n", "\\n")
  542. common_completion_suffix_new_line_handled = common_completion_suffix.replace("\n", "\\n")
  543. optional_ending_string = (
  544. f' Make sure to include `stop=["{common_completion_suffix_new_line_handled}"]` so that the generated texts ends at the expected place.'
  545. if len(common_completion_suffix_new_line_handled) > 0
  546. else ""
  547. )
  548. input_text = "\n\nYour data will be written to a new JSONL file. Proceed [Y/n]: "
  549. if not any_remediations and not split:
  550. sys.stdout.write(
  551. f'\nYou can use your file for fine-tuning:\n> openai api fine_tunes.create -t "{fname}"{additional_params}\n\nAfter you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt.{optional_ending_string}\n'
  552. )
  553. estimate_fine_tuning_time(df)
  554. elif accept_suggestion(input_text, auto_accept):
  555. fnames = get_outfnames(fname, split)
  556. if split:
  557. assert len(fnames) == 2 and "train" in fnames[0] and "valid" in fnames[1]
  558. MAX_VALID_EXAMPLES = 1000
  559. n_train = max(len(df) - MAX_VALID_EXAMPLES, int(len(df) * 0.8))
  560. df_train = df.sample(n=n_train, random_state=42)
  561. df_valid = df.drop(df_train.index)
  562. df_train[["prompt", "completion"]].to_json( # type: ignore
  563. fnames[0], lines=True, orient="records", force_ascii=False, indent=None
  564. )
  565. df_valid[["prompt", "completion"]].to_json(
  566. fnames[1], lines=True, orient="records", force_ascii=False, indent=None
  567. )
  568. n_classes, pos_class = get_classification_hyperparams(df)
  569. additional_params += " --compute_classification_metrics"
  570. if n_classes == 2:
  571. additional_params += f' --classification_positive_class "{pos_class}"'
  572. else:
  573. additional_params += f" --classification_n_classes {n_classes}"
  574. else:
  575. assert len(fnames) == 1
  576. df[["prompt", "completion"]].to_json(
  577. fnames[0], lines=True, orient="records", force_ascii=False, indent=None
  578. )
  579. # Add -v VALID_FILE if we split the file into train / valid
  580. files_string = ("s" if split else "") + " to `" + ("` and `".join(fnames))
  581. valid_string = f' -v "{fnames[1]}"' if split else ""
  582. separator_reminder = (
  583. ""
  584. if len(common_prompt_suffix_new_line_handled) == 0
  585. else f"After you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt."
  586. )
  587. sys.stdout.write(
  588. f'\nWrote modified file{files_string}`\nFeel free to take a look!\n\nNow use that file when fine-tuning:\n> openai api fine_tunes.create -t "{fnames[0]}"{valid_string}{additional_params}\n\n{separator_reminder}{optional_ending_string}\n'
  589. )
  590. estimate_fine_tuning_time(df)
  591. else:
  592. sys.stdout.write("Aborting... did not write the file\n")
  593. def infer_task_type(df: pd.DataFrame) -> str:
  594. """
  595. Infer the likely fine-tuning task type from the data
  596. """
  597. CLASSIFICATION_THRESHOLD = 3 # min_average instances of each class
  598. if sum(df.prompt.str.len()) == 0:
  599. return "open-ended generation"
  600. if len(df.completion.unique()) < len(df) / CLASSIFICATION_THRESHOLD:
  601. return "classification"
  602. return "conditional generation"
  603. def get_common_xfix(series: Any, xfix: str = "suffix") -> str:
  604. """
  605. Finds the longest common suffix or prefix of all the values in a series
  606. """
  607. common_xfix = ""
  608. while True:
  609. common_xfixes = (
  610. series.str[-(len(common_xfix) + 1) :] if xfix == "suffix" else series.str[: len(common_xfix) + 1]
  611. ) # first few or last few characters
  612. if common_xfixes.nunique() != 1: # we found the character at which we don't have a unique xfix anymore
  613. break
  614. elif common_xfix == common_xfixes.values[0]: # the entire first row is a prefix of every other row
  615. break
  616. else: # the first or last few characters are still common across all rows - let's try to add one more
  617. common_xfix = common_xfixes.values[0]
  618. return common_xfix
  619. Validator: TypeAlias = "Callable[[pd.DataFrame], Remediation | None]"
  620. def get_validators() -> list[Validator]:
  621. return [
  622. num_examples_validator,
  623. lambda x: necessary_column_validator(x, "prompt"),
  624. lambda x: necessary_column_validator(x, "completion"),
  625. additional_column_validator,
  626. non_empty_field_validator,
  627. format_inferrer_validator,
  628. duplicated_rows_validator,
  629. long_examples_validator,
  630. lambda x: lower_case_validator(x, "prompt"),
  631. lambda x: lower_case_validator(x, "completion"),
  632. common_prompt_suffix_validator,
  633. common_prompt_prefix_validator,
  634. common_completion_prefix_validator,
  635. common_completion_suffix_validator,
  636. completions_space_start_validator,
  637. ]
  638. def apply_validators(
  639. df: pd.DataFrame,
  640. fname: str,
  641. remediation: Remediation | None,
  642. validators: list[Validator],
  643. auto_accept: bool,
  644. write_out_file_func: Callable[..., Any],
  645. ) -> None:
  646. optional_remediations: list[Remediation] = []
  647. if remediation is not None:
  648. optional_remediations.append(remediation)
  649. for validator in validators:
  650. remediation = validator(df)
  651. if remediation is not None:
  652. optional_remediations.append(remediation)
  653. df = apply_necessary_remediation(df, remediation)
  654. any_optional_or_necessary_remediations = any(
  655. [
  656. remediation
  657. for remediation in optional_remediations
  658. if remediation.optional_msg is not None or remediation.necessary_msg is not None
  659. ]
  660. )
  661. any_necessary_applied = any(
  662. [remediation for remediation in optional_remediations if remediation.necessary_msg is not None]
  663. )
  664. any_optional_applied = False
  665. if any_optional_or_necessary_remediations:
  666. sys.stdout.write("\n\nBased on the analysis we will perform the following actions:\n")
  667. for remediation in optional_remediations:
  668. df, optional_applied = apply_optional_remediation(df, remediation, auto_accept)
  669. any_optional_applied = any_optional_applied or optional_applied
  670. else:
  671. sys.stdout.write("\n\nNo remediations found.\n")
  672. any_optional_or_necessary_applied = any_optional_applied or any_necessary_applied
  673. write_out_file_func(df, fname, any_optional_or_necessary_applied, auto_accept)