optimization_tf.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. # Copyright 2019 The TensorFlow Authors, The Hugging Face Team. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Functions and classes related to optimization (weight updates)."""
  16. from typing import Callable, Optional, Union
  17. import tensorflow as tf
  18. try:
  19. from tf_keras.optimizers.legacy import Adam
  20. except (ImportError, ModuleNotFoundError):
  21. from tensorflow.keras.optimizers.legacy import Adam
  22. from .modeling_tf_utils import keras
  23. # This block because Keras loves randomly moving things to different places - this changed somewhere between 2.10 - 2.15
  24. if hasattr(keras.optimizers.schedules, "learning_rate_schedule"):
  25. schedules = keras.optimizers.schedules.learning_rate_schedule
  26. else:
  27. schedules = keras.optimizers.schedules
  28. class WarmUp(schedules.LearningRateSchedule):
  29. """
  30. Applies a warmup schedule on a given learning rate decay schedule.
  31. Args:
  32. initial_learning_rate (`float`):
  33. The initial learning rate for the schedule after the warmup (so this will be the learning rate at the end
  34. of the warmup).
  35. decay_schedule_fn (`Callable`):
  36. The schedule function to apply after the warmup for the rest of training.
  37. warmup_steps (`int`):
  38. The number of steps for the warmup part of training.
  39. power (`float`, *optional*, defaults to 1.0):
  40. The power to use for the polynomial warmup (defaults is a linear warmup).
  41. name (`str`, *optional*):
  42. Optional name prefix for the returned tensors during the schedule.
  43. """
  44. def __init__(
  45. self,
  46. initial_learning_rate: float,
  47. decay_schedule_fn: Callable,
  48. warmup_steps: int,
  49. power: float = 1.0,
  50. name: Optional[str] = None,
  51. ):
  52. super().__init__()
  53. self.initial_learning_rate = initial_learning_rate
  54. self.warmup_steps = warmup_steps
  55. self.power = power
  56. self.decay_schedule_fn = decay_schedule_fn
  57. self.name = name
  58. def __call__(self, step):
  59. with tf.name_scope(self.name or "WarmUp") as name:
  60. # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
  61. # learning rate will be `global_step/num_warmup_steps * init_lr`.
  62. global_step_float = tf.cast(step, tf.float32)
  63. warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
  64. warmup_percent_done = global_step_float / warmup_steps_float
  65. warmup_learning_rate = self.initial_learning_rate * tf.math.pow(warmup_percent_done, self.power)
  66. return tf.cond(
  67. global_step_float < warmup_steps_float,
  68. lambda: warmup_learning_rate,
  69. lambda: self.decay_schedule_fn(step - self.warmup_steps),
  70. name=name,
  71. )
  72. def get_config(self):
  73. return {
  74. "initial_learning_rate": self.initial_learning_rate,
  75. "decay_schedule_fn": self.decay_schedule_fn,
  76. "warmup_steps": self.warmup_steps,
  77. "power": self.power,
  78. "name": self.name,
  79. }
  80. def create_optimizer(
  81. init_lr: float,
  82. num_train_steps: int,
  83. num_warmup_steps: int,
  84. min_lr_ratio: float = 0.0,
  85. adam_beta1: float = 0.9,
  86. adam_beta2: float = 0.999,
  87. adam_epsilon: float = 1e-8,
  88. adam_clipnorm: Optional[float] = None,
  89. adam_global_clipnorm: Optional[float] = None,
  90. weight_decay_rate: float = 0.0,
  91. power: float = 1.0,
  92. include_in_weight_decay: Optional[list[str]] = None,
  93. ):
  94. """
  95. Creates an optimizer with a learning rate schedule using a warmup phase followed by a linear decay.
  96. Args:
  97. init_lr (`float`):
  98. The desired learning rate at the end of the warmup phase.
  99. num_train_steps (`int`):
  100. The total number of training steps.
  101. num_warmup_steps (`int`):
  102. The number of warmup steps.
  103. min_lr_ratio (`float`, *optional*, defaults to 0):
  104. The final learning rate at the end of the linear decay will be `init_lr * min_lr_ratio`.
  105. adam_beta1 (`float`, *optional*, defaults to 0.9):
  106. The beta1 to use in Adam.
  107. adam_beta2 (`float`, *optional*, defaults to 0.999):
  108. The beta2 to use in Adam.
  109. adam_epsilon (`float`, *optional*, defaults to 1e-8):
  110. The epsilon to use in Adam.
  111. adam_clipnorm (`float`, *optional*, defaults to `None`):
  112. If not `None`, clip the gradient norm for each weight tensor to this value.
  113. adam_global_clipnorm (`float`, *optional*, defaults to `None`)
  114. If not `None`, clip gradient norm to this value. When using this argument, the norm is computed over all
  115. weight tensors, as if they were concatenated into a single vector.
  116. weight_decay_rate (`float`, *optional*, defaults to 0):
  117. The weight decay to use.
  118. power (`float`, *optional*, defaults to 1.0):
  119. The power to use for PolynomialDecay.
  120. include_in_weight_decay (`list[str]`, *optional*):
  121. List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is
  122. applied to all parameters except bias and layer norm parameters.
  123. """
  124. # Implements linear decay of the learning rate.
  125. lr_schedule = schedules.PolynomialDecay(
  126. initial_learning_rate=init_lr,
  127. decay_steps=num_train_steps - num_warmup_steps,
  128. end_learning_rate=init_lr * min_lr_ratio,
  129. power=power,
  130. )
  131. if num_warmup_steps:
  132. lr_schedule = WarmUp(
  133. initial_learning_rate=init_lr,
  134. decay_schedule_fn=lr_schedule,
  135. warmup_steps=num_warmup_steps,
  136. )
  137. if weight_decay_rate > 0.0:
  138. optimizer = AdamWeightDecay(
  139. learning_rate=lr_schedule,
  140. weight_decay_rate=weight_decay_rate,
  141. beta_1=adam_beta1,
  142. beta_2=adam_beta2,
  143. epsilon=adam_epsilon,
  144. clipnorm=adam_clipnorm,
  145. global_clipnorm=adam_global_clipnorm,
  146. exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
  147. include_in_weight_decay=include_in_weight_decay,
  148. )
  149. else:
  150. optimizer = keras.optimizers.Adam(
  151. learning_rate=lr_schedule,
  152. beta_1=adam_beta1,
  153. beta_2=adam_beta2,
  154. epsilon=adam_epsilon,
  155. clipnorm=adam_clipnorm,
  156. global_clipnorm=adam_global_clipnorm,
  157. )
  158. # We return the optimizer and the LR scheduler in order to better track the
  159. # evolution of the LR independently of the optimizer.
  160. return optimizer, lr_schedule
  161. class AdamWeightDecay(Adam):
  162. """
  163. Adam enables L2 weight decay and clip_by_global_norm on gradients. Just adding the square of the weights to the
  164. loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact
  165. with the m and v parameters in strange ways as shown in [Decoupled Weight Decay
  166. Regularization](https://huggingface.co/papers/1711.05101).
  167. Instead we want to decay the weights in a manner that doesn't interact with the m/v parameters. This is equivalent
  168. to adding the square of the weights to the loss with plain (non-momentum) SGD.
  169. Args:
  170. learning_rate (`Union[float, LearningRateSchedule]`, *optional*, defaults to 0.001):
  171. The learning rate to use or a schedule.
  172. beta_1 (`float`, *optional*, defaults to 0.9):
  173. The beta1 parameter in Adam, which is the exponential decay rate for the 1st momentum estimates.
  174. beta_2 (`float`, *optional*, defaults to 0.999):
  175. The beta2 parameter in Adam, which is the exponential decay rate for the 2nd momentum estimates.
  176. epsilon (`float`, *optional*, defaults to 1e-07):
  177. The epsilon parameter in Adam, which is a small constant for numerical stability.
  178. amsgrad (`bool`, *optional*, defaults to `False`):
  179. Whether to apply AMSGrad variant of this algorithm or not, see [On the Convergence of Adam and
  180. Beyond](https://huggingface.co/papers/1904.09237).
  181. weight_decay_rate (`float`, *optional*, defaults to 0.0):
  182. The weight decay to apply.
  183. include_in_weight_decay (`list[str]`, *optional*):
  184. List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is
  185. applied to all parameters by default (unless they are in `exclude_from_weight_decay`).
  186. exclude_from_weight_decay (`list[str]`, *optional*):
  187. List of the parameter names (or re patterns) to exclude from applying weight decay to. If a
  188. `include_in_weight_decay` is passed, the names in it will supersede this list.
  189. name (`str`, *optional*, defaults to `"AdamWeightDecay"`):
  190. Optional name for the operations created when applying gradients.
  191. kwargs (`dict[str, Any]`, *optional*):
  192. Keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
  193. norm; `clipvalue` is clip gradients by value, `decay` is included for backward compatibility to allow time
  194. inverse decay of learning rate. `lr` is included for backward compatibility, recommended to use
  195. `learning_rate` instead.
  196. """
  197. def __init__(
  198. self,
  199. learning_rate: Union[float, schedules.LearningRateSchedule] = 0.001,
  200. beta_1: float = 0.9,
  201. beta_2: float = 0.999,
  202. epsilon: float = 1e-7,
  203. amsgrad: bool = False,
  204. weight_decay_rate: float = 0.0,
  205. include_in_weight_decay: Optional[list[str]] = None,
  206. exclude_from_weight_decay: Optional[list[str]] = None,
  207. name: str = "AdamWeightDecay",
  208. **kwargs,
  209. ):
  210. super().__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
  211. self.weight_decay_rate = weight_decay_rate
  212. self._include_in_weight_decay = include_in_weight_decay
  213. self._exclude_from_weight_decay = exclude_from_weight_decay
  214. @classmethod
  215. def from_config(cls, config):
  216. """Creates an optimizer from its config with WarmUp custom object."""
  217. custom_objects = {"WarmUp": WarmUp}
  218. return super().from_config(config, custom_objects=custom_objects)
  219. def _prepare_local(self, var_device, var_dtype, apply_state):
  220. super()._prepare_local(var_device, var_dtype, apply_state)
  221. apply_state[(var_device, var_dtype)]["weight_decay_rate"] = tf.constant(
  222. self.weight_decay_rate, name="adam_weight_decay_rate"
  223. )
  224. def _decay_weights_op(self, var, learning_rate, apply_state):
  225. do_decay = self._do_use_weight_decay(var.name)
  226. if do_decay:
  227. return var.assign_sub(
  228. learning_rate * var * apply_state[(var.device, var.dtype.base_dtype)]["weight_decay_rate"],
  229. use_locking=self._use_locking,
  230. )
  231. return tf.no_op()
  232. def apply_gradients(self, grads_and_vars, name=None, **kwargs):
  233. grads, tvars = list(zip(*grads_and_vars))
  234. return super().apply_gradients(zip(grads, tvars), name=name, **kwargs)
  235. def _get_lr(self, var_device, var_dtype, apply_state):
  236. """Retrieves the learning rate with the given state."""
  237. if apply_state is None:
  238. return self._decayed_lr_t[var_dtype], {}
  239. apply_state = apply_state or {}
  240. coefficients = apply_state.get((var_device, var_dtype))
  241. if coefficients is None:
  242. coefficients = self._fallback_apply_state(var_device, var_dtype)
  243. apply_state[(var_device, var_dtype)] = coefficients
  244. return coefficients["lr_t"], {"apply_state": apply_state}
  245. def _resource_apply_dense(self, grad, var, apply_state=None):
  246. lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
  247. decay = self._decay_weights_op(var, lr_t, apply_state)
  248. with tf.control_dependencies([decay]):
  249. return super()._resource_apply_dense(grad, var, **kwargs)
  250. def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
  251. lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
  252. decay = self._decay_weights_op(var, lr_t, apply_state)
  253. with tf.control_dependencies([decay]):
  254. return super()._resource_apply_sparse(grad, var, indices, **kwargs)
  255. def get_config(self):
  256. config = super().get_config()
  257. config.update({"weight_decay_rate": self.weight_decay_rate})
  258. return config
  259. def _do_use_weight_decay(self, param_name):
  260. """Whether to use L2 weight decay for `param_name`."""
  261. if self.weight_decay_rate == 0:
  262. return False
  263. if self._include_in_weight_decay:
  264. for r in self._include_in_weight_decay:
  265. if r in param_name:
  266. return True
  267. if self._exclude_from_weight_decay:
  268. for r in self._exclude_from_weight_decay:
  269. if r in param_name:
  270. return False
  271. return True
  272. # Extracted from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
  273. class GradientAccumulator:
  274. """
  275. Gradient accumulation utility. When used with a distribution strategy, the accumulator should be called in a
  276. replica context. Gradients will be accumulated locally on each replica and without synchronization. Users should
  277. then call `.gradients`, scale the gradients if required, and pass the result to `apply_gradients`.
  278. """
  279. # We use the ON_READ synchronization policy so that no synchronization is
  280. # performed on assignment. To get the value, we call .value() which returns the
  281. # value on the current replica without synchronization.
  282. def __init__(self):
  283. """Initializes the accumulator."""
  284. self._gradients = []
  285. self._accum_steps = None
  286. @property
  287. def step(self):
  288. """Number of accumulated steps."""
  289. if self._accum_steps is None:
  290. self._accum_steps = tf.Variable(
  291. tf.constant(0, dtype=tf.int64),
  292. trainable=False,
  293. synchronization=tf.VariableSynchronization.ON_READ,
  294. aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
  295. )
  296. return self._accum_steps.value()
  297. @property
  298. def gradients(self):
  299. """The accumulated gradients on the current replica."""
  300. if not self._gradients:
  301. raise ValueError("The accumulator should be called first to initialize the gradients")
  302. return [gradient.value() if gradient is not None else gradient for gradient in self._gradients]
  303. def __call__(self, gradients):
  304. """Accumulates `gradients` on the current replica."""
  305. if not self._gradients:
  306. _ = self.step # Create the step variable.
  307. self._gradients.extend(
  308. [
  309. tf.Variable(
  310. tf.zeros_like(gradient),
  311. trainable=False,
  312. synchronization=tf.VariableSynchronization.ON_READ,
  313. aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
  314. )
  315. if gradient is not None
  316. else gradient
  317. for gradient in gradients
  318. ]
  319. )
  320. if len(gradients) != len(self._gradients):
  321. raise ValueError(f"Expected {len(self._gradients)} gradients, but got {len(gradients)}")
  322. for accum_gradient, gradient in zip(self._gradients, gradients):
  323. if accum_gradient is not None and gradient is not None:
  324. accum_gradient.assign_add(gradient)
  325. self._accum_steps.assign_add(1)
  326. def reset(self):
  327. """Resets the accumulated gradients on the current replica."""
  328. if not self._gradients:
  329. return
  330. self._accum_steps.assign(0)
  331. for gradient in self._gradients:
  332. if gradient is not None:
  333. gradient.assign(tf.zeros_like(gradient))