training_args_tf.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import warnings
  15. from dataclasses import dataclass, field
  16. from functools import cached_property
  17. from typing import Optional
  18. from .training_args import TrainingArguments
  19. from .utils import is_tf_available, logging, requires_backends
  20. logger = logging.get_logger(__name__)
  21. if is_tf_available():
  22. import tensorflow as tf
  23. from .modeling_tf_utils import keras
  24. @dataclass
  25. class TFTrainingArguments(TrainingArguments):
  26. """
  27. TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop
  28. itself**.
  29. Using [`HfArgumentParser`] we can turn this class into
  30. [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
  31. command line.
  32. Parameters:
  33. output_dir (`str`):
  34. The output directory where the model predictions and checkpoints will be written.
  35. overwrite_output_dir (`bool`, *optional*, defaults to `False`):
  36. If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir`
  37. points to a checkpoint directory.
  38. do_train (`bool`, *optional*, defaults to `False`):
  39. Whether to run training or not. This argument is not directly used by [`Trainer`], it's intended to be used
  40. by your training/evaluation scripts instead. See the [example
  41. scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
  42. do_eval (`bool`, *optional*):
  43. Whether to run evaluation on the validation set or not. Will be set to `True` if `eval_strategy` is
  44. different from `"no"`. This argument is not directly used by [`Trainer`], it's intended to be used by your
  45. training/evaluation scripts instead. See the [example
  46. scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
  47. do_predict (`bool`, *optional*, defaults to `False`):
  48. Whether to run predictions on the test set or not. This argument is not directly used by [`Trainer`], it's
  49. intended to be used by your training/evaluation scripts instead. See the [example
  50. scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
  51. eval_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`):
  52. The evaluation strategy to adopt during training. Possible values are:
  53. - `"no"`: No evaluation is done during training.
  54. - `"steps"`: Evaluation is done (and logged) every `eval_steps`.
  55. - `"epoch"`: Evaluation is done at the end of each epoch.
  56. per_device_train_batch_size (`int`, *optional*, defaults to 8):
  57. The batch size per GPU/TPU core/CPU for training.
  58. per_device_eval_batch_size (`int`, *optional*, defaults to 8):
  59. The batch size per GPU/TPU core/CPU for evaluation.
  60. gradient_accumulation_steps (`int`, *optional*, defaults to 1):
  61. Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
  62. <Tip warning={true}>
  63. When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging,
  64. evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training examples.
  65. </Tip>
  66. learning_rate (`float`, *optional*, defaults to 5e-5):
  67. The initial learning rate for Adam.
  68. weight_decay (`float`, *optional*, defaults to 0):
  69. The weight decay to apply (if not zero).
  70. adam_beta1 (`float`, *optional*, defaults to 0.9):
  71. The beta1 hyperparameter for the Adam optimizer.
  72. adam_beta2 (`float`, *optional*, defaults to 0.999):
  73. The beta2 hyperparameter for the Adam optimizer.
  74. adam_epsilon (`float`, *optional*, defaults to 1e-8):
  75. The epsilon hyperparameter for the Adam optimizer.
  76. max_grad_norm (`float`, *optional*, defaults to 1.0):
  77. Maximum gradient norm (for gradient clipping).
  78. num_train_epochs(`float`, *optional*, defaults to 3.0):
  79. Total number of training epochs to perform.
  80. max_steps (`int`, *optional*, defaults to -1):
  81. If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`.
  82. For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until
  83. `max_steps` is reached.
  84. warmup_ratio (`float`, *optional*, defaults to 0.0):
  85. Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.
  86. warmup_steps (`int`, *optional*, defaults to 0):
  87. Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`.
  88. logging_dir (`str`, *optional*):
  89. [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to
  90. *runs/**CURRENT_DATETIME_HOSTNAME***.
  91. logging_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
  92. The logging strategy to adopt during training. Possible values are:
  93. - `"no"`: No logging is done during training.
  94. - `"epoch"`: Logging is done at the end of each epoch.
  95. - `"steps"`: Logging is done every `logging_steps`.
  96. logging_first_step (`bool`, *optional*, defaults to `False`):
  97. Whether to log and evaluate the first `global_step` or not.
  98. logging_steps (`int`, *optional*, defaults to 500):
  99. Number of update steps between two logs if `logging_strategy="steps"`.
  100. save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`):
  101. The checkpoint save strategy to adopt during training. Possible values are:
  102. - `"no"`: No save is done during training.
  103. - `"epoch"`: Save is done at the end of each epoch.
  104. - `"steps"`: Save is done every `save_steps`.
  105. save_steps (`int`, *optional*, defaults to 500):
  106. Number of updates steps before two checkpoint saves if `save_strategy="steps"`.
  107. save_total_limit (`int`, *optional*):
  108. If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
  109. `output_dir`.
  110. no_cuda (`bool`, *optional*, defaults to `False`):
  111. Whether to not use CUDA even when it is available or not.
  112. seed (`int`, *optional*, defaults to 42):
  113. Random seed that will be set at the beginning of training.
  114. fp16 (`bool`, *optional*, defaults to `False`):
  115. Whether to use 16-bit (mixed) precision training (through NVIDIA Apex) instead of 32-bit training.
  116. fp16_opt_level (`str`, *optional*, defaults to 'O1'):
  117. For `fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details on
  118. the [Apex documentation](https://nvidia.github.io/apex/amp).
  119. local_rank (`int`, *optional*, defaults to -1):
  120. During distributed training, the rank of the process.
  121. tpu_num_cores (`int`, *optional*):
  122. When training on TPU, the number of TPU cores (automatically passed by launcher script).
  123. debug (`bool`, *optional*, defaults to `False`):
  124. Whether to activate the trace to record computation graphs and profiling information or not.
  125. dataloader_drop_last (`bool`, *optional*, defaults to `False`):
  126. Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
  127. or not.
  128. eval_steps (`int`, *optional*, defaults to 1000):
  129. Number of update steps before two evaluations.
  130. past_index (`int`, *optional*, defaults to -1):
  131. Some models like [TransformerXL](../model_doc/transformerxl) or :doc*XLNet <../model_doc/xlnet>* can make
  132. use of the past hidden states for their predictions. If this argument is set to a positive int, the
  133. `Trainer` will use the corresponding output (usually index 2) as the past state and feed it to the model at
  134. the next training step under the keyword argument `mems`.
  135. tpu_name (`str`, *optional*):
  136. The name of the TPU the process is running on.
  137. tpu_zone (`str`, *optional*):
  138. The zone of the TPU the process is running on. If not specified, we will attempt to automatically detect
  139. from metadata.
  140. gcp_project (`str`, *optional*):
  141. Google Cloud Project name for the Cloud TPU-enabled project. If not specified, we will attempt to
  142. automatically detect from metadata.
  143. run_name (`str`, *optional*):
  144. A descriptor for the run. Notably used for trackio, wandb, mlflow, comet and swanlab logging.
  145. xla (`bool`, *optional*):
  146. Whether to activate the XLA compilation or not.
  147. """
  148. framework = "tf"
  149. tpu_name: Optional[str] = field(
  150. default=None,
  151. metadata={"help": "Name of TPU"},
  152. )
  153. tpu_zone: Optional[str] = field(
  154. default=None,
  155. metadata={"help": "Zone of TPU"},
  156. )
  157. gcp_project: Optional[str] = field(
  158. default=None,
  159. metadata={"help": "Name of Cloud TPU-enabled project"},
  160. )
  161. poly_power: float = field(
  162. default=1.0,
  163. metadata={"help": "Power for the Polynomial decay LR scheduler."},
  164. )
  165. xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"})
  166. @cached_property
  167. def _setup_strategy(self) -> tuple["tf.distribute.Strategy", int]:
  168. requires_backends(self, ["tf"])
  169. logger.info("Tensorflow: setting up strategy")
  170. gpus = tf.config.list_physical_devices("GPU")
  171. # Set to float16 at first
  172. if self.fp16:
  173. keras.mixed_precision.set_global_policy("mixed_float16")
  174. if self.no_cuda:
  175. strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
  176. else:
  177. try:
  178. if self.tpu_name:
  179. tpu = tf.distribute.cluster_resolver.TPUClusterResolver(
  180. self.tpu_name, zone=self.tpu_zone, project=self.gcp_project
  181. )
  182. else:
  183. tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
  184. except ValueError:
  185. if self.tpu_name:
  186. raise RuntimeError(f"Couldn't connect to TPU {self.tpu_name}!")
  187. else:
  188. tpu = None
  189. if tpu:
  190. # Set to bfloat16 in case of TPU
  191. if self.fp16:
  192. keras.mixed_precision.set_global_policy("mixed_bfloat16")
  193. tf.config.experimental_connect_to_cluster(tpu)
  194. tf.tpu.experimental.initialize_tpu_system(tpu)
  195. strategy = tf.distribute.TPUStrategy(tpu)
  196. elif len(gpus) == 0:
  197. strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
  198. elif len(gpus) == 1:
  199. strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
  200. elif len(gpus) > 1:
  201. # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
  202. strategy = tf.distribute.MirroredStrategy()
  203. else:
  204. raise ValueError("Cannot find the proper strategy, please check your environment properties.")
  205. return strategy
  206. @property
  207. def strategy(self) -> "tf.distribute.Strategy":
  208. """
  209. The strategy used for distributed training.
  210. """
  211. requires_backends(self, ["tf"])
  212. return self._setup_strategy
  213. @property
  214. def n_replicas(self) -> int:
  215. """
  216. The number of replicas (CPUs, GPUs or TPU cores) used in this training.
  217. """
  218. requires_backends(self, ["tf"])
  219. return self._setup_strategy.num_replicas_in_sync
  220. @property
  221. def should_log(self):
  222. """
  223. Whether or not the current process should produce log.
  224. """
  225. return False # TF Logging is handled by Keras not the Trainer
  226. @property
  227. def train_batch_size(self) -> int:
  228. """
  229. The actual batch size for training (may differ from `per_gpu_train_batch_size` in distributed training).
  230. """
  231. if self.per_gpu_train_batch_size:
  232. logger.warning(
  233. "Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future "
  234. "version. Using `--per_device_train_batch_size` is preferred."
  235. )
  236. per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
  237. return per_device_batch_size * self.n_replicas
  238. @property
  239. def eval_batch_size(self) -> int:
  240. """
  241. The actual batch size for evaluation (may differ from `per_gpu_eval_batch_size` in distributed training).
  242. """
  243. if self.per_gpu_eval_batch_size:
  244. logger.warning(
  245. "Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future "
  246. "version. Using `--per_device_eval_batch_size` is preferred."
  247. )
  248. per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
  249. return per_device_batch_size * self.n_replicas
  250. @property
  251. def n_gpu(self) -> int:
  252. """
  253. The number of replicas (CPUs, GPUs or TPU cores) used in this training.
  254. """
  255. requires_backends(self, ["tf"])
  256. warnings.warn(
  257. "The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.",
  258. FutureWarning,
  259. )
  260. return self._setup_strategy.num_replicas_in_sync