transformer_engine.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from types import MethodType
  15. import torch.nn as nn
  16. from .imports import is_hpu_available, is_transformer_engine_available
  17. from .operations import GatheredParameters
  18. # Do not import `transformer_engine` at package level to avoid potential issues
  19. def convert_model(model, to_transformer_engine=True, _convert_linear=True, _convert_ln=True):
  20. """
  21. Recursively converts the linear and layernorm layers of a model to their `transformers_engine` counterpart.
  22. """
  23. if not is_transformer_engine_available():
  24. raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
  25. if is_hpu_available():
  26. import intel_transformer_engine as te
  27. if not hasattr(te, "LayerNorm"):
  28. # HPU does not have a LayerNorm implementation in TE
  29. te.LayerNorm = nn.LayerNorm
  30. else:
  31. import transformer_engine.pytorch as te
  32. for name, module in model.named_children():
  33. if isinstance(module, nn.Linear) and to_transformer_engine and _convert_linear:
  34. has_bias = module.bias is not None
  35. params_to_gather = [module.weight]
  36. if has_bias:
  37. params_to_gather.append(module.bias)
  38. with GatheredParameters(params_to_gather, modifier_rank=0):
  39. if any(p % 16 != 0 for p in module.weight.shape):
  40. return
  41. te_module = te.Linear(
  42. module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
  43. )
  44. te_module.weight.copy_(module.weight)
  45. if has_bias:
  46. te_module.bias.copy_(module.bias)
  47. setattr(model, name, te_module)
  48. # Note: @xrsrke (Phuc) found that te.LayerNorm doesn't have any real memory savings or speedups over nn.LayerNorm
  49. elif isinstance(module, nn.LayerNorm) and to_transformer_engine and _convert_ln:
  50. with GatheredParameters([module.weight, module.bias], modifier_rank=0):
  51. has_bias = module.bias is not None
  52. te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
  53. te_module.weight.copy_(module.weight)
  54. if has_bias:
  55. te_module.bias.copy_(module.bias)
  56. setattr(model, name, te_module)
  57. elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear:
  58. has_bias = module.bias is not None
  59. new_module = nn.Linear(
  60. module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
  61. )
  62. new_module.weight.copy_(module.weight)
  63. if has_bias:
  64. new_module.bias.copy_(module.bias)
  65. setattr(model, name, new_module)
  66. elif isinstance(module, te.LayerNorm) and not to_transformer_engine and _convert_ln:
  67. new_module = nn.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
  68. new_module.weight.copy_(module.weight)
  69. new_module.bias.copy_(module.bias)
  70. setattr(model, name, new_module)
  71. else:
  72. convert_model(
  73. module,
  74. to_transformer_engine=to_transformer_engine,
  75. _convert_linear=_convert_linear,
  76. _convert_ln=_convert_ln,
  77. )
  78. def has_transformer_engine_layers(model):
  79. """
  80. Returns whether a given model has some `transformer_engine` layer or not.
  81. """
  82. if not is_transformer_engine_available():
  83. raise ImportError("Using `has_transformer_engine_layers` requires transformer_engine to be installed.")
  84. if is_hpu_available():
  85. import intel_transformer_engine as te
  86. module_cls_to_check = te.Linear
  87. else:
  88. import transformer_engine.pytorch as te
  89. module_cls_to_check = (te.LayerNorm, te.Linear, te.TransformerLayer)
  90. for m in model.modules():
  91. if isinstance(m, module_cls_to_check):
  92. return True
  93. return False
  94. def contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=False):
  95. """
  96. Wrapper for a model's forward method to apply FP8 autocast. Is context aware, meaning that by default it will
  97. disable FP8 autocast during eval mode, which is generally better for more accurate metrics.
  98. """
  99. if not is_transformer_engine_available():
  100. raise ImportError("Using `contextual_fp8_autocast` requires transformer_engine to be installed.")
  101. if is_hpu_available():
  102. from intel_transformer_engine import fp8_autocast
  103. else:
  104. from transformer_engine.pytorch import fp8_autocast
  105. def forward(self, *args, **kwargs):
  106. enabled = use_during_eval or self.training
  107. with fp8_autocast(enabled=enabled, fp8_recipe=fp8_recipe):
  108. return model_forward(*args, **kwargs)
  109. # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
  110. forward.__wrapped__ = model_forward
  111. return forward
  112. def apply_fp8_autowrap(model, fp8_recipe_handler):
  113. """
  114. Applies FP8 context manager to the model's forward method
  115. """
  116. if not is_transformer_engine_available():
  117. raise ImportError("Using `apply_fp8_autowrap` requires transformer_engine to be installed.")
  118. if is_hpu_available():
  119. import intel_transformer_engine.recipe as te_recipe
  120. is_fp8_block_scaling_available = False
  121. message = "MXFP8 block scaling is not available on HPU."
  122. else:
  123. import transformer_engine.common.recipe as te_recipe
  124. import transformer_engine.pytorch as te
  125. is_fp8_block_scaling_available, message = te.fp8.check_mxfp8_support()
  126. kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}
  127. if "fp8_format" in kwargs:
  128. kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
  129. use_during_eval = kwargs.pop("use_autocast_during_eval", False)
  130. use_mxfp8_block_scaling = kwargs.pop("use_mxfp8_block_scaling", False)
  131. if use_mxfp8_block_scaling and not is_fp8_block_scaling_available:
  132. raise ValueError(f"MXFP8 block scaling is not available: {message}")
  133. if use_mxfp8_block_scaling:
  134. if "amax_compute_algo" in kwargs:
  135. raise ValueError("`amax_compute_algo` is not supported for MXFP8 block scaling.")
  136. if "amax_history_len" in kwargs:
  137. raise ValueError("`amax_history_len` is not supported for MXFP8 block scaling.")
  138. fp8_recipe = te_recipe.MXFP8BlockScaling(**kwargs)
  139. else:
  140. fp8_recipe = te_recipe.DelayedScaling(**kwargs)
  141. new_forward = contextual_fp8_autocast(model.forward, fp8_recipe, use_during_eval)
  142. if hasattr(model.forward, "__func__"):
  143. model.forward = MethodType(new_forward, model)
  144. else:
  145. model.forward = new_forward
  146. return model