trainer_utils.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911
  1. # Copyright 2020-present the HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. PyTorch-independent utilities for the Trainer class.
  16. """
  17. import copy
  18. import functools
  19. import gc
  20. import inspect
  21. import os
  22. import random
  23. import re
  24. import threading
  25. import time
  26. from typing import Any, Callable, NamedTuple, Optional, Union
  27. import numpy as np
  28. from .utils import (
  29. ExplicitEnum,
  30. is_psutil_available,
  31. is_tf_available,
  32. is_torch_available,
  33. is_torch_cuda_available,
  34. is_torch_hpu_available,
  35. is_torch_mlu_available,
  36. is_torch_mps_available,
  37. is_torch_musa_available,
  38. is_torch_npu_available,
  39. is_torch_xla_available,
  40. is_torch_xpu_available,
  41. requires_backends,
  42. )
  43. if is_torch_available():
  44. import torch
  45. def seed_worker(worker_id: int, num_workers: int, rank: int):
  46. """
  47. Helper function to set worker seed during Dataloader initialization.
  48. """
  49. init_seed = torch.initial_seed() % 2**32
  50. worker_seed = num_workers * rank + init_seed
  51. set_seed(worker_seed)
  52. def enable_full_determinism(seed: int, warn_only: bool = False):
  53. """
  54. Helper function for reproducible behavior during distributed training. See
  55. - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
  56. - https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism for tensorflow
  57. """
  58. # set seed first
  59. set_seed(seed)
  60. if is_torch_available():
  61. # Enable PyTorch deterministic mode. This potentially requires either the environment
  62. # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
  63. # depending on the CUDA version, so we set them both here
  64. os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
  65. os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
  66. # The environment variable required to enable deterministic mode on Ascend NPUs.
  67. os.environ["ASCEND_LAUNCH_BLOCKING"] = "1"
  68. os.environ["HCCL_DETERMINISTIC"] = "1"
  69. os.environ["FLASH_ATTENTION_DETERMINISTIC"] = "1"
  70. torch.use_deterministic_algorithms(True, warn_only=warn_only)
  71. # Enable CUDNN deterministic mode
  72. torch.backends.cudnn.deterministic = True
  73. torch.backends.cudnn.benchmark = False
  74. if is_tf_available():
  75. import tensorflow as tf
  76. tf.config.experimental.enable_op_determinism()
  77. def set_seed(seed: int, deterministic: bool = False):
  78. """
  79. Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).
  80. Args:
  81. seed (`int`):
  82. The seed to set.
  83. deterministic (`bool`, *optional*, defaults to `False`):
  84. Whether to use deterministic algorithms where available. Can slow down training.
  85. """
  86. random.seed(seed)
  87. np.random.seed(seed)
  88. if is_torch_available():
  89. torch.manual_seed(seed)
  90. torch.cuda.manual_seed_all(seed)
  91. # ^^ safe to call this function even if cuda is not available
  92. if deterministic:
  93. torch.use_deterministic_algorithms(True)
  94. if is_torch_mlu_available():
  95. torch.mlu.manual_seed_all(seed)
  96. if is_torch_musa_available():
  97. torch.musa.manual_seed_all(seed)
  98. if is_torch_npu_available():
  99. torch.npu.manual_seed_all(seed)
  100. if is_torch_hpu_available():
  101. torch.hpu.manual_seed_all(seed)
  102. if is_torch_xpu_available():
  103. torch.xpu.manual_seed_all(seed)
  104. if is_tf_available():
  105. import tensorflow as tf
  106. tf.random.set_seed(seed)
  107. if deterministic:
  108. tf.config.experimental.enable_op_determinism()
  109. def neftune_post_forward_hook(module, input, output):
  110. """
  111. Implements the NEFTune forward pass for the model using forward hooks. Note this works only for torch.nn.Embedding
  112. layers. This method is slightly adapted from the original source code that can be found here:
  113. https://github.com/neelsjain/NEFTune Simply add it to your model as follows:
  114. ```python
  115. model = ...
  116. model.embed_tokens.neftune_noise_alpha = 0.1
  117. model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
  118. ```
  119. Args:
  120. module (`torch.nn.Module`):
  121. The embedding module where the hook is attached. Note that you need to set `module.neftune_noise_alpha` to
  122. the desired noise alpha value.
  123. input (`torch.Tensor`):
  124. The input tensor to the model.
  125. output (`torch.Tensor`):
  126. The output tensor of the model (i.e. the embeddings).
  127. """
  128. if module.training:
  129. dims = torch.tensor(output.size(1) * output.size(2))
  130. mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
  131. output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
  132. return output
  133. class EvalPrediction:
  134. """
  135. Evaluation output (always contains labels), to be used to compute metrics.
  136. Parameters:
  137. predictions (`np.ndarray`): Predictions of the model.
  138. label_ids (`np.ndarray`): Targets to be matched.
  139. inputs (`np.ndarray`, *optional*): Input data passed to the model.
  140. losses (`np.ndarray`, *optional*): Loss values computed during evaluation.
  141. """
  142. def __init__(
  143. self,
  144. predictions: Union[np.ndarray, tuple[np.ndarray]],
  145. label_ids: Union[np.ndarray, tuple[np.ndarray]],
  146. inputs: Optional[Union[np.ndarray, tuple[np.ndarray]]] = None,
  147. losses: Optional[Union[np.ndarray, tuple[np.ndarray]]] = None,
  148. ):
  149. self.predictions = predictions
  150. self.label_ids = label_ids
  151. self.inputs = inputs
  152. self.losses = losses
  153. self.elements = (self.predictions, self.label_ids)
  154. if self.inputs is not None:
  155. self.elements += (self.inputs,)
  156. if self.losses is not None:
  157. self.elements += (self.losses,)
  158. def __iter__(self):
  159. return iter(self.elements)
  160. def __getitem__(self, idx):
  161. if idx < 0 or idx >= len(self.elements):
  162. raise IndexError("tuple index out of range")
  163. return self.elements[idx]
  164. class EvalLoopOutput(NamedTuple):
  165. predictions: Union[np.ndarray, tuple[np.ndarray]]
  166. label_ids: Optional[Union[np.ndarray, tuple[np.ndarray]]]
  167. metrics: Optional[dict[str, float]]
  168. num_samples: Optional[int]
  169. class PredictionOutput(NamedTuple):
  170. predictions: Union[np.ndarray, tuple[np.ndarray]]
  171. label_ids: Optional[Union[np.ndarray, tuple[np.ndarray]]]
  172. metrics: Optional[dict[str, float]]
  173. class TrainOutput(NamedTuple):
  174. global_step: int
  175. training_loss: float
  176. metrics: dict[str, float]
  177. PREFIX_CHECKPOINT_DIR = "checkpoint"
  178. _re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
  179. def get_last_checkpoint(folder):
  180. content = os.listdir(folder)
  181. checkpoints = [
  182. path
  183. for path in content
  184. if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path))
  185. ]
  186. if len(checkpoints) == 0:
  187. return
  188. return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
  189. class IntervalStrategy(ExplicitEnum):
  190. NO = "no"
  191. STEPS = "steps"
  192. EPOCH = "epoch"
  193. class SaveStrategy(ExplicitEnum):
  194. NO = "no"
  195. STEPS = "steps"
  196. EPOCH = "epoch"
  197. BEST = "best"
  198. class EvaluationStrategy(ExplicitEnum):
  199. NO = "no"
  200. STEPS = "steps"
  201. EPOCH = "epoch"
  202. class HubStrategy(ExplicitEnum):
  203. END = "end"
  204. EVERY_SAVE = "every_save"
  205. CHECKPOINT = "checkpoint"
  206. ALL_CHECKPOINTS = "all_checkpoints"
  207. class BestRun(NamedTuple):
  208. """
  209. The best run found by a hyperparameter search (see [`~Trainer.hyperparameter_search`]).
  210. Parameters:
  211. run_id (`str`):
  212. The id of the best run (if models were saved, the corresponding checkpoint will be in the folder ending
  213. with run-{run_id}).
  214. objective (`float`):
  215. The objective that was obtained for this run.
  216. hyperparameters (`dict[str, Any]`):
  217. The hyperparameters picked to get this run.
  218. run_summary (`Optional[Any]`):
  219. A summary of tuning experiments. `ray.tune.ExperimentAnalysis` object for Ray backend.
  220. """
  221. run_id: str
  222. objective: Union[float, list[float]]
  223. hyperparameters: dict[str, Any]
  224. run_summary: Optional[Any] = None
  225. def default_compute_objective(metrics: dict[str, float]) -> float:
  226. """
  227. The default objective to maximize/minimize when doing an hyperparameter search. It is the evaluation loss if no
  228. metrics are provided to the [`Trainer`], the sum of all metrics otherwise.
  229. Args:
  230. metrics (`dict[str, float]`): The metrics returned by the evaluate method.
  231. Return:
  232. `float`: The objective to minimize or maximize
  233. """
  234. metrics = copy.deepcopy(metrics)
  235. loss = metrics.pop("eval_loss", None)
  236. _ = metrics.pop("epoch", None)
  237. # Remove speed metrics
  238. speed_metrics = [
  239. m for m in metrics if m.endswith("_runtime") or m.endswith("_per_second") or m.endswith("_compilation_time")
  240. ]
  241. for sm in speed_metrics:
  242. _ = metrics.pop(sm, None)
  243. return loss if len(metrics) == 0 else sum(metrics.values())
  244. def default_hp_space_optuna(trial) -> dict[str, float]:
  245. from .integrations import is_optuna_available
  246. assert is_optuna_available(), "This function needs Optuna installed: `pip install optuna`"
  247. return {
  248. "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
  249. "num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5),
  250. "seed": trial.suggest_int("seed", 1, 40),
  251. "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16, 32, 64]),
  252. }
  253. def default_hp_space_ray(trial) -> dict[str, Any]:
  254. from .integrations import is_ray_tune_available
  255. assert is_ray_tune_available(), "This function needs ray installed: `pip install ray[tune]`"
  256. from ray import tune
  257. return {
  258. "learning_rate": tune.loguniform(1e-6, 1e-4),
  259. "num_train_epochs": tune.choice(list(range(1, 6))),
  260. "seed": tune.uniform(1, 40),
  261. "per_device_train_batch_size": tune.choice([4, 8, 16, 32, 64]),
  262. }
  263. def default_hp_space_sigopt(trial):
  264. return [
  265. {"bounds": {"min": 1e-6, "max": 1e-4}, "name": "learning_rate", "type": "double", "transformation": "log"},
  266. {"bounds": {"min": 1, "max": 6}, "name": "num_train_epochs", "type": "int"},
  267. {"bounds": {"min": 1, "max": 40}, "name": "seed", "type": "int"},
  268. {
  269. "categorical_values": ["4", "8", "16", "32", "64"],
  270. "name": "per_device_train_batch_size",
  271. "type": "categorical",
  272. },
  273. ]
  274. def default_hp_space_wandb(trial) -> dict[str, Any]:
  275. from .integrations import is_wandb_available
  276. if not is_wandb_available():
  277. raise ImportError("This function needs wandb installed: `pip install wandb`")
  278. return {
  279. "method": "random",
  280. "metric": {"name": "objective", "goal": "minimize"},
  281. "parameters": {
  282. "learning_rate": {"distribution": "uniform", "min": 1e-6, "max": 1e-4},
  283. "num_train_epochs": {"distribution": "int_uniform", "min": 1, "max": 6},
  284. "seed": {"distribution": "int_uniform", "min": 1, "max": 40},
  285. "per_device_train_batch_size": {"values": [4, 8, 16, 32, 64]},
  286. },
  287. }
  288. class HPSearchBackend(ExplicitEnum):
  289. OPTUNA = "optuna"
  290. RAY = "ray"
  291. SIGOPT = "sigopt"
  292. WANDB = "wandb"
  293. def is_main_process(local_rank):
  294. """
  295. Whether or not the current process is the local process, based on `xr.global_ordinal()` (for TPUs) first, then on
  296. `local_rank`.
  297. """
  298. if is_torch_xla_available():
  299. import torch_xla.runtime as xr
  300. return xr.global_ordinal() == 0
  301. return local_rank in [-1, 0]
  302. def total_processes_number(local_rank):
  303. """
  304. Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
  305. """
  306. if is_torch_xla_available():
  307. import torch_xla.runtime as xr
  308. return xr.world_size()
  309. elif local_rank != -1 and is_torch_available():
  310. import torch
  311. return torch.distributed.get_world_size()
  312. return 1
  313. def speed_metrics(split, start_time, num_samples=None, num_steps=None, num_tokens=None):
  314. """
  315. Measure and return speed performance metrics.
  316. This function requires a time snapshot `start_time` before the operation to be measured starts and this function
  317. should be run immediately after the operation to be measured has completed.
  318. Args:
  319. - split: name to prefix metric (like train, eval, test...)
  320. - start_time: operation start time
  321. - num_samples: number of samples processed
  322. - num_steps: number of steps processed
  323. - num_tokens: number of tokens processed
  324. """
  325. runtime = time.time() - start_time
  326. result = {f"{split}_runtime": round(runtime, 4)}
  327. if runtime == 0:
  328. return result
  329. if num_samples is not None:
  330. samples_per_second = num_samples / runtime
  331. result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
  332. if num_steps is not None:
  333. steps_per_second = num_steps / runtime
  334. result[f"{split}_steps_per_second"] = round(steps_per_second, 3)
  335. if num_tokens is not None:
  336. tokens_per_second = num_tokens / runtime
  337. result[f"{split}_tokens_per_second"] = round(tokens_per_second, 3)
  338. return result
  339. class SchedulerType(ExplicitEnum):
  340. """
  341. Scheduler names for the parameter `lr_scheduler_type` in [`TrainingArguments`].
  342. By default, it uses "linear". Internally, this retrieves `get_linear_schedule_with_warmup` scheduler from [`Trainer`].
  343. Scheduler types:
  344. - "linear" = [`get_linear_schedule_with_warmup`]
  345. - "cosine" = [`get_cosine_schedule_with_warmup`]
  346. - "cosine_with_restarts" = [`get_cosine_with_hard_restarts_schedule_with_warmup`]
  347. - "polynomial" = [`get_polynomial_decay_schedule_with_warmup`]
  348. - "constant" = [`get_constant_schedule`]
  349. - "constant_with_warmup" = [`get_constant_schedule_with_warmup`]
  350. - "inverse_sqrt" = [`get_inverse_sqrt_schedule`]
  351. - "reduce_lr_on_plateau" = [`get_reduce_on_plateau_schedule`]
  352. - "cosine_with_min_lr" = [`get_cosine_with_min_lr_schedule_with_warmup`]
  353. - "cosine_warmup_with_min_lr" = [`get_cosine_with_min_lr_schedule_with_warmup_lr_rate`]
  354. - "warmup_stable_decay" = [`get_wsd_schedule`]
  355. """
  356. LINEAR = "linear"
  357. COSINE = "cosine"
  358. COSINE_WITH_RESTARTS = "cosine_with_restarts"
  359. POLYNOMIAL = "polynomial"
  360. CONSTANT = "constant"
  361. CONSTANT_WITH_WARMUP = "constant_with_warmup"
  362. INVERSE_SQRT = "inverse_sqrt"
  363. REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"
  364. COSINE_WITH_MIN_LR = "cosine_with_min_lr"
  365. COSINE_WARMUP_WITH_MIN_LR = "cosine_warmup_with_min_lr"
  366. WARMUP_STABLE_DECAY = "warmup_stable_decay"
  367. class TrainerMemoryTracker:
  368. """
  369. A helper class that tracks cpu and gpu memory.
  370. This class will silently skip unless `psutil` is available. Install with `pip install psutil`.
  371. When a stage completes, it can pass metrics dict to update with the memory metrics gathered during this stage.
  372. Example :
  373. ```python
  374. self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
  375. self._memory_tracker.start()
  376. # code ...
  377. metrics = {"train_runtime": 10.5}
  378. self._memory_tracker.stop_and_update_metrics(metrics)
  379. ```
  380. At the moment GPU tracking is only for `pytorch`, but can be extended to support `tensorflow`.
  381. To understand this class' intricacies please read the documentation of [`~Trainer.log_metrics`].
  382. """
  383. # map trainer methods to metrics prefix
  384. stages = {
  385. "__init__": "init",
  386. "train": "train",
  387. "_inner_training_loop": "train",
  388. "evaluate": "eval",
  389. "predict": "test",
  390. }
  391. def __init__(self, skip_memory_metrics=False):
  392. self.skip_memory_metrics = skip_memory_metrics
  393. if not is_psutil_available():
  394. # soft dependency on psutil
  395. self.skip_memory_metrics = True
  396. if self.skip_memory_metrics:
  397. return
  398. import psutil
  399. if is_torch_cuda_available() or is_torch_mlu_available() or is_torch_musa_available():
  400. import torch
  401. self.torch = torch
  402. self.gpu = {}
  403. elif is_torch_mps_available():
  404. import torch
  405. self.torch = torch
  406. self.gpu = {}
  407. elif is_torch_xpu_available():
  408. import torch
  409. self.torch = torch
  410. self.gpu = {}
  411. elif is_torch_npu_available():
  412. import torch
  413. self.torch = torch
  414. self.gpu = {}
  415. elif is_torch_hpu_available():
  416. import torch
  417. self.torch = torch
  418. self.gpu = {}
  419. else:
  420. self.torch = None
  421. self.process = psutil.Process()
  422. self.cur_stage = None
  423. self.cpu = {}
  424. self.init_reported = False
  425. def derive_stage(self):
  426. """derives the stage/caller name automatically"""
  427. caller = inspect.currentframe().f_back.f_back.f_code.co_name
  428. if caller in self.stages:
  429. return self.stages[caller]
  430. else:
  431. raise ValueError(
  432. f"was called from {caller}, but only expect to be called from one of {self.stages.keys()}"
  433. )
  434. def cpu_mem_used(self):
  435. """get resident set size memory for the current process"""
  436. return self.process.memory_info().rss
  437. def peak_monitor_func(self):
  438. self.cpu_mem_used_peak = -1
  439. while True:
  440. self.cpu_mem_used_peak = max(self.cpu_mem_used(), self.cpu_mem_used_peak)
  441. # can't sleep or will not catch the peak right (this comment is here on purpose)
  442. # time.sleep(0.001) # 1msec
  443. if not self.peak_monitoring:
  444. break
  445. def start(self):
  446. """start tracking for the caller's stage"""
  447. if self.skip_memory_metrics:
  448. return
  449. stage = self.derive_stage()
  450. # deal with nested calls of eval during train - simply ignore those
  451. if self.cur_stage is not None and self.cur_stage != stage:
  452. return
  453. self.cur_stage = stage
  454. gc.collect()
  455. if self.torch is not None:
  456. if torch.cuda.is_available():
  457. self.torch.cuda.reset_peak_memory_stats()
  458. self.torch.cuda.empty_cache()
  459. elif is_torch_mlu_available():
  460. self.torch.mlu.reset_peak_memory_stats()
  461. self.torch.mlu.empty_cache()
  462. elif is_torch_musa_available():
  463. self.torch.musa.reset_peak_memory_stats()
  464. self.torch.musa.empty_cache()
  465. elif is_torch_xpu_available():
  466. self.torch.xpu.reset_peak_memory_stats()
  467. self.torch.xpu.empty_cache()
  468. elif is_torch_npu_available():
  469. self.torch.npu.reset_peak_memory_stats()
  470. self.torch.npu.empty_cache()
  471. elif is_torch_hpu_available():
  472. self.torch.hpu.reset_peak_memory_stats()
  473. # not available on hpu as it reserves all device memory for the current process
  474. # self.torch.hpu.empty_cache()
  475. elif is_torch_mps_available():
  476. self.torch.mps.empty_cache()
  477. # gpu
  478. if self.torch is not None:
  479. if torch.cuda.is_available():
  480. self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
  481. elif is_torch_mlu_available():
  482. self.gpu_mem_used_at_start = self.torch.mlu.memory_allocated()
  483. elif is_torch_musa_available():
  484. self.gpu_mem_used_at_start = self.torch.musa.memory_allocated()
  485. elif is_torch_xpu_available():
  486. self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated()
  487. elif is_torch_npu_available():
  488. self.gpu_mem_used_at_start = self.torch.npu.memory_allocated()
  489. elif is_torch_hpu_available():
  490. self.gpu_mem_used_at_start = self.torch.hpu.memory_allocated()
  491. elif is_torch_mps_available():
  492. self.gpu_mem_used_at_start = self.torch.mps.current_allocated_memory()
  493. # cpu
  494. self.cpu_mem_used_at_start = self.cpu_mem_used()
  495. self.peak_monitoring = True
  496. peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
  497. peak_monitor_thread.daemon = True
  498. peak_monitor_thread.start()
  499. def stop(self, stage):
  500. """stop tracking for the passed stage"""
  501. # deal with nested calls of eval during train - simply ignore those
  502. if self.cur_stage is not None and self.cur_stage != stage:
  503. return
  504. # this sends a signal to peak_monitor_func to complete its loop
  505. self.peak_monitoring = False
  506. # first ensure all objects get collected and their memory is freed
  507. gc.collect()
  508. if self.torch is not None:
  509. if torch.cuda.is_available():
  510. self.torch.cuda.empty_cache()
  511. elif is_torch_mlu_available():
  512. self.torch.mlu.empty_cache()
  513. elif is_torch_musa_available():
  514. self.torch.musa.empty_cache()
  515. elif is_torch_xpu_available():
  516. self.torch.xpu.empty_cache()
  517. elif is_torch_npu_available():
  518. self.torch.npu.empty_cache()
  519. elif is_torch_hpu_available():
  520. # not available on hpu as it reserves all device memory for the current process
  521. # self.torch.npu.empty_cache()
  522. pass
  523. elif is_torch_mps_available():
  524. self.torch.mps.empty_cache()
  525. # concepts:
  526. # - alloc_delta: the difference of allocated memory between the end and the start
  527. # - peaked_delta: the difference between the peak memory and the current memory
  528. # in order to know how much memory the measured code consumed one needs to sum these two
  529. # gpu
  530. if self.torch is not None:
  531. if torch.cuda.is_available():
  532. self.gpu_mem_used_now = self.torch.cuda.memory_allocated()
  533. self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated()
  534. elif is_torch_mlu_available():
  535. self.gpu_mem_used_now = self.torch.mlu.memory_allocated()
  536. self.gpu_mem_used_peak = self.torch.mlu.max_memory_allocated()
  537. elif is_torch_musa_available():
  538. self.gpu_mem_used_now = self.torch.musa.memory_allocated()
  539. self.gpu_mem_used_peak = self.torch.musa.max_memory_allocated()
  540. elif is_torch_xpu_available():
  541. self.gpu_mem_used_now = self.torch.xpu.memory_allocated()
  542. self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated()
  543. elif is_torch_npu_available():
  544. self.gpu_mem_used_now = self.torch.npu.memory_allocated()
  545. self.gpu_mem_used_peak = self.torch.npu.max_memory_allocated()
  546. elif is_torch_hpu_available():
  547. self.gpu_mem_used_now = self.torch.hpu.memory_allocated()
  548. self.gpu_mem_used_peak = self.torch.hpu.max_memory_allocated()
  549. elif is_torch_mps_available():
  550. self.gpu_mem_used_now = self.torch.mps.current_allocated_memory()
  551. # self.torch.mps.max_memory_allocated() does not exist yet
  552. self.gpu_mem_used_peak = None
  553. else:
  554. raise ValueError("No available GPU device found!")
  555. self.gpu[self.cur_stage] = {
  556. "begin": self.gpu_mem_used_at_start,
  557. "end": self.gpu_mem_used_now,
  558. "alloc": (self.gpu_mem_used_now - self.gpu_mem_used_at_start),
  559. }
  560. if self.gpu_mem_used_peak is not None:
  561. self.gpu[self.cur_stage]["peaked"] = max(0, self.gpu_mem_used_peak - self.gpu_mem_used_now)
  562. else:
  563. self.gpu[self.cur_stage]["peaked"] = "Not available"
  564. # cpu
  565. self.cpu_mem_used_now = self.cpu_mem_used()
  566. self.cpu[self.cur_stage] = {
  567. "begin": self.cpu_mem_used_at_start,
  568. "end": self.cpu_mem_used_now,
  569. "alloc": (self.cpu_mem_used_now - self.cpu_mem_used_at_start),
  570. "peaked": max(0, self.cpu_mem_used_peak - self.cpu_mem_used_now),
  571. }
  572. # reset - cycle finished
  573. self.cur_stage = None
  574. def update_metrics(self, stage, metrics):
  575. """updates the metrics"""
  576. if self.skip_memory_metrics:
  577. return
  578. # deal with nested calls of eval during train - simply ignore those
  579. if self.cur_stage is not None and self.cur_stage != stage:
  580. return
  581. # since we don't have a way to return init metrics, we push them into the first of train/val/predict
  582. stages = [stage]
  583. if not self.init_reported:
  584. stages.insert(0, "init")
  585. self.init_reported = True
  586. for stage in stages:
  587. for t in ["alloc", "peaked"]:
  588. if stage in self.cpu and t in self.cpu[stage]:
  589. metrics[f"{stage}_mem_cpu_{t}_delta"] = self.cpu[stage][t]
  590. if self.torch is not None and stage in self.gpu and t in self.gpu[stage]:
  591. metrics[f"{stage}_mem_gpu_{t}_delta"] = self.gpu[stage][t]
  592. # if we need additional debug info, enable the following
  593. # for t in ["begin", "end"]:
  594. # if stage in self.cpu and t in self.cpu[stage]:
  595. # metrics[f"{stage}_mem_cpu_{t}"] = self.cpu[stage][t]
  596. # if self.torch is not None and stage in self.gpu and t in self.gpu[stage]:
  597. # metrics[f"{stage}_mem_gpu_{t}"] = self.gpu[stage][t]
  598. # since memory can be allocated before init, and it might be difficult to track overall
  599. # memory usage, in particular for GPU, let's report memory usage at the point init was called
  600. if stages[0] == "init":
  601. metrics["before_init_mem_cpu"] = self.cpu["init"]["begin"]
  602. if self.torch is not None:
  603. metrics["before_init_mem_gpu"] = self.gpu["init"]["begin"]
  604. # if we also wanted to report any additional memory allocations in between init and
  605. # whatever the next stage was we could also report this:
  606. # if self.cpu["init"]["end"] != self.cpu[stage]["begin"]:
  607. # metrics[f"after_init_mem_cpu_delta"] = self.cpu[stage]["begin"] - self.cpu["init"]["end"]
  608. # if self.torch is not None and self.gpu["init"]["end"] != self.gpu[stage]["begin"]:
  609. # metrics[f"after_init_mem_gpu_delta"] = self.gpu[stage]["begin"] - self.gpu["init"]["end"]
  610. def stop_and_update_metrics(self, metrics=None):
  611. """combine stop and metrics update in one call for simpler code"""
  612. if self.skip_memory_metrics:
  613. return
  614. stage = self.derive_stage()
  615. self.stop(stage)
  616. # init doesn't have metrics to update so we just save that data for later stages to retrieve
  617. if metrics is not None:
  618. self.update_metrics(stage, metrics)
  619. def has_length(dataset):
  620. """
  621. Checks if the dataset implements __len__() and it doesn't raise an error
  622. """
  623. try:
  624. return len(dataset) is not None
  625. except TypeError:
  626. # TypeError: len() of unsized object
  627. return False
  628. except AttributeError:
  629. # Ray DataSets raises an AttributeError: https://github.com/ray-project/ray/blob/master/python/ray/data/dataset.py#L5616
  630. return False
  631. def denumpify_detensorize(metrics):
  632. """
  633. Recursively calls `.item()` on the element of the dictionary passed
  634. """
  635. if isinstance(metrics, (list, tuple)):
  636. return type(metrics)(denumpify_detensorize(m) for m in metrics)
  637. elif isinstance(metrics, dict):
  638. return type(metrics)({k: denumpify_detensorize(v) for k, v in metrics.items()})
  639. elif isinstance(metrics, np.generic):
  640. return metrics.item()
  641. elif is_torch_available() and isinstance(metrics, torch.Tensor) and metrics.numel() == 1:
  642. return metrics.item()
  643. return metrics
  644. def number_of_arguments(func):
  645. """
  646. Return the number of arguments of the passed function, even if it's a partial function.
  647. """
  648. if isinstance(func, functools.partial):
  649. total_args = len(inspect.signature(func.func).parameters)
  650. return total_args - len(func.args) - len(func.keywords)
  651. return len(inspect.signature(func).parameters)
  652. def find_executable_batch_size(
  653. function: Optional[Callable] = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False
  654. ):
  655. """
  656. Args:
  657. A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
  658. CUDNN, the batch size is multiplied by 0.9 and passed to `function`. `function` must take in a `batch_size` parameter as
  659. its first argument.
  660. function (`Callable`, *optional*)
  661. A function to wrap
  662. starting_batch_size (`int`, *optional*)
  663. The batch size to try and fit into memory
  664. auto_find_batch_size (`bool`, *optional*)
  665. If False, will just execute `function`
  666. """
  667. if function is None:
  668. return functools.partial(
  669. find_executable_batch_size,
  670. starting_batch_size=starting_batch_size,
  671. auto_find_batch_size=auto_find_batch_size,
  672. )
  673. if auto_find_batch_size:
  674. requires_backends(find_executable_batch_size, "accelerate")
  675. from accelerate.utils import find_executable_batch_size as accelerate_find_executable_batch_size
  676. return accelerate_find_executable_batch_size(function=function, starting_batch_size=starting_batch_size)
  677. return functools.partial(function, batch_size=starting_batch_size)
  678. class FSDPOption(ExplicitEnum):
  679. FULL_SHARD = "full_shard"
  680. SHARD_GRAD_OP = "shard_grad_op"
  681. NO_SHARD = "no_shard"
  682. HYBRID_SHARD = "hybrid_shard"
  683. HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2"
  684. OFFLOAD = "offload"
  685. AUTO_WRAP = "auto_wrap"
  686. class RemoveColumnsCollator:
  687. """Wrap the data collator to remove unused columns before they are passed to the collator."""
  688. def __init__(
  689. self,
  690. data_collator,
  691. signature_columns,
  692. logger=None,
  693. model_name: Optional[str] = None,
  694. description: Optional[str] = None,
  695. ):
  696. self.data_collator = data_collator
  697. self.signature_columns = signature_columns
  698. self.logger = logger
  699. self.description = description
  700. self.model_name = model_name
  701. self.message_logged = False
  702. def _remove_columns(self, feature: dict) -> dict:
  703. if not isinstance(feature, dict):
  704. return feature
  705. if not self.message_logged and self.logger and self.model_name:
  706. ignored_columns = list(set(feature.keys()) - set(self.signature_columns))
  707. if len(ignored_columns) > 0:
  708. dset_description = "" if self.description is None else f"in the {self.description} set"
  709. self.logger.info(
  710. f"The following columns {dset_description} don't have a corresponding argument in "
  711. f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}."
  712. f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, "
  713. " you can safely ignore this message."
  714. )
  715. self.message_logged = True
  716. return {k: v for k, v in feature.items() if k in self.signature_columns}
  717. def __call__(self, features: list[dict]):
  718. features = [self._remove_columns(feature) for feature in features]
  719. return self.data_collator(features)
  720. def check_target_module_exists(optim_target_modules, key: str, return_is_regex: bool = False):
  721. """A helper method to check if the passed module's key name matches any of the target modules in the optim_target_modules.
  722. Args:
  723. optim_target_modules (`Union[str, list[str]]`):
  724. A list of strings to try to match. Can be also a full string.
  725. key (`str`):
  726. A key to search any matches in optim_target_modules
  727. return_is_regex (`bool`):
  728. If set to `True`, the method will return whether the passed `optim_target_modules`
  729. is a regex or not.
  730. Returns:
  731. `bool` : True of match object if key matches any target modules from config, False or
  732. None if no match found
  733. `bool` : If the matched target module is a regex to silence out the warnings in Trainer
  734. for extra modules being found (only if `target_module_found=True` for an array of regex).
  735. """
  736. target_module_found = False
  737. is_regex = False
  738. if isinstance(optim_target_modules, str):
  739. target_module_found = bool(re.fullmatch(optim_target_modules, key))
  740. is_regex = optim_target_modules != key
  741. elif key in optim_target_modules: # from here, target_module_found must be a list of str
  742. # this module is specified directly in target_modules
  743. target_module_found = True
  744. elif any(target_key in key for target_key in optim_target_modules):
  745. target_module_found = True
  746. elif any(bool(re.fullmatch(optim_target_module, key)) for optim_target_module in optim_target_modules):
  747. target_module_found = True
  748. is_regex = True
  749. if return_is_regex:
  750. return target_module_found, is_regex
  751. return target_module_found