lowering.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import torch
  2. from torch._inductor.constant_folding import constant_fold
  3. from torch._inductor.fx_passes.freezing_patterns import freezing_passes
  4. __all__ = [
  5. "lower_pt2e_quantized_to_x86",
  6. ]
  7. def lower_pt2e_quantized_to_x86(
  8. model: torch.fx.GraphModule,
  9. example_inputs: tuple[torch.Tensor, ...],
  10. ) -> torch.fx.GraphModule:
  11. """Lower a PT2E-qantized model to x86 backend.
  12. Args:
  13. * `model` (torch.fx.GraphModule): a model quantized by PT2E quantization flow.
  14. * `example_inputs` (tuple[torch.Tensor, ...]): example inputs for the model.
  15. Return:
  16. A GraphModule lowered to x86 backend.
  17. """
  18. def _post_autograd_decomp_table(): # type: ignore[no-untyped-def]
  19. decomp_table = torch.export.default_decompositions()
  20. # if we are post-autograd, we shouldn't
  21. # decomp prim ops.
  22. for k in list(decomp_table.keys()):
  23. if not torch._export.utils._is_cia_op(k):
  24. del decomp_table[k]
  25. return decomp_table
  26. def _node_replace(m): # type: ignore[no-untyped-def]
  27. # Replace aten.t(x) with aten.permute(x, [1, 0])
  28. aten = torch.ops.aten
  29. g = m.graph
  30. for node in g.nodes:
  31. if node.target == aten.t.default:
  32. with g.inserting_before(node):
  33. x = node.args[0]
  34. dims = [1, 0]
  35. perm_node = g.call_function(aten.permute.default, args=(x, dims))
  36. node.replace_all_uses_with(perm_node)
  37. g.erase_node(node)
  38. g.lint()
  39. m.recompile()
  40. lowered_model = (
  41. torch.export.export_for_training(model, example_inputs, strict=True)
  42. .run_decompositions(_post_autograd_decomp_table())
  43. .module()
  44. )
  45. _node_replace(lowered_model)
  46. freezing_passes(lowered_model, example_inputs)
  47. constant_fold(lowered_model)
  48. return lowered_model