trainer_seq2seq.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import contextlib
  15. from copy import deepcopy
  16. from pathlib import Path
  17. from typing import TYPE_CHECKING, Any, Callable, Optional, Union
  18. import torch
  19. from torch import nn
  20. from torch.distributed.fsdp import FullyShardedDataParallel
  21. from torch.utils.data import Dataset
  22. from .generation.configuration_utils import GenerationConfig
  23. from .integrations.deepspeed import is_deepspeed_zero3_enabled
  24. from .integrations.fsdp import is_fsdp_managed_module
  25. from .trainer import Trainer
  26. from .utils import is_datasets_available, logging
  27. from .utils.deprecation import deprecate_kwarg
  28. if is_datasets_available():
  29. import datasets
  30. if TYPE_CHECKING:
  31. from torch.utils.data import IterableDataset
  32. from .data.data_collator import DataCollator
  33. from .feature_extraction_utils import FeatureExtractionMixin
  34. from .image_processing_utils import BaseImageProcessor
  35. from .modeling_utils import PreTrainedModel
  36. from .processing_utils import ProcessorMixin
  37. from .tokenization_utils_base import PreTrainedTokenizerBase
  38. from .trainer_callback import TrainerCallback
  39. from .trainer_utils import EvalPrediction, PredictionOutput
  40. from .training_args import TrainingArguments
  41. logger = logging.get_logger(__name__)
  42. class Seq2SeqTrainer(Trainer):
  43. @deprecate_kwarg("tokenizer", new_name="processing_class", version="5.0.0", raise_if_both_names=True)
  44. def __init__(
  45. self,
  46. model: Optional[Union["PreTrainedModel", nn.Module]] = None,
  47. args: Optional["TrainingArguments"] = None,
  48. data_collator: Optional["DataCollator"] = None,
  49. train_dataset: Optional[Union[Dataset, "IterableDataset", "datasets.Dataset"]] = None,
  50. eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
  51. processing_class: Optional[
  52. Union["PreTrainedTokenizerBase", "BaseImageProcessor", "FeatureExtractionMixin", "ProcessorMixin"]
  53. ] = None,
  54. model_init: Optional[Callable[[], "PreTrainedModel"]] = None,
  55. compute_loss_func: Optional[Callable] = None,
  56. compute_metrics: Optional[Callable[["EvalPrediction"], dict]] = None,
  57. callbacks: Optional[list["TrainerCallback"]] = None,
  58. optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
  59. preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
  60. ):
  61. super().__init__(
  62. model=model,
  63. args=args,
  64. data_collator=data_collator,
  65. train_dataset=train_dataset,
  66. eval_dataset=eval_dataset,
  67. processing_class=processing_class,
  68. model_init=model_init,
  69. compute_loss_func=compute_loss_func,
  70. compute_metrics=compute_metrics,
  71. callbacks=callbacks,
  72. optimizers=optimizers,
  73. preprocess_logits_for_metrics=preprocess_logits_for_metrics,
  74. )
  75. # Override self.model.generation_config if a GenerationConfig is specified in args.
  76. # Priority: args.generation_config > model.generation_config > default GenerationConfig.
  77. if self.args.generation_config is not None:
  78. gen_config = self.load_generation_config(self.args.generation_config)
  79. self.model.generation_config = gen_config
  80. @staticmethod
  81. def load_generation_config(gen_config_arg: Union[str, GenerationConfig]) -> GenerationConfig:
  82. """
  83. Loads a `~generation.GenerationConfig` from the `Seq2SeqTrainingArguments.generation_config` arguments.
  84. Args:
  85. gen_config_arg (`str` or [`~generation.GenerationConfig]`):
  86. `Seq2SeqTrainingArguments.generation_config` argument.
  87. Returns:
  88. A `~generation.GenerationConfig`.
  89. """
  90. # GenerationConfig provided, nothing to do
  91. if isinstance(gen_config_arg, GenerationConfig):
  92. gen_config = deepcopy(gen_config_arg)
  93. else:
  94. # str or Path
  95. pretrained_model_name = Path(gen_config_arg) if isinstance(gen_config_arg, str) else gen_config_arg
  96. config_file_name = None
  97. # Figuring if it is path pointing to a file, pointing to a directory or else a model id or URL
  98. # This step is required in order to determine config_file_name
  99. if pretrained_model_name.is_file():
  100. config_file_name = pretrained_model_name.name
  101. pretrained_model_name = pretrained_model_name.parent
  102. # dir path
  103. elif pretrained_model_name.is_dir():
  104. pass
  105. # model id or URL
  106. else:
  107. pretrained_model_name = gen_config_arg
  108. gen_config = GenerationConfig.from_pretrained(pretrained_model_name, config_file_name)
  109. # Strict validation to fail early. `GenerationConfig.save_pretrained()`, run at the end of training, throws
  110. # an exception if there are warnings at validation time.
  111. try:
  112. gen_config.validate(strict=True)
  113. except ValueError as exc:
  114. raise ValueError(str(exc) + "\n\nFix these issues to train your model.")
  115. return gen_config
  116. def evaluate(
  117. self,
  118. eval_dataset: Optional[Dataset] = None,
  119. ignore_keys: Optional[list[str]] = None,
  120. metric_key_prefix: str = "eval",
  121. **gen_kwargs,
  122. ) -> dict[str, float]:
  123. """
  124. Run evaluation and returns metrics.
  125. The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
  126. (pass it to the init `compute_metrics` argument).
  127. You can also subclass and override this method to inject custom behavior.
  128. Args:
  129. eval_dataset (`Dataset`, *optional*):
  130. Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
  131. not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
  132. method.
  133. ignore_keys (`list[str]`, *optional*):
  134. A list of keys in the output of your model (if it is a dictionary) that should be ignored when
  135. gathering predictions.
  136. metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
  137. An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
  138. "eval_bleu" if the prefix is `"eval"` (default)
  139. max_length (`int`, *optional*):
  140. The maximum target length to use when predicting with the generate method.
  141. num_beams (`int`, *optional*):
  142. Number of beams for beam search that will be used when predicting with the generate method. 1 means no
  143. beam search.
  144. gen_kwargs:
  145. Additional `generate` specific kwargs.
  146. Returns:
  147. A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
  148. dictionary also contains the epoch number which comes from the training state.
  149. """
  150. gen_kwargs = gen_kwargs.copy()
  151. # Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
  152. # training args
  153. if (
  154. gen_kwargs.get("max_length") is None
  155. and gen_kwargs.get("max_new_tokens") is None
  156. and self.args.generation_max_length is not None
  157. ):
  158. gen_kwargs["max_length"] = self.args.generation_max_length
  159. if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
  160. gen_kwargs["num_beams"] = self.args.generation_num_beams
  161. # We don't want to drop samples in general
  162. self.gather_function = self.accelerator.gather
  163. self._gen_kwargs = gen_kwargs
  164. return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
  165. def predict(
  166. self,
  167. test_dataset: Dataset,
  168. ignore_keys: Optional[list[str]] = None,
  169. metric_key_prefix: str = "test",
  170. **gen_kwargs,
  171. ) -> "PredictionOutput":
  172. """
  173. Run prediction and returns predictions and potential metrics.
  174. Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
  175. will also return metrics, like in `evaluate()`.
  176. Args:
  177. test_dataset (`Dataset`):
  178. Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
  179. `model.forward()` method are automatically removed. Has to implement the method `__len__`
  180. ignore_keys (`list[str]`, *optional*):
  181. A list of keys in the output of your model (if it is a dictionary) that should be ignored when
  182. gathering predictions.
  183. metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
  184. An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
  185. "eval_bleu" if the prefix is `"eval"` (default)
  186. max_length (`int`, *optional*):
  187. The maximum target length to use when predicting with the generate method.
  188. num_beams (`int`, *optional*):
  189. Number of beams for beam search that will be used when predicting with the generate method. 1 means no
  190. beam search.
  191. gen_kwargs:
  192. Additional `generate` specific kwargs.
  193. <Tip>
  194. If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
  195. padding in a token classification task) the predictions will be padded (on the right) to allow for
  196. concatenation into one array. The padding index is -100.
  197. </Tip>
  198. Returns: *NamedTuple* A namedtuple with the following keys:
  199. - predictions (`np.ndarray`): The predictions on `test_dataset`.
  200. - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
  201. - metrics (`dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
  202. labels).
  203. """
  204. gen_kwargs = gen_kwargs.copy()
  205. # Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
  206. # training args
  207. if (
  208. gen_kwargs.get("max_length") is None
  209. and gen_kwargs.get("max_new_tokens") is None
  210. and self.args.generation_max_length is not None
  211. ):
  212. gen_kwargs["max_length"] = self.args.generation_max_length
  213. if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
  214. gen_kwargs["num_beams"] = self.args.generation_num_beams
  215. self.gather_function = self.accelerator.gather
  216. self._gen_kwargs = gen_kwargs
  217. return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
  218. def prediction_step(
  219. self,
  220. model: nn.Module,
  221. inputs: dict[str, Union[torch.Tensor, Any]],
  222. prediction_loss_only: bool,
  223. ignore_keys: Optional[list[str]] = None,
  224. **gen_kwargs,
  225. ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
  226. """
  227. Perform an evaluation step on `model` using `inputs`.
  228. Subclass and override to inject custom behavior.
  229. Args:
  230. model (`nn.Module`):
  231. The model to evaluate.
  232. inputs (`dict[str, Union[torch.Tensor, Any]]`):
  233. The inputs and targets of the model.
  234. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
  235. argument `labels`. Check your model's documentation for all accepted arguments.
  236. prediction_loss_only (`bool`):
  237. Whether or not to return the loss only.
  238. gen_kwargs:
  239. Additional `generate` specific kwargs.
  240. Return:
  241. tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
  242. labels (each being optional).
  243. """
  244. if not self.args.predict_with_generate or prediction_loss_only:
  245. return super().prediction_step(
  246. model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
  247. )
  248. has_labels = "labels" in inputs
  249. inputs = self._prepare_inputs(inputs)
  250. # Priority (handled in generate):
  251. # non-`None` gen_kwargs > model.generation_config > default GenerationConfig()
  252. if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"):
  253. gen_kwargs = self._gen_kwargs.copy()
  254. if "num_beams" in gen_kwargs and gen_kwargs["num_beams"] is None:
  255. gen_kwargs.pop("num_beams")
  256. if "max_length" in gen_kwargs and gen_kwargs["max_length"] is None:
  257. gen_kwargs.pop("max_length")
  258. default_synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self.model)
  259. gen_kwargs["synced_gpus"] = gen_kwargs.get("synced_gpus", default_synced_gpus)
  260. generation_inputs = inputs.copy()
  261. # If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate
  262. # (otherwise, it would continue generating from the padded `decoder_input_ids`)
  263. if (
  264. "labels" in generation_inputs
  265. and "decoder_input_ids" in generation_inputs
  266. and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
  267. ):
  268. generation_inputs = {
  269. k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
  270. }
  271. summon_full_params_context = (
  272. FullyShardedDataParallel.summon_full_params(self.model)
  273. if isinstance(self.model, FullyShardedDataParallel)
  274. else contextlib.nullcontext()
  275. )
  276. with summon_full_params_context:
  277. generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
  278. # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
  279. # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
  280. # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
  281. if self.model.generation_config._from_model_config:
  282. self.model.generation_config._from_model_config = False
  283. # Retrieves GenerationConfig from model.generation_config
  284. gen_config = self.model.generation_config
  285. # in case the batch is shorter than max length, the output should be padded
  286. if generated_tokens.shape[-1] < gen_config.max_length:
  287. generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length)
  288. elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1:
  289. generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1)
  290. with torch.no_grad():
  291. if has_labels:
  292. with self.compute_loss_context_manager():
  293. outputs = model(**inputs)
  294. if self.label_smoother is not None:
  295. loss = self.label_smoother(outputs, inputs["labels"]).detach().mean()
  296. else:
  297. loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).detach().mean()
  298. else:
  299. loss = None
  300. if self.args.prediction_loss_only:
  301. return loss, None, None
  302. if has_labels:
  303. labels = inputs["labels"]
  304. if labels.shape[-1] < gen_config.max_length:
  305. labels = self._pad_tensors_to_max_len(labels, gen_config.max_length)
  306. elif gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1:
  307. labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1)
  308. else:
  309. labels = None
  310. return loss, generated_tokens, labels
  311. def _pad_tensors_to_max_len(self, tensor, max_length):
  312. if self.processing_class is not None and hasattr(self.processing_class, "pad_token_id"):
  313. # If PAD token is not defined at least EOS token has to be defined
  314. pad_token_id = (
  315. self.processing_class.pad_token_id
  316. if self.processing_class.pad_token_id is not None
  317. else self.processing_class.eos_token_id
  318. )
  319. else:
  320. if self.model.config.pad_token_id is not None:
  321. pad_token_id = self.model.config.pad_token_id
  322. else:
  323. raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
  324. padded_tensor = pad_token_id * torch.ones(
  325. (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
  326. )
  327. padded_tensor[:, : tensor.shape[-1]] = tensor
  328. return padded_tensor