integration_utils.py 110 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Integrations with other Python libraries.
  16. """
  17. import copy
  18. import functools
  19. import importlib.metadata
  20. import importlib.util
  21. import json
  22. import numbers
  23. import os
  24. import pickle
  25. import re
  26. import shutil
  27. import sys
  28. import tempfile
  29. from dataclasses import asdict, fields
  30. from enum import Enum
  31. from pathlib import Path
  32. from typing import TYPE_CHECKING, Any, Literal, Optional, Union
  33. import numpy as np
  34. import packaging.version
  35. from transformers.utils.import_utils import _is_package_available
  36. if os.getenv("WANDB_MODE") == "offline":
  37. print("⚙️ Running in WANDB offline mode")
  38. from .. import PreTrainedModel, TrainingArguments
  39. from .. import __version__ as version
  40. from ..utils import (
  41. PushToHubMixin,
  42. flatten_dict,
  43. is_datasets_available,
  44. is_pandas_available,
  45. is_tf_available,
  46. is_torch_available,
  47. logging,
  48. )
  49. logger = logging.get_logger(__name__)
  50. if is_tf_available():
  51. from .. import TFPreTrainedModel
  52. if is_torch_available():
  53. import torch
  54. import torch.distributed as dist
  55. # comet_ml requires to be imported before any ML frameworks
  56. _MIN_COMET_VERSION = "3.43.2"
  57. try:
  58. _comet_version = importlib.metadata.version("comet_ml")
  59. _is_comet_installed = True
  60. _is_comet_recent_enough = packaging.version.parse(_comet_version) >= packaging.version.parse(_MIN_COMET_VERSION)
  61. # Check if the Comet API Key is set
  62. import comet_ml
  63. if comet_ml.config.get_config("comet.api_key") is not None:
  64. _is_comet_configured = True
  65. else:
  66. _is_comet_configured = False
  67. except (importlib.metadata.PackageNotFoundError, ImportError, ValueError, TypeError, AttributeError, KeyError):
  68. _comet_version = None
  69. _is_comet_installed = False
  70. _is_comet_recent_enough = False
  71. _is_comet_configured = False
  72. _has_neptune = (
  73. importlib.util.find_spec("neptune") is not None or importlib.util.find_spec("neptune-client") is not None
  74. )
  75. if TYPE_CHECKING and _has_neptune:
  76. try:
  77. _neptune_version = importlib.metadata.version("neptune")
  78. logger.info(f"Neptune version {_neptune_version} available.")
  79. except importlib.metadata.PackageNotFoundError:
  80. try:
  81. _neptune_version = importlib.metadata.version("neptune-client")
  82. logger.info(f"Neptune-client version {_neptune_version} available.")
  83. except importlib.metadata.PackageNotFoundError:
  84. _has_neptune = False
  85. from .. import modelcard # noqa: E402
  86. from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
  87. from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
  88. from ..training_args import ParallelMode # noqa: E402
  89. from ..utils import ENV_VARS_TRUE_VALUES, is_torch_xla_available # noqa: E402
  90. # Integration functions:
  91. def is_wandb_available():
  92. # any value of WANDB_DISABLED disables wandb
  93. if os.getenv("WANDB_DISABLED", "").upper() in ENV_VARS_TRUE_VALUES:
  94. logger.warning(
  95. "Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the "
  96. "--report_to flag to control the integrations used for logging result (for instance --report_to none)."
  97. )
  98. return False
  99. if importlib.util.find_spec("wandb") is not None:
  100. import wandb
  101. # wandb might still be detected by find_spec after an uninstall (leftover files or metadata), but not actually
  102. # import correctly. To confirm it's fully installed and usable, we check for a key attribute like "run".
  103. return hasattr(wandb, "run")
  104. else:
  105. return False
  106. def is_trackio_available():
  107. return importlib.util.find_spec("trackio") is not None
  108. def is_clearml_available():
  109. return importlib.util.find_spec("clearml") is not None
  110. def is_comet_available():
  111. if os.getenv("COMET_MODE", "").upper() == "DISABLED":
  112. logger.warning(
  113. "Using the `COMET_MODE=DISABLED` environment variable is deprecated and will be removed in v5. Use the "
  114. "--report_to flag to control the integrations used for logging result (for instance --report_to none)."
  115. )
  116. return False
  117. if _is_comet_installed is False:
  118. return False
  119. if _is_comet_recent_enough is False:
  120. logger.warning(
  121. "comet_ml version %s is installed, but version %s or higher is required. "
  122. "Please update comet_ml to the latest version to enable Comet logging with pip install 'comet-ml>=%s'.",
  123. _comet_version,
  124. _MIN_COMET_VERSION,
  125. _MIN_COMET_VERSION,
  126. )
  127. return False
  128. if _is_comet_configured is False:
  129. logger.warning(
  130. "comet_ml is installed but the Comet API Key is not configured. "
  131. "Please set the `COMET_API_KEY` environment variable to enable Comet logging. "
  132. "Check out the documentation for other ways of configuring it: "
  133. "https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#set-the-api-key"
  134. )
  135. return False
  136. return True
  137. def is_tensorboard_available():
  138. return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
  139. def is_optuna_available():
  140. return importlib.util.find_spec("optuna") is not None
  141. def is_ray_available():
  142. return importlib.util.find_spec("ray") is not None
  143. def is_ray_tune_available():
  144. if not is_ray_available():
  145. return False
  146. return importlib.util.find_spec("ray.tune") is not None
  147. def is_sigopt_available():
  148. return importlib.util.find_spec("sigopt") is not None
  149. def is_azureml_available():
  150. if importlib.util.find_spec("azureml") is None:
  151. return False
  152. if importlib.util.find_spec("azureml.core") is None:
  153. return False
  154. return importlib.util.find_spec("azureml.core.run") is not None
  155. def is_mlflow_available():
  156. if os.getenv("DISABLE_MLFLOW_INTEGRATION", "FALSE").upper() == "TRUE":
  157. return False
  158. return importlib.util.find_spec("mlflow") is not None
  159. def is_dagshub_available():
  160. return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")]
  161. def is_neptune_available():
  162. return _has_neptune
  163. def is_codecarbon_available():
  164. return importlib.util.find_spec("codecarbon") is not None
  165. def is_flytekit_available():
  166. return importlib.util.find_spec("flytekit") is not None
  167. def is_flyte_deck_standard_available():
  168. if not is_flytekit_available():
  169. return False
  170. return importlib.util.find_spec("flytekitplugins.deck") is not None
  171. def is_dvclive_available():
  172. return importlib.util.find_spec("dvclive") is not None
  173. def is_swanlab_available():
  174. return importlib.util.find_spec("swanlab") is not None
  175. def hp_params(trial):
  176. if is_optuna_available():
  177. import optuna
  178. if isinstance(trial, optuna.trial.BaseTrial):
  179. return trial.params
  180. if is_ray_tune_available():
  181. if isinstance(trial, dict):
  182. return trial
  183. if is_sigopt_available():
  184. if isinstance(trial, dict):
  185. return trial
  186. if is_wandb_available():
  187. if isinstance(trial, dict):
  188. return trial
  189. raise RuntimeError(f"Unknown type for trial {trial.__class__}")
  190. def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
  191. import optuna
  192. from accelerate.utils.memory import release_memory
  193. if trainer.args.process_index == 0:
  194. def _objective(trial: optuna.Trial, checkpoint_dir=None):
  195. checkpoint = None
  196. if checkpoint_dir:
  197. for subdir in os.listdir(checkpoint_dir):
  198. if subdir.startswith(PREFIX_CHECKPOINT_DIR):
  199. checkpoint = os.path.join(checkpoint_dir, subdir)
  200. trainer.objective = None
  201. if trainer.args.world_size > 1:
  202. if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
  203. raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
  204. trainer.hp_space(trial)
  205. fixed_trial = optuna.trial.FixedTrial(trial.params, trial.number)
  206. trial_main_rank_list = [fixed_trial]
  207. torch.distributed.broadcast_object_list(trial_main_rank_list, src=0)
  208. trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
  209. else:
  210. trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
  211. # If there hasn't been any evaluation during the training loop.
  212. if getattr(trainer, "objective", None) is None:
  213. metrics = trainer.evaluate()
  214. trainer.objective = trainer.compute_objective(metrics)
  215. # Free GPU memory
  216. trainer.model_wrapped, trainer.model = release_memory(trainer.model_wrapped, trainer.model)
  217. trainer.accelerator.clear()
  218. return trainer.objective
  219. timeout = kwargs.pop("timeout", None)
  220. n_jobs = kwargs.pop("n_jobs", 1)
  221. gc_after_trial = kwargs.pop("gc_after_trial", False)
  222. directions = direction if isinstance(direction, list) else None
  223. direction = None if directions is not None else direction
  224. study = optuna.create_study(direction=direction, directions=directions, **kwargs)
  225. study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs, gc_after_trial=gc_after_trial)
  226. if not study._is_multi_objective():
  227. best_trial = study.best_trial
  228. return BestRun(str(best_trial.number), best_trial.value, best_trial.params)
  229. else:
  230. best_trials = study.best_trials
  231. return [BestRun(str(best.number), best.values, best.params) for best in best_trials]
  232. else:
  233. for i in range(n_trials):
  234. trainer.objective = None
  235. trial_main_rank_list = [None]
  236. if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
  237. raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
  238. torch.distributed.broadcast_object_list(trial_main_rank_list, src=0)
  239. trainer.train(resume_from_checkpoint=None, trial=trial_main_rank_list[0])
  240. # If there hasn't been any evaluation during the training loop.
  241. if getattr(trainer, "objective", None) is None:
  242. metrics = trainer.evaluate()
  243. trainer.objective = trainer.compute_objective(metrics)
  244. return None
  245. def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
  246. import ray
  247. import ray.train
  248. def _objective(trial: dict, local_trainer):
  249. try:
  250. from transformers.utils.notebook import NotebookProgressCallback
  251. if local_trainer.pop_callback(NotebookProgressCallback):
  252. local_trainer.add_callback(ProgressCallback)
  253. except ModuleNotFoundError:
  254. pass
  255. local_trainer.objective = None
  256. checkpoint = ray.train.get_checkpoint()
  257. if checkpoint:
  258. # Upon trial resume, the local_trainer's objective gets reset to None.
  259. # If `local_trainer.train` is a noop (training has already reached
  260. # the target number of epochs/steps), then this would
  261. # trigger an unnecessary extra checkpoint at the end of training.
  262. # -> Set the objective to a dummy value upon resume as a workaround.
  263. local_trainer.objective = "objective"
  264. with checkpoint.as_directory() as checkpoint_dir:
  265. checkpoint_path = next(Path(checkpoint_dir).glob(f"{PREFIX_CHECKPOINT_DIR}*")).as_posix()
  266. local_trainer.train(resume_from_checkpoint=checkpoint_path, trial=trial)
  267. else:
  268. local_trainer.train(trial=trial)
  269. # If there hasn't been any evaluation during the training loop.
  270. if getattr(local_trainer, "objective", None) is None:
  271. metrics = local_trainer.evaluate()
  272. local_trainer.objective = local_trainer.compute_objective(metrics)
  273. metrics.update({"objective": local_trainer.objective, "done": True})
  274. with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
  275. local_trainer._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
  276. checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
  277. ray.train.report(metrics, checkpoint=checkpoint)
  278. if not trainer._memory_tracker.skip_memory_metrics:
  279. from ..trainer_utils import TrainerMemoryTracker
  280. logger.warning(
  281. "Memory tracking for your Trainer is currently "
  282. "enabled. Automatically disabling the memory tracker "
  283. "since the memory tracker is not serializable."
  284. )
  285. trainer._memory_tracker = TrainerMemoryTracker(skip_memory_metrics=True)
  286. # The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
  287. # while doing the ray hp search.
  288. _tb_writer = trainer.pop_callback(TensorBoardCallback)
  289. trainer.model = None
  290. # Setup default `resources_per_trial`.
  291. if "resources_per_trial" not in kwargs:
  292. # Default to 1 CPU and 1 GPU (if applicable) per trial.
  293. kwargs["resources_per_trial"] = {"cpu": 1}
  294. if trainer.args.n_gpu > 0:
  295. kwargs["resources_per_trial"]["gpu"] = 1
  296. resource_msg = "1 CPU" + (" and 1 GPU" if trainer.args.n_gpu > 0 else "")
  297. logger.info(
  298. "No `resources_per_trial` arg was passed into "
  299. "`hyperparameter_search`. Setting it to a default value "
  300. f"of {resource_msg} for each trial."
  301. )
  302. # Make sure each trainer only uses GPUs that were allocated per trial.
  303. gpus_per_trial = kwargs["resources_per_trial"].get("gpu", 0)
  304. trainer.args._n_gpu = gpus_per_trial
  305. # Setup default `progress_reporter`.
  306. if "progress_reporter" not in kwargs:
  307. from ray.tune import CLIReporter
  308. kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
  309. if "scheduler" in kwargs:
  310. from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
  311. # Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
  312. if isinstance(
  313. kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
  314. ) and (not trainer.args.do_eval or trainer.args.eval_strategy == IntervalStrategy.NO):
  315. raise RuntimeError(
  316. "You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
  317. "This means your trials will not report intermediate results to Ray Tune, and "
  318. "can thus not be stopped early or used to exploit other trials parameters. "
  319. "If this is what you want, do not use {cls}. If you would like to use {cls}, "
  320. "make sure you pass `do_eval=True` and `eval_strategy='steps'` in the "
  321. "Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
  322. )
  323. trainable = ray.tune.with_parameters(_objective, local_trainer=trainer)
  324. @functools.wraps(trainable)
  325. def dynamic_modules_import_trainable(*args, **kwargs):
  326. """
  327. Wrapper around `tune.with_parameters` to ensure datasets_modules are loaded on each Actor.
  328. Without this, an ImportError will be thrown. See https://github.com/huggingface/transformers/issues/11565.
  329. Assumes that `_objective`, defined above, is a function.
  330. """
  331. if is_datasets_available():
  332. import datasets.load
  333. dynamic_modules_path = os.path.join(datasets.load.init_dynamic_modules(), "__init__.py")
  334. # load dynamic_modules from path
  335. spec = importlib.util.spec_from_file_location("datasets_modules", dynamic_modules_path)
  336. datasets_modules = importlib.util.module_from_spec(spec)
  337. sys.modules[spec.name] = datasets_modules
  338. spec.loader.exec_module(datasets_modules)
  339. return trainable(*args, **kwargs)
  340. # special attr set by tune.with_parameters
  341. if hasattr(trainable, "__mixins__"):
  342. dynamic_modules_import_trainable.__mixins__ = trainable.__mixins__
  343. analysis = ray.tune.run(
  344. dynamic_modules_import_trainable,
  345. config=trainer.hp_space(None),
  346. num_samples=n_trials,
  347. **kwargs,
  348. )
  349. best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3], scope=trainer.args.ray_scope)
  350. best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config, analysis)
  351. if _tb_writer is not None:
  352. trainer.add_callback(_tb_writer)
  353. return best_run
  354. def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
  355. import sigopt
  356. if trainer.args.process_index == 0:
  357. if importlib.metadata.version("sigopt") >= "8.0.0":
  358. sigopt.set_project("huggingface")
  359. experiment = sigopt.create_experiment(
  360. name="huggingface-tune",
  361. type="offline",
  362. parameters=trainer.hp_space(None),
  363. metrics=[{"name": "objective", "objective": direction, "strategy": "optimize"}],
  364. parallel_bandwidth=1,
  365. budget=n_trials,
  366. )
  367. logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")
  368. for run in experiment.loop():
  369. with run:
  370. trainer.objective = None
  371. if trainer.args.world_size > 1:
  372. if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
  373. raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.")
  374. trainer._hp_search_setup(run.run)
  375. torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0)
  376. trainer.train(resume_from_checkpoint=None)
  377. else:
  378. trainer.train(resume_from_checkpoint=None, trial=run.run)
  379. # If there hasn't been any evaluation during the training loop.
  380. if getattr(trainer, "objective", None) is None:
  381. metrics = trainer.evaluate()
  382. trainer.objective = trainer.compute_objective(metrics)
  383. run.log_metric("objective", trainer.objective)
  384. best = list(experiment.get_best_runs())[0]
  385. best_run = BestRun(best.id, best.values["objective"].value, best.assignments)
  386. else:
  387. from sigopt import Connection
  388. conn = Connection()
  389. proxies = kwargs.pop("proxies", None)
  390. if proxies is not None:
  391. conn.set_proxies(proxies)
  392. experiment = conn.experiments().create(
  393. name="huggingface-tune",
  394. parameters=trainer.hp_space(None),
  395. metrics=[{"name": "objective", "objective": direction, "strategy": "optimize"}],
  396. parallel_bandwidth=1,
  397. observation_budget=n_trials,
  398. project="huggingface",
  399. )
  400. logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")
  401. while experiment.progress.observation_count < experiment.observation_budget:
  402. suggestion = conn.experiments(experiment.id).suggestions().create()
  403. trainer.objective = None
  404. if trainer.args.world_size > 1:
  405. if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
  406. raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.")
  407. trainer._hp_search_setup(suggestion)
  408. torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0)
  409. trainer.train(resume_from_checkpoint=None)
  410. else:
  411. trainer.train(resume_from_checkpoint=None, trial=suggestion)
  412. # If there hasn't been any evaluation during the training loop.
  413. if getattr(trainer, "objective", None) is None:
  414. metrics = trainer.evaluate()
  415. trainer.objective = trainer.compute_objective(metrics)
  416. values = [{"name": "objective", "value": trainer.objective}]
  417. obs = conn.experiments(experiment.id).observations().create(suggestion=suggestion.id, values=values)
  418. logger.info(f"[suggestion_id, observation_id]: [{suggestion.id}, {obs.id}]")
  419. experiment = conn.experiments(experiment.id).fetch()
  420. best = list(conn.experiments(experiment.id).best_assignments().fetch().iterate_pages())[0]
  421. best_run = BestRun(best.id, best.value, best.assignments)
  422. return best_run
  423. else:
  424. for i in range(n_trials):
  425. trainer.objective = None
  426. args_main_rank = list(pickle.dumps(trainer.args))
  427. if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
  428. raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.")
  429. torch.distributed.broadcast_object_list(args_main_rank, src=0)
  430. args = pickle.loads(bytes(args_main_rank))
  431. for key, value in asdict(args).items():
  432. if key != "local_rank":
  433. setattr(trainer.args, key, value)
  434. trainer.train(resume_from_checkpoint=None)
  435. # If there hasn't been any evaluation during the training loop.
  436. if getattr(trainer, "objective", None) is None:
  437. metrics = trainer.evaluate()
  438. trainer.objective = trainer.compute_objective(metrics)
  439. return None
  440. def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
  441. if not is_wandb_available():
  442. raise ImportError("This function needs wandb installed: `pip install wandb`")
  443. import wandb
  444. # add WandbCallback if not already added in trainer callbacks
  445. reporting_to_wandb = False
  446. for callback in trainer.callback_handler.callbacks:
  447. if isinstance(callback, WandbCallback):
  448. reporting_to_wandb = True
  449. break
  450. if not reporting_to_wandb:
  451. trainer.add_callback(WandbCallback())
  452. trainer.args.report_to = ["wandb"]
  453. best_trial = {"run_id": None, "objective": None, "hyperparameters": None}
  454. sweep_id = kwargs.pop("sweep_id", None)
  455. project = kwargs.pop("project", None)
  456. name = kwargs.pop("name", None)
  457. entity = kwargs.pop("entity", None)
  458. metric = kwargs.pop("metric", "eval/loss")
  459. sweep_config = trainer.hp_space(None)
  460. sweep_config["metric"]["goal"] = direction
  461. sweep_config["metric"]["name"] = metric
  462. if name:
  463. sweep_config["name"] = name
  464. def _objective():
  465. run = wandb.run if wandb.run else wandb.init()
  466. trainer.state.trial_name = run.name
  467. run.config.update({"assignments": {}, "metric": metric})
  468. config = wandb.config
  469. trainer.objective = None
  470. trainer.train(resume_from_checkpoint=None, trial=vars(config)["_items"])
  471. # If there hasn't been any evaluation during the training loop.
  472. if getattr(trainer, "objective", None) is None:
  473. metrics = trainer.evaluate()
  474. trainer.objective = trainer.compute_objective(metrics)
  475. format_metrics = rewrite_logs(metrics)
  476. if metric not in format_metrics:
  477. logger.warning(
  478. f"Provided metric {metric} not found. This might result in unexpected sweeps charts. The available"
  479. f" metrics are {format_metrics.keys()}"
  480. )
  481. best_score = False
  482. if best_trial["run_id"] is not None:
  483. if direction == "minimize":
  484. best_score = trainer.objective < best_trial["objective"]
  485. elif direction == "maximize":
  486. best_score = trainer.objective > best_trial["objective"]
  487. if best_score or best_trial["run_id"] is None:
  488. best_trial["run_id"] = run.id
  489. best_trial["objective"] = trainer.objective
  490. best_trial["hyperparameters"] = dict(config)
  491. return trainer.objective
  492. if not sweep_id:
  493. sweep_id = wandb.sweep(sweep_config, project=project, entity=entity)
  494. else:
  495. import wandb.env
  496. if entity:
  497. wandb.env.set_entity(entity)
  498. wandb.env.set_project(project)
  499. logger.info(f"wandb sweep id - {sweep_id}")
  500. wandb.agent(sweep_id, function=_objective, count=n_trials)
  501. return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"], sweep_id)
  502. def get_available_reporting_integrations():
  503. integrations = []
  504. if is_azureml_available() and not is_mlflow_available():
  505. integrations.append("azure_ml")
  506. if is_comet_available():
  507. integrations.append("comet_ml")
  508. if is_dagshub_available():
  509. integrations.append("dagshub")
  510. if is_dvclive_available():
  511. integrations.append("dvclive")
  512. if is_mlflow_available():
  513. integrations.append("mlflow")
  514. if is_neptune_available():
  515. integrations.append("neptune")
  516. if is_tensorboard_available():
  517. integrations.append("tensorboard")
  518. if is_wandb_available():
  519. integrations.append("wandb")
  520. if is_codecarbon_available():
  521. integrations.append("codecarbon")
  522. if is_clearml_available():
  523. integrations.append("clearml")
  524. if is_swanlab_available():
  525. integrations.append("swanlab")
  526. if is_trackio_available():
  527. integrations.append("trackio")
  528. return integrations
  529. def rewrite_logs(d):
  530. new_d = {}
  531. eval_prefix = "eval_"
  532. eval_prefix_len = len(eval_prefix)
  533. test_prefix = "test_"
  534. test_prefix_len = len(test_prefix)
  535. for k, v in d.items():
  536. if k.startswith(eval_prefix):
  537. new_d["eval/" + k[eval_prefix_len:]] = v
  538. elif k.startswith(test_prefix):
  539. new_d["test/" + k[test_prefix_len:]] = v
  540. else:
  541. new_d["train/" + k] = v
  542. return new_d
  543. class TensorBoardCallback(TrainerCallback):
  544. """
  545. A [`TrainerCallback`] that sends the logs to [TensorBoard](https://www.tensorflow.org/tensorboard).
  546. Args:
  547. tb_writer (`SummaryWriter`, *optional*):
  548. The writer to use. Will instantiate one if not set.
  549. """
  550. def __init__(self, tb_writer=None):
  551. has_tensorboard = is_tensorboard_available()
  552. if not has_tensorboard:
  553. raise RuntimeError(
  554. "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or"
  555. " install tensorboardX."
  556. )
  557. if has_tensorboard:
  558. try:
  559. from torch.utils.tensorboard import SummaryWriter
  560. self._SummaryWriter = SummaryWriter
  561. except ImportError:
  562. try:
  563. from tensorboardX import SummaryWriter
  564. self._SummaryWriter = SummaryWriter
  565. except ImportError:
  566. self._SummaryWriter = None
  567. else:
  568. self._SummaryWriter = None
  569. self.tb_writer = tb_writer
  570. def _init_summary_writer(self, args, log_dir=None):
  571. log_dir = log_dir or args.logging_dir
  572. if self._SummaryWriter is not None:
  573. self.tb_writer = self._SummaryWriter(log_dir=log_dir)
  574. def on_train_begin(self, args, state, control, **kwargs):
  575. if not state.is_world_process_zero:
  576. return
  577. log_dir = None
  578. if state.is_hyper_param_search:
  579. trial_name = state.trial_name
  580. if trial_name is not None:
  581. log_dir = os.path.join(args.logging_dir, trial_name)
  582. if self.tb_writer is None:
  583. self._init_summary_writer(args, log_dir)
  584. if self.tb_writer is not None:
  585. self.tb_writer.add_text("args", args.to_json_string())
  586. if "model" in kwargs:
  587. model = kwargs["model"]
  588. if hasattr(model, "config") and model.config is not None:
  589. model_config_json = model.config.to_json_string()
  590. self.tb_writer.add_text("model_config", model_config_json)
  591. def on_log(self, args, state, control, logs=None, **kwargs):
  592. if not state.is_world_process_zero:
  593. return
  594. if self.tb_writer is None:
  595. self._init_summary_writer(args)
  596. if self.tb_writer is not None:
  597. logs = rewrite_logs(logs)
  598. for k, v in logs.items():
  599. if isinstance(v, (int, float)):
  600. self.tb_writer.add_scalar(k, v, state.global_step)
  601. elif isinstance(v, str):
  602. self.tb_writer.add_text(k, v, state.global_step)
  603. else:
  604. logger.warning(
  605. "Trainer is attempting to log a value of "
  606. f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
  607. "This invocation of Tensorboard's writer.add_scalar() "
  608. "is incorrect so we dropped this attribute."
  609. )
  610. self.tb_writer.flush()
  611. def on_train_end(self, args, state, control, **kwargs):
  612. if self.tb_writer:
  613. self.tb_writer.close()
  614. self.tb_writer = None
  615. def save_model_architecture_to_file(model: Any, output_dir: str):
  616. with open(f"{output_dir}/model_architecture.txt", "w+") as f:
  617. if isinstance(model, PreTrainedModel):
  618. print(model, file=f)
  619. elif is_tf_available() and isinstance(model, TFPreTrainedModel):
  620. def print_to_file(s):
  621. print(s, file=f)
  622. model.summary(print_fn=print_to_file)
  623. elif is_torch_available() and (
  624. isinstance(model, (torch.nn.Module, PushToHubMixin)) and hasattr(model, "base_model")
  625. ):
  626. print(model, file=f)
  627. class WandbLogModel(str, Enum):
  628. """Enum of possible log model values in W&B."""
  629. CHECKPOINT = "checkpoint"
  630. END = "end"
  631. FALSE = "false"
  632. @property
  633. def is_enabled(self) -> bool:
  634. """Check if the value corresponds to a state where the `WANDB_LOG_MODEL` setting is enabled."""
  635. return self in (WandbLogModel.CHECKPOINT, WandbLogModel.END)
  636. @classmethod
  637. def _missing_(cls, value: Any) -> "WandbLogModel":
  638. if not isinstance(value, str):
  639. raise TypeError(f"Expecting to have a string `WANDB_LOG_MODEL` setting, but got {type(value)}")
  640. if value.upper() in ENV_VARS_TRUE_VALUES:
  641. raise DeprecationWarning(
  642. f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
  643. "version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
  644. )
  645. logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
  646. return WandbLogModel.END
  647. logger.warning(
  648. f"Received unrecognized `WANDB_LOG_MODEL` setting value={value}; so disabling `WANDB_LOG_MODEL`"
  649. )
  650. return WandbLogModel.FALSE
  651. class WandbCallback(TrainerCallback):
  652. """
  653. A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
  654. """
  655. def __init__(self):
  656. has_wandb = is_wandb_available()
  657. if not has_wandb:
  658. # Check if wandb is actually installed but disabled via WANDB_DISABLED
  659. if importlib.util.find_spec("wandb") is not None:
  660. # wandb is installed but disabled
  661. wandb_disabled = os.getenv("WANDB_DISABLED", "").upper() in ENV_VARS_TRUE_VALUES
  662. if wandb_disabled:
  663. raise RuntimeError(
  664. "You specified `report_to='wandb'` but also set the `WANDB_DISABLED` environment variable.\n"
  665. "This disables wandb logging, even though it was explicitly requested.\n\n"
  666. "- To enable wandb logging: unset `WANDB_DISABLED`.\n"
  667. "- To disable logging: use `report_to='none'`.\n\n"
  668. "Note: WANDB_DISABLED is deprecated and will be removed in v5."
  669. )
  670. # If wandb is not installed at all, use the original error message
  671. raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.")
  672. if has_wandb:
  673. import wandb
  674. self._wandb = wandb
  675. self._initialized = False
  676. self._log_model = WandbLogModel(os.getenv("WANDB_LOG_MODEL", "false"))
  677. def setup(self, args, state, model, **kwargs):
  678. """
  679. Setup the optional Weights & Biases (*wandb*) integration.
  680. One can subclass and override this method to customize the setup if needed. Find more information
  681. [here](https://docs.wandb.ai/guides/integrations/huggingface). You can also override the following environment
  682. variables:
  683. Environment:
  684. - **WANDB_LOG_MODEL** (`str`, *optional*, defaults to `"false"`):
  685. Whether to log model and checkpoints during training. Can be `"end"`, `"checkpoint"` or `"false"`. If set
  686. to `"end"`, the model will be uploaded at the end of training. If set to `"checkpoint"`, the checkpoint
  687. will be uploaded every `args.save_steps` . If set to `"false"`, the model will not be uploaded. Use along
  688. with [`~transformers.TrainingArguments.load_best_model_at_end`] to upload best model.
  689. <Deprecated version="5.0">
  690. Setting `WANDB_LOG_MODEL` as `bool` will be deprecated in version 5 of 🤗 Transformers.
  691. </Deprecated>
  692. - **WANDB_WATCH** (`str`, *optional* defaults to `"false"`):
  693. Can be `"gradients"`, `"all"`, `"parameters"`, or `"false"`. Set to `"all"` to log gradients and
  694. parameters.
  695. - **WANDB_PROJECT** (`str`, *optional*, defaults to `"huggingface"`):
  696. Set this to a custom string to store results in a different project.
  697. - **WANDB_DISABLED** (`bool`, *optional*, defaults to `False`):
  698. Whether to disable wandb entirely. Set `WANDB_DISABLED=true` to disable.
  699. """
  700. if self._wandb is None:
  701. return
  702. self._initialized = True
  703. # prepare to handle potential configuration issues during setup
  704. from wandb.sdk.lib.config_util import ConfigError as WandbConfigError
  705. if state.is_world_process_zero:
  706. logger.info(
  707. 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
  708. )
  709. combined_dict = {**args.to_dict()}
  710. if hasattr(model, "config") and model.config is not None:
  711. model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
  712. combined_dict = {**model_config, **combined_dict}
  713. if hasattr(model, "peft_config") and model.peft_config is not None:
  714. peft_config = model.peft_config
  715. combined_dict = {**{"peft_config": peft_config}, **combined_dict}
  716. trial_name = state.trial_name
  717. init_args = {}
  718. if trial_name is not None:
  719. init_args["name"] = trial_name
  720. init_args["group"] = args.run_name or args.output_dir
  721. elif args.run_name is not None:
  722. init_args["name"] = args.run_name
  723. if args.run_name == args.output_dir:
  724. self._wandb.termwarn(
  725. "The `run_name` is currently set to the same value as `TrainingArguments.output_dir`. If this was "
  726. "not intended, please specify a different run name by setting the `TrainingArguments.run_name` parameter.",
  727. repeat=False,
  728. )
  729. if self._wandb.run is None:
  730. self._wandb.init(
  731. project=os.getenv("WANDB_PROJECT", "huggingface"),
  732. **init_args,
  733. )
  734. # add config parameters (run may have been created manually)
  735. self._wandb.config.update(combined_dict or {}, allow_val_change=True)
  736. # define default x-axis (for latest wandb versions)
  737. if getattr(self._wandb, "define_metric", None):
  738. self._wandb.define_metric("train/global_step")
  739. self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
  740. # keep track of model topology and gradients, unsupported on TPU
  741. _watch_model = os.getenv("WANDB_WATCH", "false")
  742. if not is_torch_xla_available() and _watch_model in ("all", "parameters", "gradients"):
  743. self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
  744. self._wandb.run._label(code="transformers_trainer")
  745. # add number of model parameters to wandb config
  746. try:
  747. self._wandb.config["model/num_parameters"] = model.num_parameters()
  748. except AttributeError:
  749. logger.info(
  750. "Could not log the number of model parameters in Weights & Biases due to an AttributeError."
  751. )
  752. except WandbConfigError:
  753. logger.warning(
  754. "A ConfigError was raised whilst setting the number of model parameters in Weights & Biases config."
  755. )
  756. # log the initial model architecture to an artifact
  757. if self._log_model.is_enabled:
  758. with tempfile.TemporaryDirectory() as temp_dir:
  759. model_name = (
  760. f"model-{self._wandb.run.id}"
  761. if (args.run_name is None or args.run_name == args.output_dir)
  762. else f"model-{self._wandb.run.name}"
  763. )
  764. model_artifact = self._wandb.Artifact(
  765. name=model_name,
  766. type="model",
  767. metadata={
  768. "model_config": model.config.to_dict() if hasattr(model, "config") else None,
  769. "num_parameters": self._wandb.config.get("model/num_parameters"),
  770. "initial_model": True,
  771. },
  772. )
  773. # add the architecture to a separate text file
  774. save_model_architecture_to_file(model, temp_dir)
  775. for f in Path(temp_dir).glob("*"):
  776. if f.is_file():
  777. with model_artifact.new_file(f.name, mode="wb") as fa:
  778. fa.write(f.read_bytes())
  779. self._wandb.run.log_artifact(model_artifact, aliases=["base_model"])
  780. badge_markdown = (
  781. f'[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge'
  782. f'-28.svg" alt="Visualize in Weights & Biases" width="20'
  783. f'0" height="32"/>]({self._wandb.run.url})'
  784. )
  785. modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
  786. def on_train_begin(self, args, state, control, model=None, **kwargs):
  787. if self._wandb is None:
  788. return
  789. hp_search = state.is_hyper_param_search
  790. if hp_search:
  791. self._wandb.finish()
  792. self._initialized = False
  793. args.run_name = None
  794. if not self._initialized:
  795. self.setup(args, state, model, **kwargs)
  796. def on_train_end(self, args: TrainingArguments, state, control, model=None, processing_class=None, **kwargs):
  797. if self._wandb is None:
  798. return
  799. if self._log_model.is_enabled and self._initialized and state.is_world_process_zero:
  800. from ..trainer import Trainer
  801. args_for_fake = copy.deepcopy(args)
  802. args_for_fake.deepspeed = None
  803. args_for_fake.deepspeed_plugin = None
  804. fake_trainer = Trainer(
  805. args=args_for_fake, model=model, processing_class=processing_class, eval_dataset=["fake"]
  806. )
  807. with tempfile.TemporaryDirectory() as temp_dir:
  808. fake_trainer.save_model(temp_dir)
  809. metadata = (
  810. {
  811. k: v
  812. for k, v in dict(self._wandb.summary).items()
  813. if isinstance(v, numbers.Number) and not k.startswith("_")
  814. }
  815. if not args.load_best_model_at_end
  816. else {
  817. f"eval/{args.metric_for_best_model}": state.best_metric,
  818. "train/total_floss": state.total_flos,
  819. "model/num_parameters": self._wandb.config.get("model/num_parameters"),
  820. }
  821. )
  822. metadata["final_model"] = True
  823. logger.info("Logging model artifacts. ...")
  824. model_name = (
  825. f"model-{self._wandb.run.id}"
  826. if (args.run_name is None or args.run_name == args.output_dir)
  827. else f"model-{self._wandb.run.name}"
  828. )
  829. # add the model architecture to a separate text file
  830. save_model_architecture_to_file(model, temp_dir)
  831. artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata)
  832. for f in Path(temp_dir).glob("*"):
  833. if f.is_file():
  834. with artifact.new_file(f.name, mode="wb") as fa:
  835. fa.write(f.read_bytes())
  836. self._wandb.run.log_artifact(artifact, aliases=["final_model"])
  837. def on_log(self, args, state, control, model=None, logs=None, **kwargs):
  838. single_value_scalars = [
  839. "train_runtime",
  840. "train_samples_per_second",
  841. "train_steps_per_second",
  842. "train_loss",
  843. "total_flos",
  844. ]
  845. if self._wandb is None:
  846. return
  847. if not self._initialized:
  848. self.setup(args, state, model)
  849. if state.is_world_process_zero:
  850. for k, v in logs.items():
  851. if k in single_value_scalars:
  852. self._wandb.run.summary[k] = v
  853. non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
  854. non_scalar_logs = rewrite_logs(non_scalar_logs)
  855. self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step})
  856. def on_save(self, args, state, control, **kwargs):
  857. if self._log_model == WandbLogModel.CHECKPOINT and self._initialized and state.is_world_process_zero:
  858. checkpoint_metadata = {
  859. k: v
  860. for k, v in dict(self._wandb.summary).items()
  861. if isinstance(v, numbers.Number) and not k.startswith("_")
  862. }
  863. checkpoint_metadata["model/num_parameters"] = self._wandb.config.get("model/num_parameters")
  864. ckpt_dir = f"checkpoint-{state.global_step}"
  865. artifact_path = os.path.join(args.output_dir, ckpt_dir)
  866. logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...")
  867. checkpoint_name = (
  868. f"model-{self._wandb.run.id}"
  869. if (args.run_name is None or args.run_name == args.output_dir)
  870. else f"model-{self._wandb.run.name}"
  871. )
  872. artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata)
  873. artifact.add_dir(artifact_path)
  874. self._wandb.log_artifact(
  875. artifact, aliases=[f"epoch_{round(state.epoch, 2)}", f"checkpoint_global_step_{state.global_step}"]
  876. )
  877. def on_predict(self, args, state, control, metrics, **kwargs):
  878. if self._wandb is None:
  879. return
  880. if not self._initialized:
  881. self.setup(args, state, **kwargs)
  882. if state.is_world_process_zero:
  883. metrics = rewrite_logs(metrics)
  884. self._wandb.log(metrics)
  885. class TrackioCallback(TrainerCallback):
  886. """
  887. A [`TrainerCallback`] that logs metrics to Trackio.
  888. It records training metrics, model (and PEFT) configuration, and GPU memory usage.
  889. If `nvidia-ml-py` is installed, GPU power consumption is also tracked.
  890. **Requires**:
  891. ```bash
  892. pip install trackio
  893. ```
  894. """
  895. def __init__(self):
  896. has_trackio = is_trackio_available()
  897. if not has_trackio:
  898. raise RuntimeError("TrackioCallback requires trackio to be installed. Run `pip install trackio`.")
  899. if has_trackio:
  900. import trackio
  901. self._trackio = trackio
  902. self._initialized = False
  903. def setup(self, args, state, model, **kwargs):
  904. """
  905. Setup the optional Trackio integration.
  906. To customize the setup you can also set the arguments `project`, `trackio_space_id` and `hub_private_repo` in
  907. [`TrainingArguments`]. Please refer to the docstring of for more details.
  908. """
  909. if state.is_world_process_zero:
  910. if os.getenv("TRACKIO_PROJECT"):
  911. logger.warning(
  912. "The `TRACKIO_PROJECT` environment variable is deprecated and will be removed in a future "
  913. "version. Use TrainingArguments.project instead."
  914. )
  915. project = os.getenv("TRACKIO_PROJECT")
  916. else:
  917. project = args.project
  918. if os.getenv("TRACKIO_SPACE_ID"):
  919. logger.warning(
  920. "The `TRACKIO_SPACE_ID` environment variable is deprecated and will be removed in a future "
  921. "version. Use TrainingArguments.trackio_space_id instead."
  922. )
  923. space_id = os.getenv("TRACKIO_SPACE_ID")
  924. else:
  925. space_id = args.trackio_space_id
  926. combined_dict = {**args.to_dict()}
  927. if hasattr(model, "config") and model.config is not None:
  928. model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
  929. combined_dict = {**model_config, **combined_dict}
  930. if hasattr(model, "peft_config") and model.peft_config is not None:
  931. peft_config = model.peft_config
  932. combined_dict = {**{"peft_config": peft_config}, **combined_dict}
  933. self._trackio.init(
  934. project=project,
  935. name=args.run_name,
  936. space_id=space_id,
  937. resume="allow",
  938. private=args.hub_private_repo,
  939. )
  940. # Add config parameters (run may have been created manually)
  941. self._trackio.config.update(combined_dict, allow_val_change=True)
  942. # Add number of model parameters to trackio config
  943. try:
  944. self._trackio.config["model/num_parameters"] = model.num_parameters()
  945. except AttributeError:
  946. logger.info("Could not log the number of model parameters in Trackio due to an AttributeError.")
  947. self._initialized = True
  948. def on_train_begin(self, args, state, control, model=None, **kwargs):
  949. if not self._initialized:
  950. self.setup(args, state, model, **kwargs)
  951. def on_train_end(self, args: TrainingArguments, state, control, model=None, processing_class=None, **kwargs):
  952. if state.is_world_process_zero and self._initialized:
  953. self._trackio.finish()
  954. def on_log(self, args, state, control, model=None, logs=None, **kwargs):
  955. single_value_scalars = [
  956. "train_runtime",
  957. "train_samples_per_second",
  958. "train_steps_per_second",
  959. "train_loss",
  960. "total_flos",
  961. ]
  962. if is_torch_available() and torch.cuda.is_available():
  963. device_idx = torch.cuda.current_device()
  964. total_memory = torch.cuda.get_device_properties(device_idx).total_memory
  965. memory_allocated = torch.cuda.memory_allocated(device_idx)
  966. gpu_memory_logs = {
  967. f"gpu/{device_idx}/allocated_memory": memory_allocated / (1024**3), # GB
  968. f"gpu/{device_idx}/memory_usage": memory_allocated / total_memory, # ratio
  969. }
  970. if _is_package_available("pynvml"):
  971. power = torch.cuda.power_draw(device_idx)
  972. gpu_memory_logs[f"gpu/{device_idx}/power"] = power / 1000 # Watts
  973. if dist.is_available() and dist.is_initialized():
  974. gathered_logs = [None] * dist.get_world_size()
  975. dist.all_gather_object(gathered_logs, gpu_memory_logs)
  976. gpu_memory_logs = {k: v for d in gathered_logs for k, v in d.items()}
  977. else:
  978. gpu_memory_logs = {}
  979. if not self._initialized:
  980. self.setup(args, state, model)
  981. if state.is_world_process_zero:
  982. non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
  983. non_scalar_logs = rewrite_logs(non_scalar_logs)
  984. self._trackio.log({**non_scalar_logs, **gpu_memory_logs, "train/global_step": state.global_step})
  985. def on_save(self, args, state, control, **kwargs):
  986. return
  987. def on_predict(self, args, state, control, metrics, **kwargs):
  988. if self._trackio is None:
  989. return
  990. if not self._initialized:
  991. self.setup(args, state, **kwargs)
  992. if state.is_world_process_zero:
  993. metrics = rewrite_logs(metrics)
  994. self._trackio.log(metrics)
  995. class CometCallback(TrainerCallback):
  996. """
  997. A [`TrainerCallback`] that sends the logs to [Comet ML](https://www.comet.com/site/).
  998. """
  999. def __init__(self):
  1000. if _is_comet_installed is False or _is_comet_recent_enough is False:
  1001. raise RuntimeError(
  1002. f"CometCallback requires comet-ml>={_MIN_COMET_VERSION} to be installed. Run `pip install comet-ml>={_MIN_COMET_VERSION}`."
  1003. )
  1004. self._initialized = False
  1005. self._log_assets = False
  1006. self._experiment = None
  1007. def setup(self, args, state, model):
  1008. """
  1009. Setup the optional Comet integration.
  1010. Environment:
  1011. - **COMET_MODE** (`str`, *optional*, default to `get_or_create`):
  1012. Control whether to create and log to a new Comet experiment or append to an existing experiment.
  1013. It accepts the following values:
  1014. * `get_or_create`: Decides automatically depending if
  1015. `COMET_EXPERIMENT_KEY` is set and whether an Experiment
  1016. with that key already exists or not.
  1017. * `create`: Always create a new Comet Experiment.
  1018. * `get`: Always try to append to an Existing Comet Experiment.
  1019. Requires `COMET_EXPERIMENT_KEY` to be set.
  1020. * `ONLINE`: **deprecated**, used to create an online
  1021. Experiment. Use `COMET_START_ONLINE=1` instead.
  1022. * `OFFLINE`: **deprecated**, used to created an offline
  1023. Experiment. Use `COMET_START_ONLINE=0` instead.
  1024. * `DISABLED`: **deprecated**, used to disable Comet logging.
  1025. Use the `--report_to` flag to control the integrations used
  1026. for logging result instead.
  1027. - **COMET_PROJECT_NAME** (`str`, *optional*):
  1028. Comet project name for experiments.
  1029. - **COMET_LOG_ASSETS** (`str`, *optional*, defaults to `TRUE`):
  1030. Whether or not to log training assets (tf event logs, checkpoints, etc), to Comet. Can be `TRUE`, or
  1031. `FALSE`.
  1032. For a number of configurable items in the environment, see
  1033. [here](https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options).
  1034. """
  1035. self._initialized = True
  1036. log_assets = os.getenv("COMET_LOG_ASSETS", "FALSE").upper()
  1037. if log_assets in {"TRUE", "1"}:
  1038. self._log_assets = True
  1039. if state.is_world_process_zero:
  1040. comet_old_mode = os.getenv("COMET_MODE")
  1041. mode = None
  1042. online = None
  1043. if comet_old_mode is not None:
  1044. comet_old_mode = comet_old_mode.lower()
  1045. if comet_old_mode == "online":
  1046. online = True
  1047. elif comet_old_mode == "offline":
  1048. online = False
  1049. elif comet_old_mode in ("get", "get_or_create", "create"):
  1050. mode = comet_old_mode
  1051. elif comet_old_mode:
  1052. logger.warning("Invalid COMET_MODE env value %r, Comet logging is disabled", comet_old_mode)
  1053. return
  1054. # For HPO, we always create a new experiment for each trial
  1055. if state.is_hyper_param_search:
  1056. if mode is not None:
  1057. logger.warning(
  1058. "Hyperparameter Search is enabled, forcing the creation of new experiments, COMET_MODE value %r is ignored",
  1059. comet_old_mode,
  1060. )
  1061. mode = "create"
  1062. import comet_ml
  1063. experiment_config = comet_ml.ExperimentConfig(name=args.run_name)
  1064. self._experiment = comet_ml.start(online=online, mode=mode, experiment_config=experiment_config)
  1065. self._experiment.__internal_api__set_model_graph__(model, framework="transformers")
  1066. params = {"args": args.to_dict()}
  1067. if hasattr(model, "config") and model.config is not None:
  1068. model_config = model.config.to_dict()
  1069. params["config"] = model_config
  1070. if hasattr(model, "peft_config") and model.peft_config is not None:
  1071. peft_config = model.peft_config
  1072. params["peft_config"] = peft_config
  1073. self._experiment.__internal_api__log_parameters__(
  1074. params, framework="transformers", source="manual", flatten_nested=True
  1075. )
  1076. if state.is_hyper_param_search:
  1077. optimization_id = getattr(state, "trial_name", None)
  1078. optimization_params = getattr(state, "trial_params", None)
  1079. self._experiment.log_optimization(optimization_id=optimization_id, parameters=optimization_params)
  1080. def on_train_begin(self, args, state, control, model=None, **kwargs):
  1081. if not self._initialized:
  1082. self.setup(args, state, model)
  1083. def on_log(self, args, state, control, model=None, logs=None, **kwargs):
  1084. if not self._initialized:
  1085. self.setup(args, state, model)
  1086. if state.is_world_process_zero:
  1087. if self._experiment is not None:
  1088. rewritten_logs = rewrite_logs(logs)
  1089. self._experiment.__internal_api__log_metrics__(
  1090. rewritten_logs, step=state.global_step, epoch=state.epoch, framework="transformers"
  1091. )
  1092. def on_train_end(self, args, state, control, **kwargs):
  1093. if self._initialized and state.is_world_process_zero:
  1094. if self._experiment is not None:
  1095. if self._log_assets is True:
  1096. logger.info("Logging checkpoints. This may take time.")
  1097. self._experiment.log_asset_folder(
  1098. args.output_dir, recursive=True, log_file_name=True, step=state.global_step
  1099. )
  1100. # We create one experiment per trial in HPO mode
  1101. if state.is_hyper_param_search:
  1102. self._experiment.clean()
  1103. self._initialized = False
  1104. def on_predict(self, args, state, control, metrics, **kwargs):
  1105. if not self._initialized:
  1106. self.setup(args, state, model=None)
  1107. if state.is_world_process_zero and self._experiment is not None:
  1108. rewritten_metrics = rewrite_logs(metrics)
  1109. self._experiment.__internal_api__log_metrics__(
  1110. rewritten_metrics, step=state.global_step, epoch=state.epoch, framework="transformers"
  1111. )
  1112. class AzureMLCallback(TrainerCallback):
  1113. """
  1114. A [`TrainerCallback`] that sends the logs to [AzureML](https://pypi.org/project/azureml-sdk/).
  1115. """
  1116. def __init__(self, azureml_run=None):
  1117. if not is_azureml_available():
  1118. raise RuntimeError("AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`.")
  1119. self.azureml_run = azureml_run
  1120. def on_init_end(self, args, state, control, **kwargs):
  1121. from azureml.core.run import Run
  1122. if self.azureml_run is None and state.is_world_process_zero:
  1123. self.azureml_run = Run.get_context()
  1124. def on_log(self, args, state, control, logs=None, **kwargs):
  1125. if self.azureml_run and state.is_world_process_zero:
  1126. for k, v in logs.items():
  1127. if isinstance(v, (int, float)):
  1128. self.azureml_run.log(k, v, description=k)
  1129. class MLflowCallback(TrainerCallback):
  1130. """
  1131. A [`TrainerCallback`] that sends the logs to [MLflow](https://www.mlflow.org/). Can be disabled by setting
  1132. environment variable `DISABLE_MLFLOW_INTEGRATION = TRUE`.
  1133. """
  1134. def __init__(self):
  1135. if not is_mlflow_available():
  1136. raise RuntimeError("MLflowCallback requires mlflow to be installed. Run `pip install mlflow`.")
  1137. import mlflow
  1138. self._MAX_PARAM_VAL_LENGTH = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH
  1139. self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH
  1140. self._initialized = False
  1141. self._auto_end_run = False
  1142. self._log_artifacts = False
  1143. self._ml_flow = mlflow
  1144. def setup(self, args, state, model):
  1145. """
  1146. Setup the optional MLflow integration.
  1147. Environment:
  1148. - **HF_MLFLOW_LOG_ARTIFACTS** (`str`, *optional*):
  1149. Whether to use MLflow `.log_artifact()` facility to log artifacts. This only makes sense if logging to a
  1150. remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy each saved checkpoint on each save in
  1151. [`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it without a remote
  1152. storage will just copy the files to your artifact location.
  1153. - **MLFLOW_TRACKING_URI** (`str`, *optional*):
  1154. Whether to store runs at a specific path or remote server. Unset by default, which skips setting the
  1155. tracking URI entirely.
  1156. - **MLFLOW_EXPERIMENT_NAME** (`str`, *optional*, defaults to `None`):
  1157. Whether to use an MLflow experiment_name under which to launch the run. Default to `None` which will point
  1158. to the `Default` experiment in MLflow. Otherwise, it is a case sensitive name of the experiment to be
  1159. activated. If an experiment with this name does not exist, a new experiment with this name is created.
  1160. - **MLFLOW_TAGS** (`str`, *optional*):
  1161. A string dump of a dictionary of key/value pair to be added to the MLflow run as tags. Example:
  1162. `os.environ['MLFLOW_TAGS']='{"release.candidate": "RC1", "release.version": "2.2.0"}'`.
  1163. - **MLFLOW_NESTED_RUN** (`str`, *optional*):
  1164. Whether to use MLflow nested runs. If set to `True` or *1*, will create a nested run inside the current
  1165. run.
  1166. - **MLFLOW_RUN_ID** (`str`, *optional*):
  1167. Allow to reattach to an existing run which can be useful when resuming training from a checkpoint. When
  1168. `MLFLOW_RUN_ID` environment variable is set, `start_run` attempts to resume a run with the specified run ID
  1169. and other parameters are ignored.
  1170. - **MLFLOW_FLATTEN_PARAMS** (`str`, *optional*, defaults to `False`):
  1171. Whether to flatten the parameters dictionary before logging.
  1172. - **MLFLOW_MAX_LOG_PARAMS** (`int`, *optional*):
  1173. Set the maximum number of parameters to log in the run.
  1174. """
  1175. self._log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
  1176. self._nested_run = os.getenv("MLFLOW_NESTED_RUN", "FALSE").upper() in ENV_VARS_TRUE_VALUES
  1177. self._tracking_uri = os.getenv("MLFLOW_TRACKING_URI", None)
  1178. self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
  1179. self._flatten_params = os.getenv("MLFLOW_FLATTEN_PARAMS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
  1180. self._run_id = os.getenv("MLFLOW_RUN_ID", None)
  1181. self._max_log_params = os.getenv("MLFLOW_MAX_LOG_PARAMS", None)
  1182. # "synchronous" flag is only available with mlflow version >= 2.8.0
  1183. # https://github.com/mlflow/mlflow/pull/9705
  1184. # https://github.com/mlflow/mlflow/releases/tag/v2.8.0
  1185. self._async_log = packaging.version.parse(self._ml_flow.__version__) >= packaging.version.parse("2.8.0")
  1186. logger.debug(
  1187. f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run},"
  1188. f" tracking_uri={self._tracking_uri}"
  1189. )
  1190. if state.is_world_process_zero:
  1191. if not self._ml_flow.is_tracking_uri_set():
  1192. if self._tracking_uri:
  1193. self._ml_flow.set_tracking_uri(self._tracking_uri)
  1194. logger.debug(f"MLflow tracking URI is set to {self._tracking_uri}")
  1195. else:
  1196. logger.debug(
  1197. "Environment variable `MLFLOW_TRACKING_URI` is not provided and therefore will not be"
  1198. " explicitly set."
  1199. )
  1200. else:
  1201. logger.debug(f"MLflow tracking URI is set to {self._ml_flow.get_tracking_uri()}")
  1202. if self._ml_flow.active_run() is None or self._nested_run or self._run_id:
  1203. if self._experiment_name:
  1204. # Use of set_experiment() ensure that Experiment is created if not exists
  1205. self._ml_flow.set_experiment(self._experiment_name)
  1206. self._ml_flow.start_run(run_name=args.run_name, nested=self._nested_run)
  1207. logger.debug(f"MLflow run started with run_id={self._ml_flow.active_run().info.run_id}")
  1208. self._auto_end_run = True
  1209. combined_dict = args.to_dict()
  1210. if hasattr(model, "config") and model.config is not None:
  1211. model_config = model.config.to_dict()
  1212. combined_dict = {**model_config, **combined_dict}
  1213. combined_dict = flatten_dict(combined_dict) if self._flatten_params else combined_dict
  1214. # remove params that are too long for MLflow
  1215. for name, value in list(combined_dict.items()):
  1216. # internally, all values are converted to str in MLflow
  1217. if len(str(value)) > self._MAX_PARAM_VAL_LENGTH:
  1218. logger.warning(
  1219. f'Trainer is attempting to log a value of "{value}" for key "{name}" as a parameter. MLflow\'s'
  1220. " log_param() only accepts values no longer than 250 characters so we dropped this attribute."
  1221. " You can use `MLFLOW_FLATTEN_PARAMS` environment variable to flatten the parameters and"
  1222. " avoid this message."
  1223. )
  1224. del combined_dict[name]
  1225. # MLflow cannot log more than 100 values in one go, so we have to split it
  1226. combined_dict_items = list(combined_dict.items())
  1227. if self._max_log_params and self._max_log_params.isdigit():
  1228. max_log_params = int(self._max_log_params)
  1229. if max_log_params < len(combined_dict_items):
  1230. logger.debug(
  1231. f"Reducing the number of parameters to log from {len(combined_dict_items)} to {max_log_params}."
  1232. )
  1233. combined_dict_items = combined_dict_items[:max_log_params]
  1234. for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH):
  1235. if self._async_log:
  1236. self._ml_flow.log_params(
  1237. dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]), synchronous=False
  1238. )
  1239. else:
  1240. self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
  1241. mlflow_tags = os.getenv("MLFLOW_TAGS", None)
  1242. if mlflow_tags:
  1243. mlflow_tags = json.loads(mlflow_tags)
  1244. self._ml_flow.set_tags(mlflow_tags)
  1245. self._initialized = True
  1246. def on_train_begin(self, args, state, control, model=None, **kwargs):
  1247. if not self._initialized:
  1248. self.setup(args, state, model)
  1249. def on_log(self, args, state, control, logs, model=None, **kwargs):
  1250. if not self._initialized:
  1251. self.setup(args, state, model)
  1252. if state.is_world_process_zero:
  1253. metrics = {}
  1254. for k, v in logs.items():
  1255. if isinstance(v, (int, float)):
  1256. metrics[k] = v
  1257. elif isinstance(v, torch.Tensor) and v.numel() == 1:
  1258. metrics[k] = v.item()
  1259. else:
  1260. logger.warning(
  1261. f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. '
  1262. "MLflow's log_metric() only accepts float and int types so we dropped this attribute."
  1263. )
  1264. # sanitize metric names to replace unsupported characters like parentheses
  1265. sanitized_metrics = {re.sub(r"[^0-9A-Za-z_\-\.\ :/]", "_", k): v for k, v in metrics.items()}
  1266. if self._async_log:
  1267. self._ml_flow.log_metrics(metrics=sanitized_metrics, step=state.global_step, synchronous=False)
  1268. else:
  1269. self._ml_flow.log_metrics(metrics=sanitized_metrics, step=state.global_step)
  1270. def on_train_end(self, args, state, control, **kwargs):
  1271. if self._initialized and state.is_world_process_zero:
  1272. if self._auto_end_run and self._ml_flow.active_run():
  1273. self._ml_flow.end_run()
  1274. def on_save(self, args, state, control, **kwargs):
  1275. if self._initialized and state.is_world_process_zero and self._log_artifacts:
  1276. ckpt_dir = f"checkpoint-{state.global_step}"
  1277. artifact_path = os.path.join(args.output_dir, ckpt_dir)
  1278. logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. This may take time.")
  1279. self._ml_flow.pyfunc.log_model(
  1280. ckpt_dir,
  1281. artifacts={"model_path": artifact_path},
  1282. python_model=self._ml_flow.pyfunc.PythonModel(),
  1283. )
  1284. def __del__(self):
  1285. # if the previous run is not terminated correctly, the fluent API will
  1286. # not let you start a new run before the previous one is killed
  1287. if (
  1288. self._auto_end_run
  1289. and callable(getattr(self._ml_flow, "active_run", None))
  1290. and self._ml_flow.active_run() is not None
  1291. ):
  1292. self._ml_flow.end_run()
  1293. class DagsHubCallback(MLflowCallback):
  1294. """
  1295. A [`TrainerCallback`] that logs to [DagsHub](https://dagshub.com/). Extends [`MLflowCallback`]
  1296. """
  1297. def __init__(self):
  1298. super().__init__()
  1299. if not is_dagshub_available():
  1300. raise ImportError("DagsHubCallback requires dagshub to be installed. Run `pip install dagshub`.")
  1301. from dagshub.upload import Repo
  1302. self.Repo = Repo
  1303. def setup(self, *args, **kwargs):
  1304. """
  1305. Setup the DagsHub's Logging integration.
  1306. Environment:
  1307. - **HF_DAGSHUB_LOG_ARTIFACTS** (`str`, *optional*):
  1308. Whether to save the data and model artifacts for the experiment. Default to `False`.
  1309. """
  1310. self.log_artifacts = os.getenv("HF_DAGSHUB_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
  1311. self.name = os.getenv("HF_DAGSHUB_MODEL_NAME") or "main"
  1312. self.remote = os.getenv("MLFLOW_TRACKING_URI")
  1313. self.repo = self.Repo(
  1314. owner=self.remote.split(os.sep)[-2],
  1315. name=self.remote.split(os.sep)[-1].split(".")[0],
  1316. branch=os.getenv("BRANCH") or "main",
  1317. )
  1318. self.path = Path("artifacts")
  1319. if self.remote is None:
  1320. raise RuntimeError(
  1321. "DagsHubCallback requires the `MLFLOW_TRACKING_URI` environment variable to be set. Did you run"
  1322. " `dagshub.init()`?"
  1323. )
  1324. super().setup(*args, **kwargs)
  1325. def on_train_end(self, args, state, control, **kwargs):
  1326. if self.log_artifacts:
  1327. if getattr(self, "train_dataloader", None):
  1328. torch.save(self.train_dataloader.dataset, os.path.join(args.output_dir, "dataset.pt"))
  1329. self.repo.directory(str(self.path)).add_dir(args.output_dir)
  1330. class NeptuneMissingConfiguration(Exception):
  1331. def __init__(self):
  1332. super().__init__(
  1333. """
  1334. ------ Unsupported ---- We were not able to create new runs. You provided a custom Neptune run to
  1335. `NeptuneCallback` with the `run` argument. For the integration to work fully, provide your `api_token` and
  1336. `project` by saving them as environment variables or passing them to the callback.
  1337. """
  1338. )
  1339. class NeptuneCallback(TrainerCallback):
  1340. """TrainerCallback that sends the logs to [Neptune](https://app.neptune.ai).
  1341. Args:
  1342. api_token (`str`, *optional*): Neptune API token obtained upon registration.
  1343. You can leave this argument out if you have saved your token to the `NEPTUNE_API_TOKEN` environment
  1344. variable (strongly recommended). See full setup instructions in the
  1345. [docs](https://docs.neptune.ai/setup/installation).
  1346. project (`str`, *optional*): Name of an existing Neptune project, in the form "workspace-name/project-name".
  1347. You can find and copy the name in Neptune from the project settings -> Properties. If None (default), the
  1348. value of the `NEPTUNE_PROJECT` environment variable is used.
  1349. name (`str`, *optional*): Custom name for the run.
  1350. base_namespace (`str`, *optional*, defaults to "finetuning"): In the Neptune run, the root namespace
  1351. that will contain all of the metadata logged by the callback.
  1352. log_parameters (`bool`, *optional*, defaults to `True`):
  1353. If True, logs all Trainer arguments and model parameters provided by the Trainer.
  1354. log_checkpoints (`str`, *optional*): If "same", uploads checkpoints whenever they are saved by the Trainer.
  1355. If "last", uploads only the most recently saved checkpoint. If "best", uploads the best checkpoint (among
  1356. the ones saved by the Trainer). If `None`, does not upload checkpoints.
  1357. run (`Run`, *optional*): Pass a Neptune run object if you want to continue logging to an existing run.
  1358. Read more about resuming runs in the [docs](https://docs.neptune.ai/logging/to_existing_object).
  1359. **neptune_run_kwargs (*optional*):
  1360. Additional keyword arguments to be passed directly to the
  1361. [`neptune.init_run()`](https://docs.neptune.ai/api/neptune#init_run) function when a new run is created.
  1362. For instructions and examples, see the [Transformers integration
  1363. guide](https://docs.neptune.ai/integrations/transformers) in the Neptune documentation.
  1364. """
  1365. integration_version_key = "source_code/integrations/transformers"
  1366. model_parameters_key = "model_parameters"
  1367. trial_name_key = "trial"
  1368. trial_params_key = "trial_params"
  1369. trainer_parameters_key = "trainer_parameters"
  1370. flat_metrics = {"train/epoch"}
  1371. def __init__(
  1372. self,
  1373. *,
  1374. api_token: Optional[str] = None,
  1375. project: Optional[str] = None,
  1376. name: Optional[str] = None,
  1377. base_namespace: str = "finetuning",
  1378. run=None,
  1379. log_parameters: bool = True,
  1380. log_checkpoints: Optional[str] = None,
  1381. **neptune_run_kwargs,
  1382. ):
  1383. if not is_neptune_available():
  1384. raise ValueError(
  1385. "NeptuneCallback requires the Neptune client library to be installed. "
  1386. "To install the library, run `pip install neptune`."
  1387. )
  1388. try:
  1389. from neptune import Run
  1390. from neptune.internal.utils import verify_type
  1391. except ImportError:
  1392. from neptune.new.internal.utils import verify_type
  1393. from neptune.new.metadata_containers.run import Run
  1394. verify_type("api_token", api_token, (str, type(None)))
  1395. verify_type("project", project, (str, type(None)))
  1396. verify_type("name", name, (str, type(None)))
  1397. verify_type("base_namespace", base_namespace, str)
  1398. verify_type("run", run, (Run, type(None)))
  1399. verify_type("log_parameters", log_parameters, bool)
  1400. verify_type("log_checkpoints", log_checkpoints, (str, type(None)))
  1401. self._base_namespace_path = base_namespace
  1402. self._log_parameters = log_parameters
  1403. self._log_checkpoints = log_checkpoints
  1404. self._initial_run: Optional[Run] = run
  1405. self._run = None
  1406. self._is_monitoring_run = False
  1407. self._run_id = None
  1408. self._force_reset_monitoring_run = False
  1409. self._init_run_kwargs = {"api_token": api_token, "project": project, "name": name, **neptune_run_kwargs}
  1410. self._volatile_checkpoints_dir = None
  1411. self._should_upload_checkpoint = self._log_checkpoints is not None
  1412. self._recent_checkpoint_path = None
  1413. if self._log_checkpoints in {"last", "best"}:
  1414. self._target_checkpoints_namespace = f"checkpoints/{self._log_checkpoints}"
  1415. self._should_clean_recently_uploaded_checkpoint = True
  1416. else:
  1417. self._target_checkpoints_namespace = "checkpoints"
  1418. self._should_clean_recently_uploaded_checkpoint = False
  1419. def _stop_run_if_exists(self):
  1420. if self._run:
  1421. self._run.stop()
  1422. del self._run
  1423. self._run = None
  1424. def _initialize_run(self, **additional_neptune_kwargs):
  1425. try:
  1426. from neptune import init_run
  1427. from neptune.exceptions import NeptuneMissingApiTokenException, NeptuneMissingProjectNameException
  1428. except ImportError:
  1429. from neptune.new import init_run
  1430. from neptune.new.exceptions import NeptuneMissingApiTokenException, NeptuneMissingProjectNameException
  1431. self._stop_run_if_exists()
  1432. try:
  1433. run_params = additional_neptune_kwargs.copy()
  1434. run_params.update(self._init_run_kwargs)
  1435. self._run = init_run(**run_params)
  1436. self._run_id = self._run["sys/id"].fetch()
  1437. except (NeptuneMissingProjectNameException, NeptuneMissingApiTokenException) as e:
  1438. raise NeptuneMissingConfiguration() from e
  1439. def _use_initial_run(self):
  1440. self._run = self._initial_run
  1441. self._is_monitoring_run = True
  1442. self._run_id = self._run["sys/id"].fetch()
  1443. self._initial_run = None
  1444. def _ensure_run_with_monitoring(self):
  1445. if self._initial_run is not None:
  1446. self._use_initial_run()
  1447. else:
  1448. if not self._force_reset_monitoring_run and self._is_monitoring_run:
  1449. return
  1450. if self._run and not self._is_monitoring_run and not self._force_reset_monitoring_run:
  1451. self._initialize_run(with_id=self._run_id)
  1452. self._is_monitoring_run = True
  1453. else:
  1454. self._initialize_run()
  1455. self._force_reset_monitoring_run = False
  1456. def _ensure_at_least_run_without_monitoring(self):
  1457. if self._initial_run is not None:
  1458. self._use_initial_run()
  1459. else:
  1460. if not self._run:
  1461. self._initialize_run(
  1462. with_id=self._run_id,
  1463. capture_stdout=False,
  1464. capture_stderr=False,
  1465. capture_hardware_metrics=False,
  1466. capture_traceback=False,
  1467. )
  1468. self._is_monitoring_run = False
  1469. @property
  1470. def run(self):
  1471. if self._run is None:
  1472. self._ensure_at_least_run_without_monitoring()
  1473. return self._run
  1474. @property
  1475. def _metadata_namespace(self):
  1476. return self.run[self._base_namespace_path]
  1477. def _log_integration_version(self):
  1478. self.run[NeptuneCallback.integration_version_key] = version
  1479. def _log_trainer_parameters(self, args):
  1480. self._metadata_namespace[NeptuneCallback.trainer_parameters_key] = args.to_sanitized_dict()
  1481. def _log_model_parameters(self, model):
  1482. from neptune.utils import stringify_unsupported
  1483. if model and hasattr(model, "config") and model.config is not None:
  1484. self._metadata_namespace[NeptuneCallback.model_parameters_key] = stringify_unsupported(
  1485. model.config.to_dict()
  1486. )
  1487. def _log_hyper_param_search_parameters(self, state):
  1488. if state and hasattr(state, "trial_name"):
  1489. self._metadata_namespace[NeptuneCallback.trial_name_key] = state.trial_name
  1490. if state and hasattr(state, "trial_params") and state.trial_params is not None:
  1491. self._metadata_namespace[NeptuneCallback.trial_params_key] = state.trial_params
  1492. def _log_model_checkpoint(self, source_directory: str, checkpoint: str):
  1493. target_path = relative_path = os.path.join(source_directory, checkpoint)
  1494. if self._volatile_checkpoints_dir is not None:
  1495. consistent_checkpoint_path = os.path.join(self._volatile_checkpoints_dir, checkpoint)
  1496. try:
  1497. # Remove leading ../ from a relative path.
  1498. cpkt_path = relative_path.replace("..", "").lstrip(os.path.sep)
  1499. copy_path = os.path.join(consistent_checkpoint_path, cpkt_path)
  1500. shutil.copytree(relative_path, copy_path)
  1501. target_path = consistent_checkpoint_path
  1502. except OSError as e:
  1503. logger.warning(
  1504. f"NeptuneCallback was unable to made a copy of checkpoint due to I/O exception: '{e}'. "
  1505. "Could fail trying to upload."
  1506. )
  1507. self._metadata_namespace[self._target_checkpoints_namespace].upload_files(target_path)
  1508. if self._should_clean_recently_uploaded_checkpoint and self._recent_checkpoint_path is not None:
  1509. self._metadata_namespace[self._target_checkpoints_namespace].delete_files(self._recent_checkpoint_path)
  1510. self._recent_checkpoint_path = relative_path
  1511. def on_init_end(self, args, state, control, **kwargs):
  1512. self._volatile_checkpoints_dir = None
  1513. if self._log_checkpoints and (args.overwrite_output_dir or args.save_total_limit is not None):
  1514. self._volatile_checkpoints_dir = tempfile.TemporaryDirectory().name
  1515. if self._log_checkpoints == "best" and not args.load_best_model_at_end:
  1516. raise ValueError("To save the best model checkpoint, the load_best_model_at_end argument must be enabled.")
  1517. def on_train_begin(self, args, state, control, model=None, **kwargs):
  1518. if not state.is_world_process_zero:
  1519. return
  1520. self._ensure_run_with_monitoring()
  1521. self._force_reset_monitoring_run = True
  1522. self._log_integration_version()
  1523. if self._log_parameters:
  1524. self._log_trainer_parameters(args)
  1525. self._log_model_parameters(model)
  1526. if state.is_hyper_param_search:
  1527. self._log_hyper_param_search_parameters(state)
  1528. def on_train_end(self, args, state, control, **kwargs):
  1529. self._stop_run_if_exists()
  1530. def __del__(self):
  1531. if self._volatile_checkpoints_dir is not None:
  1532. shutil.rmtree(self._volatile_checkpoints_dir, ignore_errors=True)
  1533. self._stop_run_if_exists()
  1534. def on_save(self, args, state, control, **kwargs):
  1535. if self._should_upload_checkpoint:
  1536. self._log_model_checkpoint(args.output_dir, f"checkpoint-{state.global_step}")
  1537. def on_evaluate(self, args, state, control, metrics=None, **kwargs):
  1538. if self._log_checkpoints == "best":
  1539. best_metric_name = args.metric_for_best_model
  1540. if not best_metric_name.startswith("eval_"):
  1541. best_metric_name = f"eval_{best_metric_name}"
  1542. metric_value = metrics.get(best_metric_name)
  1543. operator = np.greater if args.greater_is_better else np.less
  1544. self._should_upload_checkpoint = state.best_metric is None or operator(metric_value, state.best_metric)
  1545. @classmethod
  1546. def get_run(cls, trainer):
  1547. for callback in trainer.callback_handler.callbacks:
  1548. if isinstance(callback, cls):
  1549. return callback.run
  1550. raise Exception("The trainer doesn't have a NeptuneCallback configured.")
  1551. def on_log(self, args, state, control, logs: Optional[dict[str, float]] = None, **kwargs):
  1552. if not state.is_world_process_zero:
  1553. return
  1554. if logs is not None:
  1555. for name, value in rewrite_logs(logs).items():
  1556. if isinstance(value, (int, float)):
  1557. if name in NeptuneCallback.flat_metrics:
  1558. self._metadata_namespace[name] = value
  1559. else:
  1560. self._metadata_namespace[name].log(value, step=state.global_step)
  1561. class CodeCarbonCallback(TrainerCallback):
  1562. """
  1563. A [`TrainerCallback`] that tracks the CO2 emission of training.
  1564. """
  1565. def __init__(self):
  1566. if not is_codecarbon_available():
  1567. raise RuntimeError(
  1568. "CodeCarbonCallback requires `codecarbon` to be installed. Run `pip install codecarbon`."
  1569. )
  1570. elif torch.version.hip:
  1571. raise RuntimeError(
  1572. "CodeCarbonCallback requires `codecarbon` package, which is not compatible with AMD ROCm (https://github.com/mlco2/codecarbon/pull/490). When using the Trainer, please specify the `report_to` argument (https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments.report_to) to disable CodeCarbonCallback."
  1573. )
  1574. import codecarbon
  1575. self._codecarbon = codecarbon
  1576. self.tracker = None
  1577. def on_init_end(self, args, state, control, **kwargs):
  1578. if self.tracker is None and state.is_local_process_zero:
  1579. # CodeCarbon will automatically handle environment variables for configuration
  1580. self.tracker = self._codecarbon.EmissionsTracker(output_dir=args.output_dir)
  1581. def on_train_begin(self, args, state, control, model=None, **kwargs):
  1582. if self.tracker and state.is_local_process_zero:
  1583. self.tracker.start()
  1584. def on_train_end(self, args, state, control, **kwargs):
  1585. if self.tracker and state.is_local_process_zero:
  1586. self.tracker.stop()
  1587. class ClearMLCallback(TrainerCallback):
  1588. """
  1589. A [`TrainerCallback`] that sends the logs to [ClearML](https://clear.ml/).
  1590. Environment:
  1591. - **CLEARML_PROJECT** (`str`, *optional*, defaults to `HuggingFace Transformers`):
  1592. ClearML project name.
  1593. - **CLEARML_TASK** (`str`, *optional*, defaults to `Trainer`):
  1594. ClearML task name.
  1595. - **CLEARML_LOG_MODEL** (`bool`, *optional*, defaults to `False`):
  1596. Whether to log models as artifacts during training.
  1597. """
  1598. log_suffix = ""
  1599. _hparams_section = "Transformers"
  1600. _model_config_section = "Model Configuration"
  1601. _ignore_hparams_overrides = "_ignore_hparams_ui_overrides_"
  1602. _ignoge_model_config_overrides = "_ignore_model_config_ui_overrides_"
  1603. _model_config_description = "The configuration of model number {}."
  1604. _model_config_description_note = (
  1605. "Note that, when cloning this task and running it remotely,"
  1606. " the configuration might be applied to another model instead of this one."
  1607. " To avoid this, initialize the task externally by calling `Task.init`"
  1608. " before the `ClearMLCallback` is instantiated."
  1609. )
  1610. _train_run_counter = 0
  1611. _model_connect_counter = 0
  1612. _task_created_in_callback = False
  1613. _should_close_on_train_end = None
  1614. def __init__(self):
  1615. if is_clearml_available():
  1616. import clearml
  1617. self._clearml = clearml
  1618. else:
  1619. raise RuntimeError("ClearMLCallback requires 'clearml' to be installed. Run `pip install clearml`.")
  1620. self._initialized = False
  1621. self._clearml_task = None
  1622. self._log_model = False
  1623. self._checkpoints_saved = []
  1624. def setup(self, args, state, model, processing_class, **kwargs):
  1625. if self._clearml is None:
  1626. return
  1627. if self._initialized:
  1628. return
  1629. ClearMLCallback._train_run_counter += 1
  1630. ClearMLCallback._model_connect_counter += 1
  1631. ClearMLCallback.log_suffix = (
  1632. "" if ClearMLCallback._train_run_counter == 1 else "_" + str(ClearMLCallback._train_run_counter)
  1633. )
  1634. if state.is_world_process_zero:
  1635. logger.info("Automatic ClearML logging enabled.")
  1636. if self._clearml_task is None:
  1637. if ClearMLCallback._should_close_on_train_end is None:
  1638. if not self._clearml.Task.running_locally() or self._clearml.Task.current_task():
  1639. ClearMLCallback._should_close_on_train_end = False
  1640. else:
  1641. ClearMLCallback._should_close_on_train_end = True
  1642. # This might happen when running inside of a pipeline, where the task is already initialized
  1643. # from outside of Hugging Face
  1644. if self._clearml.Task.running_locally() and self._clearml.Task.current_task():
  1645. self._clearml_task = self._clearml.Task.current_task()
  1646. self._log_model = os.getenv(
  1647. "CLEARML_LOG_MODEL",
  1648. "FALSE" if not ClearMLCallback._task_created_in_callback else "TRUE",
  1649. ).upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
  1650. logger.info("External ClearML Task has been connected.")
  1651. else:
  1652. self._clearml_task = self._clearml.Task.init(
  1653. project_name=os.getenv("CLEARML_PROJECT", "HuggingFace Transformers"),
  1654. task_name=os.getenv("CLEARML_TASK", "Trainer"),
  1655. auto_connect_frameworks={"tensorboard": False, "pytorch": False},
  1656. output_uri=True,
  1657. )
  1658. self._log_model = os.getenv("CLEARML_LOG_MODEL", "TRUE").upper() in ENV_VARS_TRUE_VALUES.union(
  1659. {"TRUE"}
  1660. )
  1661. ClearMLCallback._task_created_in_callback = True
  1662. logger.info("ClearML Task has been initialized.")
  1663. self._initialized = True
  1664. suffixed_hparams_section = ClearMLCallback._hparams_section + ClearMLCallback.log_suffix
  1665. ignore_hparams_config_section = suffixed_hparams_section + "/" + ClearMLCallback._ignore_hparams_overrides
  1666. if self._clearml.Task.running_locally():
  1667. self._copy_training_args_as_hparams(args, suffixed_hparams_section)
  1668. self._clearml_task.set_parameter(
  1669. name=ignore_hparams_config_section,
  1670. value=True,
  1671. value_type=bool,
  1672. description=(
  1673. "If True, ignore Transformers hyperparameters overrides done in the UI/backend "
  1674. + "when running remotely. Otherwise, the overrides will be applied when running remotely"
  1675. ),
  1676. )
  1677. elif not self._clearml_task.get_parameter(ignore_hparams_config_section, default=True, cast=True):
  1678. self._clearml_task.connect(args, suffixed_hparams_section)
  1679. else:
  1680. self._copy_training_args_as_hparams(
  1681. args, ClearMLCallback._hparams_section + ClearMLCallback.log_suffix
  1682. )
  1683. if getattr(model, "config", None) is not None:
  1684. ignore_model_config_section = (
  1685. suffixed_hparams_section + "/" + ClearMLCallback._ignoge_model_config_overrides
  1686. )
  1687. configuration_object_description = ClearMLCallback._model_config_description.format(
  1688. ClearMLCallback._model_connect_counter
  1689. )
  1690. if ClearMLCallback._model_connect_counter != ClearMLCallback._train_run_counter:
  1691. configuration_object_description += " " + ClearMLCallback._model_config_description_note
  1692. if self._clearml.Task.running_locally():
  1693. self._clearml_task.set_parameter(
  1694. name=ignore_model_config_section,
  1695. value=True,
  1696. value_type=bool,
  1697. description=(
  1698. "If True, ignore Transformers model configuration overrides done in the UI/backend "
  1699. + "when running remotely. Otherwise, the overrides will be applied when running remotely"
  1700. ),
  1701. )
  1702. self._clearml_task.set_configuration_object(
  1703. name=ClearMLCallback._model_config_section + ClearMLCallback.log_suffix,
  1704. config_dict=model.config.to_dict(),
  1705. description=configuration_object_description,
  1706. )
  1707. elif not self._clearml_task.get_parameter(ignore_model_config_section, default=True, cast=True):
  1708. model.config = model.config.from_dict(
  1709. self._clearml_task.get_configuration_object_as_dict(
  1710. ClearMLCallback._model_config_section + ClearMLCallback.log_suffix
  1711. )
  1712. )
  1713. else:
  1714. self._clearml_task.set_configuration_object(
  1715. name=ClearMLCallback._model_config_section + ClearMLCallback.log_suffix,
  1716. config_dict=model.config.to_dict(),
  1717. description=configuration_object_description,
  1718. )
  1719. def on_train_begin(self, args, state, control, model=None, processing_class=None, **kwargs):
  1720. if self._clearml is None:
  1721. return
  1722. self._checkpoints_saved = []
  1723. if state.is_hyper_param_search:
  1724. self._initialized = False
  1725. if not self._initialized:
  1726. self.setup(args, state, model, processing_class, **kwargs)
  1727. def on_train_end(self, args, state, control, **kwargs):
  1728. if ClearMLCallback._should_close_on_train_end:
  1729. self._clearml_task.close()
  1730. ClearMLCallback._train_run_counter = 0
  1731. def on_log(self, args, state, control, model=None, processing_class=None, logs=None, **kwargs):
  1732. if self._clearml is None:
  1733. return
  1734. if not self._initialized:
  1735. self.setup(args, state, model, processing_class, **kwargs)
  1736. if state.is_world_process_zero:
  1737. eval_prefix = "eval_"
  1738. eval_prefix_len = len(eval_prefix)
  1739. test_prefix = "test_"
  1740. test_prefix_len = len(test_prefix)
  1741. single_value_scalars = [
  1742. "train_runtime",
  1743. "train_samples_per_second",
  1744. "train_steps_per_second",
  1745. "train_loss",
  1746. "total_flos",
  1747. "epoch",
  1748. ]
  1749. for k, v in logs.items():
  1750. if isinstance(v, (int, float)):
  1751. if k in single_value_scalars:
  1752. self._clearml_task.get_logger().report_single_value(
  1753. name=k + ClearMLCallback.log_suffix, value=v
  1754. )
  1755. elif k.startswith(eval_prefix):
  1756. self._clearml_task.get_logger().report_scalar(
  1757. title="eval" + ClearMLCallback.log_suffix,
  1758. series=k[eval_prefix_len:],
  1759. value=v,
  1760. iteration=state.global_step,
  1761. )
  1762. elif k.startswith(test_prefix):
  1763. self._clearml_task.get_logger().report_scalar(
  1764. title="test" + ClearMLCallback.log_suffix,
  1765. series=k[test_prefix_len:],
  1766. value=v,
  1767. iteration=state.global_step,
  1768. )
  1769. else:
  1770. self._clearml_task.get_logger().report_scalar(
  1771. title="train" + ClearMLCallback.log_suffix,
  1772. series=k,
  1773. value=v,
  1774. iteration=state.global_step,
  1775. )
  1776. else:
  1777. logger.warning(
  1778. "Trainer is attempting to log a value of "
  1779. f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
  1780. "This invocation of ClearML logger's report_scalar() "
  1781. "is incorrect so we dropped this attribute."
  1782. )
  1783. def on_save(self, args, state, control, **kwargs):
  1784. if self._log_model and self._clearml_task and state.is_world_process_zero:
  1785. ckpt_dir = f"checkpoint-{state.global_step}"
  1786. artifact_path = os.path.join(args.output_dir, ckpt_dir)
  1787. name = ckpt_dir + ClearMLCallback.log_suffix
  1788. logger.info(f"Logging checkpoint artifact `{name}`. This may take some time.")
  1789. output_model = self._clearml.OutputModel(task=self._clearml_task, name=name)
  1790. output_model.connect(task=self._clearml_task, name=name)
  1791. output_model.update_weights_package(
  1792. weights_path=artifact_path,
  1793. target_filename=ckpt_dir,
  1794. iteration=state.global_step,
  1795. auto_delete_file=False,
  1796. )
  1797. self._checkpoints_saved.append(output_model)
  1798. while args.save_total_limit and args.save_total_limit < len(self._checkpoints_saved):
  1799. try:
  1800. self._clearml.model.Model.remove(
  1801. self._checkpoints_saved[0],
  1802. delete_weights_file=True,
  1803. force=True,
  1804. raise_on_errors=True,
  1805. )
  1806. except Exception as e:
  1807. logger.warning(
  1808. f"Could not remove checkpoint `{self._checkpoints_saved[0].name}` after going over the `save_total_limit`. Error is: {e}"
  1809. )
  1810. break
  1811. self._checkpoints_saved = self._checkpoints_saved[1:]
  1812. def _copy_training_args_as_hparams(self, training_args, prefix):
  1813. as_dict = {
  1814. field.name: getattr(training_args, field.name)
  1815. for field in fields(training_args)
  1816. if field.init and not field.name.endswith("_token")
  1817. }
  1818. flat_dict = {str(k): v for k, v in self._clearml.utilities.proxy_object.flatten_dictionary(as_dict).items()}
  1819. self._clearml_task._arguments.copy_from_dict(flat_dict, prefix=prefix)
  1820. class FlyteCallback(TrainerCallback):
  1821. """A [`TrainerCallback`] that sends the logs to [Flyte](https://flyte.org/).
  1822. NOTE: This callback only works within a Flyte task.
  1823. Args:
  1824. save_log_history (`bool`, *optional*, defaults to `True`):
  1825. When set to True, the training logs are saved as a Flyte Deck.
  1826. sync_checkpoints (`bool`, *optional*, defaults to `True`):
  1827. When set to True, checkpoints are synced with Flyte and can be used to resume training in the case of an
  1828. interruption.
  1829. Example:
  1830. ```python
  1831. # Note: This example skips over some setup steps for brevity.
  1832. from flytekit import current_context, task
  1833. @task
  1834. def train_hf_transformer():
  1835. cp = current_context().checkpoint
  1836. trainer = Trainer(..., callbacks=[FlyteCallback()])
  1837. output = trainer.train(resume_from_checkpoint=cp.restore())
  1838. ```
  1839. """
  1840. def __init__(self, save_log_history: bool = True, sync_checkpoints: bool = True):
  1841. super().__init__()
  1842. if not is_flytekit_available():
  1843. raise ImportError("FlyteCallback requires flytekit to be installed. Run `pip install flytekit`.")
  1844. if not is_flyte_deck_standard_available() or not is_pandas_available():
  1845. logger.warning(
  1846. "Syncing log history requires both flytekitplugins-deck-standard and pandas to be installed. "
  1847. "Run `pip install flytekitplugins-deck-standard pandas` to enable this feature."
  1848. )
  1849. save_log_history = False
  1850. from flytekit import current_context
  1851. self.cp = current_context().checkpoint
  1852. self.save_log_history = save_log_history
  1853. self.sync_checkpoints = sync_checkpoints
  1854. def on_save(self, args, state, control, **kwargs):
  1855. if self.sync_checkpoints and state.is_world_process_zero:
  1856. ckpt_dir = f"checkpoint-{state.global_step}"
  1857. artifact_path = os.path.join(args.output_dir, ckpt_dir)
  1858. logger.info(f"Syncing checkpoint in {ckpt_dir} to Flyte. This may take time.")
  1859. self.cp.save(artifact_path)
  1860. def on_train_end(self, args, state, control, **kwargs):
  1861. if self.save_log_history:
  1862. import pandas as pd
  1863. from flytekit import Deck
  1864. from flytekitplugins.deck.renderer import TableRenderer
  1865. log_history_df = pd.DataFrame(state.log_history)
  1866. Deck("Log History", TableRenderer().to_html(log_history_df))
  1867. class DVCLiveCallback(TrainerCallback):
  1868. """
  1869. A [`TrainerCallback`] that sends the logs to [DVCLive](https://www.dvc.org/doc/dvclive).
  1870. Use the environment variables below in `setup` to configure the integration. To customize this callback beyond
  1871. those environment variables, see [here](https://dvc.org/doc/dvclive/ml-frameworks/huggingface).
  1872. Args:
  1873. live (`dvclive.Live`, *optional*, defaults to `None`):
  1874. Optional Live instance. If None, a new instance will be created using **kwargs.
  1875. log_model (Union[Literal["all"], bool], *optional*, defaults to `None`):
  1876. Whether to use `dvclive.Live.log_artifact()` to log checkpoints created by [`Trainer`]. If set to `True`,
  1877. the final checkpoint is logged at the end of training. If set to `"all"`, the entire
  1878. [`TrainingArguments`]'s `output_dir` is logged at each checkpoint.
  1879. """
  1880. def __init__(
  1881. self,
  1882. live: Optional[Any] = None,
  1883. log_model: Optional[Union[Literal["all"], bool]] = None,
  1884. **kwargs,
  1885. ):
  1886. if not is_dvclive_available():
  1887. raise RuntimeError("DVCLiveCallback requires dvclive to be installed. Run `pip install dvclive`.")
  1888. from dvclive import Live
  1889. self._initialized = False
  1890. self.live = None
  1891. if isinstance(live, Live):
  1892. self.live = live
  1893. elif live is not None:
  1894. raise RuntimeError(f"Found class {live.__class__} for live, expected dvclive.Live")
  1895. self._log_model = log_model
  1896. if self._log_model is None:
  1897. log_model_env = os.getenv("HF_DVCLIVE_LOG_MODEL", "FALSE")
  1898. if log_model_env.upper() in ENV_VARS_TRUE_VALUES:
  1899. self._log_model = True
  1900. elif log_model_env.lower() == "all":
  1901. self._log_model = "all"
  1902. def setup(self, args, state, model):
  1903. """
  1904. Setup the optional DVCLive integration. To customize this callback beyond the environment variables below, see
  1905. [here](https://dvc.org/doc/dvclive/ml-frameworks/huggingface).
  1906. Environment:
  1907. - **HF_DVCLIVE_LOG_MODEL** (`str`, *optional*):
  1908. Whether to use `dvclive.Live.log_artifact()` to log checkpoints created by [`Trainer`]. If set to `True` or
  1909. *1*, the final checkpoint is logged at the end of training. If set to `all`, the entire
  1910. [`TrainingArguments`]'s `output_dir` is logged at each checkpoint.
  1911. """
  1912. from dvclive import Live
  1913. self._initialized = True
  1914. if state.is_world_process_zero:
  1915. if not self.live:
  1916. self.live = Live()
  1917. self.live.log_params(args.to_dict())
  1918. def on_train_begin(self, args, state, control, model=None, **kwargs):
  1919. if not self._initialized:
  1920. self.setup(args, state, model)
  1921. def on_log(self, args, state, control, model=None, logs=None, **kwargs):
  1922. if not self._initialized:
  1923. self.setup(args, state, model)
  1924. if state.is_world_process_zero:
  1925. from dvclive.plots import Metric
  1926. from dvclive.utils import standardize_metric_name
  1927. for key, value in logs.items():
  1928. if Metric.could_log(value):
  1929. self.live.log_metric(standardize_metric_name(key, "dvclive.huggingface"), value)
  1930. else:
  1931. logger.warning(
  1932. "Trainer is attempting to log a value of "
  1933. f'"{value}" of type {type(value)} for key "{key}" as a scalar. '
  1934. "This invocation of DVCLive's Live.log_metric() "
  1935. "is incorrect so we dropped this attribute."
  1936. )
  1937. self.live.next_step()
  1938. def on_save(self, args, state, control, **kwargs):
  1939. if self._log_model == "all" and self._initialized and state.is_world_process_zero:
  1940. self.live.log_artifact(args.output_dir)
  1941. def on_train_end(self, args, state, control, **kwargs):
  1942. if self._initialized and state.is_world_process_zero:
  1943. from transformers.trainer import Trainer
  1944. if self._log_model is True:
  1945. fake_trainer = Trainer(
  1946. args=args,
  1947. model=kwargs.get("model"),
  1948. processing_class=kwargs.get("processing_class"),
  1949. eval_dataset=["fake"],
  1950. )
  1951. name = "best" if args.load_best_model_at_end else "last"
  1952. output_dir = os.path.join(args.output_dir, name)
  1953. fake_trainer.save_model(output_dir)
  1954. self.live.log_artifact(output_dir, name=name, type="model", copy=True)
  1955. self.live.end()
  1956. class SwanLabCallback(TrainerCallback):
  1957. """
  1958. A [`TrainerCallback`] that logs metrics, media, model checkpoints to [SwanLab](https://swanlab.cn/).
  1959. """
  1960. def __init__(self):
  1961. if not is_swanlab_available():
  1962. raise RuntimeError("SwanLabCallback requires swanlab to be installed. Run `pip install swanlab`.")
  1963. import swanlab
  1964. self._swanlab = swanlab
  1965. self._initialized = False
  1966. self._log_model = os.getenv("SWANLAB_LOG_MODEL", None)
  1967. def setup(self, args, state, model, **kwargs):
  1968. """
  1969. Setup the optional SwanLab (*swanlab*) integration.
  1970. One can subclass and override this method to customize the setup if needed. Find more information
  1971. [here](https://docs.swanlab.cn/guide_cloud/integration/integration-huggingface-transformers.html).
  1972. You can also override the following environment variables. Find more information about environment
  1973. variables [here](https://docs.swanlab.cn/en/api/environment-variable.html#environment-variables)
  1974. Environment:
  1975. - **SWANLAB_API_KEY** (`str`, *optional*, defaults to `None`):
  1976. Cloud API Key. During login, this environment variable is checked first. If it doesn't exist, the system
  1977. checks if the user is already logged in. If not, the login process is initiated.
  1978. - If a string is passed to the login interface, this environment variable is ignored.
  1979. - If the user is already logged in, this environment variable takes precedence over locally stored
  1980. login information.
  1981. - **SWANLAB_PROJECT** (`str`, *optional*, defaults to `None`):
  1982. Set this to a custom string to store results in a different project. If not specified, the name of the current
  1983. running directory is used.
  1984. - **SWANLAB_LOG_DIR** (`str`, *optional*, defaults to `swanlog`):
  1985. This environment variable specifies the storage path for log files when running in local mode.
  1986. By default, logs are saved in a folder named swanlog under the working directory.
  1987. - **SWANLAB_MODE** (`Literal["local", "cloud", "disabled"]`, *optional*, defaults to `cloud`):
  1988. SwanLab's parsing mode, which involves callbacks registered by the operator. Currently, there are three modes:
  1989. local, cloud, and disabled. Note: Case-sensitive. Find more information
  1990. [here](https://docs.swanlab.cn/en/api/py-init.html#swanlab-init)
  1991. - **SWANLAB_LOG_MODEL** (`str`, *optional*, defaults to `None`):
  1992. SwanLab does not currently support the save mode functionality.This feature will be available in a future
  1993. release
  1994. - **SWANLAB_WEB_HOST** (`str`, *optional*, defaults to `None`):
  1995. Web address for the SwanLab cloud environment for private version (its free)
  1996. - **SWANLAB_API_HOST** (`str`, *optional*, defaults to `None`):
  1997. API address for the SwanLab cloud environment for private version (its free)
  1998. """
  1999. self._initialized = True
  2000. if state.is_world_process_zero:
  2001. logger.info('Automatic SwanLab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"')
  2002. combined_dict = {**args.to_dict()}
  2003. if hasattr(model, "config") and model.config is not None:
  2004. model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
  2005. combined_dict = {**model_config, **combined_dict}
  2006. if hasattr(model, "peft_config") and model.peft_config is not None:
  2007. peft_config = model.peft_config
  2008. combined_dict = {**{"peft_config": peft_config}, **combined_dict}
  2009. trial_name = state.trial_name
  2010. init_args = {}
  2011. if trial_name is not None and args.run_name is not None:
  2012. init_args["experiment_name"] = f"{args.run_name}-{trial_name}"
  2013. elif args.run_name is not None:
  2014. init_args["experiment_name"] = args.run_name
  2015. elif trial_name is not None:
  2016. init_args["experiment_name"] = trial_name
  2017. init_args["project"] = os.getenv("SWANLAB_PROJECT", None)
  2018. if self._swanlab.get_run() is None:
  2019. self._swanlab.init(
  2020. **init_args,
  2021. )
  2022. # show transformers logo!
  2023. self._swanlab.config["FRAMEWORK"] = "🤗transformers"
  2024. # add config parameters (run may have been created manually)
  2025. self._swanlab.config.update(combined_dict)
  2026. # add number of model parameters to swanlab config
  2027. try:
  2028. self._swanlab.config.update({"model_num_parameters": model.num_parameters()})
  2029. # get peft model parameters
  2030. if type(model).__name__ == "PeftModel" or type(model).__name__ == "PeftMixedModel":
  2031. trainable_params, all_param = model.get_nb_trainable_parameters()
  2032. self._swanlab.config.update({"peft_model_trainable_params": trainable_params})
  2033. self._swanlab.config.update({"peft_model_all_param": all_param})
  2034. except AttributeError:
  2035. logger.info("Could not log the number of model parameters in SwanLab due to an AttributeError.")
  2036. # log the initial model architecture to an artifact
  2037. if self._log_model is not None:
  2038. logger.warning(
  2039. "SwanLab does not currently support the save mode functionality. "
  2040. "This feature will be available in a future release."
  2041. )
  2042. badge_markdown = (
  2043. f'[<img src="https://raw.githubusercontent.com/SwanHubX/assets/main/badge1.svg"'
  2044. f' alt="Visualize in SwanLab" height="28'
  2045. f'0" height="32"/>]({self._swanlab.get_run().public.cloud.experiment_url})'
  2046. )
  2047. modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
  2048. def on_train_begin(self, args, state, control, model=None, **kwargs):
  2049. if not self._initialized:
  2050. self.setup(args, state, model, **kwargs)
  2051. def on_train_end(self, args, state, control, model=None, processing_class=None, **kwargs):
  2052. if self._log_model is not None and self._initialized and state.is_world_process_zero:
  2053. logger.warning(
  2054. "SwanLab does not currently support the save mode functionality. "
  2055. "This feature will be available in a future release."
  2056. )
  2057. def on_log(self, args, state, control, model=None, logs=None, **kwargs):
  2058. single_value_scalars = [
  2059. "train_runtime",
  2060. "train_samples_per_second",
  2061. "train_steps_per_second",
  2062. "train_loss",
  2063. "total_flos",
  2064. ]
  2065. if not self._initialized:
  2066. self.setup(args, state, model)
  2067. if state.is_world_process_zero:
  2068. for k, v in logs.items():
  2069. if k in single_value_scalars:
  2070. self._swanlab.log({f"single_value/{k}": v}, step=state.global_step)
  2071. non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
  2072. non_scalar_logs = rewrite_logs(non_scalar_logs)
  2073. self._swanlab.log({**non_scalar_logs, "train/global_step": state.global_step}, step=state.global_step)
  2074. def on_save(self, args, state, control, **kwargs):
  2075. if self._log_model is not None and self._initialized and state.is_world_process_zero:
  2076. logger.warning(
  2077. "SwanLab does not currently support the save mode functionality. "
  2078. "This feature will be available in a future release."
  2079. )
  2080. def on_predict(self, args, state, control, metrics, **kwargs):
  2081. if not self._initialized:
  2082. self.setup(args, state, **kwargs)
  2083. if state.is_world_process_zero:
  2084. metrics = rewrite_logs(metrics)
  2085. self._swanlab.log(metrics)
  2086. INTEGRATION_TO_CALLBACK = {
  2087. "azure_ml": AzureMLCallback,
  2088. "comet_ml": CometCallback,
  2089. "mlflow": MLflowCallback,
  2090. "neptune": NeptuneCallback,
  2091. "tensorboard": TensorBoardCallback,
  2092. "trackio": TrackioCallback,
  2093. "wandb": WandbCallback,
  2094. "codecarbon": CodeCarbonCallback,
  2095. "clearml": ClearMLCallback,
  2096. "dagshub": DagsHubCallback,
  2097. "flyte": FlyteCallback,
  2098. "dvclive": DVCLiveCallback,
  2099. "swanlab": SwanLabCallback,
  2100. }
  2101. def get_reporting_integration_callbacks(report_to):
  2102. if report_to is None:
  2103. return []
  2104. if isinstance(report_to, str):
  2105. if "none" == report_to:
  2106. return []
  2107. elif "all" == report_to:
  2108. report_to = get_available_reporting_integrations()
  2109. else:
  2110. report_to = [report_to]
  2111. for integration in report_to:
  2112. if integration not in INTEGRATION_TO_CALLBACK:
  2113. raise ValueError(
  2114. f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported."
  2115. )
  2116. return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to]