xnnpack.h 196 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855
  1. // Copyright (c) Facebook, Inc. and its affiliates.
  2. // All rights reserved.
  3. //
  4. // Copyright 2019 Google LLC
  5. //
  6. // This source code is licensed under the BSD-style license found in the
  7. // LICENSE file in the root directory of this source tree.
  8. #pragma once
  9. #include <stdbool.h>
  10. #include <stddef.h>
  11. #include <stdint.h>
  12. #include "pthreadpool.h"
  13. #ifdef __cplusplus
  14. extern "C" {
  15. #endif
  16. /// The number of bytes XNNPACK may read beyond array bounds.
  17. /// The caller must allocate at least this many extra bytes after the tensor data passed to XNNPACK.
  18. ///
  19. /// Note: XNNPACK reads, but never writes beyond array bounds.
  20. #if XNN_ARCH_HEXAGON
  21. #define XNN_EXTRA_BYTES 128
  22. #else
  23. #define XNN_EXTRA_BYTES 16
  24. #endif // XNN_ARCH_HEXAGON
  25. /// Maximum number of dimensions in tensor shape.
  26. #define XNN_MAX_TENSOR_DIMS 6
  27. /// A value ID that cannot be valid.
  28. #define XNN_INVALID_VALUE_ID UINT32_MAX
  29. /// Allow sparse inference in a Runtime.
  30. ///
  31. /// Note: this flag is a hint to XNNPACK that it should consider sparse inference, but does not guarantee it.
  32. #define XNN_FLAG_HINT_SPARSE_INFERENCE 0x00000001
  33. /// Allow IEEE FP16 inference in a Runtime.
  34. ///
  35. /// Note: this flag hints XNNPACK to consider IEEE FP16 inference, but does not guarantee it.
  36. #define XNN_FLAG_HINT_FP16_INFERENCE 0x00000002
  37. /// Force IEEE FP16 inference in a Runtime, and fail if FP16 inference is not possible.
  38. ///
  39. /// Note: this flag guarantees that XNNPACK will use IEEE FP16 inference, or fail to create the Runtime object.
  40. /// Warning: on x86 systems FP16 computations will be emulated at a substantial performance cost.
  41. #define XNN_FLAG_FORCE_FP16_INFERENCE 0x00000004
  42. /// Enable timing of each operator's runtime.
  43. #define XNN_FLAG_BASIC_PROFILING 0x00000008
  44. /// Enable the just-in-time compiler.
  45. #define XNN_FLAG_JIT 0x00000010
  46. /// The convolution operator represents a depthwise convolution, and use HWGo layout for filters.
  47. #define XNN_FLAG_DEPTHWISE_CONVOLUTION 0x00000001
  48. /// Assume transposed weights in a fully connected operator.
  49. #define XNN_FLAG_TRANSPOSE_WEIGHTS 0x00000001
  50. /// The operator assumes NHWC layout for the input, regardless of the output layout.
  51. #define XNN_FLAG_INPUT_NHWC 0x00000002
  52. /// Match "SAME" padding in TensorFlow. Exact padding values are computed dynamically depending on input size.
  53. #define XNN_FLAG_TENSORFLOW_SAME_PADDING 0x00000004
  54. /// Assume transposed weights in a batch matrix multiply operator.
  55. #define XNN_FLAG_TRANSPOSE_B XNN_FLAG_TRANSPOSE_WEIGHTS
  56. /// Assume transposed input in a batch matrix multiply operator.
  57. #define XNN_FLAG_TRANSPOSE_A 0x00000002
  58. /// Implicitly flatten and reshape input of a Fully Connected operator into a 2D tensor.
  59. #define XNN_FLAG_TENSORFLOW_RESHAPE_2D 0x00000004
  60. /// Match behaviour of TensorFlow 1.x.
  61. #define XNN_FLAG_TENSORFLOW_LEGACY_MODE 0x00000004
  62. /// Static weights of the FP16 operator are in FP32 format.
  63. #define XNN_FLAG_FP32_STATIC_WEIGHTS 0x00000008
  64. /// Static biases of the FP16 operator are in FP32 format.
  65. #define XNN_FLAG_FP32_STATIC_BIASES 0x00000080
  66. /// Align corners of input and output images in resize operations.
  67. #define XNN_FLAG_ALIGN_CORNERS 0x00000008
  68. /// Yield worker threads of the thread pool to the system scheduler after the inference.
  69. #define XNN_FLAG_YIELD_WORKERS 0x00000010
  70. /// Use transient indirection buffer to reduce memory footprint
  71. #define XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER 0x00000020
  72. /// Retain reduced dimensions with length 1.
  73. #define XNN_FLAG_KEEP_DIMS 0x00000040
  74. // Next unused flag value: 0x00000100.
  75. /// The number of entries in an array of xnn_quantization_params that XNNPACK may read beyond array bounds.
  76. /// The caller must allocate at least this many extra xnn_quantization_params before passing the array to XNNPACK.
  77. ///
  78. /// Note: XNNPACK reads, but never writes beyond array bounds.
  79. #define XNN_EXTRA_QUANTIZATION_PARAMS 15
  80. /// The minimum blocksize for blockwise quantized operators.
  81. #define XNN_MIN_BLOCKSIZE 32
  82. #ifdef __GNUC__
  83. #define XNN_DEPRECATED __attribute__((deprecated))
  84. #else
  85. #define XNN_DEPRECATED
  86. #endif
  87. struct xnn_quantization_params {
  88. int32_t zero_point;
  89. float scale;
  90. };
  91. /// Status code for any XNNPACK function call.
  92. enum xnn_status {
  93. /// The call succeeded, and all output arguments now contain valid data.
  94. xnn_status_success = 0,
  95. xnn_status_uninitialized = 1,
  96. xnn_status_invalid_parameter = 2,
  97. xnn_status_invalid_state = 3,
  98. xnn_status_unsupported_parameter = 4,
  99. xnn_status_unsupported_hardware = 5,
  100. xnn_status_out_of_memory = 6,
  101. xnn_status_reallocation_required = 7,
  102. xnn_status_deprecated = 8,
  103. };
  104. struct xnn_allocator {
  105. /// User-specified pointer that will be passed as-is to all functions in this structure.
  106. void* context;
  107. /// Pointer to a function to be called for general memory allocation.
  108. ///
  109. /// @param context - The user-specified pointer from xnn_allocator structure.
  110. /// @param size - The size of the memory block to allocate, in bytes.
  111. ///
  112. /// @returns Pointer to the allocated memory block of at least @ref size bytes.
  113. /// If allocation fails, the function must return NULL.
  114. void* (*allocate)(void* context, size_t size);
  115. /// Pointer to a function to be called for general memory re-allocation, i.e. to increase or shrink a previously
  116. /// allocated memory block. The content of the old memory block is copied to the new memory block.
  117. ///
  118. /// @param context - The user-specified pointer from xnn_allocator structure.
  119. /// @param pointer - Pointer to a memory block allocated by @ref allocate or @ref reallocate functions. Can be NULL.
  120. /// If the pointer is NULL, the @ref reallocate call is equivalent to an @ref allocate call.
  121. /// @param size - The new size of the memory block to allocate, in bytes.
  122. ///
  123. /// @returns Pointer to the newly allocated memory block of at least @ref size bytes with the content of the previous
  124. /// memory block.
  125. /// If allocation fails, the function must return NULL, but must not release the previous memory block.
  126. void* (*reallocate)(void* context, void* pointer, size_t size);
  127. /// Pointer to a function to be called for general memory de-allocation.
  128. ///
  129. /// @param context - The user-specified pointer from xnn_allocator structure.
  130. /// @param pointer - Pointer to a memory block allocated by @ref allocate or @ref reallocate functions. Can be NULL.
  131. /// If the pointer is NULL, the @ref deallocate call is a no-op.
  132. void (*deallocate)(void* context, void* pointer);
  133. /// Pointer to a function to be called for aligned memory allocation.
  134. ///
  135. /// @param context - The user-specified pointer from xnn_allocator structure.
  136. /// @param alignment - The alignment of the memory block to allocate, in bytes. Alignment is always a power-of-2.
  137. /// @param size - The size of the memory block to allocate, in bytes.
  138. ///
  139. /// @returns Pointer to the allocated memory block of at least @ref size bytes.
  140. /// If allocation fails, the function must return NULL.
  141. void* (*aligned_allocate)(void* context, size_t alignment, size_t size);
  142. /// Pointer to a function to be called for aligned memory deallocation.
  143. ///
  144. /// @param context - The user-specified pointer from xnn_allocator structure.
  145. /// @param pointer - Pointer to a memory block allocated by @ref aligned_allocate function. Can be NULL.
  146. /// If the pointer is NULL, the @ref aligned_deallocate call is a no-op.
  147. void (*aligned_deallocate)(void* context, void* pointer);
  148. };
  149. /// Initialize XNNPACK library.
  150. ///
  151. /// XNNPACK must be successfully initialized before use. During initialization, XNNPACK populates internal structures
  152. /// depending on the host processor. Initialization can be time-consuming.
  153. ///
  154. /// @param[in] allocator - structure with function pointers to be use for memory allocation and de-allocation.
  155. /// If this argument is NULL, system-provided memory management functions (e.g. malloc/free)
  156. /// will be used.
  157. ///
  158. /// @retval xnn_status_success - XNNPACK is successfully initialized and ready to use.
  159. /// @retval xnn_status_out_of_memory - initialization failed due to out-of-memory condition.
  160. /// @retval xnn_status_unsupported_hardware - initialization failed because the host processor does not satisfy the
  161. /// minimum hardware requirements for XNNPACK. E.g. this may happen on x86
  162. /// processors without SSE2 extension, or on 32-bit ARM processors without
  163. /// the NEON SIMD extension.
  164. enum xnn_status xnn_initialize(const struct xnn_allocator* allocator);
  165. /// Deinitialize XNNPACK library.
  166. ///
  167. /// To avoid memory and resource leaks, users must call xnn_deinitialize once for each successful xnn_initialize call.
  168. ///
  169. /// @retval xnn_status_success - deinitialization call succeeded.
  170. enum xnn_status xnn_deinitialize(void);
  171. /// Get the microkernel implementation build identifier's data.
  172. ///
  173. /// That identifier will be unique for the current set of microkernels implementations.
  174. ///
  175. /// @returns A pointer to the current identifier's data.
  176. const void* xnn_experimental_get_build_identifier_data();
  177. /// Get the microkernel implementation build identifier's data size.
  178. ///
  179. /// @returns The size in bytes of the identifier's data.
  180. size_t xnn_experimental_get_build_identifier_size();
  181. /// Check whether the given data matches this version's identifier.
  182. ///
  183. /// @returns The size in bytes of the identifier's data.
  184. bool xnn_experimental_check_build_identifier(const void* data, size_t size);
  185. /// Subgraph is an abstract representation of a neural network model.
  186. /// Subgraph objects are used to define Values (tensors) and Nodes (operators) comprising the model.
  187. typedef struct xnn_subgraph* xnn_subgraph_t;
  188. /// Create a empty Subgraph object.
  189. ///
  190. /// @param external_value_ids - number of Value IDs to reserve for communication with external graph representation.
  191. /// The Subgraph object would avoid creating internal Value IDs in the
  192. /// [0, reserved_value_ids-1] range.
  193. /// @param flags - binary features of the subgraph. No supported flags are currently defined.
  194. /// @param subgraph_out - pointer to the variable that will be initialized with a handle to the Subgraph object upon
  195. /// successful return.
  196. enum xnn_status xnn_create_subgraph(
  197. uint32_t external_value_ids,
  198. uint32_t flags,
  199. xnn_subgraph_t* subgraph_out);
  200. /// Destroy a Subgraph object, as well as Values, and Nodes associated with the subgraph.
  201. ///
  202. /// @param subgraph - the Subgraph object to destroy.
  203. enum xnn_status xnn_delete_subgraph(
  204. xnn_subgraph_t subgraph);
  205. #define XNN_VALUE_FLAG_EXTERNAL_INPUT 0x00000001
  206. #define XNN_VALUE_FLAG_EXTERNAL_OUTPUT 0x00000002
  207. #define XNN_VALUE_FLAG_PERSISTENT 0x00000004
  208. #define XNN_INVALID_VALUE_ID UINT32_MAX
  209. /// Type of elements in a Value object.
  210. enum xnn_datatype {
  211. /// Invalid data type. Valid Values never have this datatype.
  212. xnn_datatype_invalid = 0,
  213. /// IEEE754 single-precision floating-point.
  214. xnn_datatype_fp32 = 1,
  215. /// IEEE754 half-precision floating-point.
  216. xnn_datatype_fp16 = 2,
  217. /// Quantized 8-bit signed integer with shared per-Value quantization
  218. /// parameters.
  219. xnn_datatype_qint8 = 3,
  220. /// Quantized 8-bit unsigned integer with shared per-Value quantization
  221. /// parameters.
  222. xnn_datatype_quint8 = 4,
  223. /// Quantized 32-bit signed integer with shared per-Value quantization
  224. /// parameters.
  225. xnn_datatype_qint32 = 5,
  226. /// Quantized 8-bit signed integer with shared per-channel quantization
  227. /// parameters.
  228. xnn_datatype_qcint8 = 6,
  229. /// Quantized 32-bit signed integer with shared per-channel quantization
  230. /// parameters.
  231. xnn_datatype_qcint32 = 7,
  232. /// Quantized 4-bit signed integer with shared per-channel quantization
  233. /// parameters.
  234. xnn_datatype_qcint4 = 8,
  235. /// Dynamically quantized 8-bit signed integer with per-batch quantization
  236. /// parameters.
  237. xnn_datatype_qdint8 = 9,
  238. /// Dynamically quantized 8-bit signed integers packed with their per-row
  239. /// quantization parameters.
  240. xnn_datatype_qpint8 = 10,
  241. /// 32-bit signed integers.
  242. xnn_datatype_int32 = 11,
  243. /// Quantized 4-bit signed integer with shared per-channel-block quantization
  244. /// parameters.
  245. xnn_datatype_qbint4 = 12,
  246. /// IEEE754 single-precision packed floating-point.
  247. xnn_datatype_pfp32 = 13,
  248. /// BFloat16, i.e. the upper 16 bits of a float32.
  249. xnn_datatype_bf16 = 14,
  250. /// Dynamically quantized 8-bit unsigned integer with per-batch quantization
  251. /// parameters.
  252. xnn_datatype_qduint8 = 15,
  253. };
  254. /// Define a tensor-type Value and add it to a Subgraph.
  255. ///
  256. /// @param subgraph - a Subgraph object that will own the created Value.
  257. /// @param datatype - type of the tensor elements.
  258. /// @param num_dims - number of dimensions in the shape.
  259. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  260. /// XNNPACK does not keep any pointers to this array after the function returns.
  261. /// @param data - pointer to static data used for tensor initialization. If the tensor is not statically initialized,
  262. /// this pointer must be is NULL. If non-NULL, the life-time of the static data must exceed the life-time
  263. /// of the Subgraph object, and of any Runtime objects created from the Subgraph.
  264. /// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on
  265. /// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be
  266. /// created for the Value.
  267. /// @param flags - binary features of the Value. Supported values are any combination of XNN_VALUE_FLAG_EXTERNAL_INPUT
  268. /// and XNN_VALUE_FLAG_EXTERNAL_OUTPUT.
  269. /// @param id_out - pointer to the variable that will be initialized with the Value ID upon successful return. If a
  270. /// valid @a external_id was provided, the variable will be initialized with the @a external_id value.
  271. enum xnn_status xnn_define_tensor_value(
  272. xnn_subgraph_t subgraph,
  273. enum xnn_datatype datatype,
  274. size_t num_dims,
  275. const size_t* dims,
  276. const void* data,
  277. uint32_t external_id,
  278. uint32_t flags,
  279. uint32_t* id_out);
  280. /// Define a quantized tensor-type Value and add it to a Subgraph.
  281. ///
  282. /// @param subgraph - a Subgraph object that will own the created Value.
  283. /// @param datatype - type of the tensor elements.
  284. /// @param zero_point - offset from zero to subtract from the quantized elements in the Value.
  285. /// @param scale - multiplication factor to convert quantized elements to real representation.
  286. /// @param num_dims - number of dimensions in the shape.
  287. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  288. /// XNNPACK does not keep any pointers to this array after the function returns.
  289. /// @param data - pointer to static data used for tensor initialization. If the tensor is not statically initialized,
  290. /// this pointer must be is NULL. If non-NULL, the life-time of the static data must exceed the life-time
  291. /// of the Subgraph object, and of any Runtime objects created from the Subgraph.
  292. /// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on
  293. /// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be
  294. /// created for the Value.
  295. /// @param flags - binary features of the Value. Supported values are any combination of XNN_VALUE_FLAG_EXTERNAL_INPUT
  296. /// and XNN_VALUE_FLAG_EXTERNAL_OUTPUT.
  297. /// @param id_out - pointer to the variable that will be initialized with the Value ID upon successful return. If a
  298. /// valid @a external_id was provided, the variable will be initialized with the @a external_id value.
  299. enum xnn_status xnn_define_quantized_tensor_value(
  300. xnn_subgraph_t subgraph,
  301. enum xnn_datatype datatype,
  302. int32_t zero_point,
  303. float scale,
  304. size_t num_dims,
  305. const size_t* dims,
  306. const void* data,
  307. uint32_t external_id,
  308. uint32_t flags,
  309. uint32_t* id_out);
  310. enum xnn_status xnn_define_channelwise_quantized_tensor_value(
  311. xnn_subgraph_t subgraph,
  312. enum xnn_datatype datatype,
  313. const float* scale,
  314. size_t num_dims,
  315. size_t channel_dim,
  316. const size_t* dims,
  317. const void* data,
  318. uint32_t external_id,
  319. uint32_t flags,
  320. uint32_t* id_out);
  321. /// Validate the dimensions, channel_dim, zero point, datatype, and scale of a quantized tensor-type.
  322. ///
  323. /// @param datatype - type of the tensor elements.
  324. /// @param zero_point - offset from zero to subtract from the quantized elements in the Value.
  325. /// @param scale - multiplication factor to convert quantized elements to real representation.
  326. /// @param num_dims - number of dimensions in the shape.
  327. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  328. /// XNNPACK does not keep any pointers to this array after the function returns.
  329. enum xnn_status xnn_validate_quantized_tensor(
  330. enum xnn_datatype datatype,
  331. int32_t zero_point,
  332. float scale,
  333. size_t num_dims,
  334. const size_t* dims);
  335. /// Validate the dimensions, channel_dim, zero point, datatype, and scales of a channelwise quantized tensor-type.
  336. ///
  337. /// @param datatype - type of the tensor elements.
  338. /// @param zero_point - offset from zero to subtract from the quantized elements in the Value.
  339. /// @param scale - per-channel multiplication factors to convert quantized elements to real representation.
  340. /// @param num_dims - number of dimensions in the shape.
  341. /// @param channel_dim - index of the channel dimension in the tensor with per-channel quantization parameters.
  342. /// Typically this is the first dimension (dimension #0) of the filter tensors in the Convolution,
  343. /// Deconvolution, and Fully Connected operators and the last dimension of the filter tensors in
  344. /// the Depthwise Convolution operators.
  345. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  346. /// XNNPACK does not keep any pointers to this array after the function returns.
  347. enum xnn_status xnn_validate_channelwise_quantized_tensor(
  348. enum xnn_datatype datatype,
  349. int32_t zero_point,
  350. const float* scale,
  351. size_t num_dims,
  352. size_t channel_dim,
  353. const size_t* dims);
  354. /// Define a channelwise quantized tensor-type Value and add it to a Subgraph.
  355. ///
  356. /// @param subgraph - a Subgraph object that will own the created Value.
  357. /// @param datatype - type of the tensor elements.
  358. /// @param zero_point - offset from zero to subtract from the quantized elements in the Value.
  359. /// @param scale - per-channel multiplication factors to convert quantized elements to real representation.
  360. /// @param num_dims - number of dimensions in the shape.
  361. /// @param channel_dim - index of the channel dimension in the tensor with per-channel quantization parameters.
  362. /// Typically this is the first dimension (dimension #0) of the filter tensors in the Convolution,
  363. /// Deconvolution, and Fully Connected operators and the last dimension of the filter tensors in
  364. /// the Depthwise Convolution operators.
  365. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  366. /// XNNPACK does not keep any pointers to this array after the function returns.
  367. /// @param data - pointer to static data used for tensor initialization. If the tensor is not statically initialized,
  368. /// this pointer must be is NULL. If non-NULL, the life-time of the static data must exceed the life-time
  369. /// of the Subgraph object, and of any Runtime objects created from the Subgraph.
  370. /// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on
  371. /// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be
  372. /// created for the Value.
  373. /// @param flags - binary features of the Value. Supported values are any combination of XNN_VALUE_FLAG_EXTERNAL_INPUT
  374. /// and XNN_VALUE_FLAG_EXTERNAL_OUTPUT.
  375. /// @param id_out - pointer to the variable that will be initialized with the Value ID upon successful return. If a
  376. /// valid @a external_id was provided, the variable will be initialized with the @a external_id value.
  377. enum xnn_status xnn_define_channelwise_quantized_tensor_value_v2(
  378. xnn_subgraph_t subgraph,
  379. enum xnn_datatype datatype,
  380. int32_t zero_point,
  381. const float* scale,
  382. size_t num_dims,
  383. size_t channel_dim,
  384. const size_t* dims,
  385. const void* data,
  386. uint32_t external_id,
  387. uint32_t flags,
  388. uint32_t* id_out);
  389. /// Define a blockwise quantized tensor-type Value and add it to a Subgraph.
  390. /// @param block_size - size of a block in the tensor with blockwise quantization parameters. Block is defined as
  391. /// number of input channel element per output channel.
  392. /// For Fully connected operators with 2d filters of size [output_channels, input_channels],
  393. /// expecting number of scale values to be = output_channels * (input_channels / block_size).
  394. enum xnn_status xnn_define_blockwise_quantized_tensor_value(
  395. xnn_subgraph_t subgraph,
  396. enum xnn_datatype datatype,
  397. int32_t zero_point,
  398. const uint16_t* scale,
  399. size_t num_dims,
  400. size_t channel_dim,
  401. size_t block_size,
  402. const size_t* dims,
  403. const void* data,
  404. uint32_t external_id,
  405. uint32_t flags,
  406. uint32_t* id_out);
  407. /// Define a dynamically quantized tensor-type Value and add it to a Subgraph.
  408. ///
  409. /// @param subgraph - a Subgraph object that will own the created Value.
  410. /// @param datatype - type of the tensor elements.
  411. /// @param num_dims - number of dimensions in the shape.
  412. /// @param num_non_batch_dims - number of non-batch dimensions in the shape. The leading (num_dims - num_non_batch_dims)
  413. /// dimensions will be flattened and treated as batch size. A set of quantization parameters
  414. /// will be calculated for each batch element.
  415. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  416. /// XNNPACK does not keep any pointers to this array after the function returns.
  417. /// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on
  418. /// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be
  419. /// created for the Value.
  420. /// @param flags - binary features of the Value. No supported flags are currently defined.
  421. /// @param id_out - pointer to the variable that will be initialized with the Value ID upon successful return. If a
  422. /// valid @a external_id was provided, the variable will be initialized with the @a external_id value.
  423. enum xnn_status xnn_define_dynamically_quantized_tensor_value(
  424. xnn_subgraph_t subgraph,
  425. enum xnn_datatype datatype,
  426. size_t num_dims,
  427. size_t num_nonbatch_dims,
  428. const size_t* dims,
  429. uint32_t external_id,
  430. uint32_t flags,
  431. uint32_t* id_out);
  432. /// Type of unary operation
  433. enum xnn_unary_operator {
  434. xnn_unary_invalid = -1,
  435. xnn_unary_convert,
  436. xnn_unary_clamp,
  437. xnn_unary_abs,
  438. xnn_unary_bankers_rounding,
  439. xnn_unary_ceiling,
  440. xnn_unary_elu,
  441. xnn_unary_exp,
  442. xnn_unary_floor,
  443. xnn_unary_gelu,
  444. xnn_unary_hardswish,
  445. xnn_unary_leaky_relu,
  446. xnn_unary_log,
  447. xnn_unary_negate,
  448. xnn_unary_sigmoid,
  449. xnn_unary_square,
  450. xnn_unary_square_root,
  451. xnn_unary_reciprocal_square_root,
  452. xnn_unary_tanh,
  453. // The following operators are experimental and may be removed.
  454. xnn_unary_cube_root,
  455. xnn_unary_cosine,
  456. xnn_unary_sine,
  457. xnn_unary_count_leading_zeros,
  458. xnn_unary_bitwise_not,
  459. xnn_unary_popcount,
  460. xnn_unary_sign,
  461. };
  462. /// Parameters for xnn_define_unary
  463. union xnn_unary_params {
  464. struct {
  465. /// lower bound for clipping output values.
  466. float min;
  467. /// upper bound for clipping output values.
  468. float max;
  469. } clamp;
  470. struct {
  471. /// scale factor for negative input elements.
  472. float alpha;
  473. } elu;
  474. struct {
  475. /// scale factor for negative input elements.
  476. float negative_slope;
  477. } leaky_relu;
  478. };
  479. /// Define a unary operator Node and add it to a Subgraph.
  480. ///
  481. /// @param subgraph - a Subgraph object that will own the created Node.
  482. /// @param operator - type of unary operator to define.
  483. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  484. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  485. /// shape must match the shape of the input tensor.
  486. /// @param params - parameters to be interpreted by the specific operator type.
  487. /// @param flags - binary features of the Node. No supported flags are currently defined.
  488. enum xnn_status xnn_define_unary(
  489. xnn_subgraph_t subgraph,
  490. enum xnn_unary_operator type,
  491. const union xnn_unary_params* params,
  492. uint32_t input_id,
  493. uint32_t output_id,
  494. uint32_t flags);
  495. /// Define a Convert Node and add it to a Subgraph.
  496. ///
  497. /// @param subgraph - a Subgraph object that will own the created Node.
  498. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  499. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  500. /// shape must match the shape of the input tensor.
  501. /// @param flags - binary features of the Convert Node. No supported flags are currently defined.
  502. XNN_DEPRECATED enum xnn_status xnn_define_convert(
  503. xnn_subgraph_t subgraph,
  504. uint32_t input_id,
  505. uint32_t output_id,
  506. uint32_t flags);
  507. /// Define a 2D Convolution Node and add it to a Subgraph.
  508. ///
  509. /// @param subgraph - a Subgraph object that will own the created Node.
  510. /// @param input_padding_top - implicit zero-padding above 2D input data. Must be 0 if XNN_FLAG_TENSORFLOW_SAME_PADDING
  511. /// flag is specified.
  512. /// @param input_padding_right - implicit zero-padding to the right of 2D input data. Must be 0 if
  513. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  514. /// @param input_padding_bottom - implicit zero-padding below 2D input data. Must be 0 if
  515. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  516. /// @param input_padding_left - implicit zero-padding to the left of 2D input data. Must be 0 if
  517. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  518. /// @param kernel_height - kernel (filter) height.
  519. /// @param kernel_width - kernel (filter) width.
  520. /// @param subsampling_height - height of subsampling region for convolution output (convolution height stride).
  521. /// @param subsampling_width - width of subsampling region for convolution output (convolution width stride).
  522. /// @param dilation_height - dilation of kernel elements along the height dimension.
  523. /// @param dilation_width - dilation of kernel elements along the width dimension.
  524. /// @param groups - number of convolution groups.
  525. /// @param group_input_channels - number of input channels per group.
  526. /// @param group_output_channels - number of output channels per group.
  527. /// @param output_min - lower bound for clipping output values.
  528. /// @param output_max - upper bound for clipping output values.
  529. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  530. /// with [N, IH, IW, groups * group_input_channels] dimensions
  531. /// @param filter_id - Value ID for the filter tensor. The filter tensor must ge a 4D tensor defined in the @a subgraph
  532. /// with [groups * group_output_channels, kernel_height, kernel_width, group_input_channels]
  533. /// dimensions.
  534. /// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a 2D Convolution Node without a bias. If
  535. /// present, the bias tensor must be a 1D tensor defined in the @a subgraph with [groups *
  536. /// group_output_channels] dimensions.
  537. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  538. /// with [N, OH, OW, groups * group_output_channels] dimensions.
  539. /// @param flags - binary features of the 2D Convolution Node. The only currently supported values is
  540. /// XNN_FLAG_TENSORFLOW_SAME_PADDING.
  541. enum xnn_status xnn_define_convolution_2d(
  542. xnn_subgraph_t subgraph,
  543. uint32_t input_padding_top,
  544. uint32_t input_padding_right,
  545. uint32_t input_padding_bottom,
  546. uint32_t input_padding_left,
  547. uint32_t kernel_height,
  548. uint32_t kernel_width,
  549. uint32_t subsampling_height,
  550. uint32_t subsampling_width,
  551. uint32_t dilation_height,
  552. uint32_t dilation_width,
  553. uint32_t groups,
  554. size_t group_input_channels,
  555. size_t group_output_channels,
  556. float output_min,
  557. float output_max,
  558. uint32_t input_id,
  559. uint32_t filter_id,
  560. uint32_t bias_id,
  561. uint32_t output_id,
  562. uint32_t flags);
  563. /// Define a 2D Deconvolution (Transposed Convolution) Node and add it to a Subgraph.
  564. ///
  565. /// @param subgraph - a Subgraph object that will own the created Node.
  566. /// @param padding_top - implicit padding above 2D output data.
  567. /// @param padding_right - implicit padding to the right of 2D output data.
  568. /// @param padding_bottom - implicit padding below 2D output data.
  569. /// @param padding_left - implicit padding to the left of 2D output data.
  570. /// @param adjustment_height - additional elements in the bottom of the 2D output data.
  571. /// @param adjustment_width - additional elements to the right of the 2D output data.
  572. /// @param kernel_height - kernel (filter) height.
  573. /// @param kernel_width - kernel (filter) width.
  574. /// @param upsampling_height - height of upsampling region for deconvolution input (deconvolution height stride).
  575. /// @param upsampling_width - width of upsampling region for deconvolution input (deconvolution width stride).
  576. /// @param dilation_height - dilation of kernel elements along the height dimension.
  577. /// @param dilation_width - dilation of kernel elements along the width dimension.
  578. /// @param groups - number of convolution groups.
  579. /// @param group_input_channels - number of input channels per group.
  580. /// @param group_output_channels - number of output channels per group.
  581. /// @param output_min - lower bound for clipping output values.
  582. /// @param output_max - upper bound for clipping output values.
  583. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  584. /// with [N, IH, IW, groups * group_input_channels] dimensions
  585. /// @param filter_id - Value ID for the filter tensor. The filter tensor must ge a 4D tensor defined in the @a subgraph
  586. /// with [groups * group_output_channels, kernel_height, kernel_width, group_input_channels]
  587. /// dimensions.
  588. /// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a 2D Convolution Node without a bias. If
  589. /// present, the bias tensor must be a 1D tensor defined in the @a subgraph with
  590. /// [groups * group_output_channels] dimensions.
  591. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  592. /// with [N, OH, OW, groups * group_output_channels] dimensions.
  593. /// @param flags - binary features of the 2D Deconvolution Node. No supported flags are currently defined.
  594. enum xnn_status xnn_define_deconvolution_2d(
  595. xnn_subgraph_t subgraph,
  596. uint32_t padding_top,
  597. uint32_t padding_right,
  598. uint32_t padding_bottom,
  599. uint32_t padding_left,
  600. uint32_t adjustment_height,
  601. uint32_t adjustment_width,
  602. uint32_t kernel_height,
  603. uint32_t kernel_width,
  604. uint32_t upsampling_height,
  605. uint32_t upsampling_width,
  606. uint32_t dilation_height,
  607. uint32_t dilation_width,
  608. uint32_t groups,
  609. size_t group_input_channels,
  610. size_t group_output_channels,
  611. float output_min,
  612. float output_max,
  613. uint32_t input_id,
  614. uint32_t filter_id,
  615. uint32_t bias_id,
  616. uint32_t output_id,
  617. uint32_t flags);
  618. /// Define a 2D Depthwise Convolution Node and add it to a Subgraph.
  619. ///
  620. /// @param subgraph - a Subgraph object that will own the created Node.
  621. /// @param input_padding_top - implicit zero-padding above 2D input data. Must be 0 if XNN_FLAG_TENSORFLOW_SAME_PADDING
  622. /// flag is specified.
  623. /// @param input_padding_right - implicit zero-padding to the right of 2D input data. Must be 0 if
  624. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  625. /// @param input_padding_bottom - implicit zero-padding below 2D input data. Must be 0 if
  626. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  627. /// @param input_padding_left - implicit zero-padding to the left of 2D input data. Must be 0 if
  628. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  629. /// @param kernel_height - kernel (filter) height.
  630. /// @param kernel_width - kernel (filter) width.
  631. /// @param subsampling_height - height of subsampling region for convolution output (convolution height stride).
  632. /// @param subsampling_width - width of subsampling region for convolution output (convolution width stride).
  633. /// @param dilation_height - dilation of kernel elements along the height dimension.
  634. /// @param dilation_width - dilation of kernel elements along the width dimension.
  635. /// @param depth_multiplier - ratio of output channels to input channels.
  636. /// @param input_channels - number of input channels.
  637. /// @param output_min - lower bound for clipping output values.
  638. /// @param output_max - upper bound for clipping output values.
  639. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  640. /// with [N, IH, IW, input_channels] dimensions
  641. /// @param filter_id - Value ID for the filter tensor. The filter tensor must ge a 4D tensor defined in the @a subgraph
  642. /// with [1, kernel_height, kernel_width, input_channels * depth_multiplier] dimensions.
  643. /// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a 2D Depthwise Convolution Node without
  644. /// a bias. If present, the bias tensor must be a 1D tensor defined in the @a subgraph with
  645. /// [input_channels * depth_multiplier] dimensions.
  646. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  647. /// with [N, OH, OW, input_channels * depth_multiplier] dimensions.
  648. /// @param flags - binary features of the 2D Depthwise Convolution Node. The only currently supported values is
  649. /// XNN_FLAG_TENSORFLOW_SAME_PADDING.
  650. enum xnn_status xnn_define_depthwise_convolution_2d(
  651. xnn_subgraph_t subgraph,
  652. uint32_t input_padding_top,
  653. uint32_t input_padding_right,
  654. uint32_t input_padding_bottom,
  655. uint32_t input_padding_left,
  656. uint32_t kernel_height,
  657. uint32_t kernel_width,
  658. uint32_t subsampling_height,
  659. uint32_t subsampling_width,
  660. uint32_t dilation_height,
  661. uint32_t dilation_width,
  662. uint32_t depth_multiplier,
  663. size_t input_channels,
  664. float output_min,
  665. float output_max,
  666. uint32_t input_id,
  667. uint32_t filter_id,
  668. uint32_t bias_id,
  669. uint32_t output_id,
  670. uint32_t flags);
  671. /// Define a Depth To Space Node 2D and add it to a Subgraph.
  672. ///
  673. /// The Depth To Space 2D Node rearranges data from depth into blocks of spatial data (a reverse transform to
  674. /// Space To Depth). For a given input pixel, an output square of pixels with side @a block_size is formed from values
  675. /// in the corresponding number of its channels. The output depth is therefore @a block_size x @a block_size times
  676. /// smaller than that of the input.
  677. ///
  678. /// @param subgraph - a Subgraph object that will own the created Node.
  679. /// @param block_size - the size of the spatial block.
  680. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  681. /// with [N, IH, IW, OC * block_size * block_size] dimensions.
  682. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  683. /// with [N, IH * block_size, IW * block_size, OC] dimensions.
  684. /// @param flags - binary features of the input_channels Node. No supported flags are currently defined.
  685. enum xnn_status xnn_define_depth_to_space_2d(
  686. xnn_subgraph_t subgraph,
  687. uint32_t block_size,
  688. uint32_t input_id,
  689. uint32_t output_id,
  690. uint32_t flags);
  691. enum xnn_status xnn_define_depth_to_space(
  692. xnn_subgraph_t subgraph,
  693. uint32_t input_id,
  694. uint32_t output_id,
  695. uint32_t block_size,
  696. uint32_t flags);
  697. /// Define a 1D Global Average Pooling Node and add it to a Subgraph.
  698. ///
  699. /// @param subgraph - a Subgraph object that will own the created Node.
  700. /// @param output_min - lower bound for clipping output values.
  701. /// @param output_max - upper bound for clipping output values.
  702. /// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with 2 or more dimensions
  703. /// defined in the @a subgraph. Averaging is performed across the second-innermost dimension.
  704. /// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor with 2 or more
  705. /// dimensions defined in the @a subgraph.
  706. /// @param flags - binary features of the 1D Global Average Pooling Node. The only currently supported value is
  707. /// XNN_FLAG_KEEP_DIMS.
  708. XNN_DEPRECATED enum xnn_status xnn_define_global_average_pooling_1d(
  709. xnn_subgraph_t subgraph,
  710. float output_min,
  711. float output_max,
  712. uint32_t input_id,
  713. uint32_t output_id,
  714. uint32_t flags);
  715. /// Define a 2D Global Average Pooling Node and add it to a Subgraph.
  716. ///
  717. /// @param subgraph - a Subgraph object that will own the created Node.
  718. /// @param output_min - lower bound for clipping output values.
  719. /// @param output_max - upper bound for clipping output values.
  720. /// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with 3 or more dimensions
  721. /// defined in the @a subgraph. Averaging is performed across the second- and third-innermost
  722. /// dimensions.
  723. /// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor with 3 or more
  724. /// dimensions defined in the @a subgraph.
  725. /// @param flags - binary features of the 2D Global Average Pooling Node. The only currently supported value is
  726. /// XNN_FLAG_KEEP_DIMS.
  727. XNN_DEPRECATED enum xnn_status xnn_define_global_average_pooling_2d(
  728. xnn_subgraph_t subgraph,
  729. float output_min,
  730. float output_max,
  731. uint32_t input_id,
  732. uint32_t output_id,
  733. uint32_t flags);
  734. /// Define a 1D Global Sum Pooling Node and add it to a Subgraph.
  735. ///
  736. /// @param subgraph - a Subgraph object that will own the created Node.
  737. /// @param output_min - lower bound for clipping output values.
  738. /// @param output_max - upper bound for clipping output values.
  739. /// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with 2 or more dimensions
  740. /// defined in the @a subgraph. Averaging is performed across the second-innermost dimension.
  741. /// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor with 2 or more
  742. /// dimensions defined in the @a subgraph.
  743. /// @param flags - binary features of the 1D Global Sum Pooling Node. The only currently supported value is
  744. /// XNN_FLAG_KEEP_DIMS.
  745. XNN_DEPRECATED enum xnn_status xnn_define_global_sum_pooling_1d(
  746. xnn_subgraph_t subgraph,
  747. float output_min,
  748. float output_max,
  749. uint32_t input_id,
  750. uint32_t output_id,
  751. uint32_t flags);
  752. /// Define a 2D Global Sum Pooling Node and add it to a Subgraph.
  753. ///
  754. /// @param subgraph - a Subgraph object that will own the created Node.
  755. /// @param output_min - lower bound for clipping output values.
  756. /// @param output_max - upper bound for clipping output values.
  757. /// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with 3 or more dimensions
  758. /// defined in the @a subgraph. Averaging is performed across the second- and third-innermost
  759. /// dimensions.
  760. /// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor with 3 or more
  761. /// dimensions defined in the @a subgraph.
  762. /// @param flags - binary features of the 2D Global Sum Pooling Node. The only currently supported value is
  763. /// XNN_FLAG_KEEP_DIMS.
  764. XNN_DEPRECATED enum xnn_status xnn_define_global_sum_pooling_2d(
  765. xnn_subgraph_t subgraph,
  766. float output_min,
  767. float output_max,
  768. uint32_t input_id,
  769. uint32_t output_id,
  770. uint32_t flags);
  771. /// Define a 2D Average Pooling Node and add it to a Subgraph.
  772. ///
  773. /// @param subgraph - a Subgraph object that will own the created Node.
  774. /// @param input_padding_top - implicit zero-padding above 2D input data. Must be 0 if XNN_FLAG_TENSORFLOW_SAME_PADDING
  775. /// flag is specified.
  776. /// @param input_padding_right - implicit zero-padding to the right of 2D input data. Must be 0 if
  777. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  778. /// @param input_padding_bottom - implicit zero-padding below 2D input data. Must be 0 if
  779. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  780. /// @param input_padding_left - implicit zero-padding to the left of 2D input data. Must be 0 if
  781. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  782. /// @param pooling_height - pooling (kernel) height.
  783. /// @param pooling_width - pooling (kernel) width.
  784. /// @param stride_height - displacing of the pooling window in the vertical dimension of the input pixels corresponding
  785. /// to vertically adjacent output pixels.
  786. /// @param stride_width - displacing of the pooling window in the horizontal dimension of the input pixels corresponding
  787. /// to horizontally adjacent output pixels.
  788. /// @param output_min - lower bound for clipping output values.
  789. /// @param output_max - upper bound for clipping output values.
  790. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  791. /// with [N, IH, IW, channels] dimensions
  792. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  793. /// with [N, OH, OW, channels] dimensions.
  794. /// @param flags - binary features of the 2D Average Pooling Node. The only currently supported values is
  795. /// XNN_FLAG_TENSORFLOW_SAME_PADDING.
  796. enum xnn_status xnn_define_average_pooling_2d(
  797. xnn_subgraph_t subgraph,
  798. uint32_t input_padding_top,
  799. uint32_t input_padding_right,
  800. uint32_t input_padding_bottom,
  801. uint32_t input_padding_left,
  802. uint32_t pooling_height,
  803. uint32_t pooling_width,
  804. uint32_t stride_height,
  805. uint32_t stride_width,
  806. float output_min,
  807. float output_max,
  808. uint32_t input_id,
  809. uint32_t output_id,
  810. uint32_t flags);
  811. /// Define a Fully Connected Node and add it to a Subgraph.
  812. ///
  813. /// @param subgraph - a Subgraph object that will own the created Node.
  814. /// @param output_min - lower bound for clipping output values.
  815. /// @param output_max - upper bound for clipping output values.
  816. /// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the
  817. /// @a subgraph. If XNN_FLAG_TENSORFLOW_RESHAPE_2D is not specified, the input tensor must be at least
  818. /// 1D and its last dimension must match the last dimension of the filter tensor. In particular, if
  819. /// input is a 2D tensor, it must have [batch_size, input_channels] dimensions.
  820. /// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is specified, the number of elements in the input tensor must be
  821. /// divisible by the input_channels. The tensor will be first flattened into a 1D tensor of
  822. /// [num_input_elements] dimensions, then reshaped into a 2D tensor of
  823. /// [num_input_elements / input_channels, input_channels] dimensions where num_input_elements is the
  824. /// total number of elements in the input tensor.
  825. /// @param filter_id - Value ID for the filter tensor. The filter tensor must a 2D tensor defined in the @a subgraph.
  826. /// If the XNN_FLAG_TRANSPOSE_WEIGHTS flag is not specified, the filter tensor must have
  827. /// [output_channels, input_channels] dimensions. If the XNN_FLAG_TRANSPOSE_WEIGHTS flag is
  828. /// specified, the filter tensor must have [input_channels, output_channels] dimensions.
  829. /// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a Fully Connected Node without a bias.
  830. /// If present, the bias tensor must be a 1D tensor defined in the @a subgraph with [output_channels]
  831. /// dimensions.
  832. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph.
  833. /// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is not specified, the output tensor must have the same
  834. /// dimensionality as the input tensor, all its dimensions but the last one must match the
  835. /// corresponding dimensions of the input tensor, and the last dimensions of the output tensor must
  836. /// match the first dimension of the filter tensor. In particular, if input is a 2D tensor, output
  837. /// must be a 2D tensor of [batch_size, output_channels] dimensions.
  838. /// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is specified, output must be a 2D tensor of
  839. /// [num_input_elements / input_channels, output_channels] dimensions where num_input_elements is the
  840. /// total number of elements in the input tensor.
  841. /// @param flags - binary features of the Fully Connected Node. The only currently supported values are
  842. /// XNN_FLAG_TENSORFLOW_RESHAPE_2D and XNN_FLAG_TRANSPOSE_WEIGHTS.
  843. enum xnn_status xnn_define_fully_connected(
  844. xnn_subgraph_t subgraph,
  845. float output_min,
  846. float output_max,
  847. uint32_t input_id,
  848. uint32_t filter_id,
  849. uint32_t bias_id,
  850. uint32_t output_id,
  851. uint32_t flags);
  852. /// Define a Sparse Fully Connected Node and add it to a Subgraph.
  853. ///
  854. /// This operator is experimental, and will be removed in the future.
  855. ///
  856. /// @param subgraph - a Subgraph object that will own the created Node.
  857. /// @param output_min - lower bound for clipping output values.
  858. /// @param output_max - upper bound for clipping output values.
  859. /// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the
  860. /// @a subgraph. If XNN_FLAG_TENSORFLOW_RESHAPE_2D is not specified, the input tensor must be at least
  861. /// 1D and its last dimension must match the last dimension of the filter tensor. In particular, if
  862. /// input is a 2D tensor, it must have [batch_size, input_channels] dimensions.
  863. /// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is specified, the number of elements in the input tensor must be
  864. /// divisible by the input_channels. The tensor will be first flattened into a 1D tensor of
  865. /// [num_input_elements] dimensions, then reshaped into a 2D tensor of
  866. /// [num_input_elements / input_channels, input_channels] dimensions where num_input_elements is the
  867. /// total number of elements in the input tensor.
  868. /// @param filter_id - Value ID for the filter tensor. The filter tensor must a 2D tensor defined in the @a subgraph.
  869. /// If the XNN_FLAG_TRANSPOSE_WEIGHTS flag is not specified, the filter tensor must have
  870. /// [output_channels, input_channels] dimensions. If the XNN_FLAG_TRANSPOSE_WEIGHTS flag is
  871. /// specified, the filter tensor must have [input_channels, output_channels] dimensions.
  872. /// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a Fully Connected Node without a bias.
  873. /// If present, the bias tensor must be a 1D tensor defined in the @a subgraph with [output_channels]
  874. /// dimensions.
  875. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph.
  876. /// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is not specified, the output tensor must have the same
  877. /// dimensionality as the input tensor, all its dimensions but the last one must match the
  878. /// corresponding dimensions of the input tensor, and the last dimensions of the output tensor must
  879. /// match the first dimension of the filter tensor. In particular, if input is a 2D tensor, output
  880. /// must be a 2D tensor of [batch_size, output_channels] dimensions.
  881. /// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is specified, output must be a 2D tensor of
  882. /// [num_input_elements / input_channels, output_channels] dimensions where num_input_elements is the
  883. /// total number of elements in the input tensor.
  884. /// @param flags - binary features of the Fully Connected Node. The only currently supported values are
  885. /// XNN_FLAG_TENSORFLOW_RESHAPE_2D and XNN_FLAG_TRANSPOSE_WEIGHTS.
  886. enum xnn_status xnn_define_fully_connected_sparse(
  887. xnn_subgraph_t subgraph,
  888. float output_min,
  889. float output_max,
  890. uint32_t input_id,
  891. uint32_t filter_id,
  892. uint32_t bias_id,
  893. uint32_t output_id,
  894. uint32_t flags);
  895. /// Define a 2D Max Pooling Node and add it to a Subgraph.
  896. ///
  897. /// @param subgraph - a Subgraph object that will own the created Node.
  898. /// @param input_padding_top - implicit zero-padding above 2D input data. Must be 0 if XNN_FLAG_TENSORFLOW_SAME_PADDING
  899. /// flag is specified.
  900. /// @param input_padding_right - implicit zero-padding to the right of 2D input data. Must be 0 if
  901. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  902. /// @param input_padding_bottom - implicit zero-padding below 2D input data. Must be 0 if
  903. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  904. /// @param input_padding_left - implicit zero-padding to the left of 2D input data. Must be 0 if
  905. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  906. /// @param pooling_height - pooling (kernel) height.
  907. /// @param pooling_width - pooling (kernel) width.
  908. /// @param stride_height - displacing of the pooling window in the vertical dimension of the input pixels corresponding
  909. /// to vertically adjacent output pixels.
  910. /// @param stride_width - displacing of the pooling window in the horizontal dimension of the input pixels corresponding
  911. /// to horizontally adjacent output pixels.
  912. /// @param dilation_height - dilation of pooling elements along the height dimension.
  913. /// @param dilation_width - dilation of pooling elements along the width dimension.
  914. /// @param output_min - lower bound for clipping output values.
  915. /// @param output_max - upper bound for clipping output values.
  916. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  917. /// with [N, IH, IW, channels] dimensions
  918. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  919. /// with [N, OH, OW, channels] dimensions.
  920. /// @param flags - binary features of the 2D Max Pooling Node. The only currently supported values is
  921. /// XNN_FLAG_TENSORFLOW_SAME_PADDING.
  922. enum xnn_status xnn_define_max_pooling_2d(
  923. xnn_subgraph_t subgraph,
  924. uint32_t input_padding_top,
  925. uint32_t input_padding_right,
  926. uint32_t input_padding_bottom,
  927. uint32_t input_padding_left,
  928. uint32_t pooling_height,
  929. uint32_t pooling_width,
  930. uint32_t stride_height,
  931. uint32_t stride_width,
  932. uint32_t dilation_height,
  933. uint32_t dilation_width,
  934. float output_min,
  935. float output_max,
  936. uint32_t input_id,
  937. uint32_t output_id,
  938. uint32_t flags);
  939. /// Define a 2D ArgMax Pooling Node and add it to a Subgraph.
  940. ///
  941. /// @param subgraph - a Subgraph object that will own the created Node.
  942. /// @param input_padding_top - implicit zero-padding above 2D input data.
  943. /// @param input_padding_right - implicit zero-padding to the right of 2D input data.
  944. /// @param input_padding_bottom - implicit zero-padding below 2D input data.
  945. /// @param input_padding_left - implicit zero-padding to the left of 2D input data.
  946. /// @param pooling_height - pooling (kernel) height. Vertical stride between pooling regions match this value.
  947. /// @param pooling_width - pooling (kernel) width. Horizontal stride between pooling regions match this value.
  948. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  949. /// with [N, IH, IW, channels] dimensions
  950. /// @param output_value_id - Value ID for the output tensor with the maximum values in the pools. The output tensor must
  951. /// be a 4D tensor defined in the @a subgraph with [N, OH, OW, channels] dimensions.
  952. /// @param output_index_id - Value ID for the output tensor with the indexes of the maximum values in the pools. The
  953. /// output tensor must be a 4D tensor defined in the @a subgraph with [N, OH, OW, channels]
  954. /// dimensions.
  955. /// @param flags - binary features of the 2D ArgMax Pooling Node. No supported flags are currently defined.
  956. enum xnn_status xnn_define_argmax_pooling_2d(
  957. xnn_subgraph_t subgraph,
  958. uint32_t input_padding_top,
  959. uint32_t input_padding_right,
  960. uint32_t input_padding_bottom,
  961. uint32_t input_padding_left,
  962. uint32_t pooling_height,
  963. uint32_t pooling_width,
  964. uint32_t input_id,
  965. uint32_t output_value_id,
  966. uint32_t output_index_id,
  967. uint32_t flags);
  968. /// Define a 2D UnPooling Node and add it to a Subgraph.
  969. ///
  970. /// @param subgraph - a Subgraph object that will own the created Node.
  971. /// @param padding_top - implicit padding above 2D output data.
  972. /// @param padding_right - implicit padding to the right of 2D output data.
  973. /// @param padding_bottom - implicit padding below 2D output data.
  974. /// @param padding_left - implicit padding to the left of 2D output data.
  975. /// @param pooling_height - height of the pooling window.
  976. /// @param pooling_width - width of the pooling window.
  977. /// @param input_value_id - Value ID for the input tensor with the max-pooling values to invert. The input value tensor
  978. /// must be a 4D tensor defined in the @a subgraph with [N, IH, IW, channels] dimensions.
  979. /// @param input_index_id - Value ID for the input tensor with the indices of the per-pool maximum values produced by
  980. /// a 2D UnPooling Node. The input tensor must be a 4D tensor defined in the @a subgraph with
  981. /// [N, IH, IW, channels] dimensions.
  982. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  983. /// with [N, OH, OW, channels] dimensions.
  984. /// @param flags - binary features of the 2D UnPooling Node. No supported flags are currently defined.
  985. enum xnn_status xnn_define_unpooling_2d(
  986. xnn_subgraph_t subgraph,
  987. uint32_t padding_top,
  988. uint32_t padding_right,
  989. uint32_t padding_bottom,
  990. uint32_t padding_left,
  991. uint32_t pooling_height,
  992. uint32_t pooling_width,
  993. uint32_t input_value_id,
  994. uint32_t input_index_id,
  995. uint32_t output_id,
  996. uint32_t flags);
  997. enum xnn_binary_operator {
  998. xnn_binary_invalid = -1,
  999. xnn_binary_add,
  1000. xnn_binary_subtract,
  1001. xnn_binary_multiply,
  1002. xnn_binary_divide,
  1003. xnn_binary_maximum,
  1004. xnn_binary_minimum,
  1005. xnn_binary_copysign,
  1006. xnn_binary_squared_difference,
  1007. xnn_binary_prelu,
  1008. // The following operators are experimental and may be removed.
  1009. xnn_binary_modulus,
  1010. xnn_binary_atan2,
  1011. xnn_binary_pow,
  1012. xnn_binary_bitwise_and,
  1013. xnn_binary_bitwise_or,
  1014. xnn_binary_bitwise_xor,
  1015. xnn_binary_shift_left,
  1016. xnn_binary_shift_right_logical,
  1017. xnn_binary_shift_right_arithmetic,
  1018. };
  1019. struct xnn_binary_params {
  1020. /// lower bound for clipping output values.
  1021. double output_min;
  1022. /// upper bound for clipping output values.
  1023. double output_max;
  1024. };
  1025. /// Define a 2-Input binary operator Node and add it to a Subgraph.
  1026. ///
  1027. /// @param subgraph - a Subgraph object that will own the created Node.
  1028. /// @param type - Type of operator to apply to the two inputs.
  1029. /// @param params - Optional parameters for the operator.
  1030. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1031. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1032. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1033. /// that dimension.
  1034. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1035. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1036. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1037. /// that dimension.
  1038. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1039. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1040. /// of the two inputs.
  1041. /// @param flags - binary features of the Node. No supported flags are currently defined.
  1042. enum xnn_status xnn_define_binary(
  1043. xnn_subgraph_t subgraph,
  1044. enum xnn_binary_operator type,
  1045. const struct xnn_binary_params* params,
  1046. uint32_t input1_id,
  1047. uint32_t input2_id,
  1048. uint32_t output_id,
  1049. uint32_t flags);
  1050. /// Define a 2-Input Add Node and add it to a Subgraph.
  1051. ///
  1052. /// The 2-Input Add Node computes elementwise addition of two tensor inputs with numpy broadcasting rules.
  1053. ///
  1054. /// @param subgraph - a Subgraph object that will own the created Node.
  1055. /// @param output_min - lower bound for clipping output values.
  1056. /// @param output_max - upper bound for clipping output values.
  1057. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1058. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1059. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1060. /// that dimension.
  1061. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1062. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1063. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1064. /// that dimension.
  1065. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1066. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1067. /// of the two inputs.
  1068. /// @param flags - binary features of the Add Node. No supported flags are currently defined.
  1069. XNN_DEPRECATED enum xnn_status xnn_define_add2(
  1070. xnn_subgraph_t subgraph,
  1071. float output_min,
  1072. float output_max,
  1073. uint32_t input1_id,
  1074. uint32_t input2_id,
  1075. uint32_t output_id,
  1076. uint32_t flags);
  1077. /// Define a 2-Input Multiply Node and add it to a Subgraph.
  1078. ///
  1079. /// The 2-Input Multiply Node computes elementwise multiplication of two tensor inputs with numpy broadcasting rules.
  1080. ///
  1081. /// @param subgraph - a Subgraph object that will own the created Node.
  1082. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1083. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1084. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1085. /// that dimension.
  1086. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1087. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1088. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1089. /// that dimension.
  1090. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1091. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1092. /// of the two inputs.
  1093. /// @param flags - binary features of the Multiply Node. No supported flags are currently defined.
  1094. XNN_DEPRECATED enum xnn_status xnn_define_multiply2(
  1095. xnn_subgraph_t subgraph,
  1096. float output_min,
  1097. float output_max,
  1098. uint32_t input1_id,
  1099. uint32_t input2_id,
  1100. uint32_t output_id,
  1101. uint32_t flags);
  1102. // Cap operations applied to logits (Q * K) of attention operator.
  1103. enum xnn_attention_logits_cap_type {
  1104. // No capping.
  1105. xnn_attention_logits_cap_type_none = 0,
  1106. // Cap the absolute values of logits by tanh: tanh(logits / cap) * cap
  1107. xnn_attention_logits_cap_type_tanh
  1108. };
  1109. // Params when the cap type is xnn_attention_logits_cap_type_tanh.
  1110. struct xnn_attention_logits_cap_tanh_params {
  1111. float cap;
  1112. };
  1113. /// Define a Scaled Dot-Product Attention Node and add it to a Subgraph.
  1114. ///
  1115. /// This operator is experimental.
  1116. ///
  1117. /// The Scaled Dot-Product Attention Node computes a multi-head or multi-query scaled dot attention on the query, key,
  1118. /// and value tensors.
  1119. ///
  1120. /// @param subgraph - a Subgraph object that will own the created Node.
  1121. /// @param cap_type - type of cap to be applied to the logits.
  1122. /// @param cap_params - parameters for the cap. Must be a pointer to xnn_attention_logits_cap_tanh_params if cap_type
  1123. /// is xnn_attention_logits_cap_type_tanh.
  1124. /// @param query_id - Value ID for the query tensor. The query tensor must be a 3+-dimensional tensor defined in the
  1125. /// @a subgraph with the dimensions as [*, H, T, C], where H/T/C are the heads/tokens/channels, and *
  1126. /// is the 0 or more dimensions treated as batch size.
  1127. /// @param key_id - Value ID for the key tensor. The key tensor must be a 2+--dimensional tensor defined in the
  1128. /// @a subgraph. It can have the same number of dimensions as the query, with the dimensions as
  1129. /// [*, H, U, C] (multi-head), or have 1 less dimension than the query, with the dimensions as
  1130. /// as [*, U, C] (multi-query, number of heads omitted implies single head), where H/U/C are the
  1131. /// heads/key_value_tokens/channels, and * is the 0 or more dimensions treated as batch size. These
  1132. /// batch size dimensions must be the same as query.
  1133. /// @param value_id - Value ID for the value tensor. The value tensor must be a 2+--dimensional tensor defined in the
  1134. /// @a subgraph. It can have the same number of dimensions as the query, with the dimensions as
  1135. /// [*, H, U, D] (multi-head), or have 1 less dimension than the query, with the dimensions as
  1136. /// as [*, U, D] (multi-query, number of heads omitted implies single head), where H/U/D are the
  1137. /// heads/key_value_tokens/value_channels, and * is the 0 or more dimensions treated as batch size.
  1138. /// These batch size dimensions must be the same as query and key.
  1139. /// @param scale_id - Value ID for the scale tensor. The scale tensor must be a 1D tensor defined in the @a subgraph
  1140. /// with [C] dimensions. The query tensor is multiplied with this scale tensor before the dot product
  1141. /// with the key tensor.
  1142. /// @param mask_id - Value ID for the mask tensor. The mask tensor must be a 2D tensor defined in the @a subgraph with
  1143. /// [T, U] dimensions. The mask tensor is added to the logits (query dot value).
  1144. /// @param output_id - Value ID for the output tensor. The output tensor must be a 3+-dimensional tensor defined in the
  1145. /// @a subgraph with the dimensions as [*, H, T, D], where H/T/D are the heads/tokens/value_channels,
  1146. /// and * is the 0 or more dimensions treated as batch size. These batch size dimensions must be the
  1147. /// same as query, key, and value.
  1148. /// @param flags - binary features of the Scaled Dot Product Attention Node. No supported flags are currently defined.
  1149. enum xnn_status xnn_define_scaled_dot_product_attention(
  1150. xnn_subgraph_t subgraph,
  1151. enum xnn_attention_logits_cap_type cap_type,
  1152. const void* cap_params,
  1153. uint32_t query_id,
  1154. uint32_t key_id,
  1155. uint32_t value_id,
  1156. uint32_t scale_id,
  1157. uint32_t mask_id,
  1158. uint32_t output_id,
  1159. uint32_t flags);
  1160. /// Define a Subtract Node and add it to a Subgraph.
  1161. ///
  1162. /// The Subtract Node computes elementwise subtraction of two tensor inputs with numpy broadcasting rules.
  1163. ///
  1164. /// @param subgraph - a Subgraph object that will own the created Node.
  1165. /// @param output_min - lower bound for clipping output values.
  1166. /// @param output_max - upper bound for clipping output values.
  1167. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1168. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1169. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1170. /// that dimension.
  1171. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1172. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1173. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1174. /// that dimension.
  1175. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1176. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1177. /// of the two inputs.
  1178. /// @param flags - binary features of the Subtract Node. No supported flags are currently defined.
  1179. XNN_DEPRECATED enum xnn_status xnn_define_subtract(
  1180. xnn_subgraph_t subgraph,
  1181. float output_min,
  1182. float output_max,
  1183. uint32_t input1_id,
  1184. uint32_t input2_id,
  1185. uint32_t output_id,
  1186. uint32_t flags);
  1187. /// Define a Divide Node and add it to a Subgraph.
  1188. ///
  1189. /// The Divide Node computes elementwise division of two tensor inputs with numpy broadcasting rules.
  1190. ///
  1191. /// @param subgraph - a Subgraph object that will own the created Node.
  1192. /// @param output_min - lower bound for clipping output values.
  1193. /// @param output_max - upper bound for clipping output values.
  1194. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1195. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1196. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1197. /// that dimension.
  1198. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1199. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1200. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1201. /// that dimension.
  1202. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1203. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1204. /// of the two inputs.
  1205. /// @param flags - binary features of the Divide Node. No supported flags are currently defined.
  1206. XNN_DEPRECATED enum xnn_status xnn_define_divide(
  1207. xnn_subgraph_t subgraph,
  1208. float output_min,
  1209. float output_max,
  1210. uint32_t input1_id,
  1211. uint32_t input2_id,
  1212. uint32_t output_id,
  1213. uint32_t flags);
  1214. /// Define a 2-Input Maximum Node and add it to a Subgraph.
  1215. ///
  1216. /// The 2-Input Maximum Node computes elementwise maximum of two tensor inputs with numpy broadcasting rules.
  1217. ///
  1218. /// @param subgraph - a Subgraph object that will own the created Node.
  1219. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1220. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1221. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1222. /// that dimension.
  1223. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1224. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1225. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1226. /// that dimension.
  1227. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1228. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1229. /// of the two inputs.
  1230. /// @param flags - binary features of the Maximum Node. No supported flags are currently defined.
  1231. XNN_DEPRECATED enum xnn_status xnn_define_maximum2(
  1232. xnn_subgraph_t subgraph,
  1233. uint32_t input1_id,
  1234. uint32_t input2_id,
  1235. uint32_t output_id,
  1236. uint32_t flags);
  1237. /// Define a 2-Input Minimum Node and add it to a Subgraph.
  1238. ///
  1239. /// The 2-Input Minimum Node computes elementwise minimum of two tensor inputs with numpy broadcasting rules.
  1240. ///
  1241. /// @param subgraph - a Subgraph object that will own the created Node.
  1242. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1243. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1244. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1245. /// that dimension.
  1246. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1247. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1248. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1249. /// that dimension.
  1250. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1251. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1252. /// of the two inputs.
  1253. /// @param flags - binary features of the Minimum Node. No supported flags are currently defined.
  1254. XNN_DEPRECATED enum xnn_status xnn_define_minimum2(
  1255. xnn_subgraph_t subgraph,
  1256. uint32_t input1_id,
  1257. uint32_t input2_id,
  1258. uint32_t output_id,
  1259. uint32_t flags);
  1260. /// Define a Squared Difference Node and add it to a Subgraph.
  1261. ///
  1262. /// The Squared Difference Node computes elementwise squared difference of two tensor inputs with numpy broadcasting
  1263. /// rules.
  1264. ///
  1265. /// @param subgraph - a Subgraph object that will own the created Node.
  1266. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1267. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1268. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1269. /// that dimension.
  1270. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1271. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1272. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1273. /// that dimension.
  1274. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1275. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1276. /// of the two inputs.
  1277. /// @param flags - binary features of the Squared Difference Node. No supported flags are currently defined.
  1278. XNN_DEPRECATED enum xnn_status xnn_define_squared_difference(
  1279. xnn_subgraph_t subgraph,
  1280. uint32_t input1_id,
  1281. uint32_t input2_id,
  1282. uint32_t output_id,
  1283. uint32_t flags);
  1284. /// Define a Constant Pad Node with static padding specification and add it to a Subgraph.
  1285. ///
  1286. /// @param subgraph - a Subgraph object that will own the created Node.
  1287. /// @param pre_paddings - number of padding elements to insert before input elements for every dimension. This array
  1288. /// must have as many elements as the number of dimensions in the input tensor.
  1289. /// @param post_paddings - number of padding elements to insert after input elements for every dimension. This array
  1290. /// must have as many elements as the number of dimensions in the input tensor.
  1291. /// @param padding_value - constant value used to initialize padding elements.
  1292. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1293. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1294. /// shape must match the shape of the input tensor with padding.
  1295. /// @param flags - binary features of the Constant Pad Node. No supported flags are currently defined.
  1296. enum xnn_status xnn_define_static_constant_pad(
  1297. xnn_subgraph_t subgraph,
  1298. const size_t* pre_paddings,
  1299. const size_t* post_paddings,
  1300. float padding_value,
  1301. uint32_t input_id,
  1302. uint32_t output_id,
  1303. uint32_t flags);
  1304. /// Define a Expand Dims Node with and add it to a Subgraph.
  1305. ///
  1306. /// @param subgraph - a Subgraph object that will own the created Node.
  1307. /// @param num_new_axes - number of new axes of size 1 to be inserted.
  1308. /// @param new_axes - The axis positions of the new axes in the expanded dimensions.
  1309. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1310. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1311. /// shape must match the shape of the input tensor with padding.
  1312. /// @param flags - binary features of the Constant Pad Node. No supported flags are currently defined.
  1313. enum xnn_status xnn_define_static_expand_dims(
  1314. xnn_subgraph_t subgraph,
  1315. size_t num_new_axes,
  1316. const size_t* new_axes,
  1317. uint32_t input_id,
  1318. uint32_t output_id,
  1319. uint32_t flags);
  1320. /// Define a Mean Node and add it to a Subgraph.
  1321. ///
  1322. /// @param subgraph - a Subgraph object that will own the created Node.
  1323. /// @param num_reduction_axes - number of axes along which mean is computed.
  1324. /// @param reduction_axes - axes along which mean is computed.
  1325. /// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with at least
  1326. /// @a num_reduction_axes dimensions defined in the @a subgraph.
  1327. /// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor defined in the
  1328. /// @a subgraph with @a num_reduction_axes fewer dimensions than the input tensor (if
  1329. /// XNN_FLAG_KEEP_DIMS is not specified), or has same dimension rank but the dimension at
  1330. /// @a reduction_axes reduced to 1 (if XNN_FLAG_KEEP_DIMS is specified).
  1331. /// @param flags - binary features of the Mean Node. The only currently supported value is XNN_FLAG_KEEP_DIMS
  1332. XNN_DEPRECATED enum xnn_status xnn_define_static_mean(
  1333. xnn_subgraph_t subgraph,
  1334. size_t num_reduction_axes,
  1335. const size_t* reduction_axes,
  1336. uint32_t input_id,
  1337. uint32_t output_id,
  1338. uint32_t flags);
  1339. enum xnn_reduce_operator {
  1340. xnn_reduce_invalid = -1,
  1341. xnn_reduce_sum,
  1342. xnn_reduce_mean,
  1343. };
  1344. /// Define a Reduce Node and add it to a Subgraph.
  1345. ///
  1346. /// @param subgraph - a Subgraph object that will own the created Node.
  1347. /// @param num_reduction_axes - number of axes along which reduce is computed.
  1348. /// @param reduction_axes - axes along which reduce is computed.
  1349. /// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with at least
  1350. /// @a num_reduction_axes dimensions defined in the @a subgraph.
  1351. /// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor defined in the
  1352. /// @a subgraph with @a num_reduction_axes fewer dimensions than the input tensor (if
  1353. /// XNN_FLAG_KEEP_DIMS is not specified), or has same dimension rank but the dimension at
  1354. /// @a reduction_axes reduced to 1 (if XNN_FLAG_KEEP_DIMS is specified).
  1355. /// @param flags - binary features of the Reduce Node. The only currently supported value is XNN_FLAG_KEEP_DIMS
  1356. enum xnn_status xnn_define_static_reduce(
  1357. xnn_subgraph_t subgraph,
  1358. enum xnn_reduce_operator reduce_operator_type,
  1359. size_t num_reduction_axes,
  1360. const size_t* reduction_axes,
  1361. uint32_t input_id,
  1362. uint32_t output_id,
  1363. uint32_t flags);
  1364. /// Define a Reduce Node and add it to a Subgraph.
  1365. ///
  1366. /// @param subgraph - a Subgraph object that will own the created Node.
  1367. /// @param num_reduction_axes - number of axes along which reduce is computed.
  1368. /// @param reduction_axes - axes along which reduce is computed. Negative values
  1369. /// are interpreted as offsets from @a
  1370. /// num_reduction_axes.
  1371. /// @param input_id - Value ID for the input tensor. The input tensor must be a
  1372. /// dense tensor with at least @a num_reduction_axes
  1373. /// dimensions defined in the @a subgraph.
  1374. /// @param output_id - Value ID for the output tensor. The output tensor must be
  1375. /// a dense tensor defined in the @a subgraph with @a
  1376. /// num_reduction_axes fewer dimensions than the input tensor
  1377. /// (if XNN_FLAG_KEEP_DIMS is not specified), or has same
  1378. /// dimension rank but the dimension at
  1379. /// @a reduction_axes reduced to 1 (if XNN_FLAG_KEEP_DIMS is
  1380. /// specified).
  1381. /// @param flags - binary features of the Reduce Node. The only currently
  1382. /// supported value is XNN_FLAG_KEEP_DIMS
  1383. enum xnn_status xnn_define_static_reduce_v2( //
  1384. xnn_subgraph_t subgraph, //
  1385. enum xnn_reduce_operator reduce_operator_type, //
  1386. size_t num_reduction_axes, //
  1387. const int64_t* reduction_axes, //
  1388. uint32_t input_id, //
  1389. uint32_t output_id, //
  1390. uint32_t flags);
  1391. /// Define a 2-Input Concatenate Node and add it to a Subgraph.
  1392. ///
  1393. /// The 2-Input Concatenate Node concatenates two tensors along a specified axis.
  1394. ///
  1395. /// @param subgraph - a Subgraph object that will own the created Node.
  1396. /// @param axis - the axis to concatenate the two input tensors along. If this is less than zero, the number of
  1397. /// dimensions is added to it.
  1398. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1399. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1400. /// second input.
  1401. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined in
  1402. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1403. /// first input.
  1404. /// @param output_id - Value ID for the output tensor. The output tensor must be a N-dimensional tensor defined
  1405. /// in the @a subgraph with each dimension equal to the dimension of both inputs, except the axis
  1406. /// dimension, where it is the sum of the corresponding dimensions of both inputs.
  1407. /// @param flags - binary features of the Concatenate Node. No supported flags are currently defined.
  1408. enum xnn_status xnn_define_concatenate2(
  1409. xnn_subgraph_t subgraph,
  1410. int32_t axis,
  1411. uint32_t input1_id,
  1412. uint32_t input2_id,
  1413. uint32_t output_id,
  1414. uint32_t flags);
  1415. /// Define a 3-Input Concatenate Node and add it to a Subgraph.
  1416. ///
  1417. /// The 3-Input Concatenate Node concatenates three tensors along a specified axis.
  1418. ///
  1419. /// @param subgraph - a Subgraph object that will own the created Node.
  1420. /// @param axis - the axis to concatenate the two input tensors along. If this is less than zero, the number of
  1421. /// dimensions is added to it.
  1422. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1423. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1424. /// other inputs.
  1425. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined in
  1426. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1427. /// other inputs.
  1428. /// @param input3_id - Value ID for the third input tensor. The input tensor must be an N-dimensional tensor defined in
  1429. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1430. /// other inputs.
  1431. /// @param output_id - Value ID for the output tensor. The output tensor must be a N-dimensional tensor defined
  1432. /// in the @a subgraph with each dimension equal to the dimension of all inputs, except the axis
  1433. /// dimension, where it is the sum of the corresponding dimensions of all inputs.
  1434. /// @param flags - binary features of the Concatenate Node. No supported flags are currently defined.
  1435. enum xnn_status xnn_define_concatenate3(
  1436. xnn_subgraph_t subgraph,
  1437. int32_t axis,
  1438. uint32_t input1_id,
  1439. uint32_t input2_id,
  1440. uint32_t input3_id,
  1441. uint32_t output_id,
  1442. uint32_t flags);
  1443. /// Define a 4-Input Concatenate Node and add it to a Subgraph.
  1444. ///
  1445. /// The 4-Input Concatenate Node concatenates four tensors along a specified axis.
  1446. ///
  1447. /// @param subgraph - a Subgraph object that will own the created Node.
  1448. /// @param axis - the axis to concatenate the two input tensors along. If this is less than zero, the number of
  1449. /// dimensions is added to it.
  1450. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1451. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1452. /// other inputs.
  1453. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined in
  1454. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1455. /// other inputs.
  1456. /// @param input3_id - Value ID for the third input tensor. The input tensor must be an N-dimensional tensor defined in
  1457. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1458. /// other inputs.
  1459. /// @param input4_id - Value ID for the fourth input tensor. The input tensor must be an N-dimensional tensor defined in
  1460. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1461. /// other inputs.
  1462. /// @param output_id - Value ID for the output tensor. The output tensor must be a N-dimensional tensor defined
  1463. /// in the @a subgraph with each dimension equal to the dimension of all inputs, except the axis
  1464. /// dimension, where it is the sum of the corresponding dimensions of all inputs.
  1465. /// @param flags - binary features of the Concatenate Node. No supported flags are currently defined.
  1466. enum xnn_status xnn_define_concatenate4(
  1467. xnn_subgraph_t subgraph,
  1468. int32_t axis,
  1469. uint32_t input1_id,
  1470. uint32_t input2_id,
  1471. uint32_t input3_id,
  1472. uint32_t input4_id,
  1473. uint32_t output_id,
  1474. uint32_t flags);
  1475. /// Define a 5-Input Concatenate Node and add it to a Subgraph.
  1476. ///
  1477. /// The 5-Input Concatenate Node concatenates four tensors along a specified axis.
  1478. ///
  1479. /// @param subgraph - a Subgraph object that will own the created Node.
  1480. /// @param axis - the axis to concatenate the two input tensors along. If this is less than zero, the number of
  1481. /// dimensions is added to it.
  1482. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1483. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1484. /// other inputs.
  1485. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined in
  1486. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1487. /// other inputs.
  1488. /// @param input3_id - Value ID for the third input tensor. The input tensor must be an N-dimensional tensor defined in
  1489. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1490. /// other inputs.
  1491. /// @param input4_id - Value ID for the fourth input tensor. The input tensor must be an N-dimensional tensor defined in
  1492. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1493. /// other inputs.
  1494. /// @param input5_id - Value ID for the fourth input tensor. The input tensor must be an N-dimensional tensor defined in
  1495. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1496. /// other inputs.
  1497. /// @param output_id - Value ID for the output tensor. The output tensor must be a N-dimensional tensor defined
  1498. /// in the @a subgraph with each dimension equal to the dimension of all inputs, except the axis
  1499. /// dimension, where it is the sum of the corresponding dimensions of all inputs.
  1500. enum xnn_status xnn_define_concatenate5(
  1501. xnn_subgraph_t subgraph,
  1502. int32_t axis,
  1503. uint32_t input1_id,
  1504. uint32_t input2_id,
  1505. uint32_t input3_id,
  1506. uint32_t input4_id,
  1507. uint32_t input5_id,
  1508. uint32_t output_id,
  1509. uint32_t flags);
  1510. /// Define a Copy Sign Node and add it to a Subgraph.
  1511. ///
  1512. /// The Copy Sign Node copies the sign of the second input to the first input.
  1513. ///
  1514. /// @param subgraph - a Subgraph object that will own the created Node.
  1515. /// @param input1_id - Value ID for the first input tensor. The input tensor must be defined in the @a subgraph.
  1516. /// @param input2_id - Value ID for the second input tensor. The input tensor must be defined in the @a subgraph.
  1517. /// @param output_id - Value ID for the output tensor.
  1518. /// @param flags - binary features of the Copy Sign Node. No supported flags are currently defined.
  1519. XNN_DEPRECATED enum xnn_status xnn_define_copysign(
  1520. xnn_subgraph_t subgraph,
  1521. uint32_t input1_id,
  1522. uint32_t input2_id,
  1523. uint32_t output_id,
  1524. uint32_t flags);
  1525. /// Define a Copy Node and add it to a Subgraph.
  1526. ///
  1527. /// The Copy Node copies an input tensor to an output tensor.
  1528. ///
  1529. /// @param subgraph - a Subgraph object that will own the created Node.
  1530. /// @param input_id - Value ID for the first input tensor. The input tensor must be defined in the @a subgraph.
  1531. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1532. /// shape must match the shape of the input tensor.
  1533. /// @param flags - binary features of the Copy Node. No supported flags are currently defined.
  1534. enum xnn_status xnn_define_copy(
  1535. xnn_subgraph_t subgraph,
  1536. uint32_t input_id,
  1537. uint32_t output_id,
  1538. uint32_t flags);
  1539. /// Define a 2-Output Split Node and add it to a Subgraph.
  1540. ///
  1541. /// The 2-Output Split Node splits an input tensor into two output tensors along a specified axis evenly.
  1542. ///
  1543. /// @param subgraph - a Subgraph object that will own the created Node.
  1544. /// @param split_dim - the dimension to split the input tensor along. If this is less than zero, the number of
  1545. /// dimensions is added to it.
  1546. /// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the @a
  1547. /// subgraph.
  1548. /// @param output1_id - Value ID for the first output tensor. The output tensor must be an N-dimensional tensor defined
  1549. /// in the @a subgraph with each dimension, except the axis, equal to the corresponding dimension
  1550. /// of the second output. The split_dim dimension is half of the input's split_dim.
  1551. /// @param output2_id - Value ID for the second output tensor. The output tensor must be an N-dimensional tensor
  1552. /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
  1553. /// dimension of the first output. The split_dim dimension is half of the input's split_dim.
  1554. /// @param flags - binary features of the Split Node. No supported flags are currently defined.
  1555. enum xnn_status xnn_define_even_split2(
  1556. xnn_subgraph_t subgraph,
  1557. int32_t split_dim,
  1558. uint32_t input_id,
  1559. uint32_t output1_id,
  1560. uint32_t output2_id,
  1561. uint32_t flags);
  1562. /// Define a 3-Output Split Node and add it to a Subgraph.
  1563. ///
  1564. /// The 3-Output Split Node splits an input tensor into three output tensors along a specified axis evenly.
  1565. ///
  1566. /// @param subgraph - a Subgraph object that will own the created Node.
  1567. /// @param split_dim - the dimension to split the input tensor along. If this is less than zero, the number of
  1568. /// dimensions is added to it.
  1569. /// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the @a
  1570. /// subgraph.
  1571. /// @param output1_id - Value ID for the first output tensor. The output tensor must be an N-dimensional tensor defined
  1572. /// in the @a subgraph with each dimension, except the axis, equal to the corresponding dimension
  1573. /// of the second and third output. The split_dim dimension is one third of the input's split_dim.
  1574. /// @param output2_id - Value ID for the second output tensor. The output tensor must be an N-dimensional tensor
  1575. /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
  1576. /// dimension of the first and third output. The split_dim dimension is one third of the input's
  1577. /// split_dim.
  1578. /// @param output3_id - Value ID for the third output tensor. The output tensor must be an N-dimensional tensor
  1579. /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
  1580. /// dimension of the second and third output. The split_dim dimension is one third of the input's
  1581. /// split_dim.
  1582. /// @param flags - binary features of the Split Node. No supported flags are currently defined.
  1583. enum xnn_status xnn_define_even_split3(
  1584. xnn_subgraph_t subgraph,
  1585. int32_t split_dim,
  1586. uint32_t input_id,
  1587. uint32_t output1_id,
  1588. uint32_t output2_id,
  1589. uint32_t output3_id,
  1590. uint32_t flags);
  1591. /// Define a 4-Output Split Node and add it to a Subgraph.
  1592. ///
  1593. /// The 4-Output Split Node splits an input tensor into four output tensors along a specified axis evenly.
  1594. ///
  1595. /// @param subgraph - a Subgraph object that will own the created Node.
  1596. /// @param split_dim - the dimension to split the input tensor along. If this is less than zero, the number of
  1597. /// dimensions is added to it.
  1598. /// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the @a
  1599. /// subgraph.
  1600. /// @param output1_id - Value ID for the first output tensor. The output tensor must be an N-dimensional tensor defined
  1601. /// in the @a subgraph with each dimension, except the axis, equal to the corresponding dimension
  1602. /// of the other output tensors. The split_dim dimension is one fourth of the input's split_dim.
  1603. /// @param output2_id - Value ID for the second output tensor. The output tensor must be an N-dimensional tensor
  1604. /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
  1605. /// dimension of the other output tensors. The split_dim dimension is one fourth of the input's
  1606. /// split_dim.
  1607. /// @param output3_id - Value ID for the third output tensor. The output tensor must be an N-dimensional tensor
  1608. /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
  1609. /// dimension of the other output tensors. The split_dim dimension is one fourth of the input's
  1610. /// split_dim.
  1611. /// @param output4_id - Value ID for the fourth output tensor. The output tensor must be an N-dimensional tensor
  1612. /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
  1613. /// dimension of the other output tensors. The split_dim dimension is one fourth of the input's
  1614. /// split_dim.
  1615. /// @param flags - binary features of the Split Node. No supported flags are currently defined.
  1616. enum xnn_status xnn_define_even_split4(
  1617. xnn_subgraph_t subgraph,
  1618. int32_t split_dim,
  1619. uint32_t input_id,
  1620. uint32_t output1_id,
  1621. uint32_t output2_id,
  1622. uint32_t output3_id,
  1623. uint32_t output4_id,
  1624. uint32_t flags);
  1625. /// Define a Reshape Node with static shape specification and add it to a Subgraph.
  1626. ///
  1627. /// @param subgraph - a Subgraph object that will own the created Node.
  1628. /// @param num_dims - number of shape dimensions in the output tensor.
  1629. /// @param new_shape - shape dimensions of the output tensor.
  1630. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1631. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1632. /// shape must match the shape of the input tensor with padding.
  1633. /// @param flags - binary features of the Reshape Node. No supported flags are currently defined.
  1634. enum xnn_status xnn_define_static_reshape(
  1635. xnn_subgraph_t subgraph,
  1636. size_t num_dims,
  1637. const size_t* new_shape,
  1638. uint32_t input_id,
  1639. uint32_t output_id,
  1640. uint32_t flags);
  1641. /// Define a 2D Resize Bilinear Node with static output height & width specification and add it to a Subgraph.
  1642. ///
  1643. /// @param subgraph - a Subgraph object that will own the created Node.
  1644. /// @param new_height - height dimension of the output tensor.
  1645. /// @param new_width - width dimension of the output tensor.
  1646. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  1647. /// with [N, H, W, C] dimensions.
  1648. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  1649. /// with [N, new_height, new_width, C] dimensions.
  1650. /// @param flags - binary features of the 2D Resize Bilinear Node. The only currently supported values are
  1651. /// XNN_FLAG_TENSORFLOW_LEGACY_MODE and XNN_FLAG_ALIGN_CORNERS, which are mutually exclusive.
  1652. enum xnn_status xnn_define_static_resize_bilinear_2d(
  1653. xnn_subgraph_t subgraph,
  1654. size_t new_height,
  1655. size_t new_width,
  1656. uint32_t input_id,
  1657. uint32_t output_id,
  1658. uint32_t flags);
  1659. /// Define a PReLU (Parametric ReLU) Node and add it to a Subgraph.
  1660. ///
  1661. /// @param subgraph - a Subgraph object that will own the created Node.
  1662. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  1663. /// with [N, H, W, channels] dimensions.
  1664. /// @param slope_id - Value ID for the slope tensor. The slope tensor must be a 1D tensor defined in the @a subgraph with
  1665. /// either [1] or [channels] dimensions.
  1666. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  1667. /// with [N, H, W, channels] dimensions.
  1668. /// @param flags - binary features of the PReLU Node. No supported flags are currently defined.
  1669. XNN_DEPRECATED enum xnn_status xnn_define_prelu(
  1670. xnn_subgraph_t subgraph,
  1671. uint32_t input_id,
  1672. uint32_t slope_id,
  1673. uint32_t output_id,
  1674. uint32_t flags);
  1675. /// Define a RoPE (Rotary Positional Embeddings) Node and add it to a Subgraph.
  1676. ///
  1677. /// @param subgraph - a Subgraph object that will own the created Node.
  1678. /// @param max_tokens - deprecated.
  1679. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  1680. /// with [batch, tokens, heads, channels] dimensions.
  1681. /// @param weights_id - Value ID for the weights tensor. The weights tensor must be a 2D tensor defined in the
  1682. /// @a subgraph with [max_tokens, channels] dimensions.
  1683. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  1684. /// with [batch, tokens, heads, channels] dimensions.
  1685. /// @param flags - binary features of the RoPE Node. No supported flags are currently defined.
  1686. enum xnn_status xnn_define_rope(
  1687. xnn_subgraph_t subgraph,
  1688. size_t max_sequence_size,
  1689. uint32_t input_id,
  1690. uint32_t weights_id,
  1691. uint32_t output_id,
  1692. uint32_t flags);
  1693. /// Define a Abs Node and add it to a Subgraph.
  1694. ///
  1695. /// @param subgraph - a Subgraph object that will own the created Node.
  1696. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1697. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1698. /// shape must match the shape of the input tensor.
  1699. /// @param flags - binary features of the Abs Node. No supported flags are currently defined.
  1700. XNN_DEPRECATED enum xnn_status xnn_define_abs(
  1701. xnn_subgraph_t subgraph,
  1702. uint32_t input_id,
  1703. uint32_t output_id,
  1704. uint32_t flags);
  1705. /// Define a Bankers' Rounding Node and add it to a Subgraph.
  1706. ///
  1707. /// @param subgraph - a Subgraph object that will own the created Node.
  1708. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1709. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1710. /// shape must match the shape of the input tensor.
  1711. /// @param flags - binary features of the Bankers' Rounding Node. No supported flags are currently defined.
  1712. XNN_DEPRECATED enum xnn_status xnn_define_bankers_rounding(
  1713. xnn_subgraph_t subgraph,
  1714. uint32_t input_id,
  1715. uint32_t output_id,
  1716. uint32_t flags);
  1717. /// Define a Batch Matrix Multiply Node and add it to a Subgraph.
  1718. ///
  1719. /// @param subgraph - a Subgraph object that will own the created Node.
  1720. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1721. /// the @a subgraph. It must be at least 3D. The first N-2 dimensions must match the second input
  1722. /// tensor. The last 2 dimensions are [M, K]. If XNN_FLAG_TRANSPOSE_B is not specified, the last
  1723. /// dimension must match the second last dimension of the second input tensor. If
  1724. /// XNN_FLAG_TRANSPOSE_B is specified, the last dimension must match the last dimension of the
  1725. /// second input tensor.
  1726. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined
  1727. /// in the @a subgraph. It must be at least 3D. The first N-2 dimensions must match the first input
  1728. /// tensor. If XNN_FLAG_TRANSPOSE_B is not specified, the last 2 dimensions are [K, N], and the
  1729. /// second last dimension must match the last dimension of the first input tensor. If
  1730. /// XNN_FLAG_TRANSPOSE_B is specified, the last 2 dimensions are [N, K], and the last dimension must
  1731. /// match the last dimension of the first input tensor.
  1732. /// @param output_id - Value ID for the output tensor. The output tensor must be an N-dimensional tensor defined in the
  1733. /// @a subgraph. It must be at least 3D. The first N-2 dimensions must match the first and second
  1734. /// input tensors . The last 2 dimensions must be [M, N].
  1735. /// @param flags - binary features of the Batch Matrix Multiply Node. The only currently supported value is
  1736. /// XNN_FLAG_TRANSPOSE_B.
  1737. enum xnn_status xnn_define_batch_matrix_multiply(
  1738. xnn_subgraph_t subgraph,
  1739. uint32_t input1_id,
  1740. uint32_t input2_id,
  1741. uint32_t output_id,
  1742. uint32_t flags);
  1743. /// Define a Ceiling Node and add it to a Subgraph.
  1744. ///
  1745. /// @param subgraph - a Subgraph object that will own the created Node.
  1746. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1747. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1748. /// shape must match the shape of the input tensor.
  1749. /// @param flags - binary features of the Ceiling Node. No supported flags are currently defined.
  1750. XNN_DEPRECATED enum xnn_status xnn_define_ceiling(
  1751. xnn_subgraph_t subgraph,
  1752. uint32_t input_id,
  1753. uint32_t output_id,
  1754. uint32_t flags);
  1755. /// Define a Clamp Node and add it to a Subgraph.
  1756. ///
  1757. /// @param subgraph - a Subgraph object that will own the created Node.
  1758. /// @param output_min - lower bound for clipping output values.
  1759. /// @param output_max - upper bound for clipping output values.
  1760. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1761. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1762. /// shape must match the shape of the input tensor.
  1763. /// @param flags - binary features of the Clamp Node. No supported flags are currently defined.
  1764. XNN_DEPRECATED enum xnn_status xnn_define_clamp(
  1765. xnn_subgraph_t subgraph,
  1766. float output_min,
  1767. float output_max,
  1768. uint32_t input_id,
  1769. uint32_t output_id,
  1770. uint32_t flags);
  1771. /// Define an ELU (Exponential Linear Unit) Node and add it to a Subgraph.
  1772. ///
  1773. /// @param subgraph - a Subgraph object that will own the created Node.
  1774. /// @param alpha - scale factor for negative output elements.
  1775. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1776. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1777. /// shape must match the shape of the input tensor.
  1778. /// @param flags - binary features of the ELU Node. No supported flags are currently defined.
  1779. XNN_DEPRECATED enum xnn_status xnn_define_elu(
  1780. xnn_subgraph_t subgraph,
  1781. float alpha,
  1782. uint32_t input_id,
  1783. uint32_t output_id,
  1784. uint32_t flags);
  1785. /// Define a Exp Node and add it to a Subgraph.
  1786. ///
  1787. /// @param subgraph - a Subgraph object that will own the created Node.
  1788. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1789. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1790. /// shape must match the shape of the input tensor.
  1791. /// @param flags - binary features of the Exp Node. No supported flags are currently defined.
  1792. XNN_DEPRECATED enum xnn_status xnn_define_exp(
  1793. xnn_subgraph_t subgraph,
  1794. uint32_t input_id,
  1795. uint32_t output_id,
  1796. uint32_t flags);
  1797. /// Define a Floor Node and add it to a Subgraph.
  1798. ///
  1799. /// @param subgraph - a Subgraph object that will own the created Node.
  1800. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1801. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1802. /// shape must match the shape of the input tensor.
  1803. /// @param flags - binary features of the Floor Node. No supported flags are currently defined.
  1804. XNN_DEPRECATED enum xnn_status xnn_define_floor(
  1805. xnn_subgraph_t subgraph,
  1806. uint32_t input_id,
  1807. uint32_t output_id,
  1808. uint32_t flags);
  1809. /// Define an GELU (Gaussian Error Linear Unit) Node and add it to a Subgraph.
  1810. ///
  1811. /// @param subgraph - a Subgraph object that will own the created Node.
  1812. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1813. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1814. /// shape must match the shape of the input tensor.
  1815. /// @param flags - binary features of the GELU Node. No supported flags are currently defined.
  1816. XNN_DEPRECATED enum xnn_status xnn_define_gelu(
  1817. xnn_subgraph_t subgraph,
  1818. uint32_t input_id,
  1819. uint32_t output_id,
  1820. uint32_t flags);
  1821. /// Define a HardSwish Node and add it to a Subgraph.
  1822. ///
  1823. /// @param subgraph - a Subgraph object that will own the created Node.
  1824. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1825. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1826. /// shape must match the shape of the input tensor.
  1827. /// @param flags - binary features of the HardSwish Node. No supported flags are currently defined.
  1828. XNN_DEPRECATED enum xnn_status xnn_define_hardswish(
  1829. xnn_subgraph_t subgraph,
  1830. uint32_t input_id,
  1831. uint32_t output_id,
  1832. uint32_t flags);
  1833. /// Define a Leaky ReLU Node and add it to a Subgraph.
  1834. ///
  1835. /// @param subgraph - a Subgraph object that will own the created Node.
  1836. /// @param negative_slope - scale factor for negative input elements.
  1837. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1838. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1839. /// shape must match the shape of the input tensor.
  1840. /// @param flags - binary features of the Leaky ReLU Node. No supported flags are currently defined.
  1841. XNN_DEPRECATED enum xnn_status xnn_define_leaky_relu(
  1842. xnn_subgraph_t subgraph,
  1843. float negative_slope,
  1844. uint32_t input_id,
  1845. uint32_t output_id,
  1846. uint32_t flags);
  1847. /// Define a Log Node and add it to a Subgraph.
  1848. ///
  1849. /// @param subgraph - a Subgraph object that will own the created Node.
  1850. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1851. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1852. /// shape must match the shape of the input tensor.
  1853. /// @param flags - binary features of the Log Node. No supported flags are currently defined.
  1854. XNN_DEPRECATED enum xnn_status xnn_define_log(
  1855. xnn_subgraph_t subgraph,
  1856. uint32_t input_id,
  1857. uint32_t output_id,
  1858. uint32_t flags);
  1859. /// Define a Negate Node and add it to a Subgraph.
  1860. ///
  1861. /// @param subgraph - a Subgraph object that will own the created Node.
  1862. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1863. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1864. /// shape must match the shape of the input tensor.
  1865. /// @param flags - binary features of the Negate Node. No supported flags are currently defined.
  1866. XNN_DEPRECATED enum xnn_status xnn_define_negate(
  1867. xnn_subgraph_t subgraph,
  1868. uint32_t input_id,
  1869. uint32_t output_id,
  1870. uint32_t flags);
  1871. /// Define a Sigmoid Node and add it to a Subgraph.
  1872. ///
  1873. /// @param subgraph - a Subgraph object that will own the created Node.
  1874. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1875. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1876. /// shape must match the shape of the input tensor.
  1877. /// @param flags - binary features of the Sigmoid Node. No supported flags are currently defined.
  1878. XNN_DEPRECATED enum xnn_status xnn_define_sigmoid(
  1879. xnn_subgraph_t subgraph,
  1880. uint32_t input_id,
  1881. uint32_t output_id,
  1882. uint32_t flags);
  1883. /// Define a SoftMax Node and add it to a Subgraph.
  1884. ///
  1885. /// @param subgraph - a Subgraph object that will own the created Node.
  1886. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph, and have at
  1887. /// least one dimension.
  1888. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1889. /// shape must match the shape of the input tensor.
  1890. /// @param flags - binary features of the SoftMax Node. No supported flags are currently defined.
  1891. enum xnn_status xnn_define_softmax(
  1892. xnn_subgraph_t subgraph,
  1893. uint32_t input_id,
  1894. uint32_t output_id,
  1895. uint32_t flags);
  1896. /// Define a Space To Depth 2D Node and add it to a Subgraph.
  1897. ///
  1898. /// The Space To Depth 2D Node rearranges blocks of spatial data into blocks (a reverse transform to Depth To Space 2D).
  1899. /// For a given input pixel, an output square of pixels with side @a block_size is formed from values in the
  1900. /// corresponding number of its channels. The output depth is therefore @a block_size x @a block_size times greater
  1901. /// than that of the input.
  1902. ///
  1903. /// @param subgraph - a Subgraph object that will own the created Node.
  1904. /// @param block_size - the size of the spatial block.
  1905. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  1906. /// with [N, IH * block_size, IW * block_size, OC] dimensions.
  1907. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  1908. /// with [N, IH, IW, OC * block_size * block_size] dimensions.
  1909. /// @param flags - binary features of the input_channels Node. No supported flags are currently defined.
  1910. enum xnn_status xnn_define_space_to_depth_2d(
  1911. xnn_subgraph_t subgraph,
  1912. uint32_t block_size,
  1913. uint32_t input_id,
  1914. uint32_t output_id,
  1915. uint32_t flags);
  1916. /// Define a Square Node and add it to a Subgraph.
  1917. ///
  1918. /// @param subgraph - a Subgraph object that will own the created Node.
  1919. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1920. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1921. /// shape must match the shape of the input tensor.
  1922. /// @param flags - binary features of the Square Node. No supported flags are currently defined.
  1923. XNN_DEPRECATED enum xnn_status xnn_define_square(
  1924. xnn_subgraph_t subgraph,
  1925. uint32_t input_id,
  1926. uint32_t output_id,
  1927. uint32_t flags);
  1928. /// Define a Square Root Node and add it to a Subgraph.
  1929. ///
  1930. /// @param subgraph - a Subgraph object that will own the created Node.
  1931. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1932. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1933. /// shape must match the shape of the input tensor.
  1934. /// @param flags - binary features of the Square Root Node. No supported flags are currently defined.
  1935. XNN_DEPRECATED enum xnn_status xnn_define_square_root(
  1936. xnn_subgraph_t subgraph,
  1937. uint32_t input_id,
  1938. uint32_t output_id,
  1939. uint32_t flags);
  1940. /// Define a Reciprocal Square Root Node and add it to a Subgraph.
  1941. ///
  1942. /// @param subgraph - a Subgraph object that will own the created Node.
  1943. /// @param input_id - Value ID for the input tensor. The input tensor must be
  1944. /// defined in the @a subgraph.
  1945. /// @param output_id - Value ID for the output tensor. The output tensor must be
  1946. /// defined in the @a subgraph, and its
  1947. /// shape must match the shape of the input tensor.
  1948. /// @param flags - binary features of the Square Root Node. No supported flags
  1949. /// are currently defined.
  1950. XNN_DEPRECATED enum xnn_status xnn_define_reciprocal_square_root(
  1951. xnn_subgraph_t subgraph,
  1952. uint32_t input_id,
  1953. uint32_t output_id,
  1954. uint32_t flags);
  1955. enum xnn_status xnn_define_static_slice(
  1956. xnn_subgraph_t subgraph,
  1957. size_t num_dims,
  1958. const size_t* offsets,
  1959. const size_t* sizes,
  1960. uint32_t input_id,
  1961. uint32_t output_id,
  1962. uint32_t flags);
  1963. /// Define a Static Slice Node add it to a Subgraph.
  1964. ///
  1965. /// @param subgraph - a Subgraph object that will own the created Node.
  1966. /// @param num_dims - number of shape dimensions in the input and output tensor.
  1967. /// @param offsets - offsets in each dimension of the input tensor. This array must have @a num_dims elements. Can be
  1968. /// negative meaning that the offset is relative to the end of the dimension.
  1969. /// @param sizes - size of each dimension in output tensor. This array must have @a num_dims elements.
  1970. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1971. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1972. /// dimensions must match @a sizes.
  1973. /// @param flags - binary features of the Static Slice Node. No supported flags are currently defined.
  1974. enum xnn_status xnn_define_static_slice_v2( //
  1975. xnn_subgraph_t subgraph, //
  1976. size_t num_dims, //
  1977. const int64_t* offsets, //
  1978. const size_t* sizes, //
  1979. uint32_t input_id, //
  1980. uint32_t output_id, //
  1981. uint32_t flags);
  1982. /// Define a Static Transpose Node and add it to a Subgraph.
  1983. ///
  1984. /// The Static Transpose Node applies a generalized transpose to the input tensor using the permuation in perm.
  1985. ///
  1986. /// @param subgraph - a Subgraph object that will own the created Node.
  1987. /// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in
  1988. /// the @a subgraph.
  1989. /// @param output_id - Value ID for the output tensor. The output tensor must be an N-dimensional tensor defined
  1990. /// in the @a subgraph with each dimension equal to its corresponding permuted input dimension.
  1991. /// @param num_dims - the number of permutation dimensions. This must be equal to the number of input dimensions.
  1992. /// @param perm - The permutation of the axis of the input tensor. The perm array must must contain 0 to N-1 in the
  1993. /// permuted order.
  1994. /// @param flags - binary features of the Static Transpose Node. No supported flags are currently defined.
  1995. enum xnn_status xnn_define_static_transpose(
  1996. xnn_subgraph_t subgraph,
  1997. size_t num_dims,
  1998. const size_t* perm,
  1999. uint32_t input_id,
  2000. uint32_t output_id,
  2001. uint32_t flags);
  2002. /// Define a Tanh Node and add it to a Subgraph.
  2003. ///
  2004. /// @param subgraph - a Subgraph object that will own the created Node.
  2005. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  2006. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  2007. /// shape must match the shape of the input tensor.
  2008. /// @param flags - binary features of the Tanh Node. No supported flags are currently defined.
  2009. XNN_DEPRECATED enum xnn_status xnn_define_tanh(
  2010. xnn_subgraph_t subgraph,
  2011. uint32_t input_id,
  2012. uint32_t output_id,
  2013. uint32_t flags);
  2014. /// Code cache is a cache for JIT generated code.
  2015. typedef struct xnn_code_cache* xnn_code_cache_t;
  2016. /// Weights cache can be finalized in these ways:
  2017. enum xnn_weights_cache_finalization_kind {
  2018. /// Weights cache is finalized, no insert operations into the weights cache is allowed, even if the "inserted"
  2019. /// weights already exist in thee cache. Weights cache memory will also be trimmed to page boundary and set to
  2020. /// read-only (to prevent writes).
  2021. xnn_weights_cache_finalization_kind_hard,
  2022. /// Weights cache will be finalized with some extra space at the end, this allows for "inserting" into the cache only
  2023. /// if the weights are already in the cache, and errors on inserting uncached weights. There is memory overhead.
  2024. xnn_weights_cache_finalization_kind_soft,
  2025. };
  2026. /// A combination of multiple factors to uniquely locate the weights cache.
  2027. struct xnn_weights_cache_look_up_key {
  2028. /// The unique seed for each ukernel. It is guaranteed that each ukernel provides
  2029. /// a consistent and identical seed.
  2030. uint32_t seed;
  2031. /// Pointer to the original kernel.
  2032. const void* kernel;
  2033. /// Pointer to the original bias, could be NULL.
  2034. const void* bias;
  2035. };
  2036. /// A group of function pointers to manage weights cache. All functions may be
  2037. /// called on multi threads.
  2038. struct xnn_weights_cache_provider {
  2039. /// User-specified pointer that will be passed as-is to all functions in this
  2040. /// structure.
  2041. void* context;
  2042. /// Looks up the tuple of {cache_key, kernel, bias} in the cache. If it is found,
  2043. /// returns the offset to the found entry for reuse. Otherwise, returns SIZE_MAX.
  2044. /// @param context - The user-specified pointer from xnn_weights_cache_provider structure.
  2045. /// @param cache_key - The key used to locate the weights cache entry.
  2046. size_t (*look_up)(void* context, const struct xnn_weights_cache_look_up_key* cache_key);
  2047. /// Ensures that cache has enough space for `n` bytes. Returns the address to
  2048. /// store weight cache. Returns NULL if fails to reserve space.
  2049. /// @param context - The user-specified pointer from xnn_weights_cache_provider structure.
  2050. /// @param n - size to be reserved.
  2051. void* (*reserve_space)(void* context, size_t n);
  2052. /// Looks up packed weights at `ptr` in the cache. If it is found, reuse it.
  2053. /// Otherwise, it is added to the cache. Returns the offset to the cache.
  2054. /// @param context - The user-specified pointer from xnn_weights_cache_provider structure.
  2055. /// @param cache_key - The key used to locate the weights cache entry.
  2056. /// @param ptr - pointer pointing to the packed weight.
  2057. /// @param size - size of the packed weight.
  2058. size_t (*look_up_or_insert)(void* context, const struct xnn_weights_cache_look_up_key* cache_key, void* ptr, size_t size);
  2059. /// Returns whether the cache is finalized.
  2060. /// @param context - The user-specified pointer from xnn_weights_cache_provider structure.
  2061. bool (*is_finalized)(void* context);
  2062. /// Returns the absolute pointer corresponding to `offset`, where the offset is returned from
  2063. /// `look_up` or `get_or_insert`. This function must be called after finalize.
  2064. /// @param context - The user-specified pointer from xnn_weights_cache_provider structure.
  2065. /// @param offset - offset to the start of internal buffer
  2066. void* (*offset_to_addr)(void* context, size_t offset);
  2067. /// Destroy a weights cache object, as well as memory used for the cache.
  2068. /// @param context - The user-specified pointer from xnn_weights_cache_provider structure.
  2069. enum xnn_status (*delete_cache)(void* context);
  2070. };
  2071. /// Weights cache is a cache for packed weights. It can be reused between runtimes.
  2072. typedef struct xnn_weights_cache_provider* xnn_weights_cache_t;
  2073. /// Create a weights cache object specifying the initial size of weights cache (in bytes).
  2074. ///
  2075. /// @param[in] size - initial capacity of the weights cache (in bytes), i.e. it can hold size bytes without growing.
  2076. /// @param weights_cache_out - pointer to the variable that will be initialized to a handle to the weights cache provider
  2077. /// upon successful return. Once created, the weights cache provider can be shared between
  2078. /// different Runtime objects.
  2079. enum xnn_status xnn_create_weights_cache_with_size(size_t size, xnn_weights_cache_t* weights_cache_out);
  2080. enum xnn_status xnn_create_weights_cache(xnn_weights_cache_t* weights_cache_out);
  2081. /// Finalizes the weights cache. The kind of finalization is specified by `finalization_kind`.
  2082. /// @param weights_cache - the weights cache object to finalize.
  2083. /// @param finalization_kind - the kind of finalization.
  2084. enum xnn_status xnn_finalize_weights_cache(
  2085. xnn_weights_cache_t weights_cache,
  2086. enum xnn_weights_cache_finalization_kind finalization_kind);
  2087. // Wrapper function of the function pointers in `xnn_weights_cache_t`.
  2088. bool xnn_weights_cache_is_finalized(xnn_weights_cache_t cache);
  2089. /// Destroy a weights cache object, as well as memory used for the cache.
  2090. /// @param weights_cache - the weights cache object to destroy.
  2091. enum xnn_status xnn_delete_weights_cache(xnn_weights_cache_t weights_cache);
  2092. typedef struct xnn_workspace* xnn_workspace_t;
  2093. /// Create a workspace object.
  2094. /// @param workspace_out - pointer to the variable that will be initialized to a handle to the workspace object upon
  2095. /// successful return. Once created, the workspace can be shared between different Runtime
  2096. /// objects.
  2097. enum xnn_status xnn_create_workspace(xnn_workspace_t* workspace_out);
  2098. /// Destroy a workspace object, as well as memory used by the workspace. Object destruction can be deferred until all
  2099. /// Runtime objects created with this workspace are destroyed.
  2100. /// @param workspace - the workspace object to destroy.
  2101. enum xnn_status xnn_release_workspace(xnn_workspace_t workspace);
  2102. /// Runtime is a combination of an execution plan for subgraph Nodes and a memory manager for subgraph Values.
  2103. typedef struct xnn_runtime* xnn_runtime_t;
  2104. enum xnn_profile_info {
  2105. /// Returns a size_t containing the number of operators.
  2106. xnn_profile_info_num_operators,
  2107. /// Returns a char[] containing the null character separated names of all operators.
  2108. xnn_profile_info_operator_name,
  2109. /// Returns a uint64_t[] with the runtimes of all operators in the same order as xnn_profile_info_operator_name.
  2110. xnn_profile_info_operator_timing,
  2111. };
  2112. /// Return profile information for all operators.
  2113. ///
  2114. /// @param runtime - a Runtime object created with @ref xnn_create_runtime, @ref xnn_create_runtime_v2 or
  2115. /// @ref xnn_create_runtime_v3.
  2116. /// @param param_name - type of profile information required.
  2117. /// @param param_value_size - the size in bytes of memory pointed to by param_value. If this is not sufficient then
  2118. /// param_value_size_ret will be set to the required size and xnn_status_out_of_memory will be
  2119. /// returned.
  2120. /// @param param_value - a pointer to memory location where appropriate values for a given param_value will be written.
  2121. /// @param param_value_size_ret - returns number of bytes required to write the result if param_value_size is not
  2122. /// sufficient.
  2123. enum xnn_status xnn_get_runtime_profiling_info(xnn_runtime_t runtime,
  2124. enum xnn_profile_info param_name,
  2125. size_t param_value_size,
  2126. void* param_value,
  2127. size_t* param_value_size_ret);
  2128. /// Create a Runtime object from a subgraph.
  2129. ///
  2130. /// @param subgraph - a Subgraph object with all Values and Nodes that would be handled by the runtime. No Values or
  2131. /// Nodes can be added to the runtime once it is constructed.
  2132. /// @param weights_cache - a cache for packed weights. The runtime will look up and reuse packed weights in this cache,
  2133. /// this will reduce memory allocated for packed weights.
  2134. /// @param workspace - a workspace to hold internal tensors. The runtime will allocate space used for internal tensors
  2135. /// and track them using workspace. Workspace can be shared and reused across different runtimes. If
  2136. /// workspace is NULL, there will be no sharing: each runtime has its own workspace.
  2137. /// @param threadpool - the thread pool to be used for parallelisation of computations in the runtime. If the thread
  2138. /// pool is NULL, the computation would run on the caller thread without parallelization.
  2139. /// @param flags - binary features of the runtime. The only currently supported values are
  2140. /// XNN_FLAG_HINT_SPARSE_INFERENCE, XNN_FLAG_HINT_FP16_INFERENCE, XNN_FLAG_FORCE_FP16_INFERENCE,
  2141. /// XNN_FLAG_YIELD_WORKERS, and XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER. If XNN_FLAG_YIELD_WORKERS is
  2142. /// specified, worker threads would be yielded to the system scheduler after processing the last operator
  2143. /// in the Runtime. If XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER is specified, convolution operators will
  2144. /// initialize indirection buffers on each inference run using temporary memory in the workspace, instead
  2145. /// of initializing persistent indirection buffers once.
  2146. /// @param runtime_out - pointer to the variable that will be initialized with a handle to the Runtime object upon
  2147. /// successful return. Once constructed, the Runtime object is independent of the Subgraph object
  2148. /// used to create it.
  2149. enum xnn_status xnn_create_runtime_v4(
  2150. xnn_subgraph_t subgraph,
  2151. xnn_weights_cache_t weights_cache,
  2152. xnn_workspace_t workspace,
  2153. pthreadpool_t threadpool,
  2154. uint32_t flags,
  2155. xnn_runtime_t* runtime_out);
  2156. enum xnn_status xnn_create_runtime_v3(
  2157. xnn_subgraph_t subgraph,
  2158. xnn_weights_cache_t weights_cache,
  2159. pthreadpool_t threadpool,
  2160. uint32_t flags,
  2161. xnn_runtime_t* runtime_out);
  2162. enum xnn_status xnn_create_runtime_v2(
  2163. xnn_subgraph_t subgraph,
  2164. pthreadpool_t threadpool,
  2165. uint32_t flags,
  2166. xnn_runtime_t* runtime_out);
  2167. enum xnn_status xnn_create_runtime(
  2168. xnn_subgraph_t subgraph,
  2169. xnn_runtime_t* runtime_out);
  2170. struct xnn_external_value {
  2171. uint32_t id;
  2172. void* data;
  2173. };
  2174. /// Reshape an external value.
  2175. ///
  2176. /// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on
  2177. /// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be
  2178. /// created for the Value.
  2179. /// @param num_dims - number of dimensions in the shape.
  2180. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  2181. /// XNNPACK does not keep any pointers to this array after the function returns.
  2182. enum xnn_status xnn_reshape_external_value(
  2183. xnn_runtime_t runtime,
  2184. uint32_t external_id,
  2185. size_t num_dims,
  2186. const size_t* dims);
  2187. /// Get the external value shape.
  2188. ///
  2189. /// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on
  2190. /// the Subgraph creation. The external ID can not be XNN_INVALID_VALUE_ID.
  2191. /// @param num_dims - A valid pointer into which the number of dimensions in the shape will be written. It can not be larger than XNN_MAX_TENSOR_DIMS.
  2192. /// @param dims - pointer to an array of @a num_dims shape dimensions. This pointer can't be NULL. It must be large enough to hold
  2193. /// at least @a num_dims elements. XNNPACK does not keep any pointers to this array after the function returns.
  2194. enum xnn_status xnn_get_external_value_shape(
  2195. xnn_runtime_t runtime,
  2196. uint32_t external_id,
  2197. size_t* num_dims,
  2198. size_t* dims);
  2199. /// Reshape the XNNPACK runtime.
  2200. ///
  2201. /// Propagates the shapes of input tensors through the graph to determine the shapes of intermediate and output tensors.
  2202. /// Memory is allocated if required. Output tensor shapes are returned by xnn_get_external_value_shape.
  2203. ///
  2204. /// @param runtime - a Runtime object created with @ref xnn_create_runtime or @ref xnn_create_runtime_v2.
  2205. enum xnn_status xnn_reshape_runtime(
  2206. xnn_runtime_t runtime);
  2207. /// Deprecated. Use xnn_reshape_runtime and xnn_setup_runtime_v2.
  2208. ///
  2209. /// Setup data pointers for external inputs and outputs in a Runtime object and
  2210. /// allocate memory.
  2211. ///
  2212. /// @param runtime - a Runtime object created with @ref xnn_create_runtime or @ref xnn_create_runtime_v2.
  2213. /// @param num_external_values - the number of external inputs and outputs specified in this call. This number must
  2214. /// match the number of external inputs and outputs in the runtime, i.e. all external
  2215. /// inputs and outputs in the runtime must be specified in one call.
  2216. /// @param external_values - array with location information for all external inputs and outputs in the runtime.
  2217. enum xnn_status xnn_setup_runtime(
  2218. xnn_runtime_t runtime,
  2219. size_t num_external_values,
  2220. const struct xnn_external_value* external_values);
  2221. /// Setup data pointers for external inputs and outputs in a Runtime object.
  2222. /// Should be called after xnn_reshape_runtime.
  2223. ///
  2224. /// @param runtime - a Runtime object created with @ref xnn_create_runtime or @ref xnn_create_runtime_v2.
  2225. /// @param num_external_values - the number of external inputs and outputs specified in this call. This number must
  2226. /// match the number of external inputs and outputs in the runtime, i.e. all external
  2227. /// inputs and outputs in the runtime must be specified in one call.
  2228. /// @param external_values - array with location information for all external inputs and outputs in the runtime.
  2229. enum xnn_status xnn_setup_runtime_v2(
  2230. xnn_runtime_t runtime,
  2231. size_t num_external_values,
  2232. const struct xnn_external_value* external_values);
  2233. /// Execute forward pass for all operators in the runtime.
  2234. ///
  2235. /// @param runtime - the Runtime object with the execution plan to invoke.
  2236. enum xnn_status xnn_invoke_runtime(
  2237. xnn_runtime_t runtime);
  2238. /// Destroy a Runtime object, as well as operators and memory associated with it.
  2239. ///
  2240. /// @param runtime - the Runtime object to destroy.
  2241. enum xnn_status xnn_delete_runtime(
  2242. xnn_runtime_t runtime);
  2243. typedef struct xnn_operator* xnn_operator_t;
  2244. enum xnn_status xnn_run_operator(
  2245. xnn_operator_t op,
  2246. pthreadpool_t threadpool);
  2247. enum xnn_status xnn_delete_operator(
  2248. xnn_operator_t op);
  2249. /// Operator API:
  2250. /// - create operator will create and populate a xnn_operator_t
  2251. /// - reshape operator will update fields in xnn_operator_t with shape/dimensions and parallelization information
  2252. /// - setup operator will update pointers to input and outputs
  2253. /// Each supported operator must have a create, reshape, and setup function. (Optionally a run function.)
  2254. /// Operators listed below are in alphabetical order by operator name; within each operator, we sort alphabetically by
  2255. /// data layout and type. We also group create, reshape, setup (and optionally run) functions of each operator together.
  2256. enum xnn_status xnn_create_binary_elementwise_nd(
  2257. enum xnn_binary_operator type,
  2258. enum xnn_datatype datatype,
  2259. const struct xnn_quantization_params* input1_quantization,
  2260. const struct xnn_quantization_params* input2_quantization,
  2261. const struct xnn_quantization_params* output_quantization,
  2262. uint32_t flags,
  2263. xnn_operator_t* binary_op_out);
  2264. enum xnn_status xnn_reshape_binary_elementwise_nd(
  2265. xnn_operator_t binary_op,
  2266. size_t num_input1_dims,
  2267. const size_t* input1_shape,
  2268. size_t num_input2_dims,
  2269. const size_t* input2_shape,
  2270. pthreadpool_t threadpool);
  2271. enum xnn_status xnn_setup_binary_elementwise_nd(
  2272. xnn_operator_t binary_op,
  2273. const void* input1,
  2274. const void* input2,
  2275. void* output);
  2276. enum xnn_status xnn_run_binary_elementwise_nd(
  2277. enum xnn_binary_operator type,
  2278. enum xnn_datatype datatype,
  2279. const struct xnn_quantization_params* input1_quantization,
  2280. const struct xnn_quantization_params* input2_quantization,
  2281. const struct xnn_quantization_params* output_quantization,
  2282. uint32_t flags,
  2283. size_t num_input1_dims,
  2284. const size_t* input1_shape,
  2285. size_t num_input2_dims,
  2286. const size_t* input2_shape,
  2287. const void* input1,
  2288. const void* input2,
  2289. void* output,
  2290. pthreadpool_t threadpool);
  2291. enum xnn_status xnn_create_unary_elementwise_nc(
  2292. enum xnn_unary_operator op_type,
  2293. enum xnn_datatype input_datatype,
  2294. enum xnn_datatype output_datatype,
  2295. const union xnn_unary_params* params,
  2296. const struct xnn_quantization_params* input_quantization,
  2297. const struct xnn_quantization_params* output_quantization,
  2298. uint32_t flags,
  2299. xnn_operator_t* op_out);
  2300. enum xnn_status xnn_reshape_unary_elementwise_nc(
  2301. xnn_operator_t op,
  2302. size_t batch_size,
  2303. size_t channels,
  2304. size_t input_stride,
  2305. size_t output_stride,
  2306. pthreadpool_t threadpool);
  2307. enum xnn_status xnn_setup_unary_elementwise_nc(
  2308. xnn_operator_t op,
  2309. const void* input,
  2310. void* output);
  2311. enum xnn_status xnn_run_unary_elementwise_nc(
  2312. // create parameters
  2313. enum xnn_unary_operator op_type,
  2314. enum xnn_datatype input_datatype,
  2315. enum xnn_datatype output_datatype,
  2316. const union xnn_unary_params* params,
  2317. const struct xnn_quantization_params* input_quantization,
  2318. const struct xnn_quantization_params* output_quantization,
  2319. uint32_t flags,
  2320. // reshape parameters
  2321. size_t batch_size,
  2322. size_t channels,
  2323. size_t input_stride,
  2324. size_t output_stride,
  2325. pthreadpool_t threadpool,
  2326. // setup parameters
  2327. const void* input,
  2328. void* output);
  2329. enum xnn_status xnn_create_argmax_pooling2d_nhwc_f32(
  2330. uint32_t input_padding_top,
  2331. uint32_t input_padding_right,
  2332. uint32_t input_padding_bottom,
  2333. uint32_t input_padding_left,
  2334. uint32_t pooling_height,
  2335. uint32_t pooling_width,
  2336. uint32_t flags,
  2337. xnn_operator_t* argmax_pooling_op_out);
  2338. enum xnn_status xnn_reshape_argmax_pooling2d_nhwc_f32(
  2339. xnn_operator_t argmax_pooling_op,
  2340. size_t batch_size,
  2341. size_t input_height,
  2342. size_t input_width,
  2343. size_t channels,
  2344. size_t input_pixel_stride,
  2345. size_t output_pixel_stride,
  2346. size_t* workspace_size,
  2347. size_t* workspace_alignment,
  2348. size_t* output_height_out,
  2349. size_t* output_width_out,
  2350. pthreadpool_t threadpool);
  2351. enum xnn_status xnn_setup_argmax_pooling2d_nhwc_f32(
  2352. xnn_operator_t argmax_pooling_op,
  2353. void* workspace,
  2354. const float* input,
  2355. float* output,
  2356. uint32_t* index);
  2357. enum xnn_status xnn_create_average_pooling2d_nhwc_f16(
  2358. uint32_t input_padding_top,
  2359. uint32_t input_padding_right,
  2360. uint32_t input_padding_bottom,
  2361. uint32_t input_padding_left,
  2362. uint32_t pooling_height,
  2363. uint32_t pooling_width,
  2364. uint32_t stride_height,
  2365. uint32_t stride_width,
  2366. float output_min,
  2367. float output_max,
  2368. uint32_t flags,
  2369. xnn_operator_t* average_pooling_op_out);
  2370. enum xnn_status xnn_reshape_average_pooling2d_nhwc_f16(
  2371. xnn_operator_t average_pooling_op,
  2372. size_t batch_size,
  2373. size_t input_height,
  2374. size_t input_width,
  2375. size_t channels,
  2376. size_t input_pixel_stride,
  2377. size_t output_pixel_stride,
  2378. size_t* workspace_size,
  2379. size_t* workspace_alignment,
  2380. size_t* output_height_out,
  2381. size_t* output_width_out,
  2382. pthreadpool_t threadpool);
  2383. enum xnn_status xnn_setup_average_pooling2d_nhwc_f16(
  2384. xnn_operator_t average_pooling_op,
  2385. void* workspace,
  2386. const void* input,
  2387. void* output);
  2388. enum xnn_status xnn_create_average_pooling2d_nhwc_f32(
  2389. uint32_t input_padding_top,
  2390. uint32_t input_padding_right,
  2391. uint32_t input_padding_bottom,
  2392. uint32_t input_padding_left,
  2393. uint32_t pooling_height,
  2394. uint32_t pooling_width,
  2395. uint32_t stride_height,
  2396. uint32_t stride_width,
  2397. float output_min,
  2398. float output_max,
  2399. uint32_t flags,
  2400. xnn_operator_t* average_pooling_op_out);
  2401. enum xnn_status xnn_reshape_average_pooling2d_nhwc_f32(
  2402. xnn_operator_t average_pooling_op,
  2403. size_t batch_size,
  2404. size_t input_height,
  2405. size_t input_width,
  2406. size_t channels,
  2407. size_t input_pixel_stride,
  2408. size_t output_pixel_stride,
  2409. size_t* workspace_size,
  2410. size_t* workspace_alignment,
  2411. size_t* output_height_out,
  2412. size_t* output_width_out,
  2413. pthreadpool_t threadpool);
  2414. enum xnn_status xnn_setup_average_pooling2d_nhwc_f32(
  2415. xnn_operator_t average_pooling_op,
  2416. void* workspace,
  2417. const float* input,
  2418. float* output);
  2419. enum xnn_status xnn_create_average_pooling2d_nhwc_qu8(
  2420. uint32_t input_padding_top,
  2421. uint32_t input_padding_right,
  2422. uint32_t input_padding_bottom,
  2423. uint32_t input_padding_left,
  2424. uint32_t pooling_height,
  2425. uint32_t pooling_width,
  2426. uint32_t stride_height,
  2427. uint32_t stride_width,
  2428. uint8_t input_zero_point,
  2429. float input_scale,
  2430. uint8_t output_zero_point,
  2431. float output_scale,
  2432. uint8_t output_min,
  2433. uint8_t output_max,
  2434. uint32_t flags,
  2435. xnn_operator_t* average_pooling_op_out);
  2436. enum xnn_status xnn_reshape_average_pooling2d_nhwc_qu8(
  2437. xnn_operator_t average_pooling_op,
  2438. size_t batch_size,
  2439. size_t input_height,
  2440. size_t input_width,
  2441. size_t channels,
  2442. size_t input_pixel_stride,
  2443. size_t output_pixel_stride,
  2444. size_t* workspace_size,
  2445. size_t* workspace_alignment,
  2446. size_t* output_height_out,
  2447. size_t* output_width_out,
  2448. pthreadpool_t threadpool);
  2449. enum xnn_status xnn_setup_average_pooling2d_nhwc_qu8(
  2450. xnn_operator_t average_pooling_op,
  2451. void* workspace,
  2452. const uint8_t* input,
  2453. uint8_t* output);
  2454. enum xnn_status xnn_create_batch_matrix_multiply_nc_f16(
  2455. uint32_t flags,
  2456. xnn_operator_t* batch_matrix_multiply_op);
  2457. enum xnn_status xnn_reshape_batch_matrix_multiply_nc_f16(
  2458. xnn_operator_t batch_matrix_multiply_op, size_t num_batch_dims,
  2459. const size_t* batch_dims_a, const size_t* batch_dims_b, size_t m, size_t k,
  2460. size_t n, size_t* workspace_size, size_t* workspace_alignment,
  2461. pthreadpool_t threadpool);
  2462. enum xnn_status xnn_setup_batch_matrix_multiply_nc_f16(
  2463. xnn_operator_t batch_matrix_multiply_op, void* workspace,
  2464. const void* input_a, const void* input_b, void* output);
  2465. enum xnn_status xnn_create_batch_matrix_multiply_nc_f32(
  2466. uint32_t flags, xnn_operator_t* batch_matrix_multiply_op);
  2467. enum xnn_status xnn_create_batch_matrix_multiply_nc_f32_const_weights(
  2468. size_t batch_size_b, size_t k, size_t n, const float* data_b,
  2469. uint32_t flags, xnn_operator_t* batch_matrix_multiply_op);
  2470. enum xnn_status xnn_reshape_batch_matrix_multiply_nc_f32(
  2471. xnn_operator_t batch_matrix_multiply_op, size_t num_batch_dims,
  2472. const size_t* batch_dims_a, const size_t* batch_dims_b, size_t m, size_t k,
  2473. size_t n, size_t* workspace_size, size_t* workspace_alignment,
  2474. pthreadpool_t threadpool);
  2475. enum xnn_status xnn_setup_batch_matrix_multiply_nc_f32(
  2476. xnn_operator_t batch_matrix_multiply_op, void* workspace,
  2477. const float* input_a, const float* input_b, float* output);
  2478. enum xnn_status xnn_create_batch_matrix_multiply_nc_qd8_f32_qc8w(
  2479. size_t batch_size_b, size_t k, size_t n, const int8_t* data_b,
  2480. const float* scale_b, uint32_t flags,
  2481. xnn_operator_t* batch_matrix_multiply_op);
  2482. enum xnn_status xnn_reshape_batch_matrix_multiply_nc_qd8_f32_qc8w(
  2483. xnn_operator_t batch_matrix_multiply_op, size_t num_batch_dims,
  2484. const size_t* batch_dims_a, const size_t* batch_dims_b, size_t m, size_t k,
  2485. size_t n, pthreadpool_t threadpool);
  2486. enum xnn_status xnn_setup_batch_matrix_multiply_nc_qd8_f32_qc8w(
  2487. xnn_operator_t batch_matrix_multiply_op, const int8_t* input_a,
  2488. const struct xnn_quantization_params* quantization_params,
  2489. float* output);
  2490. enum xnn_status xnn_create_channel_shuffle_nc_x8(
  2491. size_t groups,
  2492. size_t group_channels,
  2493. size_t input_stride,
  2494. size_t output_stride,
  2495. uint32_t flags,
  2496. xnn_operator_t* channel_shuffle_op_out);
  2497. enum xnn_status xnn_reshape_channel_shuffle_nc_x8(
  2498. xnn_operator_t channel_shuffle_op,
  2499. size_t batch_size,
  2500. pthreadpool_t threadpool);
  2501. enum xnn_status xnn_setup_channel_shuffle_nc_x8(
  2502. xnn_operator_t channel_shuffle_op,
  2503. const void* input,
  2504. void* output);
  2505. enum xnn_status xnn_create_channel_shuffle_nc_x32(
  2506. size_t groups,
  2507. size_t group_channels,
  2508. size_t input_stride,
  2509. size_t output_stride,
  2510. uint32_t flags,
  2511. xnn_operator_t* channel_shuffle_op_out);
  2512. enum xnn_status xnn_reshape_channel_shuffle_nc_x32(
  2513. xnn_operator_t channel_shuffle_op,
  2514. size_t batch_size,
  2515. pthreadpool_t threadpool);
  2516. enum xnn_status xnn_setup_channel_shuffle_nc_x32(
  2517. xnn_operator_t channel_shuffle_op,
  2518. const void* input,
  2519. void* output);
  2520. enum xnn_status xnn_create_constant_pad_nd_x8(
  2521. const void* padding_value,
  2522. uint32_t flags,
  2523. xnn_operator_t* constant_pad_op_out);
  2524. enum xnn_status xnn_reshape_constant_pad_nd_x8(
  2525. xnn_operator_t constant_pad_op,
  2526. size_t num_dims,
  2527. const size_t* input_shape,
  2528. const size_t* pre_padding,
  2529. const size_t* post_padding,
  2530. pthreadpool_t threadpool);
  2531. enum xnn_status xnn_setup_constant_pad_nd_x8(
  2532. xnn_operator_t constant_pad_op,
  2533. const void* input,
  2534. void* output);
  2535. enum xnn_status xnn_run_constant_pad_nd_x8(
  2536. uint32_t flags,
  2537. size_t num_dims,
  2538. const size_t* input_shape,
  2539. const size_t* pre_paddings,
  2540. const size_t* post_paddings,
  2541. const void* input,
  2542. void* output,
  2543. const void* padding_value,
  2544. pthreadpool_t threadpool);
  2545. enum xnn_status xnn_create_constant_pad_nd_x16(
  2546. const void* padding_value,
  2547. uint32_t flags,
  2548. xnn_operator_t* constant_pad_op_out);
  2549. enum xnn_status xnn_reshape_constant_pad_nd_x16(
  2550. xnn_operator_t constant_pad_op,
  2551. size_t num_dims,
  2552. const size_t* input_shape,
  2553. const size_t* pre_padding,
  2554. const size_t* post_padding,
  2555. pthreadpool_t threadpool);
  2556. enum xnn_status xnn_setup_constant_pad_nd_x16(
  2557. xnn_operator_t constant_pad_op,
  2558. const void* input,
  2559. void* output);
  2560. enum xnn_status xnn_run_constant_pad_nd_x16(
  2561. uint32_t flags,
  2562. size_t num_dims,
  2563. const size_t* input_shape,
  2564. const size_t* pre_paddings,
  2565. const size_t* post_paddings,
  2566. const void* input,
  2567. void* output,
  2568. const void* padding_value,
  2569. pthreadpool_t threadpool);
  2570. enum xnn_status xnn_create_constant_pad_nd_x32(
  2571. const void* padding_value,
  2572. uint32_t flags,
  2573. xnn_operator_t* constant_pad_op_out);
  2574. enum xnn_status xnn_reshape_constant_pad_nd_x32(
  2575. xnn_operator_t constant_pad_op,
  2576. size_t num_dims,
  2577. const size_t* input_shape,
  2578. const size_t* pre_padding,
  2579. const size_t* post_padding,
  2580. pthreadpool_t threadpool);
  2581. enum xnn_status xnn_setup_constant_pad_nd_x32(
  2582. xnn_operator_t constant_pad_op,
  2583. const void* input,
  2584. void* output);
  2585. enum xnn_status xnn_run_constant_pad_nd_x32(
  2586. uint32_t flags,
  2587. size_t num_dims,
  2588. const size_t* input_shape,
  2589. const size_t* pre_paddings,
  2590. const size_t* post_paddings,
  2591. const void* input,
  2592. void* output,
  2593. const void* padding_value,
  2594. pthreadpool_t threadpool);
  2595. enum xnn_status xnn_create_convert_nc_f16_qd8(
  2596. uint32_t flags,
  2597. xnn_operator_t* convert_op_out);
  2598. enum xnn_status xnn_reshape_convert_nc_f16_qd8(
  2599. xnn_operator_t convert_op,
  2600. size_t batch_size,
  2601. size_t channels,
  2602. size_t input_stride,
  2603. size_t output_stride,
  2604. pthreadpool_t threadpool);
  2605. // quantization_params must be padded with at least XNN_EXTRA_QUANTIZATION_PARAMS entries.
  2606. enum xnn_status xnn_setup_convert_nc_f16_qd8(
  2607. xnn_operator_t convert_op,
  2608. const void* input,
  2609. int8_t* output,
  2610. struct xnn_quantization_params* quantization_params);
  2611. enum xnn_status xnn_create_convert_nc_f32_qd8(
  2612. uint32_t flags,
  2613. xnn_operator_t* convert_op_out);
  2614. enum xnn_status xnn_reshape_convert_nc_f32_qd8(
  2615. xnn_operator_t convert_op,
  2616. size_t batch_size,
  2617. size_t channels,
  2618. size_t input_stride,
  2619. size_t output_stride,
  2620. pthreadpool_t threadpool);
  2621. // quantization_params must be padded with at least XNN_EXTRA_QUANTIZATION_PARAMS entries.
  2622. enum xnn_status xnn_setup_convert_nc_f32_qd8(
  2623. xnn_operator_t convert_op,
  2624. const float* input,
  2625. int8_t* output,
  2626. struct xnn_quantization_params* quantization_params);
  2627. XNN_DEPRECATED enum xnn_status xnn_run_convert_nc_f32_f16(
  2628. size_t channels,
  2629. size_t input_stride,
  2630. size_t output_stride,
  2631. size_t batch_size,
  2632. const float* input,
  2633. void* output,
  2634. uint32_t flags,
  2635. pthreadpool_t threadpool);
  2636. enum xnn_status xnn_create_convolution2d_nchw_f16(
  2637. uint32_t input_padding_top,
  2638. uint32_t input_padding_right,
  2639. uint32_t input_padding_bottom,
  2640. uint32_t input_padding_left,
  2641. uint32_t kernel_height,
  2642. uint32_t kernel_width,
  2643. uint32_t subsampling_height,
  2644. uint32_t subsampling_width,
  2645. uint32_t dilation_height,
  2646. uint32_t dilation_width,
  2647. uint32_t groups,
  2648. size_t group_input_channels,
  2649. size_t group_output_channels,
  2650. size_t input_channel_stride,
  2651. size_t output_channel_stride,
  2652. const void* kernel,
  2653. const void* bias,
  2654. float output_min,
  2655. float output_max,
  2656. uint32_t flags,
  2657. xnn_code_cache_t code_cache,
  2658. xnn_weights_cache_t weights_cache,
  2659. xnn_operator_t* convolution_op_out);
  2660. enum xnn_status xnn_reshape_convolution2d_nchw_f16(
  2661. xnn_operator_t convolution_op,
  2662. size_t batch_size,
  2663. size_t input_height,
  2664. size_t input_width,
  2665. size_t* output_height_out,
  2666. size_t* output_width_out,
  2667. pthreadpool_t threadpool);
  2668. enum xnn_status xnn_setup_convolution2d_nchw_f16(
  2669. xnn_operator_t convolution_op,
  2670. const void* input,
  2671. void* output);
  2672. enum xnn_status xnn_create_convolution2d_nchw_f32(
  2673. uint32_t input_padding_top,
  2674. uint32_t input_padding_right,
  2675. uint32_t input_padding_bottom,
  2676. uint32_t input_padding_left,
  2677. uint32_t kernel_height,
  2678. uint32_t kernel_width,
  2679. uint32_t subsampling_height,
  2680. uint32_t subsampling_width,
  2681. uint32_t dilation_height,
  2682. uint32_t dilation_width,
  2683. uint32_t groups,
  2684. size_t group_input_channels,
  2685. size_t group_output_channels,
  2686. size_t input_channel_stride,
  2687. size_t output_channel_stride,
  2688. const float* kernel,
  2689. const float* bias,
  2690. float output_min,
  2691. float output_max,
  2692. uint32_t flags,
  2693. xnn_code_cache_t code_cache,
  2694. xnn_weights_cache_t weights_cache,
  2695. xnn_operator_t* convolution_op_out);
  2696. enum xnn_status xnn_reshape_convolution2d_nchw_f32(
  2697. xnn_operator_t convolution_op,
  2698. size_t batch_size,
  2699. size_t input_height,
  2700. size_t input_width,
  2701. size_t* output_height_out,
  2702. size_t* output_width_out,
  2703. pthreadpool_t threadpool);
  2704. enum xnn_status xnn_setup_convolution2d_nchw_f32(
  2705. xnn_operator_t convolution_op,
  2706. const float* input,
  2707. float* output);
  2708. enum xnn_status xnn_create_convolution2d_nhwc_f16(
  2709. uint32_t input_padding_top,
  2710. uint32_t input_padding_right,
  2711. uint32_t input_padding_bottom,
  2712. uint32_t input_padding_left,
  2713. uint32_t kernel_height,
  2714. uint32_t kernel_width,
  2715. uint32_t subsampling_height,
  2716. uint32_t subsampling_width,
  2717. uint32_t dilation_height,
  2718. uint32_t dilation_width,
  2719. uint32_t groups,
  2720. size_t group_input_channels,
  2721. size_t group_output_channels,
  2722. size_t input_channel_stride,
  2723. size_t output_channel_stride,
  2724. const void* kernel,
  2725. const void* bias,
  2726. float output_min,
  2727. float output_max,
  2728. uint32_t flags,
  2729. xnn_code_cache_t code_cache,
  2730. xnn_weights_cache_t weights_cache,
  2731. xnn_operator_t* convolution_op_out);
  2732. enum xnn_status xnn_reshape_convolution2d_nhwc_f16(
  2733. xnn_operator_t convolution_op,
  2734. size_t batch_size,
  2735. size_t input_height,
  2736. size_t input_width,
  2737. size_t* workspace_size,
  2738. size_t* workspace_alignment,
  2739. size_t* output_height_out,
  2740. size_t* output_width_out,
  2741. pthreadpool_t threadpool);
  2742. enum xnn_status xnn_setup_convolution2d_nhwc_f16(
  2743. xnn_operator_t convolution_op,
  2744. void* workspace,
  2745. const void* input,
  2746. void* output);
  2747. enum xnn_status xnn_create_convolution2d_nhwc_f32(
  2748. uint32_t input_padding_top,
  2749. uint32_t input_padding_right,
  2750. uint32_t input_padding_bottom,
  2751. uint32_t input_padding_left,
  2752. uint32_t kernel_height,
  2753. uint32_t kernel_width,
  2754. uint32_t subsampling_height,
  2755. uint32_t subsampling_width,
  2756. uint32_t dilation_height,
  2757. uint32_t dilation_width,
  2758. uint32_t groups,
  2759. size_t group_input_channels,
  2760. size_t group_output_channels,
  2761. size_t input_channel_stride,
  2762. size_t output_channel_stride,
  2763. const float* kernel,
  2764. const float* bias,
  2765. float output_min,
  2766. float output_max,
  2767. uint32_t flags,
  2768. xnn_code_cache_t code_cache,
  2769. xnn_weights_cache_t weights_cache,
  2770. xnn_operator_t* convolution_op_out);
  2771. enum xnn_status xnn_create_convolution2d_nhwc_f32_f16(
  2772. uint32_t input_padding_top,
  2773. uint32_t input_padding_right,
  2774. uint32_t input_padding_bottom,
  2775. uint32_t input_padding_left,
  2776. uint32_t kernel_height,
  2777. uint32_t kernel_width,
  2778. uint32_t subsampling_height,
  2779. uint32_t subsampling_width,
  2780. uint32_t dilation_height,
  2781. uint32_t dilation_width,
  2782. uint32_t groups,
  2783. size_t group_input_channels,
  2784. size_t group_output_channels,
  2785. size_t input_channel_stride,
  2786. size_t output_channel_stride,
  2787. const void* kernel,
  2788. const void* bias,
  2789. float output_min,
  2790. float output_max,
  2791. uint32_t flags,
  2792. xnn_code_cache_t code_cache,
  2793. xnn_weights_cache_t weights_cache,
  2794. xnn_operator_t* convolution_op_out);
  2795. // Forward declare.
  2796. struct xnn_post_operation;
  2797. /// Deprecated
  2798. enum xnn_status xnn_create_fused_convolution2d_nhwc_f32(
  2799. uint32_t input_padding_top,
  2800. uint32_t input_padding_right,
  2801. uint32_t input_padding_bottom,
  2802. uint32_t input_padding_left,
  2803. uint32_t kernel_height,
  2804. uint32_t kernel_width,
  2805. uint32_t subsampling_height,
  2806. uint32_t subsampling_width,
  2807. uint32_t dilation_height,
  2808. uint32_t dilation_width,
  2809. uint32_t groups,
  2810. size_t group_input_channels,
  2811. size_t group_output_channels,
  2812. size_t input_channel_stride,
  2813. size_t output_channel_stride,
  2814. const float* kernel,
  2815. const float* bias,
  2816. size_t num_post_operations,
  2817. struct xnn_post_operation* post_operations,
  2818. uint32_t flags,
  2819. xnn_code_cache_t code_cache,
  2820. xnn_weights_cache_t weights_cache,
  2821. xnn_operator_t* convolution_op_out);
  2822. enum xnn_status xnn_reshape_convolution2d_nhwc_f32(
  2823. xnn_operator_t convolution_op,
  2824. size_t batch_size,
  2825. size_t input_height,
  2826. size_t input_width,
  2827. size_t* workspace_size,
  2828. size_t* workspace_alignment,
  2829. size_t* output_height_out,
  2830. size_t* output_width_out,
  2831. pthreadpool_t threadpool);
  2832. enum xnn_status xnn_setup_convolution2d_nhwc_f32(
  2833. xnn_operator_t convolution_op,
  2834. void* workspace,
  2835. const float* input,
  2836. float* output);
  2837. enum xnn_status xnn_create_convolution2d_nhwc_qd8_f16_qc8w(
  2838. uint32_t input_padding_top, uint32_t input_padding_right,
  2839. uint32_t input_padding_bottom, uint32_t input_padding_left,
  2840. uint32_t kernel_height, uint32_t kernel_width, uint32_t subsampling_height,
  2841. uint32_t subsampling_width, uint32_t dilation_height,
  2842. uint32_t dilation_width, uint32_t groups, size_t group_input_channels,
  2843. size_t group_output_channels, size_t input_channel_stride,
  2844. size_t output_channel_stride, const float* kernel_scale,
  2845. const int8_t* kernel, const float* bias, float output_min, float output_max,
  2846. uint32_t flags, xnn_code_cache_t code_cache,
  2847. xnn_weights_cache_t weights_cache, xnn_operator_t* convolution_op_out);
  2848. enum xnn_status xnn_create_convolution2d_nhwc_qd8_f32_qc8w(
  2849. uint32_t input_padding_top, uint32_t input_padding_right,
  2850. uint32_t input_padding_bottom, uint32_t input_padding_left,
  2851. uint32_t kernel_height, uint32_t kernel_width, uint32_t subsampling_height,
  2852. uint32_t subsampling_width, uint32_t dilation_height,
  2853. uint32_t dilation_width, uint32_t groups, size_t group_input_channels,
  2854. size_t group_output_channels, size_t input_channel_stride,
  2855. size_t output_channel_stride, const float* kernel_scale,
  2856. const int8_t* kernel, const float* bias, float output_min, float output_max,
  2857. uint32_t flags, xnn_code_cache_t code_cache,
  2858. xnn_weights_cache_t weights_cache, xnn_operator_t* convolution_op_out);
  2859. enum xnn_status xnn_create_convolution2d_nhwc_qs8(
  2860. uint32_t input_padding_top,
  2861. uint32_t input_padding_right,
  2862. uint32_t input_padding_bottom,
  2863. uint32_t input_padding_left,
  2864. uint32_t kernel_height,
  2865. uint32_t kernel_width,
  2866. uint32_t subsampling_height,
  2867. uint32_t subsampling_width,
  2868. uint32_t dilation_height,
  2869. uint32_t dilation_width,
  2870. uint32_t groups,
  2871. size_t group_input_channels,
  2872. size_t group_output_channels,
  2873. size_t input_channel_stride,
  2874. size_t output_channel_stride,
  2875. int8_t input_zero_point,
  2876. float input_scale,
  2877. float kernel_scale,
  2878. const int8_t* kernel,
  2879. const int32_t* bias,
  2880. int8_t output_zero_point,
  2881. float output_scale,
  2882. int8_t output_min,
  2883. int8_t output_max,
  2884. uint32_t flags,
  2885. xnn_code_cache_t code_cache,
  2886. xnn_weights_cache_t weights_cache,
  2887. xnn_operator_t* convolution_op_out);
  2888. enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f16_qc8w(
  2889. xnn_operator_t convolution_op, size_t batch_size, size_t input_height,
  2890. size_t input_width, size_t* workspace_size, size_t* workspace_alignment,
  2891. size_t* output_height_out, size_t* output_width_out,
  2892. pthreadpool_t threadpool);
  2893. enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f32_qc8w(
  2894. xnn_operator_t convolution_op, size_t batch_size, size_t input_height,
  2895. size_t input_width, size_t* workspace_size, size_t* workspace_alignment,
  2896. size_t* output_height_out, size_t* output_width_out,
  2897. pthreadpool_t threadpool);
  2898. enum xnn_status xnn_reshape_convolution2d_nhwc_qs8(
  2899. xnn_operator_t convolution_op,
  2900. size_t batch_size,
  2901. size_t input_height,
  2902. size_t input_width,
  2903. size_t* workspace_size,
  2904. size_t* workspace_alignment,
  2905. size_t* output_height_out,
  2906. size_t* output_width_out,
  2907. pthreadpool_t threadpool);
  2908. enum xnn_status xnn_setup_convolution2d_nhwc_qd8_f16_qc8w(
  2909. xnn_operator_t convolution_op, void* workspace, const int8_t* input,
  2910. void* output,
  2911. const struct xnn_quantization_params* quantization_params);
  2912. enum xnn_status xnn_setup_convolution2d_nhwc_qd8_f32_qc8w(
  2913. xnn_operator_t convolution_op, void* workspace, const int8_t* input,
  2914. float* output,
  2915. const struct xnn_quantization_params* quantization_params);
  2916. enum xnn_status xnn_setup_convolution2d_nhwc_qs8(
  2917. xnn_operator_t convolution_op,
  2918. void* workspace,
  2919. const int8_t* input,
  2920. int8_t* output);
  2921. enum xnn_status xnn_create_convolution2d_nhwc_qs8_qc8w(
  2922. uint32_t input_padding_top,
  2923. uint32_t input_padding_right,
  2924. uint32_t input_padding_bottom,
  2925. uint32_t input_padding_left,
  2926. uint32_t kernel_height,
  2927. uint32_t kernel_width,
  2928. uint32_t subsampling_height,
  2929. uint32_t subsampling_width,
  2930. uint32_t dilation_height,
  2931. uint32_t dilation_width,
  2932. uint32_t groups,
  2933. size_t group_input_channels,
  2934. size_t group_output_channels,
  2935. size_t input_channel_stride,
  2936. size_t output_channel_stride,
  2937. int8_t input_zero_point,
  2938. float input_scale,
  2939. const float* kernel_scale,
  2940. const int8_t* kernel,
  2941. const int32_t* bias,
  2942. int8_t output_zero_point,
  2943. float output_scale,
  2944. int8_t output_min,
  2945. int8_t output_max,
  2946. uint32_t flags,
  2947. xnn_code_cache_t code_cache,
  2948. xnn_weights_cache_t weights_cache,
  2949. xnn_operator_t* convolution_op_out);
  2950. enum xnn_status xnn_reshape_convolution2d_nhwc_qs8_qc8w(
  2951. xnn_operator_t convolution_op,
  2952. size_t batch_size,
  2953. size_t input_height,
  2954. size_t input_width,
  2955. size_t* workspace_size,
  2956. size_t* workspace_alignment,
  2957. size_t* output_height_out,
  2958. size_t* output_width_out,
  2959. pthreadpool_t threadpool);
  2960. enum xnn_status xnn_setup_convolution2d_nhwc_qs8_qc8w(
  2961. xnn_operator_t convolution_op,
  2962. void* workspace,
  2963. const int8_t* input,
  2964. int8_t* output);
  2965. enum xnn_status xnn_create_convolution2d_nhwc_qu8(
  2966. uint32_t input_padding_top,
  2967. uint32_t input_padding_right,
  2968. uint32_t input_padding_bottom,
  2969. uint32_t input_padding_left,
  2970. uint32_t kernel_height,
  2971. uint32_t kernel_width,
  2972. uint32_t subsampling_height,
  2973. uint32_t subsampling_width,
  2974. uint32_t dilation_height,
  2975. uint32_t dilation_width,
  2976. uint32_t groups,
  2977. size_t group_input_channels,
  2978. size_t group_output_channels,
  2979. size_t input_channel_stride,
  2980. size_t output_channel_stride,
  2981. uint8_t input_zero_point,
  2982. float input_scale,
  2983. uint8_t kernel_zero_point,
  2984. float kernel_scale,
  2985. const uint8_t* kernel,
  2986. const int32_t* bias,
  2987. uint8_t output_zero_point,
  2988. float output_scale,
  2989. uint8_t output_min,
  2990. uint8_t output_max,
  2991. uint32_t flags,
  2992. xnn_code_cache_t code_cache,
  2993. xnn_weights_cache_t weights_cache,
  2994. xnn_operator_t* convolution_op_out);
  2995. enum xnn_status xnn_reshape_convolution2d_nhwc_qu8(
  2996. xnn_operator_t convolution_op,
  2997. size_t batch_size,
  2998. size_t input_height,
  2999. size_t input_width,
  3000. size_t* workspace_size,
  3001. size_t* workspace_alignment,
  3002. size_t* output_height_out,
  3003. size_t* output_width_out,
  3004. pthreadpool_t threadpool);
  3005. enum xnn_status xnn_setup_convolution2d_nhwc_qu8(
  3006. xnn_operator_t convolution_op,
  3007. void* workspace,
  3008. const uint8_t* input,
  3009. uint8_t* output);
  3010. enum xnn_status xnn_create_copy_nc_x8(
  3011. uint32_t flags,
  3012. xnn_operator_t* copy_op_out);
  3013. enum xnn_status xnn_reshape_copy_nc_x8(
  3014. xnn_operator_t copy_op,
  3015. size_t batch_size,
  3016. size_t channels,
  3017. size_t input_stride,
  3018. size_t output_stride,
  3019. pthreadpool_t threadpool);
  3020. enum xnn_status xnn_setup_copy_nc_x8(
  3021. xnn_operator_t copy_op,
  3022. const void* input,
  3023. void* output);
  3024. enum xnn_status xnn_create_copy_nc_x16(
  3025. uint32_t flags,
  3026. xnn_operator_t* copy_op_out);
  3027. enum xnn_status xnn_reshape_copy_nc_x16(
  3028. xnn_operator_t copy_op,
  3029. size_t batch_size,
  3030. size_t channels,
  3031. size_t input_stride,
  3032. size_t output_stride,
  3033. pthreadpool_t threadpool);
  3034. enum xnn_status xnn_setup_copy_nc_x16(
  3035. xnn_operator_t copy_op,
  3036. const void* input,
  3037. void* output);
  3038. enum xnn_status xnn_create_copy_nc_x32(
  3039. uint32_t flags,
  3040. xnn_operator_t* copy_op_out);
  3041. enum xnn_status xnn_reshape_copy_nc_x32(
  3042. xnn_operator_t copy_op,
  3043. size_t batch_size,
  3044. size_t channels,
  3045. size_t input_stride,
  3046. size_t output_stride,
  3047. pthreadpool_t threadpool);
  3048. enum xnn_status xnn_setup_copy_nc_x32(
  3049. xnn_operator_t copy_op,
  3050. const void* input,
  3051. void* output);
  3052. enum xnn_status xnn_run_copy_nc_x32(
  3053. size_t channels,
  3054. size_t input_stride,
  3055. size_t output_stride,
  3056. size_t batch_size,
  3057. const uint32_t* input,
  3058. uint32_t* output,
  3059. uint32_t flags,
  3060. pthreadpool_t threadpool);
  3061. enum xnn_status xnn_create_deconvolution2d_nhwc_f16(
  3062. uint32_t output_padding_top,
  3063. uint32_t output_padding_right,
  3064. uint32_t output_padding_bottom,
  3065. uint32_t output_padding_left,
  3066. uint32_t kernel_height,
  3067. uint32_t kernel_width,
  3068. uint32_t stride_height,
  3069. uint32_t stride_width,
  3070. uint32_t dilation_height,
  3071. uint32_t dilation_width,
  3072. uint32_t groups,
  3073. size_t group_input_channels,
  3074. size_t group_output_channels,
  3075. size_t input_pixel_stride,
  3076. size_t output_pixel_stride,
  3077. const void* kernel,
  3078. const void* bias,
  3079. float output_min,
  3080. float output_max,
  3081. uint32_t flags,
  3082. xnn_code_cache_t code_cache,
  3083. xnn_weights_cache_t weights_cache,
  3084. xnn_operator_t* deconvolution_op_out);
  3085. enum xnn_status xnn_reshape_deconvolution2d_nhwc_f16(
  3086. xnn_operator_t deconvolution_op,
  3087. size_t batch_size,
  3088. size_t input_height,
  3089. size_t input_width,
  3090. uint32_t adjustment_height,
  3091. uint32_t adjustment_width,
  3092. size_t* output_height_out,
  3093. size_t* output_width_out,
  3094. pthreadpool_t threadpool);
  3095. enum xnn_status xnn_setup_deconvolution2d_nhwc_f16(
  3096. xnn_operator_t deconvolution_op,
  3097. const void* input,
  3098. void* output);
  3099. enum xnn_status xnn_create_deconvolution2d_nhwc_f32(
  3100. uint32_t output_padding_top,
  3101. uint32_t output_padding_right,
  3102. uint32_t output_padding_bottom,
  3103. uint32_t output_padding_left,
  3104. uint32_t kernel_height,
  3105. uint32_t kernel_width,
  3106. uint32_t stride_height,
  3107. uint32_t stride_width,
  3108. uint32_t dilation_height,
  3109. uint32_t dilation_width,
  3110. uint32_t groups,
  3111. size_t group_input_channels,
  3112. size_t group_output_channels,
  3113. size_t input_pixel_stride,
  3114. size_t output_pixel_stride,
  3115. const float* kernel,
  3116. const float* bias,
  3117. float output_min,
  3118. float output_max,
  3119. uint32_t flags,
  3120. xnn_code_cache_t code_cache,
  3121. xnn_weights_cache_t weights_cache,
  3122. xnn_operator_t* deconvolution_op_out);
  3123. enum xnn_status xnn_create_deconvolution2d_nhwc_f32_f16(
  3124. uint32_t output_padding_top,
  3125. uint32_t output_padding_right,
  3126. uint32_t output_padding_bottom,
  3127. uint32_t output_padding_left,
  3128. uint32_t kernel_height,
  3129. uint32_t kernel_width,
  3130. uint32_t stride_height,
  3131. uint32_t stride_width,
  3132. uint32_t dilation_height,
  3133. uint32_t dilation_width,
  3134. uint32_t groups,
  3135. size_t group_input_channels,
  3136. size_t group_output_channels,
  3137. size_t input_pixel_stride,
  3138. size_t output_pixel_stride,
  3139. const void* kernel,
  3140. const void* bias,
  3141. float output_min,
  3142. float output_max,
  3143. uint32_t flags,
  3144. xnn_code_cache_t code_cache,
  3145. xnn_weights_cache_t weights_cache,
  3146. xnn_operator_t* deconvolution_op_out);
  3147. enum xnn_status xnn_reshape_deconvolution2d_nhwc_f32(
  3148. xnn_operator_t deconvolution_op,
  3149. size_t batch_size,
  3150. size_t input_height,
  3151. size_t input_width,
  3152. uint32_t adjustment_height,
  3153. uint32_t adjustment_width,
  3154. size_t* output_height_out,
  3155. size_t* output_width_out,
  3156. pthreadpool_t threadpool);
  3157. enum xnn_status xnn_setup_deconvolution2d_nhwc_f32(
  3158. xnn_operator_t deconvolution_op,
  3159. const float* input,
  3160. float* output);
  3161. enum xnn_status xnn_create_deconvolution2d_nhwc_qd8_f32_qc8w(
  3162. uint32_t output_padding_top,
  3163. uint32_t output_padding_right,
  3164. uint32_t output_padding_bottom,
  3165. uint32_t output_padding_left,
  3166. uint32_t kernel_height,
  3167. uint32_t kernel_width,
  3168. uint32_t stride_height,
  3169. uint32_t stride_width,
  3170. uint32_t dilation_height,
  3171. uint32_t dilation_width,
  3172. uint32_t groups,
  3173. size_t group_input_channels,
  3174. size_t group_output_channels,
  3175. size_t input_pixel_stride,
  3176. size_t output_pixel_stride,
  3177. const float* kernel_scale,
  3178. const int8_t* kernel,
  3179. const float* bias,
  3180. float output_min,
  3181. float output_max,
  3182. uint32_t flags,
  3183. xnn_code_cache_t code_cache,
  3184. xnn_weights_cache_t weights_cache,
  3185. xnn_operator_t* deconvolution_op_out);
  3186. enum xnn_status xnn_reshape_deconvolution2d_nhwc_qd8_f32_qc8w(
  3187. xnn_operator_t deconvolution_op,
  3188. size_t batch_size,
  3189. size_t input_height,
  3190. size_t input_width,
  3191. uint32_t adjustment_height,
  3192. uint32_t adjustment_width,
  3193. size_t* output_height_out,
  3194. size_t* output_width_out,
  3195. pthreadpool_t threadpool);
  3196. enum xnn_status xnn_setup_deconvolution2d_nhwc_qd8_f32_qc8w(
  3197. xnn_operator_t deconvolution_op,
  3198. const int8_t* input,
  3199. float* output,
  3200. const struct xnn_quantization_params* quantization_params);
  3201. enum xnn_status xnn_create_deconvolution2d_nhwc_qs8(
  3202. uint32_t output_padding_top,
  3203. uint32_t output_padding_right,
  3204. uint32_t output_padding_bottom,
  3205. uint32_t output_padding_left,
  3206. uint32_t kernel_height,
  3207. uint32_t kernel_width,
  3208. uint32_t stride_height,
  3209. uint32_t stride_width,
  3210. uint32_t dilation_height,
  3211. uint32_t dilation_width,
  3212. uint32_t groups,
  3213. size_t group_input_channels,
  3214. size_t group_output_channels,
  3215. size_t input_pixel_stride,
  3216. size_t output_pixel_stride,
  3217. int8_t input_zero_point,
  3218. float input_scale,
  3219. float kernel_scale,
  3220. const int8_t* kernel,
  3221. const int32_t* bias,
  3222. int8_t output_zero_point,
  3223. float output_scale,
  3224. int8_t output_min,
  3225. int8_t output_max,
  3226. uint32_t flags,
  3227. xnn_code_cache_t code_cache,
  3228. xnn_weights_cache_t weights_cache,
  3229. xnn_operator_t* deconvolution_op_out);
  3230. enum xnn_status xnn_reshape_deconvolution2d_nhwc_qs8(
  3231. xnn_operator_t deconvolution_op,
  3232. size_t batch_size,
  3233. size_t input_height,
  3234. size_t input_width,
  3235. uint32_t adjustment_height,
  3236. uint32_t adjustment_width,
  3237. size_t* output_height_out,
  3238. size_t* output_width_out,
  3239. pthreadpool_t threadpool);
  3240. enum xnn_status xnn_setup_deconvolution2d_nhwc_qs8(
  3241. xnn_operator_t deconvolution_op,
  3242. const int8_t* input,
  3243. int8_t* output);
  3244. enum xnn_status xnn_create_deconvolution2d_nhwc_qs8_qc8w(
  3245. uint32_t output_padding_top,
  3246. uint32_t output_padding_right,
  3247. uint32_t output_padding_bottom,
  3248. uint32_t output_padding_left,
  3249. uint32_t kernel_height,
  3250. uint32_t kernel_width,
  3251. uint32_t stride_height,
  3252. uint32_t stride_width,
  3253. uint32_t dilation_height,
  3254. uint32_t dilation_width,
  3255. uint32_t groups,
  3256. size_t group_input_channels,
  3257. size_t group_output_channels,
  3258. size_t input_pixel_stride,
  3259. size_t output_pixel_stride,
  3260. int8_t input_zero_point,
  3261. float input_scale,
  3262. const float* kernel_scale,
  3263. const int8_t* kernel,
  3264. const int32_t* bias,
  3265. int8_t output_zero_point,
  3266. float output_scale,
  3267. int8_t output_min,
  3268. int8_t output_max,
  3269. uint32_t flags,
  3270. xnn_code_cache_t code_cache,
  3271. xnn_weights_cache_t weights_cache,
  3272. xnn_operator_t* deconvolution_op_out);
  3273. enum xnn_status xnn_reshape_deconvolution2d_nhwc_qs8_qc8w(
  3274. xnn_operator_t deconvolution_op,
  3275. size_t batch_size,
  3276. size_t input_height,
  3277. size_t input_width,
  3278. uint32_t adjustment_height,
  3279. uint32_t adjustment_width,
  3280. size_t* output_height_out,
  3281. size_t* output_width_out,
  3282. pthreadpool_t threadpool);
  3283. enum xnn_status xnn_setup_deconvolution2d_nhwc_qs8_qc8w(
  3284. xnn_operator_t deconvolution_op,
  3285. const int8_t* input,
  3286. int8_t* output);
  3287. enum xnn_status xnn_create_deconvolution2d_nhwc_qu8(
  3288. uint32_t output_padding_top,
  3289. uint32_t output_padding_right,
  3290. uint32_t output_padding_bottom,
  3291. uint32_t output_padding_left,
  3292. uint32_t kernel_height,
  3293. uint32_t kernel_width,
  3294. uint32_t stride_height,
  3295. uint32_t stride_width,
  3296. uint32_t dilation_height,
  3297. uint32_t dilation_width,
  3298. uint32_t groups,
  3299. size_t group_input_channels,
  3300. size_t group_output_channels,
  3301. size_t input_pixel_stride,
  3302. size_t output_pixel_stride,
  3303. uint8_t input_zero_point,
  3304. float input_scale,
  3305. uint8_t kernel_zero_point,
  3306. float kernel_scale,
  3307. const uint8_t* kernel,
  3308. const int32_t* bias,
  3309. uint8_t output_zero_point,
  3310. float output_scale,
  3311. uint8_t output_min,
  3312. uint8_t output_max,
  3313. uint32_t flags,
  3314. xnn_code_cache_t code_cache,
  3315. xnn_weights_cache_t weights_cache,
  3316. xnn_operator_t* deconvolution_op_out);
  3317. enum xnn_status xnn_reshape_deconvolution2d_nhwc_qu8(
  3318. xnn_operator_t deconvolution_op,
  3319. size_t batch_size,
  3320. size_t input_height,
  3321. size_t input_width,
  3322. uint32_t adjustment_height,
  3323. uint32_t adjustment_width,
  3324. size_t* output_height_out,
  3325. size_t* output_width_out,
  3326. pthreadpool_t threadpool);
  3327. enum xnn_status xnn_setup_deconvolution2d_nhwc_qu8(
  3328. xnn_operator_t deconvolution_op,
  3329. const uint8_t* input,
  3330. uint8_t* output);
  3331. enum xnn_status xnn_create_depth_to_space_nchw2nhwc_x16(
  3332. uint32_t block_size,
  3333. uint32_t flags,
  3334. xnn_operator_t* depth_to_space_op_out);
  3335. enum xnn_status xnn_reshape_depth_to_space_nchw2nhwc_x16(
  3336. xnn_operator_t depth_to_space_op,
  3337. size_t batch_size,
  3338. size_t input_height,
  3339. size_t input_width,
  3340. size_t input_channels,
  3341. size_t* output_height_out,
  3342. size_t* output_width_out,
  3343. size_t* output_channels_out,
  3344. pthreadpool_t threadpool);
  3345. enum xnn_status xnn_setup_depth_to_space_nchw2nhwc_x16(
  3346. xnn_operator_t depth_to_space_op,
  3347. const void* input,
  3348. void* output);
  3349. enum xnn_status xnn_create_depth_to_space_nchw2nhwc_x32(
  3350. uint32_t block_size,
  3351. uint32_t flags,
  3352. xnn_operator_t* depth_to_space_op_out);
  3353. enum xnn_status xnn_reshape_depth_to_space_nchw2nhwc_x32(
  3354. xnn_operator_t depth_to_space_op,
  3355. size_t batch_size,
  3356. size_t input_height,
  3357. size_t input_width,
  3358. size_t input_channels,
  3359. size_t* output_height_out,
  3360. size_t* output_width_out,
  3361. size_t* output_channels_out,
  3362. pthreadpool_t threadpool);
  3363. enum xnn_status xnn_setup_depth_to_space_nchw2nhwc_x32(
  3364. xnn_operator_t depth_to_space_op,
  3365. const void* input,
  3366. void* output);
  3367. enum xnn_status xnn_create_depth_to_space_nhwc_x8(
  3368. uint32_t block_size,
  3369. uint32_t flags,
  3370. xnn_operator_t* depth_to_space_op_out);
  3371. enum xnn_status xnn_reshape_depth_to_space_nhwc_x8(
  3372. xnn_operator_t depth_to_space_op,
  3373. size_t batch_size,
  3374. size_t input_height,
  3375. size_t input_width,
  3376. size_t input_channels,
  3377. size_t* output_height_out,
  3378. size_t* output_width_out,
  3379. size_t* output_channels_out,
  3380. pthreadpool_t threadpool);
  3381. enum xnn_status xnn_setup_depth_to_space_nhwc_x8(
  3382. xnn_operator_t depth_to_space_op,
  3383. const void* input,
  3384. void* output);
  3385. enum xnn_status xnn_create_depth_to_space_nhwc_x16(
  3386. uint32_t block_size,
  3387. uint32_t flags,
  3388. xnn_operator_t* depth_to_space_op_out);
  3389. enum xnn_status xnn_reshape_depth_to_space_nhwc_x16(
  3390. xnn_operator_t depth_to_space_op,
  3391. size_t batch_size,
  3392. size_t input_height,
  3393. size_t input_width,
  3394. size_t input_channels,
  3395. size_t* output_height_out,
  3396. size_t* output_width_out,
  3397. size_t* output_channels_out,
  3398. pthreadpool_t threadpool);
  3399. enum xnn_status xnn_setup_depth_to_space_nhwc_x16(
  3400. xnn_operator_t depth_to_space_op,
  3401. const void* input,
  3402. void* output);
  3403. enum xnn_status xnn_create_depth_to_space_nhwc_x32(
  3404. uint32_t block_size,
  3405. uint32_t flags,
  3406. xnn_operator_t* depth_to_space_op_out);
  3407. enum xnn_status xnn_reshape_depth_to_space_nhwc_x32(
  3408. xnn_operator_t depth_to_space_op,
  3409. size_t batch_size,
  3410. size_t input_height,
  3411. size_t input_width,
  3412. size_t input_channels,
  3413. size_t* output_height_out,
  3414. size_t* output_width_out,
  3415. size_t* output_channels_out,
  3416. pthreadpool_t threadpool);
  3417. enum xnn_status xnn_setup_depth_to_space_nhwc_x32(
  3418. xnn_operator_t depth_to_space_op,
  3419. const void* input,
  3420. void* output);
  3421. enum xnn_status xnn_create_dynamic_fully_connected_nc_f16(
  3422. float output_min,
  3423. float output_max,
  3424. uint32_t flags,
  3425. xnn_operator_t* dynamic_fully_connected_op_out);
  3426. enum xnn_status xnn_reshape_dynamic_fully_connected_nc_f16(
  3427. xnn_operator_t dynamic_fully_connected_op,
  3428. size_t batch_size,
  3429. size_t input_channels,
  3430. size_t output_channels,
  3431. size_t input_stride,
  3432. size_t output_stride,
  3433. size_t* workspace_size,
  3434. size_t* workspace_alignment,
  3435. pthreadpool_t threadpool);
  3436. enum xnn_status xnn_setup_dynamic_fully_connected_nc_f16(
  3437. xnn_operator_t dynamic_fully_connected_op,
  3438. void* workspace,
  3439. const void* input,
  3440. const void* kernel,
  3441. const void* bias,
  3442. void* output);
  3443. enum xnn_status xnn_create_dynamic_fully_connected_nc_f32(
  3444. float output_min,
  3445. float output_max,
  3446. uint32_t flags,
  3447. xnn_operator_t* dynamic_fully_connected_op_out);
  3448. enum xnn_status xnn_reshape_dynamic_fully_connected_nc_f32(
  3449. xnn_operator_t dynamic_fully_connected_op,
  3450. size_t batch_size,
  3451. size_t input_channels,
  3452. size_t output_channels,
  3453. size_t input_stride,
  3454. size_t output_stride,
  3455. size_t* workspace_size,
  3456. size_t* workspace_alignment,
  3457. pthreadpool_t threadpool);
  3458. enum xnn_status xnn_setup_dynamic_fully_connected_nc_f32(
  3459. xnn_operator_t dynamic_fully_connected_op,
  3460. void* workspace,
  3461. const float* input,
  3462. const float* kernel,
  3463. const float* bias,
  3464. float* output);
  3465. enum xnn_status xnn_create_fully_connected_nc_f16(
  3466. size_t input_channels,
  3467. size_t output_channels,
  3468. size_t input_stride,
  3469. size_t output_stride,
  3470. const void* kernel,
  3471. const void* bias,
  3472. float output_min,
  3473. float output_max,
  3474. uint32_t flags,
  3475. xnn_code_cache_t code_cache,
  3476. xnn_weights_cache_t weights_cache,
  3477. xnn_operator_t* fully_connected_op_out);
  3478. enum xnn_status xnn_reshape_fully_connected_nc_f16(
  3479. xnn_operator_t fully_connected_op,
  3480. size_t batch_size,
  3481. pthreadpool_t threadpool);
  3482. enum xnn_status xnn_setup_fully_connected_nc_f16(
  3483. xnn_operator_t fully_connected_op,
  3484. const void* input,
  3485. void* output);
  3486. enum xnn_status xnn_create_fully_connected_nc_f32_f16(
  3487. size_t input_channels,
  3488. size_t output_channels,
  3489. size_t input_stride,
  3490. size_t output_stride,
  3491. const void* kernel,
  3492. const void* bias,
  3493. float output_min,
  3494. float output_max,
  3495. uint32_t flags,
  3496. xnn_code_cache_t code_cache,
  3497. xnn_weights_cache_t weights_cache,
  3498. xnn_operator_t* fully_connected_op_out);
  3499. enum xnn_status xnn_create_fully_connected_nc_f32(
  3500. size_t input_channels,
  3501. size_t output_channels,
  3502. size_t input_stride,
  3503. size_t output_stride,
  3504. const float* kernel,
  3505. const float* bias,
  3506. float output_min,
  3507. float output_max,
  3508. uint32_t flags,
  3509. xnn_code_cache_t code_cache,
  3510. xnn_weights_cache_t weights_cache,
  3511. xnn_operator_t* fully_connected_op_out);
  3512. enum xnn_status xnn_reshape_fully_connected_nc_f32_f16(
  3513. xnn_operator_t fully_connected_op,
  3514. size_t batch_size,
  3515. pthreadpool_t threadpool);
  3516. enum xnn_status xnn_reshape_fully_connected_nc_f32(
  3517. xnn_operator_t fully_connected_op,
  3518. size_t batch_size,
  3519. pthreadpool_t threadpool);
  3520. enum xnn_status xnn_setup_fully_connected_nc_f32_f16(
  3521. xnn_operator_t fully_connected_op,
  3522. const float* input,
  3523. float* output);
  3524. enum xnn_status xnn_setup_fully_connected_nc_f32(
  3525. xnn_operator_t fully_connected_op,
  3526. const float* input,
  3527. float* output);
  3528. enum xnn_status xnn_create_fully_connected_nc_f32_qc4w(
  3529. size_t input_channels,
  3530. size_t output_channels,
  3531. size_t input_stride,
  3532. size_t output_stride,
  3533. uint8_t kernel_zero_point,
  3534. const float* kernel_scale,
  3535. const uint8_t* kernel,
  3536. const float* bias,
  3537. float output_min,
  3538. float output_max,
  3539. uint32_t flags,
  3540. xnn_code_cache_t code_cache,
  3541. xnn_weights_cache_t weights_cache,
  3542. xnn_operator_t* fully_connected_op_out);
  3543. enum xnn_status xnn_reshape_fully_connected_nc_f32_qc4w(
  3544. xnn_operator_t fully_connected_op,
  3545. size_t batch_size,
  3546. pthreadpool_t threadpool);
  3547. enum xnn_status xnn_setup_fully_connected_nc_f32_qc4w(
  3548. xnn_operator_t fully_connected_op,
  3549. const float* input,
  3550. float* output);
  3551. enum xnn_status xnn_create_fully_connected_nc_f32_qc8w(
  3552. size_t input_channels,
  3553. size_t output_channels,
  3554. size_t input_stride,
  3555. size_t output_stride,
  3556. const float* kernel_scale,
  3557. const int8_t* kernel,
  3558. const float* bias,
  3559. float output_min,
  3560. float output_max,
  3561. uint32_t flags,
  3562. xnn_code_cache_t code_cache,
  3563. xnn_weights_cache_t weights_cache,
  3564. xnn_operator_t* fully_connected_op_out);
  3565. enum xnn_status xnn_reshape_fully_connected_nc_f32_qc8w(
  3566. xnn_operator_t fully_connected_op,
  3567. size_t batch_size,
  3568. pthreadpool_t threadpool);
  3569. enum xnn_status xnn_setup_fully_connected_nc_f32_qc8w(
  3570. xnn_operator_t fully_connected_op,
  3571. const float* input,
  3572. float* output);
  3573. enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc4w(
  3574. size_t input_channels,
  3575. size_t output_channels,
  3576. size_t input_stride,
  3577. size_t output_stride,
  3578. uint8_t kernel_zero_point,
  3579. const float* kernel_scale,
  3580. const void* kernel,
  3581. const float* bias,
  3582. float output_min,
  3583. float output_max,
  3584. uint32_t flags,
  3585. xnn_code_cache_t code_cache,
  3586. xnn_weights_cache_t weights_cache,
  3587. xnn_operator_t* fully_connected_op_out);
  3588. enum xnn_status xnn_setup_fully_connected_nc_qd8_f16_qc4w(
  3589. xnn_operator_t fully_connected_op,
  3590. const int8_t* input,
  3591. void* output,
  3592. const struct xnn_quantization_params* quantization_params);
  3593. enum xnn_status xnn_reshape_fully_connected_nc_qd8_f16_qc4w(
  3594. xnn_operator_t fully_connected_op,
  3595. size_t batch_size,
  3596. pthreadpool_t threadpool);
  3597. enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qb4w(
  3598. size_t input_channels,
  3599. size_t output_channels,
  3600. size_t input_stride,
  3601. size_t output_stride,
  3602. size_t block_size,
  3603. uint8_t kernel_zero_point,
  3604. const uint16_t* kernel_scale,
  3605. const void* kernel,
  3606. const float* bias,
  3607. float output_min,
  3608. float output_max,
  3609. uint32_t flags,
  3610. xnn_code_cache_t code_cache,
  3611. xnn_weights_cache_t weights_cache,
  3612. xnn_operator_t* fully_connected_op_out);
  3613. enum xnn_status xnn_reshape_fully_connected_nc_qd8_f16_qb4w(
  3614. xnn_operator_t fully_connected_op,
  3615. size_t batch_size,
  3616. pthreadpool_t threadpool);
  3617. enum xnn_status xnn_setup_fully_connected_nc_qd8_f16_qb4w(
  3618. xnn_operator_t fully_connected_op,
  3619. const int8_t* input,
  3620. void* output,
  3621. const struct xnn_quantization_params* quantization_params);
  3622. enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qc4w(
  3623. size_t input_channels,
  3624. size_t output_channels,
  3625. size_t input_stride,
  3626. size_t output_stride,
  3627. uint8_t kernel_zero_point,
  3628. const float* kernel_scale,
  3629. const void* kernel,
  3630. const float* bias,
  3631. float output_min,
  3632. float output_max,
  3633. uint32_t flags,
  3634. xnn_code_cache_t code_cache,
  3635. xnn_weights_cache_t weights_cache,
  3636. xnn_operator_t* fully_connected_op_out);
  3637. enum xnn_status xnn_setup_fully_connected_nc_qd8_f32_qc4w(
  3638. xnn_operator_t fully_connected_op,
  3639. const int8_t* input,
  3640. float* output,
  3641. const struct xnn_quantization_params* quantization_params);
  3642. enum xnn_status xnn_reshape_fully_connected_nc_qd8_f32_qc4w(
  3643. xnn_operator_t fully_connected_op,
  3644. size_t batch_size,
  3645. pthreadpool_t threadpool);
  3646. enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qb4w(
  3647. size_t input_channels,
  3648. size_t output_channels,
  3649. size_t input_stride,
  3650. size_t output_stride,
  3651. size_t block_size,
  3652. uint8_t kernel_zero_point,
  3653. const uint16_t* kernel_scale,
  3654. const void* kernel,
  3655. const float* bias,
  3656. float output_min,
  3657. float output_max,
  3658. uint32_t flags,
  3659. xnn_code_cache_t code_cache,
  3660. xnn_weights_cache_t weights_cache,
  3661. xnn_operator_t* fully_connected_op_out);
  3662. enum xnn_status xnn_reshape_fully_connected_nc_qd8_f32_qb4w(
  3663. xnn_operator_t fully_connected_op,
  3664. size_t batch_size,
  3665. pthreadpool_t threadpool);
  3666. enum xnn_status xnn_setup_fully_connected_nc_qd8_f32_qb4w(
  3667. xnn_operator_t fully_connected_op,
  3668. const int8_t* input,
  3669. float* output,
  3670. const struct xnn_quantization_params* quantization_params);
  3671. enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc8w(
  3672. size_t input_channels,
  3673. size_t output_channels,
  3674. size_t input_stride,
  3675. size_t output_stride,
  3676. const float* kernel_scale,
  3677. const int8_t* kernel,
  3678. const float* bias,
  3679. float output_min,
  3680. float output_max,
  3681. uint32_t flags,
  3682. xnn_code_cache_t code_cache,
  3683. xnn_weights_cache_t weights_cache,
  3684. xnn_operator_t* fully_connected_op_out);
  3685. enum xnn_status xnn_setup_fully_connected_nc_qd8_f16_qc8w(
  3686. xnn_operator_t fully_connected_op,
  3687. const int8_t* input,
  3688. void* output,
  3689. const struct xnn_quantization_params* quantization_params);
  3690. enum xnn_status xnn_reshape_fully_connected_nc_qd8_f16_qc8w(
  3691. xnn_operator_t fully_connected_op,
  3692. size_t batch_size,
  3693. pthreadpool_t threadpool);
  3694. enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qc8w(
  3695. size_t input_channels,
  3696. size_t output_channels,
  3697. size_t input_stride,
  3698. size_t output_stride,
  3699. const float* kernel_scale,
  3700. const int8_t* kernel,
  3701. const float* bias,
  3702. float output_min,
  3703. float output_max,
  3704. uint32_t flags,
  3705. xnn_code_cache_t code_cache,
  3706. xnn_weights_cache_t weights_cache,
  3707. xnn_operator_t* fully_connected_op_out);
  3708. enum xnn_status xnn_setup_fully_connected_nc_qd8_f32_qc8w(
  3709. xnn_operator_t fully_connected_op,
  3710. const int8_t* input,
  3711. float* output,
  3712. const struct xnn_quantization_params* quantization_params);
  3713. enum xnn_status xnn_reshape_fully_connected_nc_qd8_f32_qc8w(
  3714. xnn_operator_t fully_connected_op,
  3715. size_t batch_size,
  3716. pthreadpool_t threadpool);
  3717. enum xnn_status xnn_create_fully_connected_nc_qs8(
  3718. size_t input_channels,
  3719. size_t output_channels,
  3720. size_t input_stride,
  3721. size_t output_stride,
  3722. int8_t input_zero_point,
  3723. float input_scale,
  3724. float kernel_scale,
  3725. const int8_t* kernel,
  3726. const int32_t* bias,
  3727. int8_t output_zero_point,
  3728. float output_scale,
  3729. int8_t output_min,
  3730. int8_t output_max,
  3731. uint32_t flags,
  3732. xnn_code_cache_t code_cache,
  3733. xnn_weights_cache_t weights_cache,
  3734. xnn_operator_t* fully_connected_op_out);
  3735. enum xnn_status xnn_reshape_fully_connected_nc_qs8(
  3736. xnn_operator_t fully_connected_op,
  3737. size_t batch_size,
  3738. pthreadpool_t threadpool);
  3739. enum xnn_status xnn_setup_fully_connected_nc_qs8(
  3740. xnn_operator_t fully_connected_op,
  3741. const int8_t* input,
  3742. int8_t* output);
  3743. enum xnn_status xnn_create_fully_connected_nc_qs8_qc8w(
  3744. size_t input_channels,
  3745. size_t output_channels,
  3746. size_t input_stride,
  3747. size_t output_stride,
  3748. int8_t input_zero_point,
  3749. float input_scale,
  3750. const float* kernel_scale,
  3751. const int8_t* kernel,
  3752. const int32_t* bias,
  3753. int8_t output_zero_point,
  3754. float output_scale,
  3755. int8_t output_min,
  3756. int8_t output_max,
  3757. uint32_t flags,
  3758. xnn_code_cache_t code_cache,
  3759. xnn_weights_cache_t weights_cache,
  3760. xnn_operator_t* fully_connected_op_out);
  3761. enum xnn_status xnn_reshape_fully_connected_nc_qs8_qc8w(
  3762. xnn_operator_t fully_connected_op,
  3763. size_t batch_size,
  3764. pthreadpool_t threadpool);
  3765. enum xnn_status xnn_setup_fully_connected_nc_qs8_qc8w(
  3766. xnn_operator_t fully_connected_op,
  3767. const int8_t* input,
  3768. int8_t* output);
  3769. enum xnn_status xnn_create_fully_connected_nc_qu8(
  3770. size_t input_channels,
  3771. size_t output_channels,
  3772. size_t input_stride,
  3773. size_t output_stride,
  3774. uint8_t input_zero_point,
  3775. float input_scale,
  3776. uint8_t kernel_zero_point,
  3777. float kernel_scale,
  3778. const uint8_t* kernel,
  3779. const int32_t* bias,
  3780. uint8_t output_zero_point,
  3781. float output_scale,
  3782. uint8_t output_min,
  3783. uint8_t output_max,
  3784. uint32_t flags,
  3785. xnn_code_cache_t code_cache,
  3786. xnn_weights_cache_t weights_cache,
  3787. xnn_operator_t* fully_connected_op_out);
  3788. enum xnn_status xnn_reshape_fully_connected_nc_qu8(
  3789. xnn_operator_t fully_connected_op,
  3790. size_t batch_size,
  3791. pthreadpool_t threadpool);
  3792. enum xnn_status xnn_setup_fully_connected_nc_qu8(
  3793. xnn_operator_t fully_connected_op,
  3794. const uint8_t* input,
  3795. uint8_t* output);
  3796. enum xnn_status xnn_create_max_pooling2d_nhwc_f16(
  3797. uint32_t input_padding_top,
  3798. uint32_t input_padding_right,
  3799. uint32_t input_padding_bottom,
  3800. uint32_t input_padding_left,
  3801. uint32_t pooling_height,
  3802. uint32_t pooling_width,
  3803. uint32_t stride_height,
  3804. uint32_t stride_width,
  3805. uint32_t dilation_height,
  3806. uint32_t dilation_width,
  3807. float output_min,
  3808. float output_max,
  3809. uint32_t flags,
  3810. xnn_operator_t* max_pooling_op_out);
  3811. enum xnn_status xnn_reshape_max_pooling2d_nhwc_f16(
  3812. xnn_operator_t max_pooling_op,
  3813. size_t batch_size,
  3814. size_t input_height,
  3815. size_t input_width,
  3816. size_t channels,
  3817. size_t input_pixel_stride,
  3818. size_t output_pixel_stride,
  3819. size_t* output_height_out,
  3820. size_t* output_width_out,
  3821. pthreadpool_t threadpool);
  3822. enum xnn_status xnn_setup_max_pooling2d_nhwc_f16(
  3823. xnn_operator_t max_pooling_op,
  3824. const void* input,
  3825. void* output);
  3826. enum xnn_status xnn_create_max_pooling2d_nhwc_f32(
  3827. uint32_t input_padding_top,
  3828. uint32_t input_padding_right,
  3829. uint32_t input_padding_bottom,
  3830. uint32_t input_padding_left,
  3831. uint32_t pooling_height,
  3832. uint32_t pooling_width,
  3833. uint32_t stride_height,
  3834. uint32_t stride_width,
  3835. uint32_t dilation_height,
  3836. uint32_t dilation_width,
  3837. float output_min,
  3838. float output_max,
  3839. uint32_t flags,
  3840. xnn_operator_t* max_pooling_op_out);
  3841. enum xnn_status xnn_reshape_max_pooling2d_nhwc_f32(
  3842. xnn_operator_t max_pooling_op,
  3843. size_t batch_size,
  3844. size_t input_height,
  3845. size_t input_width,
  3846. size_t channels,
  3847. size_t input_pixel_stride,
  3848. size_t output_pixel_stride,
  3849. size_t* output_height_out,
  3850. size_t* output_width_out,
  3851. pthreadpool_t threadpool);
  3852. enum xnn_status xnn_setup_max_pooling2d_nhwc_f32(
  3853. xnn_operator_t max_pooling_op,
  3854. const float* input,
  3855. float* output);
  3856. enum xnn_status xnn_create_max_pooling2d_nhwc_s8(
  3857. uint32_t input_padding_top,
  3858. uint32_t input_padding_right,
  3859. uint32_t input_padding_bottom,
  3860. uint32_t input_padding_left,
  3861. uint32_t pooling_height,
  3862. uint32_t pooling_width,
  3863. uint32_t stride_height,
  3864. uint32_t stride_width,
  3865. uint32_t dilation_height,
  3866. uint32_t dilation_width,
  3867. int8_t output_min,
  3868. int8_t output_max,
  3869. uint32_t flags,
  3870. xnn_operator_t* max_pooling_op_out);
  3871. enum xnn_status xnn_reshape_max_pooling2d_nhwc_s8(
  3872. xnn_operator_t max_pooling_op,
  3873. size_t batch_size,
  3874. size_t input_height,
  3875. size_t input_width,
  3876. size_t channels,
  3877. size_t input_pixel_stride,
  3878. size_t output_pixel_stride,
  3879. size_t* output_height_out,
  3880. size_t* output_width_out,
  3881. pthreadpool_t threadpool);
  3882. enum xnn_status xnn_setup_max_pooling2d_nhwc_s8(
  3883. xnn_operator_t max_pooling_op,
  3884. const int8_t* input,
  3885. int8_t* output);
  3886. enum xnn_status xnn_create_max_pooling2d_nhwc_u8(
  3887. uint32_t input_padding_top,
  3888. uint32_t input_padding_right,
  3889. uint32_t input_padding_bottom,
  3890. uint32_t input_padding_left,
  3891. uint32_t pooling_height,
  3892. uint32_t pooling_width,
  3893. uint32_t stride_height,
  3894. uint32_t stride_width,
  3895. uint32_t dilation_height,
  3896. uint32_t dilation_width,
  3897. uint8_t output_min,
  3898. uint8_t output_max,
  3899. uint32_t flags,
  3900. xnn_operator_t* max_pooling_op_out);
  3901. enum xnn_status xnn_reshape_max_pooling2d_nhwc_u8(
  3902. xnn_operator_t max_pooling_op,
  3903. size_t batch_size,
  3904. size_t input_height,
  3905. size_t input_width,
  3906. size_t channels,
  3907. size_t input_pixel_stride,
  3908. size_t output_pixel_stride,
  3909. size_t* output_height_out,
  3910. size_t* output_width_out,
  3911. pthreadpool_t threadpool);
  3912. enum xnn_status xnn_setup_max_pooling2d_nhwc_u8(
  3913. xnn_operator_t max_pooling_op,
  3914. const uint8_t* input,
  3915. uint8_t* output);
  3916. enum xnn_status xnn_create_reduce_nd(
  3917. enum xnn_reduce_operator reduce_operator_type,
  3918. enum xnn_datatype datatype,
  3919. const struct xnn_quantization_params* input_quantization,
  3920. const struct xnn_quantization_params* output_quantization,
  3921. uint32_t flags,
  3922. xnn_operator_t* reduce_op_out);
  3923. enum xnn_status xnn_reshape_reduce_nd( //
  3924. xnn_operator_t reduce_op, //
  3925. size_t num_reduction_axes, //
  3926. const int64_t* reduction_axes, //
  3927. size_t num_input_dims, //
  3928. const size_t* input_shape, //
  3929. size_t* workspace_size, //
  3930. size_t* workspace_alignment, //
  3931. pthreadpool_t threadpool);
  3932. enum xnn_status xnn_setup_reduce_nd(
  3933. xnn_operator_t reduce_op,
  3934. void* workspace,
  3935. const void* input,
  3936. void* output);
  3937. enum xnn_status xnn_create_resize_bilinear2d_nchw_f32(
  3938. size_t output_height,
  3939. size_t output_width,
  3940. uint32_t flags,
  3941. xnn_operator_t* resize_op_out);
  3942. enum xnn_status xnn_reshape_resize_bilinear2d_nchw_f32(
  3943. xnn_operator_t resize_op,
  3944. size_t batch_size,
  3945. size_t input_height,
  3946. size_t input_width,
  3947. size_t channels,
  3948. size_t input_pixel_stride,
  3949. size_t output_pixel_stride,
  3950. pthreadpool_t threadpool);
  3951. enum xnn_status xnn_setup_resize_bilinear2d_nchw_f32(
  3952. xnn_operator_t resize_op,
  3953. const float* input,
  3954. float* output);
  3955. enum xnn_status xnn_create_resize_bilinear2d_nchw_f16(
  3956. size_t output_height,
  3957. size_t output_width,
  3958. uint32_t flags,
  3959. xnn_operator_t* resize_op_out);
  3960. enum xnn_status xnn_reshape_resize_bilinear2d_nchw_f16(
  3961. xnn_operator_t resize_op,
  3962. size_t batch_size,
  3963. size_t input_height,
  3964. size_t input_width,
  3965. size_t channels,
  3966. size_t input_pixel_stride,
  3967. size_t output_pixel_stride,
  3968. pthreadpool_t threadpool);
  3969. enum xnn_status xnn_setup_resize_bilinear2d_nchw_f16(
  3970. xnn_operator_t resize_op,
  3971. const void* input,
  3972. void* output);
  3973. enum xnn_status xnn_create_resize_bilinear2d_nhwc_f16(
  3974. size_t output_height,
  3975. size_t output_width,
  3976. uint32_t flags,
  3977. xnn_operator_t* resize_op_out);
  3978. enum xnn_status xnn_reshape_resize_bilinear2d_nhwc_f16(
  3979. xnn_operator_t resize_op,
  3980. size_t batch_size,
  3981. size_t input_height,
  3982. size_t input_width,
  3983. size_t channels,
  3984. size_t input_pixel_stride,
  3985. size_t output_pixel_stride,
  3986. size_t* workspace_size,
  3987. size_t* workspace_alignment,
  3988. pthreadpool_t threadpool);
  3989. enum xnn_status xnn_setup_resize_bilinear2d_nhwc_f16(
  3990. xnn_operator_t resize_op,
  3991. void* workspace,
  3992. const void* input,
  3993. void* output);
  3994. enum xnn_status xnn_create_resize_bilinear2d_nhwc_f32(
  3995. size_t output_height,
  3996. size_t output_width,
  3997. uint32_t flags,
  3998. xnn_operator_t* resize_op_out);
  3999. enum xnn_status xnn_reshape_resize_bilinear2d_nhwc_f32(
  4000. xnn_operator_t resize_op,
  4001. size_t batch_size,
  4002. size_t input_height,
  4003. size_t input_width,
  4004. size_t channels,
  4005. size_t input_pixel_stride,
  4006. size_t output_pixel_stride,
  4007. size_t* workspace_size,
  4008. size_t* workspace_alignment,
  4009. pthreadpool_t threadpool);
  4010. enum xnn_status xnn_setup_resize_bilinear2d_nhwc_f32(
  4011. xnn_operator_t resize_op,
  4012. void* workspace,
  4013. const float* input,
  4014. float* output);
  4015. enum xnn_status xnn_create_resize_bilinear2d_nhwc_s8(
  4016. size_t output_height,
  4017. size_t output_width,
  4018. uint32_t flags,
  4019. xnn_operator_t* resize_op_out);
  4020. enum xnn_status xnn_reshape_resize_bilinear2d_nhwc_s8(
  4021. xnn_operator_t resize_op,
  4022. size_t batch_size,
  4023. size_t input_height,
  4024. size_t input_width,
  4025. size_t channels,
  4026. size_t input_pixel_stride,
  4027. size_t output_pixel_stride,
  4028. size_t* workspace_size,
  4029. size_t* workspace,
  4030. pthreadpool_t threadpool);
  4031. enum xnn_status xnn_setup_resize_bilinear2d_nhwc_s8(
  4032. xnn_operator_t resize_op,
  4033. void* workspace,
  4034. const int8_t* input,
  4035. int8_t* output);
  4036. enum xnn_status xnn_create_resize_bilinear2d_nhwc_u8(
  4037. size_t output_height,
  4038. size_t output_width,
  4039. uint32_t flags,
  4040. xnn_operator_t* resize_op_out);
  4041. enum xnn_status xnn_reshape_resize_bilinear2d_nhwc_u8(
  4042. xnn_operator_t resize_op,
  4043. size_t batch_size,
  4044. size_t input_height,
  4045. size_t input_width,
  4046. size_t channels,
  4047. size_t input_pixel_stride,
  4048. size_t output_pixel_stride,
  4049. size_t* workspace_size,
  4050. size_t* workspace_alignment,
  4051. pthreadpool_t threadpool);
  4052. enum xnn_status xnn_setup_resize_bilinear2d_nhwc_u8(
  4053. xnn_operator_t resize_op,
  4054. void* workspace,
  4055. const uint8_t* input,
  4056. uint8_t* output);
  4057. enum xnn_status xnn_create_rope_nthc_f16(
  4058. uint32_t flags,
  4059. xnn_operator_t* rope_op_out);
  4060. enum xnn_status xnn_reshape_rope_nthc_f16(
  4061. xnn_operator_t rope_op,
  4062. size_t batch_size,
  4063. size_t tokens,
  4064. size_t heads,
  4065. size_t channels,
  4066. pthreadpool_t threadpool);
  4067. enum xnn_status xnn_setup_rope_nthc_f16(
  4068. xnn_operator_t rope_op,
  4069. const void* input,
  4070. const void* weights,
  4071. void* output);
  4072. enum xnn_status xnn_create_rope_nthc_f32(
  4073. uint32_t flags,
  4074. xnn_operator_t* rope_op_out);
  4075. enum xnn_status xnn_reshape_rope_nthc_f32(
  4076. xnn_operator_t rope_op,
  4077. size_t batch_size,
  4078. size_t tokens,
  4079. size_t heads,
  4080. size_t channels,
  4081. pthreadpool_t threadpool);
  4082. enum xnn_status xnn_setup_rope_nthc_f32(
  4083. xnn_operator_t rope_op,
  4084. const float* input,
  4085. const float* weights,
  4086. float* output);
  4087. // N: batch size
  4088. // H: number of heads
  4089. // T: tokens (sequence length)
  4090. // C: channels (head dimension)
  4091. enum xnn_status xnn_create_scaled_dot_product_attention_nhtc_f16(
  4092. enum xnn_attention_logits_cap_type cap_type,
  4093. const void* cap_params,
  4094. uint32_t flags,
  4095. xnn_operator_t* attention_op_out);
  4096. enum xnn_status xnn_reshape_scaled_dot_product_attention_nhtc_f16(
  4097. xnn_operator_t attention_op,
  4098. size_t batch_size,
  4099. size_t query_heads,
  4100. // Number of tokens in query.
  4101. size_t query_tokens,
  4102. size_t key_value_heads,
  4103. // Number of tokens in key/value. For self-attention, this is same as tokens.
  4104. size_t key_value_tokens,
  4105. size_t query_key_channels,
  4106. size_t value_channels,
  4107. size_t* workspace_size,
  4108. size_t* workspace_alignment,
  4109. pthreadpool_t threadpool);
  4110. // Query is of dimension [batch_size, query_heads, query_tokens, channels].
  4111. // Key and value are of dimension [batch_size, key_value_heads, key_value_tokens, channels].
  4112. // Scale is of dimension [channels].
  4113. // Mask is of dimension [query_tokens, key_value_tokens].
  4114. enum xnn_status xnn_setup_scaled_dot_product_attention_nhtc_f16(
  4115. xnn_operator_t attention_op,
  4116. void* workspace,
  4117. const void* query,
  4118. const void* key,
  4119. const void* value,
  4120. const void* scale,
  4121. const void* mask,
  4122. void* output);
  4123. // N: batch size
  4124. // H: number of heads
  4125. // T: tokens (sequence length)
  4126. // C: channels (head dimension)
  4127. enum xnn_status xnn_create_scaled_dot_product_attention_nhtc_f32(
  4128. enum xnn_attention_logits_cap_type cap_type,
  4129. const void* cap_params,
  4130. uint32_t flags,
  4131. xnn_operator_t* attention_op_out);
  4132. enum xnn_status xnn_reshape_scaled_dot_product_attention_nhtc_f32(
  4133. xnn_operator_t attention_op,
  4134. size_t batch_size,
  4135. size_t query_heads,
  4136. // Number of tokens in query.
  4137. size_t query_tokens,
  4138. size_t key_value_heads,
  4139. // Number of tokens in key/value. For self-attention, this is same as tokens.
  4140. size_t key_value_tokens,
  4141. size_t query_key_channels,
  4142. size_t value_channels,
  4143. size_t* workspace_size,
  4144. size_t* workspace_alignment,
  4145. pthreadpool_t threadpool);
  4146. // Query is of dimension [batch_size, query_heads, query_tokens, query_key_channels].
  4147. // Key and value are of dimension [batch_size, key_value_heads, key_value_tokens, query_key_channels].
  4148. // Scale is of dimension [query_key_channels].
  4149. // Mask is of dimension [query_tokens, key_value_tokens].
  4150. // Output is of dimension [batch_size, query_heads, query_tokens, value_channels].
  4151. enum xnn_status xnn_setup_scaled_dot_product_attention_nhtc_f32(
  4152. xnn_operator_t attention_op,
  4153. void* workspace,
  4154. const float* query,
  4155. const float* key,
  4156. const float* value,
  4157. const float* scale,
  4158. const float* mask,
  4159. float* output);
  4160. enum xnn_status xnn_create_slice_nd_x16(
  4161. uint32_t flags,
  4162. xnn_operator_t* slice_op_out);
  4163. enum xnn_status xnn_reshape_slice_nd_x16(
  4164. xnn_operator_t slice_op,
  4165. size_t num_dims,
  4166. const size_t* input_shape,
  4167. const size_t* offsets,
  4168. const size_t* sizes,
  4169. pthreadpool_t threadpool);
  4170. enum xnn_status xnn_setup_slice_nd_x16(
  4171. xnn_operator_t slice_op,
  4172. const void* input,
  4173. void* output);
  4174. enum xnn_status xnn_create_slice_nd_x32(
  4175. uint32_t flags,
  4176. xnn_operator_t* slice_op_out);
  4177. enum xnn_status xnn_reshape_slice_nd_x32(
  4178. xnn_operator_t slice_op,
  4179. size_t num_dims,
  4180. const size_t* input_shape,
  4181. const size_t* offsets,
  4182. const size_t* sizes,
  4183. pthreadpool_t threadpool);
  4184. enum xnn_status xnn_setup_slice_nd_x32(
  4185. xnn_operator_t slice_op,
  4186. const void* input,
  4187. void* output);
  4188. enum xnn_status xnn_run_slice_nd_x32(
  4189. size_t num_dims,
  4190. const size_t* input_shape,
  4191. const size_t* offsets,
  4192. const size_t* sizes,
  4193. const void* input,
  4194. void* output,
  4195. uint32_t flags,
  4196. pthreadpool_t threadpool);
  4197. enum xnn_status xnn_create_softmax_nc_f16(
  4198. uint32_t flags,
  4199. xnn_operator_t* softmax_op_out);
  4200. enum xnn_status xnn_reshape_softmax_nc_f16(
  4201. xnn_operator_t softmax_op,
  4202. size_t channels,
  4203. size_t input_stride,
  4204. size_t output_stride,
  4205. size_t batch_size,
  4206. pthreadpool_t threadpool);
  4207. enum xnn_status xnn_setup_softmax_nc_f16(
  4208. xnn_operator_t softmax_op,
  4209. const void* input,
  4210. void* output);
  4211. enum xnn_status xnn_create_softmax_nc_f32(
  4212. uint32_t flags,
  4213. xnn_operator_t* softmax_op_out);
  4214. enum xnn_status xnn_reshape_softmax_nc_f32(
  4215. xnn_operator_t softmax_op,
  4216. size_t channels,
  4217. size_t input_stride,
  4218. size_t output_stride,
  4219. size_t batch_size,
  4220. pthreadpool_t threadpool);
  4221. enum xnn_status xnn_setup_softmax_nc_f32(
  4222. xnn_operator_t softmax_op,
  4223. const float* input,
  4224. float* output);
  4225. enum xnn_status xnn_create_softmax_nc_qu8(
  4226. float input_scale,
  4227. uint8_t output_zero_point,
  4228. float output_scale,
  4229. uint32_t flags,
  4230. xnn_operator_t* softmax_op_out);
  4231. enum xnn_status xnn_reshape_softmax_nc_qu8(
  4232. xnn_operator_t softmax_op,
  4233. size_t channels,
  4234. size_t input_stride,
  4235. size_t output_stride,
  4236. size_t batch_size,
  4237. pthreadpool_t threadpool);
  4238. enum xnn_status xnn_setup_softmax_nc_qu8(
  4239. xnn_operator_t softmax_op,
  4240. const uint8_t* input,
  4241. uint8_t* output);
  4242. enum xnn_status xnn_create_space_to_depth_nhwc_x16(
  4243. uint32_t block_size,
  4244. uint32_t flags,
  4245. xnn_operator_t* space_to_depth_op_out);
  4246. enum xnn_status xnn_reshape_space_to_depth_nhwc_x16(
  4247. xnn_operator_t space_to_depth_op,
  4248. size_t batch_size,
  4249. size_t input_height,
  4250. size_t input_width,
  4251. size_t input_channels,
  4252. size_t* output_height_out,
  4253. size_t* output_width_out,
  4254. size_t* output_channels_out,
  4255. pthreadpool_t threadpool);
  4256. enum xnn_status xnn_setup_space_to_depth_nhwc_x16(
  4257. xnn_operator_t space_to_depth_op,
  4258. const void* input,
  4259. void* output);
  4260. enum xnn_status xnn_create_space_to_depth_nhwc_x32(
  4261. uint32_t block_size,
  4262. uint32_t flags,
  4263. xnn_operator_t* space_to_depth_op_out);
  4264. enum xnn_status xnn_reshape_space_to_depth_nhwc_x32(
  4265. xnn_operator_t space_to_depth_op,
  4266. size_t batch_size,
  4267. size_t input_height,
  4268. size_t input_width,
  4269. size_t input_channels,
  4270. size_t* output_height_out,
  4271. size_t* output_width_out,
  4272. size_t* output_channels_out,
  4273. pthreadpool_t threadpool);
  4274. enum xnn_status xnn_setup_space_to_depth_nhwc_x32(
  4275. xnn_operator_t space_to_depth_op,
  4276. const void* input,
  4277. void* output);
  4278. enum xnn_status xnn_create_transpose_nd_x8(
  4279. uint32_t flags,
  4280. xnn_operator_t* transpose_op_out);
  4281. enum xnn_status xnn_reshape_transpose_nd_x8(
  4282. xnn_operator_t transpose_op,
  4283. size_t num_dims,
  4284. const size_t* input_shape,
  4285. const size_t* output_perm,
  4286. pthreadpool_t threadpool);
  4287. enum xnn_status xnn_setup_transpose_nd_x8(
  4288. xnn_operator_t transpose_op,
  4289. const void* input,
  4290. void* output);
  4291. enum xnn_status xnn_run_transpose_nd_x8(
  4292. const void* input,
  4293. void* output,
  4294. size_t num_dims,
  4295. const size_t* input_shape,
  4296. const size_t* output_perm,
  4297. uint32_t flags,
  4298. pthreadpool_t threadpool);
  4299. enum xnn_status xnn_create_transpose_nd_x16(
  4300. uint32_t flags,
  4301. xnn_operator_t* transpose_op_out);
  4302. enum xnn_status xnn_reshape_transpose_nd_x16(
  4303. xnn_operator_t transpose_op,
  4304. size_t num_dims,
  4305. const size_t* input_shape,
  4306. const size_t* output_perm,
  4307. pthreadpool_t threadpool);
  4308. enum xnn_status xnn_setup_transpose_nd_x16(
  4309. xnn_operator_t transpose_op,
  4310. const void* input,
  4311. void* output);
  4312. enum xnn_status xnn_run_transpose_nd_x16(
  4313. const void* input,
  4314. void* output,
  4315. size_t num_dims,
  4316. const size_t* input_shape,
  4317. const size_t* output_perm,
  4318. uint32_t flags,
  4319. pthreadpool_t threadpool);
  4320. enum xnn_status xnn_create_transpose_nd_x32(
  4321. uint32_t flags,
  4322. xnn_operator_t* transpose_op_out);
  4323. enum xnn_status xnn_reshape_transpose_nd_x32(
  4324. xnn_operator_t transpose_op,
  4325. size_t num_dims,
  4326. const size_t* input_shape,
  4327. const size_t* output_perm,
  4328. pthreadpool_t threadpool);
  4329. enum xnn_status xnn_setup_transpose_nd_x32(
  4330. xnn_operator_t transpose_op,
  4331. const void* input,
  4332. void* output);
  4333. enum xnn_status xnn_run_transpose_nd_x32(
  4334. const void* input,
  4335. void* output,
  4336. size_t num_dims,
  4337. const size_t* input_shape,
  4338. const size_t* output_perm,
  4339. uint32_t flags,
  4340. pthreadpool_t threadpool);
  4341. enum xnn_status xnn_create_transpose_nd_x64(
  4342. uint32_t flags,
  4343. xnn_operator_t* transpose_op_out);
  4344. enum xnn_status xnn_reshape_transpose_nd_x64(
  4345. xnn_operator_t transpose_op,
  4346. size_t num_dims,
  4347. const size_t* input_shape,
  4348. const size_t* output_perm,
  4349. pthreadpool_t threadpool);
  4350. enum xnn_status xnn_setup_transpose_nd_x64(
  4351. xnn_operator_t transpose_op,
  4352. const void* input,
  4353. void* output);
  4354. enum xnn_status xnn_run_transpose_nd_x64(
  4355. const void* input,
  4356. void* output,
  4357. size_t num_dims,
  4358. const size_t* input_shape,
  4359. const size_t* output_perm,
  4360. uint32_t flags,
  4361. pthreadpool_t threadpool);
  4362. enum xnn_status xnn_create_unpooling2d_nhwc_x32(
  4363. uint32_t input_padding_top,
  4364. uint32_t input_padding_right,
  4365. uint32_t input_padding_bottom,
  4366. uint32_t input_padding_left,
  4367. uint32_t pooling_height,
  4368. uint32_t pooling_width,
  4369. size_t channels,
  4370. size_t input_pixel_stride,
  4371. size_t output_pixel_stride,
  4372. uint32_t flags,
  4373. xnn_operator_t* unpooling_op_out);
  4374. enum xnn_status xnn_reshape_unpooling2d_nhwc_x32(
  4375. xnn_operator_t unpooling_op,
  4376. size_t batch_size,
  4377. size_t input_height,
  4378. size_t input_width,
  4379. size_t* output_height_out,
  4380. size_t* output_width_out,
  4381. pthreadpool_t threadpool);
  4382. enum xnn_status xnn_setup_unpooling2d_nhwc_x32(
  4383. xnn_operator_t unpooling_op,
  4384. const void* input,
  4385. const uint32_t* index,
  4386. void* output);
  4387. enum xnn_status xnn_create_slice_nd_x8(
  4388. uint32_t flags,
  4389. xnn_operator_t* slice_op_out);
  4390. enum xnn_status xnn_reshape_slice_nd_x8(
  4391. xnn_operator_t slice_op,
  4392. size_t num_dims,
  4393. const size_t* input_shape,
  4394. const size_t* offsets,
  4395. const size_t* sizes,
  4396. pthreadpool_t threadpool);
  4397. enum xnn_status xnn_setup_slice_nd_x8(
  4398. xnn_operator_t slice_op,
  4399. const void* input,
  4400. void* output);
  4401. enum xnn_status xnn_create_space_to_depth_nhwc_x8(
  4402. uint32_t block_size,
  4403. uint32_t flags,
  4404. xnn_operator_t* space_to_depth_op_out);
  4405. enum xnn_status xnn_reshape_space_to_depth_nhwc_x8(
  4406. xnn_operator_t space_to_depth_op,
  4407. size_t batch_size,
  4408. size_t input_height,
  4409. size_t input_width,
  4410. size_t input_channels,
  4411. size_t* output_height_out,
  4412. size_t* output_width_out,
  4413. size_t* output_channels_out,
  4414. pthreadpool_t threadpool);
  4415. enum xnn_status xnn_setup_space_to_depth_nhwc_x8(
  4416. xnn_operator_t space_to_depth_op,
  4417. const void* input,
  4418. void* output);
  4419. #ifdef __cplusplus
  4420. } // extern "C"
  4421. #endif