dnnl.h 180 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988
  1. /*******************************************************************************
  2. * Copyright 2016-2024 Intel Corporation
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. *******************************************************************************/
  16. /// @file
  17. /// C API
  18. #ifndef ONEAPI_DNNL_DNNL_H
  19. #define ONEAPI_DNNL_DNNL_H
  20. #include "oneapi/dnnl/dnnl_common.h"
  21. #include "oneapi/dnnl/dnnl_config.h"
  22. #include "oneapi/dnnl/dnnl_types.h"
  23. #include "oneapi/dnnl/dnnl_version.h"
  24. #ifdef __cplusplus
  25. extern "C" {
  26. #endif
  27. /// @addtogroup dnnl_api
  28. /// @{
  29. /// @addtogroup dnnl_api_primitives
  30. /// @{
  31. /// @addtogroup dnnl_api_primitives_common
  32. /// @{
  33. /// Changes the primitive descriptor to point to the next available
  34. /// implementation.
  35. ///
  36. /// @param primitive_desc A primitive descriptor to change.
  37. /// @returns #dnnl_success on success and a status describing the error
  38. /// otherwise.
  39. /// @returns #dnnl_last_impl_reached if no more implementations available,
  40. /// in which case the primitive descriptor itself is kept unchanged.
  41. dnnl_status_t DNNL_API dnnl_primitive_desc_next_impl(
  42. dnnl_primitive_desc_t primitive_desc);
  43. /// Clones a primitive descriptor. The resulting primitive descriptor must be
  44. /// destroyed separately.
  45. ///
  46. /// @param primitive_desc Output primitive descriptor.
  47. /// @param existing_primitive_desc Primitive descriptor to clone.
  48. /// @returns #dnnl_success on success and a status describing the error
  49. /// otherwise.
  50. dnnl_status_t DNNL_API dnnl_primitive_desc_clone(
  51. dnnl_primitive_desc_t *primitive_desc,
  52. const_dnnl_primitive_desc_t existing_primitive_desc);
  53. /// Returns a constant reference to the attributes of a primitive descriptor.
  54. ///
  55. /// @warning
  56. /// It is an error to destroy the resulting @p attr.
  57. ///
  58. /// @warning
  59. /// The lifetime of an @p attr is the same as that of a @p
  60. /// primitive_desc, so it is an error to use the @p attr once the @p
  61. /// primitive_desc has been destroyed.
  62. ///
  63. /// @param primitive_desc Primitive descriptor.
  64. /// @param attr Output primitive attributes.
  65. /// @returns #dnnl_success on success and a status describing the error
  66. /// otherwise.
  67. dnnl_status_t DNNL_API dnnl_primitive_desc_get_attr(
  68. const_dnnl_primitive_desc_t primitive_desc,
  69. const_dnnl_primitive_attr_t *attr);
  70. /// Destroys a primitive descriptor.
  71. ///
  72. /// @param primitive_desc Primitive descriptor to destroy.
  73. /// @returns #dnnl_success on success and a status describing the error
  74. /// otherwise.
  75. dnnl_status_t DNNL_API dnnl_primitive_desc_destroy(
  76. dnnl_primitive_desc_t primitive_desc);
  77. /// Queries a primitive descriptor for various pieces of information.
  78. ///
  79. /// The most common use case is to query a primitive descriptor, created with
  80. /// source, weights, and destination memory descriptors with format tags set
  81. /// to #dnnl_format_tag_any, for the corresponding memory descriptors (in this
  82. /// case the @p what is set to #dnnl_query_src_md, #dnnl_query_weights_md, and
  83. /// #dnnl_query_dst_md respectively) so that it is possible to create memory
  84. /// objects and reorder primitives if necessary.
  85. ///
  86. /// Another typical use case is to query a primitive descriptor for workspace
  87. /// memory descriptor (with @p what set to #dnnl_query_workspace_md). If this
  88. /// query returns #dnnl_not_required status, then workspace memory is not
  89. /// required.
  90. ///
  91. /// @note
  92. /// When querying for a memory descriptor for a scratchpad, a workspace,
  93. /// or an optional parameter, the query will return a pointer to a zero
  94. /// memory descriptor if the parameter is not needed.
  95. ///
  96. /// A few other use cases:
  97. /// - query a primitive descriptor for the implementation information string
  98. /// (#dnnl_query_impl_info_str)
  99. /// - query a primitive descriptor for the number of inputs and outputs
  100. /// (#dnnl_query_num_of_inputs_s32 and #dnnl_query_num_of_outputs_s32
  101. /// respectively)
  102. ///
  103. /// @sa dnnl_query_t for more options
  104. ///
  105. /// @param primitive_desc Primitive descriptor.
  106. /// @param what Parameter to query.
  107. /// @param index Index of the parameter to query for.
  108. /// @param result Output result. The type depends on the query. For example,
  109. /// it must be a @c dnnl_memory_desc_t* if querying for a memory
  110. /// descriptor.
  111. /// @returns #dnnl_success on success and a status describing the error
  112. /// otherwise.
  113. dnnl_status_t DNNL_API dnnl_primitive_desc_query(
  114. const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
  115. int index, void *result);
  116. /// Queries primitive descriptor for a memory descriptor.
  117. ///
  118. /// @note
  119. /// This function is a convenience version of
  120. /// #dnnl_primitive_desc_query().
  121. ///
  122. /// @param primitive_desc Primitive descriptor.
  123. /// @param what Kind of memory descriptor parameter to query for.
  124. /// @param index Index of the parameter to query.
  125. /// @returns A pointer to the requested memory descriptor.
  126. /// @returns A pointer to a zero memory descriptor if the parameter is not
  127. /// needed.
  128. /// @returns NULL in case of any error.
  129. ///
  130. const_dnnl_memory_desc_t DNNL_API dnnl_primitive_desc_query_md(
  131. const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
  132. int index);
  133. /// Queries primitive descriptor for a signed 32bit int.
  134. ///
  135. /// @note
  136. /// This function is a convenience version of
  137. /// #dnnl_primitive_desc_query().
  138. ///
  139. /// @param primitive_desc Primitive descriptor.
  140. /// @param what Kind of the value to query for.
  141. /// @param index Index of the parameter to query.
  142. /// @returns The requested value.
  143. /// @returns 0 in case of any error (in particular if the queried entity is
  144. /// not of type int32_t). Note that 0 may also be the actual returned
  145. /// value.
  146. int DNNL_API dnnl_primitive_desc_query_s32(
  147. const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
  148. int index);
  149. /// Creates a primitive.
  150. ///
  151. /// @param primitive Output primitive.
  152. /// @param primitive_desc Primitive descriptor used to create the primitive.
  153. /// @returns #dnnl_success on success and a status describing the error
  154. /// otherwise.
  155. dnnl_status_t DNNL_API dnnl_primitive_create(dnnl_primitive_t *primitive,
  156. const_dnnl_primitive_desc_t primitive_desc);
  157. /// Creates a primitive from a cache blob.
  158. ///
  159. /// @param primitive Output primitive.
  160. /// @param primitive_desc Primitive descriptor used to create the primitive.
  161. /// @param size Size of the cache blob in bytes.
  162. /// @param cache_blob Cache blob of size @p size.
  163. /// @returns #dnnl_success on success and a status describing the error
  164. /// otherwise.
  165. dnnl_status_t DNNL_API dnnl_primitive_create_from_cache_blob(
  166. dnnl_primitive_t *primitive, const_dnnl_primitive_desc_t primitive_desc,
  167. size_t size, const uint8_t *cache_blob);
  168. /// Executes a primitive.
  169. ///
  170. /// @param primitive Primitive to execute.
  171. /// @param stream Stream to use.
  172. /// @param nargs Number of arguments.
  173. /// @param args Array of arguments. Each argument is an
  174. /// <index, #dnnl_memory_t> pair. The index is one of the `DNNL_ARG_*`
  175. /// values such as `DNNL_ARG_SRC`. Unless runtime shapes are used (see
  176. /// #DNNL_RUNTIME_DIM_VAL), the memory object must have the same memory
  177. /// descriptor as that returned by
  178. /// #dnnl_primitive_desc_query_md(#dnnl_query_exec_arg_md, index).
  179. /// @returns #dnnl_success on success and a status describing the error
  180. /// otherwise.
  181. /// @note If any argument in @p args is padded (padded_dims >
  182. /// dims), the primitive execution will assume properly zero-padded
  183. /// input arguments, and produce zero-padded output arguments.
  184. dnnl_status_t DNNL_API dnnl_primitive_execute(const_dnnl_primitive_t primitive,
  185. dnnl_stream_t stream, int nargs, const dnnl_exec_arg_t *args);
  186. /// Retrieves a constant reference to the primitive descriptor of a given
  187. /// primitive.
  188. ///
  189. /// @warning
  190. /// It is an error to destroy the returned object. It is owned by the
  191. /// primitive. The @c const qualifier of the returned object prevents
  192. /// such attempts.
  193. ///
  194. /// @param primitive Primitive to query for the primitive descriptor.
  195. /// @param primitive_desc Output primitive descriptor.
  196. /// @returns #dnnl_success on success and a status describing the error
  197. /// otherwise.
  198. dnnl_status_t DNNL_API dnnl_primitive_get_primitive_desc(
  199. const_dnnl_primitive_t primitive,
  200. const_dnnl_primitive_desc_t *primitive_desc);
  201. /// Retrieves a cache blob associated with the given primitive.
  202. ///
  203. /// @param primitive Primitive to query for the cache blob.
  204. /// @param size Size of the cache blob in bytes.
  205. /// @param cache_blob Cache blob of size @p size. If the @p cache_blob is
  206. /// nullptr then the size of the cache blob is returned in @p size.
  207. /// @returns #dnnl_success on success and a status describing the error
  208. /// otherwise.
  209. ///
  210. /// @note The cache blob can be empty. It's the user's responsibility to check
  211. /// whether it's empty prior to passing it to
  212. /// #dnnl_primitive_create_from_cache_blob().
  213. dnnl_status_t DNNL_API dnnl_primitive_get_cache_blob(
  214. const_dnnl_primitive_t primitive, size_t *size, uint8_t *cache_blob);
  215. /// Destroys a primitive.
  216. ///
  217. /// @param primitive The primitive to destroy.
  218. /// @returns #dnnl_success on success and a status describing the error
  219. /// otherwise.
  220. dnnl_status_t DNNL_API dnnl_primitive_destroy(dnnl_primitive_t primitive);
  221. /// @} dnnl_api_primitives_common
  222. /// @addtogroup dnnl_api_attributes
  223. /// @{
  224. /// Creates an empty (default) primitive attributes with all the parameters
  225. /// set to their default values.
  226. ///
  227. /// Empty attributes are implied whenever the respective argument is NULL.
  228. ///
  229. /// @param attr Output primitive attributes.
  230. /// @returns #dnnl_success on success and a status describing the error
  231. /// otherwise.
  232. dnnl_status_t DNNL_API dnnl_primitive_attr_create(dnnl_primitive_attr_t *attr);
  233. /// Clones primitive attributes.
  234. ///
  235. /// @param attr Output primitive attributes.
  236. /// @param existing_attr Primitive attributes to clone.
  237. /// @returns #dnnl_success on success and a status describing the error
  238. /// otherwise.
  239. dnnl_status_t DNNL_API dnnl_primitive_attr_clone(
  240. dnnl_primitive_attr_t *attr, const_dnnl_primitive_attr_t existing_attr);
  241. /// Destroys primitive attributes.
  242. ///
  243. /// @param attr Primitive attributes to destroy.
  244. /// @returns #dnnl_success on success and a status describing the error
  245. /// otherwise.
  246. dnnl_status_t DNNL_API dnnl_primitive_attr_destroy(dnnl_primitive_attr_t attr);
  247. /// Returns probability for output dropout primitive attribute.
  248. ///
  249. /// @param attr Primitive attributes.
  250. /// @param dropout_desc Output dropout memory descriptor
  251. /// @returns #dnnl_success on success and a status describing the error
  252. /// otherwise.
  253. dnnl_status_t DNNL_API dnnl_primitive_attr_get_dropout(
  254. const_dnnl_primitive_attr_t attr,
  255. const_dnnl_memory_desc_t *dropout_desc);
  256. /// Sets probability for output dropout primitive attribute.
  257. ///
  258. /// @param attr Primitive attributes.
  259. /// @param dropout_desc Output dropout memory descriptor
  260. /// @returns #dnnl_success on success and a status describing the error
  261. /// otherwise.
  262. dnnl_status_t DNNL_API dnnl_primitive_attr_set_dropout(
  263. dnnl_primitive_attr_t attr, const_dnnl_memory_desc_t dropout_desc);
  264. /// Returns the floating-point math mode primitive attribute.
  265. ///
  266. /// @param attr Primitive attributes.
  267. /// @param mode Output FP math mode.
  268. /// @returns #dnnl_success on success and a status describing the error
  269. /// otherwise.
  270. dnnl_status_t DNNL_API dnnl_primitive_attr_get_fpmath_mode(
  271. const_dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t *mode);
  272. /// Sets the floating-point math mode primitive attributes.
  273. ///
  274. /// @param attr Primitive attributes.
  275. /// @param mode FP math mode. The possible values are:
  276. /// #dnnl_fpmath_mode_strict (default),
  277. /// #dnnl_fpmath_mode_bf16,
  278. /// #dnnl_fpmath_mode_f16,
  279. /// #dnnl_fpmath_mode_tf32,
  280. /// #dnnl_fpmath_mode_any.
  281. /// @returns #dnnl_success on success and a status describing the error
  282. /// otherwise.
  283. dnnl_status_t DNNL_API dnnl_primitive_attr_set_fpmath_mode(
  284. dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t mode);
  285. /// Returns the floating-point math mode primitive attribute.
  286. ///
  287. /// @param attr Primitive attributes.
  288. /// @param mode Output FP math mode.
  289. /// @param apply_to_int Output use floating-point arithmetic for integer primitives.
  290. /// @returns #dnnl_success on success and a status describing the error
  291. /// otherwise.
  292. dnnl_status_t DNNL_API dnnl_primitive_attr_get_fpmath_mode_v2(
  293. const_dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t *mode,
  294. int *apply_to_int);
  295. /// Sets the floating-point math mode primitive attributes.
  296. ///
  297. /// @param attr Primitive attributes.
  298. /// @param mode FP math mode. The possible values are:
  299. /// #dnnl_fpmath_mode_strict (default),
  300. /// #dnnl_fpmath_mode_bf16,
  301. /// #dnnl_fpmath_mode_f16,
  302. /// #dnnl_fpmath_mode_tf32,
  303. /// #dnnl_fpmath_mode_any.
  304. /// @param apply_to_int Boolean. Use of floating-point arithmetic for integer primitives.
  305. /// @returns #dnnl_success on success and a status describing the error
  306. /// otherwise.
  307. dnnl_status_t DNNL_API dnnl_primitive_attr_set_fpmath_mode_v2(
  308. dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t mode, int apply_to_int);
  309. /// Returns the deterministic primitive attribute value.
  310. ///
  311. /// @param attr Primitive attributes.
  312. /// @param value Output deterministic attribute value
  313. /// @returns #dnnl_success on success and a status describing the error
  314. /// otherwise.
  315. dnnl_status_t DNNL_API dnnl_primitive_attr_get_deterministic(
  316. const_dnnl_primitive_attr_t attr, int *value);
  317. /// Sets the deterministic primitive attribute value.
  318. ///
  319. /// @param attr Primitive attributes.
  320. /// @param value Boolean value to set deterministic attribute.
  321. /// @returns #dnnl_success on success and a status describing the error
  322. /// otherwise.
  323. dnnl_status_t DNNL_API dnnl_primitive_attr_set_deterministic(
  324. dnnl_primitive_attr_t attr, int value);
  325. /// Returns the accumulation mode primitive attribute.
  326. ///
  327. /// @param attr Primitive attributes.
  328. /// @param mode Output accumulation mode.
  329. /// @returns #dnnl_success on success and a status describing the error
  330. /// otherwise.
  331. dnnl_status_t DNNL_API dnnl_primitive_attr_get_accumulation_mode(
  332. const_dnnl_primitive_attr_t attr, dnnl_accumulation_mode_t *mode);
  333. /// Sets the accumulation mode primitive attribute.
  334. ///
  335. /// @param attr Primitive attributes.
  336. /// @param mode Accumulation mode. The possible values are:
  337. /// #dnnl_accumulation_mode_strict (default), which is s32 for quantized primitives, f32/f64 otherwise
  338. /// #dnnl_accumulation_mode_relaxed, which is same as strict but allows intermediate accumulators to be in src/dst datatype
  339. /// #dnnl_accumulation_mode_any, which allows accumulators to be src/dst datatype or any wider type.
  340. /// #dnnl_accumulation_mode_f32,
  341. /// #dnnl_accumulation_mode_s32,
  342. /// #dnnl_accumulation_mode_f16.
  343. /// @returns #dnnl_success on success and a status describing the error
  344. /// otherwise.
  345. dnnl_status_t DNNL_API dnnl_primitive_attr_set_accumulation_mode(
  346. dnnl_primitive_attr_t attr, dnnl_accumulation_mode_t mode);
  347. /// Returns the primitive attributes scratchpad mode.
  348. ///
  349. /// @param attr Primitive attributes.
  350. /// @param mode Output scratchpad mode.
  351. /// @returns #dnnl_success on success and a status describing the error
  352. /// otherwise.
  353. dnnl_status_t DNNL_API dnnl_primitive_attr_get_scratchpad_mode(
  354. const_dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t *mode);
  355. /// Sets primitive attributes scratchpad mode.
  356. ///
  357. /// @param attr Primitive attributes.
  358. /// @param mode Scratchpad mode. The possible values are:
  359. /// #dnnl_scratchpad_mode_library (default) and
  360. /// #dnnl_scratchpad_mode_user.
  361. /// @returns #dnnl_success on success and a status describing the error
  362. /// otherwise.
  363. dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(
  364. dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t mode);
  365. /// Sets primitive attributes scaling factors for primitive operations for a
  366. /// given memory argument. The scaling factors must be passed at execution time
  367. /// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
  368. ///
  369. /// @sa dnnl_primitive_attr_set_scales_mask
  370. ///
  371. ///
  372. /// @param attr Primitive attributes.
  373. /// @param arg Parameter argument index as passed to the
  374. /// dnnl_primitive_execute() call.
  375. /// @param mask Scaling factors correspondence mask that defines the
  376. /// correspondence between the tensor dimensions and the @p scales array.
  377. /// The set i-th bit indicates that a dedicated scaling factor is used for
  378. /// each index along that dimension. Set the mask to 0 to use a common
  379. /// scaling factor for the whole output tensor.
  380. /// @returns #dnnl_success on success and a status describing the error
  381. /// otherwise.
  382. dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_mask(
  383. dnnl_primitive_attr_t attr, int arg, int mask);
  384. /// Sets primitive attributes scaling factors for primitive operations for a
  385. /// given memory argument. The scaling factors must be passed at execution time
  386. /// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
  387. ///
  388. /// @sa dnnl_primitive_attr_set_scales
  389. ///
  390. ///
  391. /// @param attr Primitive attributes.
  392. /// @param arg Parameter argument index as passed to the
  393. /// dnnl_primitive_execute() call.
  394. /// @param mask Scaling factors correspondence mask that defines the
  395. /// correspondence between the tensor dimensions and the @p scales array.
  396. /// The set i-th bit indicates that a dedicated scaling factor is used for
  397. /// each index along that dimension. Set the mask to 0 to use a common
  398. /// scaling factor for the whole output tensor.
  399. /// @param ndims Number of group dimensions.
  400. /// @param group_dims Scaling factors correspondence groups that define the
  401. /// correspondence between the tensor dimensions and the scales array.
  402. /// The group dimensions should only be provided for each logical dimension
  403. /// that has correspondence mask @p mask set.
  404. /// @param data_type Scaling factors data_type.
  405. /// @returns #dnnl_success on success and a status describing the error
  406. /// otherwise.
  407. dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales(
  408. dnnl_primitive_attr_t attr, int arg, int mask, int ndims,
  409. const dnnl_dims_t group_dims, dnnl_data_type_t data_type);
  410. /// Sets primitive attributes zero points for primitive operations for a given
  411. /// memory argument. The zero points must be passed at execution time
  412. /// as an argument with index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
  413. ///
  414. /// @sa dnnl_primitive_attr_set_zero_points_mask
  415. ///
  416. ///
  417. /// @param attr Primitive attributes.
  418. /// @param arg Parameter argument index as passed to the
  419. /// dnnl_primitive_execute() call.
  420. /// @param mask Zero point correspondence mask that defines the
  421. /// correspondence between the tensor dimensions and the @p
  422. /// zero_points array. The set i-th bit indicates that a dedicated
  423. /// zero point is used for each index along that dimension. Set the
  424. /// mask to 0 to use a common zero point for the whole output tensor.
  425. /// @returns #dnnl_success on success and a status describing the error
  426. /// otherwise.
  427. dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points_mask(
  428. dnnl_primitive_attr_t attr, int arg, int mask);
  429. /// Sets primitive attributes zero points for primitive operations for a given
  430. /// memory argument. The zero points must be passed at execution time
  431. /// as an argument with index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
  432. ///
  433. /// @sa dnnl_primitive_attr_set_zero_points
  434. ///
  435. ///
  436. /// @param attr Primitive attributes.
  437. /// @param arg Parameter argument index as passed to the
  438. /// dnnl_primitive_execute() call.
  439. /// @param mask Zero point correspondence mask that defines the
  440. /// correspondence between the tensor dimensions and the @p
  441. /// zero_points array. The set i-th bit indicates that a dedicated
  442. /// zero point is used for each index along that dimension. Set the
  443. /// mask to 0 to use a common zero point for the whole output tensor.
  444. /// @param ndims Number of group dimensions.
  445. /// @param group_dims Zero point factors correspondence groups that define the
  446. /// correspondence between the tensor dimensions and the zero_points array.
  447. /// The group dimensions should be only provided for each logical dimension
  448. /// that has the bit set correspondence mask @p mask set.
  449. /// @param data_type Zero points factors data_type.
  450. /// @returns #dnnl_success on success and a status describing the error
  451. /// otherwise.
  452. dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points(
  453. dnnl_primitive_attr_t attr, int arg, int mask, int ndims,
  454. const dnnl_dims_t group_dims, dnnl_data_type_t data_type);
  455. /// Sets the rounding mode attribute value for a given argument
  456. ///
  457. /// @param attr Primitive attributes.
  458. /// @param arg Argument for which rounding mode should be set.
  459. /// @param mode Rounding mode to apply to the argument.
  460. /// @returns #dnnl_success on success and a status describing the error
  461. /// otherwise.
  462. dnnl_status_t DNNL_API dnnl_primitive_attr_set_rounding(
  463. dnnl_primitive_attr_t attr, int arg, dnnl_rounding_mode_t mode);
  464. /// Returns the rounding mode attribute value for a given argument
  465. ///
  466. /// @param attr Primitive attributes.
  467. /// @param arg Argument for which rounding mode query applies.
  468. /// @param mode Output rounding mode.
  469. /// @returns #dnnl_success on success and a status describing the error
  470. /// otherwise.
  471. dnnl_status_t DNNL_API dnnl_primitive_attr_get_rounding(
  472. dnnl_primitive_attr_t attr, int arg, dnnl_rounding_mode_t *mode);
  473. /// Returns primitive attributes post-ops.
  474. ///
  475. /// @warning
  476. /// The output @p post_ops points to the internal @p attr field, so it is
  477. /// an error to modify or destroy them. The lifetime of @p post_ops is
  478. /// the same as that of the @p attr it belongs to, so it is an error to
  479. /// use @p post_ops after @p attr has been destroyed.
  480. ///
  481. /// @param attr Primitive attributes.
  482. /// @param post_ops Output post-ops.
  483. /// @returns #dnnl_success on success and a status describing the error
  484. /// otherwise.
  485. dnnl_status_t DNNL_API dnnl_primitive_attr_get_post_ops(
  486. const_dnnl_primitive_attr_t attr, const_dnnl_post_ops_t *post_ops);
  487. /// Sets primitive attributes post-ops.
  488. ///
  489. /// @note
  490. /// There is no way to check whether the post-ops would be supported by
  491. /// the target primitive. Any error will be reported by the
  492. /// dnnl_<primitive name>_[propagation kind]_primitive_desc_create() function call.
  493. ///
  494. /// @param attr Primitive attributes.
  495. /// @param post_ops Post-ops to set.
  496. /// @returns #dnnl_success on success and a status describing the error
  497. /// otherwise.
  498. dnnl_status_t DNNL_API dnnl_primitive_attr_set_post_ops(
  499. dnnl_primitive_attr_t attr, const_dnnl_post_ops_t post_ops);
  500. /// Creates empty post-ops sequence.
  501. ///
  502. /// @param post_ops Output post-ops.
  503. /// @returns #dnnl_success on success and a status describing the error
  504. /// otherwise.
  505. dnnl_status_t DNNL_API dnnl_post_ops_create(dnnl_post_ops_t *post_ops);
  506. /// Clones post-ops primitive attribute.
  507. ///
  508. /// @param post_ops Output post-ops primitive attribute.
  509. /// @param existing_post_ops Post-ops primitive attribute to clone.
  510. /// @returns #dnnl_success on success and a status describing the error
  511. /// otherwise.
  512. dnnl_status_t DNNL_API dnnl_post_ops_clone(
  513. dnnl_post_ops_t *post_ops, const_dnnl_post_ops_t existing_post_ops);
  514. /// Destroys post-ops.
  515. ///
  516. /// @param post_ops Post-ops to destroy.
  517. /// @returns #dnnl_success on success and a status describing the error
  518. /// otherwise.
  519. dnnl_status_t DNNL_API dnnl_post_ops_destroy(dnnl_post_ops_t post_ops);
  520. /// Returns the length of post-ops.
  521. ///
  522. /// @param post_ops Post-ops.
  523. /// @returns The number of post-ops entries.
  524. int DNNL_API dnnl_post_ops_len(const_dnnl_post_ops_t post_ops);
  525. /// Returns the kind of a post-op entry.
  526. ///
  527. /// @param post_ops Post-ops.
  528. /// @param index Post-op entry index.
  529. /// @returns The kind of the post-op with the specified index.
  530. /// @returns #dnnl_undefined_primitive if there is no post-op at the specified
  531. /// index.
  532. dnnl_primitive_kind_t DNNL_API dnnl_post_ops_get_kind(
  533. const_dnnl_post_ops_t post_ops, int index);
  534. /// Appends an accumulation v3 (sum) to post-ops. Prior to accumulating the
  535. /// result, a zero point is subtracted from the previous value and is
  536. /// multiplied by the scale.
  537. ///
  538. /// The kind of this post-op is #dnnl_sum.
  539. ///
  540. /// This feature may improve performance for cases like dequantize the
  541. /// asymmetrically quantized sum's src1 tensor to f32 domain before performing
  542. /// the sum operation by subtracting the @p zero_point before the scaling.
  543. ///
  544. /// In the simplest case where accumulation is the only post-op, the
  545. /// computations will be:
  546. ///
  547. /// dst[:] <- scale * (dst[:] - zero_point) + op(...)
  548. /// // instead of dst[:] <- op(...)
  549. ///
  550. /// If @p data_type is specified, original dst tensor will be reinterpreted
  551. /// as a tensor with provided data type. Since it is reinterpretation,
  552. /// data_type and dst data type should have the same size.
  553. /// As a result, computations will be:
  554. ///
  555. /// dst[:] <- scale * (as_data_type(dst[:]) - zero_point) + op(...)
  556. /// // instead of dst[:] <- op(...)
  557. /// @note
  558. /// This post-op executes in-place and does not change the
  559. /// destination layout.
  560. ///
  561. /// @param post_ops Post-ops.
  562. /// @param scale Accumulation scaling factor.
  563. /// @param zero_point Single scalar int32_t value of zero point.
  564. /// @param data_type Accumulation data_type.
  565. /// @returns #dnnl_success on success and a status describing the error
  566. /// otherwise.
  567. dnnl_status_t DNNL_API dnnl_post_ops_append_sum(dnnl_post_ops_t post_ops,
  568. float scale, int32_t zero_point, dnnl_data_type_t data_type);
  569. /// Returns the parameters of an accumulation (sum) post-op with
  570. /// zero point and data type parameter.
  571. ///
  572. /// @param post_ops Post-ops.
  573. /// @param index Index of the sum post-op.
  574. /// @param scale Output accumulation scaling factor.
  575. /// @param zero_point Zero point.
  576. /// @param data_type Data type for accumulation.
  577. /// @returns #dnnl_success on success and a status describing the error
  578. /// otherwise.
  579. dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum(
  580. const_dnnl_post_ops_t post_ops, int index, float *scale,
  581. int32_t *zero_point, dnnl_data_type_t *data_type);
  582. /// Appends an elementwise post-op.
  583. ///
  584. /// The kind of this post operation is #dnnl_eltwise.
  585. ///
  586. /// In the simplest case when the elementwise is the only post operation, the
  587. /// computations would be:
  588. ///
  589. /// dst[:] <- eltwise_op (op(...)) // instead of dst[:] <- op(...)
  590. ///
  591. /// where eltwise_op is configured with the given parameters.
  592. ///
  593. /// @param post_ops Post-ops.
  594. /// @param alg_kind Elementwise algorithm for the post-op.
  595. /// @param alpha Alpha parameter for the elementwise algorithm.
  596. /// @param beta Beta parameter for the elementwise algorithm.
  597. /// @returns #dnnl_success on success and a status describing the error
  598. /// otherwise.
  599. dnnl_status_t DNNL_API dnnl_post_ops_append_eltwise(dnnl_post_ops_t post_ops,
  600. dnnl_alg_kind_t alg_kind, float alpha, float beta);
  601. /// Returns the parameters of an elementwise post-op.
  602. ///
  603. /// @param post_ops Post-ops.
  604. /// @param index Index of the elementwise post-op.
  605. /// @param alg_kind Output elementwise algorithm kind.
  606. /// @param alpha Output alpha parameter for the elementwise algorithm.
  607. /// @param beta Output beta parameter for the elementwise algorithm.
  608. /// @returns #dnnl_success on success and a status describing the error
  609. /// otherwise.
  610. /// @returns #dnnl_invalid_arguments if @p index does not refer to an
  611. /// elementwise post-op.
  612. dnnl_status_t DNNL_API dnnl_post_ops_get_params_eltwise(
  613. const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind,
  614. float *alpha, float *beta);
  615. /// Appends a depthwise post-op convolution.
  616. ///
  617. /// This post-op can only be fused with a 2D 1x1 convolution (convolution with
  618. /// weights spatial dimensions equal to 1 i.e., kh=kw=1).
  619. ///
  620. /// The kind of this post-op is #dnnl_convolution.
  621. ///
  622. /// The number of outputs for primitive with fusion is one. The output spatial
  623. /// size can be derived as below:
  624. ///
  625. /// output_height = ceil(output_height_1x1_convolution, stride)
  626. /// output_width = ceil(output_width_1x1_convolution, stride)
  627. ///
  628. /// See @ref dev_guide_attributes_post_ops_depthwise and
  629. /// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
  630. ///
  631. /// @param post_ops Post-ops.
  632. /// @param weights_data_type Weights data type of depthwise post-op
  633. /// @param bias_data_type Bias data type of depthwise post-op
  634. /// @param dst_data_type Output data type of depthwise post-op
  635. /// @param kernel_size Size of kernel of depthwise post-op
  636. /// @param stride_size Size of stride of depthwise post-op
  637. /// @param padding_l_size Size of left and top paddings of depthwise post-op
  638. /// @returns #dnnl_success on success and a status describing the error
  639. /// otherwise
  640. dnnl_status_t DNNL_API dnnl_post_ops_append_dw(dnnl_post_ops_t post_ops,
  641. dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type,
  642. dnnl_data_type_t dst_data_type, dnnl_dim_t kernel_size,
  643. dnnl_dim_t stride_size, dnnl_dim_t padding_l_size);
  644. /// Returns the parameters of an depthwise post-op.
  645. ///
  646. /// @param post_ops Post-ops.
  647. /// @param index Index of the elementwise post-op.
  648. /// @param weights_data_type Weights data type of depthwise post-op
  649. /// @param bias_data_type Bias data type of depthwise post-op
  650. /// @param dst_data_type Output data type of depthwise post-op
  651. /// @param kernel_size Size of kernel of depthwise post-op
  652. /// @param stride_size Size of stride of depthwise post-op
  653. /// @param padding_l_size Size of left and top paddings of depthwise post-op
  654. /// @returns #dnnl_success on success and a status describing the error
  655. /// otherwise
  656. dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw(
  657. const_dnnl_post_ops_t post_ops, int index,
  658. dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type,
  659. dnnl_data_type_t *dst_data_type, dnnl_dim_t *kernel_size,
  660. dnnl_dim_t *stride_size, dnnl_dim_t *padding_l_size);
  661. /// Appends a binary post-op.
  662. ///
  663. /// The kind of this post operation is #dnnl_binary.
  664. ///
  665. /// In the simplest case when the binary is the only post operation, the
  666. /// computations would be:
  667. ///
  668. /// dst[:] <- binary_op (dst[:], another_input[:])
  669. ///
  670. /// where binary_op is configured with the given parameters. binary_op supports
  671. /// broadcast semantics for a second operand.
  672. ///
  673. /// @param post_ops Post-ops.
  674. /// @param alg_kind Binary algorithm for the post-op.
  675. /// @param src1_desc Memory descriptor of a second operand.
  676. /// @returns #dnnl_success on success and a status describing the error
  677. /// otherwise.
  678. dnnl_status_t DNNL_API dnnl_post_ops_append_binary(dnnl_post_ops_t post_ops,
  679. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src1_desc);
  680. /// Returns the parameters of a binary post-op.
  681. ///
  682. /// @param post_ops Post-ops.
  683. /// @param index Index of the binary post-op.
  684. /// @param alg_kind Output binary algorithm kind.
  685. /// @param src1_desc Output memory descriptor of a second operand.
  686. /// @returns #dnnl_success on success and a status describing the error
  687. /// otherwise.
  688. /// @returns #dnnl_invalid_arguments if @p index does not refer to a binary
  689. /// post-op.
  690. dnnl_status_t DNNL_API dnnl_post_ops_get_params_binary(
  691. const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind,
  692. const_dnnl_memory_desc_t *src1_desc);
  693. /// Appends a prelu forward post-op.
  694. ///
  695. /// The kind of this post-op is #dnnl::primitive::kind::prelu.
  696. ///
  697. /// The post-op can be defined as:
  698. ///
  699. /// dst[:] <- prelu(dst[:], weights[:])
  700. /// prelu:
  701. /// dst[:] <- dst[:] if dst[:] > 0
  702. /// dst[:] <- dst[:] * weights[:] if dst[:] <= 0
  703. ///
  704. ///
  705. /// @note
  706. /// The order of dimensions does not depend on how elements are laid
  707. /// out in memory. For example:
  708. /// - for a 2D CNN activations tensor the order is always (n, c)
  709. /// - for a 4D CNN activations tensor the order is always (n, c, h, w)
  710. /// - for a 5D CNN weights tensor the order is always
  711. /// (g, oc, ic, kh, kw)
  712. ///
  713. /// Prelu weights tensor is passed in runtime execution phase. Prelu
  714. /// weights tensor data type is implicitly assumed as f32 using plain
  715. /// layout (a, ab, acb, acdb, acdeb)
  716. ///
  717. /// @param post_ops Post-ops.
  718. /// @param mask Defines the correspondence between the output tensor
  719. /// dimensions and the prelu weights tensor. The set i-th bit indicates
  720. /// that a dedicated weights value is used for each index along that
  721. /// dimension. Set the mask to 0 to use a common weights value
  722. /// for the whole output tensor.
  723. /// @returns #dnnl_success on success and a status describing the error
  724. /// otherwise.
  725. dnnl_status_t DNNL_API dnnl_post_ops_append_prelu(
  726. dnnl_post_ops_t post_ops, int mask);
  727. /// Returns the parameters of a prelu post-op.
  728. ///
  729. /// @param post_ops Post-ops.
  730. /// @param index Index of the prelu post-op.
  731. /// @param mask Mask of the prelu post-op.
  732. /// @returns #dnnl_success on success and a status describing the error
  733. /// otherwise.
  734. dnnl_status_t DNNL_API dnnl_post_ops_get_params_prelu(
  735. const_dnnl_post_ops_t post_ops, int index, int *mask);
  736. /// @} dnnl_api_attributes
  737. /// @} dnnl_api_primitives
  738. /// @addtogroup dnnl_api_memory
  739. /// @{
  740. /// Destroys a memory descriptor.
  741. ///
  742. /// @param memory_desc Memory descriptor to destroy.
  743. /// @returns #dnnl_success on success and a status describing the error
  744. /// otherwise.
  745. dnnl_status_t DNNL_API dnnl_memory_desc_destroy(dnnl_memory_desc_t memory_desc);
  746. /// Clones a memory descriptor. The resulting memory descriptor must be
  747. /// destroyed separately.
  748. ///
  749. /// @param memory_desc Output memory descriptor.
  750. /// @param existing_memory_desc Memory descriptor to clone.
  751. /// @returns #dnnl_success on success and a status describing the error
  752. /// otherwise.
  753. dnnl_status_t DNNL_API dnnl_memory_desc_clone(dnnl_memory_desc_t *memory_desc,
  754. const_dnnl_memory_desc_t existing_memory_desc);
  755. /// Retrieves a binary blob associated with the given memory descriptor
  756. ///
  757. /// @param blob Output pointer to binary blob.
  758. /// If not nullptr, size bytes of the memory descriptor blob are written.
  759. /// @param size Output pointer to the size of the binary blob in bytes.
  760. /// Size is written if blob is nullptr.
  761. /// @param memory_desc input memory descriptor to serialize
  762. /// @returns #dnnl_success on success and a status describing the error
  763. /// otherwise.
  764. dnnl_status_t DNNL_API dnnl_memory_desc_get_blob(
  765. uint8_t *blob, size_t *size, const_dnnl_memory_desc_t memory_desc);
  766. /// Creates a memory descriptor from a memory descriptor binary blob.
  767. ///
  768. /// @param memory_desc Output pointer to a newly allocated memory descriptor.
  769. /// @param blob Pointer to a memory descriptor binary blob.
  770. /// @returns #dnnl_success on success and a status describing the error
  771. /// otherwise.
  772. dnnl_status_t DNNL_API dnnl_memory_desc_create_with_blob(
  773. dnnl_memory_desc_t *memory_desc, const uint8_t *blob);
  774. /// Creates a memory descriptor using dimensions and strides.
  775. ///
  776. /// @note
  777. /// As always, the logical order of dimensions corresponds to the `abc...`
  778. /// format tag, and the physical meaning of the dimensions depends on both
  779. /// the primitive that consumes the memory and the context of that
  780. /// consumption.
  781. ///
  782. /// @param memory_desc Output memory descriptor.
  783. /// @param ndims Number of dimensions
  784. /// @param dims Array of dimensions.
  785. /// @param data_type Elements data type.
  786. /// @param strides Strides in each dimension.
  787. /// @returns #dnnl_success on success and a status describing the error
  788. /// otherwise.
  789. dnnl_status_t DNNL_API dnnl_memory_desc_create_with_strides(
  790. dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
  791. dnnl_data_type_t data_type, const dnnl_dims_t strides);
  792. /// Creates a memory descriptor using dimensions and memory format tag.
  793. ///
  794. /// @note
  795. /// As always, the logical order of dimensions corresponds to the `abc...`
  796. /// format tag, and the physical meaning of the dimensions depends on both
  797. /// the primitive that consumes the memory and the context of that
  798. /// consumption.
  799. ///
  800. /// @param memory_desc Output memory descriptor.
  801. /// @param ndims Number of dimensions
  802. /// @param dims Array of dimensions.
  803. /// @param data_type Elements data type.
  804. /// @param tag Memory format tag. Can be #dnnl_format_tag_any which would
  805. /// allow a primitive to chose the final memory format. In this case the
  806. /// format_kind field of the memory descriptor would be set to
  807. /// #dnnl_format_kind_any.
  808. /// @returns #dnnl_success on success and a status describing the error
  809. /// otherwise.
  810. dnnl_status_t DNNL_API dnnl_memory_desc_create_with_tag(
  811. dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
  812. dnnl_data_type_t data_type, dnnl_format_tag_t tag);
  813. #ifdef DNNL_EXPERIMENTAL_SPARSE
  814. /// Creates a memory descriptor for CSR encoding.
  815. ///
  816. /// @param memory_desc Output memory descriptor.
  817. /// @param ndims Number of dimensions
  818. /// @param dims Array of dimensions.
  819. /// @param data_type Elements data type.
  820. /// @param nnz Number of non-zero entries.
  821. /// @param indices_dt Data type of indices.
  822. /// @param pointers_dt Data type of pointers.
  823. /// @returns #dnnl_success on success and a status describing the error
  824. /// otherwise.
  825. dnnl_status_t DNNL_API dnnl_memory_desc_create_with_csr_encoding(
  826. dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
  827. dnnl_data_type_t data_type, dnnl_dim_t nnz, dnnl_data_type_t indices_dt,
  828. dnnl_data_type_t pointers_dt);
  829. /// Creates a memory descriptor for COO encoding.
  830. ///
  831. /// The created memory descriptor will describe a memory object that
  832. /// contains n+1 buffers for an n-dimensional tensor.
  833. /// The buffers have the following meaning and assigned numbers (index):
  834. /// - 0: values
  835. /// - 1: indices for dimension 0
  836. /// - 2: indices for dimension 1 ...
  837. /// - n: indices for dimension n-1
  838. ///
  839. /// @param memory_desc Output memory descriptor.
  840. /// @param ndims Number of dimensions.
  841. /// @param dims Array of dimensions.
  842. /// @param data_type Elements data type.
  843. /// @param nnz Number of non-zero entries.
  844. /// @param indices_dt Data type of indices.
  845. /// @returns #dnnl_success on success and a status describing the error
  846. /// otherwise.
  847. dnnl_status_t DNNL_API dnnl_memory_desc_create_with_coo_encoding(
  848. dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
  849. dnnl_data_type_t data_type, dnnl_dim_t nnz,
  850. dnnl_data_type_t indices_dt);
  851. /// Creates a memory descriptor for packed sparse encoding.
  852. ///
  853. /// The created memory descriptor cannot be used to create a memory
  854. /// object. It can only be used to create a primitive descriptor to
  855. /// query the actual memory descriptor (similar to the format tag
  856. /// `any`).
  857. ///
  858. /// @warning
  859. /// The meaning and content of the handles of the memory object that
  860. /// is created using the queried memory descriptor are unspecified
  861. /// therefore using the content is an undefined behavior.
  862. ///
  863. /// @param memory_desc Output memory descriptor.
  864. /// @param ndims Number of dimensions
  865. /// @param dims Array of dimensions.
  866. /// @param data_type Elements data type.
  867. /// @param nnz Number of non-zero entries.
  868. /// @returns #dnnl_success on success and a status describing the error
  869. /// otherwise.
  870. dnnl_status_t DNNL_API dnnl_memory_desc_create_with_packed_encoding(
  871. dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
  872. dnnl_data_type_t data_type, dnnl_dim_t nnz);
  873. #endif
  874. /// Creates a memory descriptor for a region inside an area
  875. /// described by an existing memory descriptor.
  876. ///
  877. /// @warning
  878. /// Some combinations of physical memory layout and/or offsets or dims may
  879. /// result in a failure to create a submemory.
  880. //
  881. /// @param memory_desc Output memory descriptor.
  882. /// @param parent_memory_desc An existing memory descriptor.
  883. /// @param dims Sizes of the region.
  884. /// @param offsets Offsets to the region from the encompassing
  885. /// memory object in each dimension
  886. /// @returns #dnnl_success on success and a status describing the error
  887. /// otherwise.
  888. dnnl_status_t DNNL_API dnnl_memory_desc_create_submemory(
  889. dnnl_memory_desc_t *memory_desc,
  890. const_dnnl_memory_desc_t parent_memory_desc, const dnnl_dims_t dims,
  891. const dnnl_dims_t offsets);
  892. /// Creates a memory descriptor by reshaping an existing one. The new
  893. /// memory descriptor inherits the data type. This operation is valid only for
  894. /// memory descriptors that have format_kind #dnnl_blocked or
  895. /// #dnnl_format_kind_any.
  896. ///
  897. /// The resulting memory descriptor must be destroyed separately.
  898. ///
  899. /// The operation ensures the transformation of the physical memory format
  900. /// corresponds to the transformation of the logical dimensions. If such
  901. /// transformation is impossible, the function returns #dnnl_invalid_arguments.
  902. ///
  903. /// The reshape operation can be described as a combination of the following
  904. /// basic operations:
  905. /// 1. Add a dimension of size `1`. This is always possible.
  906. /// 2. Remove a dimension of size `1`. This is possible only if the dimension
  907. /// has no padding (i.e. `padded_dims[dim] == dims[dim] && dims[dim] == 1`).
  908. /// 3. Split a dimension into multiple ones. This is possible only if the size
  909. /// of the dimension is exactly equal to the product of the split ones and
  910. /// the dimension does not have padding (i.e.
  911. /// `padded_dims[dim] = dims[dim]`).
  912. /// 4. Joining multiple consecutive dimensions into a single one. As in the
  913. /// cases above, this requires that the dimensions do not have padding and
  914. /// that the memory format is such that in physical memory these dimensions
  915. /// are dense and have the same order as their logical counterparts. This
  916. /// also assumes that these dimensions are not blocked.
  917. /// - Here, dense means:
  918. /// `stride for dim[i] == (stride for dim[i + 1]) * dim[i + 1]`;
  919. /// - And same order means:
  920. /// `i < j` if and only if `stride for dim[j] <= stride for dim[i]`.
  921. ///
  922. /// @warning
  923. /// Some combinations of physical memory layout and/or offsets or
  924. /// dimensions may result in a failure to make a reshape.
  925. ///
  926. /// @param out_memory_desc Output memory descriptor.
  927. /// @param in_memory_desc An existing memory descriptor. Must have format_kind
  928. /// set to #dnnl_blocked or #dnnl_format_kind_any.
  929. /// @param ndims Number of dimensions for the output memory descriptor.
  930. /// @param dims Dimensions for the output memory descriptor.
  931. /// @returns #dnnl_success on success and a status describing the error
  932. /// otherwise.
  933. dnnl_status_t DNNL_API dnnl_memory_desc_reshape(
  934. dnnl_memory_desc_t *out_memory_desc,
  935. const_dnnl_memory_desc_t in_memory_desc, int ndims,
  936. const dnnl_dims_t dims);
  937. /// Creates a memory descriptor by permuting axes in an existing one.
  938. ///
  939. /// The physical memory layout representation is adjusted accordingly to
  940. /// maintain the consistency between the logical and physical parts of the
  941. /// memory descriptor.
  942. ///
  943. /// The resulting memory descriptor must be destroyed separately.
  944. ///
  945. /// The new memory descriptor inherits the data type. This operation is valid
  946. /// only for memory descriptors that have format_kind set to #dnnl_blocked or
  947. /// #dnnl_format_kind_any.
  948. ///
  949. /// The logical axes will be permuted in the following manner:
  950. /// ```
  951. /// for (i: 0 .. in_memory_desc->ndims)
  952. /// out_memory_desc->dims[permutation[i]] = in_memory_desc->dims[i];
  953. /// ```
  954. ///
  955. /// Example:
  956. /// @code
  957. /// dnnl_memory_desc_t in_md, out_md, expect_out_md;
  958. ///
  959. /// const int permutation[] = {1, 0}; // swap the first and the second axes
  960. ///
  961. /// dnnl_dims_t in_dims = {2, 3}, out_dims = {3, 2};
  962. /// dnnl_format_tag_t in_tag = dnnl_ab, out_tag = dnnl_ba;
  963. ///
  964. /// dnnl_memory_desc_create_with_tag(
  965. /// &in_md, 2, in_dims, data_type, in_tag);
  966. /// dnnl_memory_desc_create_with_tag(
  967. /// &expect_out_md, 2, out_dims, data_type, out_tag);
  968. ///
  969. /// dnnl_memory_desc_permute_axes(&out_md, in_md, permutation);
  970. /// assert(dnnl_memory_desc_equal(out_md, expect_out_md));
  971. ///
  972. /// dnnl_memory_desc_destroy(in_md);
  973. /// dnnl_memory_desc_destroy(out_md);
  974. /// dnnl_memory_desc_destroy(expect_out_md);
  975. /// @endcode
  976. ///
  977. /// @param out_memory_desc Output memory descriptor.
  978. /// @param in_memory_desc An existing memory descriptor. Must have format_kind
  979. /// set to #dnnl_blocked or #dnnl_format_kind_any.
  980. /// @param permutation Axes permutation (of size `in_memory_desc->ndims`).
  981. /// @returns #dnnl_success on success and a status describing the error
  982. /// otherwise.
  983. dnnl_status_t DNNL_API dnnl_memory_desc_permute_axes(
  984. dnnl_memory_desc_t *out_memory_desc,
  985. const_dnnl_memory_desc_t in_memory_desc, const int *permutation);
  986. /// Queries a memory descriptor for various pieces of information.
  987. ///
  988. /// The following information can be queried:
  989. /// - Number of dimensions (#dnnl_query_ndims_s32)
  990. /// - Dimensions (#dnnl_query_dims) in the following order:
  991. /// - CNN data tensors: mini-batch, channel, spatial
  992. /// (<code>{N, C, [[D,] H,] W}</code>)
  993. /// - CNN weight tensors: group (optional), output channel, input channel,
  994. /// spatial (<code>{[G,] O, I, [[D,] H,] W}</code>)
  995. /// - RNN data tensors: time, mini-batch, channels (<code>{T, N, C}</code>)
  996. /// or layers, directions, states, mini-batch, channels
  997. /// (<code>{L, D, S, N, C}</code>)
  998. /// - RNN weight tensor: layers, directions, input channel, gates, output
  999. /// channels (<code>{L, D, I, G, O}</code>)
  1000. /// - Data type of the tensor elements (#dnnl_query_data_type)
  1001. /// - Padded dimensions (#dnnl_query_padded_dims) - size of the data including
  1002. /// padding in each dimension
  1003. /// - Padded offsets (#dnnl_query_padded_offsets) - per-dimension offset from
  1004. /// the padding to actual data, the top-level tensor with offsets applied
  1005. /// must lie within the padding area.
  1006. /// - Submemory offset (#dnnl_query_submemory_offset_s64) - offset from memory
  1007. /// origin to the current block, non-zero only in a description of a memory
  1008. /// sub-block.
  1009. /// - Format kind (#dnnl_query_format_kind) - memory format kind
  1010. ///
  1011. /// @note
  1012. /// The order of dimensions does not depend on the memory format, so
  1013. /// whether the data is laid out in #dnnl_nchw or #dnnl_nhwc
  1014. /// the dims for 4D CN data tensor would be <code>{N, C, H, W}</code>.
  1015. ///
  1016. /// The following queries are applicable only to format kind #dnnl_blocked.
  1017. /// - Strides (#dnnl_query_strides) between the outermost blocks or in case
  1018. /// of plain (non-blocked) formats the strides between dimensions
  1019. /// - Number of innermost blocks (#dnnl_query_inner_nblks_s32), e.g.
  1020. /// `{4, 16, 4}` in case of `OIhw_4i16o4i`
  1021. /// - Size of the innermost blocks (#dnnl_query_inner_blks), e.g. 3 in case
  1022. /// of `OIhw_4i16o4i_`
  1023. /// - Logical indices of the blocks (#dnnl_query_inner_idxs), e.g. `{1, 0, 1}`
  1024. /// in case of `4i16o4i`, because `i` is the 1st dim and `o` is the 0st dim
  1025. ///
  1026. /// @param memory_desc Memory descriptor.
  1027. /// @param what Parameter to query.
  1028. /// @param result Output result. The type depends on the query. For example,
  1029. /// it must be a @c dnnl_dims_t** if querying for a strides.
  1030. /// @returns #dnnl_success on success and a status describing the error
  1031. /// otherwise.
  1032. dnnl_status_t DNNL_API dnnl_memory_desc_query(
  1033. const_dnnl_memory_desc_t memory_desc, dnnl_query_t what, void *result);
  1034. #ifdef DNNL_EXPERIMENTAL_SPARSE
  1035. /// Queries a memory descriptor for various pieces of information. This version
  1036. /// support additional queries #dnnl_query_sparse_encoding, #dnnl_query_nnz_s64
  1037. /// #dnnl_query_num_handles_s32 and #dnnl_query_data_type for a particular
  1038. /// buffer.
  1039. ///
  1040. /// The following information can be queried:
  1041. /// - Number of dimensions (#dnnl_query_ndims_s32)
  1042. /// - Dimensions (#dnnl_query_dims) in the following order:
  1043. /// - CNN data tensors: mini-batch, channel, spatial
  1044. /// (<code>{N, C, [[D,] H,] W}</code>)
  1045. /// - CNN weight tensors: group (optional), output channel, input channel,
  1046. /// spatial (<code>{[G,] O, I, [[D,] H,] W}</code>)
  1047. /// - RNN data tensors: time, mini-batch, channels (<code>{T, N, C}</code>)
  1048. /// or layers, directions, states, mini-batch, channels
  1049. /// (<code>{L, D, S, N, C}</code>)
  1050. /// - RNN weight tensor: layers, directions, input channel, gates, output
  1051. /// channels (<code>{L, D, I, G, O}</code>)
  1052. /// - Data type of the tensor elements (#dnnl_query_data_type)
  1053. /// - Padded dimensions (#dnnl_query_padded_dims) - size of the data including
  1054. /// padding in each dimension
  1055. /// - Padded offsets (#dnnl_query_padded_offsets) - per-dimension offset from
  1056. /// the padding to actual data, the top-level tensor with offsets applied
  1057. /// must lie within the padding area.
  1058. /// - Submemory offset (#dnnl_query_submemory_offset_s64) - offset from memory
  1059. /// origin to the current block, non-zero only in a description of a memory
  1060. /// sub-block.
  1061. /// - Format kind (#dnnl_query_format_kind) - memory format kind
  1062. ///
  1063. /// @note
  1064. /// The order of dimensions does not depend on the memory format, so
  1065. /// whether the data is laid out in #dnnl_nchw or #dnnl_nhwc
  1066. /// the dims for 4D CN data tensor would be <code>{N, C, H, W}</code>.
  1067. ///
  1068. /// The following queries are applicable only to format kind #dnnl_blocked.
  1069. /// - Strides (#dnnl_query_strides) between the outermost blocks or in case
  1070. /// of plain (non-blocked) formats the strides between dimensions
  1071. /// - Number of innermost blocks (#dnnl_query_inner_nblks_s32), e.g.
  1072. /// `{4, 16, 4}` in case of `OIhw_4i16o4i`
  1073. /// - Size of the innermost blocks (#dnnl_query_inner_blks), e.g. 3 in case
  1074. /// of `OIhw_4i16o4i_`
  1075. /// - Logical indices of the blocks (#dnnl_query_inner_idxs), e.g. `{1, 0, 1}`
  1076. /// in case of `4i16o4i`, because `i` is the 1st dim and `o` is the 0st dim
  1077. ///
  1078. /// @param memory_desc Memory descriptor.
  1079. /// @param what Parameter to query.
  1080. /// @param index Index of the parameter to query for. It is mostly used with
  1081. /// #dnnl_query_data_type to specify which data type is being queried.
  1082. /// The main data type (data type of values) has always index 0. For other
  1083. /// indices please refer to the API for creating a memory descriptor for
  1084. /// sparse encoding.
  1085. /// @param result Output result. The type depends on the query. For example,
  1086. /// it must be a @c dnnl_dims_t** if querying for a strides.
  1087. /// @returns #dnnl_success on success and a status describing the error
  1088. /// otherwise.
  1089. dnnl_status_t DNNL_API dnnl_memory_desc_query_v2(
  1090. const_dnnl_memory_desc_t memory_desc, dnnl_query_t what, int index,
  1091. void *result);
  1092. #endif
  1093. /// Compares two memory descriptors.
  1094. ///
  1095. /// Use this function to identify whether a reorder is required between the
  1096. /// two memories
  1097. ///
  1098. /// @param lhs Left-hand side of the comparison.
  1099. /// @param rhs Right-hand side of the comparison.
  1100. /// @returns 1 if the descriptors are the same.
  1101. /// @returns 0 if the descriptors are different.
  1102. int DNNL_API dnnl_memory_desc_equal(
  1103. const_dnnl_memory_desc_t lhs, const_dnnl_memory_desc_t rhs);
  1104. /// Returns the size of a memory descriptor.
  1105. ///
  1106. /// @param memory_desc Memory descriptor.
  1107. /// @returns The number of bytes required for memory described by a memory
  1108. /// descriptor.
  1109. size_t DNNL_API dnnl_memory_desc_get_size(const_dnnl_memory_desc_t memory_desc);
  1110. #ifdef DNNL_EXPERIMENTAL_SPARSE
  1111. /// Returns the size of the data that corresponds to the given index.
  1112. ///
  1113. /// @param memory_desc Memory descriptor.
  1114. /// @param index Index of the buffer.
  1115. ///
  1116. /// @returns The number of bytes required for the requested data.
  1117. size_t DNNL_API dnnl_memory_desc_get_size_v2(
  1118. const_dnnl_memory_desc_t memory_desc, int index);
  1119. #endif
  1120. /// Returns the size of data type.
  1121. ///
  1122. /// @param data_type Data type.
  1123. /// @returns The number of bytes occupied by data type.
  1124. size_t DNNL_API dnnl_data_type_size(dnnl_data_type_t data_type);
  1125. /// Creates a memory object.
  1126. ///
  1127. /// Unless @p handle is equal to DNNL_MEMORY_NONE, the constructed memory
  1128. /// object will have the underlying buffer set. In this case, the buffer will
  1129. /// be initialized as if dnnl_memory_set_data_handle() had been called.
  1130. ///
  1131. /// @sa dnnl_memory_set_data_handle()
  1132. ///
  1133. /// @param memory Output memory object.
  1134. /// @param memory_desc Memory descriptor.
  1135. /// @param engine Engine to use.
  1136. /// @param handle Handle of the memory buffer to use as an underlying storage.
  1137. /// - A pointer to the user-allocated buffer. In this case the library
  1138. /// doesn't own the buffer.
  1139. /// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
  1140. /// allocate the buffer for the memory object. In this case the library
  1141. /// owns the buffer.
  1142. /// - DNNL_MEMORY_NONE to create dnnl_memory without an underlying buffer.
  1143. /// @returns #dnnl_success on success and a status describing the error
  1144. /// otherwise.
  1145. dnnl_status_t DNNL_API dnnl_memory_create(dnnl_memory_t *memory,
  1146. const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine,
  1147. void *handle);
  1148. #ifdef DNNL_EXPERIMENTAL_SPARSE
  1149. /// Creates a memory object with multiple handles.
  1150. ///
  1151. /// @param memory Output memory object.
  1152. /// @param memory_desc Memory descriptor.
  1153. /// @param engine Engine to use.
  1154. /// @param nhandles Number of handles.
  1155. /// @param handles Handles of the memory buffers to use as underlying storages.
  1156. /// For each element of the @p handles array the following applies:
  1157. /// - A pointer to the user-allocated buffer. In this case the library
  1158. /// doesn't own the buffer.
  1159. /// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
  1160. /// allocate the buffer for the memory object. In this case the library
  1161. /// owns the buffer.
  1162. /// - DNNL_MEMORY_NONE Instructs the library to skip allocation of the
  1163. /// memory buffer.
  1164. /// @returns #dnnl_success on success and a status describing the error
  1165. /// otherwise.
  1166. dnnl_status_t DNNL_API dnnl_memory_create_v2(dnnl_memory_t *memory,
  1167. const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine,
  1168. int nhandles, void **handles);
  1169. #endif
  1170. /// Returns the memory descriptor for a memory object.
  1171. ///
  1172. /// @param memory Memory object.
  1173. /// @param memory_desc Output memory descriptor (a copy).
  1174. /// @returns #dnnl_success on success and a status describing the error
  1175. /// otherwise.
  1176. dnnl_status_t DNNL_API dnnl_memory_get_memory_desc(
  1177. const_dnnl_memory_t memory, const_dnnl_memory_desc_t *memory_desc);
  1178. /// Returns the engine of a memory object.
  1179. ///
  1180. /// @param memory Memory object.
  1181. /// @param engine Output engine on which the memory is located.
  1182. /// @returns #dnnl_success on success and a status describing the error
  1183. /// otherwise.
  1184. dnnl_status_t DNNL_API dnnl_memory_get_engine(
  1185. const_dnnl_memory_t memory, dnnl_engine_t *engine);
  1186. /// Maps a memory object and returns a host-side pointer to a memory buffer
  1187. /// with a copy of its contents.
  1188. ///
  1189. /// Mapping enables explicit direct access to memory contents for the engines
  1190. /// that do not support it implicitly.
  1191. ///
  1192. /// Mapping is an exclusive operation - a memory object cannot be used in
  1193. /// other operations until this memory object is unmapped.
  1194. ///
  1195. /// @note
  1196. /// Any primitives working with @p memory should be completed before
  1197. /// the memory is mapped. Use dnnl_stream_wait to synchronize the
  1198. /// corresponding execution stream.
  1199. ///
  1200. /// @note
  1201. /// The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are
  1202. /// mainly provided for debug and testing purposes, and their performance
  1203. /// may be suboptimal.
  1204. ///
  1205. /// @param memory Memory object.
  1206. /// @param mapped_ptr Output pointer to the mapped buffer.
  1207. /// @returns #dnnl_success on success and a status describing the error
  1208. /// otherwise.
  1209. dnnl_status_t DNNL_API dnnl_memory_map_data(
  1210. const_dnnl_memory_t memory, void **mapped_ptr);
  1211. #ifdef DNNL_EXPERIMENTAL_SPARSE
  1212. /// Maps a memory object and returns a host-side pointer to a memory buffer
  1213. /// with a copy of its contents. The memory buffer corresponds to the given
  1214. /// index.
  1215. ///
  1216. /// Mapping enables explicit direct access to memory contents for the engines
  1217. /// that do not support it implicitly.
  1218. ///
  1219. /// Mapping is an exclusive operation - a memory object cannot be used in
  1220. /// other operations until this memory object is unmapped.
  1221. ///
  1222. /// @note
  1223. /// Any primitives working with @p memory should be completed before
  1224. /// the memory is mapped. Use dnnl_stream_wait to synchronize the
  1225. /// corresponding execution stream.
  1226. ///
  1227. /// @note
  1228. /// The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are
  1229. /// mainly provided for debug and testing purposes, and their performance
  1230. /// may be suboptimal.
  1231. ///
  1232. /// @param memory Memory object.
  1233. /// @param mapped_ptr Output pointer to the mapped buffer.
  1234. /// @param index Index of the buffer.
  1235. /// @returns #dnnl_success on success and a status describing the error
  1236. /// otherwise.
  1237. dnnl_status_t DNNL_API dnnl_memory_map_data_v2(
  1238. const_dnnl_memory_t memory, void **mapped_ptr, int index);
  1239. #endif
  1240. /// Unmaps a memory object and writes back any changes made to the previously
  1241. /// mapped memory buffer. The pointer to the mapped buffer must be obtained
  1242. /// via the dnnl_memory_map_data() call.
  1243. ///
  1244. /// @note
  1245. /// The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are
  1246. /// mainly provided for debug and testing purposes, and their performance
  1247. /// may be suboptimal.
  1248. ///
  1249. /// @param memory Memory object.
  1250. /// @param mapped_ptr Pointer to the mapped buffer that must have been
  1251. /// obtained using the dnnl_memory_map_data() function.
  1252. /// @returns #dnnl_success on success and a status describing the error
  1253. /// otherwise.
  1254. dnnl_status_t DNNL_API dnnl_memory_unmap_data(
  1255. const_dnnl_memory_t memory, void *mapped_ptr);
  1256. #ifdef DNNL_EXPERIMENTAL_SPARSE
  1257. /// Unmaps a memory object and writes back any changes made to the previously
  1258. /// mapped memory buffer. The pointer to the mapped buffer must be obtained
  1259. /// via the dnnl_memory_map_data() call. The buffer corresponds to the given
  1260. /// index.
  1261. ///
  1262. /// @note
  1263. /// The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are
  1264. /// mainly provided for debug and testing purposes, and their performance
  1265. /// may be suboptimal.
  1266. ///
  1267. /// @param memory Memory object.
  1268. /// @param mapped_ptr Pointer to the mapped buffer that must have been
  1269. /// obtained using the dnnl_memory_map_data() function.
  1270. /// @param index Index of the buffer.
  1271. /// @returns #dnnl_success on success and a status describing the error
  1272. /// otherwise.
  1273. dnnl_status_t DNNL_API dnnl_memory_unmap_data_v2(
  1274. const_dnnl_memory_t memory, void *mapped_ptr, int index);
  1275. #endif
  1276. /// Returns memory object's data handle.
  1277. ///
  1278. /// @param memory Memory object.
  1279. /// @param handle Output data handle. For the CPU engine, the data handle is a
  1280. /// pointer to the actual data. For OpenCL it is a cl_mem.
  1281. /// @returns #dnnl_success on success and a status describing the error
  1282. /// otherwise.
  1283. dnnl_status_t DNNL_API dnnl_memory_get_data_handle(
  1284. const_dnnl_memory_t memory, void **handle);
  1285. /// Sets the underlying memory buffer.
  1286. ///
  1287. /// @param memory Memory object.
  1288. /// @param handle Data handle. For the CPU engine or when USM is used, the
  1289. /// memory buffer is a pointer to the actual data. For OpenCL it is a
  1290. /// `cl_mem`.
  1291. /// @returns #dnnl_success on success and a status describing the error
  1292. /// otherwise.
  1293. dnnl_status_t DNNL_API dnnl_memory_set_data_handle(
  1294. dnnl_memory_t memory, void *handle);
  1295. #ifdef DNNL_EXPERIMENTAL_SPARSE
  1296. /// Returns an underlying memory buffer that corresponds to the given index.
  1297. ///
  1298. /// @param memory Memory object.
  1299. /// @param handle Data handle. For the CPU engine or when USM is used, the
  1300. /// memory buffer is a pointer to the actual data. For OpenCL it is a
  1301. /// `cl_mem`.
  1302. /// @param index Index of the buffer.
  1303. /// @returns #dnnl_success on success and a status describing the error
  1304. /// otherwise.
  1305. dnnl_status_t DNNL_API dnnl_memory_get_data_handle_v2(
  1306. const_dnnl_memory_t memory, void **handle, int index);
  1307. /// Sets an underlying memory buffer that corresponds to the given index.
  1308. ///
  1309. /// @param memory Memory object.
  1310. /// @param handle Data handle. For the CPU engine or when USM is used, the
  1311. /// memory buffer is a pointer to the actual data. For OpenCL it is a
  1312. /// `cl_mem`.
  1313. /// @param index Index of the buffer.
  1314. /// @returns #dnnl_success on success and a status describing the error
  1315. /// otherwise.
  1316. dnnl_status_t DNNL_API dnnl_memory_set_data_handle_v2(
  1317. dnnl_memory_t memory, void *handle, int index);
  1318. #endif
  1319. /// Destroys a memory object.
  1320. ///
  1321. /// @param memory Memory object to destroy.
  1322. /// @returns #dnnl_success on success and a status describing the error
  1323. /// otherwise.
  1324. dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory);
  1325. /// @} dnnl_api_memory
  1326. /// @addtogroup dnnl_api_primitives
  1327. /// @{
  1328. /// @addtogroup dnnl_api_reorder
  1329. /// @{
  1330. /// Creates a primitive descriptor for a reorder primitive.
  1331. ///
  1332. /// @param reorder_primitive_desc Output primitive descriptor.
  1333. /// @param src_desc Source memory descriptor.
  1334. /// @param src_engine Engine on which the source memory object will be
  1335. /// located.
  1336. /// @param dst_desc Destination memory descriptor.
  1337. /// @param dst_engine Engine on which the destination memory object
  1338. /// will be located.
  1339. /// @param attr Primitive attributes to use (can be NULL).
  1340. /// @returns #dnnl_success on success and a status describing the error
  1341. /// otherwise.
  1342. dnnl_status_t DNNL_API dnnl_reorder_primitive_desc_create(
  1343. dnnl_primitive_desc_t *reorder_primitive_desc,
  1344. const_dnnl_memory_desc_t src_desc, dnnl_engine_t src_engine,
  1345. const_dnnl_memory_desc_t dst_desc, dnnl_engine_t dst_engine,
  1346. const_dnnl_primitive_attr_t attr);
  1347. /// @} dnnl_api_reorder
  1348. /// @addtogroup dnnl_api_concat
  1349. /// @{
  1350. /// Creates a primitive descriptor for an out-of-place concatenation
  1351. /// primitive.
  1352. ///
  1353. /// @param concat_primitive_desc Output primitive descriptor.
  1354. /// @param dst_desc Destination memory descriptor.
  1355. /// @param n Number of source parameters.
  1356. /// @param concat_dimension Source tensors will be concatenated over
  1357. /// dimension with this index. Note that order of dimensions does
  1358. /// not depend on memory format.
  1359. /// @param src_descs Array of source memory descriptors with @p n elements.
  1360. /// @param attr Primitive attributes to use (can be NULL).
  1361. /// @param engine Engine to use.
  1362. /// @returns #dnnl_success on success and a status describing the error
  1363. /// otherwise.
  1364. dnnl_status_t DNNL_API dnnl_concat_primitive_desc_create(
  1365. dnnl_primitive_desc_t *concat_primitive_desc, dnnl_engine_t engine,
  1366. const_dnnl_memory_desc_t dst_desc, int n, int concat_dimension,
  1367. const_dnnl_memory_desc_t const *src_descs,
  1368. const_dnnl_primitive_attr_t attr);
  1369. /// @} dnnl_api_concat
  1370. /// @addtogroup dnnl_api_sum
  1371. /// @{
  1372. /// Creates a primitive descriptor for an (out-of-place) sum primitive.
  1373. ///
  1374. /// @param sum_primitive_desc Output primitive descriptor.
  1375. /// @param dst_desc Destination memory descriptor.
  1376. /// @param n Number of source parameters.
  1377. /// @param scales Vector of scales to multiply data in each source
  1378. /// memory by.
  1379. /// @param src_descs Array of source memory descriptors having @p n elements.
  1380. /// @param attr Primitive attributes to use (can be NULL).
  1381. /// @param engine Engine to use.
  1382. /// @returns #dnnl_success on success and a status describing the error
  1383. /// otherwise.
  1384. dnnl_status_t DNNL_API dnnl_sum_primitive_desc_create(
  1385. dnnl_primitive_desc_t *sum_primitive_desc, dnnl_engine_t engine,
  1386. const_dnnl_memory_desc_t dst_desc, int n, const float *scales,
  1387. const_dnnl_memory_desc_t const *src_descs,
  1388. const_dnnl_primitive_attr_t attr);
  1389. /// @} dnnl_api_sum
  1390. /// @addtogroup dnnl_api_binary
  1391. /// @{
  1392. /// Creates a primitive descriptor for a binary primitive.
  1393. ///
  1394. /// @note
  1395. /// Memory descriptors @p src1_desc and @p dst_desc are allowed to be
  1396. /// initialized with #dnnl_format_tag_any or with format_kind set to
  1397. /// #dnnl_format_kind_any.
  1398. ///
  1399. /// @note
  1400. /// Both memory descriptors must have the same number of dimensions.
  1401. /// Element broadcasting is supported for memory descriptor @p src1_desc
  1402. /// and are applied to @p src1_desc dimensions that have size equal to 1.
  1403. ///
  1404. /// @param primitive_desc Output primitive descriptor.
  1405. /// @param engine Engine to use.
  1406. /// @param alg_kind Algorithm kind. Valid values are #dnnl_binary_add,
  1407. /// #dnnl_binary_mul, #dnnl_binary_max, #dnnl_binary_min, #dnnl_binary_div,
  1408. /// #dnnl_binary_sub, #dnnl_binary_ge, #dnnl_binary_gt, #dnnl_binary_le,
  1409. /// #dnnl_binary_lt, #dnnl_binary_eq and #dnnl_binary_ne.
  1410. /// @param src0_desc Source 0 memory descriptor.
  1411. /// @param src1_desc Source 1 memory descriptor.
  1412. /// @param dst_desc Destination memory descriptor.
  1413. /// @param attr Primitive attributes (can be NULL).
  1414. /// @returns #dnnl_success on success and a status describing the error
  1415. /// otherwise.
  1416. dnnl_status_t DNNL_API dnnl_binary_primitive_desc_create(
  1417. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1418. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src0_desc,
  1419. const_dnnl_memory_desc_t src1_desc, const_dnnl_memory_desc_t dst_desc,
  1420. const_dnnl_primitive_attr_t attr);
  1421. /// Creates a primitive descriptor for a binary primitive with support of
  1422. /// ternary operators.
  1423. ///
  1424. /// @note
  1425. /// Memory descriptors @p src1_desc, @p src2_desc and @p dst_desc are
  1426. /// allowed to be initialized with #dnnl_format_tag_any or with format_kind
  1427. /// set to #dnnl_format_kind_any.
  1428. ///
  1429. /// @note
  1430. /// All memory descriptors must have the same number of dimensions.
  1431. /// Element broadcasting is supported for memory descriptor @p src1_desc
  1432. /// and is applied to @p src1_desc dimensions that have a size equal to 1.
  1433. /// There is no broadcasting support for @p src2_desc.
  1434. ///
  1435. /// @param primitive_desc Output primitive descriptor.
  1436. /// @param engine Engine to use.
  1437. /// @param alg_kind Algorithm kind.
  1438. /// @param src0_desc Source 0 memory descriptor.
  1439. /// @param src1_desc Source 1 memory descriptor.
  1440. /// @param src2_desc Source memory descriptor for ternary operations. Might
  1441. /// be empty.
  1442. /// @param dst_desc Destination memory descriptor.
  1443. /// @param attr Primitive attributes.
  1444. /// @returns #dnnl_success on success and a status describing the error
  1445. /// otherwise.
  1446. dnnl_status_t DNNL_API dnnl_binary_primitive_desc_create_v2(
  1447. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1448. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src0_desc,
  1449. const_dnnl_memory_desc_t src1_desc, const_dnnl_memory_desc_t src2_desc,
  1450. const_dnnl_memory_desc_t dst_desc, const_dnnl_primitive_attr_t attr);
  1451. /// @} dnnl_api_binary
  1452. /// @addtogroup dnnl_api_convolution
  1453. /// @{
  1454. /// Creates a primitive descriptor for a convolution forward propagation
  1455. /// primitive.
  1456. ///
  1457. /// @note
  1458. /// Memory descriptors can be initialized with
  1459. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  1460. ///
  1461. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
  1462. /// values for spatial dimensions only and hence must have the same number of
  1463. /// elements as there are spatial dimensions. The order of values is the same
  1464. /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
  1465. /// and width.
  1466. ///
  1467. /// @param primitive_desc Output primitive descriptor.
  1468. /// @param engine Engine to use.
  1469. /// @param prop_kind Propagation kind. Possible values are
  1470. /// #dnnl_forward_training and #dnnl_forward_inference.
  1471. /// @param alg_kind Convolution algorithm. Possible values are
  1472. /// #dnnl_convolution_direct, #dnnl_convolution_winograd,
  1473. /// #dnnl_convolution_auto.
  1474. /// @param src_desc Source memory descriptor.
  1475. /// @param weights_desc Weights memory descriptor.
  1476. /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
  1477. /// descriptor, or a memory descriptor with format_kind set to
  1478. /// #dnnl_format_kind_undef disables the bias term.
  1479. /// @param dst_desc Destination memory descriptor.
  1480. /// @param strides Array of strides for spatial dimension.
  1481. /// @param dilates Array of dilations for spatial dimension. A zero value
  1482. /// means no dilation in the corresponding dimension.
  1483. /// @param padding_l Array of padding values for low indices for each spatial
  1484. /// dimension `([[front,] top,] left)`.
  1485. /// @param padding_r Array of padding values for high indices for each spatial
  1486. /// dimension `([[back,] bottom,] right)`. Can be NULL in which case
  1487. /// padding is considered to be symmetrical.
  1488. /// @param attr Primitive attributes (can be NULL).
  1489. /// @returns #dnnl_success on success and a status describing the error
  1490. /// otherwise.
  1491. dnnl_status_t DNNL_API dnnl_convolution_forward_primitive_desc_create(
  1492. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1493. dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
  1494. const_dnnl_memory_desc_t src_desc,
  1495. const_dnnl_memory_desc_t weights_desc,
  1496. const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
  1497. const dnnl_dims_t strides, const dnnl_dims_t dilates,
  1498. const dnnl_dims_t padding_l, const dnnl_dims_t padding_r,
  1499. const_dnnl_primitive_attr_t attr);
  1500. /// Creates a primitive descriptor for a convolution backward propagation
  1501. /// primitive.
  1502. ///
  1503. /// @note
  1504. /// Memory descriptors can be initialized with
  1505. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  1506. ///
  1507. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
  1508. /// values for spatial dimensions only and hence must have the same number of
  1509. /// elements as there are spatial dimensions. The order of values is the same
  1510. /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
  1511. /// and width.
  1512. ///
  1513. /// @param primitive_desc Output primitive descriptor.
  1514. /// @param engine Engine to use.
  1515. /// @param alg_kind Convolution algorithm. Possible values are
  1516. /// #dnnl_convolution_direct, #dnnl_convolution_winograd,
  1517. /// #dnnl_convolution_auto.
  1518. /// @param diff_src_desc Diff source memory descriptor.
  1519. /// @param weights_desc Weights memory descriptor.
  1520. /// @param diff_dst_desc Diff destination memory descriptor.
  1521. /// @param strides Array of strides for spatial dimension.
  1522. /// @param dilates Array of dilations for spatial dimension. A zero value
  1523. /// means no dilation in the corresponding dimension.
  1524. /// @param padding_l Array of padding values for low indices for each spatial
  1525. /// dimension `([[front,] top,] left)`.
  1526. /// @param padding_r Array of padding values for high indices for each spatial
  1527. /// dimension `([[back,] bottom,] right)`. Can be NULL in which case
  1528. /// padding is considered to be symmetrical.
  1529. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  1530. /// primitive.
  1531. /// @param attr Primitive attributes (can be NULL).
  1532. /// @returns #dnnl_success on success and a status describing the error
  1533. /// otherwise.
  1534. dnnl_status_t DNNL_API dnnl_convolution_backward_data_primitive_desc_create(
  1535. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1536. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
  1537. const_dnnl_memory_desc_t weights_desc,
  1538. const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
  1539. const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
  1540. const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
  1541. const_dnnl_primitive_attr_t attr);
  1542. /// Creates a primitive descriptor for a convolution weights gradient primitive.
  1543. ///
  1544. /// @note
  1545. /// Memory descriptors can be initialized with
  1546. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  1547. ///
  1548. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
  1549. /// values for spatial dimensions only and hence must have the same number of
  1550. /// elements as there are spatial dimensions. The order of values is the same
  1551. /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
  1552. /// and width.
  1553. ///
  1554. /// @param primitive_desc Output primitive descriptor.
  1555. /// @param engine Engine to use.
  1556. /// @param alg_kind Convolution algorithm. Possible values are
  1557. /// #dnnl_convolution_direct, #dnnl_convolution_winograd,
  1558. /// #dnnl_convolution_auto.
  1559. /// @param src_desc Source memory descriptor.
  1560. /// @param diff_weights_desc Diff weights memory descriptor.
  1561. /// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
  1562. /// memory descriptor, or a memory descriptor with format_kind set to
  1563. /// #dnnl_format_kind_undef disables the bias term.
  1564. /// @param diff_dst_desc Diff destination memory descriptor.
  1565. /// @param strides Array of strides for spatial dimension.
  1566. /// @param dilates Array of dilations for spatial dimension. A zero value
  1567. /// means no dilation in the corresponding dimension.
  1568. /// @param padding_l Array of padding values for low indices for each spatial
  1569. /// dimension `([[front,] top,] left)`.
  1570. /// @param padding_r Array of padding values for high indices for each spatial
  1571. /// dimension `([[back,] bottom,] right)`. Can be NULL in which case
  1572. /// padding is considered to be symmetrical.
  1573. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  1574. /// primitive.
  1575. /// @param attr Primitive attributes (can be NULL).
  1576. /// @returns #dnnl_success on success and a status describing the error
  1577. /// otherwise.
  1578. dnnl_status_t DNNL_API dnnl_convolution_backward_weights_primitive_desc_create(
  1579. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1580. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src_desc,
  1581. const_dnnl_memory_desc_t diff_weights_desc,
  1582. const_dnnl_memory_desc_t diff_bias_desc,
  1583. const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
  1584. const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
  1585. const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
  1586. const_dnnl_primitive_attr_t attr);
  1587. /// @} dnnl_api_convolution
  1588. /// @addtogroup dnnl_api_deconvolution
  1589. /// @{
  1590. /// Creates a primitive descriptor for a deconvolution forward propagation
  1591. /// primitive.
  1592. ///
  1593. /// @note
  1594. /// Memory descriptors can be initialized with
  1595. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  1596. ///
  1597. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
  1598. /// values for spatial dimensions only and hence must have the same number of
  1599. /// elements as there are spatial dimensions. The order of values is the same
  1600. /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
  1601. /// and width.
  1602. ///
  1603. /// @param primitive_desc Output primitive descriptor.
  1604. /// @param engine Engine to use.
  1605. /// @param prop_kind Propagation kind. Possible values are
  1606. /// #dnnl_forward_training and #dnnl_forward_inference.
  1607. /// @param alg_kind Deconvolution algorithm. Possible values are
  1608. /// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
  1609. /// @param src_desc Source memory descriptor.
  1610. /// @param weights_desc Weights memory descriptor.
  1611. /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
  1612. /// descriptor, or a memory descriptor with format_kind set to
  1613. /// #dnnl_format_kind_undef disables the bias term.
  1614. /// @param dst_desc Destination memory descriptor.
  1615. /// @param strides Array of strides for spatial dimension.
  1616. /// @param dilates Array of dilations for spatial dimension. A zero value
  1617. /// means no dilation in the corresponding dimension.
  1618. /// @param padding_l Array of padding values for low indices for each spatial
  1619. /// dimension `([[front,] top,] left)`.
  1620. /// @param padding_r Array of padding values for high indices for each spatial
  1621. /// dimension `([[back,] bottom,] right)`. Can be NULL in which case
  1622. /// padding is considered to be symmetrical.
  1623. /// @param attr Primitive attributes (can be NULL).
  1624. /// @returns #dnnl_success on success and a status describing the error
  1625. /// otherwise.
  1626. dnnl_status_t DNNL_API dnnl_deconvolution_forward_primitive_desc_create(
  1627. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1628. dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
  1629. const_dnnl_memory_desc_t src_desc,
  1630. const_dnnl_memory_desc_t weights_desc,
  1631. const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
  1632. const dnnl_dims_t strides, const dnnl_dims_t dilates,
  1633. const dnnl_dims_t padding_l, const dnnl_dims_t padding_r,
  1634. const_dnnl_primitive_attr_t attr);
  1635. /// Creates a primitive descriptor for a deconvolution backward propagation
  1636. /// primitive.
  1637. ///
  1638. /// @note
  1639. /// Memory descriptors can be initialized with
  1640. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  1641. ///
  1642. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
  1643. /// values for spatial dimensions only and hence must have the same number of
  1644. /// elements as there are spatial dimensions. The order of values is the same
  1645. /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
  1646. /// and width.
  1647. ///
  1648. /// @param primitive_desc Output primitive descriptor.
  1649. /// @param engine Engine to use.
  1650. /// @param alg_kind Deconvolution algorithm. Possible values are
  1651. /// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
  1652. /// @param diff_src_desc Diff source memory descriptor.
  1653. /// @param weights_desc Weights memory descriptor.
  1654. /// @param diff_dst_desc Diff destination memory descriptor.
  1655. /// @param strides Array of strides for spatial dimension.
  1656. /// @param dilates Array of dilations for spatial dimension. A zero value
  1657. /// means no dilation in the corresponding dimension.
  1658. /// @param padding_l Array of padding values for low indices for each spatial
  1659. /// dimension `([[front,] top,] left)`.
  1660. /// @param padding_r Array of padding values for high indices for each spatial
  1661. /// dimension `([[back,] bottom,] right)`. Can be NULL in which case
  1662. /// padding is considered to be symmetrical.
  1663. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  1664. /// primitive.
  1665. /// @param attr Primitive attributes (can be NULL).
  1666. /// @returns #dnnl_success on success and a status describing the error
  1667. /// otherwise.
  1668. dnnl_status_t DNNL_API dnnl_deconvolution_backward_data_primitive_desc_create(
  1669. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1670. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
  1671. const_dnnl_memory_desc_t weights_desc,
  1672. const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
  1673. const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
  1674. const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
  1675. const_dnnl_primitive_attr_t attr);
  1676. /// Creates a primitive descriptor for a deconvolution weights gradient
  1677. /// primitive.
  1678. ///
  1679. /// @note
  1680. /// Memory descriptors can be initialized with
  1681. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  1682. ///
  1683. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
  1684. /// values for spatial dimensions only and hence must have the same number of
  1685. /// elements as there are spatial dimensions. The order of values is the same
  1686. /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
  1687. /// and width.
  1688. ///
  1689. /// @param primitive_desc Output primitive descriptor.
  1690. /// @param engine Engine to use.
  1691. /// @param alg_kind Deconvolution algorithm. Possible values are
  1692. /// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
  1693. /// @param src_desc Source memory descriptor.
  1694. /// @param diff_weights_desc Diff weights memory descriptor.
  1695. /// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
  1696. /// memory descriptor, or a memory descriptor with format_kind set to
  1697. /// #dnnl_format_kind_undef disables the bias term.
  1698. /// @param diff_dst_desc Diff destination memory descriptor.
  1699. /// @param strides Array of strides for spatial dimension.
  1700. /// @param dilates Array of dilations for spatial dimension. A zero value
  1701. /// means no dilation in the corresponding dimension.
  1702. /// @param padding_l Array of padding values for low indices for each spatial
  1703. /// dimension `([[front,] top,] left)`.
  1704. /// @param padding_r Array of padding values for high indices for each spatial
  1705. /// dimension `([[back,] bottom,] right)`. Can be NULL in which case
  1706. /// padding is considered to be symmetrical.
  1707. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  1708. /// primitive.
  1709. /// @param attr Primitive attributes (can be NULL).
  1710. /// @returns #dnnl_success on success and a status describing the error
  1711. /// otherwise.
  1712. dnnl_status_t DNNL_API
  1713. dnnl_deconvolution_backward_weights_primitive_desc_create(
  1714. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1715. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src_desc,
  1716. const_dnnl_memory_desc_t diff_weights_desc,
  1717. const_dnnl_memory_desc_t diff_bias_desc,
  1718. const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
  1719. const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
  1720. const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
  1721. const_dnnl_primitive_attr_t attr);
  1722. /// @} dnnl_api_deconvolution
  1723. /// @addtogroup dnnl_api_shuffle
  1724. /// @{
  1725. /// Creates a primitive descriptor for a shuffle forward propagation primitive
  1726. ///
  1727. /// @param primitive_desc Output primitive descriptor.
  1728. /// @param engine Engine to use.
  1729. /// @param prop_kind Propagation kind. Possible values are
  1730. /// #dnnl_forward_training and #dnnl_forward_inference.
  1731. /// @param src_desc Source memory descriptor.
  1732. /// @param dst_desc Destination memory descriptor.
  1733. /// @param axis The axis along which the data is shuffled.
  1734. /// @param group_size Shuffle group size.
  1735. /// @param attr Primitive attributes (can be NULL).
  1736. /// @returns #dnnl_success on success and a status describing the error
  1737. /// otherwise.
  1738. dnnl_status_t DNNL_API dnnl_shuffle_forward_primitive_desc_create(
  1739. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1740. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
  1741. const_dnnl_memory_desc_t dst_desc, int axis, dnnl_dim_t group_size,
  1742. const_dnnl_primitive_attr_t attr);
  1743. /// Creates a primitive descriptor for a shuffle backward propagation primitive
  1744. ///
  1745. /// @param primitive_desc Output primitive descriptor.
  1746. /// @param engine Engine to use.
  1747. /// @param diff_src_desc Diff source memory descriptor.
  1748. /// @param diff_dst_desc Diff destination memory descriptor.
  1749. /// @param axis The axis along which the data is shuffled.
  1750. /// @param group_size Shuffle group size.
  1751. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  1752. /// primitive.
  1753. /// @param attr Primitive attributes (can be NULL).
  1754. /// @returns #dnnl_success on success and a status describing the error
  1755. /// otherwise.
  1756. dnnl_status_t DNNL_API dnnl_shuffle_backward_primitive_desc_create(
  1757. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1758. const_dnnl_memory_desc_t diff_src_desc,
  1759. const_dnnl_memory_desc_t diff_dst_desc, int axis, dnnl_dim_t group_size,
  1760. const_dnnl_primitive_desc_t hint_fwd_pd,
  1761. const_dnnl_primitive_attr_t attr);
  1762. /// @} dnnl_api_shuffle
  1763. /// @addtogroup dnnl_api_eltwise
  1764. /// @{
  1765. /// Creates a primitive descriptor for an eltwise forward propagation primitive.
  1766. ///
  1767. /// @param primitive_desc Output primitive descriptor.
  1768. /// @param engine Engine to use.
  1769. /// @param prop_kind Propagation kind. Possible values are
  1770. /// #dnnl_forward_training and #dnnl_forward_inference.
  1771. /// @param alg_kind Elementwise algorithm kind.
  1772. /// @param src_desc Source memory descriptor.
  1773. /// @param dst_desc Destination memory descriptor.
  1774. /// @param alpha The alpha parameter for the elementwise operation. Specific
  1775. /// meaning depends on the algorithm.
  1776. /// @param beta The beta parameter for the elementwise operation. Specific
  1777. /// meaning depends on the algorithm.
  1778. /// @param attr Primitive attributes (can be NULL).
  1779. /// @returns #dnnl_success on success and a status describing the error
  1780. /// otherwise.
  1781. dnnl_status_t DNNL_API dnnl_eltwise_forward_primitive_desc_create(
  1782. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1783. dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
  1784. const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc,
  1785. float alpha, float beta, const_dnnl_primitive_attr_t attr);
  1786. /// Creates a primitive descriptor for an eltwise backward propagation
  1787. /// primitive.
  1788. ///
  1789. /// @param primitive_desc Output primitive descriptor.
  1790. /// @param engine Engine to use.
  1791. /// @param alg_kind Elementwise algorithm kind.
  1792. /// @param diff_src_desc Diff source memory descriptor.
  1793. /// @param diff_dst_desc Diff destination memory descriptor.
  1794. /// @param data_desc Destination memory descriptor if one of the
  1795. /// "use_dst_for_bwd" algorithms are used (such as
  1796. /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor otherwise.
  1797. /// @param alpha The alpha parameter for the elementwise operation. Specific
  1798. /// meaning depends on the algorithm.
  1799. /// @param beta The beta parameter for the elementwise operation. Specific
  1800. /// meaning depends on the algorithm.
  1801. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  1802. /// primitive.
  1803. /// @param attr Primitive attributes (can be NULL).
  1804. /// @returns #dnnl_success on success and a status describing the error
  1805. /// otherwise.
  1806. dnnl_status_t DNNL_API dnnl_eltwise_backward_primitive_desc_create(
  1807. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1808. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
  1809. const_dnnl_memory_desc_t diff_dst_desc,
  1810. const_dnnl_memory_desc_t data_desc, float alpha, float beta,
  1811. const_dnnl_primitive_desc_t hint_fwd_pd,
  1812. const_dnnl_primitive_attr_t attr);
  1813. /// @} dnnl_api_eltwise
  1814. /// @addtogroup dnnl_api_softmax
  1815. /// @{
  1816. /// Creates a primitive descriptor for a softmax forward propagation primitive.
  1817. ///
  1818. /// @param primitive_desc Output primitive descriptor.
  1819. /// @param engine Engine to use.
  1820. /// @param prop_kind Propagation kind. Possible values are
  1821. /// #dnnl_forward_training and #dnnl_forward_inference.
  1822. /// @param alg_kind Softmax algorithm kind: either #dnnl_softmax_accurate, or
  1823. /// #dnnl_softmax_log.
  1824. /// @param src_desc Source memory descriptor.
  1825. /// @param dst_desc Destination memory descriptor.
  1826. /// @param softmax_axis Axis over which softmax is computed.
  1827. /// @param attr Primitive attributes (can be NULL).
  1828. /// @returns #dnnl_success on success and a status describing the error
  1829. /// otherwise.
  1830. dnnl_status_t DNNL_API dnnl_softmax_forward_primitive_desc_create(
  1831. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1832. dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
  1833. const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc,
  1834. int softmax_axis, const_dnnl_primitive_attr_t attr);
  1835. /// Creates a primitive descriptor for a softmax backward propagation primitive.
  1836. ///
  1837. /// @param primitive_desc Output primitive descriptor.
  1838. /// @param engine Engine to use.
  1839. /// @param alg_kind Softmax algorithm kind: either #dnnl_softmax_accurate, or
  1840. /// #dnnl_softmax_log.
  1841. /// @param diff_src_desc Diff source memory descriptor.
  1842. /// @param diff_dst_desc Diff destination memory descriptor.
  1843. /// @param dst_desc Destination memory descriptor.
  1844. /// @param softmax_axis Axis over which softmax is computed.
  1845. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  1846. /// primitive.
  1847. /// @param attr Primitive attributes (can be NULL).
  1848. /// @returns #dnnl_success on success and a status describing the error
  1849. /// otherwise.
  1850. dnnl_status_t DNNL_API dnnl_softmax_backward_primitive_desc_create(
  1851. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1852. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
  1853. const_dnnl_memory_desc_t diff_dst_desc,
  1854. const_dnnl_memory_desc_t dst_desc, int softmax_axis,
  1855. const_dnnl_primitive_desc_t hint_fwd_pd,
  1856. const_dnnl_primitive_attr_t attr);
  1857. /// @} dnnl_api_softmax
  1858. /// @addtogroup dnnl_api_pooling
  1859. /// @{
  1860. /// Creates a primitive descriptor for a pooling forward propagation
  1861. /// primitive.
  1862. ///
  1863. /// Arrays @p strides, @p kernel, @p dilation, @p padding_l and @p padding_r
  1864. /// contain values for spatial dimensions only and hence must have the same
  1865. /// number of elements as there are spatial dimensions. The order of values
  1866. /// is the same as in the tensor: depth (for 3D tensors),
  1867. /// height (for 3D and 2D tensors), and width.
  1868. ///
  1869. /// @param primitive_desc Output primitive descriptor.
  1870. /// @param engine Engine to use.
  1871. /// @param prop_kind Propagation kind. Possible values are
  1872. /// #dnnl_forward_training and #dnnl_forward_inference.
  1873. /// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max,
  1874. /// #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg_exclude_padding.
  1875. /// @param src_desc Source memory descriptor.
  1876. /// @param dst_desc Destination memory descriptor.
  1877. /// @param strides Array of strides for spatial dimension.
  1878. /// @param kernel Array of kernel spatial dimensions.
  1879. /// @param dilation Array of dilations for spatial dimension.
  1880. /// @param padding_l Array of padding values for low indices for each spatial
  1881. /// dimension `([[front,] top,] left)`.
  1882. /// @param padding_r Array of padding values for high indices for each spatial
  1883. /// dimension `([[back,] bottom,] right)`. Can be NULL in which case
  1884. /// padding is considered to be symmetrical.
  1885. /// @param attr Primitive attributes (can be NULL).
  1886. /// @returns #dnnl_success on success and a status describing the error
  1887. /// otherwise.
  1888. dnnl_status_t DNNL_API dnnl_pooling_forward_primitive_desc_create(
  1889. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1890. dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
  1891. const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc,
  1892. const dnnl_dims_t strides, const dnnl_dims_t kernel,
  1893. const dnnl_dims_t dilation, const dnnl_dims_t padding_l,
  1894. const dnnl_dims_t padding_r, const_dnnl_primitive_attr_t attr);
  1895. /// Creates a primitive descriptor for a pooling backward propagation
  1896. /// primitive.
  1897. ///
  1898. /// Arrays @p strides, @p kernel, @p dilation, @p padding_l and @p padding_r
  1899. /// contain values for spatial dimensions only and hence must have the same
  1900. /// number of elements as there are spatial dimensions. The order of values
  1901. /// is the same as in the tensor: depth (for 3D tensors),
  1902. /// height (for 3D and 2D tensors), and width.
  1903. ///
  1904. /// @param primitive_desc Output primitive descriptor.
  1905. /// @param engine Engine to use.
  1906. /// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max,
  1907. /// #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg_exclude_padding.
  1908. /// @param diff_src_desc Diff source memory descriptor.
  1909. /// @param diff_dst_desc Diff destination memory descriptor.
  1910. /// @param strides Array of strides for spatial dimension.
  1911. /// @param kernel Array of kernel spatial dimensions.
  1912. /// @param dilation Array of dilations for spatial dimension.
  1913. /// @param padding_l Array of padding values for low indices for each spatial
  1914. /// dimension `([[front,] top,] left)`.
  1915. /// @param padding_r Array of padding values for high indices for each spatial
  1916. /// dimension `([[back,] bottom,] right)`. Can be NULL in which case
  1917. /// padding is considered to be symmetrical.
  1918. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  1919. /// primitive.
  1920. /// @param attr Primitive attributes (can be NULL).
  1921. /// @returns #dnnl_success on success and a status describing the error
  1922. /// otherwise.
  1923. dnnl_status_t DNNL_API dnnl_pooling_backward_primitive_desc_create(
  1924. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1925. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
  1926. const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
  1927. const dnnl_dims_t kernel, const dnnl_dims_t dilation,
  1928. const dnnl_dims_t padding_l, const dnnl_dims_t padding_r,
  1929. const_dnnl_primitive_desc_t hint_fwd_pd,
  1930. const_dnnl_primitive_attr_t attr);
  1931. /// @} dnnl_api_pooling
  1932. /// @addtogroup dnnl_api_prelu
  1933. /// @{
  1934. /// Creates a primitive descriptor for a PReLU (leaky ReLU with trainable
  1935. /// alpha parameter) forward propagation primitive.
  1936. ///
  1937. /// @note
  1938. /// weights descriptor is allowed to be initialized with
  1939. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  1940. ///
  1941. /// @param primitive_desc Output primitive descriptor.
  1942. /// @param engine Engine to use.
  1943. /// @param prop_kind Propagation kind. Possible values are
  1944. /// #dnnl_forward_training and #dnnl_forward_inference.
  1945. /// @param src_desc Source memory descriptor.
  1946. /// @param weights_desc Alpha parameters memory descriptor.
  1947. /// @param dst_desc Destination memory descriptor.
  1948. /// @param attr Primitive attributes (can be NULL).
  1949. /// @returns #dnnl_success on success and a status describing the error
  1950. /// otherwise.
  1951. dnnl_status_t DNNL_API dnnl_prelu_forward_primitive_desc_create(
  1952. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1953. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
  1954. const_dnnl_memory_desc_t weights_desc,
  1955. const_dnnl_memory_desc_t dst_desc, const_dnnl_primitive_attr_t attr);
  1956. /// Creates a primitive descriptor for a PReLU (leaky ReLU with trainable
  1957. /// alpha parameter) backward propagation primitive.
  1958. ///
  1959. /// @note
  1960. /// weights descriptor and diff_weights descriptor are allowed
  1961. /// to be initialized with #dnnl_format_tag_any or with format_kind
  1962. /// set to #dnnl_format_kind_any.
  1963. ///
  1964. /// @param primitive_desc Output primitive descriptor.
  1965. /// @param engine Engine to use.
  1966. /// @param src_desc Source memory descriptor.
  1967. /// @param weights_desc Alpha parameters memory descriptor.
  1968. /// @param diff_src_desc Diff source memory descriptor.
  1969. /// @param diff_weights_desc Diff alpha parameters memory descriptor.
  1970. /// @param diff_dst_desc Diff destination memory descriptor.
  1971. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  1972. /// primitive.
  1973. /// @param attr Primitive attributes (can be NULL).
  1974. /// @returns #dnnl_success on success and a status describing the error
  1975. /// otherwise.
  1976. dnnl_status_t DNNL_API dnnl_prelu_backward_primitive_desc_create(
  1977. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1978. const_dnnl_memory_desc_t src_desc,
  1979. const_dnnl_memory_desc_t weights_desc,
  1980. const_dnnl_memory_desc_t diff_src_desc,
  1981. const_dnnl_memory_desc_t diff_weights_desc,
  1982. const_dnnl_memory_desc_t diff_dst_desc,
  1983. const_dnnl_primitive_desc_t hint_fwd_pd,
  1984. const_dnnl_primitive_attr_t attr);
  1985. /// @} dnnl_api_prelu
  1986. /// @addtogroup dnnl_api_lrn
  1987. /// @{
  1988. /// Creates a primitive descriptor for an LRN forward propagation primitive.
  1989. ///
  1990. /// @param primitive_desc Output primitive_descriptor.
  1991. /// @param engine Engine to use.
  1992. /// @param prop_kind Propagation kind. Possible values are
  1993. /// #dnnl_forward_training and #dnnl_forward_inference.
  1994. /// @param alg_kind LRN algorithm kind: either #dnnl_lrn_across_channels or
  1995. /// #dnnl_lrn_within_channel.
  1996. /// @param src_desc Source memory descriptor.
  1997. /// @param dst_desc Destination memory descriptor.
  1998. /// @param local_size Regularization local size.
  1999. /// @param alpha The alpha regularization parameter.
  2000. /// @param beta The beta regularization parameter.
  2001. /// @param k The k regularization parameter.
  2002. /// @param attr Primitive attributes (can be NULL).
  2003. /// @returns #dnnl_success on success and a status describing the error
  2004. /// otherwise.
  2005. dnnl_status_t DNNL_API dnnl_lrn_forward_primitive_desc_create(
  2006. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2007. dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
  2008. const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc,
  2009. dnnl_dim_t local_size, float alpha, float beta, float k,
  2010. const_dnnl_primitive_attr_t attr);
  2011. /// Creates a primitive descriptor for an LRN backward propagation primitive.
  2012. ///
  2013. /// @param primitive_desc Output primitive_descriptor.
  2014. /// @param engine Engine to use.
  2015. /// @param alg_kind LRN algorithm kind: either #dnnl_lrn_across_channels or
  2016. /// #dnnl_lrn_within_channel.
  2017. /// @param diff_src_desc Diff source memory descriptor.
  2018. /// @param diff_dst_desc Diff destination memory descriptor.
  2019. /// @param src_desc Source memory descriptor.
  2020. /// @param local_size Regularization local size.
  2021. /// @param alpha The alpha regularization parameter.
  2022. /// @param beta The beta regularization parameter.
  2023. /// @param k The k regularization parameter.
  2024. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2025. /// primitive.
  2026. /// @param attr Primitive attributes (can be NULL).
  2027. /// @returns #dnnl_success on success and a status describing the error
  2028. /// otherwise.
  2029. dnnl_status_t DNNL_API dnnl_lrn_backward_primitive_desc_create(
  2030. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2031. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
  2032. const_dnnl_memory_desc_t diff_dst_desc,
  2033. const_dnnl_memory_desc_t src_desc, dnnl_dim_t local_size, float alpha,
  2034. float beta, float k, const_dnnl_primitive_desc_t hint_fwd_pd,
  2035. const_dnnl_primitive_attr_t attr);
  2036. /// @} dnnl_api_lrn
  2037. /// @addtogroup dnnl_api_batch_normalization
  2038. /// @{
  2039. /// Creates a primitive descriptor for a batch normalization forward propagation
  2040. /// primitive.
  2041. ///
  2042. /// @note
  2043. /// In-place operation is supported: the dst can refer to the same memory
  2044. /// as the src.
  2045. ///
  2046. /// @param primitive_desc Output primitive_descriptor.
  2047. /// @param engine Engine to use.
  2048. /// @param prop_kind Propagation kind. Possible values are
  2049. /// #dnnl_forward_training and #dnnl_forward_inference.
  2050. /// @param src_desc Source memory descriptor.
  2051. /// @param dst_desc Destination memory descriptor.
  2052. /// @param epsilon Batch normalization epsilon parameter.
  2053. /// @param flags Batch normalization flags (@ref dnnl_normalization_flags_t).
  2054. /// @param attr Primitive attributes (can be NULL).
  2055. /// @returns #dnnl_success on success and a status describing the error
  2056. /// otherwise.
  2057. dnnl_status_t DNNL_API dnnl_batch_normalization_forward_primitive_desc_create(
  2058. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2059. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
  2060. const_dnnl_memory_desc_t dst_desc, float epsilon, unsigned flags,
  2061. const_dnnl_primitive_attr_t attr);
  2062. /// Creates a primitive descriptor for a batch normalization backward
  2063. /// propagation primitive.
  2064. ///
  2065. /// @note
  2066. /// In-place operation is supported: the diff_dst can refer to the same
  2067. /// memory as the diff_src.
  2068. ///
  2069. /// @param primitive_desc Output primitive_descriptor.
  2070. /// @param engine Engine to use.
  2071. /// @param prop_kind Propagation kind. Possible values are
  2072. /// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
  2073. /// computed in this case).
  2074. /// @param diff_src_desc Diff source memory descriptor.
  2075. /// @param diff_dst_desc Diff destination memory descriptor.
  2076. /// @param src_desc Source memory descriptor.
  2077. /// @param epsilon Batch normalization epsilon parameter.
  2078. /// @param flags Batch normalization flags (@ref dnnl_normalization_flags_t).
  2079. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2080. /// primitive.
  2081. /// @param attr Primitive attributes (can be NULL).
  2082. /// @returns #dnnl_success on success and a status describing the error
  2083. /// otherwise.
  2084. dnnl_status_t DNNL_API dnnl_batch_normalization_backward_primitive_desc_create(
  2085. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2086. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc,
  2087. const_dnnl_memory_desc_t diff_dst_desc,
  2088. const_dnnl_memory_desc_t src_desc, float epsilon, unsigned flags,
  2089. const_dnnl_primitive_desc_t hint_fwd_pd,
  2090. const_dnnl_primitive_attr_t attr);
  2091. /// @} dnnl_api_batch_normalization
  2092. /// @addtogroup dnnl_api_group_normalization
  2093. /// @{
  2094. /// Creates a primitive descriptor for a group normalization forward propagation
  2095. /// primitive.
  2096. ///
  2097. /// @note
  2098. /// In-place operation is supported: the dst can refer to the same memory
  2099. /// as the src.
  2100. ///
  2101. /// @param primitive_desc Output primitive_descriptor.
  2102. /// @param engine Engine to use.
  2103. /// @param prop_kind Propagation kind. Possible values are
  2104. /// #dnnl_forward_training and #dnnl_forward_inference.
  2105. /// @param src_desc Source memory descriptor.
  2106. /// @param dst_desc Destination memory descriptor.
  2107. /// @param groups Group normalization groups parameter.
  2108. /// @param epsilon Group normalization epsilon parameter.
  2109. /// @param flags Group normalization flags (@ref dnnl_normalization_flags_t).
  2110. /// @param attr Primitive attributes (can be NULL).
  2111. /// @returns #dnnl_success on success and a status describing the error
  2112. /// otherwise.
  2113. dnnl_status_t DNNL_API dnnl_group_normalization_forward_primitive_desc_create(
  2114. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2115. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
  2116. const_dnnl_memory_desc_t dst_desc, dnnl_dim_t groups, float epsilon,
  2117. unsigned flags, const_dnnl_primitive_attr_t attr);
  2118. /// Creates a primitive descriptor for a group normalization backward
  2119. /// propagation primitive.
  2120. ///
  2121. /// @note
  2122. /// In-place operation is supported: the diff_dst can refer to the same
  2123. /// memory as the diff_src.
  2124. ///
  2125. /// @param primitive_desc Output primitive_descriptor.
  2126. /// @param engine Engine to use.
  2127. /// @param prop_kind Propagation kind. Possible values are
  2128. /// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
  2129. /// computed in this case).
  2130. /// @param diff_src_desc Diff source memory descriptor.
  2131. /// @param diff_dst_desc Diff destination memory descriptor.
  2132. /// @param src_desc Source memory descriptor.
  2133. /// @param groups Group normalization groups parameter.
  2134. /// @param epsilon Group normalization epsilon parameter.
  2135. /// @param flags Group normalization flags (@ref dnnl_normalization_flags_t).
  2136. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2137. /// primitive.
  2138. /// @param attr Primitive attributes (can be NULL).
  2139. /// @returns #dnnl_success on success and a status describing the error
  2140. /// otherwise.
  2141. dnnl_status_t DNNL_API dnnl_group_normalization_backward_primitive_desc_create(
  2142. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2143. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc,
  2144. const_dnnl_memory_desc_t diff_dst_desc,
  2145. const_dnnl_memory_desc_t src_desc, dnnl_dim_t groups, float epsilon,
  2146. unsigned flags, const_dnnl_primitive_desc_t hint_fwd_pd,
  2147. const_dnnl_primitive_attr_t attr);
  2148. /// @} dnnl_api_group_normalization
  2149. /// @addtogroup dnnl_api_layer_normalization
  2150. /// @{
  2151. /// Creates a primitive descriptor for a layer normalization forward propagation
  2152. /// primitive.
  2153. ///
  2154. /// @note
  2155. /// In-place operation is supported: the dst can refer to the same memory
  2156. /// as the src.
  2157. ///
  2158. /// @param primitive_desc Output primitive_descriptor.
  2159. /// @param engine Engine to use.
  2160. /// @param prop_kind Propagation kind. Possible values are
  2161. /// #dnnl_forward_training and #dnnl_forward_inference.
  2162. /// @param src_desc Source memory descriptor.
  2163. /// @param dst_desc Destination memory descriptor.
  2164. /// @param stat_desc Memory descriptor for mean and variance. If this
  2165. /// parameter is NULL, a zero memory descriptor, or a memory descriptor
  2166. /// with format_kind set to #dnnl_format_kind_undef, then the memory
  2167. /// descriptor for stats is derived from @p src_desc by removing the last
  2168. /// dimension.
  2169. /// @param epsilon Layer normalization epsilon parameter.
  2170. /// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
  2171. /// @param attr Primitive attributes (can be NULL).
  2172. /// @returns #dnnl_success on success and a status describing the error
  2173. /// otherwise.
  2174. dnnl_status_t DNNL_API dnnl_layer_normalization_forward_primitive_desc_create(
  2175. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2176. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
  2177. const_dnnl_memory_desc_t dst_desc, const_dnnl_memory_desc_t stat_desc,
  2178. float epsilon, unsigned flags, const_dnnl_primitive_attr_t attr);
  2179. /// Creates a primitive descriptor for a layer normalization backward
  2180. /// propagation primitive.
  2181. ///
  2182. /// @note
  2183. /// In-place operation is supported: the diff_dst can refer to the same
  2184. /// memory as the diff_src.
  2185. ///
  2186. /// @param primitive_desc Output primitive_descriptor.
  2187. /// @param engine Engine to use.
  2188. /// @param prop_kind Propagation kind. Possible values are
  2189. /// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
  2190. /// computed in this case).
  2191. /// @param diff_src_desc Diff source memory descriptor.
  2192. /// @param diff_dst_desc Diff destination memory descriptor.
  2193. /// @param src_desc Source memory descriptor.
  2194. /// @param stat_desc Memory descriptor for mean and variance. If this
  2195. /// parameter is NULL, a zero memory descriptor, or a memory descriptor
  2196. /// with format_kind set to #dnnl_format_kind_undef, then the memory
  2197. /// descriptor for stats is derived from @p src_desc by removing the last
  2198. /// dimension.
  2199. /// @param epsilon Layer normalization epsilon parameter.
  2200. /// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
  2201. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2202. /// primitive.
  2203. /// @param attr Primitive attributes (can be NULL).
  2204. /// @returns #dnnl_success on success and a status describing the error
  2205. /// otherwise.
  2206. dnnl_status_t DNNL_API dnnl_layer_normalization_backward_primitive_desc_create(
  2207. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2208. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc,
  2209. const_dnnl_memory_desc_t diff_dst_desc,
  2210. const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t stat_desc,
  2211. float epsilon, unsigned flags, const_dnnl_primitive_desc_t hint_fwd_pd,
  2212. const_dnnl_primitive_attr_t attr);
  2213. /// Creates a primitive descriptor for a layer normalization forward propagation
  2214. /// primitive with a user-provided data type for the scale and shift
  2215. /// memory objects.
  2216. ///
  2217. /// @note
  2218. /// In-place operation is supported: the dst can refer to the same memory
  2219. /// as the src.
  2220. ///
  2221. /// @param primitive_desc Output primitive_descriptor.
  2222. /// @param engine Engine to use.
  2223. /// @param prop_kind Propagation kind. Possible values are
  2224. /// #dnnl_forward_training and #dnnl_forward_inference.
  2225. /// @param src_desc Source memory descriptor.
  2226. /// @param dst_desc Destination memory descriptor.
  2227. /// @param stat_desc Memory descriptor for mean and variance. If this
  2228. /// parameter is NULL, a zero memory descriptor, or a memory descriptor
  2229. /// with format_kind set to #dnnl_format_kind_undef, then the memory
  2230. /// descriptor for stats is derived from @p src_desc by removing the last
  2231. /// dimension.
  2232. /// @param scale_shift_data_type Data type of scale and shift memory. If neither scale
  2233. /// nor shift flag are specified the parameter is ignored.
  2234. /// @param epsilon Layer normalization epsilon parameter.
  2235. /// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
  2236. /// @param attr Primitive attributes (can be NULL).
  2237. /// @returns #dnnl_success on success and a status describing the error
  2238. /// otherwise.
  2239. dnnl_status_t DNNL_API
  2240. dnnl_layer_normalization_forward_primitive_desc_create_v2(
  2241. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2242. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
  2243. const_dnnl_memory_desc_t dst_desc, const_dnnl_memory_desc_t stat_desc,
  2244. dnnl_data_type_t scale_shift_data_type, float epsilon, unsigned flags,
  2245. const_dnnl_primitive_attr_t attr);
  2246. /// Creates a primitive descriptor for a layer normalization backward
  2247. /// propagation primitive with a user-provided data type for the
  2248. /// scale and shift memory objects.
  2249. ///
  2250. /// @note
  2251. /// In-place operation is supported: the diff_dst can refer to the same
  2252. /// memory as the diff_src.
  2253. ///
  2254. /// @param primitive_desc Output primitive_descriptor.
  2255. /// @param engine Engine to use.
  2256. /// @param prop_kind Propagation kind. Possible values are
  2257. /// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
  2258. /// computed in this case).
  2259. /// @param diff_src_desc Diff source memory descriptor.
  2260. /// @param diff_dst_desc Diff destination memory descriptor.
  2261. /// @param src_desc Source memory descriptor.
  2262. /// @param stat_desc Memory descriptor for mean and variance. If this
  2263. /// parameter is NULL, a zero memory descriptor, or a memory descriptor
  2264. /// with format_kind set to #dnnl_format_kind_undef, then the memory
  2265. /// descriptor for stats is derived from @p src_desc by removing the last
  2266. /// dimension.
  2267. /// @param diff_scale_shift_data_type Data type of diff scale and shift memory. If neither scale
  2268. /// nor shift flag are specified the parameter is ignored.
  2269. /// @param scale_shift_data_type Data type of scale and shift memory. If neither scale
  2270. /// nor shift flag are specified the parameter is ignored.
  2271. /// @param epsilon Layer normalization epsilon parameter.
  2272. /// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
  2273. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2274. /// primitive.
  2275. /// @param attr Primitive attributes (can be NULL).
  2276. /// @returns #dnnl_success on success and a status describing the error
  2277. /// otherwise.
  2278. dnnl_status_t DNNL_API
  2279. dnnl_layer_normalization_backward_primitive_desc_create_v2(
  2280. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2281. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc,
  2282. const_dnnl_memory_desc_t diff_dst_desc,
  2283. const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t stat_desc,
  2284. dnnl_data_type_t diff_scale_shift_data_type,
  2285. dnnl_data_type_t scale_shift_data_type, float epsilon, unsigned flags,
  2286. const_dnnl_primitive_desc_t hint_fwd_pd,
  2287. const_dnnl_primitive_attr_t attr);
  2288. /// @} dnnl_api_layer_normalization
  2289. /// @addtogroup dnnl_api_inner_product
  2290. /// @{
  2291. /// Creates a primitive descriptor for an inner product forward propagation
  2292. /// primitive.
  2293. ///
  2294. /// @note
  2295. /// Memory descriptors can be initialized with
  2296. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  2297. ///
  2298. /// @param primitive_desc Output primitive_descriptor.
  2299. /// @param engine Engine to use.
  2300. /// @param prop_kind Propagation kind. Possible values are
  2301. /// #dnnl_forward_training and #dnnl_forward_inference.
  2302. /// @param src_desc Source memory descriptor.
  2303. /// @param weights_desc Weights memory descriptor.
  2304. /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
  2305. /// descriptor, or a memory descriptor with format_kind set to
  2306. /// #dnnl_format_kind_undef disables the bias term.
  2307. /// @param dst_desc Destination memory descriptor.
  2308. /// @param attr Primitive attributes (can be NULL).
  2309. /// @returns #dnnl_success on success and a status describing the error
  2310. /// otherwise.
  2311. dnnl_status_t DNNL_API dnnl_inner_product_forward_primitive_desc_create(
  2312. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2313. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
  2314. const_dnnl_memory_desc_t weights_desc,
  2315. const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
  2316. const_dnnl_primitive_attr_t attr);
  2317. /// Creates a primitive descriptor for an inner product backward propagation
  2318. /// primitive.
  2319. ///
  2320. /// @note
  2321. /// Memory descriptors can be initialized with
  2322. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  2323. ///
  2324. /// @param primitive_desc Output primitive_descriptor.
  2325. /// @param engine Engine to use.
  2326. /// @param diff_src_desc Diff source memory descriptor.
  2327. /// @param weights_desc Weights memory descriptor.
  2328. /// @param diff_dst_desc Diff destination memory descriptor.
  2329. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2330. /// primitive.
  2331. /// @param attr Primitive attributes (can be NULL).
  2332. /// @returns #dnnl_success on success and a status describing the error
  2333. /// otherwise.
  2334. dnnl_status_t DNNL_API dnnl_inner_product_backward_data_primitive_desc_create(
  2335. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2336. const_dnnl_memory_desc_t diff_src_desc,
  2337. const_dnnl_memory_desc_t weights_desc,
  2338. const_dnnl_memory_desc_t diff_dst_desc,
  2339. const_dnnl_primitive_desc_t hint_fwd_pd,
  2340. const_dnnl_primitive_attr_t attr);
  2341. /// Creates a primitive descriptor for an inner product weights gradient
  2342. /// primitive.
  2343. ///
  2344. /// @note
  2345. /// Memory descriptors can be initialized with
  2346. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  2347. ///
  2348. /// @param primitive_desc Output primitive_descriptor.
  2349. /// @param engine Engine to use.
  2350. /// @param src_desc Source memory descriptor.
  2351. /// @param diff_weights_desc Diff weights memory descriptor.
  2352. /// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
  2353. /// memory descriptor, or a memory descriptor with format_kind set to
  2354. /// #dnnl_format_kind_undef disables the bias term.
  2355. /// @param diff_dst_desc Diff destination memory descriptor.
  2356. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2357. /// primitive.
  2358. /// @param attr Primitive attributes (can be NULL).
  2359. /// @returns #dnnl_success on success and a status describing the error
  2360. /// otherwise.
  2361. dnnl_status_t DNNL_API
  2362. dnnl_inner_product_backward_weights_primitive_desc_create(
  2363. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2364. const_dnnl_memory_desc_t src_desc,
  2365. const_dnnl_memory_desc_t diff_weights_desc,
  2366. const_dnnl_memory_desc_t diff_bias_desc,
  2367. const_dnnl_memory_desc_t diff_dst_desc,
  2368. const_dnnl_primitive_desc_t hint_fwd_pd,
  2369. const_dnnl_primitive_attr_t attr);
  2370. /// @} dnnl_api_inner_product
  2371. /// @addtogroup dnnl_api_attributes
  2372. /// @{
  2373. /// Set quantization scale and shift parameters for RNN data tensors.
  2374. ///
  2375. /// For performance reasons, the low-precision configuration of the RNN
  2376. /// primitives expects input activations to have the unsigned 8-bit integer
  2377. /// data type. The scale and shift parameters are used to quantize
  2378. /// floating-point data to unsigned integer and must be passed to the RNN
  2379. /// primitive using attributes.
  2380. ///
  2381. /// The quantization formula is `scale * data + shift`.
  2382. ///
  2383. /// @note
  2384. /// Quantization scale and shift are common for src_layer, src_iter,
  2385. /// dst_iter, and dst_layer.
  2386. ///
  2387. /// Example usage:
  2388. /// @code
  2389. /// // RNN parameters
  2390. /// int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
  2391. /// // Activations quantization parameters
  2392. /// float scale = 63.f, shift = 64.f;
  2393. ///
  2394. /// dnnl_primitive_attr_t rnn_attr;
  2395. /// // Create default attributes
  2396. /// dnnl_primitive_attr_create(&rnn_attr);
  2397. ///
  2398. /// // Set scale and shift for int8 quantization of activation
  2399. /// dnnl_primitive_attr_set_rnn_data_qparams(rnn_attr, scale, shift);
  2400. ///
  2401. /// // Create an RNN primitive descriptor.
  2402. /// dnnl_primitive_desc_t rnn_pd;
  2403. /// dnnl_vanilla_rnn_forward_primitive_desc_create(&rnn_pd,
  2404. /// engine, /* arguments */, attr);
  2405. /// @endcode
  2406. ///
  2407. /// @param attr Primitive attributes.
  2408. /// @param scale The value to scale the data by.
  2409. /// @param shift The value to shift the data by.
  2410. /// @returns #dnnl_success on success and a status describing the error
  2411. /// otherwise.
  2412. dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_data_qparams(
  2413. dnnl_primitive_attr_t attr, const float scale, const float shift);
  2414. /// Returns the quantization scale and shift parameters for RNN data tensors.
  2415. ///
  2416. /// @note
  2417. /// Quantization scale and shift are common for src_layer, src_iter,
  2418. /// dst_iter, and dst_layer.
  2419. ///
  2420. /// @param attr Primitive attributes.
  2421. /// @param scale The value to scale the data by.
  2422. /// @param shift The value to shift the data by.
  2423. /// @returns #dnnl_success on success and a status describing the error
  2424. /// otherwise.
  2425. dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_data_qparams(
  2426. const_dnnl_primitive_attr_t attr, float *scale, float *shift);
  2427. /// Sets quantization scaling factors for RNN weights tensors. The
  2428. /// low-precision configuration of the RNN primitives expects input weights to
  2429. /// use the signed 8-bit integer data type. The scaling factors are used to
  2430. /// quantize floating-point data to signed integer and must be passed to RNN
  2431. /// primitives using attributes.
  2432. ///
  2433. /// @note
  2434. /// The dimension order is always native and does not depend on the actual
  2435. /// layout used. For example, five-dimensional weights always have (l, d,
  2436. /// i, g, o) logical dimension ordering.
  2437. ///
  2438. /// @note
  2439. /// Quantization scales are common for weights_layer and weights_iteration
  2440. ///
  2441. /// @param attr Primitive attributes.
  2442. /// @param count Number of elements in the @p scales array.
  2443. /// @param mask Scaling factors correspondence mask that defines the
  2444. /// correspondence between the output tensor dimensions and the @p
  2445. /// scales vector. The set i-th bit indicates that a dedicated scaling
  2446. /// factor should be used for each index along that dimension. Set the
  2447. /// mask to 0 to use a common scaling factor for the whole output
  2448. /// tensor.
  2449. /// @param scales Array of output scaling factors that must contain @p count
  2450. /// values and the following equality must hold:
  2451. /// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
  2452. /// Violations can only be detected when the attributes are used to create
  2453. /// a primitive descriptor.
  2454. /// @returns #dnnl_success on success and a status describing the error
  2455. /// otherwise.
  2456. dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_qparams(
  2457. dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask,
  2458. const float *scales);
  2459. /// Returns the quantization scaling factors for RNN weights tensors.
  2460. ///
  2461. /// @param attr Primitive attributes.
  2462. /// @param count Number of elements in the @p scales array.
  2463. /// @param mask Scaling factors correspondence mask that defines the
  2464. /// correspondence between the output tensor dimensions and the @p
  2465. /// scales vector. The set i-th bit indicates that a dedicated scaling
  2466. /// factor should be used for each index along that dimension. Set the
  2467. /// mask to 0 to use a common scaling factor for the whole output
  2468. /// tensor.
  2469. /// @param scales Array of output scaling factors that contain @p count
  2470. /// values and the following equality must hold:
  2471. /// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
  2472. /// @returns #dnnl_success on success and a status describing the error
  2473. /// otherwise.
  2474. dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_qparams(
  2475. const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask,
  2476. const float **scales);
  2477. /// Sets quantization scaling factors for RNN projection weights tensors. The
  2478. /// low-precision configuration of the RNN primitives expects input weights to
  2479. /// use the signed 8-bit integer data type. The scaling factors are used to
  2480. /// quantize floating-point data to signed integer and must be passed to RNN
  2481. /// primitives using attributes.
  2482. ///
  2483. /// @note
  2484. /// The dimension order is always native and does not depend on the actual
  2485. /// layout used. For example, five-dimensional weights always have (l, d,
  2486. /// i, g, o) logical dimension ordering.
  2487. ///
  2488. /// @param attr Primitive attributes.
  2489. /// @param count Number of elements in the @p scales array.
  2490. /// @param mask Scaling factors correspondence mask that defines the
  2491. /// correspondence between the output tensor dimensions and the @p
  2492. /// scales vector. The set i-th bit indicates that a dedicated scaling
  2493. /// factor should be used for each index along that dimension. Set the
  2494. /// mask to 0 to use a common scaling factor for the whole output
  2495. /// tensor.
  2496. /// @param scales Array of output scaling factors that must contain @p count
  2497. /// values and the following equality must hold:
  2498. /// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
  2499. /// Violations can only be detected when the attributes are used to create
  2500. /// a primitive descriptor.
  2501. /// @returns #dnnl_success on success and a status describing the error
  2502. /// otherwise.
  2503. dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_projection_qparams(
  2504. dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask,
  2505. const float *scales);
  2506. /// Returns the quantization scaling factors for RNN projection weights tensors.
  2507. ///
  2508. /// @param attr Primitive attributes.
  2509. /// @param count Number of elements in the @p scales array.
  2510. /// @param mask Scaling factors correspondence mask that defines the
  2511. /// correspondence between the output tensor dimensions and the @p
  2512. /// scales vector. The set i-th bit indicates that a dedicated scaling
  2513. /// factor should be used for each index along that dimension. Set the
  2514. /// mask to 0 to use a common scaling factor for the whole output
  2515. /// tensor.
  2516. /// @param scales Array of output scaling factors that contain @p count
  2517. /// values and the following equality must hold:
  2518. /// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
  2519. /// @returns #dnnl_success on success and a status describing the error
  2520. /// otherwise.
  2521. dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_projection_qparams(
  2522. const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask,
  2523. const float **scales);
  2524. /// @} dnnl_api_attributes
  2525. /// @addtogroup dnnl_api_rnn
  2526. /// @{
  2527. /// Creates a primitive descriptor for vanilla RNN forward propagation
  2528. /// primitive.
  2529. ///
  2530. /// The following arguments may either be @c NULL or point to a zero memory
  2531. /// descriptor:
  2532. /// - @p src_iter_desc,
  2533. /// - @p bias_desc,
  2534. /// - @p dst_iter_desc.
  2535. ///
  2536. /// This would then indicate that the RNN forward propagation primitive should
  2537. /// not use them and should default to zero values instead.
  2538. ///
  2539. /// @note
  2540. /// All memory descriptors can be initialized with
  2541. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  2542. ///
  2543. /// @param primitive_desc Output primitive descriptor.
  2544. /// @param engine Engine to use.
  2545. /// @param prop_kind Propagation kind. Possible values are
  2546. /// #dnnl_forward_training and #dnnl_forward_inference.
  2547. /// @param activation Activation kind. Possible values are #dnnl_eltwise_relu,
  2548. /// #dnnl_eltwise_tanh or #dnnl_eltwise_logistic.
  2549. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  2550. /// info.
  2551. /// @param src_layer_desc Memory descriptor for the input vector.
  2552. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  2553. /// state vector.
  2554. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  2555. /// layer input.
  2556. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  2557. /// recurrent input.
  2558. /// @param bias_desc Bias memory descriptor.
  2559. /// @param dst_layer_desc Memory descriptor for the output vector.
  2560. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  2561. /// state vector.
  2562. /// @param flags Unused.
  2563. /// @param alpha Negative slope if activation is #dnnl_eltwise_relu.
  2564. /// @param beta Unused.
  2565. /// @param attr Primitive attributes (can be NULL).
  2566. /// @returns #dnnl_success on success and a status describing the error
  2567. /// otherwise.
  2568. dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_primitive_desc_create(
  2569. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2570. dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation,
  2571. const dnnl_rnn_direction_t direction,
  2572. const_dnnl_memory_desc_t src_layer_desc,
  2573. const_dnnl_memory_desc_t src_iter_desc,
  2574. const_dnnl_memory_desc_t weights_layer_desc,
  2575. const_dnnl_memory_desc_t weights_iter_desc,
  2576. const_dnnl_memory_desc_t bias_desc,
  2577. const_dnnl_memory_desc_t dst_layer_desc,
  2578. const_dnnl_memory_desc_t dst_iter_desc, unsigned flags, float alpha,
  2579. float beta, const_dnnl_primitive_attr_t attr);
  2580. /// Creates a primitive descriptor for vanilla RNN backward propagation
  2581. /// primitive.
  2582. ///
  2583. /// The following arguments may either be @c NULL or point to a zero memory
  2584. /// descriptor:
  2585. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  2586. /// - @p bias_desc together with @p diff_bias_desc,
  2587. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  2588. ///
  2589. /// This would then indicate that the RNN backward propagation primitive should
  2590. /// not use the respective data and should use zero values instead.
  2591. ///
  2592. /// @note
  2593. /// All memory descriptors can be initialized with
  2594. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  2595. ///
  2596. /// @param primitive_desc Output primitive descriptor.
  2597. /// @param engine Engine to use.
  2598. /// @param prop_kind Propagation kind. Must be #dnnl_backward.
  2599. /// @param activation Activation kind. Possible values are #dnnl_eltwise_relu,
  2600. /// #dnnl_eltwise_tanh or #dnnl_eltwise_logistic.
  2601. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  2602. /// info.
  2603. /// @param src_layer_desc Memory descriptor for the input vector.
  2604. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  2605. /// state vector.
  2606. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  2607. /// layer input.
  2608. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  2609. /// recurrent input.
  2610. /// @param bias_desc Bias memory descriptor.
  2611. /// @param dst_layer_desc Memory descriptor for the output vector.
  2612. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  2613. /// state vector.
  2614. /// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
  2615. /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
  2616. /// hidden state vector.
  2617. /// @param diff_weights_layer_desc Memory descriptor for the diff of weights
  2618. /// applied to the layer input.
  2619. /// @param diff_weights_iter_desc Memory descriptor for the diff of weights
  2620. /// applied to the recurrent input.
  2621. /// @param diff_bias_desc Diff bias memory descriptor.
  2622. /// @param diff_dst_layer_desc Memory descriptor for the diff of output
  2623. /// vector.
  2624. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  2625. /// recurrent hidden state vector.
  2626. /// @param flags Unused.
  2627. /// @param alpha Negative slope if activation is #dnnl_eltwise_relu.
  2628. /// @param beta Unused.
  2629. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2630. /// primitive.
  2631. /// @param attr Primitive attributes (can be NULL).
  2632. /// @returns #dnnl_success on success and a status describing the error
  2633. /// otherwise.
  2634. dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_primitive_desc_create(
  2635. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2636. dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation,
  2637. const dnnl_rnn_direction_t direction,
  2638. const_dnnl_memory_desc_t src_layer_desc,
  2639. const_dnnl_memory_desc_t src_iter_desc,
  2640. const_dnnl_memory_desc_t weights_layer_desc,
  2641. const_dnnl_memory_desc_t weights_iter_desc,
  2642. const_dnnl_memory_desc_t bias_desc,
  2643. const_dnnl_memory_desc_t dst_layer_desc,
  2644. const_dnnl_memory_desc_t dst_iter_desc,
  2645. const_dnnl_memory_desc_t diff_src_layer_desc,
  2646. const_dnnl_memory_desc_t diff_src_iter_desc,
  2647. const_dnnl_memory_desc_t diff_weights_layer_desc,
  2648. const_dnnl_memory_desc_t diff_weights_iter_desc,
  2649. const_dnnl_memory_desc_t diff_bias_desc,
  2650. const_dnnl_memory_desc_t diff_dst_layer_desc,
  2651. const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
  2652. float alpha, float beta, const_dnnl_primitive_desc_t hint_fwd_pd,
  2653. const_dnnl_primitive_attr_t attr);
  2654. /// Creates a primitive descriptor for an LSTM forward propagation primitive.
  2655. ///
  2656. /// The following arguments may either be @c NULL or point to a zero memory
  2657. /// descriptor:
  2658. /// - @p src_iter_desc together with @p src_iter_c_desc,
  2659. /// - @p weights_peephole_desc,
  2660. /// - @p bias_desc,
  2661. /// - @p dst_iter_desc together with @p dst_iter_c_desc.
  2662. ///
  2663. /// This would then indicate that the LSTM forward propagation primitive should
  2664. /// not use them and should default to zero values instead.
  2665. ///
  2666. /// The @p weights_projection_desc could either be @c NULL or point to a zero
  2667. /// memory descriptor. This would then indicate that the LSTM doesn't have
  2668. /// recurrent projection layer.
  2669. ///
  2670. /// @note
  2671. /// All memory descriptors can be initialized with #dnnl_format_tag_any or
  2672. /// with format_kind set to #dnnl_format_kind_any.
  2673. ///
  2674. /// @param primitive_desc Output primitive descriptor.
  2675. /// @param engine Engine to use.
  2676. /// @param prop_kind Propagation kind. Possible values are
  2677. /// #dnnl_forward_training and #dnnl_forward_inference.
  2678. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  2679. /// info.
  2680. /// @param src_layer_desc Memory descriptor for the input vector.
  2681. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  2682. /// state vector.
  2683. /// @param src_iter_c_desc Memory descriptor for the input recurrent cell
  2684. /// state vector.
  2685. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  2686. /// layer input.
  2687. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  2688. /// recurrent input.
  2689. /// @param weights_peephole_desc Memory descriptor for the weights applied to
  2690. /// the cell states (according to the Peephole LSTM formula).
  2691. /// @param weights_projection_desc Memory descriptor for the weights applied to
  2692. /// the hidden states to get the recurrent projection (according to the
  2693. /// Projection LSTM formula).
  2694. /// @param bias_desc Bias memory descriptor.
  2695. /// @param dst_layer_desc Memory descriptor for the output vector.
  2696. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  2697. /// state vector.
  2698. /// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
  2699. /// state vector.
  2700. /// @param flags Unused.
  2701. /// @param attr Primitive attributes (can be NULL).
  2702. /// @returns #dnnl_success on success and a status describing the error
  2703. /// otherwise.
  2704. dnnl_status_t DNNL_API dnnl_lstm_forward_primitive_desc_create(
  2705. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2706. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  2707. const_dnnl_memory_desc_t src_layer_desc,
  2708. const_dnnl_memory_desc_t src_iter_desc,
  2709. const_dnnl_memory_desc_t src_iter_c_desc,
  2710. const_dnnl_memory_desc_t weights_layer_desc,
  2711. const_dnnl_memory_desc_t weights_iter_desc,
  2712. const_dnnl_memory_desc_t weights_peephole_desc,
  2713. const_dnnl_memory_desc_t weights_projection_desc,
  2714. const_dnnl_memory_desc_t bias_desc,
  2715. const_dnnl_memory_desc_t dst_layer_desc,
  2716. const_dnnl_memory_desc_t dst_iter_desc,
  2717. const_dnnl_memory_desc_t dst_iter_c_desc, unsigned flags,
  2718. const_dnnl_primitive_attr_t attr);
  2719. /// Creates a primitive descriptor for an LSTM backward propagation primitive.
  2720. ///
  2721. /// The following arguments may either be @c NULL or point to a zero memory
  2722. /// descriptor:
  2723. /// - @p src_iter_desc together with @p src_iter_c_desc, @p diff_src_iter_desc,
  2724. /// and @p diff_src_iter_c_desc,
  2725. /// - @p weights_peephole_desc together with @p diff_weights_peephole_desc,
  2726. /// - @p bias_desc together with @p diff_bias_desc,
  2727. /// - @p dst_iter_desc together with @p dst_iter_c_desc, @p diff_dst_iter_desc,
  2728. /// and @p diff_dst_iter_c_desc.
  2729. ///
  2730. /// This would then indicate that the LSTM backward propagation primitive
  2731. /// should not use them and should default to zero values instead.
  2732. ///
  2733. /// The @p weights_projection_desc together with @p
  2734. /// diff_weights_projection_desc could either be @c NULL or point to a zero
  2735. /// memory descriptor. This would then indicate that the LSTM doesn't have
  2736. /// recurrent projection layer.
  2737. ///
  2738. /// @note
  2739. /// All memory descriptors can be initialized with #dnnl_format_tag_any or
  2740. /// with format_kind set to #dnnl_format_kind_any.
  2741. ///
  2742. /// @param primitive_desc Output primitive descriptor.
  2743. /// @param engine Engine to use.
  2744. /// @param prop_kind Propagation kind. Must be #dnnl_backward.
  2745. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  2746. /// info.
  2747. /// @param src_layer_desc Memory descriptor for the input vector.
  2748. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  2749. /// state vector.
  2750. /// @param src_iter_c_desc Memory descriptor for the input recurrent cell
  2751. /// state vector.
  2752. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  2753. /// layer input.
  2754. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  2755. /// recurrent input.
  2756. /// @param weights_peephole_desc Memory descriptor for the weights applied to
  2757. /// the cell states (according to the Peephole LSTM formula).
  2758. /// @param weights_projection_desc Memory descriptor for the weights applied to
  2759. /// the hidden states to get the recurrent projection (according to the
  2760. /// Projection LSTM formula).
  2761. /// @param bias_desc Bias memory descriptor.
  2762. /// @param dst_layer_desc Memory descriptor for the output vector.
  2763. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  2764. /// state vector.
  2765. /// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
  2766. /// state vector.
  2767. /// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
  2768. /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
  2769. /// hidden state vector.
  2770. /// @param diff_src_iter_c_desc Memory descriptor for the diff of input
  2771. /// recurrent cell state vector.
  2772. /// @param diff_weights_layer_desc Memory descriptor for the diff of weights
  2773. /// applied to the layer input.
  2774. /// @param diff_weights_iter_desc Memory descriptor for the diff of weights
  2775. /// applied to the recurrent input.
  2776. /// @param diff_weights_peephole_desc Memory descriptor for the diff of weights
  2777. /// applied to the cell states (according to the Peephole LSTM formula).
  2778. /// @param diff_weights_projection_desc Memory descriptor for the diff of
  2779. /// weights applied to the hidden states to get the recurrent projection
  2780. /// (according to the Projection LSTM formula).
  2781. /// @param diff_bias_desc Diff bias memory descriptor.
  2782. /// @param diff_dst_layer_desc Memory descriptor for the diff of output
  2783. /// vector.
  2784. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  2785. /// recurrent hidden state vector.
  2786. /// @param diff_dst_iter_c_desc Memory descriptor for the diff of output
  2787. /// recurrent cell state vector.
  2788. /// @param flags Unused.
  2789. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2790. /// primitive.
  2791. /// @param attr Primitive attributes (can be NULL).
  2792. /// @returns #dnnl_success on success and a status describing the error
  2793. /// otherwise.
  2794. dnnl_status_t DNNL_API dnnl_lstm_backward_primitive_desc_create(
  2795. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2796. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  2797. const_dnnl_memory_desc_t src_layer_desc,
  2798. const_dnnl_memory_desc_t src_iter_desc,
  2799. const_dnnl_memory_desc_t src_iter_c_desc,
  2800. const_dnnl_memory_desc_t weights_layer_desc,
  2801. const_dnnl_memory_desc_t weights_iter_desc,
  2802. const_dnnl_memory_desc_t weights_peephole_desc,
  2803. const_dnnl_memory_desc_t weights_projection_desc,
  2804. const_dnnl_memory_desc_t bias_desc,
  2805. const_dnnl_memory_desc_t dst_layer_desc,
  2806. const_dnnl_memory_desc_t dst_iter_desc,
  2807. const_dnnl_memory_desc_t dst_iter_c_desc,
  2808. const_dnnl_memory_desc_t diff_src_layer_desc,
  2809. const_dnnl_memory_desc_t diff_src_iter_desc,
  2810. const_dnnl_memory_desc_t diff_src_iter_c_desc,
  2811. const_dnnl_memory_desc_t diff_weights_layer_desc,
  2812. const_dnnl_memory_desc_t diff_weights_iter_desc,
  2813. const_dnnl_memory_desc_t diff_weights_peephole_desc,
  2814. const_dnnl_memory_desc_t diff_weights_projection_desc,
  2815. const_dnnl_memory_desc_t diff_bias_desc,
  2816. const_dnnl_memory_desc_t diff_dst_layer_desc,
  2817. const_dnnl_memory_desc_t diff_dst_iter_desc,
  2818. const_dnnl_memory_desc_t diff_dst_iter_c_desc, unsigned flags,
  2819. const_dnnl_primitive_desc_t hint_fwd_pd,
  2820. const_dnnl_primitive_attr_t attr);
  2821. /// Creates a primitive descriptor for GRU forward propagation primitive.
  2822. ///
  2823. /// The following arguments may either be @c NULL or point to a zero memory
  2824. /// descriptor:
  2825. /// - @p src_iter_desc,
  2826. /// - @p bias_desc,
  2827. /// - @p dst_iter_desc.
  2828. ///
  2829. /// This would then indicate that the GRU forward propagation primitive should
  2830. /// not use them and should default to zero values instead.
  2831. ///
  2832. /// @note
  2833. /// All memory descriptors can be initialized with
  2834. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  2835. ///
  2836. /// @param primitive_desc Output primitive descriptor.
  2837. /// @param engine Engine to use.
  2838. /// @param prop_kind Propagation kind. Possible values are
  2839. /// #dnnl_forward_training and #dnnl_forward_inference.
  2840. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  2841. /// info.
  2842. /// @param src_layer_desc Memory descriptor for the input vector.
  2843. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  2844. /// state vector.
  2845. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  2846. /// layer input.
  2847. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  2848. /// recurrent input.
  2849. /// @param bias_desc Bias memory descriptor.
  2850. /// @param dst_layer_desc Memory descriptor for the output vector.
  2851. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  2852. /// state vector.
  2853. /// @param flags Unused.
  2854. /// @param attr Primitive attributes (can be NULL).
  2855. /// @returns #dnnl_success on success and a status describing the error
  2856. /// otherwise.
  2857. dnnl_status_t DNNL_API dnnl_gru_forward_primitive_desc_create(
  2858. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2859. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  2860. const_dnnl_memory_desc_t src_layer_desc,
  2861. const_dnnl_memory_desc_t src_iter_desc,
  2862. const_dnnl_memory_desc_t weights_layer_desc,
  2863. const_dnnl_memory_desc_t weights_iter_desc,
  2864. const_dnnl_memory_desc_t bias_desc,
  2865. const_dnnl_memory_desc_t dst_layer_desc,
  2866. const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
  2867. const_dnnl_primitive_attr_t attr);
  2868. /// Creates a primitive descriptor for GRU backward propagation primitive.
  2869. ///
  2870. /// The following arguments may either be @c NULL or point to a zero memory
  2871. /// descriptor:
  2872. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  2873. /// - @p bias_desc together with @p diff_bias_desc,
  2874. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  2875. ///
  2876. /// This would then indicate that the GRU backward propagation primitive
  2877. /// should not use them and should default to zero values instead.
  2878. ///
  2879. /// @note
  2880. /// All memory descriptors can be initialized with
  2881. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  2882. ///
  2883. /// @param primitive_desc Output primitive descriptor.
  2884. /// @param engine Engine to use.
  2885. /// @param prop_kind Propagation kind. Must be #dnnl_backward.
  2886. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  2887. /// info.
  2888. /// @param src_layer_desc Memory descriptor for the input vector.
  2889. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  2890. /// state vector.
  2891. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  2892. /// layer input.
  2893. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  2894. /// recurrent input.
  2895. /// @param bias_desc Bias memory descriptor.
  2896. /// @param dst_layer_desc Memory descriptor for the output vector.
  2897. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  2898. /// state vector.
  2899. /// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
  2900. /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
  2901. /// hidden state vector.
  2902. /// @param diff_weights_layer_desc Memory descriptor for the diff of weights
  2903. /// applied to the layer input.
  2904. /// @param diff_weights_iter_desc Memory descriptor for the diff of weights
  2905. /// applied to the recurrent input.
  2906. /// @param diff_bias_desc Diff bias memory descriptor.
  2907. /// @param diff_dst_layer_desc Memory descriptor for the diff of output
  2908. /// vector.
  2909. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  2910. /// recurrent hidden state vector.
  2911. /// @param flags Unused.
  2912. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2913. /// primitive.
  2914. /// @param attr Primitive attributes (can be NULL).
  2915. /// @returns #dnnl_success on success and a status describing the error
  2916. /// otherwise.
  2917. dnnl_status_t DNNL_API dnnl_gru_backward_primitive_desc_create(
  2918. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2919. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  2920. const_dnnl_memory_desc_t src_layer_desc,
  2921. const_dnnl_memory_desc_t src_iter_desc,
  2922. const_dnnl_memory_desc_t weights_layer_desc,
  2923. const_dnnl_memory_desc_t weights_iter_desc,
  2924. const_dnnl_memory_desc_t bias_desc,
  2925. const_dnnl_memory_desc_t dst_layer_desc,
  2926. const_dnnl_memory_desc_t dst_iter_desc,
  2927. const_dnnl_memory_desc_t diff_src_layer_desc,
  2928. const_dnnl_memory_desc_t diff_src_iter_desc,
  2929. const_dnnl_memory_desc_t diff_weights_layer_desc,
  2930. const_dnnl_memory_desc_t diff_weights_iter_desc,
  2931. const_dnnl_memory_desc_t diff_bias_desc,
  2932. const_dnnl_memory_desc_t diff_dst_layer_desc,
  2933. const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
  2934. const_dnnl_primitive_desc_t hint_fwd_pd,
  2935. const_dnnl_primitive_attr_t attr);
  2936. /// Creates a descriptor for LBR GRU forward propagation primitive.
  2937. ///
  2938. /// The following arguments may either be @c NULL or point to a zero memory
  2939. /// descriptor:
  2940. /// - @p src_iter_desc,
  2941. /// - @p bias_desc,
  2942. /// - @p dst_iter_desc.
  2943. ///
  2944. /// This would then indicate that the LBR GRU forward propagation primitive
  2945. /// should not use them and should default to zero values instead.
  2946. ///
  2947. /// @param primitive_desc Output primitive descriptor.
  2948. /// @param engine Engine to use.
  2949. /// @param prop_kind Propagation kind. Possible values are
  2950. /// #dnnl_forward_training and #dnnl_forward_inference.
  2951. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  2952. /// info.
  2953. /// @param src_layer_desc Memory descriptor for the input vector.
  2954. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  2955. /// state vector.
  2956. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  2957. /// layer input.
  2958. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  2959. /// recurrent input.
  2960. /// @param bias_desc Bias memory descriptor.
  2961. /// @param dst_layer_desc Memory descriptor for the output vector.
  2962. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  2963. /// state vector.
  2964. /// @param flags Unused.
  2965. /// @param attr Primitive attributes (can be NULL).
  2966. /// @returns #dnnl_success on success and a status describing the error
  2967. /// otherwise.
  2968. dnnl_status_t DNNL_API dnnl_lbr_gru_forward_primitive_desc_create(
  2969. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2970. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  2971. const_dnnl_memory_desc_t src_layer_desc,
  2972. const_dnnl_memory_desc_t src_iter_desc,
  2973. const_dnnl_memory_desc_t weights_layer_desc,
  2974. const_dnnl_memory_desc_t weights_iter_desc,
  2975. const_dnnl_memory_desc_t bias_desc,
  2976. const_dnnl_memory_desc_t dst_layer_desc,
  2977. const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
  2978. const_dnnl_primitive_attr_t attr);
  2979. /// Creates a primitive descriptor for LBR GRU backward propagation primitive.
  2980. ///
  2981. /// The following arguments may either be @c NULL or point to a zero memory
  2982. /// descriptor:
  2983. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  2984. /// - @p bias_desc together with @p diff_bias_desc,
  2985. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  2986. ///
  2987. /// This would then indicate that the LBR GRU backward propagation primitive
  2988. /// should not use them and should default to zero values instead.
  2989. ///
  2990. /// @note
  2991. /// All memory descriptors can be initialized with
  2992. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  2993. ///
  2994. /// @param primitive_desc Output primitive descriptor.
  2995. /// @param engine Engine to use.
  2996. /// @param prop_kind Propagation kind. Must be #dnnl_backward.
  2997. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  2998. /// info.
  2999. /// @param src_layer_desc Memory descriptor for the input vector.
  3000. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  3001. /// state vector.
  3002. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  3003. /// layer input.
  3004. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  3005. /// recurrent input.
  3006. /// @param bias_desc Bias memory descriptor.
  3007. /// @param dst_layer_desc Memory descriptor for the output vector.
  3008. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  3009. /// state vector.
  3010. /// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
  3011. /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
  3012. /// hidden state vector.
  3013. /// @param diff_weights_layer_desc Memory descriptor for the diff of weights
  3014. /// applied to the layer input.
  3015. /// @param diff_weights_iter_desc Memory descriptor for the diff of weights
  3016. /// applied to the recurrent input.
  3017. /// @param diff_bias_desc Diff bias memory descriptor.
  3018. /// @param diff_dst_layer_desc Memory descriptor for the diff of output
  3019. /// vector.
  3020. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  3021. /// recurrent hidden state vector.
  3022. /// @param flags Unused.
  3023. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  3024. /// primitive.
  3025. /// @param attr Primitive attributes (can be NULL).
  3026. /// @returns #dnnl_success on success and a status describing the error
  3027. /// otherwise.
  3028. dnnl_status_t DNNL_API dnnl_lbr_gru_backward_primitive_desc_create(
  3029. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3030. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  3031. const_dnnl_memory_desc_t src_layer_desc,
  3032. const_dnnl_memory_desc_t src_iter_desc,
  3033. const_dnnl_memory_desc_t weights_layer_desc,
  3034. const_dnnl_memory_desc_t weights_iter_desc,
  3035. const_dnnl_memory_desc_t bias_desc,
  3036. const_dnnl_memory_desc_t dst_layer_desc,
  3037. const_dnnl_memory_desc_t dst_iter_desc,
  3038. const_dnnl_memory_desc_t diff_src_layer_desc,
  3039. const_dnnl_memory_desc_t diff_src_iter_desc,
  3040. const_dnnl_memory_desc_t diff_weights_layer_desc,
  3041. const_dnnl_memory_desc_t diff_weights_iter_desc,
  3042. const_dnnl_memory_desc_t diff_bias_desc,
  3043. const_dnnl_memory_desc_t diff_dst_layer_desc,
  3044. const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
  3045. const_dnnl_primitive_desc_t hint_fwd_pd,
  3046. const_dnnl_primitive_attr_t attr);
  3047. /// Creates a primitive descriptor for AUGRU forward propagation primitive.
  3048. ///
  3049. /// The following arguments may either be @c NULL or point to a zero memory
  3050. /// descriptor:
  3051. /// - @p src_iter_desc,
  3052. /// - @p bias_desc,
  3053. /// - @p dst_iter_desc.
  3054. ///
  3055. /// This would then indicate that the AUGRU forward propagation primitive should
  3056. /// not use them and should default to zero values instead.
  3057. ///
  3058. /// @note
  3059. /// All memory descriptors can be initialized with
  3060. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  3061. ///
  3062. /// @param primitive_desc Output primitive descriptor.
  3063. /// @param engine Engine to use.
  3064. /// @param prop_kind Propagation kind. Possible values are
  3065. /// #dnnl_forward_training and #dnnl_forward_inference.
  3066. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  3067. /// info.
  3068. /// @param src_layer_desc Memory descriptor for the input vector.
  3069. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  3070. /// state vector.
  3071. /// @param attention_desc Memory descriptor for the attention vector.
  3072. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  3073. /// layer input.
  3074. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  3075. /// recurrent input.
  3076. /// @param bias_desc Bias memory descriptor.
  3077. /// @param dst_layer_desc Memory descriptor for the output vector.
  3078. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  3079. /// state vector.
  3080. /// @param flags Unused.
  3081. /// @param attr Primitive attributes (can be NULL).
  3082. /// @returns #dnnl_success on success and a status describing the error
  3083. /// otherwise.
  3084. dnnl_status_t DNNL_API dnnl_augru_forward_primitive_desc_create(
  3085. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3086. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  3087. const_dnnl_memory_desc_t src_layer_desc,
  3088. const_dnnl_memory_desc_t src_iter_desc,
  3089. const_dnnl_memory_desc_t attention_desc,
  3090. const_dnnl_memory_desc_t weights_layer_desc,
  3091. const_dnnl_memory_desc_t weights_iter_desc,
  3092. const_dnnl_memory_desc_t bias_desc,
  3093. const_dnnl_memory_desc_t dst_layer_desc,
  3094. const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
  3095. const_dnnl_primitive_attr_t attr);
  3096. /// Creates a primitive descriptor for AUGRU backward propagation primitive.
  3097. ///
  3098. /// The following arguments may either be @c NULL or point to a zero memory
  3099. /// descriptor:
  3100. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  3101. /// - @p bias_desc together with @p diff_bias_desc,
  3102. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  3103. ///
  3104. /// This would then indicate that the AUGRU backward propagation primitive
  3105. /// should not use them and should default to zero values instead.
  3106. ///
  3107. /// @note
  3108. /// All memory descriptors can be initialized with
  3109. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  3110. ///
  3111. /// @param primitive_desc Output primitive descriptor.
  3112. /// @param engine Engine to use.
  3113. /// @param prop_kind Propagation kind. Must be #dnnl_backward.
  3114. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  3115. /// info.
  3116. /// @param src_layer_desc Memory descriptor for the input vector.
  3117. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  3118. /// state vector.
  3119. /// @param attention_desc Memory descriptor for the attention vector.
  3120. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  3121. /// layer input.
  3122. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  3123. /// recurrent input.
  3124. /// @param bias_desc Bias memory descriptor.
  3125. /// @param dst_layer_desc Memory descriptor for the output vector.
  3126. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  3127. /// state vector.
  3128. /// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
  3129. /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
  3130. /// hidden state vector.
  3131. /// @param diff_attention_desc Memory descriptor for the diff of attention vector.
  3132. /// @param diff_weights_layer_desc Memory descriptor for the diff of weights
  3133. /// applied to the layer input.
  3134. /// @param diff_weights_iter_desc Memory descriptor for the diff of weights
  3135. /// applied to the recurrent input.
  3136. /// @param diff_bias_desc Diff bias memory descriptor.
  3137. /// @param diff_dst_layer_desc Memory descriptor for the diff of output
  3138. /// vector.
  3139. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  3140. /// recurrent hidden state vector.
  3141. /// @param flags Unused.
  3142. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  3143. /// primitive.
  3144. /// @param attr Primitive attributes (can be NULL).
  3145. /// @returns #dnnl_success on success and a status describing the error
  3146. /// otherwise.
  3147. dnnl_status_t DNNL_API dnnl_augru_backward_primitive_desc_create(
  3148. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3149. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  3150. const_dnnl_memory_desc_t src_layer_desc,
  3151. const_dnnl_memory_desc_t src_iter_desc,
  3152. const_dnnl_memory_desc_t attention_desc,
  3153. const_dnnl_memory_desc_t weights_layer_desc,
  3154. const_dnnl_memory_desc_t weights_iter_desc,
  3155. const_dnnl_memory_desc_t bias_desc,
  3156. const_dnnl_memory_desc_t dst_layer_desc,
  3157. const_dnnl_memory_desc_t dst_iter_desc,
  3158. const_dnnl_memory_desc_t diff_src_layer_desc,
  3159. const_dnnl_memory_desc_t diff_src_iter_desc,
  3160. const_dnnl_memory_desc_t diff_attention_desc,
  3161. const_dnnl_memory_desc_t diff_weights_layer_desc,
  3162. const_dnnl_memory_desc_t diff_weights_iter_desc,
  3163. const_dnnl_memory_desc_t diff_bias_desc,
  3164. const_dnnl_memory_desc_t diff_dst_layer_desc,
  3165. const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
  3166. const_dnnl_primitive_desc_t hint_fwd_pd,
  3167. const_dnnl_primitive_attr_t attr);
  3168. /// Creates a primitive descriptor for LBR AUGRU forward propagation primitive.
  3169. ///
  3170. /// The following arguments may either be @c NULL or point to a zero memory
  3171. /// descriptor:
  3172. /// - @p src_iter_desc,
  3173. /// - @p bias_desc,
  3174. /// - @p dst_iter_desc.
  3175. ///
  3176. /// This would then indicate that the LBR AUGRU forward propagation primitive
  3177. /// should not use them and should default to zero values instead.
  3178. ///
  3179. /// @param primitive_desc Output primitive descriptor.
  3180. /// @param engine Engine to use.
  3181. /// @param prop_kind Propagation kind. Possible values are
  3182. /// #dnnl_forward_training and #dnnl_forward_inference.
  3183. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  3184. /// info.
  3185. /// @param src_layer_desc Memory descriptor for the input vector.
  3186. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  3187. /// state vector.
  3188. /// @param attention_desc Memory descriptor for the attention vector.
  3189. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  3190. /// layer input.
  3191. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  3192. /// recurrent input.
  3193. /// @param bias_desc Bias memory descriptor.
  3194. /// @param dst_layer_desc Memory descriptor for the output vector.
  3195. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  3196. /// state vector.
  3197. /// @param flags Unused.
  3198. /// @param attr Primitive attributes (can be NULL).
  3199. /// @returns #dnnl_success on success and a status describing the error
  3200. /// otherwise.
  3201. dnnl_status_t DNNL_API dnnl_lbr_augru_forward_primitive_desc_create(
  3202. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3203. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  3204. const_dnnl_memory_desc_t src_layer_desc,
  3205. const_dnnl_memory_desc_t src_iter_desc,
  3206. const_dnnl_memory_desc_t attention_desc,
  3207. const_dnnl_memory_desc_t weights_layer_desc,
  3208. const_dnnl_memory_desc_t weights_iter_desc,
  3209. const_dnnl_memory_desc_t bias_desc,
  3210. const_dnnl_memory_desc_t dst_layer_desc,
  3211. const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
  3212. const_dnnl_primitive_attr_t attr);
  3213. /// Creates a primitive descriptor for LBR AUGRU backward propagation primitive.
  3214. ///
  3215. /// The following arguments may either be @c NULL or point to a zero memory
  3216. /// descriptor:
  3217. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  3218. /// - @p bias_desc together with @p diff_bias_desc,
  3219. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  3220. ///
  3221. /// This would then indicate that the LBR AUGRU backward propagation primitive
  3222. /// should not use them and should default to zero values instead.
  3223. ///
  3224. /// @note
  3225. /// All memory descriptors can be initialized with
  3226. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  3227. ///
  3228. /// @param primitive_desc Output primitive descriptor.
  3229. /// @param engine Engine to use.
  3230. /// @param prop_kind Propagation kind. Must be #dnnl_backward.
  3231. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  3232. /// info.
  3233. /// @param src_layer_desc Memory descriptor for the input vector.
  3234. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  3235. /// state vector.
  3236. /// @param attention_desc Memory descriptor for the attention vector.
  3237. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  3238. /// layer input.
  3239. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  3240. /// recurrent input.
  3241. /// @param bias_desc Bias memory descriptor.
  3242. /// @param dst_layer_desc Memory descriptor for the output vector.
  3243. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  3244. /// state vector.
  3245. /// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
  3246. /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
  3247. /// hidden state vector.
  3248. /// @param diff_attention_desc Memory descriptor for the diff of attention vector.
  3249. /// @param diff_weights_layer_desc Memory descriptor for the diff of weights
  3250. /// applied to the layer input.
  3251. /// @param diff_weights_iter_desc Memory descriptor for the diff of weights
  3252. /// applied to the recurrent input.
  3253. /// @param diff_bias_desc Diff bias memory descriptor.
  3254. /// @param diff_dst_layer_desc Memory descriptor for the diff of output
  3255. /// vector.
  3256. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  3257. /// recurrent hidden state vector.
  3258. /// @param flags Unused.
  3259. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  3260. /// primitive.
  3261. /// @param attr Primitive attributes (can be NULL).
  3262. /// @returns #dnnl_success on success and a status describing the error
  3263. /// otherwise.
  3264. dnnl_status_t DNNL_API dnnl_lbr_augru_backward_primitive_desc_create(
  3265. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3266. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  3267. const_dnnl_memory_desc_t src_layer_desc,
  3268. const_dnnl_memory_desc_t src_iter_desc,
  3269. const_dnnl_memory_desc_t attention_desc,
  3270. const_dnnl_memory_desc_t weights_layer_desc,
  3271. const_dnnl_memory_desc_t weights_iter_desc,
  3272. const_dnnl_memory_desc_t bias_desc,
  3273. const_dnnl_memory_desc_t dst_layer_desc,
  3274. const_dnnl_memory_desc_t dst_iter_desc,
  3275. const_dnnl_memory_desc_t diff_src_layer_desc,
  3276. const_dnnl_memory_desc_t diff_src_iter_desc,
  3277. const_dnnl_memory_desc_t diff_attention_desc,
  3278. const_dnnl_memory_desc_t diff_weights_layer_desc,
  3279. const_dnnl_memory_desc_t diff_weights_iter_desc,
  3280. const_dnnl_memory_desc_t diff_bias_desc,
  3281. const_dnnl_memory_desc_t diff_dst_layer_desc,
  3282. const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
  3283. const_dnnl_primitive_desc_t hint_fwd_pd,
  3284. const_dnnl_primitive_attr_t attr);
  3285. /// @} dnnl_api_rnn
  3286. /// @addtogroup dnnl_api_matmul
  3287. /// @{
  3288. /// Creates a primitive descriptor for a matrix multiplication primitive.
  3289. ///
  3290. /// @param primitive_desc Output primitive descriptor.
  3291. /// @param engine Engine to use.
  3292. /// @param src_desc Source memory descriptor (matrix A)
  3293. /// @param weights_desc Weights memory descriptor (matrix B)
  3294. /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
  3295. /// descriptor, or a memory descriptor with format_kind set to
  3296. /// #dnnl_format_kind_undef disables the bias term.
  3297. /// @param dst_desc Destination memory descriptor (matrix C).
  3298. /// @param attr Primitive attributes (can be NULL).
  3299. /// @returns #dnnl_success on success and a status describing the error
  3300. /// otherwise.
  3301. dnnl_status_t DNNL_API dnnl_matmul_primitive_desc_create(
  3302. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3303. const_dnnl_memory_desc_t src_desc,
  3304. const_dnnl_memory_desc_t weights_desc,
  3305. const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
  3306. const_dnnl_primitive_attr_t attr);
  3307. /// @} dnnl_api_matmul
  3308. /// @addtogroup dnnl_api_resampling Resampling
  3309. /// @{
  3310. /// Creates a primitive descriptor for a resampling forward propagation
  3311. /// primitive.
  3312. ///
  3313. /// @note
  3314. /// Destination memory descriptor is allowed to be initialized with
  3315. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  3316. ///
  3317. /// @param primitive_desc Output primitive descriptor.
  3318. /// @param engine Engine to use.
  3319. /// @param prop_kind Propagation kind. Possible values are
  3320. /// #dnnl_forward_training and #dnnl_forward_inference.
  3321. /// @param alg_kind resampling algorithm kind: either #dnnl_resampling_nearest,
  3322. /// or #dnnl_resampling_linear.
  3323. /// @param factors Array of scaling factors for spatial dimension.
  3324. /// @param src_desc Source memory descriptor.
  3325. /// @param dst_desc Destination memory descriptor.
  3326. /// @param attr Primitive attributes (can be NULL).
  3327. /// @returns #dnnl_success on success and a status describing the error
  3328. /// otherwise.
  3329. dnnl_status_t DNNL_API dnnl_resampling_forward_primitive_desc_create(
  3330. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3331. dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
  3332. const float *factors, const_dnnl_memory_desc_t src_desc,
  3333. const_dnnl_memory_desc_t dst_desc, const_dnnl_primitive_attr_t attr);
  3334. /// Creates a primitive descriptor for a resampling backward propagation
  3335. /// primitive.
  3336. ///
  3337. /// @param primitive_desc Output primitive descriptor.
  3338. /// @param engine Engine to use.
  3339. /// @param alg_kind resamplinging algorithm kind: either
  3340. /// #dnnl_resampling_nearest, or #dnnl_resampling_linear.
  3341. /// @param diff_src_desc Diff source memory descriptor.
  3342. /// @param diff_dst_desc Diff destination memory descriptor.
  3343. /// @param factors Array of scaling factors for spatial dimension.
  3344. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  3345. /// primitive.
  3346. /// @param attr Primitive attributes (can be NULL).
  3347. /// @returns #dnnl_success on success and a status describing the error
  3348. /// otherwise.
  3349. ///
  3350. dnnl_status_t DNNL_API dnnl_resampling_backward_primitive_desc_create(
  3351. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3352. dnnl_alg_kind_t alg_kind, const float *factors,
  3353. const_dnnl_memory_desc_t diff_src_desc,
  3354. const_dnnl_memory_desc_t diff_dst_desc,
  3355. const_dnnl_primitive_desc_t hint_fwd_pd,
  3356. const_dnnl_primitive_attr_t attr);
  3357. /// @} dnnl_api_resampling
  3358. /// @addtogroup dnnl_api_reduction Reduction
  3359. /// @{
  3360. /// Creates a primitive descriptor for a reduction primitive.
  3361. ///
  3362. /// @note
  3363. /// Destination memory descriptor is allowed to be initialized with
  3364. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  3365. ///
  3366. /// @param primitive_desc Output primitive descriptor.
  3367. /// @param engine Engine to use.
  3368. /// @param alg_kind reduction algorithm kind. Possible values:
  3369. /// #dnnl_reduction_max, #dnnl_reduction_min, #dnnl_reduction_sum,
  3370. /// #dnnl_reduction_mul, #dnnl_reduction_mean, #dnnl_reduction_norm_lp_max,
  3371. /// #dnnl_reduction_norm_lp_sum, #dnnl_reduction_norm_lp_power_p_max,
  3372. /// #dnnl_reduction_norm_lp_power_p_sum.
  3373. /// @param p Algorithm specific parameter.
  3374. /// @param eps Algorithm specific parameter.
  3375. /// @param src_desc Source memory descriptor.
  3376. /// @param dst_desc Destination memory descriptor.
  3377. /// @param attr Primitive attributes (can be NULL).
  3378. /// @returns #dnnl_success on success and a status describing the error
  3379. /// otherwise.
  3380. dnnl_status_t DNNL_API dnnl_reduction_primitive_desc_create(
  3381. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3382. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src_desc,
  3383. const_dnnl_memory_desc_t dst_desc, float p, float eps,
  3384. const_dnnl_primitive_attr_t attr);
  3385. /// @} dnnl_api_reduction
  3386. /// @} dnnl_api_primitives
  3387. /// @addtogroup dnnl_api_primitive_cache
  3388. /// @{
  3389. /// Returns the number of primitives that can be held in the primitive cache
  3390. /// at the same time.
  3391. ///
  3392. /// @param capacity Primitive cache capacity to query. Concurrently
  3393. /// accessing @p capacity is safe.
  3394. /// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
  3395. /// @p capacity value is invalid, and #dnnl_success/#dnnl::status::success on
  3396. /// success.
  3397. dnnl_status_t DNNL_API dnnl_get_primitive_cache_capacity(int *capacity);
  3398. /// Sets a number of primitives that can be held in the primitive cache
  3399. /// at a time.
  3400. ///
  3401. /// @param capacity Primitive cache capacity to set. If a new @p capacity is
  3402. /// less than a number of primitives that the primitive cache already has
  3403. /// then the excess entries will be evicted. Setting the @p capacity to 0
  3404. /// clears the primitive cache and disables it. Concurrently modifying
  3405. /// @p capacity is safe.
  3406. /// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
  3407. /// @p capacity value is invalid, and #dnnl_success/#dnnl::status::success on
  3408. /// success.
  3409. dnnl_status_t DNNL_API dnnl_set_primitive_cache_capacity(int capacity);
  3410. /// @} dnnl_api_primitive_cache
  3411. /// @addtogroup dnnl_api_service
  3412. /// @{
  3413. /// Configures dumping of JIT-generated code.
  3414. ///
  3415. /// @note
  3416. /// This setting overrides the DNNL_JIT_DUMP environment variable.
  3417. ///
  3418. /// @param enable Flag value. Set to 0 to disable and set to 1 to enable.
  3419. /// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
  3420. /// @p flag value is invalid, and #dnnl_success/#dnnl::status::success on
  3421. /// success.
  3422. dnnl_status_t DNNL_API dnnl_set_jit_dump(int enable);
  3423. /// Sets library profiling flags. The flags define which profilers are
  3424. /// supported.
  3425. ///
  3426. /// @note
  3427. /// This setting overrides DNNL_JIT_PROFILE environment variable.
  3428. ///
  3429. /// @sa @ref dev_guide_profilers
  3430. ///
  3431. /// @param flags Profiling flags that can contain the following bits:
  3432. /// - @ref DNNL_JIT_PROFILE_VTUNE -- integration with VTune Profiler
  3433. /// (on by default)
  3434. /// - @ref DNNL_JIT_PROFILE_LINUX_JITDUMP -- produce Linux-specific
  3435. /// jit-pid.dump output (off by default). The location of the output
  3436. /// is controlled via JITDUMPDIR environment variable or via
  3437. /// dnnl_set_jit_profiling_jitdumpdir() function.
  3438. /// - @ref DNNL_JIT_PROFILE_LINUX_PERFMAP -- produce Linux-specific
  3439. /// perf-pid.map output (off by default). The output is always placed
  3440. /// into /tmp.
  3441. ///
  3442. /// Passing @ref DNNL_JIT_PROFILE_NONE disables profiling completely.
  3443. ///
  3444. /// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
  3445. /// @p flags value is invalid, and #dnnl_success/#dnnl::status::success on
  3446. /// success.
  3447. dnnl_status_t DNNL_API dnnl_set_jit_profiling_flags(unsigned flags);
  3448. /// Sets JIT dump output path. Only applicable to Linux and is only
  3449. /// used when profiling flags have DNNL_JIT_PROFILE_LINUX_PERF bit set.
  3450. ///
  3451. /// After the first JIT kernel is generated, the jitdump output will be placed
  3452. /// into temporary directory created using the mkdtemp template
  3453. /// 'dir/.debug/jit/dnnl.XXXXXX'.
  3454. ///
  3455. /// @sa @ref dev_guide_profilers
  3456. ///
  3457. /// @note
  3458. /// This setting overrides JITDUMPDIR environment variable. If
  3459. /// JITDUMPDIR is not set, and this function is never called, the path
  3460. /// defaults to HOME. Passing NULL reverts the value to default.
  3461. ///
  3462. /// @note
  3463. /// The directory is accessed only when the first JIT kernel is being
  3464. /// created. JIT profiling will be disabled in case of any errors
  3465. /// accessing or creating this directory.
  3466. ///
  3467. /// @param dir JIT dump output path.
  3468. /// @returns #dnnl_success/#dnnl::status::success if the
  3469. /// output directory was set correctly and an error status otherwise.
  3470. /// @returns #dnnl_unimplemented/#dnnl::status::unimplemented on Windows.
  3471. dnnl_status_t DNNL_API dnnl_set_jit_profiling_jitdumpdir(const char *dir);
  3472. /// Sets the maximal ISA the library can dispatch to on the CPU. See
  3473. /// #dnnl_cpu_isa_t and #dnnl::cpu_isa for the list of the values accepted by
  3474. /// the C and C++ API functions respectively.
  3475. ///
  3476. /// This function has effect only once, and returns an error on subsequent
  3477. /// calls. It should also be invoked before any other oneDNN API call, otherwise
  3478. /// it may return an error.
  3479. ///
  3480. /// This function overrides the DNNL_MAX_CPU_ISA environment variable. The
  3481. /// environment variable can be set to the desired maximal ISA name in upper
  3482. /// case and with dnnl_cpu_isa prefix removed. For example:
  3483. /// `DNNL_MAX_CPU_ISA=AVX2`.
  3484. ///
  3485. /// @note
  3486. /// The ISAs are only partially ordered:
  3487. /// - SSE41 < AVX < AVX2 < AVX2_VNNI < AVX2_VNNI_2,
  3488. /// - AVX2 < AVX512_CORE < AVX512_CORE_VNNI < AVX512_CORE_BF16
  3489. /// < AVX10_1_512 < AVX10_1_512_AMX < AVX10_1_512_AMX_FP16,
  3490. /// - AVX2_VNNI < AVX10_1_512.
  3491. /// Aliases:
  3492. /// - AVX512_CORE_FP16 = AVX10_1_512
  3493. /// - AVX512_CORE_AMX = AVX10_1_512_AMX
  3494. /// - AVX512_CORE_AMX_FP16 = AVX10_1_512_AMX_FP16
  3495. ///
  3496. /// @sa @ref dev_guide_cpu_dispatcher_control for more details
  3497. ///
  3498. /// @param isa Maximal ISA the library should dispatch to. Pass
  3499. /// #dnnl_cpu_isa_default/#dnnl::cpu_isa::isa_default to remove ISA restrictions
  3500. /// (except for ISAs with initial support in the library).
  3501. /// @returns #dnnl_success/#dnnl::status::success on success and a
  3502. /// #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the @p isa
  3503. /// parameter is invalid or the ISA cannot be changed at this time.
  3504. /// @returns #dnnl_unimplemented/#dnnl::status::unimplemented if the feature
  3505. /// was disabled at build time (see @ref dev_guide_build_options for more
  3506. /// details).
  3507. dnnl_status_t DNNL_API dnnl_set_max_cpu_isa(dnnl_cpu_isa_t isa);
  3508. /// Gets the maximal ISA the library can dispatch to on the CPU. See
  3509. /// #dnnl_cpu_isa_t and #dnnl::cpu_isa for the list of the values returned by
  3510. /// the C and C++ API functions respectively.
  3511. ///
  3512. /// @sa @ref dev_guide_cpu_dispatcher_control for more details
  3513. ///
  3514. /// @returns #dnnl_cpu_isa_t value reflecting the maximal ISA the library may
  3515. /// dispatch to.
  3516. dnnl_cpu_isa_t DNNL_API dnnl_get_effective_cpu_isa(void);
  3517. /// Sets the hints flag for the CPU ISA. See #dnnl_cpu_isa_hints_t and
  3518. /// #dnnl::cpu_isa_hints for the list of the values accepted by the C and C++
  3519. /// API functions respectively.
  3520. ///
  3521. /// This function has effect only once, and returns an error on subsequent
  3522. /// calls. It should also be invoked before any other oneDNN API call, otherwise
  3523. /// it may return an error.
  3524. ///
  3525. /// This function overrides the DNNL_CPU_ISA_HINTS environment variable.
  3526. /// @sa @ref dev_guide_cpu_isa_hints for more details
  3527. ///
  3528. /// @param isa_hints CPU ISA hints to be passed over to the implementation.
  3529. /// Pass #dnnl_cpu_isa_no_hints/#dnnl::cpu_isa_hints::no_hints to use
  3530. /// default features i.e. no hints.
  3531. /// @returns #dnnl_success/#dnnl::status::success on success and a
  3532. /// #dnnl_runtime_error/#dnnl::status::runtime_error if the ISA hints cannot
  3533. /// be specified at the current time.
  3534. /// @returns #dnnl_unimplemented/#dnnl::status::unimplemented if the feature
  3535. /// was disabled at build time (see @ref dev_guide_build_options for more
  3536. /// details).
  3537. dnnl_status_t DNNL_API dnnl_set_cpu_isa_hints(dnnl_cpu_isa_hints_t isa_hints);
  3538. /// Gets the ISA specific hints that library can follow. See
  3539. /// #dnnl_cpu_isa_hints_t and #dnnl::cpu_isa_hints for the list of the values
  3540. /// returned by the C and C++ API functions respectively.
  3541. ///
  3542. /// @sa @ref dev_guide_cpu_isa_hints for more details
  3543. ///
  3544. /// @returns #dnnl_cpu_isa_hints_t value reflecting the ISA specific hints the
  3545. /// library can follow.
  3546. dnnl_cpu_isa_hints_t DNNL_API dnnl_get_cpu_isa_hints(void);
  3547. /// @} dnnl_api_service
  3548. #ifdef DNNL_EXPERIMENTAL_PROFILING
  3549. /// @addtogroup dnnl_api_profiling Profiling
  3550. /// @{
  3551. /// Resets a profiler's state.
  3552. ///
  3553. /// @param stream Stream associated with the profiler.
  3554. ///
  3555. /// @returns #dnnl_success on success and a status describing the error
  3556. /// otherwise.
  3557. dnnl_status_t DNNL_API dnnl_reset_profiling(dnnl_stream_t stream);
  3558. /// Queries profiling data. The profiling data accumulates for each primitive
  3559. /// execution. The @p num_entries will be equal to the number of executions
  3560. /// since the last `dnnl_reset_profiling` call. In order to query the
  3561. /// @p num_entries the @p data parameter should be NULL. When @p data is NULL
  3562. /// then the @p data_kind parameter is ignored.
  3563. ///
  3564. /// The profiling data can be reset by calling #dnnl_reset_profiling.
  3565. ///
  3566. /// @note
  3567. /// It is required to wait for all submitted primitives to complete
  3568. /// using #dnnl_stream_wait prior to querying profiling data.
  3569. ///
  3570. /// @param stream Stream that was used for executing a primitive that
  3571. /// is being profiled.
  3572. /// @param data_kind Profiling data kind to query.
  3573. /// @param num_entries Number of profiling data entries.
  3574. /// @param data Profiling data.
  3575. ///
  3576. /// @returns #dnnl_success on success and a status describing the error
  3577. /// otherwise.
  3578. dnnl_status_t DNNL_API dnnl_query_profiling_data(dnnl_stream_t stream,
  3579. dnnl_profiling_data_kind_t data_kind, int *num_entries, uint64_t *data);
  3580. /// @} dnnl_api_profiling
  3581. #endif
  3582. /// @addtogroup dnnl_api_blas
  3583. /// @{
  3584. /// Performs single-precision matrix-matrix multiply.
  3585. ///
  3586. /// The operation is defined as:
  3587. ///
  3588. /// `C := alpha * op( A ) * op( B ) + beta * C`
  3589. ///
  3590. /// where
  3591. /// - `op( X ) = X` or `op( X ) = X**T`,
  3592. /// - `alpha` and `beta` are scalars, and
  3593. /// - `A`, `B`, and `C` are matrices:
  3594. /// - `op( A )` is an `MxK` matrix,
  3595. /// - `op( B )` is an `KxN` matrix,
  3596. /// - `C` is an `MxN` matrix.
  3597. ///
  3598. /// The matrices are assumed to be stored in row-major order (the elements in
  3599. /// each of the matrix rows are contiguous in memory).
  3600. ///
  3601. /// @note
  3602. /// This API does not support XERBLA. Instead, unlike the standard BLAS
  3603. /// functions, this one returns a dnnl_status_t value to allow error
  3604. /// handling.
  3605. ///
  3606. /// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not
  3607. /// transposed, and 'T' or 't' means that A is transposed.
  3608. /// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not
  3609. /// transposed, and 'T' or 't' means that B is transposed.
  3610. /// @param M The M dimension.
  3611. /// @param N The N dimension.
  3612. /// @param K The K dimension.
  3613. /// @param alpha The alpha parameter that is used to scale the product of
  3614. /// matrices A and B.
  3615. /// @param A A pointer to the A matrix data.
  3616. /// @param lda The leading dimension for the matrix A.
  3617. /// @param B A pointer to the B matrix data.
  3618. /// @param ldb The leading dimension for the matrix B.
  3619. /// @param beta The beta parameter that is used to scale the matrix C.
  3620. /// @param C A pointer to the C matrix data.
  3621. /// @param ldc The leading dimension for the matrix C.
  3622. /// @returns #dnnl_success/#dnnl::status::success on success and a status
  3623. /// describing the error otherwise.
  3624. dnnl_status_t DNNL_API dnnl_sgemm(char transa, char transb, dnnl_dim_t M,
  3625. dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
  3626. const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc);
  3627. /// Performs integer matrix-matrix multiply on 8-bit unsigned matrix A, 8-bit
  3628. /// signed matrix B, and 32-bit signed resulting matrix C.
  3629. ///
  3630. /// The operation is defined as:
  3631. ///
  3632. /// `C := alpha * (op(A) - A_offset) * (op(B) - B_offset) + beta * C + C_offset`
  3633. ///
  3634. /// where
  3635. /// - `op( X ) = X` or `op( X ) = X**T`,
  3636. /// - `alpha` and `beta` are scalars, and
  3637. /// - `A`, `B`, and `C` are matrices:
  3638. /// - `op( A )` is an `MxK` matrix,
  3639. /// - `op( B )` is an `KxN` matrix,
  3640. /// - `C` is an `MxN` matrix.
  3641. /// - `A_offset` is an `MxK` matrix with every element equal the `ao` value,
  3642. /// - `B_offset` is an `KxN` matrix with every element equal the `bo` value,
  3643. /// - `C_offset` is an `MxN` matrix which is defined by the `co` array of size `len`:
  3644. /// - if `offsetc = F`: the `len` must be at least `1`,
  3645. /// - if `offsetc = C`: the `len` must be at least `max(1, m)`,
  3646. /// - if `offsetc = R`: the `len` must be at least `max(1, n)`,
  3647. ///
  3648. /// The matrices are assumed to be stored in row-major order (the elements in
  3649. /// each of the matrix rows are contiguous in memory).
  3650. ///
  3651. /// @note
  3652. /// This API does not support XERBLA. Instead, unlike the standard BLAS
  3653. /// functions, this one returns a dnnl_status_t value to allow error
  3654. /// handling.
  3655. ///
  3656. /// @warning
  3657. /// On some architectures saturation may happen during intermediate
  3658. /// computations, which would lead to unexpected results. For more
  3659. /// details, refer to @ref dev_guide_int8_computations.
  3660. ///
  3661. /// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not
  3662. /// transposed, and 'T' or 't' means that A is transposed.
  3663. /// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not
  3664. /// transposed, and 'T' or 't' means that B is transposed.
  3665. /// @param offsetc Flag specifying how offsets should be applied to matrix C:
  3666. /// - 'F' means that the same offset will be applied to each element of
  3667. /// the matrix C,
  3668. /// - 'C' means that individual offset will be applied to each element
  3669. /// within each column,
  3670. /// - 'R' means that individual offset will be applied to each element
  3671. /// within each row.
  3672. /// @param M The M dimension.
  3673. /// @param N The N dimension.
  3674. /// @param K The K dimension.
  3675. /// @param alpha The alpha parameter that is used to scale the product of
  3676. /// matrices A and B.
  3677. /// @param A A pointer to the A matrix data.
  3678. /// @param lda The leading dimension for the matrix A.
  3679. /// @param ao The offset value for the matrix A.
  3680. /// @param B A pointer to the B matrix data.
  3681. /// @param ldb The leading dimension for the matrix B.
  3682. /// @param bo The offset value for the matrix B.
  3683. /// @param beta The beta parameter that is used to scale the matrix C.
  3684. /// @param C A pointer to the C matrix data.
  3685. /// @param ldc The leading dimension for the matrix C.
  3686. /// @param co An array of offset values for the matrix C. The number of
  3687. /// elements in the array depends on the value of @p offsetc.
  3688. /// @returns #dnnl_success/#dnnl::status::success on success and a status
  3689. /// describing the error otherwise.
  3690. dnnl_status_t DNNL_API dnnl_gemm_u8s8s32(char transa, char transb, char offsetc,
  3691. dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
  3692. dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
  3693. float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co);
  3694. /// Performs integer matrix-matrix multiply on 8-bit signed matrix A, 8-bit
  3695. /// signed matrix B, and 32-bit signed resulting matrix C.
  3696. ///
  3697. /// The operation is defined as:
  3698. ///
  3699. /// `C := alpha * (op(A) - A_offset) * (op(B) - B_offset) + beta * C + C_offset`
  3700. ///
  3701. /// where
  3702. /// - `op( X ) = X` or `op( X ) = X**T`,
  3703. /// - `alpha` and `beta` are scalars, and
  3704. /// - `A`, `B`, and `C` are matrices:
  3705. /// - `op( A )` is an `MxK` matrix,
  3706. /// - `op( B )` is an `KxN` matrix,
  3707. /// - `C` is an `MxN` matrix.
  3708. /// - `A_offset` is an `MxK` matrix with every element equal the `ao` value,
  3709. /// - `B_offset` is an `KxN` matrix with every element equal the `bo` value,
  3710. /// - `C_offset` is an `MxN` matrix which is defined by the `co` array of size `len`:
  3711. /// - if `offsetc = F`: the `len` must be at least `1`,
  3712. /// - if `offsetc = C`: the `len` must be at least `max(1, m)`,
  3713. /// - if `offsetc = R`: the `len` must be at least `max(1, n)`,
  3714. ///
  3715. /// The matrices are assumed to be stored in row-major order (the elements in
  3716. /// each of the matrix rows are contiguous in memory).
  3717. ///
  3718. /// @note
  3719. /// This API does not support XERBLA. Instead, unlike the standard BLAS
  3720. /// functions, this one returns a dnnl_status_t value to allow error
  3721. /// handling.
  3722. ///
  3723. /// @warning
  3724. /// On some architectures saturation may happen during intermediate
  3725. /// computations, which would lead to unexpected results. For more
  3726. /// details, refer to @ref dev_guide_int8_computations.
  3727. ///
  3728. /// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not
  3729. /// transposed, and 'T' or 't' means that A is transposed.
  3730. /// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not
  3731. /// transposed, and 'T' or 't' means that B is transposed.
  3732. /// @param offsetc Flag specifying how offsets should be applied to matrix C:
  3733. /// - 'F' means that the same offset will be applied to each element of
  3734. /// the matrix C,
  3735. /// - 'C' means that individual offset will be applied to each element
  3736. /// within each column,
  3737. /// - 'R' means that individual offset will be applied to each element
  3738. /// within each row.
  3739. /// @param M The M dimension.
  3740. /// @param N The N dimension.
  3741. /// @param K The K dimension.
  3742. /// @param alpha The alpha parameter that is used to scale the product of
  3743. /// matrices A and B.
  3744. /// @param A A pointer to the A matrix data.
  3745. /// @param lda The leading dimension for the matrix A.
  3746. /// @param ao The offset value for the matrix A.
  3747. /// @param B A pointer to the B matrix data.
  3748. /// @param ldb The leading dimension for the matrix B.
  3749. /// @param bo The offset value for the matrix B.
  3750. /// @param beta The beta parameter that is used to scale the matrix C.
  3751. /// @param C A pointer to the C matrix data.
  3752. /// @param ldc The leading dimension for the matrix C.
  3753. /// @param co An array of offset values for the matrix C. The number of
  3754. /// elements in the array depends on the value of @p offsetc.
  3755. /// @returns #dnnl_success/#dnnl::status::success on success and a status
  3756. /// describing the error otherwise.
  3757. dnnl_status_t DNNL_API dnnl_gemm_s8s8s32(char transa, char transb, char offsetc,
  3758. dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
  3759. dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
  3760. float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co);
  3761. /// @} dnnl_api_blas
  3762. /// @} dnnl_api
  3763. #ifdef __cplusplus
  3764. }
  3765. #endif
  3766. #endif /* ONEAPI_DNNL_DNNL_H */