inference.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from types import MethodType
  16. from typing import Any, Optional, Union
  17. from .state import PartialState
  18. from .utils import (
  19. calculate_maximum_sizes,
  20. convert_bytes,
  21. copy_tensor_to_devices,
  22. ignorant_find_batch_size,
  23. infer_auto_device_map,
  24. is_pippy_available,
  25. pad_input_tensors,
  26. send_to_device,
  27. )
  28. def generate_device_map(
  29. model, num_processes: int = 1, no_split_module_classes=None, max_memory: Optional[dict] = None
  30. ):
  31. """
  32. Calculates the device map for `model` with an offset for PiPPy
  33. """
  34. if num_processes == 1:
  35. return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False)
  36. if max_memory is None:
  37. model_size, shared = calculate_maximum_sizes(model)
  38. # Split into `n` chunks for each GPU
  39. memory = (model_size + shared[0]) / num_processes
  40. memory = convert_bytes(memory)
  41. value, ending = memory.split(" ")
  42. # Add a chunk to deal with potential extra shared memory instances
  43. memory = math.ceil(float(value)) * 1.1
  44. memory = f"{memory} {ending}"
  45. max_memory = {i: memory for i in range(num_processes)}
  46. device_map = infer_auto_device_map(
  47. model,
  48. max_memory=max_memory,
  49. no_split_module_classes=no_split_module_classes,
  50. clean_result=False,
  51. )
  52. return device_map
  53. def find_pippy_batch_size(args, kwargs):
  54. found_batch_size = None
  55. if args is not None:
  56. for arg in args:
  57. found_batch_size = ignorant_find_batch_size(arg)
  58. if found_batch_size is not None:
  59. break
  60. if kwargs is not None and found_batch_size is None:
  61. for kwarg in kwargs.values():
  62. found_batch_size = ignorant_find_batch_size(kwarg)
  63. if found_batch_size is not None:
  64. break
  65. return found_batch_size
  66. def build_pipeline(model, split_points, args, kwargs, num_chunks):
  67. """
  68. Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing
  69. in needed `args` and `kwargs` as the model needs on the CPU.
  70. Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use
  71. `AcceleratorState.num_processes`
  72. """
  73. # Note: We import here to reduce import time from general modules, and isolate outside dependencies
  74. from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline
  75. # We need to annotate the split points in the model for PiPPy
  76. state = PartialState()
  77. split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points}
  78. pipe = pipeline(
  79. model,
  80. mb_args=args,
  81. mb_kwargs=kwargs,
  82. split_spec=split_spec,
  83. )
  84. stage = pipe.build_stage(state.local_process_index, device=state.device)
  85. schedule = ScheduleGPipe(stage, num_chunks)
  86. return schedule
  87. def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs):
  88. state = PartialState()
  89. output = None
  90. if state.num_processes == 1:
  91. output = forward(*args, **kwargs)
  92. elif state.is_local_main_process:
  93. found_batch_size = find_pippy_batch_size(args, kwargs)
  94. if found_batch_size is None:
  95. raise ValueError("Could not find batch size from args or kwargs")
  96. else:
  97. if found_batch_size != num_chunks:
  98. args = pad_input_tensors(args, found_batch_size, num_chunks)
  99. kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
  100. forward(*args, **kwargs)
  101. elif state.is_last_process:
  102. output = forward()
  103. else:
  104. forward()
  105. if gather_output:
  106. # Each node will get a copy of the full output which is only on the last GPU
  107. output = copy_tensor_to_devices(output)
  108. return output
  109. def prepare_pippy(
  110. model,
  111. split_points: Optional[Union[str, list[str]]] = "auto",
  112. no_split_module_classes: Optional[list[str]] = None,
  113. example_args: Optional[tuple[Any]] = (),
  114. example_kwargs: Optional[dict[str, Any]] = None,
  115. num_chunks: Optional[int] = None,
  116. gather_output: Optional[bool] = False,
  117. ):
  118. """
  119. Wraps `model` for pipeline parallel inference.
  120. Args:
  121. model (`torch.nn.Module`):
  122. A model we want to split for pipeline-parallel inference
  123. split_points (`str` or `List[str]`, defaults to 'auto'):
  124. How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced
  125. split given any model. Should be a list of layer names in the model to split by otherwise.
  126. no_split_module_classes (`List[str]`):
  127. A list of class names for layers we don't want to be split.
  128. example_args (tuple of model inputs):
  129. The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use
  130. this method if possible.
  131. example_kwargs (dict of model inputs)
  132. The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a
  133. *highly* limiting structure that requires the same keys be present at *all* inference calls. Not
  134. recommended unless the prior condition is true for all cases.
  135. num_chunks (`int`, defaults to the number of available GPUs):
  136. The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
  137. this can be tuned and played with. In general one should have num_chunks >= num_gpus.
  138. gather_output (`bool`, defaults to `False`):
  139. If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs.
  140. """
  141. if not is_pippy_available():
  142. raise ImportError("Using `torch.distributed.pipelining` requires PyTorch 2.4.0 or later.")
  143. state = PartialState()
  144. example_args = send_to_device(example_args, "cpu")
  145. example_kwargs = send_to_device(example_kwargs, "cpu")
  146. if num_chunks is None:
  147. num_chunks = state.num_processes
  148. if split_points == "auto":
  149. device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes)
  150. split_points = []
  151. for i in range(1, num_chunks):
  152. split_points.append(next(k for k, v in device_map.items() if v == i))
  153. model.hf_split_points = split_points
  154. stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks)
  155. model._original_forward = model.forward
  156. model._original_call = model.__call__
  157. model.pippy_stage = stage
  158. model.hf_split_points = split_points
  159. def forward(*args, **kwargs):
  160. return pippy_forward(stage.step, num_chunks, gather_output, *args, **kwargs)
  161. # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
  162. # Note: creates an infinite recursion loop with `generate`
  163. model_forward = MethodType(forward, model)
  164. forward.__wrapped__ = model_forward
  165. model.forward = forward
  166. return model