fx.py 56 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502
  1. # Copyright 2021 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import builtins
  15. import collections
  16. import contextlib
  17. import functools
  18. import inspect
  19. import math
  20. import operator
  21. import os
  22. import random
  23. import sys
  24. import warnings
  25. from typing import Any, Callable, Literal, Optional, Union
  26. import torch
  27. import torch.utils._pytree as pytree
  28. from torch import nn
  29. from torch.fx import Graph, GraphModule, Node, Proxy, Tracer
  30. from torch.fx._compatibility import compatibility
  31. from torch.fx._symbolic_trace import is_fx_tracing
  32. from torch.fx.proxy import ParameterProxy
  33. from .. import logging
  34. from ..cache_utils import Cache, DynamicCache, StaticCache
  35. from ..modeling_utils import PretrainedConfig, PreTrainedModel
  36. from ..models.auto import get_values
  37. from ..models.auto.modeling_auto import (
  38. MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
  39. MODEL_FOR_BACKBONE_MAPPING_NAMES,
  40. MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
  41. MODEL_FOR_CTC_MAPPING_NAMES,
  42. MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
  43. MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
  44. MODEL_FOR_IMAGE_MAPPING_NAMES,
  45. MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
  46. MODEL_FOR_MASKED_LM_MAPPING_NAMES,
  47. MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
  48. MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
  49. MODEL_FOR_PRETRAINING_MAPPING_NAMES,
  50. MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
  51. MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
  52. MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
  53. MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
  54. MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
  55. MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
  56. MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES,
  57. MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
  58. MODEL_MAPPING_NAMES,
  59. )
  60. from .import_utils import (
  61. ENV_VARS_TRUE_VALUES,
  62. is_peft_available,
  63. )
  64. if is_peft_available():
  65. from peft import PeftModel
  66. logger = logging.get_logger(__name__)
  67. _IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE_VALUES
  68. def _generate_supported_model_class_names(
  69. model_name: type[PretrainedConfig],
  70. supported_tasks: Optional[Union[str, list[str]]] = None,
  71. ) -> list[str]:
  72. task_mapping = {
  73. "default": MODEL_MAPPING_NAMES,
  74. "pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
  75. "next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
  76. "masked-lm": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
  77. "causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
  78. "seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
  79. "speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
  80. "multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
  81. "document-question-answering": MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
  82. "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
  83. "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
  84. "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
  85. "masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
  86. "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
  87. "zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
  88. "ctc": MODEL_FOR_CTC_MAPPING_NAMES,
  89. "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
  90. "semantic-segmentation": MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
  91. "backbone": MODEL_FOR_BACKBONE_MAPPING_NAMES,
  92. "image-feature-extraction": MODEL_FOR_IMAGE_MAPPING_NAMES,
  93. }
  94. if supported_tasks is None:
  95. supported_tasks = task_mapping.keys()
  96. if isinstance(supported_tasks, str):
  97. supported_tasks = [supported_tasks]
  98. model_class_names = []
  99. for task in supported_tasks:
  100. class_name = task_mapping[task].get(model_name, None)
  101. if class_name:
  102. model_class_names.append(class_name)
  103. return model_class_names
  104. _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
  105. "altclip",
  106. "albert",
  107. "bart",
  108. "bert",
  109. "bitnet",
  110. "blenderbot",
  111. "blenderbot-small",
  112. "bloom",
  113. "clip",
  114. "convnext",
  115. "deberta",
  116. "deberta-v2",
  117. "dinov2",
  118. "dinov3_convnext",
  119. "dinov3_vit",
  120. "distilbert",
  121. "donut-swin",
  122. "electra",
  123. "gpt2",
  124. "gpt_neo",
  125. "gptj",
  126. "hiera",
  127. "hubert",
  128. "ijepa",
  129. "layoutlm",
  130. "llama",
  131. "cohere",
  132. "lxmert",
  133. "m2m_100",
  134. "marian",
  135. "mbart",
  136. "megatron-bert",
  137. "ministral",
  138. "mistral",
  139. "mixtral",
  140. "mobilebert",
  141. "mt5",
  142. "nezha",
  143. "opt",
  144. "pegasus",
  145. "plbart",
  146. "qwen2",
  147. "qwen2_moe",
  148. "qwen3",
  149. "qwen3_next",
  150. "qwen3_moe",
  151. "resnet",
  152. "roberta",
  153. "segformer",
  154. "speech_to_text",
  155. "speech_to_text_2",
  156. "swin",
  157. "t5",
  158. "trocr",
  159. "vit",
  160. "vjepa2",
  161. "xglm",
  162. "wav2vec2",
  163. # "xlnet",
  164. ]
  165. _FX_SUPPORTED_MODELS_WITH_KV_CACHE = ["llama", "opt"]
  166. _REGULAR_SUPPORTED_MODELS = []
  167. for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
  168. if isinstance(item, dict):
  169. _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(**item))
  170. else:
  171. _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(item))
  172. _SPECIAL_SUPPORTED_MODELS = [
  173. "CLIPTextModel",
  174. "CLIPTextModelWithProjection",
  175. "CLIPVisionModel",
  176. "CLIPVisionModelWithProjection",
  177. "AltCLIPTextModel",
  178. "AltCLIPVisionModel",
  179. "GitVisionModel",
  180. "GPT2DoubleHeadsModel",
  181. "Speech2Text2Decoder",
  182. "TrOCRDecoder",
  183. "PeftModelForCausalLM",
  184. "PeftModelForSeq2SeqLM",
  185. "VJEPA2ForVideoClassification",
  186. # TODO: add support for them as it should be quite easy to do so (small blocking issues).
  187. # XLNetForQuestionAnswering,
  188. ]
  189. _SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
  190. _CURRENT_TRACER = None
  191. def torch_nn_embedding(self, input):
  192. return torch.empty(*input.shape, self.weight.shape[-1], device="meta", dtype=self.weight.dtype)
  193. def torch_nn_functional_embedding(
  194. input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
  195. ):
  196. return torch.empty(*input.shape, weight.shape[-1], device="meta", dtype=weight.dtype)
  197. def torch_nn_layernorm(self, input):
  198. return input
  199. def torch_nn_groupnorm(self, input):
  200. return input
  201. def torch_nn_linear(self, input):
  202. return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
  203. def torch_relu(x):
  204. return x
  205. def torch_nn_relu(self, x):
  206. return x
  207. def torch_nn_functional_relu(x, inplace=False):
  208. if not inplace:
  209. raise ValueError("Don't support in-place functional.relu for MetaTensor analysis")
  210. return x
  211. def torch_where(condition, x, y):
  212. # torch.where returns the broadcasted tensor of condition, x, and y,
  213. # so hack it by using addition
  214. return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
  215. def torch_abs(input, *, out=None):
  216. if out is not None:
  217. raise ValueError("Don't support in-place abs for MetaTensor analysis")
  218. return input
  219. def torch_arange(*args, **kwargs):
  220. n = len(args)
  221. step = 1
  222. if n == 1:
  223. start = 0
  224. end = args[0]
  225. elif n == 2:
  226. start, end = args
  227. else:
  228. start, end, step = args
  229. if isinstance(start, float):
  230. start = int(start)
  231. if isinstance(end, float):
  232. start = int(end)
  233. if isinstance(step, float):
  234. step = int(step)
  235. step = kwargs.get("step", step)
  236. dtype = kwargs.get("dtype")
  237. return torch.empty((end - start) // step, dtype=dtype, device="meta")
  238. def torch_full(*args, **kwargs):
  239. args = list(args)
  240. # We set the fill value to 1 as its value is not important as long as it's not a tensor on the `meta` device.
  241. if len(args) > 1:
  242. args[1] = 1
  243. else:
  244. kwargs["fill_value"] = 1
  245. kwargs_without_device = dict(kwargs)
  246. kwargs_without_device.pop("device", None)
  247. return torch.full(*args, **kwargs_without_device, device="meta")
  248. def torch_cat(tensors, dim=None, axis=None, *, out=None):
  249. if dim is None and axis is None:
  250. dim = 0
  251. if dim is None and axis is not None:
  252. dim = axis
  253. if dim < 0:
  254. dim = tensors[0].dim() + dim
  255. shapes = [t.shape for t in tensors]
  256. shape = list(shapes[0])
  257. concatenated_dim = sum(shape[dim] for shape in shapes)
  258. final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :]
  259. return torch.empty(final_shape, device="meta")
  260. def torch_stack(tensors, dim=None, axis=None, *, out=None):
  261. if dim is None and axis is None:
  262. dim = 0
  263. if dim is None and axis is not None:
  264. dim = axis
  265. if dim < 0:
  266. dim = tensors[0].dim() + 1 + dim
  267. shape = list(tensors[0].shape)
  268. shape.insert(dim, len(tensors))
  269. return torch.empty(shape, device="meta")
  270. def torch_add(input, other, *, alpha=1, out=None):
  271. if not isinstance(input, torch.Tensor):
  272. return torch.empty_like(other, device="meta")
  273. if not isinstance(other, torch.Tensor):
  274. return torch.empty_like(input, device="meta")
  275. max_length = max(input.dim(), other.dim())
  276. input_shape = list(input.shape) + [1] * (max_length - input.dim())
  277. other_shape = list(other.shape) + [1] * (max_length - other.dim())
  278. shape = []
  279. for i in range(max_length):
  280. shape.append(max(input_shape[i], other_shape[i]))
  281. return torch.empty(shape, device="meta")
  282. def torch_mul(input, other, *, out=None):
  283. return torch_add(input, other, out=out)
  284. def torch_tensor_mul(self, other):
  285. return torch_mul(self, other)
  286. def torch_matmul(input, other, *, out=None):
  287. d1 = input.dim()
  288. d2 = other.dim()
  289. shape = None
  290. if d1 == 1 and d2 == 1:
  291. shape = None
  292. elif d1 == 2 and d2 == 2:
  293. shape = (input.size(0), other.size(1))
  294. elif d1 == 1 and d2 == 2:
  295. shape = (other.size(1),)
  296. elif d1 == 2 and d1 == 1:
  297. shape = (input.size(0),)
  298. else:
  299. max_length = max(input.dim(), other.dim())
  300. shape1 = list(input.shape)
  301. shape2 = list(other.shape)
  302. if d1 == 1:
  303. shape1 = [1] + shape1
  304. if d2 == 1:
  305. shape2.append(1)
  306. shape1 = [-1] * (max_length - d1) + list(input.shape)
  307. shape2 = [-1] * (max_length - d2) + list(other.shape)
  308. shape = []
  309. for i in range(max_length):
  310. shape.append(max(shape1[i], shape2[i]))
  311. shape[-2] = shape1[-2]
  312. shape[-1] = shape2[-1]
  313. if d1 == 1:
  314. shape.pop(-2)
  315. if d2 == 1:
  316. shape.pop(-1)
  317. if shape is None:
  318. return torch.tensor(0.0, device="meta")
  319. return torch.empty(*shape, device="meta")
  320. def torch_bmm(input, mat2, *, out=None):
  321. if out is not None:
  322. raise ValueError("Don't support in-place bmm for MetaTensor analysis")
  323. batch_size, n, m = input.shape
  324. _, _, p = mat2.shape
  325. return torch.empty(batch_size, n, p, device="meta")
  326. def torch_baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None):
  327. if out is not None:
  328. raise ValueError("Don't support in-place baddbmm for MetaTensor analysis")
  329. return torch_bmm(batch1, batch2)
  330. def torch_tensor_baddbmm(self, batch1, batch2, *, beta=1, alpha=1, out=None):
  331. return torch_baddbmm(self, batch1, batch2, beta=beta, alpha=alpha, out=out)
  332. def torch_einsum(equation, *operands):
  333. # TODO: infer shape without performing the computation, this might be quite hard.
  334. concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
  335. return torch.einsum(equation, *concrete_operands).to("meta")
  336. def torch_tensor_repeat(self, *sizes):
  337. shape = list(self.shape)
  338. for i, x in enumerate(sizes):
  339. shape[i] *= x
  340. return torch.empty(shape, device="meta")
  341. def torch_repeat_interleave(*args, dim=None, output_size=None):
  342. num_args = len(args)
  343. if num_args == 1:
  344. shape = [output_size if output_size is not None else args[0].sum()]
  345. else:
  346. shape = list(args[0].shape)
  347. if dim is None:
  348. if num_args > 2:
  349. dim = args[2]
  350. else:
  351. shape = [sum(shape)]
  352. dim = 0
  353. repeats = args[1]
  354. if isinstance(repeats, int) or torch.numel(repeats) == 1:
  355. shape[dim] *= int(repeats)
  356. else:
  357. shape[dim] = output_size if output_size is not None else repeats.sum()
  358. return torch.empty(*shape, device="meta")
  359. def torch_index_select(input, dim, index, *, out=None):
  360. shape = list(input.shape)
  361. shape[dim] = len(index)
  362. return torch.empty(*shape, device="meta")
  363. def torch_tensor_index_select(self, dim, index):
  364. return torch_index_select(self, dim, index)
  365. def torch_gather(input, dim, index, *, sparse_grad=False, out=None):
  366. shape = list(input.shape)
  367. shape[dim] = index.shape[dim]
  368. return torch.empty(*shape, device="meta")
  369. def torch_tensor_gather(self, dim, index):
  370. return torch_gather(self, dim, index)
  371. def torch_roll(input, shifts, dims=None):
  372. return input
  373. def torch_flip(input, dims):
  374. return input
  375. def torch_tensor_flip(self, dims):
  376. return self
  377. def torch_nn_conv1d(self, input):
  378. l_in = input.shape[-1]
  379. shape = None
  380. padding = self.padding
  381. if padding == "valid":
  382. padding = (0, 0)
  383. if padding == "same":
  384. shape = list(input.shape)
  385. if shape is None:
  386. shape = list(input.shape)
  387. l_out = math.floor(
  388. (l_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
  389. )
  390. shape[-1] = l_out
  391. shape[-2] = self.out_channels
  392. return torch.empty(shape, device="meta")
  393. def torch_nn_conv2d(self, input):
  394. h_in, w_in = input.shape[-2:]
  395. shape = None
  396. padding = self.padding
  397. if padding == "valid":
  398. padding = (0, 0)
  399. if padding == "same":
  400. shape = list(input.shape)
  401. if shape is None:
  402. shape = list(input.shape)
  403. h_out = math.floor(
  404. (h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
  405. )
  406. w_out = math.floor(
  407. (w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
  408. )
  409. shape[-2:] = [h_out, w_out]
  410. shape[-3] = self.out_channels
  411. return torch.empty(shape, device="meta")
  412. def torch_squeeze(input, dim=None):
  413. shape = list(input.shape)
  414. if dim is not None:
  415. if dim < 0:
  416. dim = input.dim() + dim
  417. if shape[dim] == 1:
  418. shape.pop(dim)
  419. else:
  420. new_shape = []
  421. for dim_value in shape:
  422. if dim_value == 1:
  423. continue
  424. new_shape.append(dim_value)
  425. shape = new_shape
  426. return torch.empty(shape, device="meta")
  427. def torch_tensor_squeeze(self, dim=None):
  428. return torch_squeeze(self, dim)
  429. def torch_unsqueeze(input, dim):
  430. shape = list(input.shape)
  431. if dim < 0:
  432. dim = input.dim() + 1 + dim
  433. shape.insert(dim, 1)
  434. return torch.empty(shape, device="meta")
  435. def torch_tensor_unsqueeze(self, dim):
  436. return torch_unsqueeze(self, dim)
  437. def torch_unique_consecutive(input, **kwargs):
  438. output = torch.unique_consecutive(torch.zeros_like(input, device="cpu"), **kwargs)
  439. if isinstance(output, torch.Tensor):
  440. return output.to("meta")
  441. else:
  442. return tuple(map(output, lambda x: x.to("meta")))
  443. def torch_nn_functional_one_hot(tensor, num_classes=-1):
  444. if num_classes < 0:
  445. raise ValueError("Don't support automatic num_classes inference for MetaTensor analysis")
  446. shape = list(tensor.shape) + [num_classes]
  447. return torch.empty(shape, device="meta")
  448. def torch_nn_functional_scaled_dot_product_attention(
  449. query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
  450. ):
  451. target_length = query.shape[-2]
  452. head_dim = value.shape[-1]
  453. return torch.empty((*query.shape[:-2], target_length, head_dim), device="meta")
  454. def torch_nn_mseloss(self, input, target):
  455. if self.reduction == "none":
  456. shape = target.shape
  457. else:
  458. shape = (1,)
  459. return torch.empty(shape, device="meta")
  460. def torch_nn_crossentropyloss(self, input, target):
  461. if self.reduction == "none":
  462. shape = target.shape
  463. else:
  464. shape = (1,)
  465. return torch.empty(shape, device="meta")
  466. def torch_nn_bcewithlogitsloss(self, input, target):
  467. if self.reduction == "none":
  468. shape = target.shape
  469. else:
  470. shape = (1,)
  471. return torch.empty(shape, device="meta")
  472. def operator_getitem(a, b):
  473. def to_concrete(t):
  474. if isinstance(t, torch.Tensor):
  475. concrete = torch.ones_like(t, device="cpu")
  476. if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
  477. concrete = concrete.to(torch.int64)
  478. return concrete
  479. return t
  480. if isinstance(a, torch.Tensor):
  481. # TODO: infer shape without performing the computation.
  482. if isinstance(b, tuple):
  483. b = tuple(map(to_concrete, b))
  484. else:
  485. b = to_concrete(b)
  486. return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
  487. return operator.getitem(a, b)
  488. _MANUAL_META_OVERRIDES: dict[Callable, Callable] = {
  489. torch.nn.Embedding: torch_nn_embedding,
  490. torch.nn.functional.embedding: torch_nn_functional_embedding,
  491. torch.nn.LayerNorm: torch_nn_layernorm,
  492. torch.nn.GroupNorm: torch_nn_groupnorm,
  493. torch.nn.Linear: torch_nn_linear,
  494. torch.relu: torch_relu,
  495. torch.nn.functional.relu: torch_nn_functional_relu,
  496. torch.nn.ReLU: torch_nn_relu,
  497. torch.where: torch_where,
  498. torch.abs: torch_abs,
  499. torch.arange: torch_arange,
  500. torch.full: torch_full,
  501. torch.cat: torch_cat,
  502. torch.stack: torch_stack,
  503. torch.add: torch_add,
  504. torch.mul: torch_mul,
  505. torch.Tensor.mul: torch_tensor_mul,
  506. torch.matmul: torch_matmul,
  507. torch.bmm: torch_bmm,
  508. torch.baddbmm: torch_baddbmm,
  509. torch.Tensor.baddbmm: torch_tensor_baddbmm,
  510. torch.einsum: torch_einsum,
  511. torch.Tensor.repeat: torch_tensor_repeat,
  512. torch.repeat_interleave: torch_repeat_interleave,
  513. torch.roll: torch_roll,
  514. torch.flip: torch_flip,
  515. torch.Tensor.flip: torch_tensor_flip,
  516. torch.index_select: torch_index_select,
  517. torch.Tensor.index_select: torch_tensor_index_select,
  518. torch.gather: torch_gather,
  519. torch.Tensor.gather: torch_tensor_gather,
  520. torch.nn.Conv1d: torch_nn_conv1d,
  521. torch.nn.Conv2d: torch_nn_conv2d,
  522. torch.squeeze: torch_squeeze,
  523. torch.Tensor.squeeze: torch_tensor_squeeze,
  524. torch.unsqueeze: torch_unsqueeze,
  525. torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
  526. torch.unique_consecutive: torch_unique_consecutive,
  527. torch.nn.functional.one_hot: torch_nn_functional_one_hot,
  528. torch.nn.MSELoss: torch_nn_mseloss,
  529. torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
  530. torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
  531. operator.getitem: operator_getitem,
  532. }
  533. _MANUAL_META_OVERRIDES[torch.nn.functional.scaled_dot_product_attention] = (
  534. torch_nn_functional_scaled_dot_product_attention
  535. )
  536. class HFProxy(Proxy):
  537. """
  538. Proxy that uses metadata to handle data-dependent control-flow.
  539. """
  540. def install_metadata(self, metadata):
  541. self._metadata = metadata
  542. @property
  543. def shape(self):
  544. return self.tracer.create_proxy("call_method", "size", (self,), {})
  545. @property
  546. def device(self):
  547. # Hack so we can track when devices are used. During meta-tensor propagation,
  548. # replace these values with a constant 'meta'
  549. return MetaDeviceAttribute(self, "device")
  550. def __len__(self):
  551. if hasattr(self, "_metadata") and self._metadata is not None:
  552. return len(self._metadata)
  553. return super().__len__()
  554. def __bool__(self):
  555. if hasattr(self, "_metadata") and self._metadata is not None:
  556. return self._metadata
  557. return super().__bool__()
  558. def __getattr__(self, k):
  559. if k == "_metadata":
  560. return self.__getattribute__(k)
  561. # note: not added to the graph yet, if this is a method call
  562. # we peephole optimize to the method invocation
  563. return HFAttribute(self, k)
  564. def __setitem__(self, indices, values):
  565. return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
  566. def __contains__(self, key):
  567. if hasattr(self, "_metadata") and self._metadata is not None:
  568. return key in self._metadata
  569. return super().__contains__(key)
  570. class HFAttribute(HFProxy):
  571. def __init__(self, root, attr: str):
  572. self.root = root
  573. self.attr = attr
  574. self.tracer = root.tracer
  575. self._node = None
  576. if hasattr(self.root, "_metadata"):
  577. self.install_metadata(getattr(self.root._metadata, attr))
  578. @property
  579. def node(self):
  580. # the node for attributes is added lazily, since most will just be method calls
  581. # which do not rely on the getitem call
  582. if self._node is None:
  583. self._node = self.tracer.create_proxy("call_function", builtins.getattr, (self.root, self.attr), {}).node
  584. return self._node
  585. def __call__(self, *args, **kwargs):
  586. return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
  587. class MetaDeviceAttribute(HFAttribute):
  588. pass
  589. class HFCacheProxy(HFProxy):
  590. """
  591. Proxy that represents an instance of `transformers.cache_utils.Cache`.
  592. """
  593. def install_orig_cache_cls(self, orig_cache_cls: type[Cache]):
  594. self._orig_cache_cls = orig_cache_cls
  595. @property
  596. def __class__(self):
  597. if not hasattr(self, "_orig_cache_cls"):
  598. raise RuntimeError("The original Cache class must be installed to the HFCacheProxy.")
  599. return self.tracer._CLASSES_TO_PATCH[self._orig_cache_cls]
  600. def create_wrapper(
  601. function: Callable,
  602. op_type: Union[Literal["call_function"], Literal["call_method"], Literal["get_attr"]],
  603. proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None,
  604. ) -> Callable:
  605. @functools.wraps(function)
  606. def wrapper(*args, **kwargs):
  607. if not is_fx_tracing():
  608. return function(*args, **kwargs)
  609. found_proxies = []
  610. def check_proxy(a):
  611. if isinstance(a, Proxy):
  612. found_proxies.append(a)
  613. torch.fx.node.map_aggregate(args, check_proxy)
  614. torch.fx.node.map_aggregate(kwargs, check_proxy)
  615. if len(found_proxies) > 0:
  616. tracer = found_proxies[0].tracer
  617. if op_type == "call_function":
  618. target = function
  619. elif op_type == "call_method" or op_type == "get_attr":
  620. target = function.__name__
  621. else:
  622. raise ValueError(f"op_type {op_type} not supported.")
  623. return tracer.create_proxy(op_type, target, args, kwargs, proxy_factory_fn=proxy_factory_fn)
  624. else:
  625. return function(*args, **kwargs)
  626. return wrapper
  627. class HFProxyableClassMeta(type):
  628. """
  629. Metaclass that creates a class with its main methods wrapped to be proxyable.
  630. """
  631. def __new__(
  632. cls,
  633. name: str,
  634. bases: tuple[type, ...],
  635. attrs: dict[str, Any],
  636. proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None,
  637. ):
  638. cls = super().__new__(cls, name, bases, attrs)
  639. for attr_name in dir(cls):
  640. attr = getattr(cls, attr_name, None)
  641. if attr is None:
  642. continue
  643. if attr_name == "__init__":
  644. op_type = "call_function"
  645. elif attr_name.startswith("__"):
  646. op_type = None
  647. elif inspect.ismethod(attr):
  648. op_type = "call_function"
  649. elif inspect.isfunction(attr):
  650. op_type = "call_method"
  651. else:
  652. op_type = None
  653. if op_type is not None:
  654. setattr(cls, attr_name, create_wrapper(attr, op_type, proxy_factory_fn=proxy_factory_fn))
  655. return cls
  656. def gen_constructor_wrapper(target: Callable) -> tuple[Callable, Callable]:
  657. """
  658. Wraps `target` to be proxyable. Used for tensor creators like `torch.ones`, `torch.arange` and so on.
  659. """
  660. wrapper = create_wrapper(target, "call_function")
  661. return wrapper, target
  662. def _proxies_to_metas(v):
  663. """Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
  664. if isinstance(v, MetaDeviceAttribute):
  665. return "meta"
  666. if isinstance(v, torch.fx.Proxy):
  667. if not (isinstance(v, HFProxy) and hasattr(v, "_metadata")):
  668. raise RuntimeError(f"No metadata was found for {v}")
  669. return v._metadata
  670. return v
  671. def create_cache_proxy_factory_fn(orig_cache_cls: type[Cache]) -> Callable[[Node], HFCacheProxy]:
  672. def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
  673. if not isinstance(_CURRENT_TRACER, HFTracer):
  674. raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
  675. cache_proxy = HFCacheProxy(n, _CURRENT_TRACER)
  676. cache_proxy.install_orig_cache_cls(orig_cache_cls)
  677. return cache_proxy
  678. return cache_proxy_factory_fn
  679. # Proxyable equivalent of the cache classes defined in `transformers.cache_utils`.
  680. ProxyableCache = HFProxyableClassMeta(
  681. "ProxyableCache", (Cache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(Cache)
  682. )
  683. ProxyableDynamicCache = HFProxyableClassMeta(
  684. "ProxyableDynamicCache",
  685. (DynamicCache,),
  686. {},
  687. proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache),
  688. )
  689. ProxyableStaticCache = HFProxyableClassMeta(
  690. "ProxyableStaticCache",
  691. (StaticCache,),
  692. {},
  693. proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache),
  694. )
  695. def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[list[int]] = None):
  696. if forbidden_values is None:
  697. forbidden_values = []
  698. value = random.randint(low, high)
  699. while value in forbidden_values:
  700. value = random.randint(low, high)
  701. return value
  702. class HFTracer(Tracer):
  703. """
  704. Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the
  705. regular PyTorch torch.fx.Proxy.
  706. """
  707. # Feature flag for proxying accesses to buffer values
  708. proxy_buffer_attributes: bool = True
  709. allow_insert_stateless_mods: bool = True
  710. _TORCH_METHODS_TO_PATCH = [
  711. "arange",
  712. "zeros",
  713. "ones",
  714. "full",
  715. "full_like",
  716. "eye",
  717. "empty",
  718. "tensor",
  719. "clamp",
  720. "finfo",
  721. "tril",
  722. ]
  723. _CLASSES_TO_PATCH = {
  724. Cache: ProxyableCache,
  725. DynamicCache: ProxyableDynamicCache,
  726. StaticCache: ProxyableStaticCache,
  727. }
  728. supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
  729. def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
  730. super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
  731. def _generate_dummy_input(
  732. self, model: "PreTrainedModel", input_name: str, shape: list[int], input_names: list[str]
  733. ) -> dict[str, torch.Tensor]:
  734. """Generates dummy input for model inference recording."""
  735. # Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
  736. # from pickle, or from the "__class__" attribute in the general case.
  737. model_class_name = getattr(model, "class_for_deserialization", model.__class__).__name__
  738. device = model.device
  739. inputs_dict = {}
  740. # when tracing a model with KV cache, we simply need to unsure that the KV cache length is larger than one to
  741. # rightfully pass certain controlflows (Example: https://github.com/huggingface/transformers/blob/5c8d941d66734811d2ef6f57f15b44f7fb7a98c4/src/transformers/modeling_attn_mask_utils.py#L162).
  742. # After tracing, the model can then still be used with arbitrary lengths different than the one used during tracing.
  743. kv_cache_length = 5
  744. if input_name in ["labels", "start_positions", "end_positions"]:
  745. batch_size = shape[0]
  746. if model_class_name in [
  747. *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
  748. *get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES),
  749. *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
  750. *get_values(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES),
  751. *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
  752. *get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES),
  753. ]:
  754. inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
  755. elif model_class_name in [
  756. *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
  757. *get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES),
  758. "XLNetForQuestionAnswering",
  759. ]:
  760. inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
  761. inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
  762. elif model_class_name in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
  763. if not hasattr(model.config, "problem_type") or model.config.problem_type is None:
  764. raise ValueError(
  765. "Could not retrieve the problem type for the sequence classification task, please set "
  766. 'model.config.problem_type to one of the following values: "regression", '
  767. '"single_label_classification", or "multi_label_classification".'
  768. )
  769. if model.config.problem_type == "regression":
  770. labels_shape = (batch_size, model.config.num_labels)
  771. labels_dtype = torch.float32
  772. elif model.config.problem_type == "single_label_classification":
  773. labels_shape = (batch_size,)
  774. labels_dtype = torch.long
  775. elif model.config.problem_type == "multi_label_classification":
  776. labels_shape = (batch_size, model.config.num_labels)
  777. labels_dtype = torch.float32
  778. else:
  779. raise ValueError(
  780. 'Expected model.config.problem_type to be either: "regression", "single_label_classification"'
  781. f', or "multi_label_classification", but "{model.config.problem_type}" was provided.'
  782. )
  783. inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)
  784. elif model_class_name in [
  785. *get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
  786. *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
  787. *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
  788. *get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
  789. *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
  790. *get_values(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES),
  791. "GPT2DoubleHeadsModel",
  792. "PeftModelForCausalLM",
  793. "PeftModelForSeq2SeqLM",
  794. ]:
  795. inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
  796. elif model_class_name in [*get_values(MODEL_FOR_CTC_MAPPING_NAMES)]:
  797. inputs_dict["labels"] = torch.zeros(shape, dtype=torch.float32, device=device)
  798. else:
  799. raise NotImplementedError(
  800. f"Generating the dummy input named {input_name} for {model_class_name} is not supported yet."
  801. )
  802. elif "pixel_values" in input_name:
  803. batch_size = shape[0]
  804. image_size = getattr(model.config, "image_size", None)
  805. if image_size is None:
  806. if hasattr(model.config, "vision_config"):
  807. image_size = model.config.vision_config.image_size
  808. elif hasattr(model.config, "encoder"):
  809. image_size = model.config.encoder.image_size
  810. else:
  811. image_size = (_generate_random_int(), _generate_random_int())
  812. # If no num_channels is in the config, use some arbitrary value.
  813. num_channels = getattr(model.config, "num_channels", 3)
  814. if not isinstance(image_size, collections.abc.Iterable):
  815. image_size = (image_size, image_size)
  816. height, width = image_size
  817. inputs_dict[input_name] = torch.zeros(
  818. batch_size, num_channels, height, width, dtype=torch.float32, device=device
  819. )
  820. elif "bbox" in input_name:
  821. inputs_dict[input_name] = torch.zeros(*shape, 4, dtype=torch.float, device=device)
  822. elif "input_features" in input_name:
  823. inputs_dict[input_name] = torch.zeros(
  824. *shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
  825. )
  826. elif "inputs_embeds" in input_name:
  827. batch_size = shape[0]
  828. if (
  829. getattr(model.config, "embedding_size", None) is not None
  830. and model.config.model_type != "megatron-bert"
  831. ):
  832. embedding_size = model.config.embedding_size
  833. else:
  834. embedding_size = model.config.hidden_size
  835. if len(shape) == 3:
  836. # (batch_size, num_choices, sequence_length, embedding_size)
  837. embedding_shape = (batch_size, shape[1], shape[2], embedding_size)
  838. else:
  839. # (batch_size, sequence_length, embedding_size)
  840. embedding_shape = (batch_size, shape[1], embedding_size)
  841. inputs_dict[input_name] = torch.zeros(embedding_shape, dtype=torch.float, device=device)
  842. elif "visual_feats" in input_name:
  843. inputs_dict[input_name] = torch.zeros(
  844. shape
  845. + [
  846. model.config.visual_feat_dim,
  847. ],
  848. dtype=torch.float,
  849. device=device,
  850. )
  851. elif "visual_pos" in input_name:
  852. inputs_dict[input_name] = torch.zeros(
  853. shape
  854. + [
  855. model.config.visual_pos_dim,
  856. ],
  857. dtype=torch.float,
  858. device=device,
  859. )
  860. elif "inputs" in input_name:
  861. inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
  862. elif "input_values" in input_name:
  863. batch_size, _ = shape
  864. # Generating big sequence length for audio inputs.
  865. seq_length = _generate_random_int(low=10000, high=20000)
  866. inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
  867. elif "mask" in input_name:
  868. if "past_key_values" in input_names:
  869. mask_shape = [shape[0], shape[1] + kv_cache_length]
  870. else:
  871. mask_shape = shape
  872. inputs_dict[input_name] = torch.zeros(mask_shape, dtype=torch.long, device=device)
  873. elif "ids" in input_name:
  874. inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
  875. elif "past_key_values" in input_name:
  876. if model.config.model_type not in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
  877. raise NotImplementedError(
  878. f"Symbolic trace with past_key_values input is not supported yet for the model {model.config.model_type}. Please open an issue or a PR in Transformers repository if you would like to see the support added."
  879. )
  880. num_heads = model.config.num_attention_heads
  881. head_dim = model.config.hidden_size // model.config.num_attention_heads
  882. cache_shape = (shape[0], num_heads, kv_cache_length, head_dim)
  883. pkv = tuple(
  884. (
  885. torch.rand(cache_shape, dtype=torch.float, device=device),
  886. torch.rand(cache_shape, dtype=torch.float, device=device),
  887. )
  888. for i in range(model.config.num_hidden_layers)
  889. )
  890. inputs_dict[input_name] = pkv
  891. else:
  892. shape_with_hidden_size = shape + [model.config.hidden_size]
  893. inputs_dict[input_name] = torch.zeros(shape_with_hidden_size, dtype=torch.float, device=device)
  894. return inputs_dict
  895. def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
  896. rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
  897. if kind == "placeholder" and target in self.meta_args:
  898. rv.install_metadata(self.meta_args[target])
  899. return rv
  900. if target in self.orig_fns:
  901. # NOTE: tensor constructors in PyTorch define the `device` argument as
  902. # *kwargs-only*. That is why this works. If you add methods to
  903. # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
  904. # this will break and you will likely see issues where we cannot infer
  905. # the size of the output.
  906. if "device" in kwargs:
  907. kwargs["device"] = "meta"
  908. try:
  909. args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas)
  910. kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas)
  911. should_install_metadata = True
  912. self._disable_module_getattr = True
  913. self._disable_call_module = True
  914. if kind == "call_function":
  915. meta_target = _MANUAL_META_OVERRIDES.get(target, target)
  916. meta_out = meta_target(*args_metas, **kwargs_metas)
  917. if isinstance(meta_out, torch.Tensor):
  918. meta_out = meta_out.to(device="meta")
  919. elif kind == "call_method":
  920. method = getattr(args_metas[0].__class__, target)
  921. meta_target = _MANUAL_META_OVERRIDES.get(method, method)
  922. meta_out = meta_target(*args_metas, **kwargs_metas)
  923. elif kind == "call_module":
  924. if not hasattr(self, "orig_forward"):
  925. raise AttributeError(f"{self} does not have an attribute called orig_forward")
  926. mod = self.root.get_submodule(target)
  927. mod_type = type(mod)
  928. if mod_type in _MANUAL_META_OVERRIDES:
  929. meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas)
  930. else:
  931. meta_out = self.orig_forward(*args_metas, **kwargs_metas)
  932. elif kind == "get_attr":
  933. attr_itr = self.root
  934. atoms = target.split(".")
  935. for atom in atoms:
  936. attr_itr = getattr(attr_itr, atom)
  937. if isinstance(attr_itr, torch.Tensor):
  938. meta_out = attr_itr.to(device="meta")
  939. else:
  940. meta_out = attr_itr
  941. else:
  942. should_install_metadata = False
  943. if should_install_metadata:
  944. if not isinstance(rv, Proxy):
  945. raise ValueError("Don't support composite output yet")
  946. rv.install_metadata(meta_out)
  947. except Exception as e:
  948. if _IS_IN_DEBUG_MODE:
  949. warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
  950. self._disable_module_getattr = False
  951. self._disable_call_module = False
  952. return rv
  953. # Replaced by .getattr from PyTorch 1.13
  954. def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
  955. if getattr(self, "_disable_module_getattr", False):
  956. return attr_val
  957. else:
  958. def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
  959. for n, p in collection_to_search:
  960. if attr_val is p:
  961. if n not in parameter_proxy_cache:
  962. kwargs = {}
  963. if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
  964. kwargs["proxy_factory_fn"] = (
  965. None
  966. if not self.param_shapes_constant
  967. else lambda node: ParameterProxy(self, node, n, attr_val)
  968. )
  969. val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
  970. parameter_proxy_cache[n] = val_proxy
  971. return parameter_proxy_cache[n]
  972. return None
  973. if isinstance(attr_val, torch.nn.Parameter):
  974. maybe_parameter_proxy = maybe_get_proxy_for_attr(
  975. attr_val, self.root.named_parameters(), parameter_proxy_cache
  976. )
  977. if maybe_parameter_proxy is not None:
  978. return maybe_parameter_proxy
  979. if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
  980. maybe_buffer_proxy = maybe_get_proxy_for_attr(
  981. attr_val, self.root.named_buffers(), parameter_proxy_cache
  982. )
  983. if maybe_buffer_proxy is not None:
  984. return maybe_buffer_proxy
  985. return attr_val
  986. # Needed for PyTorch 1.13+
  987. def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: dict[str, Any]):
  988. return self._module_getattr(attr, attr_val, parameter_proxy_cache)
  989. def call_module(self, m, forward, args, kwargs):
  990. if getattr(self, "_disable_call_module", False):
  991. return forward(*args, **kwargs)
  992. self.orig_forward = forward
  993. return super().call_module(m, forward, args, kwargs)
  994. def proxy(self, node):
  995. return HFProxy(node, self)
  996. @contextlib.contextmanager
  997. def patch_for_tracing(self, root: Union[torch.nn.Module, Callable[..., Any]]):
  998. # Patching torch functions
  999. self.patched_torch_methods = {
  1000. target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
  1001. }
  1002. self.orig_fns = set()
  1003. for name, (wrapper, orig) in self.patched_torch_methods.items():
  1004. setattr(torch, name, wrapper)
  1005. self.orig_fns.add(orig)
  1006. # Patching classes
  1007. patched = []
  1008. module_of_model = inspect.getmodule(root)
  1009. for name, mod in sys.modules.items():
  1010. if module_of_model is not None and mod is not module_of_model:
  1011. continue
  1012. if not name.startswith("transformers"):
  1013. continue
  1014. for orig_cls, patched_cls in self._CLASSES_TO_PATCH.items():
  1015. for attr_name, attr in mod.__dict__.items():
  1016. if attr is orig_cls:
  1017. patched.append((mod, attr_name, orig_cls))
  1018. setattr(mod, attr_name, patched_cls)
  1019. yield
  1020. # Restoring patched functions and classes.
  1021. for name, (_, orig) in self.patched_torch_methods.items():
  1022. setattr(torch, name, orig)
  1023. self.patched_torch_methods = {}
  1024. self.orig_fns = set()
  1025. for mod, attr_name, orig_cls in patched:
  1026. setattr(mod, attr_name, orig_cls)
  1027. def trace(
  1028. self,
  1029. root: Union[torch.nn.Module, Callable[..., Any]],
  1030. concrete_args: Optional[dict[str, Any]] = None,
  1031. dummy_inputs: Optional[dict[str, Any]] = None,
  1032. complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
  1033. ) -> Graph:
  1034. """
  1035. Traces `root` and returns the corresponding FX `torch.fx.Graph` representation. `root` can either be a
  1036. `torch.nn.Module` instance or a Python callable. Note that after this call, `self.root` may be different from
  1037. the `root` passed in here. For example, when a free function is passed to `trace()`, we will create a
  1038. `torch.nn.Module` instance to use as the root and add embedded constants to.
  1039. Args:
  1040. root (`torch.nn.Module` or `Callable`):
  1041. Either a `torch.nn.Module`` or a function to be traced through. If root is not a
  1042. [`~transformers.PreTrainedModel`], then `dummy_inputs` must be passed, otherwise tracing will fail.
  1043. concrete_args (`dict[str, Any], *optional*):
  1044. Concrete arguments that should not be treated as Proxies
  1045. dummy_inputs (`dict[str, Any]`, *optional*):
  1046. The dummy inputs needed to handle data-dependent control-flow if `root` is not a
  1047. [`~transformers.PreTrainedModel`]. It can also be used when `root` is a
  1048. [`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs.
  1049. complete_concrete_args_with_inputs_not_in_dummy_inputs (`bool`, *optional*, defaults to `True`):
  1050. If `True`, and `dummy_inputs` is specified, every argument that `root` can take that is not in
  1051. `dummy_inputs` and not in `concrete_args` will be added to `concrete_args`, otherwise does nothing.
  1052. Returns:
  1053. `torch.fx.Graph`:
  1054. A FX `torch.fx.Graph` representing the semantics of the passed-in `root`.
  1055. """
  1056. sig = inspect.signature(root.forward if isinstance(root, torch.nn.Module) else root)
  1057. if concrete_args is None:
  1058. concrete_args = {}
  1059. if dummy_inputs is not None and complete_concrete_args_with_inputs_not_in_dummy_inputs:
  1060. for param in sig.parameters.values():
  1061. if param.name in dummy_inputs:
  1062. continue
  1063. if param.default is inspect.Parameter.empty:
  1064. raise ValueError(f"You need to specify a default value for the parameter {param.name}.")
  1065. concrete_args.update(
  1066. {
  1067. p.name: p.default
  1068. for p in sig.parameters.values()
  1069. if (p.name not in dummy_inputs and p.name not in concrete_args)
  1070. }
  1071. )
  1072. input_names = sig.parameters.keys() - concrete_args.keys()
  1073. # Creating a random input shape to generate dummy inputs.
  1074. batch_size = _generate_random_int()
  1075. sequence_length = _generate_random_int()
  1076. shape = [batch_size, sequence_length]
  1077. if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
  1078. num_choices = _generate_random_int(low=2, high=5)
  1079. shape.insert(1, num_choices)
  1080. inputs = dict(dummy_inputs) if dummy_inputs is not None else {}
  1081. for input_name in input_names:
  1082. if input_name in inputs:
  1083. continue
  1084. # We enforce that root must either be a PreTrainedModel or deserialized from a serialized traced model to
  1085. # be able to use HFTracer._generate_dummy_input.
  1086. if isinstance(root, self.supported_archs) or type(root).__qualname__.startswith(
  1087. ("_deserialize_graph_module", "_CodeOnlyModule")
  1088. ):
  1089. inputs.update(self._generate_dummy_input(root, input_name, shape, input_names=input_names))
  1090. else:
  1091. raise RuntimeError(
  1092. f"Could not generate input named {input_name} for because root is not a"
  1093. " transformers.PreTrainedModel."
  1094. )
  1095. def to_meta(value):
  1096. if isinstance(value, torch.Tensor):
  1097. return value.to("meta")
  1098. return value
  1099. concrete_metas = pytree.tree_map(to_meta, inputs)
  1100. for param in sig.parameters.values():
  1101. if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
  1102. concrete_metas[f"**{param.name}"] = {}
  1103. self.meta_args = concrete_metas
  1104. global _CURRENT_TRACER
  1105. _CURRENT_TRACER = self
  1106. with self.patch_for_tracing(root):
  1107. try:
  1108. self.graph = super().trace(root, concrete_args=concrete_args)
  1109. finally:
  1110. _CURRENT_TRACER = None
  1111. # This is necessary because concrete args are added as input to the traced module since
  1112. # https://github.com/pytorch/pytorch/pull/55888.
  1113. for node in self.graph.nodes:
  1114. if node.op == "placeholder":
  1115. # Removing default values for inputs as the forward pass will fail with them.
  1116. if node.target in input_names:
  1117. node.args = ()
  1118. # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
  1119. # It cannot infer on the attributes and methods the input should have, and fails.
  1120. node.type = torch.Tensor
  1121. # It is a concrete arg so it is not used and should be removed.
  1122. else:
  1123. to_visit = [node]
  1124. to_delete = collections.OrderedDict()
  1125. while to_visit:
  1126. n = to_visit.pop(0)
  1127. to_delete[n] = None
  1128. to_visit += list(n.users.keys())
  1129. for user in reversed(to_delete.keys()):
  1130. self.graph.erase_node(user)
  1131. # TODO: solves GraphModule creation.
  1132. # Without this, return type annotation "Tuple" is causing code execution failure.
  1133. if node.op == "output":
  1134. node.type = None
  1135. return self.graph
  1136. def _stateless_mod_instantiation_depends_on_proxies(self, mod: nn.Module) -> bool:
  1137. """
  1138. Whether the module was instantiated with Proxies. If that is the case, such module cannot be a leaf module
  1139. because its attributes are input-dependent.
  1140. """
  1141. return any(isinstance(attr, Proxy) for attr in mod.__dict__.values())
  1142. def _insert_module_as_submodule(self, mod: nn.Module) -> str:
  1143. """
  1144. Helper method which tries to insert a module that was not declared as submodule.
  1145. """
  1146. # If one of the module attributes is a Proxy, it means that its instantiation is input-dependent.
  1147. # It is not possible to insert such modules, those should be traced through.
  1148. if self._stateless_mod_instantiation_depends_on_proxies(mod):
  1149. return ""
  1150. idx = 0
  1151. mod_name = mod.__class__.__name__.lower()
  1152. path = f"{mod_name}_{idx}"
  1153. already_inserted = False
  1154. while hasattr(self.root, path):
  1155. if getattr(self.root, path) is mod:
  1156. already_inserted = True
  1157. break
  1158. path = f"{mod_name}_{idx}"
  1159. idx += 1
  1160. # No need to add multiple instances of the same module.
  1161. if not already_inserted:
  1162. self.root.add_module(path, mod)
  1163. return path
  1164. def path_of_module(self, mod: nn.Module) -> str:
  1165. """
  1166. Helper method to find the qualified name of `mod` in the Module hierarchy of `root`. For example, if `root` has
  1167. a submodule named `foo`, which has a submodule named `bar`, passing `bar` into this function will return the
  1168. string "foo.bar".
  1169. Args:
  1170. mod (str): The `Module` to retrieve the qualified name for.
  1171. """
  1172. try:
  1173. return super().path_of_module(mod)
  1174. except NameError as e:
  1175. if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
  1176. path = self._insert_module_as_submodule(mod)
  1177. return path
  1178. raise e
  1179. def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
  1180. return (not self._stateless_mod_instantiation_depends_on_proxies(m)) and super().is_leaf_module(
  1181. m, module_qualified_name
  1182. )
  1183. @compatibility(is_backward_compatible=True)
  1184. def keys(self, obj: "Proxy") -> Any:
  1185. """Called when a proxy object is has the keys() method called.
  1186. This is what happens when ** is called on a proxy. This should return an iterator if ** is supposed to work in
  1187. your custom tracer.
  1188. """
  1189. attribute = HFAttribute(obj, "keys")()
  1190. if obj.node.target.startswith("**"):
  1191. return attribute._metadata
  1192. return attribute
  1193. def get_concrete_args(model: nn.Module, input_names: list[str]):
  1194. sig = inspect.signature(model.forward)
  1195. if not (set(input_names) <= set(sig.parameters.keys())):
  1196. formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
  1197. formatted_allowed_input_names = ", ".join(sig.parameters.keys())
  1198. raise ValueError(
  1199. f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
  1200. f" {formatted_allowed_input_names}"
  1201. )
  1202. return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
  1203. def is_model_supported(model: "PreTrainedModel"):
  1204. return model.__class__.__name__ in _SUPPORTED_MODELS
  1205. def check_if_model_is_supported(model: "PreTrainedModel"):
  1206. if not is_model_supported(model):
  1207. supported_model_names = ", ".join(_SUPPORTED_MODELS)
  1208. raise NotImplementedError(
  1209. f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
  1210. )
  1211. def symbolic_trace(
  1212. model: "PreTrainedModel",
  1213. input_names: Optional[list[str]] = None,
  1214. disable_check: bool = False,
  1215. tracer_cls: type[HFTracer] = HFTracer,
  1216. ) -> GraphModule:
  1217. """
  1218. Performs symbolic tracing on the model.
  1219. Args:
  1220. model ([`PretrainedModel`]):
  1221. The model to trace.
  1222. input_names (`list[str]`, *optional*):
  1223. The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
  1224. disable_check (`bool`, *optional*, defaults to `False`):
  1225. If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
  1226. tracer_cls (`Type[HFTracer]`, *optional*, defaults to `HFTracer`):
  1227. The tracer class to use for instantiating the tracer. If unset, `HFTracer` is used instead.
  1228. Returns:
  1229. `torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.
  1230. Example:
  1231. ```python
  1232. from transformers.utils.fx import symbolic_trace
  1233. traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
  1234. ```
  1235. """
  1236. if input_names is None:
  1237. input_names = model.dummy_inputs.keys()
  1238. input_names = list(input_names)
  1239. concrete_args = get_concrete_args(model, input_names)
  1240. if not disable_check:
  1241. check_if_model_is_supported(model)
  1242. if "past_key_values" in input_names and not getattr(model.config, "use_cache", False):
  1243. logger.warning(
  1244. "`past_key_values` were specified as input names, but model.config.use_cache = False, this might lead to "
  1245. "unexpected behavior."
  1246. )
  1247. if "past_key_values" not in input_names and getattr(model.config, "use_cache", False):
  1248. logger.warning(
  1249. "`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting "
  1250. "model.config.use_cache = False."
  1251. )
  1252. model.config.use_cache = False
  1253. # Tracing.
  1254. tracer = tracer_cls()
  1255. traced_graph = tracer.trace(model, concrete_args=concrete_args)
  1256. traced = torch.fx.GraphModule(model, traced_graph)
  1257. traced.config = model.config
  1258. # The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus
  1259. # _generate_dummy_input, where the model class is needed.
  1260. traced.class_for_deserialization = model.__class__
  1261. traced.device = model.device
  1262. return traced