graph_signature.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724
  1. # mypy: allow-untyped-defs
  2. import dataclasses
  3. from collections.abc import Collection, Mapping
  4. from enum import auto, Enum
  5. from typing import Optional, TYPE_CHECKING, Union
  6. from torch._library.fake_class_registry import FakeScriptObject
  7. from torch._subclasses.fake_tensor import is_fake
  8. if TYPE_CHECKING:
  9. import torch
  10. from torch._functorch._aot_autograd.schemas import GraphSignature
  11. __all__ = [
  12. "ConstantArgument",
  13. "CustomObjArgument",
  14. "ExportBackwardSignature",
  15. "ExportGraphSignature",
  16. "InputKind",
  17. "InputSpec",
  18. "OutputKind",
  19. "OutputSpec",
  20. "SymIntArgument",
  21. "SymFloatArgument",
  22. "SymBoolArgument",
  23. "TensorArgument",
  24. ]
  25. @dataclasses.dataclass
  26. class TensorArgument:
  27. name: str
  28. @dataclasses.dataclass
  29. class TokenArgument:
  30. name: str
  31. @dataclasses.dataclass
  32. class SymIntArgument:
  33. name: str
  34. @dataclasses.dataclass
  35. class SymFloatArgument:
  36. name: str
  37. @dataclasses.dataclass
  38. class SymBoolArgument:
  39. name: str
  40. @dataclasses.dataclass
  41. class CustomObjArgument:
  42. name: str
  43. class_fqn: str
  44. fake_val: Optional[FakeScriptObject] = None
  45. @dataclasses.dataclass
  46. class ConstantArgument:
  47. name: str
  48. value: Union[int, float, bool, str, None]
  49. ArgumentSpec = Union[
  50. TensorArgument,
  51. SymIntArgument,
  52. SymFloatArgument,
  53. SymBoolArgument,
  54. ConstantArgument,
  55. CustomObjArgument,
  56. TokenArgument,
  57. ]
  58. class InputKind(Enum):
  59. USER_INPUT = auto()
  60. PARAMETER = auto()
  61. BUFFER = auto()
  62. CONSTANT_TENSOR = auto()
  63. CUSTOM_OBJ = auto()
  64. TOKEN = auto()
  65. @dataclasses.dataclass
  66. class InputSpec:
  67. kind: InputKind
  68. arg: ArgumentSpec
  69. target: Optional[str]
  70. persistent: Optional[bool] = None
  71. def __post_init__(self):
  72. if self.kind == InputKind.BUFFER:
  73. assert self.persistent is not None, (
  74. "Failed to specify persistent flag on BUFFER."
  75. )
  76. assert isinstance(
  77. self.arg,
  78. (
  79. TensorArgument,
  80. SymIntArgument,
  81. SymFloatArgument,
  82. SymBoolArgument,
  83. ConstantArgument,
  84. CustomObjArgument,
  85. TokenArgument,
  86. ),
  87. ), f"got {type(self.arg)}"
  88. def __str__(self):
  89. target = "" if self.target is None else f" target='{self.target}'"
  90. persistent = "" if self.persistent is None else f" persistent={self.persistent}"
  91. return f"{str(self.arg.name)}: {str(self.kind.name)}{target}{persistent}"
  92. class OutputKind(Enum):
  93. USER_OUTPUT = auto()
  94. LOSS_OUTPUT = auto()
  95. BUFFER_MUTATION = auto()
  96. PARAMETER_MUTATION = auto()
  97. GRADIENT_TO_PARAMETER = auto()
  98. GRADIENT_TO_USER_INPUT = auto()
  99. USER_INPUT_MUTATION = auto()
  100. TOKEN = auto()
  101. @dataclasses.dataclass
  102. class OutputSpec:
  103. kind: OutputKind
  104. arg: ArgumentSpec
  105. target: Optional[str]
  106. def __post_init__(self):
  107. assert isinstance(
  108. self.arg,
  109. (
  110. TensorArgument,
  111. SymIntArgument,
  112. SymFloatArgument,
  113. SymBoolArgument,
  114. ConstantArgument,
  115. TokenArgument,
  116. CustomObjArgument,
  117. ),
  118. ), self.arg
  119. def __str__(self):
  120. target = "" if self.target is None else f" target='{self.target}'"
  121. return f"{str(self.arg.name)}: {str(self.kind.name)}{target}"
  122. @dataclasses.dataclass
  123. class ExportBackwardSignature:
  124. gradients_to_parameters: dict[str, str]
  125. gradients_to_user_inputs: dict[str, str]
  126. loss_output: str
  127. @dataclasses.dataclass
  128. class ExportGraphSignature:
  129. """
  130. :class:`ExportGraphSignature` models the input/output signature of Export Graph,
  131. which is a fx.Graph with stronger invariants guarantees.
  132. Export Graph is functional and does not access "states" like parameters
  133. or buffers within the graph via ``getattr`` nodes. Instead, :func:`export`
  134. guarantees that parameters, buffers, and constant tensors are lifted out of
  135. the graph as inputs. Similarly, any mutations to buffers are not included
  136. in the graph either, instead the updated values of mutated buffers are
  137. modeled as additional outputs of Export Graph.
  138. The ordering of all inputs and outputs are::
  139. Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
  140. Outputs = [*mutated_inputs, *flattened_user_outputs]
  141. e.g. If following module is exported::
  142. class CustomModule(nn.Module):
  143. def __init__(self) -> None:
  144. super(CustomModule, self).__init__()
  145. # Define a parameter
  146. self.my_parameter = nn.Parameter(torch.tensor(2.0))
  147. # Define two buffers
  148. self.register_buffer("my_buffer1", torch.tensor(3.0))
  149. self.register_buffer("my_buffer2", torch.tensor(4.0))
  150. def forward(self, x1, x2):
  151. # Use the parameter, buffers, and both inputs in the forward method
  152. output = (
  153. x1 + self.my_parameter
  154. ) * self.my_buffer1 + x2 * self.my_buffer2
  155. # Mutate one of the buffers (e.g., increment it by 1)
  156. self.my_buffer2.add_(1.0) # In-place addition
  157. return output
  158. mod = CustomModule()
  159. ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0)))
  160. Resulting Graph is non-functional::
  161. graph():
  162. %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter]
  163. %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1]
  164. %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2]
  165. %x1 : [num_users=1] = placeholder[target=x1]
  166. %x2 : [num_users=1] = placeholder[target=x2]
  167. %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {})
  168. %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {})
  169. %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {})
  170. %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {})
  171. %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {})
  172. return (add_1,)
  173. Resulting ExportGraphSignature of the non-functional Graph would be::
  174. # inputs
  175. p_my_parameter: PARAMETER target='my_parameter'
  176. b_my_buffer1: BUFFER target='my_buffer1' persistent=True
  177. b_my_buffer2: BUFFER target='my_buffer2' persistent=True
  178. x1: USER_INPUT
  179. x2: USER_INPUT
  180. # outputs
  181. add_1: USER_OUTPUT
  182. To get a functional Graph, you can use :func:`run_decompositions`::
  183. mod = CustomModule()
  184. ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0)))
  185. ep = ep.run_decompositions()
  186. Resulting Graph is functional::
  187. graph():
  188. %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter]
  189. %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1]
  190. %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2]
  191. %x1 : [num_users=1] = placeholder[target=x1]
  192. %x2 : [num_users=1] = placeholder[target=x2]
  193. %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {})
  194. %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {})
  195. %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {})
  196. %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {})
  197. %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {})
  198. return (add_2, add_1)
  199. Resulting ExportGraphSignature of the functional Graph would be::
  200. # inputs
  201. p_my_parameter: PARAMETER target='my_parameter'
  202. b_my_buffer1: BUFFER target='my_buffer1' persistent=True
  203. b_my_buffer2: BUFFER target='my_buffer2' persistent=True
  204. x1: USER_INPUT
  205. x2: USER_INPUT
  206. # outputs
  207. add_2: BUFFER_MUTATION target='my_buffer2'
  208. add_1: USER_OUTPUT
  209. """
  210. input_specs: list[InputSpec]
  211. output_specs: list[OutputSpec]
  212. # A list of parameters uniquely identified by mangled fully qualified name
  213. @property
  214. def parameters(self) -> Collection[str]:
  215. return tuple(
  216. s.target
  217. for s in self.input_specs
  218. if s.kind == InputKind.PARAMETER
  219. if isinstance(s.target, str)
  220. )
  221. # A list of buffers uniquely identified by mangled fully qualified name
  222. @property
  223. def buffers(self) -> Collection[str]:
  224. return tuple(
  225. s.target
  226. for s in self.input_specs
  227. if s.kind == InputKind.BUFFER
  228. if isinstance(s.target, str)
  229. )
  230. @property
  231. def non_persistent_buffers(self) -> Collection[str]:
  232. return tuple(
  233. s.target
  234. for s in self.input_specs
  235. if s.kind == InputKind.BUFFER
  236. if s.persistent is False
  237. if isinstance(s.target, str)
  238. )
  239. # A list of lifted constant tensors
  240. @property
  241. def lifted_tensor_constants(self) -> Collection[str]:
  242. return tuple(
  243. s.target
  244. for s in self.input_specs
  245. if s.kind == InputKind.CONSTANT_TENSOR
  246. if isinstance(s.target, str)
  247. )
  248. @property
  249. def lifted_custom_objs(self) -> Collection[str]:
  250. return tuple(
  251. s.target
  252. for s in self.input_specs
  253. if s.kind == InputKind.CUSTOM_OBJ
  254. if isinstance(s.target, str)
  255. )
  256. # Graph node names of pytree-flattened inputs of original program
  257. @property
  258. def user_inputs(self) -> Collection[Union[int, float, bool, None, str]]:
  259. user_inputs: list[Union[int, float, bool, None, str]] = []
  260. for s in self.input_specs:
  261. if s.kind != InputKind.USER_INPUT:
  262. continue
  263. if isinstance(
  264. s.arg,
  265. (
  266. TensorArgument,
  267. SymIntArgument,
  268. SymFloatArgument,
  269. SymBoolArgument,
  270. CustomObjArgument,
  271. ),
  272. ):
  273. user_inputs.append(s.arg.name)
  274. elif isinstance(s.arg, ConstantArgument):
  275. user_inputs.append(s.arg.value)
  276. else:
  277. raise RuntimeError(f"{s.arg} is not a valid user inputs")
  278. return tuple(user_inputs)
  279. # Graph node names of pytree-flattened outputs of original program
  280. # For joint-graph purposes, will include the loss output.
  281. @property
  282. def user_outputs(self) -> Collection[Union[int, float, bool, None, str]]:
  283. user_outputs: list[Union[int, float, bool, None, str]] = []
  284. for s in self.output_specs:
  285. if s.kind not in [
  286. OutputKind.USER_OUTPUT,
  287. OutputKind.LOSS_OUTPUT,
  288. ]:
  289. continue
  290. if isinstance(
  291. s.arg,
  292. (TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument),
  293. ):
  294. user_outputs.append(s.arg.name)
  295. elif isinstance(s.arg, ConstantArgument):
  296. user_outputs.append(s.arg.value)
  297. elif isinstance(s.arg, CustomObjArgument):
  298. user_outputs.append(s.arg.name)
  299. else:
  300. raise RuntimeError(f"{s.arg} is not a valid user output")
  301. return tuple(user_outputs)
  302. # A dictionary mapping graph input node names to parameters. If a graph input
  303. # name is found in this dictionary, it is guaranteed to be a lifted parameter.
  304. @property
  305. def inputs_to_parameters(self) -> Mapping[str, str]:
  306. return _immutable_dict(
  307. (s.arg.name, s.target)
  308. for s in self.input_specs
  309. if s.kind == InputKind.PARAMETER
  310. and isinstance(s.arg, TensorArgument)
  311. and isinstance(s.target, str)
  312. )
  313. # A dictionary mapping graph input node names to buffers. If a graph input
  314. # name is found in this dictionary, it is guaranteed to be a lifted buffer.
  315. @property
  316. def inputs_to_buffers(self) -> Mapping[str, str]:
  317. return _immutable_dict(
  318. (s.arg.name, s.target) # type: ignore[union-attr, misc]
  319. for s in self.input_specs
  320. if s.kind == InputKind.BUFFER
  321. and isinstance(s.arg, TensorArgument)
  322. and isinstance(s.target, str)
  323. )
  324. # A dictionary mapping graph output node names to buffers that are mutated in the
  325. # original program. Buffers that are not mutated will not be found in this dictionary.
  326. @property
  327. def buffers_to_mutate(self) -> Mapping[str, str]:
  328. return _immutable_dict(
  329. (s.arg.name, s.target)
  330. for s in self.output_specs
  331. if s.kind == OutputKind.BUFFER_MUTATION
  332. and isinstance(s.arg, TensorArgument)
  333. and isinstance(s.target, str)
  334. )
  335. @property
  336. def parameters_to_mutate(self) -> Mapping[str, str]:
  337. return _immutable_dict(
  338. (s.arg.name, s.target)
  339. for s in self.output_specs
  340. if s.kind == OutputKind.PARAMETER_MUTATION
  341. and isinstance(s.arg, TensorArgument)
  342. and isinstance(s.target, str)
  343. )
  344. @property
  345. def user_inputs_to_mutate(self) -> Mapping[str, str]:
  346. return _immutable_dict(
  347. (s.arg.name, s.target)
  348. for s in self.output_specs
  349. if s.kind == OutputKind.USER_INPUT_MUTATION
  350. and isinstance(s.arg, TensorArgument)
  351. and isinstance(s.target, str)
  352. )
  353. # A dictionary mapping graph input node names to lifted tensor constants.
  354. @property
  355. def inputs_to_lifted_tensor_constants(self) -> Mapping[str, str]:
  356. return _immutable_dict(
  357. (s.arg.name, s.target)
  358. for s in self.input_specs
  359. if s.kind == InputKind.CONSTANT_TENSOR
  360. and isinstance(s.arg, TensorArgument)
  361. and isinstance(s.target, str)
  362. )
  363. @property
  364. def inputs_to_lifted_custom_objs(self) -> Mapping[str, str]:
  365. return _immutable_dict(
  366. (s.arg.name, s.target)
  367. for s in self.input_specs
  368. if s.kind == InputKind.CUSTOM_OBJ
  369. and isinstance(s.arg, CustomObjArgument)
  370. and isinstance(s.target, str)
  371. )
  372. @property
  373. def backward_signature(self) -> Optional[ExportBackwardSignature]:
  374. loss_output = None
  375. gradients_to_parameters: dict[str, str] = {}
  376. gradients_to_user_inputs: dict[str, str] = {}
  377. for spec in self.output_specs:
  378. if spec.kind == OutputKind.LOSS_OUTPUT:
  379. assert loss_output is None
  380. assert isinstance(spec.arg, TensorArgument)
  381. loss_output = spec.arg.name
  382. elif spec.kind == OutputKind.GRADIENT_TO_PARAMETER:
  383. assert isinstance(spec.target, str)
  384. assert isinstance(spec.arg, TensorArgument)
  385. gradients_to_parameters[spec.arg.name] = spec.target
  386. elif spec.kind == OutputKind.GRADIENT_TO_USER_INPUT:
  387. assert isinstance(spec.target, str)
  388. assert isinstance(spec.arg, TensorArgument)
  389. gradients_to_user_inputs[spec.arg.name] = spec.target
  390. if loss_output is None:
  391. return None
  392. return ExportBackwardSignature(
  393. loss_output=loss_output,
  394. gradients_to_parameters=gradients_to_parameters,
  395. gradients_to_user_inputs=gradients_to_user_inputs,
  396. )
  397. # Map from assertion dependency token index to assertion dep token output
  398. # name in output. The shape of output after aot_autograd will be like:
  399. # (updated_inputs, user_outputs, dep_token).
  400. @property
  401. def assertion_dep_token(self) -> Optional[Mapping[int, str]]:
  402. return None
  403. @property
  404. def input_tokens(self) -> Collection[str]:
  405. input_tokens = []
  406. for s in self.input_specs:
  407. if s.kind == InputKind.TOKEN:
  408. assert isinstance(s.arg, TokenArgument)
  409. input_tokens.append(s.arg.name)
  410. return tuple(input_tokens)
  411. @property
  412. def output_tokens(self) -> Collection[str]:
  413. output_tokens = []
  414. for s in self.output_specs:
  415. if s.kind == OutputKind.TOKEN:
  416. assert isinstance(s.arg, TokenArgument)
  417. output_tokens.append(s.arg.name)
  418. return tuple(output_tokens)
  419. def __post_init__(self) -> None:
  420. assertion_dep_token = self.assertion_dep_token
  421. if assertion_dep_token is None:
  422. return
  423. assert len(assertion_dep_token) == 1
  424. assertion_dep_token_index = next(iter(assertion_dep_token.keys()))
  425. assert (
  426. len(self.user_outputs) + len(self.buffers_to_mutate)
  427. == assertion_dep_token_index
  428. )
  429. def replace_all_uses(self, old: str, new: str):
  430. """
  431. Replace all uses of the old name with new name in the signature.
  432. """
  433. assert isinstance(old, str)
  434. assert isinstance(new, str)
  435. arg_types = (
  436. TensorArgument,
  437. SymIntArgument,
  438. SymFloatArgument,
  439. SymBoolArgument,
  440. CustomObjArgument,
  441. TokenArgument,
  442. )
  443. for o in self.output_specs:
  444. if isinstance(o.arg, arg_types):
  445. if o.arg.name == old:
  446. o.arg.name = new
  447. for i in self.input_specs:
  448. if isinstance(i.arg, arg_types):
  449. if i.arg.name == old:
  450. i.arg.name = new
  451. def get_replace_hook(self, replace_inputs=False):
  452. def _(old, new, user):
  453. if user.op == "output":
  454. self.replace_all_uses(old.name, new)
  455. if replace_inputs and old.op == "placeholder":
  456. self.replace_all_uses(old.name, new)
  457. return _
  458. def __str__(self):
  459. input_specs = "\n".join(str(s) for s in self.input_specs)
  460. output_specs = "\n".join(str(s) for s in self.output_specs)
  461. return f"\n# inputs\n{input_specs}\n\n# outputs\n{output_specs}\n"
  462. def _immutable_dict(items):
  463. """
  464. Creates a mapping where items cannot be added, deleted, or updated.
  465. NOTE: The immutability is shallow (like tuple is an immutable collection).
  466. """
  467. from types import MappingProxyType
  468. return MappingProxyType(dict(items))
  469. def _make_argument_spec(node, token_names) -> ArgumentSpec:
  470. from torch import ScriptObject, SymBool, SymFloat, SymInt
  471. from torch._library.fake_class_registry import FakeScriptObject
  472. if isinstance(node, (int, bool, float, type(None), str)):
  473. # For const outputs we just directly return this
  474. return ConstantArgument(name="", value=node)
  475. assert "val" in node.meta, (
  476. f"{node} is not a constant or a node with a 'val' metadata field"
  477. )
  478. val = node.meta["val"]
  479. if node.name in token_names:
  480. return TokenArgument(name=node.name)
  481. elif is_fake(val):
  482. return TensorArgument(name=node.name)
  483. elif isinstance(val, SymInt):
  484. return SymIntArgument(name=node.name)
  485. elif isinstance(val, SymFloat):
  486. return SymFloatArgument(name=node.name)
  487. elif isinstance(val, SymBool):
  488. return SymBoolArgument(name=node.name)
  489. elif isinstance(val, ScriptObject):
  490. return CustomObjArgument(name=node.name, class_fqn=val._type().qualified_name()) # type: ignore[attr-defined]
  491. elif isinstance(val, FakeScriptObject):
  492. return CustomObjArgument(
  493. name=node.name, class_fqn=val.script_class_name, fake_val=val
  494. )
  495. elif isinstance(val, (int, bool, str, float, type(None))):
  496. return ConstantArgument(name=node.name, value=val)
  497. else:
  498. raise AssertionError(
  499. f"Encountered an unsupported object of type {type(val)} "
  500. f"while writing the metadata for exported program"
  501. )
  502. def _convert_to_export_graph_signature(
  503. graph_signature: "GraphSignature",
  504. gm: "torch.fx.GraphModule",
  505. non_persistent_buffers: set[str],
  506. ) -> "ExportGraphSignature":
  507. from torch.utils import _pytree as pytree
  508. is_joint = graph_signature.backward_signature is not None
  509. # unpack objects
  510. user_inputs = set(graph_signature.user_inputs)
  511. inputs_to_parameters = graph_signature.inputs_to_parameters
  512. inputs_to_buffers = graph_signature.inputs_to_buffers
  513. user_outputs = set(graph_signature.user_outputs)
  514. buffer_mutations = graph_signature.buffers_to_mutate
  515. parameter_mutations = graph_signature.parameters_to_mutate
  516. user_input_mutations = graph_signature.user_inputs_to_mutate
  517. grad_params = (
  518. graph_signature.backward_signature.gradients_to_parameter # type: ignore[union-attr]
  519. if is_joint
  520. else {}
  521. )
  522. grad_user_inputs = (
  523. graph_signature.backward_signature.gradients_to_user_inputs # type: ignore[union-attr]
  524. if is_joint
  525. else {}
  526. )
  527. loss_output = (
  528. graph_signature.backward_signature.loss_output # type: ignore[union-attr]
  529. if is_joint
  530. else None
  531. )
  532. input_tokens = graph_signature.input_tokens
  533. output_tokens = graph_signature.output_tokens
  534. inputs = [
  535. _make_argument_spec(node, input_tokens)
  536. for node in gm.graph.nodes
  537. if node.op == "placeholder"
  538. ]
  539. outputs = [
  540. _make_argument_spec(node, output_tokens)
  541. for node in pytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args)
  542. ]
  543. def to_input_spec(inp: ArgumentSpec) -> InputSpec:
  544. if isinstance(inp, TokenArgument):
  545. return InputSpec(kind=InputKind.TOKEN, arg=inp, target=None)
  546. if not isinstance(inp, TensorArgument):
  547. return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None)
  548. name = inp.name
  549. if name in user_inputs:
  550. return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None)
  551. elif name in inputs_to_parameters:
  552. return InputSpec(
  553. kind=InputKind.PARAMETER,
  554. arg=inp,
  555. target=inputs_to_parameters[name], # type: ignore[index]
  556. )
  557. elif name in inputs_to_buffers:
  558. return InputSpec(
  559. kind=InputKind.BUFFER,
  560. arg=inp,
  561. target=inputs_to_buffers[name], # type: ignore[index]
  562. persistent=(inputs_to_buffers[name] not in non_persistent_buffers), # type: ignore[index]
  563. )
  564. else:
  565. raise AssertionError(f"Unknown tensor input kind: {name}")
  566. def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec:
  567. if isinstance(o, TokenArgument):
  568. return OutputSpec(kind=OutputKind.TOKEN, arg=o, target=None)
  569. if not isinstance(o, TensorArgument):
  570. return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
  571. name = o.name
  572. if idx < len(buffer_mutations) + len(parameter_mutations) + len(
  573. user_input_mutations
  574. ) + len(output_tokens):
  575. if name in buffer_mutations:
  576. return OutputSpec(
  577. kind=OutputKind.BUFFER_MUTATION,
  578. arg=o,
  579. target=buffer_mutations[name], # type: ignore[index]
  580. )
  581. elif name in parameter_mutations:
  582. return OutputSpec(
  583. kind=OutputKind.PARAMETER_MUTATION,
  584. arg=o,
  585. target=parameter_mutations[name], # type: ignore[index]
  586. )
  587. elif name in user_input_mutations:
  588. return OutputSpec(
  589. kind=OutputKind.USER_INPUT_MUTATION,
  590. arg=o,
  591. target=user_input_mutations[name], # type: ignore[index]
  592. )
  593. else:
  594. raise AssertionError(f"Unknown tensor mutation kind: {name}")
  595. else:
  596. if name in user_outputs:
  597. return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
  598. elif name in grad_params:
  599. return OutputSpec(
  600. kind=OutputKind.GRADIENT_TO_PARAMETER,
  601. arg=o,
  602. target=grad_params[name],
  603. )
  604. elif name in grad_user_inputs:
  605. return OutputSpec(
  606. kind=OutputKind.GRADIENT_TO_USER_INPUT,
  607. arg=o,
  608. target=grad_user_inputs[name],
  609. )
  610. elif name == loss_output:
  611. return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None)
  612. else:
  613. raise AssertionError(f"Unknown tensor output kind: {name}")
  614. input_specs = [to_input_spec(inp) for inp in inputs]
  615. output_specs = [to_output_spec(idx, o) for idx, o in enumerate(outputs)]
  616. return ExportGraphSignature(input_specs=input_specs, output_specs=output_specs)