serializer.py 81 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228
  1. # mypy: allow-untyped-defs
  2. import array
  3. import enum
  4. import functools
  5. import logging
  6. import operator
  7. import struct
  8. import sys
  9. from typing import NamedTuple, Optional
  10. import torch
  11. # TODO: Add type annotations
  12. # TODO: Check tensor types for ops
  13. LOG = logging.getLogger("nnapi_serialize")
  14. class NNAPI_OperandCode:
  15. FLOAT32 = 0
  16. INT32 = 1
  17. UINT32 = 2
  18. TENSOR_FLOAT32 = 3
  19. TENSOR_INT32 = 4
  20. TENSOR_QUANT8_ASYMM = 5
  21. BOOL = 6
  22. TENSOR_QUANT16_SYMM = 7
  23. TENSOR_FLOAT16 = 8
  24. TENSOR_BOOL8 = 9
  25. FLOAT16 = 10
  26. TENSOR_QUANT8_SYMM_PER_CHANNEL = 11
  27. TENSOR_QUANT16_ASYMM = 12
  28. class NNAPI_OperationCode:
  29. ADD = 0
  30. AVERAGE_POOL_2D = 1
  31. CONCATENATION = 2
  32. CONV_2D = 3
  33. DEPTHWISE_CONV_2D = 4
  34. DEPTH_TO_SPACE = 5
  35. DEQUANTIZE = 6
  36. EMBEDDING_LOOKUP = 7
  37. FLOOR = 8
  38. FULLY_CONNECTED = 9
  39. HASHTABLE_LOOKUP = 10
  40. L2_NORMALIZATION = 11
  41. L2_POOL_2D = 12
  42. LOCAL_RESPONSE_NORMALIZATION = 13
  43. LOGISTIC = 14
  44. LSH_PROJECTION = 15
  45. LSTM = 16
  46. MAX_POOL_2D = 17
  47. MUL = 18
  48. RELU = 19
  49. RELU1 = 20
  50. RELU6 = 21
  51. RESHAPE = 22
  52. RESIZE_BILINEAR = 23
  53. RNN = 24
  54. SOFTMAX = 25
  55. SPACE_TO_DEPTH = 26
  56. SVDF = 27
  57. TANH = 28
  58. BATCH_TO_SPACE_ND = 29
  59. DIV = 30
  60. MEAN = 31
  61. PAD = 32
  62. SPACE_TO_BATCH_ND = 33
  63. SQUEEZE = 34
  64. STRIDED_SLICE = 35
  65. SUB = 36
  66. TRANSPOSE = 37
  67. ABS = 38
  68. ARGMAX = 39
  69. ARGMIN = 40
  70. AXIS_ALIGNED_BBOX_TRANSFORM = 41
  71. BIDIRECTIONAL_SEQUENCE_LSTM = 42
  72. BIDIRECTIONAL_SEQUENCE_RNN = 43
  73. BOX_WITH_NMS_LIMIT = 44
  74. CAST = 45
  75. CHANNEL_SHUFFLE = 46
  76. DETECTION_POSTPROCESSING = 47
  77. EQUAL = 48
  78. EXP = 49
  79. EXPAND_DIMS = 50
  80. GATHER = 51
  81. GENERATE_PROPOSALS = 52
  82. GREATER = 53
  83. GREATER_EQUAL = 54
  84. GROUPED_CONV_2D = 55
  85. HEATMAP_MAX_KEYPOINT = 56
  86. INSTANCE_NORMALIZATION = 57
  87. LESS = 58
  88. LESS_EQUAL = 59
  89. LOG = 60
  90. LOGICAL_AND = 61
  91. LOGICAL_NOT = 62
  92. LOGICAL_OR = 63
  93. LOG_SOFTMAX = 64
  94. MAXIMUM = 65
  95. MINIMUM = 66
  96. NEG = 67
  97. NOT_EQUAL = 68
  98. PAD_V2 = 69
  99. POW = 70
  100. PRELU = 71
  101. QUANTIZE = 72
  102. QUANTIZED_16BIT_LSTM = 73
  103. RANDOM_MULTINOMIAL = 74
  104. REDUCE_ALL = 75
  105. REDUCE_ANY = 76
  106. REDUCE_MAX = 77
  107. REDUCE_MIN = 78
  108. REDUCE_PROD = 79
  109. REDUCE_SUM = 80
  110. ROI_ALIGN = 81
  111. ROI_POOLING = 82
  112. RSQRT = 83
  113. SELECT = 84
  114. SIN = 85
  115. SLICE = 86
  116. SPLIT = 87
  117. SQRT = 88
  118. TILE = 89
  119. TOPK_V2 = 90
  120. TRANSPOSE_CONV_2D = 91
  121. UNIDIRECTIONAL_SEQUENCE_LSTM = 92
  122. UNIDIRECTIONAL_SEQUENCE_RNN = 93
  123. RESIZE_NEAREST_NEIGHBOR = 94
  124. class NNAPI_FuseCode:
  125. FUSED_NONE = 0
  126. FUSED_RELU = 1
  127. FUSED_RELU1 = 2
  128. FUSED_RELU6 = 3
  129. class OperandValueSourceType:
  130. IMMEDIATE = 0
  131. NUMBERED_BUFFER = 2
  132. NUMBERED_MEMORY = 3
  133. # Scalar types that appear explicitly in models.
  134. # These must be kept in sync with
  135. # AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
  136. # TODO: Expose these directly to Python to avoid maintaining this list.
  137. class TorchScalarTypes(enum.Enum):
  138. QUINT8 = 13
  139. def approx_equal(lhs, rhs, tolerance=1e-6):
  140. return abs(lhs - rhs) <= tolerance * min(lhs, rhs)
  141. def tensor_size(op_type, dims):
  142. ITEM_SIZES = {
  143. NNAPI_OperandCode.TENSOR_FLOAT32: 4,
  144. NNAPI_OperandCode.TENSOR_INT32: 4,
  145. NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: 1,
  146. NNAPI_OperandCode.TENSOR_QUANT16_SYMM: 2,
  147. NNAPI_OperandCode.TENSOR_QUANT16_ASYMM: 2,
  148. }
  149. size = ITEM_SIZES[op_type]
  150. for d in dims:
  151. size *= d
  152. return size
  153. def change_element(tup, index, value):
  154. ls = list(tup)
  155. ls[index] = value
  156. return tuple(ls)
  157. class ConvPoolArgs2d(NamedTuple):
  158. """Configuration arguments for a convolution."""
  159. kernel_h: int
  160. kernel_w: int
  161. stride_h: int
  162. stride_w: int
  163. pad_t: int
  164. pad_b: int
  165. pad_l: int
  166. pad_r: int
  167. dilation_h: int
  168. dilation_w: int
  169. group: int
  170. class DimOrder(enum.Enum):
  171. PRESUMED_CONTIGUOUS = 0
  172. CHANNELS_LAST = 1
  173. SCALAR_OR_VECTOR = 2
  174. UNKNOWN_CONSTANT = 999
  175. class Operand(NamedTuple):
  176. """Representation of an NNAPI operand."""
  177. # NNAPI operand type. One of NNAPI_OperandCode.
  178. # TODO: Make this an enum.
  179. op_type: int
  180. # This is always the PyTorch shape, which is NCHW for feature maps.
  181. # The actual NNAPI operand might have a transposed shape.
  182. # we use 0 for load time dynamic shapes & -1 for runtime dynamic shapes
  183. shape: tuple[int, ...]
  184. # Specifies how the shape of the operand that we define in NNAPI
  185. # relates to the shape we track above.
  186. # - PRESUMED_CONTIGUOUS: physical NNAPI operand will exactly match
  187. # the shape of the PyTorch tensor.
  188. # - CHANNELS_LAST: The PyTorch tensor is expected to be NCHW, and
  189. # the NNAPI operand will be represented explicitly as NHWC.
  190. dim_order: DimOrder
  191. # Quantization params
  192. scale: float
  193. zero_point: int
  194. def use_nchw(self):
  195. if self.dim_order is DimOrder.PRESUMED_CONTIGUOUS:
  196. return True
  197. if self.dim_order is DimOrder.CHANNELS_LAST:
  198. return False
  199. raise Exception("Unknown dim order") # noqa: TRY002
  200. def broadcast_shapes(shape1, shape2):
  201. assert len(shape1) > 0
  202. assert len(shape2) > 0
  203. s1 = list(shape1)
  204. s2 = list(shape2)
  205. # TODO: Support non-equal-rank broadcast where semantics match.
  206. # This can be tricky for NHWC tensors because dimension orders
  207. # don't match between PT and NNAPI, even though semantics match.
  208. if len(s1) > len(s2):
  209. # s2 = [1] * (len(s1) - len(s2)) + s2
  210. raise Exception( # noqa: TRY002
  211. "Non-equal-rank broadcast is not supported yet."
  212. ) # noqa: TRY002
  213. if len(s2) > len(s1):
  214. # s3 = [1] * (len(s2) - len(s1)) + s1
  215. raise Exception( # noqa: TRY002
  216. "Non-equal-rank broadcast is not supported yet."
  217. ) # noqa: TRY002
  218. ret = []
  219. for d1, d2 in zip(s1, s2):
  220. if d1 == 1:
  221. ret.append(d2)
  222. elif d2 == 1:
  223. ret.append(d1)
  224. elif d1 == d2:
  225. ret.append(d1)
  226. else:
  227. raise Exception( # noqa: TRY002
  228. f"Cannot broadcast shapes: {shape1} and {shape2}"
  229. ) # noqa: TRY002
  230. return tuple(ret)
  231. def get_conv_pool_shape(image_shape, args, out_ch, transpose):
  232. batch, _in_c, in_h, in_w = image_shape
  233. # TODO: Handle dilation
  234. if args.dilation_h != 1 or args.dilation_w != 1:
  235. raise Exception("Dilation not supported yet.") # noqa: TRY002
  236. if transpose:
  237. out_h = (in_h - 1) * args.stride_h + args.kernel_h - args.pad_t - args.pad_b
  238. out_w = (in_w - 1) * args.stride_w + args.kernel_w - args.pad_l - args.pad_l
  239. else:
  240. out_h = (in_h - args.kernel_h + args.pad_t + args.pad_b) // args.stride_h + 1
  241. out_w = (in_w - args.kernel_w + args.pad_l + args.pad_r) // args.stride_w + 1
  242. # Handle variable-sized tensors.
  243. if in_h == 0:
  244. out_h = 0
  245. if in_w == 0:
  246. out_w = 0
  247. out_shape = (batch, out_ch, out_h, out_w)
  248. return out_shape
  249. def fix_shape(shape, dim_order):
  250. # Return the actual shape that an operand should have in NNAPI,
  251. # given a PyTorch shape and dimension order. This is where we
  252. # convert from PyTorch's "always NCHW" shape to explicit NHWC.
  253. if dim_order is DimOrder.PRESUMED_CONTIGUOUS:
  254. return shape
  255. if dim_order is DimOrder.CHANNELS_LAST:
  256. return tuple([shape[0]] + list(shape[2:]) + [shape[1]])
  257. if dim_order is DimOrder.SCALAR_OR_VECTOR:
  258. assert len(shape) == 0 or len(shape) == 1
  259. return shape
  260. if dim_order is DimOrder.UNKNOWN_CONSTANT:
  261. # XXX think this through
  262. return shape
  263. raise Exception(f"Bad dim_order: {dim_order!r}.") # noqa: TRY002
  264. def reverse_map_dim(dim_order, d):
  265. # Return the original PyTorch dimension position for a given dimension.
  266. # d should be the dimension that NNAPI will see.
  267. # reverse_map_dim(PRESUMED_CONTIGUOUS, x) == x
  268. # reverse_map_dim(CHANNELS_LAST, 3) == 1
  269. if dim_order in (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.SCALAR_OR_VECTOR):
  270. return d
  271. assert dim_order is DimOrder.CHANNELS_LAST
  272. return [0, 2, 3, 1][d]
  273. def flex_name(op_id, dim):
  274. # Return the local variable name for the computed flexible size
  275. # for a given op and dimension.
  276. return f"s_{op_id}_{dim}"
  277. class _NnapiSerializer:
  278. def __init__(self, config, use_int16_for_qint16=False):
  279. self.operands = []
  280. self.values = []
  281. self.operations = []
  282. self.value_data = []
  283. self.operation_args = []
  284. self.inputs = []
  285. self.outputs = []
  286. self.flexible_shape_computation_lines = []
  287. self.modules = {}
  288. self.constants = {}
  289. self.tensor_sequences = {}
  290. self.jitval_operand_map = {}
  291. self.cached_immediates = {}
  292. self.used_weights = []
  293. self.weight_offset = 0
  294. self.use_int16_for_qint16 = use_int16_for_qint16
  295. if config is None:
  296. config = {}
  297. def get_next_operand_id(self):
  298. return len(self.operands)
  299. # Add a tensor operand corresponding to a JIT Value.
  300. # Returns the NNAPI operand ID. Can be looked up later with
  301. # get_tensor_operand_by_jitval.
  302. def add_tensor_operand(self, jitval, oper):
  303. assert isinstance(oper, Operand)
  304. if jitval in self.jitval_operand_map:
  305. raise Exception(f"Duplicate tensor: {jitval!r}") # noqa: TRY002
  306. operand_id = self.get_next_operand_id()
  307. self.operands.append(oper)
  308. self.jitval_operand_map[jitval] = operand_id
  309. return operand_id
  310. # Add a tensor operand that does not correspond to a JIT Value.
  311. # Useful for cases where multiple NNAPI operands are required
  312. # to implement one JIT IR node. Returns the NNAPI operand ID.
  313. def add_anonymous_tensor_operand(self, oper):
  314. assert isinstance(oper, Operand)
  315. operand_id = self.get_next_operand_id()
  316. self.operands.append(oper)
  317. return operand_id
  318. def torch_tensor_to_operand(self, tensor, dim_order):
  319. dtype = str(tensor.dtype).replace("torch.", "")
  320. scale = 0.0
  321. zero_point = 0
  322. if dtype == "float32":
  323. op_type = NNAPI_OperandCode.TENSOR_FLOAT32
  324. elif dtype == "int32":
  325. op_type = NNAPI_OperandCode.TENSOR_INT32
  326. elif dtype == "quint8":
  327. op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
  328. scale = tensor.q_scale()
  329. zero_point = tensor.q_zero_point()
  330. elif dtype == "qint32":
  331. op_type = NNAPI_OperandCode.TENSOR_INT32
  332. scale = tensor.q_scale()
  333. zero_point = tensor.q_zero_point()
  334. assert zero_point == 0
  335. elif dtype == "int16":
  336. if self.use_int16_for_qint16:
  337. nnapi_dtype = getattr(tensor, "nnapi_dtype", None)
  338. op_codes = (
  339. NNAPI_OperandCode.TENSOR_QUANT16_SYMM,
  340. NNAPI_OperandCode.TENSOR_QUANT16_ASYMM,
  341. )
  342. if nnapi_dtype in op_codes:
  343. op_type = nnapi_dtype
  344. scale = tensor.nnapi_scale
  345. zero_point = tensor.nnapi_zero_point
  346. else:
  347. raise Exception( # noqa: TRY002
  348. f"`nnapi_type` needs to be one of {op_codes} for `int16`"
  349. )
  350. else:
  351. raise Exception( # noqa: TRY002
  352. "`int16` isn't supported. If you're trying to represent NNAPI"
  353. " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`"
  354. )
  355. else:
  356. raise Exception( # noqa: TRY002
  357. f"Can't handle input with dtype '{tensor.dtype}'"
  358. ) # noqa: TRY002
  359. return Operand(
  360. shape=tuple(tensor.shape),
  361. op_type=op_type,
  362. dim_order=dim_order,
  363. scale=scale,
  364. zero_point=zero_point,
  365. )
  366. def add_tensor_operand_for_input(self, arg_idx, jitval, tensor):
  367. dim_order = (
  368. DimOrder.CHANNELS_LAST
  369. if getattr(tensor, "nnapi_nhwc", False)
  370. else DimOrder.PRESUMED_CONTIGUOUS
  371. )
  372. toper = self.torch_tensor_to_operand(tensor, dim_order)
  373. operand_id = self.add_tensor_operand(jitval, toper)
  374. self.inputs.append(operand_id)
  375. for dim, size in enumerate(tensor.shape):
  376. if size == 0:
  377. self.compute_operand_shape(
  378. operand_id, dim, f"args[{arg_idx}].shape[{dim}]"
  379. )
  380. return operand_id
  381. def add_tensor_operand_for_weight(
  382. self, tensor, dim_order=DimOrder.UNKNOWN_CONSTANT
  383. ):
  384. toper = self.torch_tensor_to_operand(tensor, dim_order)
  385. operand_id = len(self.operands)
  386. self.operands.append(toper)
  387. tsize = tensor_size(toper.op_type, toper.shape)
  388. self.values.append((operand_id, OperandValueSourceType.NUMBERED_BUFFER))
  389. buf_num = len(self.used_weights)
  390. offset = 0
  391. self.value_data.append(struct.pack("iii", buf_num, offset, tsize))
  392. # For NHWC NNAPI op, lay out data in the same dim order by permuting torch tensor
  393. if dim_order == DimOrder.CHANNELS_LAST:
  394. tensor = tensor.permute(0, 2, 3, 1)
  395. self.used_weights.append(tensor)
  396. return operand_id
  397. def add_immediate_operand(self, code, value, dims):
  398. assert isinstance(dims, tuple)
  399. cache_key = (code, value)
  400. if cache_key not in self.cached_immediates:
  401. operand_id = len(self.operands)
  402. self.operands.append(Operand(code, dims, DimOrder.SCALAR_OR_VECTOR, 0.0, 0))
  403. self.values.append((operand_id, OperandValueSourceType.IMMEDIATE))
  404. self.value_data.append(value)
  405. self.cached_immediates[cache_key] = operand_id
  406. return self.cached_immediates[cache_key]
  407. def add_immediate_int_scalar(self, value):
  408. return self.add_immediate_operand(
  409. NNAPI_OperandCode.INT32, struct.pack("i", value), ()
  410. )
  411. def add_immediate_float_scalar(self, value):
  412. return self.add_immediate_operand(
  413. NNAPI_OperandCode.FLOAT32, struct.pack("f", value), ()
  414. )
  415. def add_immediate_bool_scalar(self, value):
  416. return self.add_immediate_operand(
  417. NNAPI_OperandCode.BOOL, b"\x01" if value else b"\x00", ()
  418. )
  419. def add_immediate_int_vector(self, value):
  420. return self.add_immediate_operand(
  421. NNAPI_OperandCode.TENSOR_INT32,
  422. array.array("i", value).tobytes(),
  423. (len(value),),
  424. )
  425. def has_operand_for_jitval(self, jitval):
  426. return jitval in self.jitval_operand_map
  427. def get_tensor_operand_by_jitval(self, jitval):
  428. operand_id = self.jitval_operand_map[jitval]
  429. return (operand_id, self.operands[operand_id])
  430. def get_tensor_operand_by_jitval_fixed_size(self, jitval):
  431. op_id, oper = self.get_tensor_operand_by_jitval(jitval)
  432. for s in oper.shape:
  433. if s == 0:
  434. # TODO: Improve this error message, possibly after converting
  435. # many callsites to support flexible size.
  436. raise Exception( # noqa: TRY002
  437. "Flexible size is not supported for this operand."
  438. ) # noqa: TRY002
  439. if s < 0:
  440. # runtime flex
  441. LOG.warning("Operand %s has runtime flex shape", oper)
  442. return op_id, oper
  443. def get_tensor_operand_or_constant(
  444. self, jitval, dim_order=DimOrder.PRESUMED_CONTIGUOUS
  445. ):
  446. operand_id = self.jitval_operand_map.get(jitval)
  447. if operand_id is None:
  448. _, value = self.get_constant_value(jitval, "TensorType")
  449. operand_id = self.add_tensor_operand_for_weight(value, dim_order)
  450. return (operand_id, self.operands[operand_id])
  451. def get_tensor_operand_for_weight(self, jitval):
  452. _, value = self.get_constant_value(jitval, "TensorType")
  453. operand_id = self.add_tensor_operand_for_weight(value)
  454. return (operand_id, self.operands[operand_id])
  455. def add_operation(self, opcode, inputs, outputs):
  456. self.operations.append((opcode, len(inputs), len(outputs)))
  457. self.operation_args.extend(inputs + outputs)
  458. def add_tensor_sequence(self, jitval, values):
  459. assert jitval not in self.tensor_sequences
  460. self.tensor_sequences[jitval] = values
  461. def add_constant_value(self, jitval, ctype, value):
  462. assert jitval not in self.constants
  463. self.constants[jitval] = (ctype, value)
  464. def get_constant_value(self, jitval, typekind=None):
  465. record = self.constants.get(jitval)
  466. if record is None:
  467. raise Exception( # noqa: TRY002
  468. f"Could not find constant value for '{jitval!r}'."
  469. ) # noqa: TRY002
  470. ctype, _ = record
  471. if typekind is not None and ctype.kind() != typekind:
  472. raise Exception( # noqa: TRY002
  473. f"Expected constant value of type {typekind}, but got {ctype.kind()} for value '{jitval!r}'"
  474. )
  475. return record
  476. def operand_to_template_torchscript(self, op_id, oper, shape=None):
  477. """Return a TorchScript expression to build a template for a given operand."""
  478. if shape is None:
  479. shape = oper.shape
  480. else:
  481. assert len(shape) == len(oper.shape)
  482. shape_parts = ["("]
  483. for d, s in enumerate(shape):
  484. if s > 0:
  485. # Fixed shape dimension: just add the value.
  486. shape_parts.append(str(s))
  487. elif s == 0:
  488. # Load time flexible shape dimension: it should have been computed in a variable.
  489. shape_parts.append(flex_name(op_id, d))
  490. elif s == -1:
  491. # Runtime flexible shape
  492. shape_parts.append("0")
  493. else:
  494. raise Exception( # noqa: TRY002
  495. "Unknown dim value, dimensions should be >= -1"
  496. ) # noqa: TRY002
  497. shape_parts.append(",")
  498. shape_parts.append(")")
  499. shape_code = "".join(shape_parts)
  500. if oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32:
  501. return f"torch.zeros({shape_code}, dtype=torch.float32)"
  502. elif oper.op_type == NNAPI_OperandCode.TENSOR_INT32:
  503. return f"torch.zeros({shape_code}, dtype=torch.int32)"
  504. elif oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
  505. return (
  506. f"torch.quantize_per_tensor("
  507. f"torch.zeros(1), scale={oper.scale}, zero_point={oper.zero_point}, dtype=torch.quint8)"
  508. f".expand({shape_code}).contiguous()"
  509. )
  510. elif oper.op_type in (
  511. NNAPI_OperandCode.TENSOR_QUANT16_ASYMM,
  512. NNAPI_OperandCode.TENSOR_QUANT16_SYMM,
  513. ):
  514. if self.use_int16_for_qint16:
  515. return f"torch.zeros({shape_code}, dtype=torch.int16)"
  516. else:
  517. raise Exception( # noqa: TRY002
  518. "`int16` isn't supported. If you're trying to represent NNAPI"
  519. " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`"
  520. )
  521. raise Exception( # noqa: TRY002
  522. f"Unsupported output operand type: {oper.op_type}"
  523. ) # noqa: TRY002
  524. def forward_operand_shape(self, out_op_id, out_dim, in_op_id, in_dim):
  525. self.compute_operand_shape(out_op_id, out_dim, flex_name(in_op_id, in_dim))
  526. def compute_operand_shape(self, op_id, dim, expr):
  527. self.flexible_shape_computation_lines.append(
  528. f"{flex_name(op_id, dim)} = {expr}"
  529. )
  530. def transpose_to_nhwc(self, in_id, oper):
  531. if oper.shape[2:] != (1, 1):
  532. raise Exception( # noqa: TRY002
  533. "Automatic transpose only supported for H,W == 1,1"
  534. ) # noqa: TRY002
  535. out_oper = oper._replace(dim_order=DimOrder.CHANNELS_LAST)
  536. inputs = [None] * 2
  537. inputs[0] = in_id
  538. inputs[1] = self.add_immediate_int_vector([0, 2, 3, 1])
  539. outputs = [None] * 1
  540. outputs[0] = self.add_anonymous_tensor_operand(out_oper)
  541. self.add_operation(NNAPI_OperationCode.TRANSPOSE, inputs, outputs)
  542. return outputs[0], out_oper
  543. # Transpose inputs as necessary to allow broadcasting.
  544. def transpose_for_broadcast(self, in0_id, in0_oper, in1_id, in1_oper):
  545. if in0_oper.dim_order == in1_oper.dim_order:
  546. return in0_id, in0_oper, in1_id, in1_oper
  547. # Assume NHWC is preferred if there is a mismatch.
  548. orders = (in0_oper.dim_order, in1_oper.dim_order)
  549. if orders == (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.CHANNELS_LAST):
  550. return self.transpose_to_nhwc(in0_id, in0_oper) + (in1_id, in1_oper)
  551. if orders == (DimOrder.CHANNELS_LAST, DimOrder.PRESUMED_CONTIGUOUS):
  552. return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper)
  553. raise Exception( # noqa: TRY002
  554. f"Automatic transpose not supported for dim_orders: {in0_oper.dim_order!r}, {in1_oper.dim_order!r}"
  555. )
  556. def get_size_arg(self, jitval):
  557. ctype, value = self.get_constant_value(jitval)
  558. if ctype.kind() == "ListType":
  559. assert ctype.getElementType().kind() == "IntType"
  560. return value
  561. raise Exception( # noqa: TRY002
  562. f"Can't handle size arg of type '{ctype!r}' for '{jitval!r}'"
  563. ) # noqa: TRY002
  564. def get_conv_pool_args_2d_from_pack(self, kernel_size, packed_config):
  565. pc = [i.item() for i in packed_config]
  566. assert pc[0] == 2
  567. strides = [pc[1], pc[2]]
  568. paddings = [pc[3], pc[4]]
  569. dilations = [pc[5], pc[6]]
  570. output_padding = [pc[7], pc[8]]
  571. group_num = pc[9]
  572. assert len(pc) == 11
  573. assert output_padding == [0, 0]
  574. return self.get_conv_pool_args_2d_common(
  575. kernel_size, strides, paddings, dilations, group_num
  576. )
  577. def get_conv_pool_args_2d_from_jit(
  578. self, kernel_size, stride, padding, dilation=None, group=None
  579. ):
  580. strides = self.get_size_arg(stride)
  581. paddings = self.get_size_arg(padding)
  582. if dilation is None:
  583. dilations = [1, 1]
  584. else:
  585. dilations = self.get_size_arg(dilation)
  586. if group is not None:
  587. _, group_num = self.get_constant_value(group, "IntType")
  588. else:
  589. group_num = None
  590. return self.get_conv_pool_args_2d_common(
  591. kernel_size, strides, paddings, dilations, group_num
  592. )
  593. def get_conv_pool_args_2d_common(
  594. self, kernel_size, strides, paddings, dilations, group_num
  595. ):
  596. kernels = list(kernel_size)
  597. assert len(kernels) == 2
  598. assert len(strides) == 2
  599. assert len(paddings) == 2
  600. assert len(dilations) == 2
  601. # NNAPI uses 4 values for padding.
  602. ph, pw = paddings
  603. real_paddings = [ph, ph, pw, pw]
  604. return ConvPoolArgs2d(
  605. *(kernels + strides + real_paddings + dilations + [group_num])
  606. )
  607. def serialize_model(self, model, inputs, return_shapes=None):
  608. self.add_immediate_bool_scalar(False)
  609. self.add_immediate_bool_scalar(True)
  610. inp_dim_orders = []
  611. out_dim_orders = []
  612. self_jitval = next(model.graph.inputs())
  613. self.add_constant_value(self_jitval, self_jitval.type(), model)
  614. for arg_idx, (input_value, input_tensor) in enumerate(
  615. zip(list(model.graph.inputs())[1:], inputs)
  616. ):
  617. op_id = self.add_tensor_operand_for_input(
  618. arg_idx, input_value, input_tensor
  619. )
  620. inp_dim_orders.append(self.operands[op_id].dim_order.value)
  621. for idx, node in enumerate(model.graph.nodes()):
  622. LOG.debug("Processing node #%d: %r", idx, node)
  623. self.add_node(node)
  624. retn = model.graph.return_node()
  625. assert retn.inputsSize() == 1
  626. assert retn.outputsSize() == 0
  627. retn_input = retn.inputsAt(0)
  628. template_return_lines = ["return ["]
  629. if retn_input.type().kind() == "TensorType":
  630. return_values = [retn_input]
  631. retval_count = -1
  632. elif retn_input.type().kind() == "TupleType":
  633. return_values = self.tensor_sequences[retn_input]
  634. retval_count = len(return_values)
  635. else:
  636. raise Exception( # noqa: TRY002
  637. f"Unsupported return type: {retn_input.type()}"
  638. ) # noqa: TRY002
  639. if return_shapes is not None:
  640. assert len(return_shapes) == len(return_values)
  641. for i, v in enumerate(return_values):
  642. op_id = self.jitval_operand_map[v]
  643. self.outputs.append(op_id)
  644. out_dim_orders.append(self.operands[op_id].dim_order.value)
  645. shape = return_shapes[i] if return_shapes else None
  646. template_return_lines.append(
  647. self.operand_to_template_torchscript(op_id, self.operands[op_id], shape)
  648. + ","
  649. )
  650. template_return_lines.append("]")
  651. model = []
  652. version = 1
  653. header = struct.pack(
  654. "iiiiii",
  655. version,
  656. len(self.operands),
  657. len(self.values),
  658. len(self.operations),
  659. len(self.inputs),
  660. len(self.outputs),
  661. )
  662. model.append(header)
  663. serialized_values, serialized_value_data = self.serialize_values()
  664. model.extend(
  665. struct.pack("iifi", t, len(d), s, z) for (t, d, _m, s, z) in self.operands
  666. )
  667. model.extend(serialized_values)
  668. model.extend(struct.pack("iii", *x) for x in self.operations)
  669. # Compact the model so we can get its length so far.
  670. model = [b"".join(model)]
  671. model_offset = len(model[0])
  672. # Model offset is the index into the model (in 32-bit words, not bytes)
  673. # of the next dimension we're about to serialize. If it's 0,
  674. # generate code to mutate it before passing to NNAPI.
  675. assert model_offset % 4 == 0
  676. model_offset = int(model_offset / 4)
  677. for op_id, (_, dims, dim_order, _, _) in enumerate(self.operands):
  678. shape = fix_shape(dims, dim_order)
  679. for d, s in enumerate(shape):
  680. if s == 0:
  681. pt_d = reverse_map_dim(dim_order, d)
  682. self.flexible_shape_computation_lines.append(
  683. f"ser_model[{model_offset}] = {flex_name(op_id, pt_d)}"
  684. )
  685. model_offset += 1
  686. # convert runtime flex shape from -1 to 0
  687. shape = tuple(d if d != -1 else 0 for d in shape)
  688. model.append(self.serialize_ints(shape))
  689. model.extend(serialized_value_data)
  690. model.append(self.serialize_ints(self.operation_args))
  691. model.append(self.serialize_ints(self.inputs))
  692. model.append(self.serialize_ints(self.outputs))
  693. self.flexible_shape_computation_lines.extend(template_return_lines)
  694. return (
  695. array.array("i", b"".join(model)),
  696. self.used_weights,
  697. inp_dim_orders,
  698. out_dim_orders,
  699. self.flexible_shape_computation_lines,
  700. retval_count,
  701. )
  702. def serialize_values(self):
  703. serialized_values = []
  704. serialized_value_data = []
  705. assert len(self.values) == len(self.value_data)
  706. for (op_index, source_type), data in zip(self.values, self.value_data):
  707. source_length = len(data)
  708. # Pad with 0 bytes out to a multiple of 4 for alignment.
  709. physical_length = ((source_length - 1) | 0x3) + 1
  710. padded_data = data + (b"\0" * (physical_length - source_length))
  711. serialized_values.append(
  712. struct.pack("iii", op_index, source_type, source_length)
  713. )
  714. serialized_value_data.append(padded_data)
  715. return serialized_values, serialized_value_data
  716. @staticmethod
  717. def serialize_ints(ints):
  718. return array.array("i", ints).tobytes()
  719. ADDER_MAP = {
  720. "prim::GetAttr": lambda self, node: self.add_getattr(node),
  721. "prim::Constant": lambda self, node: self.add_constant_node(node),
  722. "prim::ListConstruct": lambda self, node: self.add_list_construct(node),
  723. "prim::TupleConstruct": lambda self, node: self.add_tuple_construct(node),
  724. "aten::unsqueeze": lambda self, node: self.add_unsqueeze(node),
  725. "aten::to": lambda self, node: self.add_to(node),
  726. "aten::detach": lambda self, node: self._identity(node),
  727. "aten::reshape": lambda self, node: self.add_reshape(node),
  728. "aten::flatten": lambda self, node: self.add_flatten(node),
  729. "aten::slice": lambda self, node: self.add_slice(node),
  730. "aten::size": lambda self, node: self.add_size(node),
  731. "aten::cat": lambda self, node: self.add_cat(node),
  732. "aten::mean": lambda self, node: self.add_mean(node),
  733. "aten::quantize_per_tensor": lambda self, node: self.add_quantize(node),
  734. "aten::dequantize": lambda self, node: self.add_dequantize(node),
  735. "aten::add": lambda self, node: self.add_add_sub_op(
  736. node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE
  737. ),
  738. "aten::sub": lambda self, node: self.add_add_sub_op(
  739. node, NNAPI_OperationCode.SUB, NNAPI_FuseCode.FUSED_NONE
  740. ),
  741. "aten::mul": lambda self, node: self.add_pointwise_simple_binary_broadcast_op(
  742. node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE
  743. ),
  744. "aten::div": lambda self, node: self.add_pointwise_simple_binary_broadcast_op(
  745. node, NNAPI_OperationCode.DIV, NNAPI_FuseCode.FUSED_NONE
  746. ),
  747. "aten::relu": lambda self, node: self.add_pointwise_simple_unary_op(
  748. node, NNAPI_OperationCode.RELU
  749. ),
  750. "aten::sigmoid": lambda self, node: self.add_pointwise_simple_unary_op(
  751. node, NNAPI_OperationCode.LOGISTIC
  752. ),
  753. "aten::softmax": lambda self, node: self.add_softmax(node),
  754. "aten::hardtanh": lambda self, node: self.add_hardtanh(node),
  755. "aten::avg_pool2d": lambda self, node: self.add_avg_pool2d(node),
  756. "aten::max_pool2d": lambda self, node: self.add_pool2d_node(
  757. node, NNAPI_OperationCode.MAX_POOL_2D
  758. ),
  759. "aten::adaptive_avg_pool2d": lambda self, node: self.add_adaptive_avg_pool2d(
  760. node
  761. ),
  762. "aten::upsample_nearest2d": lambda self, node: self.add_upsample_nearest2d(
  763. node
  764. ),
  765. "aten::prelu": lambda self, node: self.add_prelu_op(node),
  766. "aten::addmm": lambda self, node: self.add_addmm(node),
  767. "aten::linear": lambda self, node: self.add_linear(node),
  768. "aten::_convolution": lambda self, node: self.add_conv_underscore(node),
  769. "aten::conv2d": lambda self, node: self.add_conv2d(node),
  770. "aten::log_softmax": lambda self, node: self.add_log_softmax(node),
  771. "quantized::linear": lambda self, node: self.add_qlinear(node),
  772. "quantized::conv2d": lambda self, node: self.add_qconv2d(
  773. node, NNAPI_FuseCode.FUSED_NONE
  774. ),
  775. "quantized::conv2d_relu": lambda self, node: self.add_qconv2d(
  776. node, NNAPI_FuseCode.FUSED_RELU
  777. ),
  778. "quantized::conv_transpose2d": lambda self, node: self.add_qconv2d(
  779. node, NNAPI_FuseCode.FUSED_NONE, transpose=True
  780. ),
  781. "quantized::add": lambda self, node: self.add_qadd(
  782. node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE
  783. ),
  784. "quantized::add_relu": lambda self, node: self.add_qadd(
  785. node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_RELU
  786. ),
  787. "quantized::mul": lambda self, node: self.add_qadd(
  788. node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE
  789. ),
  790. }
  791. def add_node(self, node):
  792. adder = self.ADDER_MAP.get(node.kind())
  793. if not adder:
  794. raise Exception( # noqa: TRY002
  795. f"Unsupported node kind ({node.kind()!r}) in node {node!r}"
  796. ) # noqa: TRY002
  797. adder(self, node)
  798. def _identity(self, node):
  799. in_id, _in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
  800. jitval = node.outputsAt(0)
  801. self.jitval_operand_map[jitval] = in_id
  802. def add_getattr(self, node):
  803. assert node.inputsSize() == 1
  804. assert node.outputsSize() == 1
  805. obj_ctype, obj = self.get_constant_value(node.inputsAt(0))
  806. assert str(obj_ctype).startswith("__torch__.")
  807. name = node.s("name")
  808. value = getattr(obj, name)
  809. output = node.outputsAt(0)
  810. ctype = output.type()
  811. self.add_constant_value(output, ctype, value)
  812. def add_constant_node(self, node):
  813. assert node.inputsSize() == 0
  814. assert node.outputsSize() == 1
  815. output = node.outputsAt(0)
  816. ctype = output.type()
  817. value = output.toIValue()
  818. self.add_constant_value(output, ctype, value)
  819. def add_list_construct(self, node):
  820. assert node.outputsSize() == 1
  821. output = node.outputsAt(0)
  822. ctype = output.type()
  823. const_vals: Optional[list] = []
  824. tensors: Optional[list] = []
  825. for inp in node.inputs():
  826. if const_vals is not None and inp in self.constants:
  827. _, val = self.get_constant_value(inp)
  828. const_vals.append(val)
  829. else:
  830. const_vals = None
  831. if tensors is not None and inp.type().kind() == "TensorType":
  832. tensors.append(inp)
  833. else:
  834. tensors = None
  835. if const_vals is not None:
  836. # NOTE: Now that TorchScript supports list constants,
  837. # this code path might not be used anymore.
  838. self.add_constant_value(output, ctype, const_vals)
  839. if tensors is not None:
  840. self.add_tensor_sequence(output, tensors)
  841. if const_vals is None and tensors is None:
  842. raise Exception( # noqa: TRY002
  843. f"Unable to handle ListConstruct node. Neither all constants nor all tensors. {node!r}"
  844. )
  845. def add_tuple_construct(self, node):
  846. assert node.outputsSize() == 1
  847. output = node.outputsAt(0)
  848. values = list(node.inputs())
  849. self.add_tensor_sequence(output, values)
  850. def add_unsqueeze(self, node):
  851. assert node.inputsSize() == 2
  852. assert node.outputsSize() == 1
  853. in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
  854. _, dim = self.get_constant_value(node.inputsAt(1), "IntType")
  855. assert in_oper.dim_order == DimOrder.PRESUMED_CONTIGUOUS
  856. real_dim = dim if dim >= 0 else dim + len(in_oper.shape) + 1
  857. out_shape_list = list(in_oper.shape)
  858. out_shape_list.insert(real_dim, 1)
  859. out_shape = tuple(out_shape_list)
  860. out_oper = in_oper._replace(shape=out_shape)
  861. inputs = [None] * 2
  862. inputs[0] = in_id
  863. inputs[1] = self.add_immediate_int_scalar(dim)
  864. outputs = [None] * 1
  865. outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
  866. self.add_operation(NNAPI_OperationCode.EXPAND_DIMS, inputs, outputs)
  867. def add_to(self, node):
  868. # Handle to("cpu") / to("gpu") case
  869. self._identity(node)
  870. def add_reshape(self, node):
  871. assert node.inputsSize() == 2
  872. assert node.outputsSize() == 1
  873. in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
  874. shape_ctype, shape = self.get_constant_value(node.inputsAt(1))
  875. assert shape_ctype.kind() == "ListType"
  876. assert shape_ctype.getElementType().kind() == "IntType"
  877. is_trivial_reshape = len(shape) == 2 and shape[1] == -1
  878. if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_reshape:
  879. raise Exception( # noqa: TRY002
  880. "Currently, reshape is only supported on NHWC tensors if the target size is [X, -1]."
  881. )
  882. # Bit of a hack here. Use a real tensor to infer the output shape.
  883. out_shape = torch.zeros(1).expand(in_oper.shape).reshape(shape).shape
  884. out_oper = in_oper._replace(
  885. shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS
  886. )
  887. inputs = [None] * 2
  888. inputs[0] = in_id
  889. inputs[1] = self.add_immediate_int_vector(shape)
  890. outputs = [None] * 1
  891. outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
  892. self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs)
  893. def add_flatten(self, node):
  894. assert node.inputsSize() == 3
  895. assert node.outputsSize() == 1
  896. in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
  897. _start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType")
  898. _end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType")
  899. # channels last with channels == 1 or (height & width both 1)
  900. is_trivial_flatten = len(in_oper.shape) == 4 and (
  901. in_oper.shape[1] == 1 or (in_oper.shape[2] == 1 and in_oper.shape[3] == 1)
  902. )
  903. if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_flatten:
  904. raise Exception( # noqa: TRY002
  905. "Currently, flatten is not supported on NHWC tensors unless C=1 or H=W=1"
  906. )
  907. if start_dim < 0:
  908. start_dim += len(in_oper.shape)
  909. if end_dim < 0:
  910. end_dim += len(in_oper.shape)
  911. out_shape = (
  912. in_oper.shape[:start_dim]
  913. + (functools.reduce(operator.mul, in_oper.shape[start_dim : end_dim + 1]),)
  914. + in_oper.shape[end_dim + 1 :]
  915. )
  916. if any(dim == 0 for dim in in_oper.shape[start_dim : end_dim + 1]):
  917. raise Exception( # noqa: TRY002
  918. "Flattening flexible dims is not supported yet"
  919. ) # noqa: TRY002
  920. non_flattened_dims = in_oper.shape[:start_dim] + in_oper.shape[end_dim + 1 :]
  921. if non_flattened_dims.count(0) > 1:
  922. raise Exception("Only 1 dim can be flexible") # noqa: TRY002
  923. out_oper = in_oper._replace(
  924. shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS
  925. )
  926. out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
  927. for idx, dim in enumerate(out_shape):
  928. if dim == 0:
  929. self.forward_operand_shape(out_id, idx, in_id, in_oper.shape.index(0))
  930. inputs_1 = tuple(dim if dim != 0 else -1 for dim in out_shape)
  931. inputs = [None] * 2
  932. inputs[0] = in_id
  933. inputs[1] = self.add_immediate_int_vector(inputs_1)
  934. outputs = [None] * 1
  935. outputs[0] = out_id
  936. self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs)
  937. def add_slice(self, node):
  938. assert node.inputsSize() == 5
  939. assert node.outputsSize() == 1
  940. in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
  941. _, dim_value = self.get_constant_value(node.inputsAt(1))
  942. _, start_value = self.get_constant_value(node.inputsAt(2))
  943. _, stop_value = self.get_constant_value(node.inputsAt(3))
  944. _, step_value = self.get_constant_value(node.inputsAt(4))
  945. if start_value is None:
  946. start_value = 0
  947. if stop_value is None:
  948. stop_value = sys.maxsize
  949. if start_value < 0:
  950. start_value += in_oper.shape[dim_value]
  951. elif start_value == sys.maxsize:
  952. start_value = 0
  953. if start_value == 0 and stop_value == sys.maxsize:
  954. self._identity(node)
  955. return
  956. if in_oper.shape[dim_value] == 0:
  957. raise Exception("Unable to slice with flexible shape") # noqa: TRY002
  958. if stop_value < 0:
  959. stop_value += in_oper.shape[dim_value]
  960. elif stop_value == sys.maxsize:
  961. stop_value = in_oper.shape[dim_value]
  962. if start_value >= stop_value:
  963. raise Exception( # noqa: TRY002
  964. "Slice start value should be less than stop value"
  965. ) # noqa: TRY002
  966. out_len = (stop_value - start_value) // step_value
  967. out_shape = tuple(
  968. out_len if i == dim_value else dim for i, dim in enumerate(in_oper.shape)
  969. )
  970. out_id = self.add_tensor_operand(
  971. node.outputsAt(0), in_oper._replace(shape=out_shape)
  972. )
  973. # flex inputs
  974. end_mask = 0
  975. for idx, dim in enumerate(out_shape):
  976. if dim == 0:
  977. self.forward_operand_shape(out_id, idx, in_id, idx)
  978. end_mask |= 1 << idx
  979. inputs = [None] * 7
  980. inputs[0] = in_id
  981. inputs[1] = self.add_immediate_int_vector(
  982. [start_value if i == dim_value else 0 for i in range(len(in_oper.shape))]
  983. )
  984. inputs[2] = self.add_immediate_int_vector(
  985. [
  986. stop_value if i == dim_value else dim
  987. for i, dim in enumerate(in_oper.shape)
  988. ]
  989. )
  990. inputs[3] = self.add_immediate_int_vector(
  991. [step_value if i == dim_value else 1 for i in range(len(in_oper.shape))]
  992. )
  993. inputs[4] = self.add_immediate_int_scalar(0) # begin mask
  994. inputs[5] = self.add_immediate_int_scalar(end_mask)
  995. inputs[6] = self.add_immediate_int_scalar(0) # shrink axis mas
  996. outputs = [None] * 1
  997. outputs[0] = out_id
  998. self.add_operation(NNAPI_OperationCode.STRIDED_SLICE, inputs, outputs)
  999. def add_size(self, node):
  1000. assert node.inputsSize() == 2
  1001. assert node.outputsSize() == 1
  1002. _, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
  1003. _, value = self.constants[node.inputsAt(1)]
  1004. res = in_oper.shape[value]
  1005. output = node.outputsAt(0)
  1006. self.add_constant_value(output, output.type(), res)
  1007. def add_cat(self, node):
  1008. assert node.inputsSize() == 2
  1009. assert node.outputsSize() == 1
  1010. tensors = self.tensor_sequences[node.inputsAt(0)]
  1011. _, dim = self.get_constant_value(node.inputsAt(1), "IntType")
  1012. assert len(tensors) > 0
  1013. in_ids = []
  1014. out_oper = None
  1015. out_dim_size = 0
  1016. for inp in tensors:
  1017. in_id, in_oper = self.get_tensor_operand_by_jitval(inp)
  1018. if out_oper is None:
  1019. out_shape = change_element(in_oper.shape, dim, -1)
  1020. out_oper = in_oper._replace(shape=out_shape)
  1021. assert in_oper.op_type == out_oper.op_type
  1022. assert in_oper.dim_order == out_oper.dim_order
  1023. assert change_element(in_oper.shape, dim, -1) == change_element(
  1024. out_oper.shape, dim, -1
  1025. )
  1026. # TODO: Possibly check scale and zero point.
  1027. in_ids.append(in_id)
  1028. # TODO: Possibly support variable-sized inputs.
  1029. out_dim_size += in_oper.shape[dim]
  1030. assert out_oper is not None
  1031. out_oper = out_oper._replace(
  1032. shape=change_element(out_oper.shape, dim, out_dim_size)
  1033. )
  1034. if in_oper.dim_order == DimOrder.CHANNELS_LAST: # type: ignore[possibly-undefined]
  1035. assert len(out_oper.shape) == 4
  1036. nnapi_dim = [0, 3, 1, 2][dim]
  1037. else:
  1038. nnapi_dim = dim
  1039. out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
  1040. for idx, d in enumerate(out_oper.shape):
  1041. if d == 0:
  1042. if idx == dim:
  1043. shape = " + ".join(flex_name(ip_id, dim) for ip_id in in_ids)
  1044. self.compute_operand_shape(out_id, idx, shape)
  1045. else:
  1046. self.forward_operand_shape(out_id, idx, in_ids[0], idx)
  1047. inputs = in_ids + [self.add_immediate_int_scalar(nnapi_dim)]
  1048. outputs = [None] * 1
  1049. outputs[0] = out_id
  1050. self.add_operation(NNAPI_OperationCode.CONCATENATION, inputs, outputs)
  1051. def add_mean(self, node):
  1052. assert node.inputsSize() == 4
  1053. assert node.outputsSize() == 1
  1054. in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
  1055. dim_ctype, dim = self.get_constant_value(node.inputsAt(1))
  1056. assert dim_ctype.kind() == "ListType"
  1057. assert dim_ctype.getElementType().kind() == "IntType"
  1058. _, keep_dim = self.get_constant_value(node.inputsAt(2), "BoolType")
  1059. # Expect None for dtype
  1060. self.get_constant_value(node.inputsAt(3), "NoneType")
  1061. if in_oper.dim_order == DimOrder.CHANNELS_LAST:
  1062. assert len(in_oper.shape) == 4
  1063. nnapi_dim = [[0, 3, 1, 2][d] for d in dim]
  1064. else:
  1065. nnapi_dim = dim
  1066. collapsed_dims = set()
  1067. for d in dim:
  1068. if d < 0:
  1069. d += len(in_oper.shape)
  1070. collapsed_dims.add(d)
  1071. if in_oper.dim_order == DimOrder.CHANNELS_LAST and not keep_dim:
  1072. assert collapsed_dims.issuperset({2, 3})
  1073. out_dim_order = DimOrder.PRESUMED_CONTIGUOUS
  1074. else:
  1075. out_dim_order = in_oper.dim_order
  1076. out_shape = []
  1077. for i, s in enumerate(in_oper.shape):
  1078. if i not in collapsed_dims:
  1079. out_shape.append(s)
  1080. elif keep_dim:
  1081. out_shape.append(1)
  1082. out_oper = in_oper._replace(shape=out_shape, dim_order=out_dim_order)
  1083. inputs = [None] * 3
  1084. inputs[0] = in_id
  1085. inputs[1] = self.add_immediate_int_vector(nnapi_dim)
  1086. inputs[2] = self.add_immediate_int_scalar(keep_dim)
  1087. outputs = [None] * 1
  1088. outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
  1089. self.add_operation(NNAPI_OperationCode.MEAN, inputs, outputs)
  1090. def add_quantize(self, node):
  1091. assert node.inputsSize() == 4
  1092. assert node.outputsSize() == 1
  1093. in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
  1094. if in_oper.dim_order != DimOrder.CHANNELS_LAST:
  1095. raise Exception( # noqa: TRY002
  1096. "Most hardware backends prefer NHWC quantized tensors. "
  1097. "Try setting `t.nnapi_nhwc = True` on your tensor inputs. "
  1098. )
  1099. _, scale = self.get_constant_value(node.inputsAt(1), "FloatType")
  1100. _, zero_point = self.get_constant_value(node.inputsAt(2), "IntType")
  1101. _, scalar_type = self.get_constant_value(node.inputsAt(3), "IntType")
  1102. if scalar_type != TorchScalarTypes.QUINT8.value:
  1103. raise Exception( # noqa: TRY002
  1104. "PyTorch NNAPI export only supports quantized tensors "
  1105. "with the quint8 dtype."
  1106. )
  1107. op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
  1108. out_oper = in_oper._replace(
  1109. op_type=op_type,
  1110. scale=scale,
  1111. zero_point=zero_point,
  1112. )
  1113. inputs = [None] * 1
  1114. inputs[0] = in_id
  1115. outputs = [None] * 1
  1116. outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
  1117. self.add_operation(NNAPI_OperationCode.QUANTIZE, inputs, outputs)
  1118. def add_dequantize(self, node):
  1119. assert node.inputsSize() == 1
  1120. assert node.outputsSize() == 1
  1121. in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
  1122. out_oper = in_oper._replace(
  1123. op_type=NNAPI_OperandCode.TENSOR_FLOAT32,
  1124. scale=0.0,
  1125. zero_point=0,
  1126. )
  1127. inputs = [None] * 1
  1128. inputs[0] = in_id
  1129. outputs = [None] * 1
  1130. outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
  1131. self.add_operation(NNAPI_OperationCode.DEQUANTIZE, inputs, outputs)
  1132. def add_pointwise_simple_unary_op(self, node, opcode):
  1133. assert node.inputsSize() == 1
  1134. assert node.outputsSize() == 1
  1135. in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
  1136. out_oper = in_oper
  1137. if opcode == NNAPI_OperationCode.LOGISTIC:
  1138. # NNAPI docs: For ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, the scale
  1139. # must be 1.f / 256 and the zeroPoint must be 0.
  1140. # https://fburl.com/h52stoog
  1141. if in_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
  1142. out_oper = in_oper._replace(zero_point=0, scale=1.0 / 256)
  1143. out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
  1144. for idx, dim in enumerate(in_oper.shape):
  1145. if dim == 0:
  1146. self.forward_operand_shape(out_id, idx, in_id, idx)
  1147. inputs = [None] * 1
  1148. inputs[0] = in_id
  1149. outputs = [None] * 1
  1150. outputs[0] = out_id
  1151. self.add_operation(opcode, inputs, outputs)
  1152. def _do_add_binary(self, node, opcode, fuse_code, *, qparams=None): # noqa: D401
  1153. """Helper for pointwise binary broadcast ops with superfluous extra args."""
  1154. assert node.outputsSize() == 1
  1155. assert node.inputsAt(0).type().kind() == "TensorType"
  1156. assert node.inputsAt(1).type().kind() == "TensorType"
  1157. if self.has_operand_for_jitval(node.inputsAt(0)):
  1158. in0_id, in0_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
  1159. in1_id, in1_oper = self.get_tensor_operand_or_constant(
  1160. node.inputsAt(1), in0_oper.dim_order
  1161. )
  1162. elif self.has_operand_for_jitval(node.inputsAt(1)):
  1163. in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1))
  1164. in0_id, in0_oper = self.get_tensor_operand_or_constant(
  1165. node.inputsAt(0), in1_oper.dim_order
  1166. )
  1167. else:
  1168. raise Exception( # noqa: TRY002
  1169. f"Can't do a NNAPI binary op: {opcode} on two constants"
  1170. ) # noqa: TRY002
  1171. assert in0_oper.op_type == in1_oper.op_type
  1172. in0_id, in0_oper, in1_id, in1_oper = self.transpose_for_broadcast(
  1173. in0_id, in0_oper, in1_id, in1_oper
  1174. )
  1175. # NOTE: PyTorch and NNAPI have the same broadcast semantics.
  1176. out_shape = broadcast_shapes(in0_oper.shape, in1_oper.shape)
  1177. out_oper = in0_oper._replace(shape=out_shape)
  1178. if qparams is not None:
  1179. scale, zp = qparams
  1180. out_oper = out_oper._replace(scale=scale, zero_point=zp)
  1181. out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
  1182. for idx, (d0, d1) in enumerate(zip(in0_oper.shape, in1_oper.shape)):
  1183. if d0 == 1 and d1 == 0:
  1184. self.forward_operand_shape(out_id, idx, in1_id, idx)
  1185. elif d0 == 0 and d1 == 1:
  1186. self.forward_operand_shape(out_id, idx, in0_id, idx)
  1187. elif d0 == 0 and d1 == 0:
  1188. self.flexible_shape_computation_lines.append(
  1189. f"assert {flex_name(in0_id, idx)} == {flex_name(in1_id, idx)}"
  1190. )
  1191. self.forward_operand_shape(out_id, idx, in0_id, idx)
  1192. inputs = [None] * 3
  1193. inputs[0] = in0_id
  1194. inputs[1] = in1_id
  1195. inputs[2] = self.add_immediate_int_scalar(fuse_code)
  1196. outputs = [None] * 1
  1197. outputs[0] = out_id
  1198. self.add_operation(opcode, inputs, outputs)
  1199. def add_pointwise_simple_binary_broadcast_op(self, node, opcode, fuse_code):
  1200. assert node.inputsSize() == 2
  1201. self._do_add_binary(node, opcode, fuse_code)
  1202. def add_add_sub_op(self, node, opcode, fuse_code):
  1203. assert node.inputsSize() == 3
  1204. _, alpha = self.get_constant_value(node.inputsAt(2), "IntType")
  1205. if alpha != 1:
  1206. raise Exception( # noqa: TRY002
  1207. "NNAPI does not support add/sub with alpha."
  1208. ) # noqa: TRY002
  1209. self._do_add_binary(node, opcode, fuse_code)
  1210. def add_qadd(self, node, opcode, fuse_code):
  1211. assert node.inputsSize() == 4
  1212. _, scale = self.get_constant_value(node.inputsAt(2), "FloatType")
  1213. _, zero_point = self.get_constant_value(node.inputsAt(3), "IntType")
  1214. self._do_add_binary(node, opcode, fuse_code, qparams=(scale, zero_point))
  1215. def add_softmax(self, node):
  1216. assert node.inputsSize() == 3
  1217. in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
  1218. _, softmax_dim = self.get_constant_value(node.inputsAt(1), "IntType")
  1219. out_id = self.add_tensor_operand(node.outputsAt(0), in_oper)
  1220. for dim, size in enumerate(in_oper.shape):
  1221. if size == 0:
  1222. self.forward_operand_shape(out_id, dim, in_id, dim)
  1223. inputs = [None] * 3
  1224. inputs[0] = in_id
  1225. inputs[1] = self.add_immediate_float_scalar(
  1226. 1.0
  1227. ) # positive scaling factor of exponent, beta
  1228. inputs[2] = self.add_immediate_int_scalar(softmax_dim)
  1229. outputs = [None] * 1
  1230. outputs[0] = out_id
  1231. self.add_operation(NNAPI_OperationCode.SOFTMAX, inputs, outputs)
  1232. def add_hardtanh(self, node):
  1233. assert node.inputsSize() == 3
  1234. assert node.outputsSize() == 1
  1235. in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
  1236. _, min_val = self.get_constant_value(node.inputsAt(1), "FloatType")
  1237. _, max_val = self.get_constant_value(node.inputsAt(2), "FloatType")
  1238. op_map = {
  1239. (-1, 1): NNAPI_OperationCode.RELU1,
  1240. (0, 6): NNAPI_OperationCode.RELU6, # noqa: E201
  1241. }
  1242. opcode = op_map.get((min_val, max_val))
  1243. if opcode is None:
  1244. raise Exception( # noqa: TRY002
  1245. "NNAPI only supports hardtanh with args (-1, 1) or (0, 6)."
  1246. ) # noqa: TRY002
  1247. inputs = [None] * 1
  1248. inputs[0] = in_id
  1249. outputs = [None] * 1
  1250. outputs[0] = self.add_tensor_operand(node.outputsAt(0), in_oper)
  1251. self.add_operation(opcode, inputs, outputs)
  1252. def add_prelu_op(self, node):
  1253. assert node.inputsSize() == 2
  1254. assert node.outputsSize() == 1
  1255. assert node.inputsAt(0).type().kind() == "TensorType"
  1256. assert node.inputsAt(1).type().kind() == "TensorType"
  1257. in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
  1258. w_id, w_oper = self.get_tensor_operand_for_weight(node.inputsAt(1))
  1259. assert len(w_oper.shape) == 1
  1260. assert w_oper.shape[0] > 0
  1261. if w_oper.shape[0] > 1:
  1262. if in_oper.use_nchw():
  1263. # TODO: Support this by adding trailing 1 dims.
  1264. raise Exception( # noqa: TRY002
  1265. "Per-channel PReLU only supports channels_last right now."
  1266. )
  1267. out_id = self.add_tensor_operand(node.outputsAt(0), in_oper)
  1268. for dim, size in enumerate(in_oper.shape):
  1269. if size > 0:
  1270. pass
  1271. elif dim <= 1:
  1272. raise Exception( # noqa: TRY002
  1273. "PReLU requires fixed size for dim 0 and dim 1."
  1274. ) # noqa: TRY002
  1275. else:
  1276. self.forward_operand_shape(out_id, dim, in_id, dim)
  1277. inputs = [None] * 2
  1278. inputs[0] = in_id
  1279. inputs[1] = w_id
  1280. outputs = [None] * 1
  1281. outputs[0] = out_id
  1282. self.add_operation(NNAPI_OperationCode.PRELU, inputs, outputs)
  1283. def add_pool2d_node(self, node, opcode):
  1284. assert node.inputsSize() == 6
  1285. assert node.outputsSize() == 1
  1286. image, kernel, stride, padding, dilation, _ceil_mode = node.inputs()
  1287. stride = stride or kernel
  1288. # TODO: Validate ceil_mode semantics.
  1289. args = self.get_conv_pool_args_2d_from_jit(
  1290. self.get_size_arg(kernel), stride, padding, dilation
  1291. )
  1292. if args.dilation_h != 1 or args.dilation_w != 1:
  1293. raise Exception("NNAPI does not support dilated pooling.") # noqa: TRY002
  1294. image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(image)
  1295. assert len(image_oper.shape) == 4
  1296. out_shape = get_conv_pool_shape(
  1297. image_oper.shape, args, image_oper.shape[1], False
  1298. )
  1299. use_nchw = image_oper.use_nchw()
  1300. inputs = [None] * 11
  1301. inputs[0] = image_id
  1302. inputs[1] = self.add_immediate_int_scalar(args.pad_l)
  1303. inputs[2] = self.add_immediate_int_scalar(args.pad_r)
  1304. inputs[3] = self.add_immediate_int_scalar(args.pad_t)
  1305. inputs[4] = self.add_immediate_int_scalar(args.pad_b)
  1306. inputs[5] = self.add_immediate_int_scalar(args.stride_w)
  1307. inputs[6] = self.add_immediate_int_scalar(args.stride_h)
  1308. inputs[7] = self.add_immediate_int_scalar(args.kernel_w)
  1309. inputs[8] = self.add_immediate_int_scalar(args.kernel_h)
  1310. inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
  1311. inputs[10] = self.add_immediate_bool_scalar(use_nchw)
  1312. outputs = [None] * 1
  1313. outputs[0] = self.add_tensor_operand(
  1314. node.outputsAt(0), image_oper._replace(shape=out_shape)
  1315. )
  1316. self.add_operation(opcode, inputs, outputs)
  1317. def add_avg_pool2d(self, node):
  1318. assert node.inputsSize() == 7
  1319. assert node.outputsSize() == 1
  1320. (
  1321. image,
  1322. kernel,
  1323. stride,
  1324. padding,
  1325. _ceil_mode,
  1326. count_include_pad,
  1327. divisor_override,
  1328. ) = node.inputs()
  1329. _, count_include_pad_value = self.get_constant_value(count_include_pad)
  1330. _, divisor_override_value = self.get_constant_value(divisor_override)
  1331. if not count_include_pad_value or divisor_override_value:
  1332. raise Exception( # noqa: TRY002
  1333. "NNAPI doesn't support count_include_pad=False or divisor_override"
  1334. )
  1335. args = self.get_conv_pool_args_2d_from_jit(
  1336. self.get_size_arg(kernel), stride, padding
  1337. )
  1338. image_id, image_oper = self.get_tensor_operand_by_jitval(image)
  1339. assert len(image_oper.shape) == 4
  1340. out_shape = get_conv_pool_shape(
  1341. image_oper.shape, args, image_oper.shape[1], False
  1342. )
  1343. use_nchw = image_oper.use_nchw()
  1344. inputs = [None] * 11
  1345. inputs[0] = image_id
  1346. inputs[1] = self.add_immediate_int_scalar(args.pad_l)
  1347. inputs[2] = self.add_immediate_int_scalar(args.pad_r)
  1348. inputs[3] = self.add_immediate_int_scalar(args.pad_t)
  1349. inputs[4] = self.add_immediate_int_scalar(args.pad_b)
  1350. inputs[5] = self.add_immediate_int_scalar(args.stride_w)
  1351. inputs[6] = self.add_immediate_int_scalar(args.stride_h)
  1352. inputs[7] = self.add_immediate_int_scalar(args.kernel_w)
  1353. inputs[8] = self.add_immediate_int_scalar(args.kernel_h)
  1354. inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
  1355. inputs[10] = self.add_immediate_bool_scalar(use_nchw)
  1356. outputs = [None] * 1
  1357. out_id = self.add_tensor_operand(
  1358. node.outputsAt(0), image_oper._replace(shape=out_shape)
  1359. )
  1360. self._handle_conv_pool_flexible_input(out_id, image, args, False)
  1361. outputs[0] = out_id
  1362. self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs)
  1363. def add_adaptive_avg_pool2d(self, node):
  1364. assert node.inputsSize() == 2
  1365. assert node.outputsSize() == 1
  1366. image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(
  1367. node.inputsAt(0)
  1368. )
  1369. assert len(image_oper.shape) == 4
  1370. size_ctype, size_arg = self.get_constant_value(node.inputsAt(1))
  1371. assert size_ctype.kind() == "ListType"
  1372. assert size_ctype.getElementType().kind() == "IntType"
  1373. if size_arg != [1, 1]:
  1374. raise Exception( # noqa: TRY002
  1375. "NNAPI only supports adaptive_avg_pool2d with output size (1, 1)."
  1376. )
  1377. out_shape = image_oper.shape[0:2] + tuple(size_arg)
  1378. use_nchw = image_oper.use_nchw()
  1379. inputs = [None] * 11
  1380. inputs[0] = image_id
  1381. inputs[1] = self.add_immediate_int_scalar(0)
  1382. inputs[2] = self.add_immediate_int_scalar(0)
  1383. inputs[3] = self.add_immediate_int_scalar(0)
  1384. inputs[4] = self.add_immediate_int_scalar(0)
  1385. inputs[5] = self.add_immediate_int_scalar(1)
  1386. inputs[6] = self.add_immediate_int_scalar(1)
  1387. inputs[7] = self.add_immediate_int_scalar(image_oper.shape[3])
  1388. inputs[8] = self.add_immediate_int_scalar(image_oper.shape[2])
  1389. inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
  1390. inputs[10] = self.add_immediate_bool_scalar(use_nchw)
  1391. outputs = [None] * 1
  1392. outputs[0] = self.add_tensor_operand(
  1393. node.outputsAt(0), image_oper._replace(shape=out_shape)
  1394. )
  1395. self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs)
  1396. def add_upsample_nearest2d(self, node):
  1397. assert node.inputsSize() == 3 or node.inputsSize() == 4
  1398. assert node.outputsSize() == 1
  1399. if node.inputsSize() == 3:
  1400. image, size_jit, scale_jit = node.inputs()
  1401. else:
  1402. image, size_jit, scale_h_jit, scale_w_jit = node.inputs()
  1403. size_ctype, size_arg = self.get_constant_value(size_jit)
  1404. if node.inputsSize() == 3:
  1405. scale_ctype, scale_arg = self.get_constant_value(scale_jit) # type: ignore[possibly-undefined]
  1406. else:
  1407. scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit) # type: ignore[possibly-undefined]
  1408. scale_w_ctype, _scale_w_arg = self.get_constant_value(scale_w_jit) # type: ignore[possibly-undefined]
  1409. # The only way for the 4-argument overload of upsample_nearest2d to
  1410. # have been added to the graph without error is if the scale_h and
  1411. # scale_w arguments are None
  1412. assert scale_h_ctype.kind() == "NoneType"
  1413. assert scale_w_ctype.kind() == "NoneType"
  1414. scale_ctype = scale_h_ctype
  1415. scale_arg = scale_h_arg
  1416. image_id, image_oper = self.get_tensor_operand_by_jitval(image)
  1417. assert len(image_oper.shape) == 4
  1418. if size_ctype.kind() != "NoneType" and scale_ctype.kind() != "NoneType":
  1419. raise Exception("Size and scale cannot both be non-None.") # noqa: TRY002
  1420. elif size_ctype.kind() != "NoneType":
  1421. assert size_ctype.kind() == "ListType"
  1422. assert size_ctype.getElementType().kind() == "IntType"
  1423. assert scale_ctype.kind() == "NoneType"
  1424. assert scale_arg is None
  1425. assert isinstance(size_arg, list)
  1426. assert size_arg
  1427. assert all(isinstance(val, int) for val in size_arg)
  1428. if len(size_arg) == 1:
  1429. size_arg = size_arg * 2
  1430. assert len(size_arg) == 2
  1431. out_h = size_arg[0]
  1432. out_w = size_arg[1]
  1433. arg_h = self.add_immediate_int_scalar(out_h)
  1434. arg_w = self.add_immediate_int_scalar(out_w)
  1435. elif scale_ctype.kind() != "NoneType":
  1436. assert scale_ctype.kind() == "ListType"
  1437. assert scale_ctype.getElementType().kind() == "FloatType"
  1438. assert size_ctype.kind() == "NoneType"
  1439. assert size_arg is None
  1440. assert isinstance(scale_arg, list)
  1441. assert scale_arg
  1442. assert all(isinstance(val, float) for val in scale_arg)
  1443. if len(scale_arg) == 1:
  1444. scale_arg = scale_arg * 2
  1445. assert len(scale_arg) == 2
  1446. out_h = int(scale_arg[0] * image_oper.shape[2])
  1447. out_w = int(scale_arg[1] * image_oper.shape[3])
  1448. arg_h = self.add_immediate_float_scalar(scale_arg[0])
  1449. arg_w = self.add_immediate_float_scalar(scale_arg[1])
  1450. else:
  1451. raise Exception("Size and scale cannot both be None.") # noqa: TRY002
  1452. out_shape = (image_oper.shape[0], image_oper.shape[1], out_h, out_w)
  1453. use_nchw = image_oper.use_nchw()
  1454. out_id = self.add_tensor_operand(
  1455. node.outputsAt(0), image_oper._replace(shape=out_shape)
  1456. )
  1457. if image_oper.shape[0] == 0 or image_oper.shape[1] == 0:
  1458. raise Exception("Flexible batch or channels not supported") # noqa: TRY002
  1459. # Handle variable input size
  1460. for dim in (2, 3): # h, w indices
  1461. if image_oper.shape[dim] == 0:
  1462. if size_ctype.kind() != "NoneType":
  1463. self.compute_operand_shape(out_id, dim, size_arg[dim - 2])
  1464. elif scale_ctype.kind() != "NoneType":
  1465. self.compute_operand_shape(
  1466. out_id,
  1467. dim,
  1468. f"int({scale_arg[dim - 2]} * {flex_name(image_id, dim)})",
  1469. )
  1470. else:
  1471. raise Exception( # noqa: TRY002
  1472. "Size and scale cannot both be None."
  1473. ) # noqa: TRY002
  1474. inputs = [None] * 4
  1475. inputs[0] = image_id
  1476. inputs[1] = arg_w
  1477. inputs[2] = arg_h
  1478. inputs[3] = self.add_immediate_bool_scalar(use_nchw)
  1479. outputs = [None] * 1
  1480. outputs[0] = out_id
  1481. self.add_operation(NNAPI_OperationCode.RESIZE_NEAREST_NEIGHBOR, inputs, outputs)
  1482. def add_addmm(self, node):
  1483. assert node.inputsSize() == 5
  1484. assert node.outputsSize() == 1
  1485. jit_bias, jit_input, jit_weight, jit_beta, jit_alpha = node.inputs()
  1486. for jitval in (jit_beta, jit_alpha):
  1487. scale_ctype, scale_value = self.get_constant_value(jitval)
  1488. assert scale_ctype.kind() in ("IntType", "FloatType")
  1489. if scale_value != 1:
  1490. raise Exception( # noqa: TRY002
  1491. "NNAPI Fully-Connected does not support alpha and beta."
  1492. )
  1493. self.add_addmm_or_linear(node, True, jit_input, jit_weight, jit_bias)
  1494. def add_linear(self, node):
  1495. assert node.inputsSize() == 3
  1496. assert node.outputsSize() == 1
  1497. jit_input, jit_weight, jit_bias = node.inputs()
  1498. self.add_addmm_or_linear(node, False, jit_input, jit_weight, jit_bias)
  1499. def add_addmm_or_linear(
  1500. self, node, transpose_weight, jit_input, jit_weight, jit_bias
  1501. ):
  1502. input_id, input_oper = self.get_tensor_operand_by_jitval(jit_input)
  1503. bias_id, bias_oper = self.get_tensor_operand_for_weight(jit_bias)
  1504. assert len(input_oper.shape) == 2
  1505. assert len(bias_oper.shape) == 1
  1506. # TODO: Transform at load time to share weights with CPU model.
  1507. _, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
  1508. assert len(weight_tensor.shape) == 2
  1509. if transpose_weight:
  1510. nnapi_weight_tensor = weight_tensor.t().contiguous()
  1511. else:
  1512. nnapi_weight_tensor = weight_tensor.contiguous()
  1513. weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
  1514. weight_oper = self.operands[weight_id]
  1515. out_shape = (input_oper.shape[0], weight_oper.shape[0])
  1516. out_id = self.add_tensor_operand(
  1517. node.outputsAt(0), input_oper._replace(shape=out_shape)
  1518. )
  1519. if input_oper.shape[0] == 0:
  1520. self.forward_operand_shape(out_id, 0, input_id, 0)
  1521. inputs = [None] * 4
  1522. inputs[0] = input_id
  1523. inputs[1] = weight_id
  1524. inputs[2] = bias_id
  1525. inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
  1526. outputs = [None] * 1
  1527. outputs[0] = out_id
  1528. self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs)
  1529. def add_qlinear(self, node):
  1530. assert node.inputsSize() == 4
  1531. assert node.outputsSize() == 1
  1532. (
  1533. jit_input,
  1534. jit_packed_weight,
  1535. jit_scale,
  1536. jit_zero_point,
  1537. ) = node.inputs()
  1538. input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input)
  1539. # TODO: Support automatic reshape
  1540. assert len(input_oper.shape) == 2
  1541. _, out_scale = self.get_constant_value(jit_scale, "FloatType")
  1542. _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType")
  1543. weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight)
  1544. assert weight_ctype.name() == "LinearPackedParamsBase"
  1545. raw_weight, raw_bias = packed_weight.__getstate__()[0]
  1546. assert raw_bias is not None
  1547. assert len(raw_weight.shape) == 2
  1548. assert len(raw_bias.shape) == 1
  1549. assert raw_bias.shape[0] == raw_weight.shape[0]
  1550. assert raw_weight.shape[1] == input_oper.shape[1]
  1551. assert raw_weight.qscheme() == torch.per_tensor_affine
  1552. if raw_weight.dtype == torch.quint8:
  1553. unsigned_weight = raw_weight
  1554. else:
  1555. assert raw_weight.dtype == torch.qint8
  1556. unsigned_weight = torch._make_per_tensor_quantized_tensor(
  1557. (raw_weight.int_repr().int() + 128).to(torch.uint8),
  1558. scale=raw_weight.q_scale(),
  1559. zero_point=raw_weight.q_zero_point() + 128,
  1560. )
  1561. weight_scale = unsigned_weight.q_scale()
  1562. bias_scale = input_oper.scale * weight_scale
  1563. int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32)
  1564. bias_id = self.add_tensor_operand_for_weight(int_bias)
  1565. multiplier = input_oper.scale * weight_scale / out_scale
  1566. assert multiplier > 0
  1567. if multiplier >= 1:
  1568. raise Exception( # noqa: TRY002
  1569. "Quantized convolution multiplier is greater than 1. "
  1570. "This is supported by NNAPI, but not by most hardware backends. "
  1571. "Try training a model without quantization-aware training. "
  1572. )
  1573. # TODO: Transform at load time to share weights with CPU model.
  1574. nnapi_weight_tensor = unsigned_weight.contiguous()
  1575. weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
  1576. weight_oper = self.operands[weight_id]
  1577. out_shape = (input_oper.shape[0], weight_oper.shape[0])
  1578. out_oper = input_oper._replace(
  1579. shape=out_shape,
  1580. scale=out_scale,
  1581. zero_point=out_zero_point,
  1582. )
  1583. inputs = [None] * 4
  1584. inputs[0] = input_id
  1585. inputs[1] = weight_id
  1586. inputs[2] = bias_id
  1587. inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
  1588. outputs = [None] * 1
  1589. outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
  1590. self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs)
  1591. def get_optional_bias(self, jit_bias, weight_tensor, transpose=False):
  1592. ctype, _value = self.get_constant_value(jit_bias)
  1593. if ctype.kind() == "NoneType":
  1594. bias_idx = 1 if transpose else 0
  1595. nnapi_bias_tensor = torch.zeros(
  1596. weight_tensor.size()[bias_idx], dtype=weight_tensor.dtype
  1597. )
  1598. bias_id = self.add_tensor_operand_for_weight(nnapi_bias_tensor)
  1599. bias_oper = self.operands[bias_id]
  1600. return bias_id, bias_oper
  1601. else:
  1602. return self.get_tensor_operand_for_weight(jit_bias)
  1603. def add_conv2d(self, node):
  1604. assert node.inputsSize() == 7
  1605. assert node.outputsSize() == 1
  1606. (
  1607. jit_image,
  1608. jit_weight,
  1609. jit_bias,
  1610. jit_stride,
  1611. jit_pad,
  1612. jit_dilation,
  1613. jit_groups,
  1614. ) = node.inputs()
  1615. _, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
  1616. bias_id, _bias_oper = self.get_optional_bias(jit_bias, weight_tensor)
  1617. args = self.get_conv_pool_args_2d_from_jit(
  1618. weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups
  1619. )
  1620. return self.add_conv2d_common(
  1621. node.outputsAt(0),
  1622. 0.0,
  1623. 0,
  1624. jit_image,
  1625. weight_tensor,
  1626. bias_id,
  1627. args,
  1628. False, # transpose
  1629. NNAPI_FuseCode.FUSED_NONE,
  1630. )
  1631. def add_conv_underscore(self, node):
  1632. assert node.inputsSize() == 13
  1633. assert node.outputsSize() == 1
  1634. (
  1635. jit_image,
  1636. jit_weight,
  1637. jit_bias,
  1638. jit_stride,
  1639. jit_pad,
  1640. jit_dilation,
  1641. jit_transpose,
  1642. _,
  1643. jit_groups,
  1644. _,
  1645. _,
  1646. _,
  1647. _,
  1648. ) = node.inputs()
  1649. _, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
  1650. _, transpose = self.get_constant_value(jit_transpose)
  1651. bias_id, _bias_oper = self.get_optional_bias(jit_bias, weight_tensor, transpose)
  1652. args = self.get_conv_pool_args_2d_from_jit(
  1653. weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups
  1654. )
  1655. return self.add_conv2d_common(
  1656. node.outputsAt(0),
  1657. 0.0,
  1658. 0,
  1659. jit_image,
  1660. weight_tensor,
  1661. bias_id,
  1662. args,
  1663. transpose,
  1664. NNAPI_FuseCode.FUSED_NONE,
  1665. )
  1666. def add_log_softmax(self, node):
  1667. assert node.inputsSize() == 3
  1668. assert node.outputsSize() == 1
  1669. jit_input, jit_dim, _jit_half_to_float = node.inputs()
  1670. input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input)
  1671. _, dim = self.get_constant_value(jit_dim, "IntType")
  1672. out_shape = input_oper.shape
  1673. inputs = [None] * 3
  1674. inputs[0] = input_id
  1675. # specifying 1 as the scaling factor for the exponent, beta
  1676. inputs[1] = self.add_immediate_float_scalar(1)
  1677. inputs[2] = self.add_immediate_int_scalar(dim)
  1678. outputs = [None] * 1
  1679. outputs[0] = self.add_tensor_operand(
  1680. node.outputsAt(0), input_oper._replace(shape=out_shape)
  1681. )
  1682. self.add_operation(NNAPI_OperationCode.LOG_SOFTMAX, inputs, outputs)
  1683. def add_qconv2d(self, node, fuse_code, transpose=False):
  1684. assert node.inputsSize() == 4
  1685. assert node.outputsSize() == 1
  1686. (
  1687. jit_image,
  1688. jit_packed_weight,
  1689. jit_scale,
  1690. jit_zero_point,
  1691. ) = node.inputs()
  1692. _, out_scale = self.get_constant_value(jit_scale, "FloatType")
  1693. _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType")
  1694. weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight)
  1695. assert weight_ctype.name() == "Conv2dPackedParamsBase"
  1696. (
  1697. pack_version,
  1698. tensors,
  1699. opt_tensors,
  1700. ) = packed_weight.__getstate__()[0]
  1701. assert pack_version == "2"
  1702. packed_config, raw_weight = tensors
  1703. (raw_bias,) = opt_tensors
  1704. assert raw_bias is not None
  1705. args = self.get_conv_pool_args_2d_from_pack(
  1706. raw_weight.shape[2:4], packed_config
  1707. )
  1708. assert raw_weight.qscheme() == torch.per_tensor_affine
  1709. if raw_weight.dtype == torch.quint8:
  1710. unsigned_weight = raw_weight
  1711. else:
  1712. assert raw_weight.dtype == torch.qint8
  1713. unsigned_weight = torch._make_per_tensor_quantized_tensor(
  1714. (raw_weight.int_repr().int() + 128).to(torch.uint8),
  1715. scale=raw_weight.q_scale(),
  1716. zero_point=raw_weight.q_zero_point() + 128,
  1717. )
  1718. weight_scale = unsigned_weight.q_scale()
  1719. _, image_oper = self.get_tensor_operand_by_jitval(jit_image)
  1720. bias_scale = image_oper.scale * weight_scale
  1721. int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32)
  1722. bias_id = self.add_tensor_operand_for_weight(int_bias)
  1723. multiplier = image_oper.scale * weight_scale / out_scale
  1724. assert multiplier > 0
  1725. if multiplier >= 1:
  1726. raise Exception( # noqa: TRY002
  1727. "Quantized convolution multiplier is greater than 1. "
  1728. "This is supported by NNAPI, but not by most hardware backends. "
  1729. "Try training a model without quantization-aware training. "
  1730. )
  1731. return self.add_conv2d_common(
  1732. node.outputsAt(0),
  1733. out_scale,
  1734. out_zero_point,
  1735. jit_image,
  1736. unsigned_weight,
  1737. bias_id,
  1738. args,
  1739. transpose,
  1740. fuse_code,
  1741. )
  1742. def add_conv2d_common(
  1743. self,
  1744. jit_out,
  1745. out_scale,
  1746. out_zero_point,
  1747. jit_image,
  1748. weight_tensor,
  1749. bias_id,
  1750. args,
  1751. transpose,
  1752. fuse_code,
  1753. ):
  1754. image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image)
  1755. in_c = image_oper.shape[1]
  1756. if args.group == 1:
  1757. # Full convolution
  1758. depthwise = False
  1759. if transpose:
  1760. weight_permutation = (1, 2, 3, 0)
  1761. else:
  1762. weight_permutation = (0, 2, 3, 1)
  1763. elif args.group == in_c:
  1764. # Depthwise convolution
  1765. depthwise = True
  1766. weight_permutation = (1, 2, 3, 0)
  1767. else:
  1768. raise Exception("Group convolution not supported yet.") # noqa: TRY002
  1769. # TODO: Transform at load time to share weights with CPU model.
  1770. nnapi_weight_tensor = weight_tensor.permute(*weight_permutation).contiguous()
  1771. weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
  1772. weight_oper = self.operands[weight_id]
  1773. bias_oper = self.operands[bias_id]
  1774. if image_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32:
  1775. assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32
  1776. assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32
  1777. elif image_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
  1778. assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
  1779. assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_INT32
  1780. assert approx_equal(image_oper.scale * weight_oper.scale, bias_oper.scale)
  1781. assert bias_oper.zero_point == 0
  1782. else:
  1783. raise Exception( # noqa: TRY002
  1784. f"Unsupported input type for conv2d: {image_oper.op_type}"
  1785. ) # noqa: TRY002
  1786. assert len(image_oper.shape) == 4
  1787. assert len(weight_oper.shape) == 4
  1788. assert len(bias_oper.shape) == 1
  1789. if depthwise:
  1790. # Depthwise convolution
  1791. one, _kern_h, _kern_w, out_c = weight_oper.shape
  1792. assert one == 1
  1793. assert out_c % in_c == 0
  1794. channel_multiplier = out_c // in_c
  1795. assert channel_multiplier == 1 # Don't support multiplier
  1796. assert out_c == in_c
  1797. else:
  1798. # Full convolution
  1799. out_c, _kern_h, _kern_w, kern_d = weight_oper.shape
  1800. assert kern_d == in_c
  1801. assert out_c == bias_oper.shape[0]
  1802. use_nchw = image_oper.use_nchw()
  1803. if depthwise:
  1804. num_args = 12
  1805. opcode = NNAPI_OperationCode.DEPTHWISE_CONV_2D
  1806. else:
  1807. num_args = 11
  1808. if transpose:
  1809. opcode = NNAPI_OperationCode.TRANSPOSE_CONV_2D
  1810. else:
  1811. opcode = NNAPI_OperationCode.CONV_2D
  1812. inputs = [None] * num_args
  1813. inputs[0] = image_id
  1814. inputs[1] = weight_id
  1815. inputs[2] = bias_id
  1816. inputs[3] = self.add_immediate_int_scalar(args.pad_l)
  1817. inputs[4] = self.add_immediate_int_scalar(args.pad_r)
  1818. inputs[5] = self.add_immediate_int_scalar(args.pad_t)
  1819. inputs[6] = self.add_immediate_int_scalar(args.pad_b)
  1820. inputs[7] = self.add_immediate_int_scalar(args.stride_w)
  1821. inputs[8] = self.add_immediate_int_scalar(args.stride_h)
  1822. if depthwise:
  1823. inputs[9] = self.add_immediate_int_scalar(1)
  1824. inputs[10] = self.add_immediate_int_scalar(fuse_code)
  1825. inputs[11] = self.add_immediate_bool_scalar(use_nchw)
  1826. else:
  1827. inputs[9] = self.add_immediate_int_scalar(fuse_code)
  1828. inputs[10] = self.add_immediate_bool_scalar(use_nchw)
  1829. outputs = [None] * 1
  1830. out_shape = get_conv_pool_shape(image_oper.shape, args, out_c, transpose)
  1831. out_oper = image_oper._replace(
  1832. shape=out_shape,
  1833. scale=out_scale,
  1834. zero_point=out_zero_point,
  1835. )
  1836. out_id = self.add_tensor_operand(jit_out, out_oper)
  1837. self._handle_conv_pool_flexible_input(out_id, jit_image, args, transpose)
  1838. outputs[0] = out_id
  1839. self.add_operation(opcode, inputs, outputs)
  1840. def _handle_conv_pool_flexible_input(self, out_id, jit_image, args, transpose):
  1841. image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image)
  1842. batch, in_ch, in_h, in_w = image_oper.shape
  1843. if batch == 0:
  1844. self.forward_operand_shape(out_id, 0, image_id, 0)
  1845. if in_ch == 0:
  1846. raise Exception("Input channels can't be flexible") # noqa: TRY002
  1847. # H & W
  1848. if transpose:
  1849. if in_h == 0:
  1850. self.compute_operand_shape(
  1851. out_id,
  1852. 2,
  1853. f"({flex_name(image_id, 2)} - 1) * {args.stride_h} + {args.kernel_h} - {args.pad_t} - {args.pad_b}",
  1854. )
  1855. if in_w == 0:
  1856. self.compute_operand_shape(
  1857. out_id,
  1858. 3,
  1859. f"({flex_name(image_id, 3)} - 1) * {args.stride_w} + {args.kernel_w} - {args.pad_l} - {args.pad_r}",
  1860. )
  1861. else:
  1862. if in_h == 0:
  1863. self.compute_operand_shape(
  1864. out_id,
  1865. 2,
  1866. f"({flex_name(image_id, 2)} - {args.kernel_h} + {args.pad_t} + {args.pad_b}) // {args.stride_h} + 1",
  1867. )
  1868. if in_w == 0:
  1869. self.compute_operand_shape(
  1870. out_id,
  1871. 3,
  1872. f"({flex_name(image_id, 3)} - {args.kernel_w} + {args.pad_l} + {args.pad_r}) // {args.stride_w} + 1",
  1873. )
  1874. def serialize_model(
  1875. module, inputs, *, config=None, return_shapes=None, use_int16_for_qint16=False
  1876. ):
  1877. """Convert to NNAPI and serialize torchscript module.
  1878. Parameters:
  1879. module: Torchscript module to convert
  1880. inputs: Tensors used to specify input details for NNAPI
  1881. config (optional): Optional config to attach to module
  1882. return_shapes (optional): Specify shape of outputs if
  1883. your module uses runtime flexible shapes to set output
  1884. buffer size for NNAPI
  1885. use_int16_for_qint16 (optional): Use Pytorch int16 to represent NNAPI qint16 values
  1886. """
  1887. return _NnapiSerializer(config, use_int16_for_qint16).serialize_model(
  1888. module, inputs, return_shapes
  1889. )