verifier.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. # mypy: allow-untyped-defs
  2. import inspect
  3. import math
  4. import operator
  5. from collections.abc import Iterable
  6. from typing import Any, final, TYPE_CHECKING
  7. import torch
  8. from torch._ops import HigherOrderOperator, OpOverload
  9. from torch._subclasses.fake_tensor import FakeTensor
  10. from torch.export.graph_signature import (
  11. CustomObjArgument,
  12. InputKind,
  13. SymBoolArgument,
  14. SymFloatArgument,
  15. SymIntArgument,
  16. TensorArgument,
  17. TokenArgument,
  18. )
  19. from torch.fx import GraphModule
  20. if TYPE_CHECKING:
  21. from torch.export.exported_program import ExportedProgram
  22. class SpecViolationError(Exception):
  23. pass
  24. def is_functional(op: OpOverload) -> bool:
  25. return not op._schema.is_mutable
  26. def _check_has_fake_tensor(node: torch.fx.Node) -> None:
  27. # TODO(angelayi): remove this in favor of _check_val
  28. return _check_val(node)
  29. def _check_val(node: torch.fx.Node) -> None:
  30. from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt
  31. def _check_correct_val(val):
  32. if val is None:
  33. return True
  34. elif isinstance(val, (int, bool, str, float)):
  35. return True
  36. elif isinstance(
  37. val, (torch.memory_format, torch.dtype, torch.device, torch.layout)
  38. ):
  39. return True
  40. elif isinstance(
  41. val, (FakeTensor, torch.Tensor)
  42. ): # TODO(zhxchen17) Remove Tensor.
  43. return True
  44. elif isinstance(val, (SymInt, SymFloat, SymBool)):
  45. return True
  46. elif isinstance(val, CustomObjArgument):
  47. return True
  48. elif isinstance(val, Iterable):
  49. return all(_check_correct_val(x) for x in val)
  50. return False
  51. def _no_returns(op):
  52. if not isinstance(op, OpOverload):
  53. return False
  54. return len(op._schema.returns) == 0
  55. if "val" not in node.meta:
  56. if node.op == "call_function" and _no_returns(node.target):
  57. return
  58. raise SpecViolationError(f"Node.meta {node.name} is missing val field.")
  59. val = node.meta["val"]
  60. if not _check_correct_val(val):
  61. raise SpecViolationError(f"Node.meta {node.name} has invalid val field {val}")
  62. def _check_torch_fn(node: torch.fx.Node) -> None:
  63. torch_fn = node.meta.get("torch_fn")
  64. if torch_fn is None:
  65. raise SpecViolationError(
  66. f"Unable to find torch_fn metadata for node {node.name}"
  67. )
  68. if (
  69. not isinstance(torch_fn, tuple)
  70. and isinstance(torch_fn[0], str)
  71. and isinstance(torch_fn[1], str)
  72. ):
  73. raise SpecViolationError(
  74. f"Node.meta {node.name} has invalid torch_fn field {torch_fn}"
  75. )
  76. class _VerifierMeta(type):
  77. _registry: dict[str, type["Verifier"]] = {}
  78. def __new__(metacls, name, bases, attrs):
  79. if bases:
  80. if "check" in attrs or "_check_graph_module" in attrs:
  81. raise SyntaxError("Overriding method check is not allowed.")
  82. assert "dialect" in attrs and attrs["dialect"] != "ATEN"
  83. else:
  84. assert "check" in attrs
  85. assert "_check_graph_module" in attrs
  86. assert attrs["dialect"] == "ATEN"
  87. assert isinstance(attrs["dialect"], str)
  88. ret = type.__new__(metacls, name, bases, attrs)
  89. metacls._registry[attrs["dialect"]] = ret # type: ignore[assignment]
  90. return ret
  91. def getattr_recursive(obj: Any, target: str) -> Any:
  92. target_atoms = target.split(".")
  93. attr_itr = obj
  94. for i, atom in enumerate(target_atoms):
  95. if not hasattr(attr_itr, atom):
  96. raise RuntimeError(
  97. f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
  98. )
  99. attr_itr = getattr(attr_itr, atom)
  100. return attr_itr
  101. class Verifier(metaclass=_VerifierMeta):
  102. dialect = "ATEN"
  103. def allowed_builtin_ops(self) -> list:
  104. return [
  105. operator.getitem,
  106. operator.add,
  107. operator.mul,
  108. operator.sub,
  109. operator.truediv,
  110. operator.ge,
  111. operator.le,
  112. operator.gt,
  113. operator.lt,
  114. operator.eq,
  115. operator.ne,
  116. operator.floordiv,
  117. operator.mod,
  118. operator.and_,
  119. operator.or_,
  120. operator.not_,
  121. operator.pow,
  122. operator.neg,
  123. operator.abs,
  124. operator.lshift,
  125. operator.rshift,
  126. math.ceil,
  127. math.floor,
  128. math.trunc,
  129. round,
  130. ]
  131. def allowed_op_types(self) -> tuple[type[Any], ...]:
  132. return (OpOverload, HigherOrderOperator)
  133. def allowed_getattr_types(self) -> tuple[type[Any], ...]:
  134. return (torch.fx.GraphModule, torch.utils._pytree.TreeSpec)
  135. def allowed_getattr_types_for_subgm(self) -> tuple[type[Any], ...]:
  136. # subgm in HOP's argument could has have getattr(weight) nodes, thus stateful
  137. return (
  138. torch.fx.GraphModule,
  139. torch.nn.parameter.Parameter,
  140. torch.Tensor, # for buffer and constant tensor
  141. torch.utils._pytree.TreeSpec,
  142. )
  143. def check_valid_op(self, op):
  144. pass
  145. def check_additional(self, gm: GraphModule) -> None:
  146. """
  147. Additional checks that are specific to some dialects.
  148. """
  149. @final
  150. def check(self, ep: "ExportedProgram") -> None:
  151. self._check_graph_module(ep.graph_module)
  152. _verify_exported_program_module_call_graph(ep)
  153. _verify_exported_program_signature(ep)
  154. @final
  155. def _check_graph_module(self, gm: torch.fx.GraphModule) -> None:
  156. def _allowed_getattr_types(is_toplevel_gm) -> tuple[type[Any], ...]:
  157. if is_toplevel_gm:
  158. ret = self.allowed_getattr_types()
  159. else:
  160. ret = self.allowed_getattr_types_for_subgm()
  161. assert not any(t is object for t in ret)
  162. return ret
  163. def _check_valid_op(op) -> None:
  164. def _allowed_builtin_ops() -> list:
  165. ret = self.allowed_builtin_ops()
  166. assert all(inspect.isbuiltin(op) for op in ret)
  167. return ret
  168. def _allowed_op_types() -> tuple[type[Any], ...]:
  169. ret = self.allowed_op_types()
  170. assert not any(t is object for t in ret)
  171. return ret
  172. # TODO Remove this allowlist.
  173. _allowed_torch_functions = (
  174. torch.autograd.grad_mode.set_grad_enabled,
  175. torch.sym_int,
  176. torch.sym_float,
  177. torch.sym_ite,
  178. torch.sym_max,
  179. torch.sym_min,
  180. torch.sym_not,
  181. torch.sym_sqrt,
  182. torch.sym_sum,
  183. torch.export.custom_ops._call_custom_autograd_function_in_pre_dispatch,
  184. # TODO (tmanlaibaatar)
  185. # Predispatch export is able to contain autograd ops.
  186. # These will be modeled as HOO later
  187. torch._C._set_grad_enabled,
  188. torch.amp.autocast_mode._enter_autocast,
  189. torch.amp.autocast_mode._exit_autocast,
  190. torch.fx.experimental.symbolic_shapes.cast_symbool_to_symint_guardless,
  191. torch._functorch.predispatch._add_batch_dim,
  192. torch._functorch.predispatch._remove_batch_dim,
  193. torch._functorch.predispatch._vmap_increment_nesting,
  194. torch._functorch.predispatch._vmap_decrement_nesting,
  195. torch._functorch.predispatch.lazy_load_decompositions,
  196. )
  197. if not isinstance(op, _allowed_op_types()):
  198. if (
  199. op not in _allowed_builtin_ops()
  200. and op not in _allowed_torch_functions
  201. ):
  202. raise SpecViolationError(
  203. f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n"
  204. f"Valid builtin ops: {_allowed_builtin_ops()}"
  205. f"Valid torch functions: {_allowed_torch_functions}"
  206. )
  207. if isinstance(op, OpOverload):
  208. # All ops functional
  209. # TODO (tmanlaibaatar) more proper way is needed here
  210. if self.dialect != "TRAINING" and not is_functional(op):
  211. raise SpecViolationError(f"operator '{op}' is not functional")
  212. self.check_valid_op(op)
  213. for mod in gm.modules():
  214. is_toplevel_gm = mod is gm
  215. if not isinstance(mod, torch.fx.GraphModule):
  216. continue
  217. mod.graph.lint()
  218. for node in mod.graph.nodes:
  219. # TODO(T140410192): should have fake tensor for all dialects
  220. if node.op in {"call_module", "call_method"}:
  221. raise SpecViolationError(
  222. f"call_module is not valid: got a class '{node.target}' ",
  223. )
  224. elif node.op == "call_function":
  225. _check_val(node)
  226. _check_valid_op(node.target)
  227. elif node.op == "get_attr":
  228. if not isinstance(node.target, str):
  229. raise SpecViolationError(
  230. f"Expected get_attr target to be string, but got {type(node.target)}"
  231. )
  232. attr = getattr_recursive(mod, node.target)
  233. if isinstance(attr, torch.nn.Module):
  234. def _is_type(name, ty):
  235. return isinstance(getattr(attr, name, None), ty)
  236. if type(attr).__name__ == "LoweredBackendModule":
  237. if (
  238. _is_type("backend_id", str)
  239. and hasattr(attr, "original_module")
  240. and hasattr(attr, "module_name")
  241. and getattr(attr, "backend_id", None) == "aoti"
  242. ):
  243. continue
  244. if (
  245. _is_type("backend_id", str)
  246. and _is_type("processed_bytes", bytes)
  247. and _is_type("compile_specs", list)
  248. and hasattr(attr, "original_module")
  249. ):
  250. continue
  251. else:
  252. backend_id = getattr(attr, "backend_id", None)
  253. processed_bytes = getattr(attr, "processed_bytes", None)
  254. compile_specs = getattr(attr, "compile_specs", None)
  255. raise SpecViolationError(
  256. f"Invalid get_attr type {type(attr)}. \n"
  257. f"LoweredBackendModule fields: "
  258. f"backend_id(str) : {type(backend_id)}, "
  259. f"processed_bytes(bytes) : {type(processed_bytes)}, "
  260. f"compile_specs(list) : {type(compile_specs)}"
  261. )
  262. elif type(attr).__name__ == "AOTInductorEPModule":
  263. continue
  264. elif type(attr).__name__ == "AOTInductorRunnerWrapper":
  265. continue
  266. if not isinstance(attr, _allowed_getattr_types(is_toplevel_gm)):
  267. raise SpecViolationError(
  268. f"Invalid get_attr type {type(attr)} on target {node.target}. \n"
  269. f"Valid get_attr types: {_allowed_getattr_types(is_toplevel_gm)}"
  270. )
  271. elif node.op == "placeholder":
  272. _check_val(node)
  273. # TODO(zhxchen17)
  274. # elif node.op == "output":
  275. # _check_flattened_outputs()
  276. self.check_additional(gm)
  277. class TrainingIRVerifier(Verifier):
  278. dialect = "TRAINING"
  279. def _verify_exported_program_module_call_graph(exported_program) -> None:
  280. module_call_graph = exported_program.module_call_graph
  281. nodes = {node.name for node in exported_program.graph.nodes}
  282. for entry in module_call_graph:
  283. if entry.signature is not None:
  284. for arg in entry.signature.inputs:
  285. if arg.name and arg.name not in nodes:
  286. raise SpecViolationError(
  287. f"Input {arg.name} does not exist in the graph."
  288. )
  289. for arg in entry.signature.outputs:
  290. if arg.name and arg.name not in nodes:
  291. raise SpecViolationError(
  292. f"Output {arg.name} does not exist in the graph."
  293. )
  294. def _verify_exported_program_signature(exported_program) -> None:
  295. # Check ExportedProgram signature matches
  296. gs = exported_program.graph_signature
  297. # Check every node in the signature exists in the graph
  298. input_node_names = [
  299. node.name for node in exported_program.graph.nodes if node.op == "placeholder"
  300. ]
  301. if len(input_node_names) != len(gs.input_specs):
  302. raise SpecViolationError(
  303. f"Number of graph inputs ({len(input_node_names)}) "
  304. f"does not match number of inputs in the graph signature ({len(gs.input_specs)})"
  305. )
  306. for input_spec, node in zip(gs.input_specs, input_node_names):
  307. if isinstance(
  308. input_spec.arg,
  309. (TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument),
  310. ):
  311. if input_spec.arg.name != node:
  312. raise SpecViolationError(
  313. f"Input spec name {input_spec.arg.name} does not match node name {node}"
  314. )
  315. if input_spec.kind == InputKind.USER_INPUT:
  316. continue
  317. elif input_spec.kind == InputKind.PARAMETER:
  318. if not isinstance(input_spec.arg, TensorArgument):
  319. raise SpecViolationError(
  320. f"Parameter {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
  321. )
  322. if input_spec.target is None:
  323. raise SpecViolationError(
  324. f"InputSpec for {input_spec.name} has no target."
  325. )
  326. param = input_spec.target
  327. if param not in exported_program.state_dict:
  328. raise SpecViolationError(f"Parameter {param} is not in the state dict.")
  329. if not isinstance(exported_program.state_dict[param], torch.nn.Parameter):
  330. raise SpecViolationError(
  331. f"State dict entry for parameter {param} is not an instance of torch.nn.Parameter."
  332. )
  333. elif input_spec.kind == InputKind.BUFFER:
  334. if not isinstance(input_spec.arg, TensorArgument):
  335. raise SpecViolationError(
  336. f"Buffer {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
  337. )
  338. if input_spec.target is None:
  339. raise SpecViolationError(
  340. f"InputSpec for {input_spec.name} has no target."
  341. )
  342. buffer = input_spec.target
  343. if input_spec.persistent is None:
  344. raise SpecViolationError(
  345. f"Buffer {buffer} is missing a persistence flag"
  346. )
  347. if (
  348. input_spec.persistent is True
  349. and buffer not in exported_program.state_dict
  350. ):
  351. raise SpecViolationError(f"Buffer {buffer} is not in the state dict.")
  352. if input_spec.persistent is False and buffer in exported_program.state_dict:
  353. raise SpecViolationError(
  354. f"Non-persistent buffer {buffer} is in the state dict, it should not be."
  355. )
  356. elif input_spec.kind == InputKind.CONSTANT_TENSOR:
  357. if not isinstance(input_spec.arg, TensorArgument):
  358. raise SpecViolationError(
  359. f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
  360. )
  361. if input_spec.target is None:
  362. raise SpecViolationError(
  363. f"InputSpec for {input_spec.name} has no target."
  364. )
  365. tensor_const = input_spec.target
  366. if tensor_const not in exported_program.constants:
  367. raise SpecViolationError(
  368. f"Constant tensor {tensor_const} is not in the constants dictionary."
  369. )
  370. elif input_spec.kind == InputKind.CUSTOM_OBJ:
  371. if not isinstance(input_spec.arg, CustomObjArgument):
  372. raise SpecViolationError(
  373. f"Custom object {input_spec.name} is not a custom object argument. Found {input_spec.arg} instead."
  374. )
  375. if input_spec.target is None:
  376. raise SpecViolationError(
  377. f"InputSpec for {input_spec.name} has no target."
  378. )
  379. custom_obj = input_spec.target
  380. if custom_obj not in exported_program.constants:
  381. raise SpecViolationError(
  382. f"Custom object {custom_obj} is not in the constants dictionary."
  383. )
  384. elif input_spec.kind == InputKind.TOKEN:
  385. if not isinstance(input_spec.arg, TokenArgument):
  386. raise SpecViolationError(
  387. f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
  388. )
  389. else:
  390. raise SpecViolationError(f"Unknown InputKind {input_spec.kind}.")
  391. # Check outputs
  392. output_node = list(exported_program.graph.nodes)[-1]
  393. assert output_node.op == "output"
  394. output_nodes = [
  395. arg.name if isinstance(arg, torch.fx.Node) else arg
  396. for arg in output_node.args[0]
  397. ]
  398. if len(output_nodes) != len(gs.output_specs):
  399. raise SpecViolationError(
  400. f"Number of output nodes {len(output_nodes)} is different "
  401. "Than the number of outputs specified by the graph signature: \n"
  402. f"Number of mutated buffers: {len(gs.buffers_to_mutate)}. \n"
  403. f"Number of user outputs: {len(gs.user_outputs)}. \n"
  404. )
  405. num_tokens = len(gs.output_tokens)
  406. end = (
  407. len(gs.buffers_to_mutate)
  408. + len(gs.parameters_to_mutate)
  409. + len(gs.user_inputs_to_mutate)
  410. + num_tokens
  411. )
  412. mutate_nodes: list[str] = output_nodes[num_tokens:end]
  413. user_output_nodes = output_nodes[end : end + len(gs.user_outputs)]
  414. for mutation_node in mutate_nodes:
  415. if mutation_node in gs.buffers_to_mutate:
  416. if gs.buffers_to_mutate[mutation_node] not in gs.buffers:
  417. raise SpecViolationError(
  418. f"Buffer output {mutation_node} does not point to a buffer that exists. \n"
  419. f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n"
  420. f"Buffer nodes available: {gs.buffers} \n"
  421. )
  422. elif mutation_node in gs.parameters_to_mutate:
  423. if gs.parameters_to_mutate[mutation_node] not in gs.parameters:
  424. raise SpecViolationError(
  425. f"Parameter output {mutation_node} does not point to a parameter that exists. \n"
  426. f"Dict of parameters that are mutated, in order: {gs.parameters_to_mutate} \n"
  427. f"Parameter nodes available: {gs.parameters} \n"
  428. )
  429. elif mutation_node in gs.user_inputs_to_mutate:
  430. if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs:
  431. raise SpecViolationError(
  432. f"User input output {mutation_node} does not point to a user input that exists. \n"
  433. f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n"
  434. f"User input nodes available: {gs.user_inputs} \n"
  435. )
  436. else:
  437. raise SpecViolationError(
  438. f"Mutation node {mutation_node} is neither a buffer nor a user input. "
  439. f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}"
  440. )
  441. for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs):
  442. if user_output_node != user_output_name:
  443. raise SpecViolationError(
  444. f"User output {user_output_node} is not in the correct "
  445. "order or is not found in the "
  446. f"exported program's user_output list: {gs.user_outputs}. "
  447. )
  448. def load_verifier(dialect: str) -> type[Verifier]:
  449. if dialect == "ATEN" or dialect == "":
  450. return _VerifierMeta._registry.get(dialect, Verifier)
  451. return _VerifierMeta._registry[dialect]