prepare.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. # mypy: allow-untyped-decorators
  2. # mypy: allow-untyped-defs
  3. from typing import Optional
  4. import torch
  5. from torch.backends._nnapi.serializer import _NnapiSerializer
  6. ANEURALNETWORKS_PREFER_LOW_POWER = 0
  7. ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1
  8. ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2
  9. class NnapiModule(torch.nn.Module):
  10. """Torch Module that wraps an NNAPI Compilation.
  11. This module handles preparing the weights, initializing the
  12. NNAPI TorchBind object, and adjusting the memory formats
  13. of all inputs and outputs.
  14. """
  15. # _nnapi.Compilation is defined
  16. comp: Optional[torch.classes._nnapi.Compilation] # type: ignore[name-defined]
  17. weights: list[torch.Tensor]
  18. out_templates: list[torch.Tensor]
  19. def __init__(
  20. self,
  21. shape_compute_module: torch.nn.Module,
  22. ser_model: torch.Tensor,
  23. weights: list[torch.Tensor],
  24. inp_mem_fmts: list[int],
  25. out_mem_fmts: list[int],
  26. compilation_preference: int,
  27. relax_f32_to_f16: bool,
  28. ):
  29. super().__init__()
  30. self.shape_compute_module = shape_compute_module
  31. self.ser_model = ser_model
  32. self.weights = weights
  33. self.inp_mem_fmts = inp_mem_fmts
  34. self.out_mem_fmts = out_mem_fmts
  35. self.out_templates = []
  36. self.comp = None
  37. self.compilation_preference = compilation_preference
  38. self.relax_f32_to_f16 = relax_f32_to_f16
  39. @torch.jit.export
  40. def init(self, args: list[torch.Tensor]):
  41. assert self.comp is None
  42. self.out_templates = self.shape_compute_module.prepare(self.ser_model, args) # type: ignore[operator]
  43. self.weights = [w.contiguous() for w in self.weights]
  44. comp = torch.classes._nnapi.Compilation()
  45. comp.init2(
  46. self.ser_model,
  47. self.weights,
  48. self.compilation_preference,
  49. self.relax_f32_to_f16,
  50. )
  51. self.comp = comp
  52. def forward(self, args: list[torch.Tensor]) -> list[torch.Tensor]:
  53. if self.comp is None:
  54. self.init(args)
  55. comp = self.comp
  56. assert comp is not None
  57. outs = [torch.empty_like(out) for out in self.out_templates]
  58. assert len(args) == len(self.inp_mem_fmts)
  59. fixed_args = []
  60. for idx in range(len(args)):
  61. fmt = self.inp_mem_fmts[idx]
  62. # These constants match the values in DimOrder in serializer.py
  63. # TODO: See if it's possible to use those directly.
  64. if fmt == 0:
  65. fixed_args.append(args[idx].contiguous())
  66. elif fmt == 1:
  67. fixed_args.append(args[idx].permute(0, 2, 3, 1).contiguous())
  68. else:
  69. raise ValueError("Invalid mem_fmt")
  70. comp.run(fixed_args, outs)
  71. assert len(outs) == len(self.out_mem_fmts)
  72. for idx in range(len(self.out_templates)):
  73. fmt = self.out_mem_fmts[idx]
  74. # These constants match the values in DimOrder in serializer.py
  75. # TODO: See if it's possible to use those directly.
  76. if fmt in (0, 2):
  77. pass
  78. elif fmt == 1:
  79. outs[idx] = outs[idx].permute(0, 3, 1, 2)
  80. else:
  81. raise ValueError("Invalid mem_fmt")
  82. return outs
  83. def convert_model_to_nnapi(
  84. model,
  85. inputs,
  86. serializer=None,
  87. return_shapes=None,
  88. use_int16_for_qint16=False,
  89. compilation_preference=ANEURALNETWORKS_PREFER_SUSTAINED_SPEED,
  90. relax_f32_to_f16=False,
  91. ):
  92. (
  93. shape_compute_module,
  94. ser_model_tensor,
  95. used_weights,
  96. inp_mem_fmts,
  97. out_mem_fmts,
  98. retval_count,
  99. ) = process_for_nnapi(
  100. model, inputs, serializer, return_shapes, use_int16_for_qint16
  101. )
  102. nnapi_model = NnapiModule(
  103. shape_compute_module,
  104. ser_model_tensor,
  105. used_weights,
  106. inp_mem_fmts,
  107. out_mem_fmts,
  108. compilation_preference,
  109. relax_f32_to_f16,
  110. )
  111. class NnapiInterfaceWrapper(torch.nn.Module):
  112. """NNAPI list-ifying and de-list-ifying wrapper.
  113. NNAPI always expects a list of inputs and provides a list of outputs.
  114. This module allows us to accept inputs as separate arguments.
  115. It returns results as either a single tensor or tuple,
  116. matching the original module.
  117. """
  118. def __init__(self, mod):
  119. super().__init__()
  120. self.mod = mod
  121. wrapper_model_py = NnapiInterfaceWrapper(nnapi_model)
  122. wrapper_model = torch.jit.script(wrapper_model_py)
  123. # TODO: Maybe make these names match the original.
  124. arg_list = ", ".join(f"arg_{idx}" for idx in range(len(inputs)))
  125. if retval_count < 0:
  126. ret_expr = "retvals[0]"
  127. else:
  128. ret_expr = "".join(f"retvals[{idx}], " for idx in range(retval_count))
  129. wrapper_model.define(
  130. f"def forward(self, {arg_list}):\n"
  131. f" retvals = self.mod([{arg_list}])\n"
  132. f" return {ret_expr}\n"
  133. )
  134. return wrapper_model
  135. def process_for_nnapi(
  136. model, inputs, serializer=None, return_shapes=None, use_int16_for_qint16=False
  137. ):
  138. model = torch.jit.freeze(model)
  139. if isinstance(inputs, torch.Tensor):
  140. inputs = [inputs]
  141. serializer = serializer or _NnapiSerializer(
  142. config=None, use_int16_for_qint16=use_int16_for_qint16
  143. )
  144. (
  145. ser_model,
  146. used_weights,
  147. inp_mem_fmts,
  148. out_mem_fmts,
  149. shape_compute_lines,
  150. retval_count,
  151. ) = serializer.serialize_model(model, inputs, return_shapes)
  152. ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32)
  153. # We have to create a new class here every time this function is called
  154. # because module.define adds a method to the *class*, not the instance.
  155. class ShapeComputeModule(torch.nn.Module):
  156. """Code-gen-ed module for tensor shape computation.
  157. module.prepare will mutate ser_model according to the computed operand
  158. shapes, based on the shapes of args. Returns a list of output templates.
  159. """
  160. shape_compute_module = torch.jit.script(ShapeComputeModule())
  161. real_shape_compute_lines = [
  162. "def prepare(self, ser_model: torch.Tensor, args: List[torch.Tensor]) -> List[torch.Tensor]:\n",
  163. ] + [f" {line}\n" for line in shape_compute_lines]
  164. shape_compute_module.define("".join(real_shape_compute_lines))
  165. return (
  166. shape_compute_module,
  167. ser_model_tensor,
  168. used_weights,
  169. inp_mem_fmts,
  170. out_mem_fmts,
  171. retval_count,
  172. )