large_model_exporter.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. """
  6. Export LLM to onnx
  7. """
  8. import argparse
  9. import inspect
  10. import math
  11. import os
  12. import tempfile
  13. from pathlib import Path
  14. import onnx
  15. import torch
  16. import transformers
  17. from torch import nn
  18. def disable_huggingface_init():
  19. """do not init model twice as it slow initialization"""
  20. torch.nn.init.kaiming_uniform_ = lambda x, *args, **kwargs: x
  21. torch.nn.init.uniform_ = lambda x, *args, **kwargs: x
  22. torch.nn.init.normal_ = lambda x, *args, **kwargs: x
  23. torch.nn.init.constant_ = lambda x, *args, **kwargs: x
  24. torch.nn.init.xavier_uniform_ = lambda x, *args, **kwargs: x
  25. torch.nn.init.xavier_normal_ = lambda x, *args, **kwargs: x
  26. torch.nn.init.kaiming_normal_ = lambda x, *args, **kwargs: x
  27. torch.nn.init.orthogonal_ = lambda x, *args, **kwargs: x
  28. def get_model_parameter_size(model: nn.Module):
  29. """to calculate how much memory this model needs"""
  30. param_size = 0
  31. param_sum = 0
  32. for param in model.parameters():
  33. param_size += param.nelement() * param.element_size()
  34. param_sum += param.nelement()
  35. buffer_size = 0
  36. buffer_sum = 0
  37. for buffer in model.buffers():
  38. buffer_size += buffer.nelement() * buffer.element_size()
  39. buffer_sum += buffer.nelement()
  40. all_size = (param_size + buffer_size) / 1024 / 1024
  41. return all_size
  42. def initialize_model_and_sample_inputs(hf_model: str, cache_dir: str | None, tokenizer=None):
  43. """
  44. get the pretrained torch model from hugginface,
  45. and sample model-inputs
  46. """
  47. disable_huggingface_init()
  48. model = transformers.AutoModelForCausalLM.from_pretrained( # type: ignore
  49. hf_model, torch_dtype=torch.float16, cache_dir=cache_dir, trust_remote_code=True
  50. )
  51. if tokenizer is None:
  52. tokenizer = hf_model
  53. tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) # type: ignore
  54. sample_inputs = tuple(tokenizer("Hello, my dog is cute", return_tensors="pt").values())
  55. return model, sample_inputs
  56. def auto_pipeline_parallel(model: nn.Module, gpulist: list, sample_inputs: tuple):
  57. """Make the model executable across multiple GPUs."""
  58. def input_gpu_device_hook(mod, inputs, kwargs):
  59. modifyed_inputs = []
  60. first_dev = None
  61. for layer_input in inputs:
  62. if type(layer_input) is not torch.Tensor:
  63. modifyed_inputs.append(layer_input)
  64. elif hasattr(mod, "weight"):
  65. modifyed_inputs.append(layer_input.to(mod.weight.device))
  66. elif hasattr(mod, "parameters"):
  67. device = next(mod.parameters(), layer_input).device
  68. modifyed_inputs.append(layer_input.to(device))
  69. elif hasattr(next(mod.children(), None), "weight"):
  70. modifyed_inputs.append(layer_input.to(next(mod.children()).weight.device))
  71. elif first_dev is not None and layer_input.device != first_dev:
  72. modifyed_inputs.append(layer_input.to(first_dev))
  73. else:
  74. modifyed_inputs.append(layer_input)
  75. if first_dev is None:
  76. first_dev = modifyed_inputs[0].device
  77. for key, value in kwargs.items():
  78. if type(value) is torch.Tensor:
  79. kwargs[key] = value.to(first_dev)
  80. return (tuple(modifyed_inputs), kwargs)
  81. def move_layer_to_device_rurc(mod, dev):
  82. mod.to(dev)
  83. for layer in mod.named_children():
  84. move_layer_to_device_rurc(layer[1], dev)
  85. model = model.half()
  86. all_hooks = []
  87. all_hooks.append(model.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True))
  88. pre_fix = next(iter(model.named_children()))[0]
  89. for top_name, top_module in model.named_children():
  90. for name, module in top_module.named_children():
  91. all_hooks.append(module.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True))
  92. if type(module) in [torch.nn.ModuleList]:
  93. num_layers_on_each_gpu = math.floor(len(module) / len(gpulist))
  94. for idx, attn_layer in enumerate(module):
  95. all_hooks.append(attn_layer.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True))
  96. to_dev = gpulist[min(idx // num_layers_on_each_gpu, len(gpulist))]
  97. attn_layer.to(to_dev)
  98. move_layer_to_device_rurc(attn_layer, to_dev)
  99. print(f"move {pre_fix}.{name}.{idx} to {to_dev}")
  100. else:
  101. module.to(gpulist[0])
  102. print(f"move {pre_fix}.{name} to {gpulist[0]}")
  103. if len(list(top_module.named_children())) == 0:
  104. top_module.to(gpulist[0])
  105. print(f"move {top_name} to {gpulist[0]}")
  106. with torch.no_grad():
  107. model(sample_inputs[0], attention_mask=sample_inputs[1])
  108. return model
  109. def retrieve_onnx_inputs(model: nn.Module, sample_inputs: tuple, with_past: bool):
  110. """
  111. auto retrieve onnx inputs from torch model as we can't enumlate all possibilities
  112. for all models
  113. """
  114. user_inputs = []
  115. def hook_for_inputs(_, inputs, kwargs):
  116. user_inputs.append((inputs, kwargs))
  117. return user_inputs[0]
  118. hook_handle = model.register_forward_pre_hook(hook_for_inputs, with_kwargs=True)
  119. forward_params = inspect.signature(model.forward).parameters
  120. input_keys = list(forward_params.keys())
  121. default_values = [forward_params.get(key).default for key in input_keys]
  122. out = model(sample_inputs[0], attention_mask=sample_inputs[1])
  123. hook_handle.remove()
  124. user_inputs = user_inputs[0]
  125. onnx_inputs = default_values
  126. for idx, _val in enumerate(user_inputs[0]):
  127. onnx_inputs[idx] = user_inputs[0][idx]
  128. for key, value in user_inputs[1].items():
  129. idx = input_keys.index(key)
  130. onnx_inputs[idx] = value
  131. for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs, strict=False)):
  132. if type(value) is torch.Tensor:
  133. value.to(model.device)
  134. if "use_cache" in key:
  135. onnx_inputs[idx] = with_past
  136. out = model(sample_inputs[0], attention_mask=sample_inputs[1], use_cache=with_past) if with_past else out
  137. return input_keys, onnx_inputs, out.past_key_values
  138. def move_to_appropriate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module:
  139. """
  140. According to the model size, we will upload it to
  141. CPU if has no GPU or enough GPU memory,
  142. Single GPU if has only one GPU in local or model size is enough to fit one GPU
  143. Multiple GPU if there is more than one gpu in local and model is too large
  144. """
  145. total_mem_per_cpu = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024
  146. print(f"Model_Size = {get_model_parameter_size(model) / 1024} GB")
  147. print(f"total_mem_per_cpu = {total_mem_per_cpu / 1024} GB")
  148. if get_model_parameter_size(model) > total_mem_per_cpu * 0.45:
  149. device_collection = [torch.device(i) for i in range(torch.cuda.device_count())]
  150. if len(device_collection) > 1:
  151. print(
  152. f"{len(device_collection)} GPUs are used to export onnx, \
  153. Please set CUDA_VISIBLE_DEVICES to use specific GPU group"
  154. )
  155. model = auto_pipeline_parallel(model, device_collection, sample_inputs_tp)
  156. else:
  157. print("!!!! convert model to float and export onnx using CPU")
  158. model = model.cpu().float()
  159. else:
  160. print("Export model on a single GPU")
  161. model = model.cuda().half()
  162. return model
  163. def adapt_inputs_to_device(sample_inputs: tuple, device: torch.device) -> tuple:
  164. """move inputs to device"""
  165. sample_inputs_ = []
  166. for sample_int in sample_inputs:
  167. if isinstance(sample_int, torch.Tensor):
  168. sample_inputs_.append(sample_int.to(device))
  169. else:
  170. sample_inputs_.append(sample_int)
  171. return tuple(sample_inputs_)
  172. def fetch_onnx_inputs_outputs_name(
  173. model: nn.Module,
  174. onnx_inputs: list,
  175. torch_input_names: tuple,
  176. past_key_values: tuple,
  177. with_past: bool,
  178. input_with_past: bool,
  179. ):
  180. """fetch onnx inputs and outputs name"""
  181. num_of_past_key = 0
  182. kv_cache_axis = {0: "batch_size"}
  183. # try get num_of_past_key and shape of past_key_value
  184. if past_key_values is not None:
  185. num_of_past_key = len(past_key_values)
  186. seq_index = (torch.tensor(past_key_values[0][0].shape) == onnx_inputs[0].shape[-1]).nonzero().view(-1)
  187. assert seq_index.numel() == 1
  188. kv_cache_axis = {0: "batch_size", seq_index.item(): "seq_len"}
  189. if not num_of_past_key:
  190. num_of_past_key = model.config.num_hidden_layers
  191. # filter out constant inputs
  192. onnx_inp_names = tuple(
  193. [torch_input_names[i] for i in range(len(torch_input_names)) if isinstance(onnx_inputs[i], torch.Tensor)]
  194. )
  195. assert "input_ids" in onnx_inp_names and "attention_mask" in onnx_inp_names, (
  196. "input_ids and attention_mask must be existed in inputs"
  197. )
  198. onnx_out_names = ("logits",)
  199. onnx_dynamic_axes = {
  200. "input_ids": {0: "batch_size", 1: "seq_len"},
  201. "attention_mask": {0: "batch_size", 1: "seq_len"},
  202. }
  203. # add dyanmic dimensions for the unkonw inputs
  204. for idx, name in enumerate(onnx_inp_names):
  205. if name not in onnx_dynamic_axes:
  206. unknown_dims = {i: f"{idx}__unknown_dims__{i}" for i in range(onnx_inputs[idx].dim())}
  207. onnx_dynamic_axes[name] = unknown_dims
  208. if input_with_past:
  209. for i in range(num_of_past_key):
  210. onnx_inp_names += (f"past_key_values.{i}.key",)
  211. onnx_inp_names += (f"past_key_values.{i}.value",)
  212. onnx_dynamic_axes[onnx_inp_names[-1]] = kv_cache_axis
  213. onnx_dynamic_axes[onnx_inp_names[-2]] = kv_cache_axis
  214. if with_past or input_with_past:
  215. for i in range(num_of_past_key):
  216. onnx_out_names += (f"present.{i}.key",)
  217. onnx_out_names += (f"present.{i}.value",)
  218. for idx, name in enumerate(torch_input_names):
  219. if input_with_past:
  220. if name == "past_key_values":
  221. onnx_inputs[idx] = past_key_values
  222. elif name == "attention_mask":
  223. attn_mask = onnx_inputs[idx]
  224. onnx_inputs[idx] = torch.cat(
  225. (attn_mask, torch.ones((attn_mask.shape[0], 1), device=attn_mask.device, dtype=attn_mask.dtype)),
  226. dim=1,
  227. )
  228. elif name == "input_ids":
  229. input_ids = onnx_inputs[idx]
  230. onnx_inputs[idx] = input_ids[:, -1:]
  231. return onnx_inp_names, onnx_out_names, onnx_dynamic_axes
  232. def do_export_internal(model: nn.Module, onnx_io_tuple: tuple, onnx_inputs: tuple, onnx_path: Path, opset: int):
  233. """do export with torch.onnx.export"""
  234. onnx_model_name = onnx_path.name
  235. onnx_inp_names, onnx_out_names, onnx_dynamic_axes = onnx_io_tuple
  236. # two step to export onnx
  237. # 1. export onnx with lots of pieces of weights
  238. # 2. save all weights to external data
  239. with tempfile.TemporaryDirectory() as tmpdirname:
  240. tmp_onnx = os.path.join(tmpdirname, "tmp.onnx")
  241. torch.onnx.export(
  242. model=model,
  243. args=tuple(onnx_inputs),
  244. f=tmp_onnx,
  245. verbose=False,
  246. opset_version=opset,
  247. input_names=onnx_inp_names,
  248. output_names=onnx_out_names,
  249. dynamic_axes=onnx_dynamic_axes,
  250. )
  251. onnx_path.unlink(missing_ok=True)
  252. (onnx_path.parent / f"{onnx_model_name}_ext.data").unlink(missing_ok=True)
  253. onnx_model = onnx.load(str(tmp_onnx))
  254. onnx.save_model(
  255. onnx_model,
  256. str(onnx_path),
  257. save_as_external_data=(len(os.listdir(tmpdirname)) > 1),
  258. all_tensors_to_one_file=True,
  259. location=f"{onnx_model_name}_ext.data",
  260. size_threshold=1024,
  261. convert_attribute=False,
  262. )
  263. @torch.no_grad()
  264. def export_onnx(hf_model: str, cache_dir: str | None, onnx_path_str: str, with_past: bool, opset: int):
  265. """
  266. do export
  267. model: torch model
  268. onnx_path: where the onnx model saved to
  269. sample_inputs_tp: inputs for torch model
  270. """
  271. model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir)
  272. model = move_to_appropriate_device(model, sample_inputs_tp)
  273. sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device)
  274. # input_keys would be usesful if the model has some special inputs
  275. input_keys, onnx_inputs, past_key_value = retrieve_onnx_inputs(model, sample_inputs, with_past)
  276. onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, False)
  277. onnx_model_name = "model.onnx"
  278. onnx_path: Path = Path(onnx_path_str).absolute()
  279. if onnx_path.suffix != ".onnx":
  280. onnx_path = onnx_path / onnx_model_name
  281. do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset)
  282. if not with_past:
  283. return
  284. onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, True)
  285. onnx_model_name = "model_with_past.onnx"
  286. onnx_path = onnx_path.parent / onnx_model_name
  287. do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset)
  288. def parse_arguments():
  289. """arguments parsing."""
  290. parser = argparse.ArgumentParser()
  291. parser.add_argument(
  292. "-m",
  293. "--model",
  294. required=True,
  295. type=str,
  296. default=["meta-llama/Llama-2-70b-hf"],
  297. help="Pre-trained models in huggingface model hub",
  298. )
  299. parser.add_argument(
  300. "-s",
  301. "--saved_path",
  302. required=False,
  303. type=str,
  304. default="./onnx_models/",
  305. help="where the onnx model will be saved",
  306. )
  307. parser.add_argument(
  308. "--cache_dir",
  309. required=False,
  310. type=str,
  311. default=None,
  312. help=("cache directly of huggingface, by setting this to avoid useless downloading if you have one"),
  313. )
  314. parser.add_argument(
  315. "--with_past",
  316. action="store_true",
  317. default=False,
  318. help=("The tool will export onnx without past-key-value by default"),
  319. )
  320. parser.add_argument(
  321. "--opset",
  322. required=False,
  323. type=int,
  324. default=17,
  325. help=(
  326. "the opset to save onnx model, \
  327. try to increase it if this opset doens't have new features you want"
  328. ),
  329. )
  330. return parser.parse_args()
  331. if __name__ == "__main__":
  332. args = parse_arguments()
  333. export_onnx(args.model, args.cache_dir, args.saved_path, args.with_past, args.opset)