| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- import torch
- from torch._inductor.constant_folding import constant_fold
- from torch._inductor.fx_passes.freezing_patterns import freezing_passes
- __all__ = [
- "lower_pt2e_quantized_to_x86",
- ]
- def lower_pt2e_quantized_to_x86(
- model: torch.fx.GraphModule,
- example_inputs: tuple[torch.Tensor, ...],
- ) -> torch.fx.GraphModule:
- """Lower a PT2E-qantized model to x86 backend.
- Args:
- * `model` (torch.fx.GraphModule): a model quantized by PT2E quantization flow.
- * `example_inputs` (tuple[torch.Tensor, ...]): example inputs for the model.
- Return:
- A GraphModule lowered to x86 backend.
- """
- def _post_autograd_decomp_table(): # type: ignore[no-untyped-def]
- decomp_table = torch.export.default_decompositions()
- # if we are post-autograd, we shouldn't
- # decomp prim ops.
- for k in list(decomp_table.keys()):
- if not torch._export.utils._is_cia_op(k):
- del decomp_table[k]
- return decomp_table
- def _node_replace(m): # type: ignore[no-untyped-def]
- # Replace aten.t(x) with aten.permute(x, [1, 0])
- aten = torch.ops.aten
- g = m.graph
- for node in g.nodes:
- if node.target == aten.t.default:
- with g.inserting_before(node):
- x = node.args[0]
- dims = [1, 0]
- perm_node = g.call_function(aten.permute.default, args=(x, dims))
- node.replace_all_uses_with(perm_node)
- g.erase_node(node)
- g.lint()
- m.recompile()
- lowered_model = (
- torch.export.export_for_training(model, example_inputs, strict=True)
- .run_decompositions(_post_autograd_decomp_table())
- .module()
- )
- _node_replace(lowered_model)
- freezing_passes(lowered_model, example_inputs)
- constant_fold(lowered_model)
- return lowered_model
|