dnnl.hpp 634 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287528852895290529152925293529452955296529752985299530053015302530353045305530653075308530953105311531253135314531553165317531853195320532153225323532453255326532753285329533053315332533353345335533653375338533953405341534253435344534553465347534853495350535153525353535453555356535753585359536053615362536353645365536653675368536953705371537253735374537553765377537853795380538153825383538453855386538753885389539053915392539353945395539653975398539954005401540254035404540554065407540854095410541154125413541454155416541754185419542054215422542354245425542654275428542954305431543254335434543554365437543854395440544154425443544454455446544754485449545054515452545354545455545654575458545954605461546254635464546554665467546854695470547154725473547454755476547754785479548054815482548354845485548654875488548954905491549254935494549554965497549854995500550155025503550455055506550755085509551055115512551355145515551655175518551955205521552255235524552555265527552855295530553155325533553455355536553755385539554055415542554355445545554655475548554955505551555255535554555555565557555855595560556155625563556455655566556755685569557055715572557355745575557655775578557955805581558255835584558555865587558855895590559155925593559455955596559755985599560056015602560356045605560656075608560956105611561256135614561556165617561856195620562156225623562456255626562756285629563056315632563356345635563656375638563956405641564256435644564556465647564856495650565156525653565456555656565756585659566056615662566356645665566656675668566956705671567256735674567556765677567856795680568156825683568456855686568756885689569056915692569356945695569656975698569957005701570257035704570557065707570857095710571157125713571457155716571757185719572057215722572357245725572657275728572957305731573257335734573557365737573857395740574157425743574457455746574757485749575057515752575357545755575657575758575957605761576257635764576557665767576857695770577157725773577457755776577757785779578057815782578357845785578657875788578957905791579257935794579557965797579857995800580158025803580458055806580758085809581058115812581358145815581658175818581958205821582258235824582558265827582858295830583158325833583458355836583758385839584058415842584358445845584658475848584958505851585258535854585558565857585858595860586158625863586458655866586758685869587058715872587358745875587658775878587958805881588258835884588558865887588858895890589158925893589458955896589758985899590059015902590359045905590659075908590959105911591259135914591559165917591859195920592159225923592459255926592759285929593059315932593359345935593659375938593959405941594259435944594559465947594859495950595159525953595459555956595759585959596059615962596359645965596659675968596959705971597259735974597559765977597859795980598159825983598459855986598759885989599059915992599359945995599659975998599960006001600260036004600560066007600860096010601160126013601460156016601760186019602060216022602360246025602660276028602960306031603260336034603560366037603860396040604160426043604460456046604760486049605060516052605360546055605660576058605960606061606260636064606560666067606860696070607160726073607460756076607760786079608060816082608360846085608660876088608960906091609260936094609560966097609860996100610161026103610461056106610761086109611061116112611361146115611661176118611961206121612261236124612561266127612861296130613161326133613461356136613761386139614061416142614361446145614661476148614961506151615261536154615561566157615861596160616161626163616461656166616761686169617061716172617361746175617661776178617961806181618261836184618561866187618861896190619161926193619461956196619761986199620062016202620362046205620662076208620962106211621262136214621562166217621862196220622162226223622462256226622762286229623062316232623362346235623662376238623962406241624262436244624562466247624862496250625162526253625462556256625762586259626062616262626362646265626662676268626962706271627262736274627562766277627862796280628162826283628462856286628762886289629062916292629362946295629662976298629963006301630263036304630563066307630863096310631163126313631463156316631763186319632063216322632363246325632663276328632963306331633263336334633563366337633863396340634163426343634463456346634763486349635063516352635363546355635663576358635963606361636263636364636563666367636863696370637163726373637463756376637763786379638063816382638363846385638663876388638963906391639263936394639563966397639863996400640164026403640464056406640764086409641064116412641364146415641664176418641964206421642264236424642564266427642864296430643164326433643464356436643764386439644064416442644364446445644664476448644964506451645264536454645564566457645864596460646164626463646464656466646764686469647064716472647364746475647664776478647964806481648264836484648564866487648864896490649164926493649464956496649764986499650065016502650365046505650665076508650965106511651265136514651565166517651865196520652165226523652465256526652765286529653065316532653365346535653665376538653965406541654265436544654565466547654865496550655165526553655465556556655765586559656065616562656365646565656665676568656965706571657265736574657565766577657865796580658165826583658465856586658765886589659065916592659365946595659665976598659966006601660266036604660566066607660866096610661166126613661466156616661766186619662066216622662366246625662666276628662966306631663266336634663566366637663866396640664166426643664466456646664766486649665066516652665366546655665666576658665966606661666266636664666566666667666866696670667166726673667466756676667766786679668066816682668366846685668666876688668966906691669266936694669566966697669866996700670167026703670467056706670767086709671067116712671367146715671667176718671967206721672267236724672567266727672867296730673167326733673467356736673767386739674067416742674367446745674667476748674967506751675267536754675567566757675867596760676167626763676467656766676767686769677067716772677367746775677667776778677967806781678267836784678567866787678867896790679167926793679467956796679767986799680068016802680368046805680668076808680968106811681268136814681568166817681868196820682168226823682468256826682768286829683068316832683368346835683668376838683968406841684268436844684568466847684868496850685168526853685468556856685768586859686068616862686368646865686668676868686968706871687268736874687568766877687868796880688168826883688468856886688768886889689068916892689368946895689668976898689969006901690269036904690569066907690869096910691169126913691469156916691769186919692069216922692369246925692669276928692969306931693269336934693569366937693869396940694169426943694469456946694769486949695069516952695369546955695669576958695969606961696269636964696569666967696869696970697169726973697469756976697769786979698069816982698369846985698669876988698969906991699269936994699569966997699869997000700170027003700470057006700770087009701070117012701370147015701670177018701970207021702270237024702570267027702870297030703170327033703470357036703770387039704070417042704370447045704670477048704970507051705270537054705570567057705870597060706170627063706470657066706770687069707070717072707370747075707670777078707970807081708270837084708570867087708870897090709170927093709470957096709770987099710071017102710371047105710671077108710971107111711271137114711571167117711871197120712171227123712471257126712771287129713071317132713371347135713671377138713971407141714271437144714571467147714871497150715171527153715471557156715771587159716071617162716371647165716671677168716971707171717271737174717571767177717871797180718171827183718471857186718771887189719071917192719371947195719671977198719972007201720272037204720572067207720872097210721172127213721472157216721772187219722072217222722372247225722672277228722972307231723272337234723572367237723872397240724172427243724472457246724772487249725072517252725372547255725672577258725972607261726272637264726572667267726872697270727172727273727472757276727772787279728072817282728372847285728672877288728972907291729272937294729572967297729872997300730173027303730473057306730773087309731073117312731373147315731673177318731973207321732273237324732573267327732873297330733173327333733473357336733773387339734073417342734373447345734673477348734973507351735273537354735573567357735873597360736173627363736473657366736773687369737073717372737373747375737673777378737973807381738273837384738573867387738873897390739173927393739473957396739773987399740074017402740374047405740674077408740974107411741274137414741574167417741874197420742174227423742474257426742774287429743074317432743374347435743674377438743974407441744274437444744574467447744874497450745174527453745474557456745774587459746074617462746374647465746674677468746974707471747274737474747574767477747874797480748174827483748474857486748774887489749074917492749374947495749674977498749975007501750275037504750575067507750875097510751175127513751475157516751775187519752075217522752375247525752675277528752975307531753275337534753575367537753875397540754175427543754475457546754775487549755075517552755375547555755675577558755975607561756275637564756575667567756875697570757175727573757475757576757775787579758075817582758375847585758675877588758975907591759275937594759575967597759875997600760176027603760476057606760776087609761076117612761376147615761676177618761976207621762276237624762576267627762876297630763176327633763476357636763776387639764076417642764376447645764676477648764976507651765276537654765576567657765876597660766176627663766476657666766776687669767076717672767376747675767676777678767976807681768276837684768576867687768876897690769176927693769476957696769776987699770077017702770377047705770677077708770977107711771277137714771577167717771877197720772177227723772477257726772777287729773077317732773377347735773677377738773977407741774277437744774577467747774877497750775177527753775477557756775777587759776077617762776377647765776677677768776977707771777277737774777577767777777877797780778177827783778477857786778777887789779077917792779377947795779677977798779978007801780278037804780578067807780878097810781178127813781478157816781778187819782078217822782378247825782678277828782978307831783278337834783578367837783878397840784178427843784478457846784778487849785078517852785378547855785678577858785978607861786278637864786578667867786878697870787178727873787478757876787778787879788078817882788378847885788678877888788978907891789278937894789578967897789878997900790179027903790479057906790779087909791079117912791379147915791679177918791979207921792279237924792579267927792879297930793179327933793479357936793779387939794079417942794379447945794679477948794979507951795279537954795579567957795879597960796179627963796479657966796779687969797079717972797379747975797679777978797979807981798279837984798579867987798879897990799179927993799479957996799779987999800080018002800380048005800680078008800980108011801280138014801580168017801880198020802180228023802480258026802780288029803080318032803380348035803680378038803980408041804280438044804580468047804880498050805180528053805480558056805780588059806080618062806380648065806680678068806980708071807280738074807580768077807880798080808180828083808480858086808780888089809080918092809380948095809680978098809981008101810281038104810581068107810881098110811181128113811481158116811781188119812081218122812381248125812681278128812981308131813281338134813581368137813881398140814181428143814481458146814781488149815081518152815381548155815681578158815981608161816281638164816581668167816881698170817181728173817481758176817781788179818081818182818381848185818681878188818981908191819281938194819581968197819881998200820182028203820482058206820782088209821082118212821382148215821682178218821982208221822282238224822582268227822882298230823182328233823482358236823782388239824082418242824382448245824682478248824982508251825282538254825582568257825882598260826182628263826482658266826782688269827082718272827382748275827682778278827982808281828282838284828582868287828882898290829182928293829482958296829782988299830083018302830383048305830683078308830983108311831283138314831583168317831883198320832183228323832483258326832783288329833083318332833383348335833683378338833983408341834283438344834583468347834883498350835183528353835483558356835783588359836083618362836383648365836683678368836983708371837283738374837583768377837883798380838183828383838483858386838783888389839083918392839383948395839683978398839984008401840284038404840584068407840884098410841184128413841484158416841784188419842084218422842384248425842684278428842984308431843284338434843584368437843884398440844184428443844484458446844784488449845084518452845384548455845684578458845984608461846284638464846584668467846884698470847184728473847484758476847784788479848084818482848384848485848684878488848984908491849284938494849584968497849884998500850185028503850485058506850785088509851085118512851385148515851685178518851985208521852285238524852585268527852885298530853185328533853485358536853785388539854085418542854385448545854685478548854985508551855285538554855585568557855885598560856185628563856485658566856785688569857085718572857385748575857685778578857985808581858285838584858585868587858885898590859185928593859485958596859785988599860086018602860386048605860686078608860986108611861286138614861586168617861886198620862186228623862486258626862786288629863086318632863386348635863686378638863986408641864286438644864586468647864886498650865186528653865486558656865786588659866086618662866386648665866686678668866986708671867286738674867586768677867886798680868186828683868486858686868786888689869086918692869386948695869686978698869987008701870287038704870587068707870887098710871187128713871487158716871787188719872087218722872387248725872687278728872987308731873287338734873587368737873887398740874187428743874487458746874787488749875087518752875387548755875687578758875987608761876287638764876587668767876887698770877187728773877487758776877787788779878087818782878387848785878687878788878987908791879287938794879587968797879887998800880188028803880488058806880788088809881088118812881388148815881688178818881988208821882288238824882588268827882888298830883188328833883488358836883788388839884088418842884388448845884688478848884988508851885288538854885588568857885888598860886188628863886488658866886788688869887088718872887388748875887688778878887988808881888288838884888588868887888888898890889188928893889488958896889788988899890089018902890389048905890689078908890989108911891289138914891589168917891889198920892189228923892489258926892789288929893089318932893389348935893689378938893989408941894289438944894589468947894889498950895189528953895489558956895789588959896089618962896389648965896689678968896989708971897289738974897589768977897889798980898189828983898489858986898789888989899089918992899389948995899689978998899990009001900290039004900590069007900890099010901190129013901490159016901790189019902090219022902390249025902690279028902990309031903290339034903590369037903890399040904190429043904490459046904790489049905090519052905390549055905690579058905990609061906290639064906590669067906890699070907190729073907490759076907790789079908090819082908390849085908690879088908990909091909290939094909590969097909890999100910191029103910491059106910791089109911091119112911391149115911691179118911991209121912291239124912591269127912891299130913191329133913491359136913791389139914091419142914391449145914691479148914991509151915291539154915591569157915891599160916191629163916491659166916791689169917091719172917391749175917691779178917991809181918291839184918591869187918891899190919191929193919491959196919791989199920092019202920392049205920692079208920992109211921292139214921592169217921892199220922192229223922492259226922792289229923092319232923392349235923692379238923992409241924292439244924592469247924892499250925192529253925492559256925792589259926092619262926392649265926692679268926992709271927292739274927592769277927892799280928192829283928492859286928792889289929092919292929392949295929692979298929993009301930293039304930593069307930893099310931193129313931493159316931793189319932093219322932393249325932693279328932993309331933293339334933593369337933893399340934193429343934493459346934793489349935093519352935393549355935693579358935993609361936293639364936593669367936893699370937193729373937493759376937793789379938093819382938393849385938693879388938993909391939293939394939593969397939893999400940194029403940494059406940794089409941094119412941394149415941694179418941994209421942294239424942594269427942894299430943194329433943494359436943794389439944094419442944394449445944694479448944994509451945294539454945594569457945894599460946194629463946494659466946794689469947094719472947394749475947694779478947994809481948294839484948594869487948894899490949194929493949494959496949794989499950095019502950395049505950695079508950995109511951295139514951595169517951895199520952195229523952495259526952795289529953095319532953395349535953695379538953995409541954295439544954595469547954895499550955195529553955495559556955795589559956095619562956395649565956695679568956995709571957295739574957595769577957895799580958195829583958495859586958795889589959095919592959395949595959695979598959996009601960296039604960596069607960896099610961196129613961496159616961796189619962096219622962396249625962696279628962996309631963296339634963596369637963896399640964196429643964496459646964796489649965096519652965396549655965696579658965996609661966296639664966596669667966896699670967196729673967496759676967796789679968096819682968396849685968696879688968996909691969296939694969596969697969896999700970197029703970497059706970797089709971097119712971397149715971697179718971997209721972297239724972597269727972897299730973197329733973497359736973797389739974097419742974397449745974697479748974997509751975297539754975597569757975897599760976197629763976497659766976797689769977097719772977397749775977697779778977997809781978297839784978597869787978897899790979197929793979497959796979797989799980098019802980398049805980698079808980998109811981298139814981598169817981898199820982198229823982498259826982798289829983098319832983398349835983698379838983998409841984298439844984598469847984898499850985198529853985498559856985798589859986098619862986398649865986698679868986998709871987298739874987598769877987898799880988198829883988498859886988798889889989098919892989398949895989698979898989999009901990299039904990599069907990899099910991199129913991499159916991799189919992099219922992399249925992699279928992999309931993299339934993599369937993899399940994199429943994499459946994799489949995099519952995399549955995699579958995999609961996299639964996599669967996899699970997199729973997499759976997799789979998099819982998399849985998699879988998999909991999299939994999599969997999899991000010001100021000310004100051000610007100081000910010100111001210013100141001510016100171001810019100201002110022100231002410025100261002710028100291003010031100321003310034100351003610037100381003910040100411004210043100441004510046100471004810049100501005110052100531005410055100561005710058100591006010061100621006310064100651006610067100681006910070100711007210073100741007510076100771007810079100801008110082100831008410085100861008710088100891009010091100921009310094100951009610097100981009910100101011010210103101041010510106101071010810109101101011110112101131011410115101161011710118101191012010121101221012310124101251012610127101281012910130101311013210133101341013510136101371013810139101401014110142101431014410145101461014710148101491015010151101521015310154101551015610157101581015910160101611016210163101641016510166101671016810169101701017110172101731017410175101761017710178101791018010181101821018310184101851018610187101881018910190101911019210193101941019510196101971019810199102001020110202102031020410205102061020710208102091021010211102121021310214102151021610217102181021910220102211022210223102241022510226102271022810229102301023110232102331023410235102361023710238102391024010241102421024310244102451024610247102481024910250102511025210253102541025510256102571025810259102601026110262102631026410265102661026710268102691027010271102721027310274102751027610277102781027910280102811028210283102841028510286102871028810289102901029110292102931029410295102961029710298102991030010301103021030310304103051030610307103081030910310103111031210313103141031510316103171031810319103201032110322103231032410325103261032710328103291033010331103321033310334103351033610337103381033910340103411034210343103441034510346103471034810349103501035110352103531035410355103561035710358103591036010361103621036310364103651036610367103681036910370103711037210373103741037510376103771037810379103801038110382103831038410385103861038710388103891039010391103921039310394103951039610397103981039910400104011040210403104041040510406104071040810409104101041110412104131041410415104161041710418104191042010421104221042310424104251042610427104281042910430104311043210433104341043510436104371043810439104401044110442104431044410445104461044710448104491045010451104521045310454104551045610457104581045910460104611046210463104641046510466104671046810469104701047110472104731047410475104761047710478104791048010481104821048310484104851048610487104881048910490104911049210493104941049510496104971049810499105001050110502105031050410505105061050710508105091051010511105121051310514105151051610517105181051910520105211052210523105241052510526105271052810529105301053110532105331053410535105361053710538105391054010541105421054310544105451054610547105481054910550105511055210553105541055510556105571055810559105601056110562105631056410565105661056710568105691057010571105721057310574105751057610577105781057910580105811058210583105841058510586105871058810589105901059110592105931059410595105961059710598105991060010601106021060310604106051060610607106081060910610106111061210613106141061510616106171061810619106201062110622106231062410625106261062710628106291063010631106321063310634106351063610637106381063910640106411064210643106441064510646106471064810649106501065110652106531065410655106561065710658106591066010661106621066310664106651066610667106681066910670106711067210673106741067510676106771067810679106801068110682106831068410685106861068710688106891069010691106921069310694106951069610697106981069910700107011070210703107041070510706107071070810709107101071110712107131071410715107161071710718107191072010721107221072310724107251072610727107281072910730107311073210733107341073510736107371073810739107401074110742107431074410745107461074710748107491075010751107521075310754107551075610757107581075910760107611076210763107641076510766107671076810769107701077110772107731077410775107761077710778107791078010781107821078310784107851078610787107881078910790107911079210793107941079510796107971079810799108001080110802108031080410805108061080710808108091081010811108121081310814108151081610817108181081910820108211082210823108241082510826108271082810829108301083110832108331083410835108361083710838108391084010841108421084310844108451084610847108481084910850108511085210853108541085510856108571085810859108601086110862108631086410865108661086710868108691087010871108721087310874108751087610877108781087910880108811088210883108841088510886108871088810889108901089110892108931089410895108961089710898108991090010901109021090310904109051090610907109081090910910109111091210913109141091510916109171091810919109201092110922109231092410925109261092710928109291093010931109321093310934109351093610937109381093910940109411094210943109441094510946109471094810949109501095110952109531095410955109561095710958109591096010961109621096310964109651096610967109681096910970109711097210973109741097510976109771097810979109801098110982109831098410985109861098710988109891099010991109921099310994109951099610997109981099911000110011100211003110041100511006110071100811009110101101111012110131101411015110161101711018110191102011021110221102311024110251102611027110281102911030110311103211033110341103511036110371103811039110401104111042110431104411045110461104711048110491105011051110521105311054110551105611057110581105911060110611106211063110641106511066110671106811069110701107111072110731107411075110761107711078110791108011081110821108311084110851108611087110881108911090110911109211093110941109511096110971109811099111001110111102111031110411105111061110711108111091111011111111121111311114111151111611117111181111911120111211112211123111241112511126111271112811129111301113111132111331113411135111361113711138111391114011141111421114311144111451114611147111481114911150111511115211153111541115511156111571115811159111601116111162111631116411165111661116711168111691117011171111721117311174111751117611177111781117911180111811118211183111841118511186111871118811189111901119111192111931119411195111961119711198111991120011201112021120311204112051120611207112081120911210112111121211213112141121511216112171121811219112201122111222112231122411225112261122711228112291123011231112321123311234112351123611237112381123911240112411124211243112441124511246112471124811249112501125111252112531125411255112561125711258112591126011261112621126311264112651126611267112681126911270112711127211273112741127511276112771127811279112801128111282112831128411285112861128711288112891129011291112921129311294112951129611297112981129911300113011130211303113041130511306113071130811309113101131111312113131131411315113161131711318113191132011321113221132311324113251132611327113281132911330113311133211333113341133511336113371133811339113401134111342113431134411345113461134711348113491135011351113521135311354113551135611357113581135911360113611136211363113641136511366113671136811369113701137111372113731137411375113761137711378113791138011381113821138311384113851138611387113881138911390113911139211393113941139511396113971139811399114001140111402114031140411405114061140711408114091141011411114121141311414114151141611417114181141911420114211142211423114241142511426114271142811429114301143111432114331143411435114361143711438114391144011441114421144311444114451144611447114481144911450114511145211453114541145511456114571145811459114601146111462114631146411465114661146711468114691147011471114721147311474114751147611477114781147911480114811148211483114841148511486114871148811489114901149111492114931149411495114961149711498114991150011501115021150311504115051150611507115081150911510115111151211513115141151511516115171151811519115201152111522115231152411525115261152711528115291153011531115321153311534115351153611537115381153911540115411154211543115441154511546115471154811549115501155111552115531155411555115561155711558115591156011561115621156311564115651156611567115681156911570115711157211573115741157511576115771157811579115801158111582115831158411585115861158711588115891159011591115921159311594115951159611597115981159911600116011160211603116041160511606116071160811609116101161111612116131161411615116161161711618116191162011621116221162311624116251162611627116281162911630116311163211633116341163511636116371163811639116401164111642116431164411645116461164711648116491165011651116521165311654116551165611657116581165911660116611166211663116641166511666116671166811669116701167111672116731167411675116761167711678116791168011681116821168311684116851168611687116881168911690116911169211693116941169511696116971169811699117001170111702117031170411705117061170711708117091171011711117121171311714117151171611717117181171911720117211172211723117241172511726117271172811729117301173111732117331173411735117361173711738117391174011741117421174311744117451174611747117481174911750117511175211753117541175511756117571175811759117601176111762117631176411765117661176711768117691177011771117721177311774117751177611777117781177911780117811178211783117841178511786117871178811789117901179111792117931179411795117961179711798117991180011801118021180311804118051180611807118081180911810118111181211813118141181511816118171181811819118201182111822118231182411825118261182711828118291183011831118321183311834118351183611837118381183911840118411184211843118441184511846118471184811849118501185111852118531185411855118561185711858118591186011861118621186311864118651186611867118681186911870118711187211873118741187511876118771187811879118801188111882118831188411885118861188711888118891189011891118921189311894118951189611897118981189911900119011190211903119041190511906119071190811909119101191111912119131191411915119161191711918119191192011921119221192311924119251192611927119281192911930119311193211933119341193511936119371193811939119401194111942119431194411945119461194711948119491195011951119521195311954119551195611957119581195911960119611196211963119641196511966119671196811969119701197111972119731197411975119761197711978119791198011981119821198311984119851198611987119881198911990119911199211993119941199511996119971199811999120001200112002120031200412005120061200712008120091201012011120121201312014120151201612017120181201912020120211202212023120241202512026120271202812029120301203112032120331203412035120361203712038120391204012041120421204312044120451204612047120481204912050120511205212053120541205512056120571205812059120601206112062120631206412065120661206712068120691207012071120721207312074120751207612077120781207912080120811208212083120841208512086120871208812089120901209112092120931209412095120961209712098120991210012101121021210312104121051210612107121081210912110121111211212113121141211512116121171211812119121201212112122121231212412125121261212712128121291213012131121321213312134121351213612137121381213912140121411214212143121441214512146121471214812149121501215112152121531215412155121561215712158121591216012161121621216312164121651216612167121681216912170121711217212173121741217512176121771217812179121801218112182121831218412185121861218712188121891219012191121921219312194121951219612197121981219912200122011220212203122041220512206122071220812209122101221112212122131221412215122161221712218122191222012221122221222312224122251222612227122281222912230122311223212233122341223512236122371223812239122401224112242122431224412245122461224712248122491225012251122521225312254122551225612257122581225912260122611226212263122641226512266122671226812269122701227112272122731227412275122761227712278122791228012281122821228312284122851228612287122881228912290122911229212293122941229512296122971229812299123001230112302123031230412305123061230712308123091231012311123121231312314123151231612317123181231912320123211232212323123241232512326123271232812329123301233112332123331233412335123361233712338123391234012341123421234312344123451234612347123481234912350123511235212353123541235512356123571235812359123601236112362123631236412365123661236712368123691237012371123721237312374123751237612377123781237912380123811238212383123841238512386123871238812389123901239112392123931239412395123961239712398123991240012401124021240312404124051240612407124081240912410124111241212413124141241512416124171241812419124201242112422124231242412425124261242712428124291243012431124321243312434124351243612437124381243912440124411244212443124441244512446124471244812449124501245112452124531245412455124561245712458124591246012461124621246312464124651246612467124681246912470124711247212473124741247512476124771247812479124801248112482124831248412485124861248712488124891249012491124921249312494124951249612497124981249912500125011250212503125041250512506125071250812509125101251112512125131251412515125161251712518125191252012521125221252312524125251252612527125281252912530125311253212533125341253512536125371253812539125401254112542125431254412545125461254712548125491255012551125521255312554125551255612557125581255912560125611256212563125641256512566125671256812569125701257112572125731257412575125761257712578125791258012581125821258312584125851258612587125881258912590125911259212593125941259512596125971259812599126001260112602126031260412605126061260712608126091261012611126121261312614126151261612617126181261912620126211262212623126241262512626126271262812629126301263112632126331263412635126361263712638126391264012641126421264312644126451264612647126481264912650126511265212653126541265512656126571265812659126601266112662126631266412665126661266712668126691267012671126721267312674126751267612677126781267912680126811268212683126841268512686126871268812689126901269112692126931269412695126961269712698126991270012701127021270312704127051270612707127081270912710127111271212713127141271512716127171271812719127201272112722127231272412725127261272712728127291273012731127321273312734127351273612737127381273912740127411274212743127441274512746127471274812749127501275112752127531275412755127561275712758127591276012761127621276312764127651276612767127681276912770127711277212773127741277512776127771277812779127801278112782127831278412785127861278712788127891279012791127921279312794127951279612797127981279912800128011280212803128041280512806128071280812809128101281112812128131281412815128161281712818128191282012821128221282312824128251282612827128281282912830128311283212833128341283512836128371283812839128401284112842128431284412845128461284712848128491285012851128521285312854128551285612857128581285912860128611286212863128641286512866128671286812869128701287112872128731287412875128761287712878128791288012881128821288312884128851288612887128881288912890128911289212893128941289512896128971289812899129001290112902129031290412905129061290712908129091291012911129121291312914129151291612917129181291912920129211292212923129241292512926129271292812929129301293112932129331293412935129361293712938129391294012941129421294312944129451294612947129481294912950129511295212953129541295512956129571295812959129601296112962129631296412965129661296712968129691297012971129721297312974129751297612977129781297912980129811298212983129841298512986129871298812989129901299112992129931299412995129961299712998129991300013001130021300313004130051300613007130081300913010130111301213013130141301513016130171301813019130201302113022130231302413025130261302713028130291303013031130321303313034130351303613037130381303913040130411304213043130441304513046130471304813049130501305113052130531305413055130561305713058130591306013061130621306313064130651306613067130681306913070130711307213073130741307513076130771307813079130801308113082130831308413085130861308713088130891309013091130921309313094130951309613097130981309913100131011310213103131041310513106131071310813109131101311113112131131311413115131161311713118131191312013121131221312313124131251312613127131281312913130131311313213133131341313513136131371313813139131401314113142131431314413145131461314713148131491315013151131521315313154131551315613157131581315913160131611316213163131641316513166131671316813169131701317113172131731317413175131761317713178131791318013181131821318313184131851318613187131881318913190131911319213193131941319513196131971319813199132001320113202132031320413205132061320713208132091321013211132121321313214132151321613217132181321913220132211322213223132241322513226132271322813229132301323113232132331323413235132361323713238132391324013241132421324313244132451324613247132481324913250132511325213253132541325513256132571325813259132601326113262132631326413265132661326713268132691327013271132721327313274132751327613277132781327913280132811328213283132841328513286132871328813289132901329113292132931329413295132961329713298132991330013301133021330313304133051330613307133081330913310133111331213313133141331513316133171331813319133201332113322133231332413325133261332713328133291333013331133321333313334133351333613337133381333913340133411334213343133441334513346133471334813349133501335113352133531335413355133561335713358133591336013361133621336313364133651336613367133681336913370133711337213373133741337513376133771337813379133801338113382133831338413385133861338713388133891339013391133921339313394133951339613397133981339913400134011340213403134041340513406134071340813409134101341113412134131341413415134161341713418134191342013421134221342313424134251342613427134281342913430134311343213433134341343513436134371343813439134401344113442134431344413445134461344713448134491345013451134521345313454134551345613457134581345913460134611346213463134641346513466134671346813469134701347113472134731347413475134761347713478134791348013481134821348313484134851348613487134881348913490134911349213493134941349513496134971349813499135001350113502135031350413505135061350713508135091351013511135121351313514135151351613517135181351913520135211352213523135241352513526135271352813529135301353113532135331353413535135361353713538135391354013541135421354313544135451354613547135481354913550135511355213553135541355513556135571355813559135601356113562135631356413565135661356713568135691357013571135721357313574135751357613577135781357913580135811358213583135841358513586135871358813589135901359113592135931359413595135961359713598135991360013601136021360313604136051360613607136081360913610136111361213613136141361513616136171361813619136201362113622136231362413625136261362713628136291363013631136321363313634136351363613637136381363913640136411364213643136441364513646136471364813649136501365113652136531365413655136561365713658136591366013661136621366313664136651366613667136681366913670136711367213673136741367513676136771367813679136801368113682136831368413685136861368713688136891369013691136921369313694136951369613697136981369913700137011370213703137041370513706137071370813709137101371113712137131371413715137161371713718137191372013721137221372313724137251372613727137281372913730137311373213733137341373513736137371373813739137401374113742137431374413745137461374713748137491375013751137521375313754137551375613757137581375913760137611376213763137641376513766137671376813769137701377113772137731377413775137761377713778137791378013781137821378313784137851378613787137881378913790137911379213793137941379513796137971379813799138001380113802138031380413805138061380713808138091381013811138121381313814138151381613817138181381913820138211382213823138241382513826138271382813829138301383113832138331383413835138361383713838138391384013841138421384313844138451384613847138481384913850138511385213853138541385513856138571385813859138601386113862138631386413865138661386713868138691387013871138721387313874138751387613877138781387913880138811388213883138841388513886138871388813889138901389113892138931389413895138961389713898138991390013901139021390313904139051390613907139081390913910139111391213913139141391513916139171391813919139201392113922139231392413925139261392713928139291393013931139321393313934139351393613937139381393913940139411394213943139441394513946139471394813949139501395113952139531395413955139561395713958139591396013961139621396313964139651396613967139681396913970139711397213973139741397513976139771397813979139801398113982139831398413985139861398713988139891399013991139921399313994139951399613997139981399914000140011400214003140041400514006140071400814009140101401114012140131401414015140161401714018140191402014021140221402314024140251402614027140281402914030140311403214033140341403514036140371403814039140401404114042140431404414045140461404714048140491405014051140521405314054140551405614057140581405914060140611406214063140641406514066
  1. /*******************************************************************************
  2. * Copyright 2016-2025 Intel Corporation
  3. * Copyright 2024 FUJITSU LIMITED
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *******************************************************************************/
  17. /// @file
  18. /// C++ API
  19. #ifndef ONEAPI_DNNL_DNNL_HPP
  20. #define ONEAPI_DNNL_DNNL_HPP
  21. #include "oneapi/dnnl/dnnl_config.h"
  22. /// @cond DO_NOT_DOCUMENT_THIS
  23. #include <algorithm>
  24. #include <cstdlib>
  25. #include <iterator>
  26. #include <memory>
  27. #include <string>
  28. #include <vector>
  29. #include <unordered_map>
  30. #include "oneapi/dnnl/dnnl.h"
  31. #include "oneapi/dnnl/dnnl_common.hpp"
  32. /// @endcond
  33. /// @addtogroup dnnl_api oneDNN API
  34. /// @{
  35. /// oneDNN namespace
  36. namespace dnnl {
  37. /// @addtogroup dnnl_api_utils Utilities
  38. /// Utility types and definitions.
  39. /// @{
  40. /// @cond DO_NOT_DOCUMENT_THIS
  41. template <typename T>
  42. void validate_container_size(const T &v, const char *error_message,
  43. int min_size = 1, int max_size = -1) {
  44. const int size = (int)v.size();
  45. if (size < min_size || (max_size >= 0 && size > max_size))
  46. DNNL_THROW_ERROR(dnnl_invalid_arguments, error_message);
  47. }
  48. /// @endcond
  49. /// @cond DO_NOT_DOCUMENT_THIS
  50. template <>
  51. struct handle_traits<dnnl_memory_desc_t> {
  52. static dnnl_status_t destructor(dnnl_memory_desc_t p) {
  53. return dnnl_memory_desc_destroy(p);
  54. }
  55. };
  56. template <>
  57. struct handle_traits<dnnl_memory_t> {
  58. static dnnl_status_t destructor(dnnl_memory_t p) {
  59. return dnnl_memory_destroy(p);
  60. }
  61. };
  62. template <>
  63. struct handle_traits<dnnl_primitive_desc_t> {
  64. static dnnl_status_t destructor(dnnl_primitive_desc_t p) {
  65. return dnnl_primitive_desc_destroy(p);
  66. }
  67. };
  68. template <>
  69. struct handle_traits<dnnl_primitive_t> {
  70. static dnnl_status_t destructor(dnnl_primitive_t p) {
  71. return dnnl_primitive_destroy(p);
  72. }
  73. };
  74. /// @endcond
  75. /// @} dnnl_api_utils
  76. struct stream;
  77. struct memory;
  78. struct primitive_desc;
  79. /// @addtogroup dnnl_api_primitives Primitives
  80. /// Compute primitives
  81. /// @sa @ref dev_guide_basic_concepts
  82. /// @{
  83. /// @addtogroup dnnl_api_primitives_common Common
  84. /// Common operations to create, destroy and inspect primitives
  85. /// @{
  86. /// Base class for all computational primitives.
  87. struct primitive : public handle<dnnl_primitive_t> {
  88. /// Kinds of primitives supported by the library.
  89. enum class kind {
  90. /// Undefined primitive
  91. undef = dnnl_undefined_primitive,
  92. /// A reorder primitive.
  93. reorder = dnnl_reorder,
  94. /// A shuffle primitive.
  95. shuffle = dnnl_shuffle,
  96. /// A (out-of-place) tensor concatenation primitive.
  97. concat = dnnl_concat,
  98. /// A summation primitive.
  99. sum = dnnl_sum,
  100. /// A convolution primitive.
  101. convolution = dnnl_convolution,
  102. /// A deconvolution primitive.
  103. deconvolution = dnnl_deconvolution,
  104. /// An element-wise primitive.
  105. eltwise = dnnl_eltwise,
  106. /// An LRN primitive.
  107. lrn = dnnl_lrn,
  108. /// A batch normalization primitive.
  109. batch_normalization = dnnl_batch_normalization,
  110. /// An inner product primitive.
  111. inner_product = dnnl_inner_product,
  112. /// An RNN primitive.
  113. rnn = dnnl_rnn,
  114. /// A binary primitive.
  115. binary = dnnl_binary,
  116. /// A matmul (matrix multiplication) primitive.
  117. matmul = dnnl_matmul,
  118. /// A resampling primitive.
  119. resampling = dnnl_resampling,
  120. /// A pooling primitive.
  121. pooling = dnnl_pooling,
  122. /// A reduction primitive.
  123. reduction = dnnl_reduction,
  124. /// A PReLU primitive.
  125. prelu = dnnl_prelu,
  126. /// A softmax primitive.
  127. softmax = dnnl_softmax,
  128. /// A layer normalization primitive.
  129. layer_normalization = dnnl_layer_normalization,
  130. /// A group normalization primitive
  131. group_normalization = dnnl_group_normalization,
  132. };
  133. using handle::handle;
  134. /// Default constructor. Constructs an empty object.
  135. primitive() = default;
  136. /// Constructs a primitive from a C API primitive descriptor.
  137. ///
  138. /// @param c_pd C API primitive descriptor.
  139. primitive(const_dnnl_primitive_desc_t c_pd);
  140. /// Constructs a primitive from a C API primitive descriptor and a cache blob.
  141. ///
  142. /// @param c_pd C API primitive descriptor.
  143. /// @param cache_blob Cache blob.
  144. primitive(const_dnnl_primitive_desc_t c_pd,
  145. const std::vector<uint8_t> &cache_blob);
  146. /// Constructs a primitive from a primitive descriptor.
  147. ///
  148. /// @param pd Primitive descriptor.
  149. primitive(const primitive_desc &pd);
  150. /// Constructs a primitive from a primitive descriptor and a cache blob.
  151. ///
  152. /// @param pd Primitive descriptor.
  153. /// @param cache_blob Cache blob.
  154. primitive(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob);
  155. /// Returns the C API primitive descriptor of the underlying C API
  156. /// primitive.
  157. ///
  158. /// @returns The underlying C API primitive descriptor.
  159. inline const_dnnl_primitive_desc_t get_primitive_desc() const;
  160. /// Returns the kind of the primitive.
  161. ///
  162. /// @returns The primitive kind.
  163. inline kind get_kind() const;
  164. /// Returns a cache blob for the primitive.
  165. ///
  166. /// @returns Vector containing the cache blob.
  167. ///
  168. /// @note The cache blob can be empty. It's the user's responsibility to
  169. /// check whether it's empty prior to passing it to the primitive
  170. /// constructor.
  171. inline std::vector<uint8_t> get_cache_blob() const;
  172. /// Executes computations specified by the primitive in a specified stream.
  173. ///
  174. /// Arguments are passed via an arguments map containing <index,
  175. /// memory object> pairs. The index must be one of the `DNNL_ARG_*` values
  176. /// such as `DNNL_ARG_SRC`, and the memory must have a memory descriptor
  177. /// matching the one returned by
  178. /// primitive_desc::query_md(#query::exec_arg_md, index) unless using
  179. /// dynamic shapes (see #DNNL_RUNTIME_DIM_VAL).
  180. ///
  181. /// @param astream Stream object. The stream must belong to the same engine
  182. /// as the primitive.
  183. /// @param args Arguments map.
  184. void execute(const stream &astream,
  185. const std::unordered_map<int, memory> &args) const;
  186. };
  187. /// Converts primitive kind enum value from C++ API to C API type.
  188. ///
  189. /// @param akind C++ API primitive kind enum value.
  190. /// @returns Corresponding C API primitive kind enum value.
  191. inline dnnl_primitive_kind_t convert_to_c(primitive::kind akind) {
  192. return static_cast<dnnl_primitive_kind_t>(akind);
  193. }
  194. const_dnnl_primitive_desc_t primitive::get_primitive_desc() const {
  195. const_dnnl_primitive_desc_t pd;
  196. error::wrap_c_api(dnnl_primitive_get_primitive_desc(get(), &pd),
  197. "could not get a primitive descriptor from a primitive");
  198. return pd;
  199. }
  200. dnnl::primitive::kind primitive::get_kind() const {
  201. const_dnnl_primitive_desc_t pd = get_primitive_desc();
  202. // TODO (Roma): the code below is only needed because get_primitive_desc
  203. // returns a C type.
  204. dnnl_primitive_kind_t kind;
  205. error::wrap_c_api(dnnl_primitive_desc_query(
  206. pd, dnnl_query_primitive_kind, 0, (void *)&kind),
  207. "could not get a primitive kind from a primitive descriptor");
  208. return static_cast<dnnl::primitive::kind>(kind);
  209. }
  210. std::vector<uint8_t> primitive::get_cache_blob() const {
  211. size_t size;
  212. error::wrap_c_api(dnnl_primitive_get_cache_blob(get(), &size, nullptr),
  213. "could not get cache blob size from a primitive");
  214. std::vector<uint8_t> cache_blob(size);
  215. error::wrap_c_api(
  216. dnnl_primitive_get_cache_blob(get(), &size, cache_blob.data()),
  217. "could not get a cache blob from a primitive");
  218. return cache_blob;
  219. }
  220. /// @} dnnl_api_primitives_common
  221. /// @addtogroup dnnl_api_attributes
  222. ///
  223. /// A container for parameters that extend primitives behavior.
  224. ///
  225. /// Attributes can also contain Post-ops, which are computations executed
  226. /// after the primitive.
  227. ///
  228. /// @sa @ref dev_guide_attributes
  229. /// @sa @ref dev_guide_attributes_post_ops
  230. ///
  231. /// @{
  232. /// Scratchpad mode
  233. enum class scratchpad_mode {
  234. /// The library manages the scratchpad allocation according to the policy
  235. /// specified by the `DNNL_ENABLE_CONCURRENT_EXEC`
  236. /// [build option](@ref dev_guide_build_options) (default).
  237. ///
  238. /// When `DNNL_ENABLE_CONCURRENT_EXEC=OFF` (default), the library
  239. /// scratchpad is common to all primitives to reduce the memory footprint.
  240. /// This configuration comes with limited thread-safety properties, namely
  241. /// primitives can be created and executed in parallel but cannot migrate
  242. /// between threads (in other words, each primitive should be executed in
  243. /// the same thread it was created in).
  244. ///
  245. /// When `DNNL_ENABLE_CONCURRENT_EXEC=ON`, the library scratchpad is
  246. /// private to each primitive. The memory footprint is larger than when
  247. /// using `DNNL_ENABLE_CONCURRENT_EXEC=OFF` but different primitives can be
  248. /// created and run concurrently (the same primitive cannot be run
  249. /// concurrently from two different threads though).
  250. library = dnnl_scratchpad_mode_library,
  251. /// The user manages the scratchpad allocation by querying and providing
  252. /// the scratchpad memory to primitives. This mode is thread-safe as long
  253. /// as the scratchpad buffers are not used concurrently by two primitive
  254. /// executions.
  255. user = dnnl_scratchpad_mode_user,
  256. };
  257. /// Converts a scratchpad mode enum value from C++ API to C API type.
  258. ///
  259. /// @param mode C++ API scratchpad mode enum value.
  260. /// @returns Corresponding C API scratchpad mode enum value.
  261. inline dnnl_scratchpad_mode_t convert_to_c(scratchpad_mode mode) {
  262. return static_cast<dnnl_scratchpad_mode_t>(mode);
  263. }
  264. /// Rounding mode
  265. enum class rounding_mode {
  266. /// rounding mode dictated by the floating-point environment
  267. environment = dnnl_rounding_mode_environment,
  268. /// stochastic rounding mode where a random bias is added to the
  269. /// trailing mantissa bits before conversion.
  270. stochastic = dnnl_rounding_mode_stochastic
  271. };
  272. /// Converts a rounding mode enum value from C++ API to C API type.
  273. ///
  274. /// @param mode C++ API rounding mode enum value.
  275. /// @returns Corresponding C API rounding mode enum value.
  276. inline dnnl_rounding_mode_t convert_to_c(rounding_mode mode) {
  277. return static_cast<dnnl_rounding_mode_t>(mode);
  278. }
  279. /// Propagation kind.
  280. enum class prop_kind {
  281. /// Undefined propagation kind.
  282. undef = dnnl_prop_kind_undef,
  283. /// Forward data propagation (training mode). In this mode, primitives
  284. /// perform computations necessary for subsequent backward propagation.
  285. forward_training = dnnl_forward_training,
  286. /// Forward data propagation (inference mode). In this mode, primitives
  287. /// perform only computations that are necessary for inference and omit
  288. /// computations that are necessary only for backward propagation.
  289. forward_inference = dnnl_forward_inference,
  290. /// Forward data propagation,
  291. /// alias for #dnnl::prop_kind::forward_training.
  292. forward = dnnl_forward,
  293. /// Backward propagation (with respect to all parameters).
  294. backward = dnnl_backward,
  295. /// Backward data propagation.
  296. backward_data = dnnl_backward_data,
  297. /// Backward weights propagation.
  298. backward_weights = dnnl_backward_weights,
  299. /// Backward bias propagation.
  300. backward_bias = dnnl_backward_bias
  301. };
  302. /// Converts propagation kind enum value from C++ API to C API type.
  303. ///
  304. /// @param akind C++ API propagation kind enum value.
  305. /// @returns Corresponding C API propagation kind enum value.
  306. inline dnnl_prop_kind_t convert_to_c(prop_kind akind) {
  307. return static_cast<dnnl_prop_kind_t>(akind);
  308. }
  309. /// Kinds of algorithms.
  310. enum class algorithm {
  311. /// Undefined algorithm
  312. undef = dnnl_alg_kind_undef,
  313. /// Convolution algorithm that is chosen to be either direct or Winograd
  314. /// automatically
  315. convolution_auto = dnnl_convolution_auto,
  316. /// Direct convolution
  317. convolution_direct = dnnl_convolution_direct,
  318. /// Winograd convolution
  319. convolution_winograd = dnnl_convolution_winograd,
  320. /// Direct deconvolution
  321. deconvolution_direct = dnnl_deconvolution_direct,
  322. /// Winograd deconvolution
  323. deconvolution_winograd = dnnl_deconvolution_winograd,
  324. /// Elementwise: rectified linear unit (ReLU)
  325. eltwise_relu = dnnl_eltwise_relu,
  326. /// Elementwise: hyperbolic tangent non-linearity (tanh)
  327. eltwise_tanh = dnnl_eltwise_tanh,
  328. /// Elementwise: exponential linear unit (ELU)
  329. eltwise_elu = dnnl_eltwise_elu,
  330. /// Elementwise: square
  331. eltwise_square = dnnl_eltwise_square,
  332. /// Elementwise: abs
  333. eltwise_abs = dnnl_eltwise_abs,
  334. /// Elementwise: square root
  335. eltwise_sqrt = dnnl_eltwise_sqrt,
  336. /// Elementwise: swish (\f$x \cdot sigmoid(a \cdot x)\f$)
  337. eltwise_swish = dnnl_eltwise_swish,
  338. /// Elementwise: linear
  339. eltwise_linear = dnnl_eltwise_linear,
  340. /// Elementwise: soft_relu
  341. eltwise_soft_relu = dnnl_eltwise_soft_relu,
  342. /// Elementwise: mish
  343. eltwise_mish = dnnl_eltwise_mish,
  344. /// Elementwise: logistic
  345. eltwise_logistic = dnnl_eltwise_logistic,
  346. /// Elementwise: exponent
  347. eltwise_exp = dnnl_eltwise_exp,
  348. /// Elementwise: tanh-based gelu
  349. eltwise_gelu_tanh = dnnl_eltwise_gelu_tanh,
  350. /// Elementwise: erf-based gelu
  351. eltwise_gelu_erf = dnnl_eltwise_gelu_erf,
  352. /// Elementwise: natural logarithm
  353. eltwise_log = dnnl_eltwise_log,
  354. /// Elementwise: clip
  355. eltwise_clip = dnnl_eltwise_clip,
  356. /// Eltwise: clip version 2
  357. eltwise_clip_v2 = dnnl_eltwise_clip_v2,
  358. /// Elementwise: pow
  359. eltwise_pow = dnnl_eltwise_pow,
  360. /// Elementwise: round
  361. eltwise_round = dnnl_eltwise_round,
  362. /// Elementwise: hardswish
  363. eltwise_hardswish = dnnl_eltwise_hardswish,
  364. /// Elementwise: hardsigmoid
  365. eltwise_hardsigmoid = dnnl_eltwise_hardsigmoid,
  366. /// Elementwise: rectified linar unit (ReLU) (dst for backward)
  367. eltwise_relu_use_dst_for_bwd = dnnl_eltwise_relu_use_dst_for_bwd,
  368. /// Elementwise: hyperbolic tangent non-linearity (tanh) (dst for backward)
  369. eltwise_tanh_use_dst_for_bwd = dnnl_eltwise_tanh_use_dst_for_bwd,
  370. /// Elementwise: exponential linear unit (ELU) (dst for backward)
  371. eltwise_elu_use_dst_for_bwd = dnnl_eltwise_elu_use_dst_for_bwd,
  372. /// Elementwise: square root (dst for backward)
  373. eltwise_sqrt_use_dst_for_bwd = dnnl_eltwise_sqrt_use_dst_for_bwd,
  374. /// Elementwise: logistic (dst for backward)
  375. eltwise_logistic_use_dst_for_bwd = dnnl_eltwise_logistic_use_dst_for_bwd,
  376. /// Elementwise: exponent (dst for backward)
  377. eltwise_exp_use_dst_for_bwd = dnnl_eltwise_exp_use_dst_for_bwd,
  378. /// Elementwise: clip version 2 (dst for backward)
  379. eltwise_clip_v2_use_dst_for_bwd = dnnl_eltwise_clip_v2_use_dst_for_bwd,
  380. /// Local response normalization (LRN) across multiple channels
  381. lrn_across_channels = dnnl_lrn_across_channels,
  382. /// LRN within a single channel
  383. lrn_within_channel = dnnl_lrn_within_channel,
  384. /// Max pooling
  385. pooling_max = dnnl_pooling_max,
  386. /// Average pooling include padding
  387. pooling_avg_include_padding = dnnl_pooling_avg_include_padding,
  388. /// Average pooling exclude padding
  389. pooling_avg_exclude_padding = dnnl_pooling_avg_exclude_padding,
  390. /// RNN cell
  391. vanilla_rnn = dnnl_vanilla_rnn,
  392. /// LSTM cell
  393. vanilla_lstm = dnnl_vanilla_lstm,
  394. /// GRU cell
  395. vanilla_gru = dnnl_vanilla_gru,
  396. /// GRU cell with linear before reset. Differs from the vanilla GRU
  397. /// in how the new memory gate is calculated:
  398. /// \f$c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f$
  399. /// LRB GRU expects 4 bias tensors on input:
  400. /// \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$
  401. lbr_gru = dnnl_lbr_gru,
  402. /// AUGRU cell
  403. vanilla_augru = dnnl_vanilla_augru,
  404. /// AUGRU cell with linear before reset
  405. lbr_augru = dnnl_lbr_augru,
  406. /// Binary add
  407. binary_add = dnnl_binary_add,
  408. /// Binary mul
  409. binary_mul = dnnl_binary_mul,
  410. /// Binary max
  411. binary_max = dnnl_binary_max,
  412. /// Binary min
  413. binary_min = dnnl_binary_min,
  414. /// Binary div
  415. binary_div = dnnl_binary_div,
  416. /// Binary sub
  417. binary_sub = dnnl_binary_sub,
  418. /// Binary greater than or equal
  419. binary_ge = dnnl_binary_ge,
  420. /// Binary greater than
  421. binary_gt = dnnl_binary_gt,
  422. /// Binary less than or equal
  423. binary_le = dnnl_binary_le,
  424. /// Binary less than
  425. binary_lt = dnnl_binary_lt,
  426. /// Binary equal
  427. binary_eq = dnnl_binary_eq,
  428. /// Binary not equal
  429. binary_ne = dnnl_binary_ne,
  430. /// Binary select
  431. binary_select = dnnl_binary_select,
  432. /// Nearest Neighbor resampling method
  433. resampling_nearest = dnnl_resampling_nearest,
  434. /// Linear (Bilinear, Trilinear) resampling method
  435. resampling_linear = dnnl_resampling_linear,
  436. /// Reduction using max operation
  437. reduction_max = dnnl_reduction_max,
  438. /// Reduction using min operation
  439. reduction_min = dnnl_reduction_min,
  440. /// Reduction using sum operation
  441. reduction_sum = dnnl_reduction_sum,
  442. /// Reduction using mul operation
  443. reduction_mul = dnnl_reduction_mul,
  444. /// Reduction using mean operation
  445. reduction_mean = dnnl_reduction_mean,
  446. /// Reduction using norm_lp_max operation
  447. reduction_norm_lp_max = dnnl_reduction_norm_lp_max,
  448. /// Reduction using norm_lp_sum operation
  449. reduction_norm_lp_sum = dnnl_reduction_norm_lp_sum,
  450. /// Reduction using norm_lp_power_p_max operation
  451. reduction_norm_lp_power_p_max = dnnl_reduction_norm_lp_power_p_max,
  452. /// Reduction using norm_lp_power_p_sum operation
  453. reduction_norm_lp_power_p_sum = dnnl_reduction_norm_lp_power_p_sum,
  454. /// Softmax, numerically stable
  455. softmax_accurate = dnnl_softmax_accurate,
  456. /// LogSoftmax, numerically stable
  457. softmax_log = dnnl_softmax_log,
  458. };
  459. /// Converts algorithm kind enum value from C++ API to C API type.
  460. /// @param aalgorithm C++ API algorithm kind enum value.
  461. /// @returns Corresponding C API algorithm kind enum value.
  462. inline dnnl_alg_kind_t convert_to_c(algorithm aalgorithm) {
  463. return static_cast<dnnl_alg_kind_t>(aalgorithm);
  464. }
  465. /// @} dnnl_api_attributes
  466. /// @addtogroup dnnl_api_primitives_common
  467. /// @{
  468. /// Flags for normalization primitives.
  469. enum class normalization_flags : unsigned {
  470. /// Use no normalization flags. If specified, the library computes mean and
  471. /// variance on forward propagation for training and inference, outputs them
  472. /// on forward propagation for training, and computes the respective
  473. /// derivatives on backward propagation.
  474. none = dnnl_normalization_flags_none,
  475. /// Use global statistics. If specified, the library uses mean and
  476. /// variance provided by the user as an input on forward propagation and
  477. /// does not compute their derivatives on backward propagation. Otherwise,
  478. /// the library computes mean and variance on forward propagation for
  479. /// training and inference, outputs them on forward propagation for
  480. /// training, and computes the respective derivatives on backward
  481. /// propagation.
  482. use_global_stats = dnnl_use_global_stats,
  483. /// Use scale parameter. If specified, the user is expected to pass scale as
  484. /// input on forward propagation. On backward propagation of type
  485. /// #dnnl::prop_kind::backward, the library computes its derivative.
  486. use_scale = dnnl_use_scale,
  487. /// Use shift parameter. If specified, the user is expected to pass shift as
  488. /// input on forward propagation. On backward propagation of type
  489. /// #dnnl::prop_kind::backward, the library computes its derivative.
  490. use_shift = dnnl_use_shift,
  491. /// Fuse normalization with ReLU. On training, normalization will require
  492. /// the workspace to implement backward propagation. On inference, the
  493. /// workspace is not required and behavior is the same as when normalization
  494. /// is fused with ReLU using the post-ops API.
  495. fuse_norm_relu = dnnl_fuse_norm_relu,
  496. /// Fuse normalization with elementwise binary Add and then fuse with ReLU.
  497. /// On training, normalization will require the workspace to implement
  498. /// backward propagation. On inference, the workspace is not required.
  499. fuse_norm_add_relu = dnnl_fuse_norm_add_relu,
  500. };
  501. /// Converts normalization flags enum value from C++ API to C API type.
  502. /// @param flags C++ API normalization flags enum value.
  503. /// @returns Corresponding C API normalization flags enum value.
  504. inline dnnl_normalization_flags_t convert_to_c(normalization_flags flags) {
  505. return static_cast<dnnl_normalization_flags_t>(flags);
  506. }
  507. /// @} dnnl_api_primitives_common
  508. /// @addtogroup dnnl_api_rnn
  509. /// @{
  510. /// RNN cell flags.
  511. enum class rnn_flags : unsigned {
  512. /// Undefined RNN flags
  513. undef = dnnl_rnn_flags_undef,
  514. /// Do not add weights gradient to existing diff_weights memory
  515. diff_weights_overwrite = dnnl_rnn_flags_diff_weights_overwrite,
  516. };
  517. /// Converts RNN cell flags enum value from C++ API to C API type.
  518. /// @param flags C++ API RNN cell flags enum value.
  519. /// @returns Corresponding C API RNN cell flags enum value.
  520. inline dnnl_rnn_flags_t convert_to_c(rnn_flags flags) {
  521. return static_cast<dnnl_rnn_flags_t>(flags);
  522. }
  523. DNNL_DEFINE_BITMASK_OPS(normalization_flags)
  524. DNNL_DEFINE_BITMASK_OPS(rnn_flags)
  525. /// A direction of RNN primitive execution
  526. enum class rnn_direction {
  527. /// Undefined RNN direction.
  528. undef = dnnl_rnn_direction_undef,
  529. /// Unidirectional execution of RNN primitive from left to right.
  530. unidirectional_left2right = dnnl_unidirectional_left2right,
  531. /// Unidirectional execution of RNN primitive from right to left.
  532. unidirectional_right2left = dnnl_unidirectional_right2left,
  533. /// Bidirectional execution of RNN primitive with concatenation of the
  534. /// results.
  535. bidirectional_concat = dnnl_bidirectional_concat,
  536. /// Bidirectional execution of RNN primitive with summation of the
  537. /// results.
  538. bidirectional_sum = dnnl_bidirectional_sum,
  539. };
  540. /// Converts RNN direction enum value from C++ API to C API type.
  541. /// @param dir C++ API RNN direction enum value.
  542. /// @returns Corresponding C API RNN direction enum value.
  543. inline dnnl_rnn_direction_t convert_to_c(rnn_direction dir) {
  544. return static_cast<dnnl_rnn_direction_t>(dir);
  545. }
  546. /// @} dnnl_api_rnn
  547. /// @addtogroup dnnl_api_primitives_common
  548. /// @{
  549. /// Primitive descriptor query specification.
  550. ///
  551. /// In general, queries are not used with the C++ API because most queries are
  552. /// implemented as class members.
  553. ///
  554. /// See @ref dnnl_query_t for more information.
  555. enum class query {
  556. /// no query
  557. undef = dnnl_query_undef,
  558. /// execution engine
  559. engine = dnnl_query_engine,
  560. /// primitive kind
  561. primitive_kind = dnnl_query_primitive_kind,
  562. /// number of inputs expected
  563. num_of_inputs_s32 = dnnl_query_num_of_inputs_s32,
  564. /// number of outputs expected
  565. num_of_outputs_s32 = dnnl_query_num_of_outputs_s32,
  566. /// runtime estimation (seconds), unimplemented
  567. time_estimate_f64 = dnnl_query_time_estimate_f64,
  568. /// memory required for scratchpad (bytes)
  569. ///
  570. /// @sa @ref dev_guide_attributes_scratchpad
  571. memory_consumption_s64 = dnnl_query_memory_consumption_s64,
  572. /// scratchpad engine
  573. ///
  574. /// engine to be used for creating scratchpad memory
  575. scratchpad_engine = dnnl_query_scratchpad_engine,
  576. /// reorder source engine
  577. reorder_src_engine = dnnl_query_reorder_src_engine,
  578. /// reorder destination engine
  579. reorder_dst_engine = dnnl_query_reorder_dst_engine,
  580. /// implementation name
  581. impl_info_str = dnnl_query_impl_info_str,
  582. /// propagation kind
  583. prop_kind = dnnl_query_prop_kind,
  584. /// size of cache blob ID in bytes
  585. cache_blob_id_size_s64 = dnnl_query_cache_blob_id_size_s64,
  586. /// cache blob ID (pointer to array)
  587. cache_blob_id = dnnl_query_cache_blob_id,
  588. /// strides
  589. strides = dnnl_query_strides,
  590. /// dilations
  591. dilations = dnnl_query_dilations,
  592. /// left padding
  593. padding_l = dnnl_query_padding_l,
  594. /// right padding
  595. padding_r = dnnl_query_padding_r,
  596. /// epsilon
  597. epsilon_f32 = dnnl_query_epsilon_f32,
  598. /// flags
  599. flags = dnnl_query_flags,
  600. /// algorithm kind
  601. alg_kind = dnnl_query_alg_kind,
  602. /// alpha
  603. alpha_f32 = dnnl_query_alpha_f32,
  604. /// beta
  605. beta_f32 = dnnl_query_beta_f32,
  606. /// axis
  607. axis_s32 = dnnl_query_axis_s32,
  608. /// LRN parameter local size
  609. local_size_s64 = dnnl_query_local_size_s64,
  610. /// LRN parameter K
  611. k_f32 = dnnl_query_k_f32,
  612. /// Reduction parameter P
  613. p_f32 = dnnl_query_p_f32,
  614. /// Resampling parameter factors
  615. factors = dnnl_query_factors,
  616. /// RNN parameter cell kind
  617. cell_kind = dnnl_query_cell_kind,
  618. /// RNN parameter direction
  619. direction = dnnl_query_direction,
  620. /// RNN parameter activation kind
  621. activation_kind = dnnl_query_activation_kind,
  622. /// Pooling parameter kernel
  623. kernel = dnnl_query_kernel,
  624. /// Shuffle parameter group size
  625. group_size_s64 = dnnl_query_group_size_s64,
  626. /// source memory desc
  627. src_md = dnnl_query_src_md,
  628. /// source gradient (diff) memory desc
  629. diff_src_md = dnnl_query_diff_src_md,
  630. /// weights memory descriptor desc
  631. weights_md = dnnl_query_weights_md,
  632. /// weights gradient (diff) memory desc
  633. diff_weights_md = dnnl_query_diff_weights_md,
  634. /// destination memory desc
  635. dst_md = dnnl_query_dst_md,
  636. /// destination gradient (diff) memory desc
  637. diff_dst_md = dnnl_query_diff_dst_md,
  638. /// workspace memory desc
  639. workspace_md = dnnl_query_workspace_md,
  640. /// scratchpad memory desc
  641. scratchpad_md = dnnl_query_scratchpad_md,
  642. /// memory desc of an execute argument
  643. exec_arg_md = dnnl_query_exec_arg_md,
  644. /// number of dimensions
  645. ndims_s32 = dnnl_query_ndims_s32,
  646. /// vector of dimensions
  647. dims = dnnl_query_dims,
  648. /// data type
  649. data_type = dnnl_query_data_type,
  650. /// submemory offset
  651. submemory_offset_s64 = dnnl_query_submemory_offset_s64,
  652. /// vector of padded dimensions
  653. padded_dims = dnnl_query_padded_dims,
  654. /// vector of padded offsets
  655. padded_offsets = dnnl_query_padded_offsets,
  656. /// format kind
  657. format_kind = dnnl_query_format_kind,
  658. /// number of innermost blocks
  659. inner_nblks_s32 = dnnl_query_inner_nblks_s32,
  660. /// vector of sizes of the innermost blocks
  661. inner_blks = dnnl_query_inner_blks,
  662. /// vector of logical indices of the blocks
  663. inner_idxs = dnnl_query_inner_idxs,
  664. #ifdef DNNL_EXPERIMENTAL_SPARSE
  665. /// Sparse encoding
  666. sparse_encoding = dnnl_query_sparse_encoding,
  667. /// Number of non-zero entries
  668. nnz_s64 = dnnl_query_nnz_s64,
  669. /// Number of buffers required for a memory descriptor
  670. num_handles_s32 = dnnl_query_num_handles_s32,
  671. #endif
  672. };
  673. /// Converts query enum value from C++ API to C API type.
  674. /// @param aquery C++ API query enum value.
  675. /// @returns Corresponding C API query enum value.
  676. inline dnnl_query_t convert_to_c(query aquery) {
  677. return static_cast<dnnl_query_t>(aquery);
  678. }
  679. /// @} dnnl_api_primitives_common
  680. /// @} dnnl_api_primitives
  681. /// @addtogroup dnnl_api_memory Memory
  682. ///
  683. /// A container that describes and stores data. Memory objects can contain
  684. /// data of various types and formats. There are two levels of abstraction:
  685. ///
  686. /// 1. **Memory descriptor** -- engine-agnostic logical description of data
  687. /// (number of dimensions, dimension sizes, and data type), and,
  688. /// optionally, the information about the physical format of data in
  689. /// memory. If this information is not known yet, a memory descriptor can
  690. /// be created with #dnnl::memory::format_tag::any. This allows
  691. /// compute-intensive primitives to choose the best format for
  692. /// computation. The user is responsible for reordering the data into the
  693. /// chosen format when formats do not match.
  694. ///
  695. /// A memory descriptor can be initialized either by specifying dimensions
  696. /// and a memory format tag or strides for each of them, or by
  697. /// manipulating the dnnl_memory_desc_t structure directly.
  698. ///
  699. /// @warning
  700. /// The latter approach requires understanding how the physical data
  701. /// representation is mapped to the structure and is discouraged. This
  702. /// topic is discussed in @ref dev_guide_understanding_memory_formats.
  703. ///
  704. /// The user can query the amount of memory required by a memory
  705. /// descriptor using the #dnnl::memory::desc::get_size() function. The
  706. /// size of data in general cannot be computed as the product of
  707. /// dimensions multiplied by the size of the data type. So users are
  708. /// required to use this function for better code portability.
  709. ///
  710. /// Two memory descriptors can be compared using the equality and
  711. /// inequality operators. The comparison is especially useful when
  712. /// checking whether it is necessary to reorder data from the user's data
  713. /// format to a primitive's format.
  714. ///
  715. /// 2. **Memory object** -- an engine-specific object that handles the memory
  716. /// buffer and its description (a memory descriptor). For the CPU engine or
  717. /// with USM, the memory buffer handle is simply a pointer to @c void. The
  718. /// memory buffer can be queried using #dnnl::memory::get_data_handle() and
  719. /// set using #dnnl::memory::set_data_handle(). The underlying SYCL buffer,
  720. /// when used, can be queried using #dnnl::sycl_interop::get_buffer and set
  721. /// using #dnnl::sycl_interop::set_buffer. A memory object can also be
  722. /// queried for the underlying memory descriptor and for its engine using
  723. /// #dnnl::memory::get_desc() and dnnl::memory::get_engine().
  724. ///
  725. /// Along with ordinary memory descriptors with all dimensions being positive,
  726. /// the library supports *zero-volume* memory descriptors with one or more
  727. /// dimensions set to zero. This is used to support the NumPy\* convention.
  728. /// If a zero-volume memory is passed to a primitive, the primitive typically
  729. /// does not perform any computations with this memory. For example:
  730. ///
  731. /// - A concatenation primitive would ignore all memory object with zeroes in
  732. /// the concat dimension / axis.
  733. ///
  734. /// - A forward convolution with a source memory object with zero in the
  735. /// minibatch dimension would always produce a destination memory object
  736. /// with a zero in the minibatch dimension and perform no computations.
  737. ///
  738. /// - However, a forward convolution with a zero in one of the weights
  739. /// dimensions is ill-defined and is considered to be an error by the
  740. /// library because there is no clear definition of what the output values
  741. /// should be.
  742. ///
  743. /// Memory buffer of a zero-volume memory is never accessed.
  744. ///
  745. /// @{
  746. /// Memory object.
  747. ///
  748. /// A memory object encapsulates a handle to a memory buffer allocated on a
  749. /// specific engine, tensor dimensions, data type, and memory format, which is
  750. /// the way tensor indices map to offsets in linear memory space. Memory
  751. /// objects are passed to primitives during execution.
  752. struct memory : public handle<dnnl_memory_t> {
  753. using handle::handle;
  754. /// Integer type for representing dimension sizes and indices.
  755. typedef dnnl_dim_t dim;
  756. /// Vector of dimensions. Implementations are free to force a limit on the
  757. /// vector's length.
  758. typedef std::vector<dim> dims;
  759. /// Helper function that validates that an `std::vector` of dimensions can
  760. /// be safely converted to the C API array ::dnnl_dims_t. Throws if
  761. /// validation fails.
  762. ///
  763. /// @param v Vector of dimensions.
  764. /// @param min_size Minimum expected size of the vector.
  765. template <typename T>
  766. static void validate_dims(const std::vector<T> &v, int min_size = 0) {
  767. validate_container_size(
  768. v, "dimensions are invalid", min_size, DNNL_MAX_NDIMS);
  769. }
  770. /// Data type specification.
  771. enum class data_type {
  772. /// Undefined data type (used for empty memory descriptors).
  773. undef = dnnl_data_type_undef,
  774. /// 4-bit float data type with 3-bit exponent and 0 bit mantissa.
  775. f4_e3m0 = dnnl_f4_e3m0,
  776. /// [MX-compliant 4-bit float data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 2-bit exponent and 1 bit mantissa.
  777. f4_e2m1 = dnnl_f4_e2m1,
  778. /// [MX-compliant 8-bit compliant scale data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 8-bit exponent.
  779. e8m0 = dnnl_e8m0,
  780. /// [OFP8 standard 8-bit floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf)
  781. /// with a 5-bit exponent and a 2-bit mantissa.
  782. f8_e5m2 = dnnl_f8_e5m2,
  783. /// [OFP8 standard 8-bit floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf)
  784. /// with a 4-bit exponent and a 3-bit mantissa.
  785. f8_e4m3 = dnnl_f8_e4m3,
  786. /// [16-bit/half-precision floating point](https://en.wikipedia.org/wiki/Half-precision_floating-point_format).
  787. f16 = dnnl_f16,
  788. /// non-standard
  789. /// [16-bit floating point with 7-bit mantissa](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format).
  790. bf16 = dnnl_bf16,
  791. /// [32-bit/single-precision floating point](https://en.wikipedia.org/wiki/Single-precision_floating-point_format).
  792. f32 = dnnl_f32,
  793. //// [64-bit/double-precision floating point](https://en.wikipedia.org/wiki/Double-precision_floating-point_format).
  794. f64 = dnnl_f64,
  795. /// 32-bit signed integer.
  796. s32 = dnnl_s32,
  797. /// 8-bit signed integer.
  798. s8 = dnnl_s8,
  799. /// 8-bit unsigned integer.
  800. u8 = dnnl_u8,
  801. /// 4-bit signed integer.
  802. s4 = dnnl_s4,
  803. /// 4-bit unsigned integer.
  804. u4 = dnnl_u4,
  805. };
  806. /// Returns size of data type in bytes.
  807. /// @returns The number of bytes occupied by data type.
  808. static size_t data_type_size(data_type adata_type) {
  809. return dnnl_data_type_size(convert_to_c(adata_type));
  810. }
  811. /// Memory format kind
  812. enum class format_kind {
  813. /// Undefined memory format kind, used for empty memory descriptors.
  814. undef = dnnl_format_kind_undef,
  815. /// A special format kind that indicates that the actual format will be
  816. /// selected by a primitive automatically.
  817. any = dnnl_format_kind_any,
  818. /// A tensor in a generic format described by the stride and blocking
  819. /// values in each dimension.
  820. blocked = dnnl_blocked,
  821. #ifdef DNNL_EXPERIMENTAL_SPARSE
  822. /// Format kind for sparse tensors.
  823. sparse = dnnl_format_kind_sparse,
  824. #endif
  825. /// A special format kind that indicates that tensor format is opaque.
  826. opaque = dnnl_format_kind_opaque,
  827. };
  828. #ifdef DNNL_EXPERIMENTAL_SPARSE
  829. /// Sparse encodings.
  830. enum class sparse_encoding {
  831. /// Undefined sparse encoding kind, used for empty memory descriptors.
  832. undef = dnnl_sparse_encoding_undef,
  833. /// Compressed Sparse Row (CSR) encoding.
  834. csr = dnnl_csr,
  835. /// An encoding that is used for an opaque storage schema for
  836. /// tensors with unstructured sparsity. A memory descriptor with the
  837. /// packed encoding cannot be used to create a memory object. It can
  838. /// only be used to create a primitive descriptor to query the
  839. /// actual memory descriptor (similar to the format tag `any`).
  840. packed = dnnl_packed,
  841. /// Coordinate Sparse (COO) encoding.
  842. coo = dnnl_coo,
  843. };
  844. #endif
  845. /// Memory format tag specification.
  846. ///
  847. /// Memory format tags can be further divided into two categories:
  848. ///
  849. /// - Domain-agnostic names, i.e. names that do not depend on the tensor
  850. /// usage in the specific primitive. These names use letters from `a`
  851. /// to `f` to denote logical dimensions and form the order in which the
  852. /// dimensions are laid in memory. For example,
  853. /// #dnnl::memory::format_tag::ab is used to denote a 2D tensor where the
  854. /// second logical dimension (denoted as `b`) is the innermost, i.e.
  855. /// has stride = 1, and the first logical dimension (`a`) is laid out in
  856. /// memory with stride equal to the size of the second dimension. On the
  857. /// other hand, #dnnl::memory::format_tag::ba is the transposed version
  858. /// of the same tensor: the outermost dimension (`a`) becomes the
  859. /// innermost one.
  860. ///
  861. /// - Domain-specific names, i.e. names that make sense only in the
  862. /// context of a certain domain, such as CNN. These names are
  863. /// aliases to the corresponding domain-agnostic tags and used mostly
  864. /// for convenience. For example, #dnnl::memory::format_tag::nc
  865. /// is used to denote 2D CNN activations tensor memory format, where
  866. /// the channels dimension is the innermost one and the batch dimension
  867. /// is the outermost one. Moreover, #dnnl::memory::format_tag::nc is
  868. /// an alias for #dnnl::memory::format_tag::ab, because for
  869. /// CNN primitives the logical dimensions of activations tensors come
  870. /// in order: batch, channels, spatial. In other words, batch
  871. /// corresponds to the first logical dimension (`a`), and channels
  872. /// correspond to the second one (`b`).
  873. ///
  874. /// The following domain-specific notation applies to memory format tags:
  875. /// - @c 'n' denotes the mini-batch dimension
  876. /// - @c 'c' denotes a channels dimension
  877. /// - When there are multiple channel dimensions (for example,
  878. /// in convolution weights tensor), @c 'i' and @c 'o' denote dimensions
  879. /// of input and output channels
  880. /// - @c 'g' denotes a groups dimension for convolution weights
  881. /// - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width
  882. /// respectively
  883. ///
  884. /// See @ref dnnl_format_tag_t for a detailed description.
  885. enum class format_tag {
  886. /// Undefined memory format tag
  887. undef = dnnl_format_tag_undef,
  888. /// Placeholder memory format tag. Used to instruct the primitive to
  889. /// select a format automatically.
  890. any = dnnl_format_tag_any,
  891. /// plain 1D tensor
  892. a = dnnl_a,
  893. /// plain 2D tensor
  894. ab = dnnl_ab,
  895. /// permuted 2D tensor
  896. ba = dnnl_ba,
  897. /// plain 3D tensor
  898. abc = dnnl_abc,
  899. /// permuted 3D tensor
  900. acb = dnnl_acb,
  901. /// permuted 3D tensor
  902. bac = dnnl_bac,
  903. /// permuted 3D tensor
  904. bca = dnnl_bca,
  905. /// permuted 3D tensor
  906. cba = dnnl_cba,
  907. /// plain 4D tensor
  908. abcd = dnnl_abcd,
  909. /// permuted 4D tensor
  910. abdc = dnnl_abdc,
  911. /// permuted 4D tensor
  912. acbd = dnnl_acbd,
  913. /// permuted 4D tensor
  914. acdb = dnnl_acdb,
  915. /// permuted 4D tensor
  916. adbc = dnnl_adbc,
  917. /// permuted 4D tensor
  918. bacd = dnnl_bacd,
  919. /// permuted 4D tensor
  920. bcda = dnnl_bcda,
  921. /// permuted 4D tensor
  922. cdba = dnnl_cdba,
  923. /// permuted 4D tensor
  924. dcab = dnnl_dcab,
  925. /// plain 5D tensor
  926. abcde = dnnl_abcde,
  927. /// permuted 5D tensor
  928. abdec = dnnl_abdec,
  929. /// permuted 5D tensor
  930. acbde = dnnl_acbde,
  931. /// permuted 5D tensor
  932. acdeb = dnnl_acdeb,
  933. /// permuted 5D tensor
  934. bacde = dnnl_bacde,
  935. /// permuted 5D tensor
  936. bcdea = dnnl_bcdea,
  937. /// permuted 5D tensor
  938. cdeba = dnnl_cdeba,
  939. /// permuted 5D tensor
  940. decab = dnnl_decab,
  941. /// permuted 5D tensor
  942. abced = dnnl_abced,
  943. /// plain 6D tensor
  944. abcdef = dnnl_abcdef,
  945. /// permuted 6D tensor
  946. abdfce = dnnl_abdfce,
  947. /// permuted 6D tensor
  948. acbdef = dnnl_acbdef,
  949. /// permuted 6D tensor
  950. abdefc = dnnl_abdefc,
  951. /// permuted 6D tensor
  952. defcab = dnnl_defcab,
  953. /// permuted 6D tensor
  954. abcdfe = dnnl_abcdfe,
  955. /// plain 7D tensor
  956. abcdefg = dnnl_abcdefg,
  957. /// permuted 7D tensor
  958. abcdegf = dnnl_abcdegf,
  959. /// plain 8D tensor
  960. abcdefgh = dnnl_abcdefgh,
  961. /// permuted 8D tensor
  962. abcdefhg = dnnl_abcdefhg,
  963. /// plain 9D tensor
  964. abcdefghi = dnnl_abcdefghi,
  965. /// permuted 9D tensor
  966. abcdefgih = dnnl_abcdefgih,
  967. /// plain 10D tensor
  968. abcdefghij = dnnl_abcdefghij,
  969. /// permuted 10D tensor
  970. abcdefghji = dnnl_abcdefghji,
  971. /// plain 11D tensor
  972. abcdefghijk = dnnl_abcdefghijk,
  973. /// permuted 11D tensor
  974. abcdefghikj = dnnl_abcdefghikj,
  975. /// plain 12D tensor
  976. abcdefghijkl = dnnl_abcdefghijkl,
  977. /// permuted 12D tensor
  978. abcdefghijlk = dnnl_abcdefghijlk,
  979. /// 1D tensor; an alias for #dnnl::memory::format_tag::a
  980. x = a,
  981. /// 2D CNN activations tensor; an alias for #dnnl::memory::format_tag::ab
  982. nc = ab,
  983. /// 2D CNN activations tensor; an alias for #dnnl::memory::format_tag::ba
  984. cn = ba,
  985. /// 2D RNN statistics tensor; an alias for #dnnl::memory::format_tag::ab
  986. tn = ab,
  987. /// 2D RNN statistics tensor; an alias for #dnnl::memory::format_tag::ba
  988. nt = ba,
  989. /// 3D CNN activations tensor; an alias for #dnnl::memory::format_tag::abc
  990. ncw = abc,
  991. /// 3D CNN activations tensor; an alias for #dnnl::memory::format_tag::acb
  992. nwc = acb,
  993. /// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::abcd
  994. nchw = abcd,
  995. /// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::acdb
  996. nhwc = acdb,
  997. /// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::bcda
  998. chwn = bcda,
  999. /// 5D CNN activations tensor; an alias for #dnnl::memory::format_tag::abcde
  1000. ncdhw = abcde,
  1001. /// 5D CNN activations tensor; an alias for #dnnl::memory::format_tag::acdeb
  1002. ndhwc = acdeb,
  1003. /// 2D CNN weights tensor; an alias for #dnnl::memory::format_tag::ab
  1004. oi = ab,
  1005. /// 2D CNN weights tensor; an alias for #dnnl::memory::format_tag::ba
  1006. io = ba,
  1007. /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::abc
  1008. oiw = abc,
  1009. /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::acb
  1010. owi = acb,
  1011. /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::cba
  1012. wio = cba,
  1013. /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::bca
  1014. iwo = bca,
  1015. /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::abcd
  1016. oihw = abcd,
  1017. /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::cdba
  1018. hwio = cdba,
  1019. /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::acdb
  1020. ohwi = acdb,
  1021. /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::bcda
  1022. ihwo = bcda,
  1023. /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::bacd
  1024. iohw = bacd,
  1025. /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::abcde
  1026. oidhw = abcde,
  1027. /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::cdeba
  1028. dhwio = cdeba,
  1029. /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::acdeb
  1030. odhwi = acdeb,
  1031. /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::bacde
  1032. iodhw = bacde,
  1033. /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::bcdea
  1034. idhwo = bcdea,
  1035. /// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcd
  1036. goiw = abcd,
  1037. /// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdc
  1038. gowi = abdc,
  1039. /// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::dcab
  1040. wigo = dcab,
  1041. /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdec
  1042. gohwi = abdec,
  1043. /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcde
  1044. goihw = abcde,
  1045. /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::decab
  1046. hwigo = decab,
  1047. /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::acbde
  1048. giohw = acbde,
  1049. /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcdef
  1050. goidhw = abcdef,
  1051. /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcdef
  1052. giodhw = acbdef,
  1053. /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdefc
  1054. godhwi = abdefc,
  1055. /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::defcab
  1056. dhwigo = defcab,
  1057. /// 3D RNN data tensor in the format (seq_length, batch, input
  1058. /// channels); an alias for #dnnl::memory::format_tag::abc.
  1059. tnc = abc,
  1060. /// 3D RNN data tensor in the format (batch, seq_length, input
  1061. /// channels); an alias for #dnnl::memory::format_tag::bac.
  1062. ntc = bac,
  1063. /// 4D RNN states tensor in the format (num_layers, num_directions,
  1064. /// batch, state channels); an alias for #dnnl::memory::format_tag::abcd.
  1065. ldnc = abcd,
  1066. /// 5D RNN weights tensor in the format (num_layers, num_directions,
  1067. /// input_channels, num_gates, output_channels);
  1068. /// an alias for #dnnl::memory::format_tag::abcde.
  1069. ///
  1070. /// - For LSTM cells, the gates order is input, forget, candidate
  1071. /// and output gate.
  1072. /// - For GRU cells, the gates order is update, reset and output gate.
  1073. ldigo = abcde,
  1074. /// 5D RNN weights tensor in the format (num_layers, num_directions,
  1075. /// num_gates, output_channels, input_channels);
  1076. /// an alias for #dnnl::memory::format_tag::abdec.
  1077. ///
  1078. /// - For LSTM cells, the gates order is input, forget, candidate
  1079. /// and output gate.
  1080. /// - For GRU cells, the gates order is update, reset and output gate.
  1081. ldgoi = abdec,
  1082. /// 4D LSTM projection tensor in the format (num_layers, num_directions,
  1083. /// num_channels_in_hidden_state, num_channels_in_recurrent_projection);
  1084. /// an alias for #dnnl::memory::format_tag::abcd.
  1085. ldio = abcd,
  1086. /// 4D LSTM projection tensor in the format (num_layers, num_directions,
  1087. /// num_channels_in_recurrent_projection, num_channels_in_hidden_state);
  1088. /// an alias for #dnnl::memory::format_tag::abdc.
  1089. ldoi = abdc,
  1090. /// 4D RNN bias tensor in the format (num_layers, num_directions,
  1091. /// num_gates, output_channels);
  1092. /// an alias for #dnnl::memory::format_tag::abcd.
  1093. ///
  1094. /// - For LSTM cells, the gates order is input, forget, candidate
  1095. /// and output gate.
  1096. /// - For GRU cells, the gates order is update, reset and output gate.
  1097. ldgo = abcd,
  1098. // Opaque blocked formats
  1099. AB16b16a = dnnl_AB16b16a,
  1100. AB16b32a = dnnl_AB16b32a,
  1101. AB16b48a = dnnl_AB16b48a,
  1102. AB16b64a = dnnl_AB16b64a,
  1103. AB8b16a2b = dnnl_AB8b16a2b,
  1104. AB8b32a2b = dnnl_AB8b32a2b,
  1105. AB8b64a2b = dnnl_AB8b64a2b,
  1106. AB4b16a4b = dnnl_AB4b16a4b,
  1107. AB4b32a4b = dnnl_AB4b32a4b,
  1108. AB4b64a4b = dnnl_AB4b64a4b,
  1109. AB16b16a4b = dnnl_AB16b16a4b,
  1110. AB16b32a4b = dnnl_AB16b32a4b,
  1111. AB16b48a4b = dnnl_AB16b48a4b,
  1112. AB16b64a4b = dnnl_AB16b64a4b,
  1113. AB16b16a2b = dnnl_AB16b16a2b,
  1114. AB16b32a2b = dnnl_AB16b32a2b,
  1115. AB16b48a2b = dnnl_AB16b48a2b,
  1116. AB16b64a2b = dnnl_AB16b64a2b,
  1117. Ab4a = dnnl_Ab4a,
  1118. Ab8a = dnnl_Ab8a,
  1119. Ab32a = dnnl_Ab32a,
  1120. Abc16a = dnnl_Abc16a,
  1121. ABc16a16b = dnnl_ABc16a16b,
  1122. ABc4a4b = dnnl_ABc4a4b,
  1123. aBc16b = dnnl_aBc16b,
  1124. aBc32b = dnnl_aBc32b,
  1125. ABc16b16a = dnnl_ABc16b16a,
  1126. AcB16b16a = dnnl_AcB16b16a,
  1127. ABc16b32a = dnnl_ABc16b32a,
  1128. AcB16b32a = dnnl_AcB16b32a,
  1129. ABc16b48a = dnnl_ABc16b48a,
  1130. AcB16b48a = dnnl_AcB16b48a,
  1131. ABc16b64a = dnnl_ABc16b64a,
  1132. AcB16b64a = dnnl_AcB16b64a,
  1133. Abc4a = dnnl_Abc4a,
  1134. aBc4b = dnnl_aBc4b,
  1135. ABc4b16a4b = dnnl_ABc4b16a4b,
  1136. AcB4b16a4b = dnnl_AcB4b16a4b,
  1137. ABc4b32a4b = dnnl_ABc4b32a4b,
  1138. AcB4b32a4b = dnnl_AcB4b32a4b,
  1139. ABc4b64a4b = dnnl_ABc4b64a4b,
  1140. AcB4b64a4b = dnnl_AcB4b64a4b,
  1141. ABc2b8a4b = dnnl_ABc2b8a4b,
  1142. ABc16a16b2a = dnnl_ABc16a16b2a,
  1143. ABc16b16a4b = dnnl_ABc16b16a4b,
  1144. ABc16b32a4b = dnnl_ABc16b32a4b,
  1145. ABc16b48a4b = dnnl_ABc16b48a4b,
  1146. ABc16b64a4b = dnnl_ABc16b64a4b,
  1147. ABc16b16a2b = dnnl_ABc16b16a2b,
  1148. ABc16b32a2b = dnnl_ABc16b32a2b,
  1149. ABc16b48a2b = dnnl_ABc16b48a2b,
  1150. ABc16b64a2b = dnnl_ABc16b64a2b,
  1151. ABc4b4a = dnnl_ABc4b4a,
  1152. ABc8a16b2a = dnnl_ABc8a16b2a,
  1153. ABc8a8b = dnnl_ABc8a8b,
  1154. ABc8a4b = dnnl_ABc8a4b,
  1155. aBc8b = dnnl_aBc8b,
  1156. ABc8b16a2b = dnnl_ABc8b16a2b,
  1157. AcB8b16a2b = dnnl_AcB8b16a2b,
  1158. ABc8b32a2b = dnnl_ABc8b32a2b,
  1159. AcB8b32a2b = dnnl_AcB8b32a2b,
  1160. ABc8b64a2b = dnnl_ABc8b64a2b,
  1161. AcB8b64a2b = dnnl_AcB8b64a2b,
  1162. ABc8b8a = dnnl_ABc8b8a,
  1163. AcB8b8a = dnnl_AcB8b8a,
  1164. Abcd8a = dnnl_Abcd8a,
  1165. Abcd16a = dnnl_Abcd16a,
  1166. Abcd32a = dnnl_Abcd32a,
  1167. ABcd16a16b = dnnl_ABcd16a16b,
  1168. aBcd16b = dnnl_aBcd16b,
  1169. aBcd32b = dnnl_aBcd32b,
  1170. ABcd16b16a = dnnl_ABcd16b16a,
  1171. AcdB16b16a = dnnl_AcdB16b16a,
  1172. ABcd16b32a = dnnl_ABcd16b32a,
  1173. AcdB16b32a = dnnl_AcdB16b32a,
  1174. ABcd16b48a = dnnl_ABcd16b48a,
  1175. AcdB16b48a = dnnl_AcdB16b48a,
  1176. ABcd16b64a = dnnl_ABcd16b64a,
  1177. AcdB16b64a = dnnl_AcdB16b64a,
  1178. aBCd16b16c = dnnl_aBCd16b16c,
  1179. aBCd16c16b = dnnl_aBCd16c16b,
  1180. Abcd4a = dnnl_Abcd4a,
  1181. aBcd4b = dnnl_aBcd4b,
  1182. ABcd4b16a4b = dnnl_ABcd4b16a4b,
  1183. AcdB4b16a4b = dnnl_AcdB4b16a4b,
  1184. ABcd4b32a4b = dnnl_ABcd4b32a4b,
  1185. AcdB4b32a4b = dnnl_AcdB4b32a4b,
  1186. ABcd4b64a4b = dnnl_ABcd4b64a4b,
  1187. AcdB4b64a4b = dnnl_AcdB4b64a4b,
  1188. ABcd2b8a4b = dnnl_ABcd2b8a4b,
  1189. ABcd4b4a = dnnl_ABcd4b4a,
  1190. ABcd4a4b = dnnl_ABcd4a4b,
  1191. aBCd4c16b4c = dnnl_aBCd4c16b4c,
  1192. aBCd2c8b4c = dnnl_aBCd2c8b4c,
  1193. ABcd16a16b2a = dnnl_ABcd16a16b2a,
  1194. ABcd16b16a4b = dnnl_ABcd16b16a4b,
  1195. ABcd16b32a4b = dnnl_ABcd16b32a4b,
  1196. ABcd16b48a4b = dnnl_ABcd16b48a4b,
  1197. ABcd16b64a4b = dnnl_ABcd16b64a4b,
  1198. ABcd16b16a2b = dnnl_ABcd16b16a2b,
  1199. ABcd16b32a2b = dnnl_ABcd16b32a2b,
  1200. ABcd16b48a2b = dnnl_ABcd16b48a2b,
  1201. ABcd16b64a2b = dnnl_ABcd16b64a2b,
  1202. aBCd16b16c2b = dnnl_aBCd16b16c2b,
  1203. aBCd16c16b4c = dnnl_aBCd16c16b4c,
  1204. aBCd16c16b2c = dnnl_aBCd16c16b2c,
  1205. aBCd4c4b = dnnl_aBCd4c4b,
  1206. aBCd4b4c = dnnl_aBCd4b4c,
  1207. ABcd8a16b2a = dnnl_ABcd8a16b2a,
  1208. ABcd8a8b = dnnl_ABcd8a8b,
  1209. ABcd8a4b = dnnl_ABcd8a4b,
  1210. ABcd8a2b = dnnl_ABcd8a2b,
  1211. /// 4D tensor blocked by 2nd dimension with block size 8
  1212. aBcd8b = dnnl_aBcd8b,
  1213. ABcd8b16a2b = dnnl_ABcd8b16a2b,
  1214. AcdB8b16a2b = dnnl_AcdB8b16a2b,
  1215. ABcd8b32a2b = dnnl_ABcd8b32a2b,
  1216. AcdB8b32a2b = dnnl_AcdB8b32a2b,
  1217. ABcd8b64a2b = dnnl_ABcd8b64a2b,
  1218. AcdB8b64a2b = dnnl_AcdB8b64a2b,
  1219. aBCd8b16c2b = dnnl_aBCd8b16c2b,
  1220. /// 4D tensor blocked by 1st and 2nd dimension with block size 8
  1221. ABcd8b8a = dnnl_ABcd8b8a,
  1222. AcdB8b8a = dnnl_AcdB8b8a,
  1223. aBCd8b8c = dnnl_aBCd8b8c,
  1224. aBCd8b4c = dnnl_aBCd8b4c,
  1225. aBCd8c16b2c = dnnl_aBCd8c16b2c,
  1226. aBCd8c8b = dnnl_aBCd8c8b,
  1227. Abcde16a = dnnl_Abcde16a,
  1228. Abcde32a = dnnl_Abcde32a,
  1229. ABcde16a16b = dnnl_ABcde16a16b,
  1230. aBcde16b = dnnl_aBcde16b,
  1231. aBcde32b = dnnl_aBcde32b,
  1232. ABcde16b16a = dnnl_ABcde16b16a,
  1233. AcdeB16b16a = dnnl_AcdeB16b16a,
  1234. ABcde16b32a = dnnl_ABcde16b32a,
  1235. AcdeB16b32a = dnnl_AcdeB16b32a,
  1236. ABcde16b48a = dnnl_ABcde16b48a,
  1237. AcdeB16b48a = dnnl_AcdeB16b48a,
  1238. ABcde16b64a = dnnl_ABcde16b64a,
  1239. AcdeB16b64a = dnnl_AcdeB16b64a,
  1240. aBCde16b16c = dnnl_aBCde16b16c,
  1241. aBCde16c16b = dnnl_aBCde16c16b,
  1242. aBCde2c8b4c = dnnl_aBCde2c8b4c,
  1243. Abcde4a = dnnl_Abcde4a,
  1244. aBcde4b = dnnl_aBcde4b,
  1245. ABcde4b4a = dnnl_ABcde4b4a,
  1246. ABcde4a4b = dnnl_ABcde4a4b,
  1247. aBCde4b4c = dnnl_aBCde4b4c,
  1248. aBCde4c16b4c = dnnl_aBCde4c16b4c,
  1249. aBCde16b16c2b = dnnl_aBCde16b16c2b,
  1250. aBCde16c16b4c = dnnl_aBCde16c16b4c,
  1251. aBCde16c16b2c = dnnl_aBCde16c16b2c,
  1252. aBCdef16c16b2c = dnnl_aBCdef16c16b2c,
  1253. aBCde4c4b = dnnl_aBCde4c4b,
  1254. Abcde8a = dnnl_Abcde8a,
  1255. ABcde8a8b = dnnl_ABcde8a8b,
  1256. ABcde8a4b = dnnl_ABcde8a4b,
  1257. aBcde8b = dnnl_aBcde8b,
  1258. ABcde8b16a2b = dnnl_ABcde8b16a2b,
  1259. AcdeB8b16a2b = dnnl_AcdeB8b16a2b,
  1260. ABcde8b32a2b = dnnl_ABcde8b32a2b,
  1261. AcdeB8b32a2b = dnnl_AcdeB8b32a2b,
  1262. ABcde8b64a2b = dnnl_ABcde8b64a2b,
  1263. AcdeB8b64a2b = dnnl_AcdeB8b64a2b,
  1264. ABcde4b16a4b = dnnl_ABcde4b16a4b,
  1265. AcdeB4b16a4b = dnnl_AcdeB4b16a4b,
  1266. ABcde4b32a4b = dnnl_ABcde4b32a4b,
  1267. AcdeB4b32a4b = dnnl_AcdeB4b32a4b,
  1268. ABcde4b64a4b = dnnl_ABcde4b64a4b,
  1269. AcdeB4b64a4b = dnnl_AcdeB4b64a4b,
  1270. ABcde16b16a4b = dnnl_ABcde16b16a4b,
  1271. ABcde16b32a4b = dnnl_ABcde16b32a4b,
  1272. ABcde16b48a4b = dnnl_ABcde16b48a4b,
  1273. ABcde16b64a4b = dnnl_ABcde16b64a4b,
  1274. ABcde16b16a2b = dnnl_ABcde16b16a2b,
  1275. ABcde16b32a2b = dnnl_ABcde16b32a2b,
  1276. ABcde16b48a2b = dnnl_ABcde16b48a2b,
  1277. ABcde16b64a2b = dnnl_ABcde16b64a2b,
  1278. ABcde2b8a4b = dnnl_ABcde2b8a4b,
  1279. aBCde8b16c2b = dnnl_aBCde8b16c2b,
  1280. ABcde8b8a = dnnl_ABcde8b8a,
  1281. AcdeB8b8a = dnnl_AcdeB8b8a,
  1282. aBCde8b8c = dnnl_aBCde8b8c,
  1283. aBCde8b4c = dnnl_aBCde8b4c,
  1284. ABcd4a8b8a4b = dnnl_ABcd4a8b8a4b,
  1285. ABcd2a8b8a2b = dnnl_ABcd2a8b8a2b,
  1286. aBCde4b8c8b4c = dnnl_aBCde4b8c8b4c,
  1287. aBCde2b8c8b2c = dnnl_aBCde2b8c8b2c,
  1288. aBCde8c16b2c = dnnl_aBCde8c16b2c,
  1289. aBCde8c8b = dnnl_aBCde8c8b,
  1290. aBcdef16b = dnnl_aBcdef16b,
  1291. aBCdef16b16c = dnnl_aBCdef16b16c,
  1292. aBCdef16c16b = dnnl_aBCdef16c16b,
  1293. aBcdef4b = dnnl_aBcdef4b,
  1294. aBCdef2c8b4c = dnnl_aBCdef2c8b4c,
  1295. aBCdef4c4b = dnnl_aBCdef4c4b,
  1296. aBCdef4b4c = dnnl_aBCdef4b4c,
  1297. aBCdef8b8c = dnnl_aBCdef8b8c,
  1298. aBCdef8b4c = dnnl_aBCdef8b4c,
  1299. aBCdef8c16b2c = dnnl_aBCdef8c16b2c,
  1300. aBCdef4c16b4c = dnnl_aBCdef4c16b4c,
  1301. aBCdef8c8b = dnnl_aBCdef8c8b,
  1302. aBdc16b = dnnl_aBdc16b,
  1303. aBdc4b = dnnl_aBdc4b,
  1304. aBdc8b = dnnl_aBdc8b,
  1305. aBdC8b2c = dnnl_aBdC8b2c,
  1306. aBdC8b4c = dnnl_aBdC8b4c,
  1307. aBdec16b = dnnl_aBdec16b,
  1308. aBdec4b = dnnl_aBdec4b,
  1309. aBdec8b = dnnl_aBdec8b,
  1310. aBdeC8b2c = dnnl_aBdeC8b2c,
  1311. aBdeC8b4c = dnnl_aBdeC8b4c,
  1312. aBdefc16b = dnnl_aBdefc16b,
  1313. aCBdef16c16b = dnnl_aCBdef16c16b,
  1314. aCBdef8b8c = dnnl_aCBdef8b8c,
  1315. aCBdef16b16c = dnnl_aCBdef16b16c,
  1316. aBdefc4b = dnnl_aBdefc4b,
  1317. aBdefc8b = dnnl_aBdefc8b,
  1318. aBdefC8b2c = dnnl_aBdefC8b2c,
  1319. aBdefC8b4c = dnnl_aBdefC8b4c,
  1320. Acb16a = dnnl_Acb16a,
  1321. Acb4a = dnnl_Acb4a,
  1322. Acb8a = dnnl_Acb8a,
  1323. AcB8a2b = dnnl_AcB8a2b,
  1324. AcB8a4b = dnnl_AcB8a4b,
  1325. aCBd8b8c = dnnl_aCBd8b8c,
  1326. aCBd16b16c = dnnl_aCBd16b16c,
  1327. aCBd16c16b = dnnl_aCBd16c16b,
  1328. aCBde8b8c = dnnl_aCBde8b8c,
  1329. aCBde16b16c = dnnl_aCBde16b16c,
  1330. aCBde16c16b = dnnl_aCBde16c16b,
  1331. Acdb16a = dnnl_Acdb16a,
  1332. Acdb4a = dnnl_Acdb4a,
  1333. Acdb8a = dnnl_Acdb8a,
  1334. AcdB8a2b = dnnl_AcdB8a2b,
  1335. AcdB8a4b = dnnl_AcdB8a4b,
  1336. Acdeb16a = dnnl_Acdeb16a,
  1337. Acdeb4a = dnnl_Acdeb4a,
  1338. Acdeb8a = dnnl_Acdeb8a,
  1339. AcdeB8a2b = dnnl_AcdeB8a2b,
  1340. AcdeB8a4b = dnnl_AcdeB8a4b,
  1341. BAc8a8b = dnnl_BAc8a8b,
  1342. BAc16a16b = dnnl_BAc16a16b,
  1343. BAc16b16a = dnnl_BAc16b16a,
  1344. BAcd8a8b = dnnl_BAcd8a8b,
  1345. BAcd16a16b = dnnl_BAcd16a16b,
  1346. BAcd16b16a = dnnl_BAcd16b16a,
  1347. ABcd32a32b = dnnl_ABcd32a32b,
  1348. BAcde16b16a = dnnl_BAcde16b16a,
  1349. BAcde8a8b = dnnl_BAcde8a8b,
  1350. BAcde16a16b = dnnl_BAcde16a16b,
  1351. aBdec32b = dnnl_aBdec32b,
  1352. Abcdef16a = dnnl_Abcdef16a,
  1353. Abcdef32a = dnnl_Abcdef32a,
  1354. Acdb32a = dnnl_Acdb32a,
  1355. aBCd2b4c2b = dnnl_aBCd2b4c2b,
  1356. aBCde2b4c2b = dnnl_aBCde2b4c2b,
  1357. aBCdef2b4c2b = dnnl_aBCdef2b4c2b,
  1358. aBCd2c4b2c = dnnl_aBCd2c4b2c,
  1359. aBCde2c4b2c = dnnl_aBCde2c4b2c,
  1360. aBCdef2c4b2c = dnnl_aBCdef2c4b2c,
  1361. aBCd4b8c2b = dnnl_aBCd4b8c2b,
  1362. aBCde4b8c2b = dnnl_aBCde4b8c2b,
  1363. aBCdef4b8c2b = dnnl_aBCdef4b8c2b,
  1364. aBCd4c8b2c = dnnl_aBCd4c8b2c,
  1365. aBCde4c8b2c = dnnl_aBCde4c8b2c,
  1366. aBCdef4c8b2c = dnnl_aBCdef4c8b2c,
  1367. AB32a32b8a4b = dnnl_AB32a32b8a4b,
  1368. AB32a32b8a2b = dnnl_AB32a32b8a2b,
  1369. AB8a4b = dnnl_AB8a4b,
  1370. AB8a2b = dnnl_AB8a2b,
  1371. abDc16d = dnnl_abDc16d,
  1372. abDc32d = dnnl_abDc32d,
  1373. abDC16d4c = dnnl_abDC16d4c,
  1374. abDC32d4c = dnnl_abDC32d4c,
  1375. abCd32c = dnnl_abCd32c,
  1376. abdEc16e = dnnl_abdEc16e,
  1377. abdEc32e = dnnl_abdEc32e,
  1378. abdEC16e4c = dnnl_abdEC16e4c,
  1379. abdEC32e2c = dnnl_abdEC32e2c,
  1380. abdEC32e4c = dnnl_abdEC32e4c,
  1381. abdCe16c = dnnl_abdCe16c,
  1382. abdCe32c = dnnl_abdCe32c,
  1383. abdCE32c2e = dnnl_abdCE32c2e,
  1384. aBCdef16c16b4c = dnnl_aBCdef16c16b4c,
  1385. aBdC16b4c = dnnl_aBdC16b4c,
  1386. aBdeC16b4c = dnnl_aBdeC16b4c,
  1387. AcB16a4b = dnnl_AcB16a4b,
  1388. AcdB16a2b = dnnl_AcdB16a2b,
  1389. aBdefC16b4c = dnnl_aBdefC16b4c,
  1390. AcdeB16a4b = dnnl_AcdeB16a4b,
  1391. Acb32a = dnnl_Acb32a,
  1392. AcB32a2b = dnnl_AcB32a2b,
  1393. AcB32a4b = dnnl_AcB32a4b,
  1394. Acb48a = dnnl_Acb48a,
  1395. AcB48a2b = dnnl_AcB48a2b,
  1396. AcB48a4b = dnnl_AcB48a4b,
  1397. Acb64a = dnnl_Acb64a,
  1398. AcB64a2b = dnnl_AcB64a2b,
  1399. AcB64a4b = dnnl_AcB64a4b,
  1400. cBa2b = dnnl_cBa2b,
  1401. cBa4b = dnnl_cBa4b,
  1402. aBdc32b = dnnl_aBdc32b,
  1403. aBdC32b2c = dnnl_aBdC32b2c,
  1404. aBdC32b4c = dnnl_aBdC32b4c,
  1405. aBdc48b = dnnl_aBdc48b,
  1406. aBdC48b2c = dnnl_aBdC48b2c,
  1407. aBdC48b4c = dnnl_aBdC48b4c,
  1408. aBdc64b = dnnl_aBdc64b,
  1409. aBdC64b2c = dnnl_aBdC64b2c,
  1410. aBdC64b4c = dnnl_aBdC64b4c,
  1411. adcb = dnnl_adcb,
  1412. adCb2c = dnnl_adCb2c,
  1413. adCb4c = dnnl_adCb4c,
  1414. AcdB32a2b = dnnl_AcdB32a2b,
  1415. AcdB32a4b = dnnl_AcdB32a4b,
  1416. Acdb48a = dnnl_Acdb48a,
  1417. AcdB48a2b = dnnl_AcdB48a2b,
  1418. AcdB48a4b = dnnl_AcdB48a4b,
  1419. Acdb64a = dnnl_Acdb64a,
  1420. AcdB64a2b = dnnl_AcdB64a2b,
  1421. AcdB64a4b = dnnl_AcdB64a4b,
  1422. cdBa2b = dnnl_cdBa2b,
  1423. cdBa4b = dnnl_cdBa4b,
  1424. aBdeC32b2c = dnnl_aBdeC32b2c,
  1425. aBdeC32b4c = dnnl_aBdeC32b4c,
  1426. aBdec48b = dnnl_aBdec48b,
  1427. aBdeC48b2c = dnnl_aBdeC48b2c,
  1428. aBdeC48b4c = dnnl_aBdeC48b4c,
  1429. aBdec64b = dnnl_aBdec64b,
  1430. aBdeC64b2c = dnnl_aBdeC64b2c,
  1431. aBdeC64b4c = dnnl_aBdeC64b4c,
  1432. adecb = dnnl_adecb,
  1433. adeCb2c = dnnl_adeCb2c,
  1434. adeCb4c = dnnl_adeCb4c,
  1435. Acdeb32a = dnnl_Acdeb32a,
  1436. AcdeB32a2b = dnnl_AcdeB32a2b,
  1437. AcdeB32a4b = dnnl_AcdeB32a4b,
  1438. Acdeb48a = dnnl_Acdeb48a,
  1439. AcdeB48a2b = dnnl_AcdeB48a2b,
  1440. AcdeB48a4b = dnnl_AcdeB48a4b,
  1441. Acdeb64a = dnnl_Acdeb64a,
  1442. AcdeB64a2b = dnnl_AcdeB64a2b,
  1443. AcdeB64a4b = dnnl_AcdeB64a4b,
  1444. cdeBa2b = dnnl_cdeBa2b,
  1445. cdeBa4b = dnnl_cdeBa4b,
  1446. aBdefc32b = dnnl_aBdefc32b,
  1447. aBdefC32b2c = dnnl_aBdefC32b2c,
  1448. aBdefC32b4c = dnnl_aBdefC32b4c,
  1449. aBdefc48b = dnnl_aBdefc48b,
  1450. aBdefC48b2c = dnnl_aBdefC48b2c,
  1451. aBdefC48b4c = dnnl_aBdefC48b4c,
  1452. aBdefc64b = dnnl_aBdefc64b,
  1453. aBdefC64b2c = dnnl_aBdefC64b2c,
  1454. aBdefC64b4c = dnnl_aBdefC64b4c,
  1455. adefcb = dnnl_adefcb,
  1456. adefCb2c = dnnl_adefCb2c,
  1457. adefCb4c = dnnl_adefCb4c,
  1458. ABc32a32b = dnnl_ABc32a32b,
  1459. BAc8a16b2a = dnnl_BAc8a16b2a,
  1460. BAcd8a16b2a = dnnl_BAcd8a16b2a,
  1461. ABcde8a16b2a = dnnl_ABcde8a16b2a,
  1462. aCBd8b16c2b = dnnl_aCBd8b16c2b,
  1463. BAcde8a16b2a = dnnl_BAcde8a16b2a,
  1464. aCBde8b16c2b = dnnl_aCBde8b16c2b,
  1465. ABcde32a32b = dnnl_ABcde32a32b,
  1466. ABc4a8b8a4b = dnnl_ABc4a8b8a4b,
  1467. ABcde4a8b8a4b = dnnl_ABcde4a8b8a4b,
  1468. BAc4b8a8b4a = dnnl_BAc4b8a8b4a,
  1469. BAcd4b8a8b4a = dnnl_BAcd4b8a8b4a,
  1470. BAcde4b8a8b4a = dnnl_BAcde4b8a8b4a,
  1471. aBCd4b8c8b4c = dnnl_aBCd4b8c8b4c,
  1472. aBCdef4b8c8b4c = dnnl_aBCdef4b8c8b4c,
  1473. aBCdef8b16c2b = dnnl_aBCdef8b16c2b,
  1474. aCBdef8b16c2b = dnnl_aCBdef8b16c2b,
  1475. aBdC16b2c = dnnl_aBdC16b2c,
  1476. aBdeC16b2c = dnnl_aBdeC16b2c,
  1477. aBdefC16b2c = dnnl_aBdefC16b2c,
  1478. aBedc16b = dnnl_aBedc16b,
  1479. AcB16a2b = dnnl_AcB16a2b,
  1480. AcdB16a4b = dnnl_AcdB16a4b,
  1481. AcdeB16a2b = dnnl_AcdeB16a2b,
  1482. Adcb16a = dnnl_Adcb16a,
  1483. aCBd4c8b8c4b = dnnl_aCBd4c8b8c4b,
  1484. aCBde4c8b8c4b = dnnl_aCBde4c8b8c4b,
  1485. aCBdef4c8b8c4b = dnnl_aCBdef4c8b8c4b,
  1486. ABc32a16b = dnnl_ABc32a16b,
  1487. ABcd16a32b = dnnl_ABcd16a32b,
  1488. ABcd32a16b = dnnl_ABcd32a16b,
  1489. ABcde32a16b = dnnl_ABcde32a16b,
  1490. AB48a16b = dnnl_AB48a16b,
  1491. AB48a32b = dnnl_AB48a32b,
  1492. ABc40a16b = dnnl_ABc40a16b,
  1493. ABc40a32b = dnnl_ABc40a32b,
  1494. aBC48b16c = dnnl_aBC48b16c,
  1495. aBC48b32c = dnnl_aBC48b32c,
  1496. ABcd40a16b = dnnl_ABcd40a16b,
  1497. ABcd40a32b = dnnl_ABcd40a32b,
  1498. BA16a16b = dnnl_BA16a16b,
  1499. BA16a32b = dnnl_BA16a32b,
  1500. BA16a48b = dnnl_BA16a48b,
  1501. BA16a64b = dnnl_BA16a64b,
  1502. BA16a16b2a = dnnl_BA16a16b2a,
  1503. BA16a32b2a = dnnl_BA16a32b2a,
  1504. BA16a48b2a = dnnl_BA16a48b2a,
  1505. BA16a64b2a = dnnl_BA16a64b2a,
  1506. BA16a16b4a = dnnl_BA16a16b4a,
  1507. BA16a32b4a = dnnl_BA16a32b4a,
  1508. BA16a48b4a = dnnl_BA16a48b4a,
  1509. BA16a64b4a = dnnl_BA16a64b4a,
  1510. decbA16a = dnnl_decbA16a,
  1511. decbA8a = dnnl_decbA8a,
  1512. defcbA16a = dnnl_defcbA16a,
  1513. defcbA8a = dnnl_defcbA8a,
  1514. aCB16b16c = dnnl_aCB16b16c,
  1515. aCB16b32c = dnnl_aCB16b32c,
  1516. aCB16b48c = dnnl_aCB16b48c,
  1517. aCB16b64c = dnnl_aCB16b64c,
  1518. aCB16b16c2b = dnnl_aCB16b16c2b,
  1519. aCB16b32c2b = dnnl_aCB16b32c2b,
  1520. aCB16b48c2b = dnnl_aCB16b48c2b,
  1521. aCB16b64c2b = dnnl_aCB16b64c2b,
  1522. aCB16b16c4b = dnnl_aCB16b16c4b,
  1523. aCB16b32c4b = dnnl_aCB16b32c4b,
  1524. aCB16b48c4b = dnnl_aCB16b48c4b,
  1525. aCB16b64c4b = dnnl_aCB16b64c4b,
  1526. Acb24a = dnnl_Acb24a,
  1527. Acdb24a = dnnl_Acdb24a,
  1528. Acdeb24a = dnnl_Acdeb24a,
  1529. aBdc24b = dnnl_aBdc24b,
  1530. aBdec24b = dnnl_aBdec24b,
  1531. aBdefc24b = dnnl_aBdefc24b,
  1532. AcB24a2b = dnnl_AcB24a2b,
  1533. AcdB24a2b = dnnl_AcdB24a2b,
  1534. AcdeB24a2b = dnnl_AcdeB24a2b,
  1535. aBdC24b2c = dnnl_aBdC24b2c,
  1536. aBdeC24b2c = dnnl_aBdeC24b2c,
  1537. aBdefC24b2c = dnnl_aBdefC24b2c,
  1538. AcB24a4b = dnnl_AcB24a4b,
  1539. AcdB24a4b = dnnl_AcdB24a4b,
  1540. AcdeB24a4b = dnnl_AcdeB24a4b,
  1541. aBdC24b4c = dnnl_aBdC24b4c,
  1542. aBdeC24b4c = dnnl_aBdeC24b4c,
  1543. aBdefC24b4c = dnnl_aBdefC24b4c,
  1544. AB8b32a = dnnl_AB8b32a,
  1545. ABc8b32a = dnnl_ABc8b32a,
  1546. AcB8b32a = dnnl_AcB8b32a,
  1547. ABcd8b32a = dnnl_ABcd8b32a,
  1548. AcdB8b32a = dnnl_AcdB8b32a,
  1549. ABcde8b32a = dnnl_ABcde8b32a,
  1550. AcdeB8b32a = dnnl_AcdeB8b32a,
  1551. AB8b24a = dnnl_AB8b24a,
  1552. ABc8b24a = dnnl_ABc8b24a,
  1553. AcB8b24a = dnnl_AcB8b24a,
  1554. ABcd8b24a = dnnl_ABcd8b24a,
  1555. AcdB8b24a = dnnl_AcdB8b24a,
  1556. ABcde8b24a = dnnl_ABcde8b24a,
  1557. AcdeB8b24a = dnnl_AcdeB8b24a,
  1558. AB8b16a = dnnl_AB8b16a,
  1559. ABc8b16a = dnnl_ABc8b16a,
  1560. AcB8b16a = dnnl_AcB8b16a,
  1561. ABcd8b16a = dnnl_ABcd8b16a,
  1562. AcdB8b16a = dnnl_AcdB8b16a,
  1563. ABcde8b16a = dnnl_ABcde8b16a,
  1564. AcdeB8b16a = dnnl_AcdeB8b16a,
  1565. AB8b8a = dnnl_AB8b8a,
  1566. format_tag_last = dnnl_format_tag_last,
  1567. nCdhw16c = dnnl_nCdhw16c,
  1568. nCdhw4c = dnnl_nCdhw4c,
  1569. nCdhw8c = dnnl_nCdhw8c,
  1570. nChw16c = dnnl_nChw16c,
  1571. nChw4c = dnnl_nChw4c,
  1572. nChw8c = dnnl_nChw8c,
  1573. nCw16c = dnnl_nCw16c,
  1574. nCw4c = dnnl_nCw4c,
  1575. nCw8c = dnnl_nCw8c,
  1576. NCw16n16c = dnnl_NCw16n16c,
  1577. NChw16n16c = dnnl_NChw16n16c,
  1578. NCdhw16n16c = dnnl_NCdhw16n16c,
  1579. NCdhw32n32c = dnnl_NCdhw32n32c,
  1580. NChw32n32c = dnnl_NChw32n32c,
  1581. IOhw16i16o = dnnl_IOhw16i16o,
  1582. OI16i16o = dnnl_OI16i16o,
  1583. OI16i32o = dnnl_OI16i32o,
  1584. OI16i48o = dnnl_OI16i48o,
  1585. OI16i64o = dnnl_OI16i64o,
  1586. OI8i16o2i = dnnl_OI8i16o2i,
  1587. OI8i32o2i = dnnl_OI8i32o2i,
  1588. OI8i64o2i = dnnl_OI8i64o2i,
  1589. OI4i8o4i = dnnl_OI4i8o4i,
  1590. OI4i16o4i = dnnl_OI4i16o4i,
  1591. OI4i24o4i = dnnl_OI4i24o4i,
  1592. OI4i32o4i = dnnl_OI4i32o4i,
  1593. OI4i64o4i = dnnl_OI4i64o4i,
  1594. Ohwi32o = dnnl_Ohwi32o,
  1595. IOdhw16i16o = dnnl_IOdhw16i16o,
  1596. gIOhw16i16o = dnnl_gIOhw16i16o,
  1597. gOhwi32o = dnnl_gOhwi32o,
  1598. Goidhw16g = dnnl_Goidhw16g,
  1599. IOw8o8i = dnnl_IOw8o8i,
  1600. IOw16o16i = dnnl_IOw16o16i,
  1601. OIw16i16o = dnnl_OIw16i16o,
  1602. OwI16i16o = dnnl_OwI16i16o,
  1603. OIw16i32o = dnnl_OIw16i32o,
  1604. OwI16i32o = dnnl_OwI16i32o,
  1605. OIw16i48o = dnnl_OIw16i48o,
  1606. OwI16i48o = dnnl_OwI16i48o,
  1607. OIw16i64o = dnnl_OIw16i64o,
  1608. OwI16i64o = dnnl_OwI16i64o,
  1609. IOw16i16o = dnnl_IOw16i16o,
  1610. gIOw16i16o = dnnl_gIOw16i16o,
  1611. OIw16o16i = dnnl_OIw16o16i,
  1612. Oiw16o = dnnl_Oiw16o,
  1613. OIw4i8o4i = dnnl_OIw4i8o4i,
  1614. OwI4i8o4i = dnnl_OwI4i8o4i,
  1615. OIw4i16o4i = dnnl_OIw4i16o4i,
  1616. OwI4i16o4i = dnnl_OwI4i16o4i,
  1617. OIw4i24o4i = dnnl_OIw4i24o4i,
  1618. OwI4i24o4i = dnnl_OwI4i24o4i,
  1619. OIw4i32o4i = dnnl_OIw4i32o4i,
  1620. OwI4i32o4i = dnnl_OwI4i32o4i,
  1621. OIw4i64o4i = dnnl_OIw4i64o4i,
  1622. OwI4i64o4i = dnnl_OwI4i64o4i,
  1623. OIw2i8o4i = dnnl_OIw2i8o4i,
  1624. OIw4i4o = dnnl_OIw4i4o,
  1625. OIw4o4i = dnnl_OIw4o4i,
  1626. Oiw4o = dnnl_Oiw4o,
  1627. OIw8i16o2i = dnnl_OIw8i16o2i,
  1628. OwI8i16o2i = dnnl_OwI8i16o2i,
  1629. OIw8i32o2i = dnnl_OIw8i32o2i,
  1630. OwI8i32o2i = dnnl_OwI8i32o2i,
  1631. OIw8i64o2i = dnnl_OIw8i64o2i,
  1632. OwI8i64o2i = dnnl_OwI8i64o2i,
  1633. OIw8i8o = dnnl_OIw8i8o,
  1634. OwI8i8o = dnnl_OwI8i8o,
  1635. OIw8o16i2o = dnnl_OIw8o16i2o,
  1636. OIw8o8i = dnnl_OIw8o8i,
  1637. OIw8o4i = dnnl_OIw8o4i,
  1638. OIw16i16o4i = dnnl_OIw16i16o4i,
  1639. OIw16i32o4i = dnnl_OIw16i32o4i,
  1640. OIw16i48o4i = dnnl_OIw16i48o4i,
  1641. OIw16i64o4i = dnnl_OIw16i64o4i,
  1642. OIw16i16o2i = dnnl_OIw16i16o2i,
  1643. OIw16i32o2i = dnnl_OIw16i32o2i,
  1644. OIw16i48o2i = dnnl_OIw16i48o2i,
  1645. OIw16i64o2i = dnnl_OIw16i64o2i,
  1646. OIw16o16i2o = dnnl_OIw16o16i2o,
  1647. Owi16o = dnnl_Owi16o,
  1648. OwI16o2i = dnnl_OwI16o2i,
  1649. Iwo16i = dnnl_Iwo16i,
  1650. IwO16i2o = dnnl_IwO16i2o,
  1651. IwO16i4o = dnnl_IwO16i4o,
  1652. Owi4o = dnnl_Owi4o,
  1653. Owi8o = dnnl_Owi8o,
  1654. OwI8o2i = dnnl_OwI8o2i,
  1655. OwI8o4i = dnnl_OwI8o4i,
  1656. IOhw8o8i = dnnl_IOhw8o8i,
  1657. IOhw16o16i = dnnl_IOhw16o16i,
  1658. Ohwi16o = dnnl_Ohwi16o,
  1659. OhwI16o2i = dnnl_OhwI16o2i,
  1660. Ihwo16i = dnnl_Ihwo16i,
  1661. IhwO16i2o = dnnl_IhwO16i2o,
  1662. IhwO16i4o = dnnl_IhwO16i4o,
  1663. Ohwi4o = dnnl_Ohwi4o,
  1664. Ohwi8o = dnnl_Ohwi8o,
  1665. OhwI8o2i = dnnl_OhwI8o2i,
  1666. OhwI8o4i = dnnl_OhwI8o4i,
  1667. OIhw16i16o = dnnl_OIhw16i16o,
  1668. OhwI16i16o = dnnl_OhwI16i16o,
  1669. OIhw16i32o = dnnl_OIhw16i32o,
  1670. OhwI16i32o = dnnl_OhwI16i32o,
  1671. OIhw16i48o = dnnl_OIhw16i48o,
  1672. OhwI16i48o = dnnl_OhwI16i48o,
  1673. OIhw16i64o = dnnl_OIhw16i64o,
  1674. OhwI16i64o = dnnl_OhwI16i64o,
  1675. OIhw16o16i = dnnl_OIhw16o16i,
  1676. Oihw16o = dnnl_Oihw16o,
  1677. OIhw4i8o4i = dnnl_OIhw4i8o4i,
  1678. OhwI4i8o4i = dnnl_OhwI4i8o4i,
  1679. OIhw4i16o4i = dnnl_OIhw4i16o4i,
  1680. OhwI4i16o4i = dnnl_OhwI4i16o4i,
  1681. OIhw4i24o4i = dnnl_OIhw4i24o4i,
  1682. OhwI4i24o4i = dnnl_OhwI4i24o4i,
  1683. OIhw4i32o4i = dnnl_OIhw4i32o4i,
  1684. OhwI4i32o4i = dnnl_OhwI4i32o4i,
  1685. OIhw4i64o4i = dnnl_OIhw4i64o4i,
  1686. OhwI4i64o4i = dnnl_OhwI4i64o4i,
  1687. OIhw4i4o = dnnl_OIhw4i4o,
  1688. OIhw4o4i = dnnl_OIhw4o4i,
  1689. Oihw4o = dnnl_Oihw4o,
  1690. OIhw8i16o2i = dnnl_OIhw8i16o2i,
  1691. OhwI8i16o2i = dnnl_OhwI8i16o2i,
  1692. OIhw8i32o2i = dnnl_OIhw8i32o2i,
  1693. OhwI8i32o2i = dnnl_OhwI8i32o2i,
  1694. OIhw8i64o2i = dnnl_OIhw8i64o2i,
  1695. OhwI8i64o2i = dnnl_OhwI8i64o2i,
  1696. OIhw8i8o = dnnl_OIhw8i8o,
  1697. OhwI8i8o = dnnl_OhwI8i8o,
  1698. OIhw8o16i2o = dnnl_OIhw8o16i2o,
  1699. OIhw8o8i = dnnl_OIhw8o8i,
  1700. OIhw8o4i = dnnl_OIhw8o4i,
  1701. OIhw2i8o4i = dnnl_OIhw2i8o4i,
  1702. IOdhw8o8i = dnnl_IOdhw8o8i,
  1703. IOdhw16o16i = dnnl_IOdhw16o16i,
  1704. Odhwi16o = dnnl_Odhwi16o,
  1705. OdhwI16o2i = dnnl_OdhwI16o2i,
  1706. Idhwo16i = dnnl_Idhwo16i,
  1707. IdhwO16i2o = dnnl_IdhwO16i2o,
  1708. IdhwO16i4o = dnnl_IdhwO16i4o,
  1709. Odhwi4o = dnnl_Odhwi4o,
  1710. Odhwi8o = dnnl_Odhwi8o,
  1711. OdhwI8o2i = dnnl_OdhwI8o2i,
  1712. OdhwI8o4i = dnnl_OdhwI8o4i,
  1713. OIdhw16i16o = dnnl_OIdhw16i16o,
  1714. OdhwI16i16o = dnnl_OdhwI16i16o,
  1715. OIdhw16i32o = dnnl_OIdhw16i32o,
  1716. OdhwI16i32o = dnnl_OdhwI16i32o,
  1717. OIdhw16i48o = dnnl_OIdhw16i48o,
  1718. OdhwI16i48o = dnnl_OdhwI16i48o,
  1719. OIdhw16i64o = dnnl_OIdhw16i64o,
  1720. OdhwI16i64o = dnnl_OdhwI16i64o,
  1721. OIdhw16o16i = dnnl_OIdhw16o16i,
  1722. OIdhw16o16i2o = dnnl_OIdhw16o16i2o,
  1723. Oidhw16o = dnnl_Oidhw16o,
  1724. OIdhw4i4o = dnnl_OIdhw4i4o,
  1725. OIdhw4o4i = dnnl_OIdhw4o4i,
  1726. Oidhw4o = dnnl_Oidhw4o,
  1727. OIdhw8i16o2i = dnnl_OIdhw8i16o2i,
  1728. OdhwI8i16o2i = dnnl_OdhwI8i16o2i,
  1729. OIdhw8i32o2i = dnnl_OIdhw8i32o2i,
  1730. OdhwI8i32o2i = dnnl_OdhwI8i32o2i,
  1731. OIdhw8i64o2i = dnnl_OIdhw8i64o2i,
  1732. OdhwI8i64o2i = dnnl_OdhwI8i64o2i,
  1733. OIdhw4i8o4i = dnnl_OIdhw4i8o4i,
  1734. OdhwI4i8o4i = dnnl_OdhwI4i8o4i,
  1735. OIdhw4i16o4i = dnnl_OIdhw4i16o4i,
  1736. OdhwI4i16o4i = dnnl_OdhwI4i16o4i,
  1737. OIdhw16i16o4i = dnnl_OIdhw16i16o4i,
  1738. OIdhw16i32o4i = dnnl_OIdhw16i32o4i,
  1739. OIdhw16i48o4i = dnnl_OIdhw16i48o4i,
  1740. OIdhw16i64o4i = dnnl_OIdhw16i64o4i,
  1741. OIdhw16i16o2i = dnnl_OIdhw16i16o2i,
  1742. OIdhw16i32o2i = dnnl_OIdhw16i32o2i,
  1743. OIdhw16i48o2i = dnnl_OIdhw16i48o2i,
  1744. OIdhw16i64o2i = dnnl_OIdhw16i64o2i,
  1745. OIdhw4i24o4i = dnnl_OIdhw4i24o4i,
  1746. OdhwI4i24o4i = dnnl_OdhwI4i24o4i,
  1747. OIdhw4i32o4i = dnnl_OIdhw4i32o4i,
  1748. OdhwI4i32o4i = dnnl_OdhwI4i32o4i,
  1749. OIdhw4i64o4i = dnnl_OIdhw4i64o4i,
  1750. OdhwI4i64o4i = dnnl_OdhwI4i64o4i,
  1751. OIdhw2i8o4i = dnnl_OIdhw2i8o4i,
  1752. OIdhw8i8o = dnnl_OIdhw8i8o,
  1753. OdhwI8i8o = dnnl_OdhwI8i8o,
  1754. OIdhw8o8i = dnnl_OIdhw8o8i,
  1755. OIdhw8o4i = dnnl_OIdhw8o4i,
  1756. gIOw8o8i = dnnl_gIOw8o8i,
  1757. gIOw16o16i = dnnl_gIOw16o16i,
  1758. gOIw16i16o = dnnl_gOIw16i16o,
  1759. gOIw16o16i = dnnl_gOIw16o16i,
  1760. gOiw16o = dnnl_gOiw16o,
  1761. gOIw4i16o4i = dnnl_gOIw4i16o4i,
  1762. gOIw2i8o4i = dnnl_gOIw2i8o4i,
  1763. gOIw4i4o = dnnl_gOIw4i4o,
  1764. gOIw4o4i = dnnl_gOIw4o4i,
  1765. gOiw4o = dnnl_gOiw4o,
  1766. gOIw8i16o2i = dnnl_gOIw8i16o2i,
  1767. gOIw8i8o = dnnl_gOIw8i8o,
  1768. gOIw8o16i2o = dnnl_gOIw8o16i2o,
  1769. gOIw8o8i = dnnl_gOIw8o8i,
  1770. gOIw8o4i = dnnl_gOIw8o4i,
  1771. gOIw16i16o4i = dnnl_gOIw16i16o4i,
  1772. gOIw16i16o2i = dnnl_gOIw16i16o2i,
  1773. gOIw16o16i2o = dnnl_gOIw16o16i2o,
  1774. gOwi16o = dnnl_gOwi16o,
  1775. gOwI16o2i = dnnl_gOwI16o2i,
  1776. gIwo16i = dnnl_gIwo16i,
  1777. gIwO16i2o = dnnl_gIwO16i2o,
  1778. gIwO16i4o = dnnl_gIwO16i4o,
  1779. gOwi4o = dnnl_gOwi4o,
  1780. gOwi8o = dnnl_gOwi8o,
  1781. gOwI8o2i = dnnl_gOwI8o2i,
  1782. gOwI8o4i = dnnl_gOwI8o4i,
  1783. Goiw8g = dnnl_Goiw8g,
  1784. Goiw16g = dnnl_Goiw16g,
  1785. gIOhw8o8i = dnnl_gIOhw8o8i,
  1786. gIOhw16o16i = dnnl_gIOhw16o16i,
  1787. gOhwi16o = dnnl_gOhwi16o,
  1788. gOhwI16o2i = dnnl_gOhwI16o2i,
  1789. gIhwo16i = dnnl_gIhwo16i,
  1790. gIhwO16i2o = dnnl_gIhwO16i2o,
  1791. gIhwO16i4o = dnnl_gIhwO16i4o,
  1792. gOhwi4o = dnnl_gOhwi4o,
  1793. gOhwi8o = dnnl_gOhwi8o,
  1794. gOhwI8o2i = dnnl_gOhwI8o2i,
  1795. gOhwI8o4i = dnnl_gOhwI8o4i,
  1796. Goihw16g = dnnl_Goihw16g,
  1797. gOIhw16i16o = dnnl_gOIhw16i16o,
  1798. gOIhw16o16i = dnnl_gOIhw16o16i,
  1799. gOihw16o = dnnl_gOihw16o,
  1800. gOIhw4i16o4i = dnnl_gOIhw4i16o4i,
  1801. gOIhw2i8o4i = dnnl_gOIhw2i8o4i,
  1802. gOIhw4i4o = dnnl_gOIhw4i4o,
  1803. gOIhw4o4i = dnnl_gOIhw4o4i,
  1804. gOihw4o = dnnl_gOihw4o,
  1805. Goihw8g = dnnl_Goihw8g,
  1806. gOIhw8i16o2i = dnnl_gOIhw8i16o2i,
  1807. gOIhw8i8o = dnnl_gOIhw8i8o,
  1808. gOIhw8o16i2o = dnnl_gOIhw8o16i2o,
  1809. OIw4o8i8o4i = dnnl_OIw4o8i8o4i,
  1810. OIdhw4o8i8o4i = dnnl_OIdhw4o8i8o4i,
  1811. OIhw4o8i8o4i = dnnl_OIhw4o8i8o4i,
  1812. OIhw2o8i8o2i = dnnl_OIhw2o8i8o2i,
  1813. gOIw4o8i8o4i = dnnl_gOIw4o8i8o4i,
  1814. gOIdhw4o8i8o4i = dnnl_gOIdhw4o8i8o4i,
  1815. gOIhw4o8i8o4i = dnnl_gOIhw4o8i8o4i,
  1816. gOIhw2o8i8o2i = dnnl_gOIhw2o8i8o2i,
  1817. OIhw16i16o4i = dnnl_OIhw16i16o4i,
  1818. OIhw16i32o4i = dnnl_OIhw16i32o4i,
  1819. OIhw16i48o4i = dnnl_OIhw16i48o4i,
  1820. OIhw16i64o4i = dnnl_OIhw16i64o4i,
  1821. OIhw16i16o2i = dnnl_OIhw16i16o2i,
  1822. OIhw16i32o2i = dnnl_OIhw16i32o2i,
  1823. OIhw16i48o2i = dnnl_OIhw16i48o2i,
  1824. OIhw16i64o2i = dnnl_OIhw16i64o2i,
  1825. OIhw16o16i2o = dnnl_OIhw16o16i2o,
  1826. gOIhw16i16o4i = dnnl_gOIhw16i16o4i,
  1827. gOIhw16i16o2i = dnnl_gOIhw16i16o2i,
  1828. gOIhw16o16i2o = dnnl_gOIhw16o16i2o,
  1829. gOIhw8o8i = dnnl_gOIhw8o8i,
  1830. gOIhw8o4i = dnnl_gOIhw8o4i,
  1831. gIOdhw16i16o = dnnl_gIOdhw16i16o,
  1832. gIOdhw8o8i = dnnl_gIOdhw8o8i,
  1833. gIOdhw16o16i = dnnl_gIOdhw16o16i,
  1834. gOdhwi16o = dnnl_gOdhwi16o,
  1835. gOdhwI16o2i = dnnl_gOdhwI16o2i,
  1836. gIdhwo16i = dnnl_gIdhwo16i,
  1837. gIdhwO16i2o = dnnl_gIdhwO16i2o,
  1838. gIdhwO16i4o = dnnl_gIdhwO16i4o,
  1839. gOdhwi4o = dnnl_gOdhwi4o,
  1840. gOdhwi8o = dnnl_gOdhwi8o,
  1841. gOdhwI8o2i = dnnl_gOdhwI8o2i,
  1842. gOdhwI8o4i = dnnl_gOdhwI8o4i,
  1843. gOIdhw16i16o = dnnl_gOIdhw16i16o,
  1844. gOIdhw16o16i = dnnl_gOIdhw16o16i,
  1845. gOIdhw16o16i2o = dnnl_gOIdhw16o16i2o,
  1846. gOidhw16o = dnnl_gOidhw16o,
  1847. gOIdhw4i4o = dnnl_gOIdhw4i4o,
  1848. gOIdhw4o4i = dnnl_gOIdhw4o4i,
  1849. gOidhw4o = dnnl_gOidhw4o,
  1850. gOIdhw8i16o2i = dnnl_gOIdhw8i16o2i,
  1851. gOIdhw4i16o4i = dnnl_gOIdhw4i16o4i,
  1852. gOIdhw16i16o4i = dnnl_gOIdhw16i16o4i,
  1853. gOIdhw16i16o2i = dnnl_gOIdhw16i16o2i,
  1854. gOIdhw2i8o4i = dnnl_gOIdhw2i8o4i,
  1855. gOIdhw8i8o = dnnl_gOIdhw8i8o,
  1856. gOIdhw8o8i = dnnl_gOIdhw8o8i,
  1857. gOIdhw8o4i = dnnl_gOIdhw8o4i,
  1858. gOIw2i4o2i = dnnl_gOIw2i4o2i,
  1859. gOIhw2i4o2i = dnnl_gOIhw2i4o2i,
  1860. gOIdhw2i4o2i = dnnl_gOIdhw2i4o2i,
  1861. gOIw2o4i2o = dnnl_gOIw2o4i2o,
  1862. gOIhw2o4i2o = dnnl_gOIhw2o4i2o,
  1863. gOIdhw2o4i2o = dnnl_gOIdhw2o4i2o,
  1864. gOIw4i8o2i = dnnl_gOIw4i8o2i,
  1865. gOIhw4i8o2i = dnnl_gOIhw4i8o2i,
  1866. gOIdhw4i8o2i = dnnl_gOIdhw4i8o2i,
  1867. gOIw4o8i2o = dnnl_gOIw4o8i2o,
  1868. gOIhw4o8i2o = dnnl_gOIhw4o8i2o,
  1869. gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o,
  1870. ldOi16o = abDc16d,
  1871. ldOi32o = abDc32d,
  1872. ldOI16o4i = abDC16d4c,
  1873. ldOI32o4i = abDC32d4c,
  1874. ldgOi16o = abdEc16e,
  1875. ldgOI16o4i = abdEC16e4c,
  1876. ldgOi32o = abdEc32e,
  1877. ldgOI32o2i = abdEC32e2c,
  1878. ldgOI32o4i = abdEC32e4c,
  1879. OwI16o4i = dnnl_OwI16o4i,
  1880. OhwI16o4i = dnnl_OhwI16o4i,
  1881. gOwI16o4i = dnnl_gOwI16o4i,
  1882. gOhwI16o4i = dnnl_gOhwI16o4i,
  1883. OdhwI16o4i = dnnl_OdhwI16o4i,
  1884. gOdhwI16o4i = dnnl_gOdhwI16o4i,
  1885. Owi32o = dnnl_Owi32o,
  1886. OwI32o2i = dnnl_OwI32o2i,
  1887. OwI32o4i = dnnl_OwI32o4i,
  1888. Owi48o = dnnl_Owi48o,
  1889. OwI48o2i = dnnl_OwI48o2i,
  1890. OwI48o4i = dnnl_OwI48o4i,
  1891. Owi64o = dnnl_Owi64o,
  1892. OwI64o2i = dnnl_OwI64o2i,
  1893. OwI64o4i = dnnl_OwI64o4i,
  1894. Iwo32i = dnnl_Iwo32i,
  1895. IwO32i2o = dnnl_IwO32i2o,
  1896. IwO32i4o = dnnl_IwO32i4o,
  1897. Iwo48i = dnnl_Iwo48i,
  1898. IwO48i2o = dnnl_IwO48i2o,
  1899. IwO48i4o = dnnl_IwO48i4o,
  1900. Iwo64i = dnnl_Iwo64i,
  1901. IwO64i2o = dnnl_IwO64i2o,
  1902. IwO64i4o = dnnl_IwO64i4o,
  1903. wIo2i = dnnl_wIo2i,
  1904. wIo4i = dnnl_wIo4i,
  1905. gOwi32o = dnnl_gOwi32o,
  1906. gOwI32o2i = dnnl_gOwI32o2i,
  1907. gOwI32o4i = dnnl_gOwI32o4i,
  1908. gOwi48o = dnnl_gOwi48o,
  1909. gOwI48o2i = dnnl_gOwI48o2i,
  1910. gOwI48o4i = dnnl_gOwI48o4i,
  1911. gOwi64o = dnnl_gOwi64o,
  1912. gOwI64o2i = dnnl_gOwI64o2i,
  1913. gOwI64o4i = dnnl_gOwI64o4i,
  1914. gIwo32i = dnnl_gIwo32i,
  1915. gIwO32i2o = dnnl_gIwO32i2o,
  1916. gIwO32i4o = dnnl_gIwO32i4o,
  1917. gIwo48i = dnnl_gIwo48i,
  1918. gIwO48i2o = dnnl_gIwO48i2o,
  1919. gIwO48i4o = dnnl_gIwO48i4o,
  1920. gIwo64i = dnnl_gIwo64i,
  1921. gIwO64i2o = dnnl_gIwO64i2o,
  1922. gIwO64i4o = dnnl_gIwO64i4o,
  1923. gwio = dnnl_gwio,
  1924. gwIo2i = dnnl_gwIo2i,
  1925. gwIo4i = dnnl_gwIo4i,
  1926. OhwI32o = dnnl_OhwI32o,
  1927. OhwI32o2i = dnnl_OhwI32o2i,
  1928. OhwI32o4i = dnnl_OhwI32o4i,
  1929. Ohwi48o = dnnl_Ohwi48o,
  1930. OhwI48o2i = dnnl_OhwI48o2i,
  1931. OhwI48o4i = dnnl_OhwI48o4i,
  1932. Ohwi64o = dnnl_Ohwi64o,
  1933. OhwI64o2i = dnnl_OhwI64o2i,
  1934. OhwI64o4i = dnnl_OhwI64o4i,
  1935. Ihwo32i = dnnl_Ihwo32i,
  1936. IhwO32i2o = dnnl_IhwO32i2o,
  1937. IhwO32i4o = dnnl_IhwO32i4o,
  1938. Ihwo48i = dnnl_Ihwo48i,
  1939. IhwO48i2o = dnnl_IhwO48i2o,
  1940. IhwO48i4o = dnnl_IhwO48i4o,
  1941. Ihwo64i = dnnl_Ihwo64i,
  1942. IhwO64i2o = dnnl_IhwO64i2o,
  1943. IhwO64i4o = dnnl_IhwO64i4o,
  1944. hwIo2i = dnnl_hwIo2i,
  1945. hwIo4i = dnnl_hwIo4i,
  1946. gOhwI32o = dnnl_gOhwI32o,
  1947. gOhwI32o2i = dnnl_gOhwI32o2i,
  1948. gOhwI32o4i = dnnl_gOhwI32o4i,
  1949. gOhwi48o = dnnl_gOhwi48o,
  1950. gOhwI48o2i = dnnl_gOhwI48o2i,
  1951. gOhwI48o4i = dnnl_gOhwI48o4i,
  1952. gOhwi64o = dnnl_gOhwi64o,
  1953. gOhwI64o2i = dnnl_gOhwI64o2i,
  1954. gOhwI64o4i = dnnl_gOhwI64o4i,
  1955. gIhwo32i = dnnl_gIhwo32i,
  1956. gIhwO32i2o = dnnl_gIhwO32i2o,
  1957. gIhwO32i4o = dnnl_gIhwO32i4o,
  1958. gIhwo48i = dnnl_gIhwo48i,
  1959. gIhwO48i2o = dnnl_gIhwO48i2o,
  1960. gIhwO48i4o = dnnl_gIhwO48i4o,
  1961. gIhwo64i = dnnl_gIhwo64i,
  1962. gIhwO64i2o = dnnl_gIhwO64i2o,
  1963. gIhwO64i4o = dnnl_gIhwO64i4o,
  1964. ghwio = dnnl_ghwio,
  1965. ghwIo2i = dnnl_ghwIo2i,
  1966. ghwIo4i = dnnl_ghwIo4i,
  1967. Odhwi32o = dnnl_Odhwi32o,
  1968. OdhwI32o2i = dnnl_OdhwI32o2i,
  1969. OdhwI32o4i = dnnl_OdhwI32o4i,
  1970. Odhwi48o = dnnl_Odhwi48o,
  1971. OdhwI48o2i = dnnl_OdhwI48o2i,
  1972. OdhwI48o4i = dnnl_OdhwI48o4i,
  1973. Odhwi64o = dnnl_Odhwi64o,
  1974. OdhwI64o2i = dnnl_OdhwI64o2i,
  1975. OdhwI64o4i = dnnl_OdhwI64o4i,
  1976. Idhwo32i = dnnl_Idhwo32i,
  1977. IdhwO32i2o = dnnl_IdhwO32i2o,
  1978. IdhwO32i4o = dnnl_IdhwO32i4o,
  1979. Idhwo48i = dnnl_Idhwo48i,
  1980. IdhwO48i2o = dnnl_IdhwO48i2o,
  1981. IdhwO48i4o = dnnl_IdhwO48i4o,
  1982. Idhwo64i = dnnl_Idhwo64i,
  1983. IdhwO64i2o = dnnl_IdhwO64i2o,
  1984. IdhwO64i4o = dnnl_IdhwO64i4o,
  1985. dhwIo2i = dnnl_dhwIo2i,
  1986. dhwIo4i = dnnl_dhwIo4i,
  1987. gOdhwi32o = dnnl_gOdhwi32o,
  1988. gOdhwI32o2i = dnnl_gOdhwI32o2i,
  1989. gOdhwI32o4i = dnnl_gOdhwI32o4i,
  1990. gOdhwi48o = dnnl_gOdhwi48o,
  1991. gOdhwI48o2i = dnnl_gOdhwI48o2i,
  1992. gOdhwI48o4i = dnnl_gOdhwI48o4i,
  1993. gOdhwi64o = dnnl_gOdhwi64o,
  1994. gOdhwI64o2i = dnnl_gOdhwI64o2i,
  1995. gOdhwI64o4i = dnnl_gOdhwI64o4i,
  1996. gIdhwo32i = dnnl_gIdhwo32i,
  1997. gIdhwO32i2o = dnnl_gIdhwO32i2o,
  1998. gIdhwO32i4o = dnnl_gIdhwO32i4o,
  1999. gIdhwo48i = dnnl_gIdhwo48i,
  2000. gIdhwO48i2o = dnnl_gIdhwO48i2o,
  2001. gIdhwO48i4o = dnnl_gIdhwO48i4o,
  2002. gIdhwo64i = dnnl_gIdhwo64i,
  2003. gIdhwO64i2o = dnnl_gIdhwO64i2o,
  2004. gIdhwO64i4o = dnnl_gIdhwO64i4o,
  2005. gdhwio = dnnl_gdhwio,
  2006. gdhwIo2i = dnnl_gdhwIo2i,
  2007. gdhwIo4i = dnnl_gdhwIo4i,
  2008. ldIo32i = dnnl_ldIo32i,
  2009. ldgIo16i = dnnl_ldgIo16i,
  2010. ldgIo32i = dnnl_ldgIo32i,
  2011. ldgIO32i2o = dnnl_ldgIO32i2o,
  2012. nCdhw32c = dnnl_nCdhw32c,
  2013. nChw32c = dnnl_nChw32c,
  2014. nCw32c = dnnl_nCw32c,
  2015. NCw32n16c = dnnl_NCw32n16c,
  2016. NChw32n16c = dnnl_NChw32n16c,
  2017. NCdhw32n16c = dnnl_NCdhw32n16c,
  2018. NCw32n32c = dnnl_NCw32n32c,
  2019. OI16i16o4i = dnnl_OI16i16o4i,
  2020. IOw8o16i2o = dnnl_IOw8o16i2o,
  2021. IOhw8o16i2o = dnnl_IOhw8o16i2o,
  2022. Owhi16o = dnnl_Owhi16o,
  2023. OIdhw8o16i2o = dnnl_OIdhw8o16i2o,
  2024. IOdhw8o16i2o = dnnl_IOdhw8o16i2o,
  2025. Goiw4g = dnnl_Goiw4g,
  2026. gIOw8o16i2o = dnnl_gIOw8o16i2o,
  2027. Goiw32g = dnnl_Goiw32g,
  2028. Goihw4g = dnnl_Goihw4g,
  2029. gIOhw8o16i2o = dnnl_gIOhw8o16i2o,
  2030. Goihw32g = dnnl_Goihw32g,
  2031. gOwhi16o = dnnl_gOwhi16o,
  2032. IOw4i8o8i4o = dnnl_IOw4i8o8i4o,
  2033. IOhw4i8o8i4o = dnnl_IOhw4i8o8i4o,
  2034. IOdhw4i8o8i4o = dnnl_IOdhw4i8o8i4o,
  2035. gIOw4i8o8i4o = dnnl_gIOw4i8o8i4o,
  2036. gIOhw4i8o8i4o = dnnl_gIOhw4i8o8i4o,
  2037. gIOdhw4i8o8i4o = dnnl_gIOdhw4i8o8i4o,
  2038. gOIdhw8o16i2o = dnnl_gOIdhw8o16i2o,
  2039. gIOdhw8o16i2o = dnnl_gIOdhw8o16i2o,
  2040. Goidhw32g = dnnl_Goidhw32g,
  2041. OI16i32o4i = dnnl_OI16i32o4i,
  2042. OI16i48o4i = dnnl_OI16i48o4i,
  2043. OI16i64o4i = dnnl_OI16i64o4i,
  2044. OI16i16o2i = dnnl_OI16i16o2i,
  2045. OI16i32o2i = dnnl_OI16i32o2i,
  2046. OI16i48o2i = dnnl_OI16i48o2i,
  2047. OI16i64o2i = dnnl_OI16i64o2i,
  2048. aBdeC16c16b4c = dnnl_aBdeC16c16b4c,
  2049. AcB16b16a2b = dnnl_AcB16b16a2b,
  2050. aBdC16c16b2c = dnnl_aBdC16c16b2c,
  2051. AcB16b16a4b = dnnl_AcB16b16a4b,
  2052. aBdC16c16b4c = dnnl_aBdC16c16b4c,
  2053. AcdB16b16a2b = dnnl_AcdB16b16a2b,
  2054. aBdefC16c16b4c = dnnl_aBdefC16c16b4c,
  2055. AcdeB16b16a4b = dnnl_AcdeB16b16a4b,
  2056. AcB16b32a2b = dnnl_AcB16b32a2b,
  2057. AcB16b32a4b = dnnl_AcB16b32a4b,
  2058. AcB16b48a2b = dnnl_AcB16b48a2b,
  2059. AcB16b48a4b = dnnl_AcB16b48a4b,
  2060. AcB16b64a2b = dnnl_AcB16b64a2b,
  2061. AcB16b64a4b = dnnl_AcB16b64a4b,
  2062. aBdC16c32b2c = dnnl_aBdC16c32b2c,
  2063. aBdC16c32b4c = dnnl_aBdC16c32b4c,
  2064. aBdC16c48b2c = dnnl_aBdC16c48b2c,
  2065. aBdC16c48b4c = dnnl_aBdC16c48b4c,
  2066. aBdC16c64b2c = dnnl_aBdC16c64b2c,
  2067. aBdC16c64b4c = dnnl_aBdC16c64b4c,
  2068. AcdB16b32a2b = dnnl_AcdB16b32a2b,
  2069. AcdB16b32a4b = dnnl_AcdB16b32a4b,
  2070. AcdB16b48a2b = dnnl_AcdB16b48a2b,
  2071. AcdB16b48a4b = dnnl_AcdB16b48a4b,
  2072. AcdB16b64a2b = dnnl_AcdB16b64a2b,
  2073. AcdB16b64a4b = dnnl_AcdB16b64a4b,
  2074. aBdeC16c32b2c = dnnl_aBdeC16c32b2c,
  2075. aBdeC16c32b4c = dnnl_aBdeC16c32b4c,
  2076. aBdeC16c48b2c = dnnl_aBdeC16c48b2c,
  2077. aBdeC16c48b4c = dnnl_aBdeC16c48b4c,
  2078. aBdeC16c64b2c = dnnl_aBdeC16c64b2c,
  2079. aBdeC16c64b4c = dnnl_aBdeC16c64b4c,
  2080. AcdeB16b32a2b = dnnl_AcdeB16b32a2b,
  2081. AcdeB16b32a4b = dnnl_AcdeB16b32a4b,
  2082. AcdeB16b48a2b = dnnl_AcdeB16b48a2b,
  2083. AcdeB16b48a4b = dnnl_AcdeB16b48a4b,
  2084. AcdeB16b64a2b = dnnl_AcdeB16b64a2b,
  2085. AcdeB16b64a4b = dnnl_AcdeB16b64a4b,
  2086. aBdefC16c32b2c = dnnl_aBdefC16c32b2c,
  2087. aBdefC16c32b4c = dnnl_aBdefC16c32b4c,
  2088. aBdefC16c48b2c = dnnl_aBdefC16c48b2c,
  2089. aBdefC16c48b4c = dnnl_aBdefC16c48b4c,
  2090. aBdefC16c64b2c = dnnl_aBdefC16c64b2c,
  2091. aBdefC16c64b4c = dnnl_aBdefC16c64b4c,
  2092. OwI16i16o2i = dnnl_OwI16i16o2i,
  2093. gOwI16i16o2i = dnnl_gOwI16i16o2i,
  2094. OhwI16i16o2i = dnnl_OhwI16i16o2i,
  2095. gOhwI16i16o2i = dnnl_gOhwI16i16o2i,
  2096. OdhwI16i16o2i = dnnl_OdhwI16i16o2i,
  2097. gOdhwI16i16o2i = dnnl_gOdhwI16i16o2i,
  2098. OwI16i16o4i = dnnl_OwI16i16o4i,
  2099. gOwI16i16o4i = dnnl_gOwI16i16o4i,
  2100. OhwI16i16o4i = dnnl_OhwI16i16o4i,
  2101. gOhwI16i16o4i = dnnl_gOhwI16i16o4i,
  2102. OdhwI16i16o4i = dnnl_OdhwI16i16o4i,
  2103. gOdhwI16i16o4i = dnnl_gOdhwI16i16o4i,
  2104. OwI16i32o2i = dnnl_OwI16i32o2i,
  2105. OwI16i32o4i = dnnl_OwI16i32o4i,
  2106. OwI16i48o2i = dnnl_OwI16i48o2i,
  2107. OwI16i48o4i = dnnl_OwI16i48o4i,
  2108. OwI16i64o2i = dnnl_OwI16i64o2i,
  2109. OwI16i64o4i = dnnl_OwI16i64o4i,
  2110. gOwI16i32o2i = dnnl_gOwI16i32o2i,
  2111. gOwI16i32o4i = dnnl_gOwI16i32o4i,
  2112. gOwI16i48o2i = dnnl_gOwI16i48o2i,
  2113. gOwI16i48o4i = dnnl_gOwI16i48o4i,
  2114. gOwI16i64o2i = dnnl_gOwI16i64o2i,
  2115. gOwI16i64o4i = dnnl_gOwI16i64o4i,
  2116. OhwI16i32o2i = dnnl_OhwI16i32o2i,
  2117. OhwI16i32o4i = dnnl_OhwI16i32o4i,
  2118. OhwI16i48o2i = dnnl_OhwI16i48o2i,
  2119. OhwI16i48o4i = dnnl_OhwI16i48o4i,
  2120. OhwI16i64o2i = dnnl_OhwI16i64o2i,
  2121. OhwI16i64o4i = dnnl_OhwI16i64o4i,
  2122. gOhwI16i32o2i = dnnl_gOhwI16i32o2i,
  2123. gOhwI16i32o4i = dnnl_gOhwI16i32o4i,
  2124. gOhwI16i48o2i = dnnl_gOhwI16i48o2i,
  2125. gOhwI16i48o4i = dnnl_gOhwI16i48o4i,
  2126. gOhwI16i64o2i = dnnl_gOhwI16i64o2i,
  2127. gOhwI16i64o4i = dnnl_gOhwI16i64o4i,
  2128. OdhwI16i32o2i = dnnl_OdhwI16i32o2i,
  2129. OdhwI16i32o4i = dnnl_OdhwI16i32o4i,
  2130. OdhwI16i48o2i = dnnl_OdhwI16i48o2i,
  2131. OdhwI16i48o4i = dnnl_OdhwI16i48o4i,
  2132. OdhwI16i64o2i = dnnl_OdhwI16i64o2i,
  2133. OdhwI16i64o4i = dnnl_OdhwI16i64o4i,
  2134. IdhwO16o32i2o = dnnl_IdhwO16o32i2o,
  2135. IdhwO16o32i4o = dnnl_IdhwO16o32i4o,
  2136. IdhwO16o48i2o = dnnl_IdhwO16o48i2o,
  2137. IdhwO16o48i4o = dnnl_IdhwO16o48i4o,
  2138. IdhwO16o64i2o = dnnl_IdhwO16o64i2o,
  2139. IdhwO16o64i4o = dnnl_IdhwO16o64i4o,
  2140. gOdhwI16i32o2i = dnnl_gOdhwI16i32o2i,
  2141. gOdhwI16i32o4i = dnnl_gOdhwI16i32o4i,
  2142. gOdhwI16i48o2i = dnnl_gOdhwI16i48o2i,
  2143. gOdhwI16i48o4i = dnnl_gOdhwI16i48o4i,
  2144. gOdhwI16i64o2i = dnnl_gOdhwI16i64o2i,
  2145. gOdhwI16i64o4i = dnnl_gOdhwI16i64o4i,
  2146. gIdhwO16o32i2o = dnnl_gIdhwO16o32i2o,
  2147. gIdhwO16o32i4o = dnnl_gIdhwO16o32i4o,
  2148. gIdhwO16o48i2o = dnnl_gIdhwO16o48i2o,
  2149. gIdhwO16o48i4o = dnnl_gIdhwO16o48i4o,
  2150. gIdhwO16o64i2o = dnnl_gIdhwO16o64i2o,
  2151. gIdhwO16o64i4o = dnnl_gIdhwO16o64i4o,
  2152. IwO16o16i2o = dnnl_IwO16o16i2o,
  2153. IwO16o16i4o = dnnl_IwO16o16i4o,
  2154. IhwO16o16i2o = dnnl_IhwO16o16i2o,
  2155. IhwO16o16i4o = dnnl_IhwO16o16i4o,
  2156. IdhwO16o16i2o = dnnl_IdhwO16o16i2o,
  2157. IdhwO16o16i4o = dnnl_IdhwO16o16i4o,
  2158. gIwO16o16i2o = dnnl_gIwO16o16i2o,
  2159. gIwO16o16i4o = dnnl_gIwO16o16i4o,
  2160. gIhwO16o16i2o = dnnl_gIhwO16o16i2o,
  2161. gIhwO16o16i4o = dnnl_gIhwO16o16i4o,
  2162. gIdhwO16o16i2o = dnnl_gIdhwO16o16i2o,
  2163. gIdhwO16o16i4o = dnnl_gIdhwO16o16i4o,
  2164. IwO16o32i2o = dnnl_IwO16o32i2o,
  2165. IwO16o32i4o = dnnl_IwO16o32i4o,
  2166. IwO16o48i2o = dnnl_IwO16o48i2o,
  2167. IwO16o48i4o = dnnl_IwO16o48i4o,
  2168. IwO16o64i2o = dnnl_IwO16o64i2o,
  2169. IwO16o64i4o = dnnl_IwO16o64i4o,
  2170. gIwO16o32i2o = dnnl_gIwO16o32i2o,
  2171. gIwO16o32i4o = dnnl_gIwO16o32i4o,
  2172. gIwO16o48i2o = dnnl_gIwO16o48i2o,
  2173. gIwO16o48i4o = dnnl_gIwO16o48i4o,
  2174. gIwO16o64i2o = dnnl_gIwO16o64i2o,
  2175. gIwO16o64i4o = dnnl_gIwO16o64i4o,
  2176. IhwO16o32i2o = dnnl_IhwO16o32i2o,
  2177. IhwO16o32i4o = dnnl_IhwO16o32i4o,
  2178. IhwO16o48i2o = dnnl_IhwO16o48i2o,
  2179. IhwO16o48i4o = dnnl_IhwO16o48i4o,
  2180. IhwO16o64i2o = dnnl_IhwO16o64i2o,
  2181. IhwO16o64i4o = dnnl_IhwO16o64i4o,
  2182. gIhwO16o32i2o = dnnl_gIhwO16o32i2o,
  2183. gIhwO16o32i4o = dnnl_gIhwO16o32i4o,
  2184. gIhwO16o48i2o = dnnl_gIhwO16o48i2o,
  2185. gIhwO16o48i4o = dnnl_gIhwO16o48i4o,
  2186. gIhwO16o64i2o = dnnl_gIhwO16o64i2o,
  2187. gIhwO16o64i4o = dnnl_gIhwO16o64i4o,
  2188. aBdeC16c16b2c = dnnl_aBdeC16c16b2c,
  2189. aBdefC16c16b2c = dnnl_aBdefC16c16b2c,
  2190. AcdB16b16a4b = dnnl_AcdB16b16a4b,
  2191. AcdeB16b16a2b = dnnl_AcdeB16b16a2b,
  2192. hwioG16g = dnnl_hwioG16g,
  2193. hwioG8g = dnnl_hwioG8g,
  2194. dhwioG16g = dnnl_dhwioG16g,
  2195. dhwioG8g = dnnl_dhwioG8g,
  2196. ABc4a2b = dnnl_ABc4a2b,
  2197. ABc8a2b = dnnl_ABc8a2b,
  2198. ABcd4a2b = dnnl_ABcd4a2b,
  2199. ABcde4a2b = dnnl_ABcde4a2b,
  2200. ABcde8a2b = dnnl_ABcde8a2b,
  2201. ABcd4a8b8a2b = dnnl_ABcd4a8b8a2b,
  2202. NCdhw40n32c = dnnl_NCdhw40n32c,
  2203. NChw40n32c = dnnl_NChw40n32c,
  2204. NCw40n32c = dnnl_NCw40n32c,
  2205. OIdhw4o8i8o2i = dnnl_OIdhw4o8i8o2i,
  2206. OIhw4o8i8o2i = dnnl_OIhw4o8i8o2i,
  2207. OIw4o8i8o2i = dnnl_OIw4o8i8o2i,
  2208. gOIdhw4o8i8o2i = dnnl_gOIdhw4o8i8o2i,
  2209. gOIhw4o8i8o2i = dnnl_gOIhw4o8i8o2i,
  2210. gOIw4o8i8o2i = dnnl_gOIw4o8i8o2i,
  2211. IOdhw4i8o8i2o = dnnl_IOdhw4i8o8i2o,
  2212. IOhw4i8o8i2o = dnnl_IOhw4i8o8i2o,
  2213. IOw4i8o8i2o = dnnl_IOw4i8o8i2o,
  2214. gIOdhw4i8o8i2o = dnnl_gIOdhw4i8o8i2o,
  2215. gIOhw4i8o8i2o = dnnl_gIOhw4i8o8i2o,
  2216. gIOw4i8o8i2o = dnnl_gIOw4i8o8i2o,
  2217. aBCd8b2c = dnnl_aBCd8b2c,
  2218. ABcde40a16b = dnnl_ABcde40a16b,
  2219. ABcde40a32b = dnnl_ABcde40a32b,
  2220. aBCde8b2c = dnnl_aBCde8b2c,
  2221. ABcde4a8b8a2b = dnnl_ABcde4a8b8a2b,
  2222. ABc4a8b8a2b = dnnl_ABc4a8b8a2b,
  2223. aBCdef4b8c8b2c = dnnl_aBCdef4b8c8b2c,
  2224. aBCde4b8c8b2c = dnnl_aBCde4b8c8b2c,
  2225. aBCd4b8c8b2c = dnnl_aBCd4b8c8b2c,
  2226. BAcde4b8a8b2a = dnnl_BAcde4b8a8b2a,
  2227. BAcd4b8a8b2a = dnnl_BAcd4b8a8b2a,
  2228. BAc4b8a8b2a = dnnl_BAc4b8a8b2a,
  2229. aCBdef4c8b8c2b = dnnl_aCBdef4c8b8c2b,
  2230. aCBde4c8b8c2b = dnnl_aCBde4c8b8c2b,
  2231. aCBd4c8b8c2b = dnnl_aCBd4c8b8c2b,
  2232. aBCdef8b2c = dnnl_aBCdef8b2c,
  2233. AB32a16b = dnnl_AB32a16b,
  2234. AB32a32b = dnnl_AB32a32b,
  2235. BA4b8a8b2a = dnnl_BA4b8a8b2a,
  2236. BA4b8a8b4a = dnnl_BA4b8a8b4a,
  2237. aBC32b16c = dnnl_aBC32b16c,
  2238. aBC32b32c = dnnl_aBC32b32c,
  2239. aCB4c8b8c2b = dnnl_aCB4c8b8c2b,
  2240. aCB4c8b8c4b = dnnl_aCB4c8b8c4b,
  2241. ABc2b8a16b4a = dnnl_ABc2b8a16b4a,
  2242. ABcd2b8a16b4a = dnnl_ABcd2b8a16b4a,
  2243. ABcde2b8a16b4a = dnnl_ABcde2b8a16b4a,
  2244. ABc2a8b16a4b = dnnl_ABc2a8b16a4b,
  2245. ABc2a8b16a2b = dnnl_ABc2a8b16a2b,
  2246. ABc2b32a8b = dnnl_ABc2b32a8b,
  2247. ABcd2a8b16a4b = dnnl_ABcd2a8b16a4b,
  2248. ABcd2a8b16a2b = dnnl_ABcd2a8b16a2b,
  2249. aCBd2c8b16c2b = dnnl_aCBd2c8b16c2b,
  2250. ABcd2b32a8b = dnnl_ABcd2b32a8b,
  2251. aBCd2c8b16c2b = dnnl_aBCd2c8b16c2b,
  2252. ABcde2a8b16a4b = dnnl_ABcde2a8b16a4b,
  2253. ABcde2a8b16a2b = dnnl_ABcde2a8b16a2b,
  2254. aCBde2c8b16c2b = dnnl_aCBde2c8b16c2b,
  2255. ABcde2b32a8b = dnnl_ABcde2b32a8b,
  2256. aBC2b8c16b2c = dnnl_aBC2b8c16b2c,
  2257. aBCd2b8c16b2c = dnnl_aBCd2b8c16b2c,
  2258. aBCde2b8c16b2c = dnnl_aBCde2b8c16b2c,
  2259. aBCdef2b8c16b2c = dnnl_aBCdef2b8c16b2c,
  2260. BAcde2b8a16b4a = dnnl_BAcde2b8a16b4a,
  2261. BAcd2b8a16b4a = dnnl_BAcd2b8a16b4a,
  2262. BAc2b8a16b4a = dnnl_BAc2b8a16b4a,
  2263. BAcde2b8a16b2a = dnnl_BAcde2b8a16b2a,
  2264. BAcd2b8a16b2a = dnnl_BAcd2b8a16b2a,
  2265. BAc2b8a16b2a = dnnl_BAc2b8a16b2a,
  2266. aBCde2c8b16c2b = dnnl_aBCde2c8b16c2b,
  2267. aBCdef2c8b16c2b = dnnl_aBCdef2c8b16c2b,
  2268. aCBdef2c8b16c2b = dnnl_aCBdef2c8b16c2b,
  2269. aBCd2b8c16b4c = dnnl_aBCd2b8c16b4c,
  2270. aBCde2b8c16b4c = dnnl_aBCde2b8c16b4c,
  2271. NCdhw40n16c = dnnl_NCdhw40n16c,
  2272. NCw40n16c = dnnl_NCw40n16c,
  2273. NChw40n16c = dnnl_NChw40n16c,
  2274. NCw2c32n8c = dnnl_NCw2c32n8c,
  2275. NChw2c32n8c = dnnl_NChw2c32n8c,
  2276. NCdhw2c32n8c = dnnl_NCdhw2c32n8c,
  2277. OIw2i8o16i4o = dnnl_OIw2i8o16i4o,
  2278. OIhw2i8o16i4o = dnnl_OIhw2i8o16i4o,
  2279. OIdhw2i8o16i4o = dnnl_OIdhw2i8o16i4o,
  2280. OIw2o8i16o4i = dnnl_OIw2o8i16o4i,
  2281. OIw2o8i16o2i = dnnl_OIw2o8i16o2i,
  2282. IOw2i8o16i4o = dnnl_IOw2i8o16i4o,
  2283. IOw2i8o16i2o = dnnl_IOw2i8o16i2o,
  2284. OIhw2o8i16o4i = dnnl_OIhw2o8i16o4i,
  2285. OIhw2o8i16o2i = dnnl_OIhw2o8i16o2i,
  2286. IOhw2i8o16i4o = dnnl_IOhw2i8o16i4o,
  2287. IOhw2i8o16i2o = dnnl_IOhw2i8o16i2o,
  2288. OIdhw2o8i16o4i = dnnl_OIdhw2o8i16o4i,
  2289. OIdhw2o8i16o2i = dnnl_OIdhw2o8i16o2i,
  2290. IOdhw2i8o16i4o = dnnl_IOdhw2i8o16i4o,
  2291. IOdhw2i8o16i2o = dnnl_IOdhw2i8o16i2o,
  2292. gOIw2o8i16o2i = dnnl_gOIw2o8i16o2i,
  2293. gIOw2i8o16i2o = dnnl_gIOw2i8o16i2o,
  2294. gIOhw2i8o16i2o = dnnl_gIOhw2i8o16i2o,
  2295. gIOdhw2i8o16i2o = dnnl_gIOdhw2i8o16i2o,
  2296. gOIhw2o8i16o2i = dnnl_gOIhw2o8i16o2i,
  2297. gOIdhw2o8i16o2i = dnnl_gOIdhw2o8i16o2i,
  2298. gOIw2o8i16o4i = dnnl_gOIw2o8i16o4i,
  2299. gOIhw2o8i16o4i = dnnl_gOIhw2o8i16o4i,
  2300. BA4b8a16b2a = dnnl_BA4b8a16b2a,
  2301. BA4b8a16b4a = dnnl_BA4b8a16b4a,
  2302. aCB4c8b16c2b = dnnl_aCB4c8b16c2b,
  2303. aCB4c8b16c4b = dnnl_aCB4c8b16c4b,
  2304. aCB16c2b = dnnl_aCB16c2b,
  2305. aCB16c4b = dnnl_aCB16c4b,
  2306. BA16b2a = dnnl_BA16b2a,
  2307. BA16b4a = dnnl_BA16b4a,
  2308. BA4b4a = dnnl_BA4b4a,
  2309. BA8b4a = dnnl_BA8b4a,
  2310. aBC16b16c = dnnl_aBC16b16c,
  2311. aBC16b32c = dnnl_aBC16b32c,
  2312. AB16a16b = dnnl_AB16a16b,
  2313. AB16a32b = dnnl_AB16a32b,
  2314. ABcde16a16b2a = dnnl_ABcde16a16b2a,
  2315. aBCdef16b16c2b = dnnl_aBCdef16b16c2b,
  2316. Acedb16a = dnnl_Acedb16a,
  2317. aBdfec16b = dnnl_aBdfec16b,
  2318. Odwhi16o = dnnl_Odwhi16o,
  2319. gOdwhi16o = dnnl_gOdwhi16o,
  2320. abdEC64e2c = dnnl_abdEC64e2c,
  2321. abdEC64e4c = dnnl_abdEC64e4c,
  2322. ldgOI64o2i = abdEC64e2c,
  2323. ldgOI64o4i = abdEC64e4c,
  2324. abCd4c = dnnl_abCd4c,
  2325. abCde4c = dnnl_abCde4c,
  2326. abCdef4c = dnnl_abCdef4c,
  2327. abCde32c = dnnl_abCde32c,
  2328. abCdef32c = dnnl_abCdef32c,
  2329. aCdefB16b32c2b = dnnl_aCdefB16b32c2b,
  2330. aCdefB16b32c4b = dnnl_aCdefB16b32c4b,
  2331. aCdefB16b48c2b = dnnl_aCdefB16b48c2b,
  2332. aCdefB16b48c4b = dnnl_aCdefB16b48c4b,
  2333. aCdefB16b64c2b = dnnl_aCdefB16b64c2b,
  2334. aCdefB16b64c4b = dnnl_aCdefB16b64c4b,
  2335. BcdeA16a32b2a = dnnl_BcdeA16a32b2a,
  2336. BcdeA16a32b4a = dnnl_BcdeA16a32b4a,
  2337. BcdeA16a48b2a = dnnl_BcdeA16a48b2a,
  2338. BcdeA16a48b4a = dnnl_BcdeA16a48b4a,
  2339. BcdeA16a64b2a = dnnl_BcdeA16a64b2a,
  2340. BcdeA16a64b4a = dnnl_BcdeA16a64b4a,
  2341. aCdefb32c = dnnl_aCdefb32c,
  2342. aCdefB32c2b = dnnl_aCdefB32c2b,
  2343. aCdefB32c4b = dnnl_aCdefB32c4b,
  2344. aCdefb48c = dnnl_aCdefb48c,
  2345. aCdefB48c2b = dnnl_aCdefB48c2b,
  2346. aCdefB48c4b = dnnl_aCdefB48c4b,
  2347. aCdefb64c = dnnl_aCdefb64c,
  2348. aCdefB64c2b = dnnl_aCdefB64c2b,
  2349. aCdefB64c4b = dnnl_aCdefB64c4b,
  2350. Bcdea32b = dnnl_Bcdea32b,
  2351. BcdeA32b2a = dnnl_BcdeA32b2a,
  2352. BcdeA32b4a = dnnl_BcdeA32b4a,
  2353. Bcdea48b = dnnl_Bcdea48b,
  2354. BcdeA48b2a = dnnl_BcdeA48b2a,
  2355. BcdeA48b4a = dnnl_BcdeA48b4a,
  2356. Bcdea64b = dnnl_Bcdea64b,
  2357. BcdeA64b2a = dnnl_BcdeA64b2a,
  2358. BcdeA64b4a = dnnl_BcdeA64b4a,
  2359. Bca32b = dnnl_Bca32b,
  2360. BcA32b2a = dnnl_BcA32b2a,
  2361. BcA32b4a = dnnl_BcA32b4a,
  2362. Bca48b = dnnl_Bca48b,
  2363. BcA48b2a = dnnl_BcA48b2a,
  2364. BcA48b4a = dnnl_BcA48b4a,
  2365. Bca64b = dnnl_Bca64b,
  2366. BcA64b2a = dnnl_BcA64b2a,
  2367. BcA64b4a = dnnl_BcA64b4a,
  2368. aCdb32c = dnnl_aCdb32c,
  2369. aCdB32c2b = dnnl_aCdB32c2b,
  2370. aCdB32c4b = dnnl_aCdB32c4b,
  2371. aCdb48c = dnnl_aCdb48c,
  2372. aCdB48c2b = dnnl_aCdB48c2b,
  2373. aCdB48c4b = dnnl_aCdB48c4b,
  2374. aCdb64c = dnnl_aCdb64c,
  2375. aCdB64c2b = dnnl_aCdB64c2b,
  2376. aCdB64c4b = dnnl_aCdB64c4b,
  2377. BcA16a16b2a = dnnl_BcA16a16b2a,
  2378. BcA16a16b4a = dnnl_BcA16a16b4a,
  2379. BcdA16a16b2a = dnnl_BcdA16a16b2a,
  2380. BcdA16a16b4a = dnnl_BcdA16a16b4a,
  2381. BcdeA16a16b2a = dnnl_BcdeA16a16b2a,
  2382. BcdeA16a16b4a = dnnl_BcdeA16a16b4a,
  2383. aCdB16b16c2b = dnnl_aCdB16b16c2b,
  2384. aCdB16b16c4b = dnnl_aCdB16b16c4b,
  2385. aCdeB16b16c2b = dnnl_aCdeB16b16c2b,
  2386. aCdeB16b16c4b = dnnl_aCdeB16b16c4b,
  2387. aCdefB16b16c2b = dnnl_aCdefB16b16c2b,
  2388. aCdefB16b16c4b = dnnl_aCdefB16b16c4b,
  2389. BcA16a32b2a = dnnl_BcA16a32b2a,
  2390. BcA16a32b4a = dnnl_BcA16a32b4a,
  2391. BcA16a48b2a = dnnl_BcA16a48b2a,
  2392. BcA16a48b4a = dnnl_BcA16a48b4a,
  2393. BcA16a64b2a = dnnl_BcA16a64b2a,
  2394. BcA16a64b4a = dnnl_BcA16a64b4a,
  2395. aCdB16b32c2b = dnnl_aCdB16b32c2b,
  2396. aCdB16b32c4b = dnnl_aCdB16b32c4b,
  2397. aCdB16b48c2b = dnnl_aCdB16b48c2b,
  2398. aCdB16b48c4b = dnnl_aCdB16b48c4b,
  2399. aCdB16b64c2b = dnnl_aCdB16b64c2b,
  2400. aCdB16b64c4b = dnnl_aCdB16b64c4b,
  2401. BcdA16a32b2a = dnnl_BcdA16a32b2a,
  2402. BcdA16a32b4a = dnnl_BcdA16a32b4a,
  2403. BcdA16a48b2a = dnnl_BcdA16a48b2a,
  2404. BcdA16a48b4a = dnnl_BcdA16a48b4a,
  2405. BcdA16a64b2a = dnnl_BcdA16a64b2a,
  2406. BcdA16a64b4a = dnnl_BcdA16a64b4a,
  2407. aCdeB16b32c2b = dnnl_aCdeB16b32c2b,
  2408. aCdeB16b32c4b = dnnl_aCdeB16b32c4b,
  2409. aCdeB16b48c2b = dnnl_aCdeB16b48c2b,
  2410. aCdeB16b48c4b = dnnl_aCdeB16b48c4b,
  2411. aCdeB16b64c2b = dnnl_aCdeB16b64c2b,
  2412. aCdeB16b64c4b = dnnl_aCdeB16b64c4b,
  2413. Bca16b = dnnl_Bca16b,
  2414. BcA16b2a = dnnl_BcA16b2a,
  2415. BcA16b4a = dnnl_BcA16b4a,
  2416. Bcda16b = dnnl_Bcda16b,
  2417. BcdA16b2a = dnnl_BcdA16b2a,
  2418. BcdA16b4a = dnnl_BcdA16b4a,
  2419. Bcdea16b = dnnl_Bcdea16b,
  2420. BcdeA16b2a = dnnl_BcdeA16b2a,
  2421. BcdeA16b4a = dnnl_BcdeA16b4a,
  2422. aCdb16c = dnnl_aCdb16c,
  2423. aCdB16c2b = dnnl_aCdB16c2b,
  2424. aCdB16c4b = dnnl_aCdB16c4b,
  2425. aCdeb16c = dnnl_aCdeb16c,
  2426. aCdeB16c2b = dnnl_aCdeB16c2b,
  2427. aCdeB16c4b = dnnl_aCdeB16c4b,
  2428. aCdefb16c = dnnl_aCdefb16c,
  2429. aCdefB16c2b = dnnl_aCdefB16c2b,
  2430. aCdefB16c4b = dnnl_aCdefB16c4b,
  2431. Bcda32b = dnnl_Bcda32b,
  2432. BcdA32b2a = dnnl_BcdA32b2a,
  2433. BcdA32b4a = dnnl_BcdA32b4a,
  2434. Bcda48b = dnnl_Bcda48b,
  2435. BcdA48b2a = dnnl_BcdA48b2a,
  2436. BcdA48b4a = dnnl_BcdA48b4a,
  2437. Bcda64b = dnnl_Bcda64b,
  2438. BcdA64b2a = dnnl_BcdA64b2a,
  2439. BcdA64b4a = dnnl_BcdA64b4a,
  2440. aCdeb32c = dnnl_aCdeb32c,
  2441. aCdeB32c2b = dnnl_aCdeB32c2b,
  2442. aCdeB32c4b = dnnl_aCdeB32c4b,
  2443. aCdeb48c = dnnl_aCdeb48c,
  2444. aCdeB48c2b = dnnl_aCdeB48c2b,
  2445. aCdeB48c4b = dnnl_aCdeB48c4b,
  2446. aCdeb64c = dnnl_aCdeb64c,
  2447. aCdeB64c2b = dnnl_aCdeB64c2b,
  2448. aCdeB64c4b = dnnl_aCdeB64c4b,
  2449. NChw16n32c = dnnl_NChw16n32c,
  2450. goIw4i = dnnl_goIw4i,
  2451. goIw32i = dnnl_goIw32i,
  2452. goIhw4i = dnnl_goIhw4i,
  2453. goIhw32i = dnnl_goIhw32i,
  2454. goIdhw4i = dnnl_goIdhw4i,
  2455. goIdhw32i = dnnl_goIdhw32i,
  2456. cab = dnnl_cab,
  2457. cdab = dnnl_cdab,
  2458. cdeab = dnnl_cdeab,
  2459. woi = dnnl_woi,
  2460. hwoi = dnnl_hwoi,
  2461. dhwoi = dnnl_dhwoi,
  2462. Owi24o = dnnl_Owi24o,
  2463. Ohwi24o = dnnl_Ohwi24o,
  2464. Odhwi24o = dnnl_Odhwi24o,
  2465. gOwi24o = dnnl_gOwi24o,
  2466. gOhwi24o = dnnl_gOhwi24o,
  2467. gOdhwi24o = dnnl_gOdhwi24o,
  2468. OwI24o2i = dnnl_OwI24o2i,
  2469. OhwI24o2i = dnnl_OhwI24o2i,
  2470. OdhwI24o2i = dnnl_OdhwI24o2i,
  2471. gOwI24o2i = dnnl_gOwI24o2i,
  2472. gOhwI24o2i = dnnl_gOhwI24o2i,
  2473. gOdhwI24o2i = dnnl_gOdhwI24o2i,
  2474. OwI24o4i = dnnl_OwI24o4i,
  2475. OhwI24o4i = dnnl_OhwI24o4i,
  2476. OdhwI24o4i = dnnl_OdhwI24o4i,
  2477. gOwI24o4i = dnnl_gOwI24o4i,
  2478. gOhwI24o4i = dnnl_gOhwI24o4i,
  2479. gOdhwI24o4i = dnnl_gOdhwI24o4i,
  2480. OI8i32o = dnnl_OI8i32o,
  2481. OIw8i32o = dnnl_OIw8i32o,
  2482. OwI8i32o = dnnl_OwI8i32o,
  2483. OIhw8i32o = dnnl_OIhw8i32o,
  2484. OhwI8i32o = dnnl_OhwI8i32o,
  2485. OIdhw8i32o = dnnl_OIdhw8i32o,
  2486. OdhwI8i32o = dnnl_OdhwI8i32o,
  2487. OI8i24o = dnnl_OI8i24o,
  2488. OIw8i24o = dnnl_OIw8i24o,
  2489. OwI8i24o = dnnl_OwI8i24o,
  2490. OIhw8i24o = dnnl_OIhw8i24o,
  2491. OhwI8i24o = dnnl_OhwI8i24o,
  2492. OIdhw8i24o = dnnl_OIdhw8i24o,
  2493. OdhwI8i24o = dnnl_OdhwI8i24o,
  2494. OI8i16o = dnnl_OI8i16o,
  2495. OIw8i16o = dnnl_OIw8i16o,
  2496. OwI8i16o = dnnl_OwI8i16o,
  2497. OIhw8i16o = dnnl_OIhw8i16o,
  2498. OhwI8i16o = dnnl_OhwI8i16o,
  2499. OIdhw8i16o = dnnl_OIdhw8i16o,
  2500. OdhwI8i16o = dnnl_OdhwI8i16o,
  2501. OI8i8o = dnnl_OI8i8o,
  2502. AB4b8a4b = dnnl_AB4b8a4b,
  2503. AB4b24a4b = dnnl_AB4b24a4b,
  2504. ABc4b8a4b = dnnl_ABc4b8a4b,
  2505. AcB4b8a4b = dnnl_AcB4b8a4b,
  2506. ABc4b24a4b = dnnl_ABc4b24a4b,
  2507. AcB4b24a4b = dnnl_AcB4b24a4b,
  2508. ABcd4b8a4b = dnnl_ABcd4b8a4b,
  2509. AcdB4b8a4b = dnnl_AcdB4b8a4b,
  2510. ABcd4b24a4b = dnnl_ABcd4b24a4b,
  2511. AcdB4b24a4b = dnnl_AcdB4b24a4b,
  2512. ABcde4b8a4b = dnnl_ABcde4b8a4b,
  2513. AcdeB4b8a4b = dnnl_AcdeB4b8a4b,
  2514. ABcde4b24a4b = dnnl_ABcde4b24a4b,
  2515. AcdeB4b24a4b = dnnl_AcdeB4b24a4b,
  2516. Bca8b = dnnl_Bca8b,
  2517. BcA8b2a = dnnl_BcA8b2a,
  2518. Bcda8b = dnnl_Bcda8b,
  2519. BcdA8b2a = dnnl_BcdA8b2a,
  2520. Bcdea8b = dnnl_Bcdea8b,
  2521. BcdeA8b2a = dnnl_BcdeA8b2a,
  2522. aCdb8c = dnnl_aCdb8c,
  2523. aCdB8c2b = dnnl_aCdB8c2b,
  2524. aCdeb8c = dnnl_aCdeb8c,
  2525. aCdeB8c2b = dnnl_aCdeB8c2b,
  2526. aCdefb8c = dnnl_aCdefb8c,
  2527. aCdefB8c2b = dnnl_aCdefB8c2b,
  2528. Bca24b = dnnl_Bca24b,
  2529. BcA24b2a = dnnl_BcA24b2a,
  2530. Bcda24b = dnnl_Bcda24b,
  2531. BcdA24b2a = dnnl_BcdA24b2a,
  2532. Bcdea24b = dnnl_Bcdea24b,
  2533. BcdeA24b2a = dnnl_BcdeA24b2a,
  2534. aCdb24c = dnnl_aCdb24c,
  2535. aCdB24c2b = dnnl_aCdB24c2b,
  2536. aCdeb24c = dnnl_aCdeb24c,
  2537. aCdeB24c2b = dnnl_aCdeB24c2b,
  2538. aCdefb24c = dnnl_aCdefb24c,
  2539. aCdefB24c2b = dnnl_aCdefB24c2b,
  2540. Iwo8i = dnnl_Iwo8i,
  2541. IwO8i2o = dnnl_IwO8i2o,
  2542. Iwo24i = dnnl_Iwo24i,
  2543. IwO24i2o = dnnl_IwO24i2o,
  2544. Ihwo8i = dnnl_Ihwo8i,
  2545. IhwO8i2o = dnnl_IhwO8i2o,
  2546. Ihwo24i = dnnl_Ihwo24i,
  2547. IhwO24i2o = dnnl_IhwO24i2o,
  2548. Idhwo8i = dnnl_Idhwo8i,
  2549. IdhwO8i2o = dnnl_IdhwO8i2o,
  2550. Idhwo24i = dnnl_Idhwo24i,
  2551. IdhwO24i2o = dnnl_IdhwO24i2o,
  2552. gIwo8i = dnnl_gIwo8i,
  2553. gIwO8i2o = dnnl_gIwO8i2o,
  2554. gIwo24i = dnnl_gIwo24i,
  2555. gIwO24i2o = dnnl_gIwO24i2o,
  2556. gIhwo8i = dnnl_gIhwo8i,
  2557. gIhwO8i2o = dnnl_gIhwO8i2o,
  2558. gIhwo24i = dnnl_gIhwo24i,
  2559. gIhwO24i2o = dnnl_gIhwO24i2o,
  2560. gIdhwo8i = dnnl_gIdhwo8i,
  2561. gIdhwO8i2o = dnnl_gIdhwO8i2o,
  2562. gIdhwo24i = dnnl_gIdhwo24i,
  2563. gIdhwO24i2o = dnnl_gIdhwO24i2o,
  2564. OhwI24o = dnnl_OhwI24o,
  2565. gOhwI24o = dnnl_gOhwI24o,
  2566. AB8b24a2b = dnnl_AB8b24a2b,
  2567. ABc8b24a2b = dnnl_ABc8b24a2b,
  2568. AcB8b24a2b = dnnl_AcB8b24a2b,
  2569. ABcd8b24a2b = dnnl_ABcd8b24a2b,
  2570. AcdB8b24a2b = dnnl_AcdB8b24a2b,
  2571. ABcde8b24a2b = dnnl_ABcde8b24a2b,
  2572. AcdeB8b24a2b = dnnl_AcdeB8b24a2b,
  2573. AB8b8a2b = dnnl_AB8b8a2b,
  2574. ABc8b8a2b = dnnl_ABc8b8a2b,
  2575. AcB8b8a2b = dnnl_AcB8b8a2b,
  2576. ABcd8b8a2b = dnnl_ABcd8b8a2b,
  2577. AcdB8b8a2b = dnnl_AcdB8b8a2b,
  2578. ABcde8b8a2b = dnnl_ABcde8b8a2b,
  2579. AcdeB8b8a2b = dnnl_AcdeB8b8a2b,
  2580. OI8i8o2i = dnnl_OI8i8o2i,
  2581. OI8i24o2i = dnnl_OI8i24o2i,
  2582. OIw8i8o2i = dnnl_OIw8i8o2i,
  2583. OwI8i8o2i = dnnl_OwI8i8o2i,
  2584. OIw8i24o2i = dnnl_OIw8i24o2i,
  2585. OwI8i24o2i = dnnl_OwI8i24o2i,
  2586. OIhw8i8o2i = dnnl_OIhw8i8o2i,
  2587. OhwI8i8o2i = dnnl_OhwI8i8o2i,
  2588. OIhw8i24o2i = dnnl_OIhw8i24o2i,
  2589. OhwI8i24o2i = dnnl_OhwI8i24o2i,
  2590. OIdhw8i8o2i = dnnl_OIdhw8i8o2i,
  2591. OdhwI8i8o2i = dnnl_OdhwI8i8o2i,
  2592. OIdhw8i24o2i = dnnl_OIdhw8i24o2i,
  2593. OdhwI8i24o2i = dnnl_OdhwI8i24o2i,
  2594. BcA8b4a = dnnl_BcA8b4a,
  2595. BcdA8b4a = dnnl_BcdA8b4a,
  2596. BcdeA8b4a = dnnl_BcdeA8b4a,
  2597. aCdB8c4b = dnnl_aCdB8c4b,
  2598. aCdeB8c4b = dnnl_aCdeB8c4b,
  2599. aCdefB8c4b = dnnl_aCdefB8c4b,
  2600. BcA24b4a = dnnl_BcA24b4a,
  2601. BcdA24b4a = dnnl_BcdA24b4a,
  2602. BcdeA24b4a = dnnl_BcdeA24b4a,
  2603. aCdB24c4b = dnnl_aCdB24c4b,
  2604. aCdeB24c4b = dnnl_aCdeB24c4b,
  2605. aCdefB24c4b = dnnl_aCdefB24c4b,
  2606. ABc16a4b = dnnl_ABc16a4b,
  2607. ABcd16a4b = dnnl_ABcd16a4b,
  2608. ABcde16a4b = dnnl_ABcde16a4b,
  2609. IwO8i4o = dnnl_IwO8i4o,
  2610. IwO24i4o = dnnl_IwO24i4o,
  2611. IhwO8i4o = dnnl_IhwO8i4o,
  2612. IhwO24i4o = dnnl_IhwO24i4o,
  2613. IdhwO8i4o = dnnl_IdhwO8i4o,
  2614. IdhwO24i4o = dnnl_IdhwO24i4o,
  2615. gIwO8i4o = dnnl_gIwO8i4o,
  2616. gIwO24i4o = dnnl_gIwO24i4o,
  2617. gIhwO8i4o = dnnl_gIhwO8i4o,
  2618. gIhwO24i4o = dnnl_gIhwO24i4o,
  2619. gIdhwO8i4o = dnnl_gIdhwO8i4o,
  2620. gIdhwO24i4o = dnnl_gIdhwO24i4o,
  2621. BA2a24b = dnnl_BA2a24b,
  2622. aCB2b24c = dnnl_aCB2b24c,
  2623. BA2a8b = dnnl_BA2a8b,
  2624. aCB2b8c = dnnl_aCB2b8c,
  2625. BA8a24b = dnnl_BA8a24b,
  2626. aCB8b24c = dnnl_aCB8b24c,
  2627. BA8a16b = dnnl_BA8a16b,
  2628. aCB8b16c = dnnl_aCB8b16c,
  2629. BA8a8b = dnnl_BA8a8b,
  2630. aCB8b8c = dnnl_aCB8b8c,
  2631. bcad = dnnl_bcad,
  2632. cabd = dnnl_cabd,
  2633. dabc = dnnl_dabc,
  2634. };
  2635. /// A memory descriptor.
  2636. struct desc : public handle<dnnl_memory_desc_t> {
  2637. using handle<dnnl_memory_desc_t>::handle;
  2638. friend struct memory;
  2639. /// Constructs a zero (empty) memory descriptor. Such a memory
  2640. /// descriptor can be used to indicate absence of an argument.
  2641. desc() {
  2642. dnnl_memory_desc_t zero_md = nullptr;
  2643. error::wrap_c_api(
  2644. dnnl_memory_desc_create_with_tag(&zero_md, 0, nullptr,
  2645. dnnl_data_type_undef, dnnl_format_tag_undef),
  2646. "could not create a zero memory descriptor");
  2647. reset(zero_md);
  2648. }
  2649. /// Constructs a memory descriptor.
  2650. ///
  2651. /// @note
  2652. /// The logical order of dimensions corresponds to the `abc...`
  2653. /// format tag, and the physical meaning of the dimensions depends
  2654. /// both on the primitive that would operate on this memory and
  2655. /// the operation context.
  2656. ///
  2657. /// @param adims Tensor dimensions.
  2658. /// @param adata_type Data precision/type.
  2659. /// @param aformat_tag Memory format tag.
  2660. /// @param allow_empty A flag signifying whether construction is
  2661. /// allowed to fail without throwing an exception. In this case a
  2662. /// zero memory descriptor will be constructed. This flag is
  2663. /// optional and defaults to false.
  2664. desc(const dims &adims, data_type adata_type, format_tag aformat_tag,
  2665. bool allow_empty = false) {
  2666. validate_dims(adims);
  2667. dnnl_memory_desc_t md = nullptr;
  2668. dnnl_status_t status = dnnl_memory_desc_create_with_tag(&md,
  2669. (int)adims.size(), adims.data(), convert_to_c(adata_type),
  2670. convert_to_c(aformat_tag));
  2671. if (!allow_empty)
  2672. error::wrap_c_api(status,
  2673. "could not construct a memory descriptor using a "
  2674. "format tag");
  2675. reset(md);
  2676. }
  2677. /// Constructs a memory descriptor by strides.
  2678. ///
  2679. /// @note
  2680. /// The logical order of dimensions corresponds to the `abc...`
  2681. /// format tag, and the physical meaning of the dimensions depends
  2682. /// both on the primitive that would operate on this memory and
  2683. /// the operation context.
  2684. ///
  2685. /// @param adims Tensor dimensions.
  2686. /// @param adata_type Data precision/type.
  2687. /// @param strides Strides for each dimension.
  2688. /// @param allow_empty A flag signifying whether construction is
  2689. /// allowed to fail without throwing an exception. In this case a
  2690. /// zero memory descriptor will be constructed. This flag is
  2691. /// optional and defaults to false.
  2692. desc(const dims &adims, data_type adata_type, const dims &strides,
  2693. bool allow_empty = false) {
  2694. validate_dims(adims);
  2695. if (!strides.empty()) validate_dims(strides, (int)adims.size());
  2696. dnnl_memory_desc_t md = nullptr;
  2697. dnnl_status_t status = dnnl_memory_desc_create_with_strides(&md,
  2698. (int)adims.size(), adims.data(), convert_to_c(adata_type),
  2699. strides.empty() ? nullptr : &strides[0]);
  2700. if (!allow_empty)
  2701. error::wrap_c_api(status,
  2702. "could not construct a memory descriptor using "
  2703. "strides");
  2704. reset(md);
  2705. }
  2706. #ifdef DNNL_EXPERIMENTAL_SPARSE
  2707. /// Function for creating a memory descriptor for CSR sparse encoding.
  2708. ///
  2709. /// The created memory descriptor will describe a memory object that
  2710. /// contains 3 buffers. The buffers have the following meaning and
  2711. /// assigned numbers (index):
  2712. /// - 0: values
  2713. /// - 1: indices
  2714. /// - 2: pointers
  2715. ///
  2716. /// @param adims Tensor dimensions.
  2717. /// @param adata_type Data precision/type.
  2718. /// @param nnz Number of non-zero entries.
  2719. /// @param index_dt Data type of indices.
  2720. /// @param pointer_dt Data type of pointers.
  2721. /// @param allow_empty A flag signifying whether construction is
  2722. /// allowed to fail without throwing an exception. In this case a
  2723. /// zero memory descriptor will be constructed. This flag is
  2724. /// optional and defaults to false.
  2725. static desc csr(const dims &adims, data_type adata_type, dim nnz,
  2726. data_type index_dt, data_type pointer_dt,
  2727. bool allow_empty = false) {
  2728. validate_dims(adims);
  2729. dnnl_memory_desc_t md = nullptr;
  2730. dnnl_status_t status = dnnl_memory_desc_create_with_csr_encoding(
  2731. &md, (int)adims.size(), adims.data(),
  2732. convert_to_c(adata_type), nnz, convert_to_c(index_dt),
  2733. convert_to_c(pointer_dt));
  2734. if (!allow_empty)
  2735. error::wrap_c_api(status,
  2736. "could not create a memory descriptor for CSR sparse "
  2737. "encoding");
  2738. return desc {md};
  2739. }
  2740. /// Function for creating a memory descriptor for COO sparse encodings.
  2741. ///
  2742. /// The created memory descriptor will describe a memory object that
  2743. /// contains n+1 buffers for an n-dimensional tensor.
  2744. /// The buffers have the following meaning and assigned numbers (index):
  2745. /// - 0: values
  2746. /// - 1: indices for dimension 0
  2747. /// - 2: indices for dimension 1 ...
  2748. /// - n: indices for dimension n-1
  2749. ///
  2750. /// @param adims Tensor dimensions.
  2751. /// @param adata_type Data precision/type.
  2752. /// @param nnz Number of non-zero entries.
  2753. /// @param index_dt Data type of indices.
  2754. /// @param allow_empty A flag signifying whether construction is
  2755. /// allowed to fail without throwing an exception. In this case a
  2756. /// zero memory descriptor will be constructed. This flag is
  2757. /// optional and defaults to false.
  2758. static desc coo(const dims &adims, data_type adata_type, dim nnz,
  2759. data_type index_dt, bool allow_empty = false) {
  2760. validate_dims(adims);
  2761. dnnl_memory_desc_t md = nullptr;
  2762. dnnl_status_t status = dnnl_memory_desc_create_with_coo_encoding(
  2763. &md, (int)adims.size(), adims.data(),
  2764. convert_to_c(adata_type), nnz, convert_to_c(index_dt));
  2765. if (!allow_empty)
  2766. error::wrap_c_api(status,
  2767. "could not create a memory descriptor for COO sparse "
  2768. "encoding");
  2769. return desc {md};
  2770. }
  2771. /// Function for creating a memory descriptor for packed sparse
  2772. /// encoding.
  2773. ///
  2774. /// The created memory descriptor cannot be used to create a memory
  2775. /// object. It can only be used to create a primitive descriptor to
  2776. /// query the actual memory descriptor (similar to the format tag
  2777. /// `any`).
  2778. ///
  2779. /// @warning
  2780. /// The meaning and content of the handles of the memory object that
  2781. /// is created using the queried memory descriptor are unspecified
  2782. /// therefore using the content is an undefined behavior.
  2783. ///
  2784. /// @param adims Tensor dimensions.
  2785. /// @param adata_type Data precision/type.
  2786. /// @param nnz Number of non-zero entries.
  2787. /// @param allow_empty A flag signifying whether construction is
  2788. /// allowed to fail without throwing an exception. In this case a
  2789. /// zero memory descriptor will be constructed. This flag is
  2790. /// optional and defaults to false.
  2791. static desc packed(const dims &adims, data_type adata_type, dim nnz,
  2792. bool allow_empty = false) {
  2793. validate_dims(adims);
  2794. dnnl_memory_desc_t md = nullptr;
  2795. dnnl_status_t status = dnnl_memory_desc_create_with_packed_encoding(
  2796. &md, (int)adims.size(), adims.data(),
  2797. convert_to_c(adata_type), nnz);
  2798. if (!allow_empty)
  2799. error::wrap_c_api(status,
  2800. "could not create a memory descriptor for packed "
  2801. "sparse encoding");
  2802. return desc {md};
  2803. }
  2804. #endif
  2805. /// Construct a memory descriptor from a C API ::dnnl_memory_desc_t
  2806. /// handle. The resulting handle is not weak and the C handle will be
  2807. /// destroyed during the destruction of the C++ object.
  2808. ///
  2809. /// @param md The C API memory descriptor.
  2810. desc(dnnl_memory_desc_t md) : handle<dnnl_memory_desc_t>(md) {}
  2811. /// Construct a memory descriptor from a binary blob.
  2812. ///
  2813. /// @param blob A binary blob previously queried from a memory descriptor.
  2814. desc(const std::vector<uint8_t> &blob) {
  2815. dnnl_memory_desc_t md = nullptr;
  2816. error::wrap_c_api(
  2817. dnnl_memory_desc_create_with_blob(&md, blob.data()),
  2818. "could not create a memory descriptor from blob");
  2819. reset(md);
  2820. }
  2821. /// Constructs a memory descriptor for a region inside an area
  2822. /// described by this memory descriptor.
  2823. //
  2824. /// @param adims Sizes of the region.
  2825. /// @param offsets Offsets to the region from the encompassing
  2826. /// memory object in each dimension.
  2827. /// @param allow_empty A flag signifying whether construction is
  2828. /// allowed to fail without throwing an exception. In this case a
  2829. /// zero memory descriptor will be returned. This flag is optional
  2830. /// and defaults to false.
  2831. /// @returns A memory descriptor for the region.
  2832. desc submemory_desc(const dims &adims, const dims &offsets,
  2833. bool allow_empty = false) const {
  2834. validate_dims(adims, get_ndims());
  2835. validate_dims(offsets, get_ndims());
  2836. dnnl_memory_desc_t sub_md = nullptr;
  2837. dnnl_status_t status = dnnl_memory_desc_create_submemory(
  2838. &sub_md, get(), adims.data(), offsets.data());
  2839. if (!allow_empty)
  2840. error::wrap_c_api(status, "could not construct a sub-memory");
  2841. return desc(sub_md);
  2842. }
  2843. /// Constructs a memory descriptor by reshaping an existing one. The
  2844. /// new memory descriptor inherits the data type. This operation is
  2845. /// valid only for memory descriptors that have format_kind set to
  2846. /// #dnnl::memory::format_kind::blocked or
  2847. /// #dnnl::memory::format_kind::any.
  2848. ///
  2849. /// The operation ensures that the transformation of the physical memory
  2850. /// format corresponds to the transformation of the logical dimensions.
  2851. /// If such transformation is impossible, the function either throws an
  2852. /// exception (default) or returns a zero memory descriptor depending on
  2853. /// the `allow_empty` flag.
  2854. ///
  2855. /// The reshape operation can be described as a combination of the
  2856. /// following basic operations:
  2857. /// 1. Add a dimension of size `1`. This is always possible.
  2858. /// 2. Remove a dimension of size `1`. This is possible only if the
  2859. /// dimension has no padding (i.e.
  2860. /// `padded_dims[dim] == dims[dim] && dims[dim] == 1`).
  2861. /// 3. Split a dimension into multiple ones. This is possible only if
  2862. /// the product of all tensor dimensions stays constant and the
  2863. /// dimension being split does not have padding (i.e.
  2864. /// `padded_dims[dim] = dims[dim]`).
  2865. /// 4. Join multiple consecutive dimensions into a single one. As in
  2866. /// the cases above, this requires that the dimensions do not have
  2867. /// padding and that the memory format is such that in physical
  2868. /// memory these dimensions are dense and have the same order as
  2869. /// their logical counterparts. This also assumes that these
  2870. /// dimensions are not blocked.
  2871. /// - Here, 'dense' means:
  2872. /// `stride for dim[i] == (stride for dim[i + 1]) * dim[i + 1]`;
  2873. /// - And 'same order' means:
  2874. /// `i < j` if and only if `stride for dim[j] <= stride for dim[i]`.
  2875. ///
  2876. /// @warning
  2877. /// Some combinations of physical memory layout and/or offsets or
  2878. /// dimensions may result in a failure to make a reshape.
  2879. ///
  2880. /// @param adims New dimensions. The product of dimensions must
  2881. /// remain constant.
  2882. /// @param allow_empty A flag signifying whether construction is
  2883. /// allowed to fail without throwing an exception. In this case a
  2884. /// zero memory descriptor will be returned. This flag is optional
  2885. /// and defaults to false.
  2886. /// @returns A new memory descriptor with new dimensions.
  2887. desc reshape(const dims &adims, bool allow_empty = false) const {
  2888. if (get_ndims()) validate_dims(adims, 1);
  2889. dnnl_memory_desc_t out_md = nullptr;
  2890. dnnl_status_t status = dnnl_memory_desc_reshape(
  2891. &out_md, get(), (int)adims.size(), adims.data());
  2892. if (!allow_empty)
  2893. error::wrap_c_api(
  2894. status, "could not reshape a memory descriptor");
  2895. return desc(out_md);
  2896. }
  2897. /// Constructs a memory descriptor by permuting axes in an existing
  2898. /// one.
  2899. ///
  2900. /// The physical memory layout representation is adjusted accordingly
  2901. /// to maintain the consistency between the logical and physical parts
  2902. /// of the memory descriptor. The new memory descriptor inherits the
  2903. /// data type.
  2904. ///
  2905. /// The new memory descriptor inherits the data type. This operation is
  2906. /// valid only for memory descriptors that have format_kind set to
  2907. /// #dnnl::memory::format_kind::blocked or
  2908. /// #dnnl::memory::format_kind::any.
  2909. ///
  2910. /// The logical axes will be permuted in the following manner:
  2911. /// @code
  2912. /// for (i = 0; i < get_ndims(); i++)
  2913. /// new_desc.dims()[permutation[i]] = dims()[i];
  2914. /// @endcode
  2915. ///
  2916. /// Example:
  2917. /// @code
  2918. /// std::vector<int> permutation = {1, 0}; // swap the first and
  2919. /// // the second axes
  2920. /// dnnl::memory::desc in_md(
  2921. /// {2, 3}, data_type, memory::format_tag::ab);
  2922. /// dnnl::memory::desc expect_out_md(
  2923. /// {3, 2}, data_type, memory::format_tag::ba);
  2924. ///
  2925. /// assert(in_md.permute_axes(permutation) == expect_out_md);
  2926. /// @endcode
  2927. ///
  2928. /// @param permutation Axes permutation.
  2929. /// @param allow_empty A flag signifying whether construction is
  2930. /// allowed to fail without throwing an exception. In this case a
  2931. /// zero memory descriptor will be returned. This flag is optional
  2932. /// and defaults to false.
  2933. /// @returns A new memory descriptor with new dimensions.
  2934. desc permute_axes(const std::vector<int> &permutation,
  2935. bool allow_empty = false) const {
  2936. validate_dims(permutation, get_ndims());
  2937. dnnl_memory_desc_t out_md = nullptr;
  2938. dnnl_status_t status = dnnl_memory_desc_permute_axes(
  2939. &out_md, get(), permutation.data());
  2940. if (!allow_empty)
  2941. error::wrap_c_api(status,
  2942. "could not permute axes of a memory descriptor");
  2943. return desc(out_md);
  2944. }
  2945. /// Returns a number of dimensions of the memory descriptor.
  2946. ///
  2947. /// @returns A number of dimensions.
  2948. int get_ndims() const { return query_s32(query::ndims_s32); }
  2949. /// Returns padded dimensions of the memory descriptor.
  2950. ///
  2951. /// @returns A copy of the padded dimensions vector.
  2952. memory::dims get_padded_dims() const {
  2953. return query_dims(query::padded_dims);
  2954. }
  2955. /// Returns padded offsets of the memory descriptor.
  2956. ///
  2957. /// @returns A copy of the padded offsets vector.
  2958. memory::dims get_padded_offsets() const {
  2959. return query_dims(query::padded_offsets);
  2960. }
  2961. /// Returns a submemory offset of the memory descriptor.
  2962. ///
  2963. /// @returns A submemory offset.
  2964. memory::dim get_submemory_offset() const {
  2965. dnnl_dim_t submemory_offset;
  2966. dnnl_status_t status = dnnl_memory_desc_query(
  2967. get(), dnnl_query_submemory_offset_s64, &submemory_offset);
  2968. return status == dnnl_success ? submemory_offset : 0;
  2969. }
  2970. /// Returns strides of the memory descriptor.
  2971. ///
  2972. /// @note
  2973. /// This API is only applicable to memory descriptors with format
  2974. /// kind #dnnl_blocked.
  2975. ///
  2976. /// @returns A copy of the strides vector.
  2977. /// @returns An empty #dnnl::memory::dims if the memory descriptor
  2978. /// does not have strides.
  2979. memory::dims get_strides() const { return query_dims(query::strides); }
  2980. /// Returns a number of inner blocks of the memory descriptor.
  2981. ///
  2982. /// @note
  2983. /// This API is only applicable to memory descriptors with format
  2984. /// kind #dnnl_blocked.
  2985. ///
  2986. /// @returns A number of inner blocks.
  2987. int get_inner_nblks() const {
  2988. return query_s32(query::inner_nblks_s32);
  2989. }
  2990. /// Returns inner blocks of the memory descriptor.
  2991. ///
  2992. /// @note
  2993. /// This API is only applicable to memory descriptors with format
  2994. /// kind #dnnl_blocked.
  2995. ///
  2996. /// @returns A copy of the inner blocks vector.
  2997. /// @returns An empty #dnnl::memory::dims if the memory descriptor
  2998. /// does not have inner blocks.
  2999. memory::dims get_inner_blks() const {
  3000. return query_dims(query::inner_blks);
  3001. }
  3002. /// Returns inner indices of the memory descriptor.
  3003. ///
  3004. /// @note
  3005. /// This API is only applicable to memory descriptors with format
  3006. /// kind #dnnl_blocked.
  3007. ///
  3008. /// @returns A copy of the inner indices vector.
  3009. /// @returns An empty #dnnl::memory::dims if the memory descriptor
  3010. /// does not have inner indices.
  3011. memory::dims get_inner_idxs() const {
  3012. return query_dims(query::inner_idxs);
  3013. }
  3014. #ifdef DNNL_EXPERIMENTAL_SPARSE
  3015. /// Returns number of handles.
  3016. ///
  3017. /// @returns A number of handles.
  3018. int get_num_handles() const {
  3019. int nhandles;
  3020. dnnl_status_t status = dnnl_memory_desc_query_v2(
  3021. get(), dnnl_query_num_handles_s32, 0, &nhandles);
  3022. return status == dnnl_success ? nhandles : 0;
  3023. }
  3024. /// Returns a number of non-zero entries of the memory descriptor.
  3025. ///
  3026. /// @returns A number non-zero entries.
  3027. dim get_nnz() const {
  3028. dnnl_dim_t nnz;
  3029. dnnl_status_t status = dnnl_memory_desc_query_v2(
  3030. get(), dnnl_query_nnz_s64, 0, &nnz);
  3031. return status == dnnl_success ? nnz : 0;
  3032. }
  3033. /// Returns the sparse encoding of the memory descriptor.
  3034. ///
  3035. /// @returns the sparse encoding kind.
  3036. memory::sparse_encoding get_sparse_encoding() const {
  3037. dnnl_sparse_encoding_t sparse_encoding;
  3038. dnnl_status_t status = dnnl_memory_desc_query_v2(
  3039. get(), dnnl_query_sparse_encoding, 0, &sparse_encoding);
  3040. return status == dnnl_success
  3041. ? static_cast<dnnl::memory::sparse_encoding>(
  3042. sparse_encoding)
  3043. : dnnl::memory::sparse_encoding::undef;
  3044. }
  3045. /// Returns the data type of the memory descriptor.
  3046. ///
  3047. /// @returns The data type.
  3048. memory::data_type get_data_type(int index = 0) const {
  3049. return query_data_type(query::data_type, index);
  3050. }
  3051. #else
  3052. /// Returns the data type of the memory descriptor.
  3053. ///
  3054. /// @returns The data type.
  3055. memory::data_type get_data_type() const {
  3056. return query_data_type(query::data_type);
  3057. }
  3058. #endif
  3059. /// Returns the format kind of the memory descriptor.
  3060. ///
  3061. /// @returns the format kind.
  3062. memory::format_kind get_format_kind() const {
  3063. dnnl_format_kind_t format_kind;
  3064. dnnl_status_t status = dnnl_memory_desc_query(
  3065. get(), dnnl_query_format_kind, &format_kind);
  3066. return status == dnnl_success
  3067. ? static_cast<dnnl::memory::format_kind>(format_kind)
  3068. : dnnl::memory::format_kind::undef;
  3069. }
  3070. /// Returns dimensions of the memory descriptor.
  3071. ///
  3072. /// Potentially expensive due to the data copy involved.
  3073. /// @returns A copy of the dimensions vector.
  3074. memory::dims get_dims() const { return query_dims(query::dims); }
  3075. #ifdef DNNL_EXPERIMENTAL_SPARSE
  3076. /// Returns size of the memory descriptor in bytes.
  3077. /// @param index Data index. Defaults to 0.
  3078. /// @returns The number of bytes required to allocate a memory buffer
  3079. /// for data with a particular @p index described by this memory
  3080. /// descriptor including the padding area.
  3081. size_t get_size(int index = 0) const {
  3082. return dnnl_memory_desc_get_size_v2(get(), index);
  3083. }
  3084. #else
  3085. /// Returns size of the memory descriptor in bytes.
  3086. /// @returns The number of bytes required to allocate a memory buffer
  3087. /// for the memory object described by this memory descriptor
  3088. /// including the padding area.
  3089. size_t get_size() const { return dnnl_memory_desc_get_size(get()); }
  3090. #endif
  3091. /// Returns a binary blob associated with the given memory descriptor
  3092. /// @returns The memory descriptor blob associated with the memory descriptor
  3093. std::vector<uint8_t> get_blob() {
  3094. size_t size;
  3095. dnnl_status_t status
  3096. = dnnl_memory_desc_get_blob(nullptr, &size, get());
  3097. error::wrap_c_api(
  3098. status, "could not get memory descriptor blob size");
  3099. std::vector<uint8_t> out_blob(size);
  3100. status = dnnl_memory_desc_get_blob(out_blob.data(), &size, get());
  3101. error::wrap_c_api(status, "could not get memory descriptor blob");
  3102. return out_blob;
  3103. }
  3104. /// Checks whether the memory descriptor is zero (empty).
  3105. /// @returns @c true if the memory descriptor describes an empty
  3106. /// memory and @c false otherwise.
  3107. bool is_zero() const { return get_ndims() == 0; }
  3108. /// An equality operator.
  3109. /// @param other Another memory descriptor.
  3110. /// @returns Whether this and the other memory descriptors have
  3111. /// the same format tag, dimensions, strides, blocking, etc.
  3112. bool operator==(const desc &other) const {
  3113. return dnnl_memory_desc_equal(get(), other.get()) != 0;
  3114. }
  3115. /// An inequality operator.
  3116. /// @param other Another memory descriptor.
  3117. /// @returns Whether this and the other memory descriptors describe
  3118. /// different memory.
  3119. bool operator!=(const desc &other) const { return !operator==(other); }
  3120. private:
  3121. #ifdef DNNL_EXPERIMENTAL_SPARSE
  3122. memory::data_type query_data_type(query what, int index) const {
  3123. dnnl_data_type_t data_type;
  3124. dnnl_status_t status = dnnl_memory_desc_query_v2(
  3125. get(), dnnl::convert_to_c(what), index, &data_type);
  3126. return status == dnnl_success
  3127. ? static_cast<dnnl::memory::data_type>(data_type)
  3128. : dnnl::memory::data_type::undef;
  3129. }
  3130. #else
  3131. memory::data_type query_data_type(query what) const {
  3132. dnnl_data_type_t data_type;
  3133. dnnl_status_t status = dnnl_memory_desc_query(
  3134. get(), dnnl::convert_to_c(what), &data_type);
  3135. return status == dnnl_success
  3136. ? static_cast<dnnl::memory::data_type>(data_type)
  3137. : dnnl::memory::data_type::undef;
  3138. }
  3139. #endif
  3140. int query_s32(query what) const {
  3141. int res;
  3142. dnnl_status_t status = dnnl_memory_desc_query(
  3143. get(), dnnl::convert_to_c(what), &res);
  3144. return status == dnnl_success ? res : 0;
  3145. }
  3146. memory::dims query_dims(query what) const {
  3147. dnnl_dims_t *c_dims;
  3148. dnnl_status_t status = dnnl_memory_desc_query(
  3149. get(), dnnl::convert_to_c(what), &c_dims);
  3150. const int ndims
  3151. = (what == query::inner_idxs || what == query::inner_blks)
  3152. ? get_inner_nblks()
  3153. : get_ndims();
  3154. return status == dnnl_success
  3155. ? memory::dims(*c_dims, *c_dims + ndims)
  3156. : memory::dims {};
  3157. }
  3158. };
  3159. /// Default constructor.
  3160. ///
  3161. /// Constructs an empty memory object, which can be used to indicate
  3162. /// absence of a parameter.
  3163. memory() = default;
  3164. #ifdef DNNL_EXPERIMENTAL_SPARSE
  3165. /// Constructs a memory object.
  3166. ///
  3167. /// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory
  3168. /// object will have the underlying buffer set. In this case, the buffer
  3169. /// will be initialized as if #dnnl::memory::set_data_handle() had been
  3170. /// called.
  3171. ///
  3172. /// @sa memory::set_data_handle()
  3173. ///
  3174. /// @param md Memory descriptor.
  3175. /// @param aengine Engine to store the data on.
  3176. /// @param handle Handle of the memory buffer to use.
  3177. /// - A pointer to the user-allocated buffer. In this case the library
  3178. /// doesn't own the buffer.
  3179. /// - The #DNNL_MEMORY_ALLOCATE special value. Instructs the library to
  3180. /// allocate the buffer for the memory object. In this case the
  3181. /// library owns the buffer.
  3182. /// - #DNNL_MEMORY_NONE to create dnnl::memory without an underlying
  3183. /// buffer.
  3184. memory(const desc &md, const engine &aengine, void *handle)
  3185. : memory(md, aengine, std::vector<void *> {handle}) {}
  3186. /// Constructs a memory object with multiple handles.
  3187. ///
  3188. /// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory
  3189. /// object will have the underlying buffer set. In this case, the buffer
  3190. /// will be initialized as if #dnnl::memory::set_data_handle() had been
  3191. /// called.
  3192. ///
  3193. /// @sa memory::set_data_handle()
  3194. ///
  3195. /// @param md Memory descriptor.
  3196. /// @param aengine Engine to store the data on.
  3197. /// @param handles Handles of the memory buffers to use.
  3198. /// For each element of the @p handles vector the following applies:
  3199. /// - A pointer to the user-allocated buffer. In this case the library
  3200. /// doesn't own the buffer.
  3201. /// - The #DNNL_MEMORY_ALLOCATE special value. Instructs the library to
  3202. /// allocate the buffer for the memory object. In this case the
  3203. /// library owns the buffer.
  3204. /// - #DNNL_MEMORY_NONE Instructs the library to skip allocation of the
  3205. /// memory buffer.
  3206. memory(const desc &md, const engine &aengine, std::vector<void *> handles) {
  3207. dnnl_memory_t result;
  3208. dnnl_status_t status = dnnl_memory_create_v2(&result, md.get(),
  3209. aengine.get(), (int)handles.size(), handles.data());
  3210. error::wrap_c_api(status, "could not create a memory object");
  3211. reset(result);
  3212. }
  3213. /// Constructs a memory object.
  3214. ///
  3215. /// The underlying buffer(s) for the memory will be allocated by the
  3216. /// library.
  3217. /// @param md Memory descriptor.
  3218. /// @param aengine Engine to store the data on.
  3219. memory(const desc &md, const engine &aengine) {
  3220. dnnl_status_t status;
  3221. dnnl_memory_t result;
  3222. const int nhandles = md.get_num_handles();
  3223. std::vector<void *> handles(nhandles, DNNL_MEMORY_ALLOCATE);
  3224. status = dnnl_memory_create_v2(&result, md.get(), aengine.get(),
  3225. (int)handles.size(), handles.data());
  3226. error::wrap_c_api(status, "could not create a memory object");
  3227. reset(result);
  3228. }
  3229. #else
  3230. /// Constructs a memory object.
  3231. ///
  3232. /// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory
  3233. /// object will have the underlying buffer set. In this case, the buffer
  3234. /// will be initialized as if #dnnl::memory::set_data_handle() had been
  3235. /// called.
  3236. ///
  3237. /// @sa memory::set_data_handle()
  3238. ///
  3239. /// @param md Memory descriptor.
  3240. /// @param aengine Engine to store the data on.
  3241. /// @param handle Handle of the memory buffer to use.
  3242. /// - A pointer to the user-allocated buffer. In this case the library
  3243. /// doesn't own the buffer.
  3244. /// - The #DNNL_MEMORY_ALLOCATE special value. Instructs the library to
  3245. /// allocate the buffer for the memory object. In this case the
  3246. /// library owns the buffer.
  3247. /// - #DNNL_MEMORY_NONE to create dnnl::memory without an underlying
  3248. /// buffer.
  3249. memory(const desc &md, const engine &aengine, void *handle) {
  3250. dnnl_memory_t result;
  3251. error::wrap_c_api(
  3252. dnnl_memory_create(&result, md.get(), aengine.get(), handle),
  3253. "could not create a memory object");
  3254. reset(result);
  3255. }
  3256. /// Constructs a memory object.
  3257. ///
  3258. /// The underlying buffer for the memory will be allocated by the library.
  3259. ///
  3260. /// @param md Memory descriptor.
  3261. /// @param aengine Engine to store the data on.
  3262. memory(const desc &md, const engine &aengine)
  3263. : memory(md, aengine, DNNL_MEMORY_ALLOCATE) {}
  3264. #endif
  3265. /// Returns the associated memory descriptor.
  3266. desc get_desc() const {
  3267. const_dnnl_memory_desc_t cdesc;
  3268. error::wrap_c_api(dnnl_memory_get_memory_desc(get(), &cdesc),
  3269. "could not get a memory descriptor from a memory object");
  3270. dnnl_memory_desc_t cloned_md = nullptr;
  3271. error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
  3272. "could not clone a memory descriptor");
  3273. return desc(cloned_md);
  3274. }
  3275. /// Returns the associated engine.
  3276. engine get_engine() const {
  3277. dnnl_engine_t c_engine;
  3278. error::wrap_c_api(dnnl_memory_get_engine(get(), &c_engine),
  3279. "could not get an engine from a memory object");
  3280. return engine(c_engine, true);
  3281. }
  3282. #ifdef DNNL_EXPERIMENTAL_SPARSE
  3283. /// Returns an underlying memory buffer that corresponds to the given index.
  3284. ///
  3285. /// On the CPU engine, or when using USM, this is a pointer to the
  3286. /// allocated memory.
  3287. void *get_data_handle(int index = 0) const {
  3288. void *handle;
  3289. error::wrap_c_api(dnnl_memory_get_data_handle_v2(get(), &handle, index),
  3290. "could not get a native handle from a memory object");
  3291. return handle;
  3292. }
  3293. /// Sets an underlying memory buffer that corresponds to the given index.
  3294. ///
  3295. /// @param handle Memory buffer to use. On the CPU engine or when USM is
  3296. /// used, the memory buffer is a pointer to the actual data. For OpenCL
  3297. /// it is a cl_mem. It must have at least
  3298. /// #dnnl::memory::desc::get_size() bytes allocated.
  3299. /// @param index Memory index to attach the buffer. Defaults to 0.
  3300. void set_data_handle(void *handle, int index = 0) const {
  3301. error::wrap_c_api(dnnl_memory_set_data_handle_v2(get(), handle, index),
  3302. "could not set native handle of a memory object");
  3303. }
  3304. /// Maps a memory object and returns a host-side pointer to a memory
  3305. /// buffer with a copy of its contents. The memory buffer corresponds to
  3306. /// the given index.
  3307. ///
  3308. /// Mapping enables read/write directly from/to the memory contents for
  3309. /// engines that do not support direct memory access.
  3310. ///
  3311. /// Mapping is an exclusive operation - a memory object cannot be used in
  3312. /// other operations until it is unmapped via #dnnl::memory::unmap_data()
  3313. /// call.
  3314. ///
  3315. /// @note
  3316. /// Any primitives working with the memory should be completed before
  3317. /// the memory is mapped. Use #dnnl::stream::wait() to synchronize the
  3318. /// corresponding execution stream.
  3319. ///
  3320. /// @note
  3321. /// The map_data and unmap_data functions are provided mainly for
  3322. /// debug and testing purposes and their performance may be suboptimal.
  3323. ///
  3324. /// @tparam T Data type to return a pointer to.
  3325. /// @param index Index of the buffer. Defaults to 0.
  3326. /// @returns Pointer to the mapped memory.
  3327. template <typename T = void>
  3328. T *map_data(int index = 0) const {
  3329. void *mapped_ptr;
  3330. error::wrap_c_api(dnnl_memory_map_data_v2(get(), &mapped_ptr, index),
  3331. "could not map memory object data");
  3332. return static_cast<T *>(mapped_ptr);
  3333. }
  3334. /// Unmaps a memory object and writes back any changes made to the
  3335. /// previously mapped memory buffer. The memory buffer corresponds to
  3336. /// the given index.
  3337. ///
  3338. /// @note
  3339. /// The map_data and unmap_data functions are provided mainly for
  3340. /// debug and testing purposes and their performance may be
  3341. /// suboptimal.
  3342. ///
  3343. /// @param mapped_ptr A pointer previously returned by
  3344. /// #dnnl::memory::map_data().
  3345. /// @param index Index of the buffer. Defaults to 0.
  3346. void unmap_data(void *mapped_ptr, int index = 0) const {
  3347. error::wrap_c_api(dnnl_memory_unmap_data_v2(get(), mapped_ptr, index),
  3348. "could not unmap memory object data");
  3349. }
  3350. #else
  3351. /// Returns the underlying memory buffer.
  3352. ///
  3353. /// On the CPU engine, or when using USM, this is a pointer to the
  3354. /// allocated memory.
  3355. void *get_data_handle() const {
  3356. void *handle;
  3357. error::wrap_c_api(dnnl_memory_get_data_handle(get(), &handle),
  3358. "could not get a native handle from a memory object");
  3359. return handle;
  3360. }
  3361. /// Sets the underlying memory buffer.
  3362. ///
  3363. /// @param handle Memory buffer to use. On the CPU engine or when USM is
  3364. /// used, the memory buffer is a pointer to the actual data. For OpenCL
  3365. /// it is a cl_mem. It must have at least
  3366. /// #dnnl::memory::desc::get_size() bytes allocated.
  3367. void set_data_handle(void *handle) const {
  3368. error::wrap_c_api(dnnl_memory_set_data_handle(get(), handle),
  3369. "could not set native handle of a memory object");
  3370. }
  3371. /// Maps a memory object and returns a host-side pointer to a memory
  3372. /// buffer with a copy of its contents.
  3373. ///
  3374. /// Mapping enables read/write directly from/to the memory contents for
  3375. /// engines that do not support direct memory access.
  3376. ///
  3377. /// Mapping is an exclusive operation - a memory object cannot be used in
  3378. /// other operations until it is unmapped via #dnnl::memory::unmap_data()
  3379. /// call.
  3380. ///
  3381. /// @note
  3382. /// Any primitives working with the memory should be completed before
  3383. /// the memory is mapped. Use #dnnl::stream::wait() to synchronize the
  3384. /// corresponding execution stream.
  3385. ///
  3386. /// @note
  3387. /// The map_data and unmap_data functions are provided mainly for
  3388. /// debug and testing purposes and their performance may be suboptimal.
  3389. ///
  3390. /// @tparam T Data type to return a pointer to.
  3391. /// @returns Pointer to the mapped memory.
  3392. template <typename T = void>
  3393. T *map_data() const {
  3394. void *mapped_ptr;
  3395. error::wrap_c_api(dnnl_memory_map_data(get(), &mapped_ptr),
  3396. "could not map memory object data");
  3397. return static_cast<T *>(mapped_ptr);
  3398. }
  3399. /// Unmaps a memory object and writes back any changes made to the
  3400. /// previously mapped memory buffer.
  3401. ///
  3402. /// @note
  3403. /// The map_data and unmap_data functions are provided mainly for
  3404. /// debug and testing purposes and their performance may be
  3405. /// suboptimal.
  3406. ///
  3407. /// @param mapped_ptr A pointer previously returned by
  3408. /// #dnnl::memory::map_data().
  3409. void unmap_data(void *mapped_ptr) const {
  3410. error::wrap_c_api(dnnl_memory_unmap_data(get(), mapped_ptr),
  3411. "could not unmap memory object data");
  3412. }
  3413. #endif
  3414. static dnnl_data_type_t convert_to_c(data_type adata_type) {
  3415. return static_cast<dnnl_data_type_t>(adata_type);
  3416. }
  3417. static dnnl_format_tag_t convert_to_c(format_tag format) {
  3418. return static_cast<dnnl_format_tag_t>(format);
  3419. }
  3420. };
  3421. inline bool operator==(dnnl_data_type_t a, memory::data_type b) {
  3422. return a == memory::convert_to_c(b);
  3423. }
  3424. inline bool operator!=(dnnl_data_type_t a, memory::data_type b) {
  3425. return !(a == b);
  3426. }
  3427. inline bool operator==(memory::data_type a, dnnl_data_type_t b) {
  3428. return b == a;
  3429. }
  3430. inline bool operator!=(memory::data_type a, dnnl_data_type_t b) {
  3431. return !(a == b);
  3432. }
  3433. inline bool operator==(dnnl_format_tag_t a, memory::format_tag b) {
  3434. return a == memory::convert_to_c(b);
  3435. }
  3436. inline bool operator!=(dnnl_format_tag_t a, memory::format_tag b) {
  3437. return !(a == b);
  3438. }
  3439. inline bool operator==(memory::format_tag a, dnnl_format_tag_t b) {
  3440. return b == a;
  3441. }
  3442. inline bool operator!=(memory::format_tag a, dnnl_format_tag_t b) {
  3443. return !(a == b);
  3444. }
  3445. /// @} dnnl_api_memory
  3446. /// @addtogroup dnnl_api_primitives
  3447. /// @{
  3448. /// @addtogroup dnnl_api_attributes Attributes
  3449. ///
  3450. /// A container for parameters that extend primitives behavior.
  3451. ///
  3452. /// @{
  3453. /// @cond DO_NOT_DOCUMENT_THIS
  3454. template <>
  3455. struct handle_traits<dnnl_post_ops_t> {
  3456. static dnnl_status_t destructor(dnnl_post_ops_t p) {
  3457. return dnnl_post_ops_destroy(p);
  3458. }
  3459. };
  3460. /// @endcond
  3461. /// Post-ops.
  3462. ///
  3463. /// Post-ops are computations executed after the main primitive computations
  3464. /// and are attached to the primitive via primitive attributes.
  3465. ///
  3466. /// @sa @ref dev_guide_attributes_post_ops
  3467. ///
  3468. struct post_ops : public handle<dnnl_post_ops_t> {
  3469. using handle<dnnl_post_ops_t>::handle;
  3470. /// Constructs an empty sequence of post-ops.
  3471. post_ops() {
  3472. dnnl_post_ops_t result;
  3473. error::wrap_c_api(
  3474. dnnl_post_ops_create(&result), "could not create post-ops");
  3475. reset(result);
  3476. }
  3477. /// Creates post-ops primitive attribute from a C API ::dnnl_post_ops_t
  3478. /// handle. The resulting handle is not weak and the C handle will be
  3479. /// destroyed during the destruction of the C++ object.
  3480. ///
  3481. /// @param post_ops The C API post-ops primitive attribute.
  3482. post_ops(dnnl_post_ops_t post_ops) : handle<dnnl_post_ops_t>(post_ops) {}
  3483. /// Returns the number of post-ops entries.
  3484. int len() const { return dnnl_post_ops_len(get()); }
  3485. /// Returns the primitive kind of post-op at entry with a certain index.
  3486. /// @param index Index of the post-op to return the kind for.
  3487. /// @returns Primitive kind of the post-op at the specified index.
  3488. primitive::kind kind(int index) const {
  3489. error::wrap_c_api(index < len() ? dnnl_success : dnnl_invalid_arguments,
  3490. "post-ops index is out of range");
  3491. return static_cast<primitive::kind>(
  3492. dnnl_post_ops_get_kind(get(), index));
  3493. }
  3494. /// Appends an accumulation (sum) post-op. Prior to accumulating the
  3495. /// result, the previous value will be will be reduced by zero point
  3496. /// @p zero_point and multiplied by a scaling factor @p scale.
  3497. ///
  3498. /// The kind of this post-op is #dnnl::primitive::kind::sum.
  3499. ///
  3500. /// This feature may improve performance for cases like dequantize the
  3501. /// asymmetrically quantized sum's src1 tensor to f32 domain before
  3502. /// performing the sum operation by subtracting @p zero_point before the
  3503. /// scaling.
  3504. ///
  3505. /// In the simplest case when the accumulation is the only post-op,
  3506. /// the computations will be `dst[:] := scale * (dst[:] - zero_point) +
  3507. /// op(...)` instead of `dst[:] := op(...)`.
  3508. ///
  3509. /// If @p data_type is specified, the original dst tensor will be
  3510. /// reinterpreted as a tensor with the provided data type. Because it is a
  3511. /// reinterpretation, data_type and dst data type should have the same size.
  3512. /// As a result, computations will be `dst[:] <- scale *
  3513. /// (as_data_type(dst[:]) - zero_point) + op(...)` instead of
  3514. /// `dst[:] <- op(...)`.
  3515. ///
  3516. /// @note
  3517. /// This post-op executes in-place and does not change the
  3518. /// destination layout.
  3519. ///
  3520. /// @param scale Scaling factor.
  3521. /// @param zero_point Zero point.
  3522. /// @param data_type Data type.
  3523. void append_sum(float scale = 1.f, int32_t zero_point = 0,
  3524. memory::data_type data_type = memory::data_type::undef) {
  3525. error::wrap_c_api(dnnl_post_ops_append_sum(get(), scale, zero_point,
  3526. memory::convert_to_c(data_type)),
  3527. "could not append a sum post-op");
  3528. }
  3529. /// Returns the parameters of an accumulation (sum) post-op.
  3530. ///
  3531. /// @param index Index of the sum post-op.
  3532. /// @param scale Scaling factor of the sum post-op.
  3533. void get_params_sum(int index, float &scale) const {
  3534. error::wrap_c_api(dnnl_post_ops_get_params_sum(
  3535. get(), index, &scale, nullptr, nullptr),
  3536. "could not get parameters of a sum post-op");
  3537. }
  3538. /// Returns the parameters of an accumulation (sum) post-op.
  3539. ///
  3540. /// @param index Index of the sum post-op.
  3541. /// @param scale Scaling factor of the sum post-op.
  3542. /// @param data_type Data type of the sum post-op.
  3543. void get_params_sum(
  3544. int index, float &scale, memory::data_type &data_type) const {
  3545. dnnl_data_type_t c_data_type;
  3546. error::wrap_c_api(dnnl_post_ops_get_params_sum(
  3547. get(), index, &scale, nullptr, &c_data_type),
  3548. "could not get parameters of a sum post-op");
  3549. data_type = static_cast<memory::data_type>(c_data_type);
  3550. }
  3551. /// Returns the parameters of an accumulation (sum) post-op.
  3552. ///
  3553. /// @param index Index of the sum post-op.
  3554. /// @param scale Scaling factor of the sum post-op.
  3555. /// @param zero_point Single scalar int32_t value of zeropoint.
  3556. /// @param data_type Data type of the sum post-op.
  3557. void get_params_sum(int index, float &scale, int32_t &zero_point,
  3558. memory::data_type &data_type) const {
  3559. dnnl_data_type_t c_data_type;
  3560. error::wrap_c_api(dnnl_post_ops_get_params_sum(get(), index, &scale,
  3561. &zero_point, &c_data_type),
  3562. "could not get parameters of a sum post-op");
  3563. data_type = static_cast<memory::data_type>(c_data_type);
  3564. }
  3565. /// Appends an elementwise post-op.
  3566. ///
  3567. /// The kind of this post-op is #dnnl::primitive::kind::eltwise.
  3568. ///
  3569. /// In the simplest case when the elementwise is the only post-op, the
  3570. /// computations would be `dst[:] := eltwise_op (op(...))` instead
  3571. /// of `dst[:] <- op(...)`, where eltwise_op is configured with the given
  3572. /// parameters.
  3573. ///
  3574. /// @param aalgorithm Elementwise algorithm.
  3575. /// @param alpha Alpha parameter for the elementwise algorithm.
  3576. /// @param beta Beta parameter for the elementwise algorithm.
  3577. void append_eltwise(algorithm aalgorithm, float alpha, float beta) {
  3578. error::wrap_c_api(dnnl_post_ops_append_eltwise(
  3579. get(), convert_to_c(aalgorithm), alpha, beta),
  3580. "could not append an elementwise post-op");
  3581. }
  3582. /// Returns parameters of an elementwise post-op.
  3583. ///
  3584. /// @param index Index of the post-op.
  3585. /// @param aalgorithm Output elementwise algorithm kind.
  3586. /// @param alpha Output alpha parameter for the elementwise algorithm.
  3587. /// @param beta Output beta parameter for the elementwise algorithm.
  3588. void get_params_eltwise(
  3589. int index, algorithm &aalgorithm, float &alpha, float &beta) const {
  3590. dnnl_alg_kind_t c_alg;
  3591. error::wrap_c_api(dnnl_post_ops_get_params_eltwise(
  3592. get(), index, &c_alg, &alpha, &beta),
  3593. "could not get parameters of an elementwise post-op");
  3594. aalgorithm = static_cast<dnnl::algorithm>(c_alg);
  3595. }
  3596. /// Appends a depthwise post-op convolution.
  3597. ///
  3598. /// This post-op can only be fused with a 2D 1x1 convolution (convolution
  3599. /// with weights spatial dimension equal to 1 i.e., kh=kw=1).
  3600. ///
  3601. /// The kind of this post-op is #dnnl_convolution.
  3602. ///
  3603. /// The number of outputs for primitive remain same as before. The output
  3604. /// spatial size can be derived as below:
  3605. ///
  3606. /// output_height = ceil(output_height_1x1_convolution, stride)
  3607. /// output_width = ceil(output_width_1x1_convolution, stride)
  3608. ///
  3609. /// See @ref dev_guide_attributes_post_ops_depthwise and
  3610. /// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
  3611. ///
  3612. /// @param weights_data_type Weights data type of depthwise post-op
  3613. /// @param bias_data_type Bias data type of depthwise post-op
  3614. /// @param dst_data_type Output data type of depthwise post-op
  3615. /// @param kernel_size Size of kernel of depthwise post-op
  3616. /// @param stride_size Size of stride of depthwise post-op
  3617. /// @param padding_l_size Size of left and top paddings of depthwise post-op
  3618. void append_dw(memory::data_type weights_data_type,
  3619. memory::data_type bias_data_type, memory::data_type dst_data_type,
  3620. memory::dim kernel_size, memory::dim stride_size,
  3621. memory::dim padding_l_size) {
  3622. error::wrap_c_api(dnnl_post_ops_append_dw(get(),
  3623. memory::convert_to_c(weights_data_type),
  3624. memory::convert_to_c(bias_data_type),
  3625. memory::convert_to_c(dst_data_type),
  3626. kernel_size, stride_size, padding_l_size),
  3627. "could not append depthwise post-op");
  3628. }
  3629. /// Returns the parameters of an depthwise post-op.
  3630. ///
  3631. /// @param index Index of the elementwise post-op.
  3632. /// @param weights_data_type Weights data type of depthwise post-op
  3633. /// @param bias_data_type Bias data type of depthwise post-op
  3634. /// @param dst_data_type Output data type of depthwise post-op
  3635. /// @param kernel_size Size of kernel of depthwise post-op
  3636. /// @param stride_size Size of stride of depthwise post-op
  3637. /// @param padding_l_size Size of left and top paddings of depthwise post-op
  3638. void get_params_dw(int index, memory::data_type &weights_data_type,
  3639. memory::data_type &bias_data_type, memory::data_type &dst_data_type,
  3640. memory::dim &kernel_size, memory::dim &stride_size,
  3641. memory::dim &padding_l_size) const {
  3642. dnnl_data_type_t c_weights_data_type;
  3643. dnnl_data_type_t c_bias_data_type;
  3644. dnnl_data_type_t c_dst_data_type;
  3645. dnnl_dim_t c_kernel_size;
  3646. dnnl_dim_t c_stride_size;
  3647. dnnl_dim_t c_padding_l_size;
  3648. error::wrap_c_api(
  3649. dnnl_post_ops_get_params_dw(get(), index, &c_weights_data_type,
  3650. &c_bias_data_type, &c_dst_data_type, &c_kernel_size,
  3651. &c_stride_size, &c_padding_l_size),
  3652. "could not get parameters of depthwise post-op");
  3653. weights_data_type = static_cast<memory::data_type>(c_weights_data_type);
  3654. bias_data_type = static_cast<memory::data_type>(c_bias_data_type);
  3655. dst_data_type = static_cast<memory::data_type>(c_dst_data_type);
  3656. kernel_size = c_kernel_size;
  3657. stride_size = c_stride_size;
  3658. padding_l_size = c_padding_l_size;
  3659. }
  3660. /// Appends a binary post-op.
  3661. ///
  3662. /// The kind of this post operation is #dnnl_binary.
  3663. ///
  3664. /// In the simplest case when the binary is the only post operation, the
  3665. /// computations would be:
  3666. ///
  3667. /// dst[:] <- binary_op (dst[:], another_input[:])
  3668. ///
  3669. /// where binary_op is configured with the given parameters. binary_op
  3670. /// supports broadcast semantics for a second operand.
  3671. ///
  3672. /// @param aalgorithm Binary algorithm for the post-op.
  3673. /// @param src1_desc Memory descriptor of a second operand.
  3674. void append_binary(algorithm aalgorithm, const memory::desc &src1_desc) {
  3675. error::wrap_c_api(dnnl_post_ops_append_binary(get(),
  3676. convert_to_c(aalgorithm), src1_desc.get()),
  3677. "could not append a binary post-op");
  3678. }
  3679. /// Returns the parameters of a binary post-op.
  3680. ///
  3681. /// @param index Index of the binary post-op.
  3682. /// @param aalgorithm Output binary algorithm kind.
  3683. /// @param src1_desc Output memory descriptor of a second operand.
  3684. void get_params_binary(
  3685. int index, algorithm &aalgorithm, memory::desc &src1_desc) const {
  3686. dnnl_alg_kind_t c_alg;
  3687. const_dnnl_memory_desc_t cdesc;
  3688. error::wrap_c_api(
  3689. dnnl_post_ops_get_params_binary(get(), index, &c_alg, &cdesc),
  3690. "could not get parameters of a binary post-op");
  3691. aalgorithm = static_cast<dnnl::algorithm>(c_alg);
  3692. dnnl_memory_desc_t cloned_md = nullptr;
  3693. error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
  3694. "could not clone a memory descriptor");
  3695. src1_desc = memory::desc(cloned_md);
  3696. }
  3697. /// Appends a prelu forward post-op.
  3698. ///
  3699. /// The kind of this post-op is #dnnl::primitive::kind::prelu.
  3700. ///
  3701. /// The post-op can be defined as:
  3702. ///
  3703. /// dst[:] <- prelu(dst[:], weights[:])
  3704. /// prelu:
  3705. /// dst[:] <- dst[:] if dst[:] > 0
  3706. /// dst[:] <- dst[:] * weights[:] if dst[:] <= 0
  3707. ///
  3708. ///
  3709. /// Example usage:
  3710. /// @code
  3711. /// int mb = 32, oc = 32,
  3712. /// oh = 14, ow = 14; // convolution output params
  3713. /// // unique weights per output channel
  3714. /// vector<float> weights = { ... };
  3715. /// int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ...
  3716. ///
  3717. /// // construct a convolution descriptor
  3718. /// dnnl::convolution::desc conv_d;
  3719. ///
  3720. /// dnnl::primitive_attr attr;
  3721. /// attr.append_prelu(1 << oc_dim);
  3722. ///
  3723. /// dnnl::primitive_desc conv_pd(conv_d, attr, engine);
  3724. /// memory prelu_weights({{1}, dt::f32, {1}}, eng, weights.data());
  3725. ///
  3726. /// std::unordered_map<int, memory> conv_args;
  3727. ///
  3728. /// conv_args.insert(
  3729. /// {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_WEIGHTS, prelu_weights})
  3730. /// @endcode
  3731. ///
  3732. /// @note
  3733. /// The order of dimensions does not depend on how elements are laid
  3734. /// out in memory. For example:
  3735. /// - for a 2D CNN activations tensor the order is always (n, c)
  3736. /// - for a 4D CNN activations tensor the order is always (n, c, h, w)
  3737. /// - for a 5D CNN weights tensor the order is always
  3738. /// (g, oc, ic, kh, kw)
  3739. ///
  3740. /// Prelu weights tensor is passed in runtime execution phase. Prelu
  3741. /// weights tensor data type is implicitly assumed as f32 using plain
  3742. /// layout (a, ab, acb, acdb, acdeb).
  3743. ///
  3744. /// @param mask Defines the correspondence between the output tensor
  3745. /// dimensions and the prelu weights tensor. The set i-th bit indicates
  3746. /// that a dedicated weights value is used for each index along that
  3747. /// dimension. Set the mask to 0 to use a common weights value
  3748. /// for the whole output tensor.
  3749. void append_prelu(int mask) {
  3750. error::wrap_c_api(dnnl_post_ops_append_prelu(get(), mask),
  3751. "could not append a prelu post-op");
  3752. }
  3753. /// Returns the parameters of a prelu post-op.
  3754. ///
  3755. /// @param index Index of the prelu post-op.
  3756. /// @param mask Weights mask of prelu post-op.
  3757. void get_params_prelu(int index, int &mask) const {
  3758. error::wrap_c_api(dnnl_post_ops_get_params_prelu(get(), index, &mask),
  3759. "could not get parameters of a binary post-op");
  3760. }
  3761. };
  3762. /// @cond DO_NOT_DOCUMENT_THIS
  3763. template <>
  3764. struct handle_traits<dnnl_primitive_attr_t> {
  3765. static dnnl_status_t destructor(dnnl_primitive_attr_t p) {
  3766. return dnnl_primitive_attr_destroy(p);
  3767. }
  3768. };
  3769. /// @endcond
  3770. /// Primitive attributes.
  3771. ///
  3772. /// @sa @ref dev_guide_attributes
  3773. struct primitive_attr : public handle<dnnl_primitive_attr_t> {
  3774. using handle<dnnl_primitive_attr_t>::handle;
  3775. /// Constructs default (empty) primitive attributes.
  3776. primitive_attr() {
  3777. dnnl_primitive_attr_t result;
  3778. error::wrap_c_api(dnnl_primitive_attr_create(&result),
  3779. "could not create primitive attribute");
  3780. reset(result);
  3781. }
  3782. /// Creates primitive attributes from a C API ::dnnl_primitive_attr_t
  3783. /// handle. The resulting handle is not weak and the C handle will be
  3784. /// destroyed during the destruction of the C++ object.
  3785. ///
  3786. /// @param attr The C API primitive attributes.
  3787. primitive_attr(dnnl_primitive_attr_t attr)
  3788. : handle<dnnl_primitive_attr_t>(attr) {}
  3789. /// Returns the parameters of a dropout attribute.
  3790. ///
  3791. /// @param mask_desc Output memory descriptor of a dropout mask.
  3792. void get_dropout(memory::desc &mask_desc) const {
  3793. const_dnnl_memory_desc_t cdesc;
  3794. error::wrap_c_api(dnnl_primitive_attr_get_dropout(get(), &cdesc),
  3795. "could not get parameters of a dropout attribute");
  3796. dnnl_memory_desc_t cloned_md = nullptr;
  3797. error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
  3798. "could not clone a memory descriptor");
  3799. mask_desc = memory::desc(cloned_md);
  3800. }
  3801. /// Sets dropout probability.
  3802. ///
  3803. /// @param mask_desc Output memory descriptor of a dropout mask.
  3804. void set_dropout(const memory::desc &mask_desc) {
  3805. error::wrap_c_api(
  3806. dnnl_primitive_attr_set_dropout(get(), mask_desc.get()),
  3807. "could not set dropout primitive attribute");
  3808. }
  3809. /// Returns the fpmath mode
  3810. fpmath_mode get_fpmath_mode() const {
  3811. dnnl_fpmath_mode_t result;
  3812. error::wrap_c_api(dnnl_primitive_attr_get_fpmath_mode(get(), &result),
  3813. "could not get fpmath mode primitive attribute");
  3814. return fpmath_mode(result);
  3815. }
  3816. /// Returns the fpmath mode
  3817. ///
  3818. /// @param mode Specified fpmath mode.
  3819. /// @param apply_to_int Use floating-point arithmetic for integer primitives.
  3820. void get_fpmath_mode(fpmath_mode &mode, bool &apply_to_int) const {
  3821. dnnl_fpmath_mode_t c_mode;
  3822. int c_apply_to_int;
  3823. error::wrap_c_api(dnnl_primitive_attr_get_fpmath_mode_v2(
  3824. get(), &c_mode, &c_apply_to_int),
  3825. "could not get fpmath mode primitive attribute");
  3826. mode = fpmath_mode(c_mode);
  3827. apply_to_int = static_cast<bool>(c_apply_to_int);
  3828. }
  3829. /// Sets fpmath mode.
  3830. ///
  3831. /// @param mode Specified fpmath mode.
  3832. /// @param apply_to_int Boolean. Use of floating-point arithmetic for integer primitives.
  3833. void set_fpmath_mode(fpmath_mode mode, bool apply_to_int = false) {
  3834. error::wrap_c_api(dnnl_primitive_attr_set_fpmath_mode_v2(get(),
  3835. dnnl::convert_to_c(mode), apply_to_int),
  3836. "could not set fpmath mode primitive attribute");
  3837. }
  3838. /// Returns the accumulation mode
  3839. accumulation_mode get_accumulation_mode() const {
  3840. dnnl_accumulation_mode_t result;
  3841. error::wrap_c_api(
  3842. dnnl_primitive_attr_get_accumulation_mode(get(), &result),
  3843. "could not get accumulation mode primitive attribute");
  3844. return accumulation_mode(result);
  3845. }
  3846. /// Sets accumulation mode.
  3847. ///
  3848. /// @param mode Specified accumulation mode.
  3849. void set_accumulation_mode(accumulation_mode mode) {
  3850. error::wrap_c_api(dnnl_primitive_attr_set_accumulation_mode(
  3851. get(), dnnl::convert_to_c(mode)),
  3852. "could not set accumulation mode primitive attribute");
  3853. }
  3854. /// Returns the deterministic attribute value
  3855. bool get_deterministic() const {
  3856. int result;
  3857. error::wrap_c_api(dnnl_primitive_attr_get_deterministic(get(), &result),
  3858. "could not get deterministic primitive attribute");
  3859. return static_cast<bool>(result);
  3860. }
  3861. /// Sets deterministic attribute value
  3862. ///
  3863. /// @param value Specified deterministic mode.
  3864. void set_deterministic(bool value) {
  3865. error::wrap_c_api(dnnl_primitive_attr_set_deterministic(
  3866. get(), static_cast<int>(value)),
  3867. "could not set deterministic primitive attribute");
  3868. }
  3869. /// Returns the rounding mode attribute value
  3870. ///
  3871. /// @param arg Argument for which rounding mode query applies.
  3872. /// @returns The rounding mode applied to the specified argument.
  3873. rounding_mode get_rounding_mode(int arg) const {
  3874. dnnl_rounding_mode_t result;
  3875. error::wrap_c_api(dnnl_primitive_attr_get_rounding(get(), arg, &result),
  3876. "could not get rounding mode primitive attribute");
  3877. return rounding_mode(result);
  3878. }
  3879. /// Sets the rounding mode attribute value for a given argument
  3880. ///
  3881. /// @param arg Argument for which to set rounding mode.
  3882. /// @param mode Rounding mode to apply.
  3883. void set_rounding_mode(int arg, rounding_mode mode) {
  3884. error::wrap_c_api(dnnl_primitive_attr_set_rounding(
  3885. get(), arg, convert_to_c(mode)),
  3886. "could not set rounding mode primitive attribute");
  3887. }
  3888. /// Returns the scratchpad mode.
  3889. scratchpad_mode get_scratchpad_mode() const {
  3890. dnnl_scratchpad_mode_t result;
  3891. error::wrap_c_api(
  3892. dnnl_primitive_attr_get_scratchpad_mode(get(), &result),
  3893. "could not get scratchpad mode primitive attribute");
  3894. return scratchpad_mode(result);
  3895. }
  3896. /// Sets scratchpad mode.
  3897. ///
  3898. /// @param mode Specified scratchpad mode.
  3899. void set_scratchpad_mode(scratchpad_mode mode) {
  3900. error::wrap_c_api(dnnl_primitive_attr_set_scratchpad_mode(
  3901. get(), dnnl::convert_to_c(mode)),
  3902. "could not set scratchpad mode primitive attribute");
  3903. }
  3904. /// Sets scaling factors for primitive operations for a given memory
  3905. /// argument. The scaling factors must be passed at execution time
  3906. /// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
  3907. ///
  3908. /// @sa dnnl_primitive_attr_set_scales_mask
  3909. ///
  3910. /// @param arg Parameter argument index as passed to the
  3911. /// primitive::execute() call.
  3912. /// @param mask Scaling factors correspondence mask that defines the
  3913. /// correspondence between the tensor dimensions and the @p scales
  3914. /// vector. The set i-th bit indicates that a dedicated scaling factor
  3915. /// is used for each index along that dimension. Set the mask to 0 to
  3916. /// use a common scaling factor for the whole output tensor.
  3917. void set_scales_mask(int arg, int mask) {
  3918. error::wrap_c_api(dnnl_primitive_attr_set_scales_mask(get(), arg, mask),
  3919. "could not set scales primitive attribute");
  3920. }
  3921. /// Sets scaling factors for primitive operations for a given memory
  3922. /// argument. The scaling factors must be passed at execution time
  3923. /// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
  3924. ///
  3925. /// @sa dnnl_primitive_attr_set_scales
  3926. ///
  3927. /// @param arg Parameter argument index as passed to the
  3928. /// primitive::execute() call.
  3929. /// @param mask Scales correspondence mask that defines the
  3930. /// correspondence between the tensor dimensions and the @p
  3931. /// scales vector. The set i-th bit indicates that a dedicated
  3932. /// scale is used for each index along that dimension. Set the
  3933. /// mask to 0 to use a common scale for the whole output tensor.
  3934. /// @param groups Scaling factors correspondence groups that define the
  3935. /// correspondence between the tensor dimensions and the scales array.
  3936. /// The set i-th dimension indicates a number of groups of scaling
  3937. /// factors used for that logical dimension in a memory indicated by @p arg.
  3938. /// @param data_type Scaling factors data_type.
  3939. void set_scales(int arg, int mask, const memory::dims &groups,
  3940. memory::data_type data_type = memory::data_type::f32) {
  3941. error::wrap_c_api(dnnl_primitive_attr_set_scales(get(), arg, mask,
  3942. (int)groups.size(), groups.data(),
  3943. memory::convert_to_c(data_type)),
  3944. "could not set scales primitive attribute");
  3945. }
  3946. /// Sets zero points for primitive operations for a given memory argument.
  3947. /// The zero points must be passed at execution time as an argument with
  3948. /// index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
  3949. ///
  3950. /// @sa dnnl_primitive_attr_set_zero_points_mask
  3951. ///
  3952. /// @param arg Parameter argument index as passed to the
  3953. /// primitive::execute() call.
  3954. /// @param mask Zero point correspondence mask that defines the
  3955. /// correspondence between the tensor dimensions and the @p
  3956. /// zero_points vector. The set i-th bit indicates that a dedicated
  3957. /// zero point is used for each index along that dimension. Set the
  3958. /// mask to 0 to use a common zero point for the whole output tensor.
  3959. void set_zero_points_mask(int arg, int mask) {
  3960. error::wrap_c_api(
  3961. dnnl_primitive_attr_set_zero_points_mask(get(), arg, mask),
  3962. "could not set zero points primitive attribute");
  3963. }
  3964. /// Sets zero points for primitive operations for a given memory argument.
  3965. /// The zero points must be passed at execution time as an argument with
  3966. /// index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
  3967. ///
  3968. /// @sa dnnl_primitive_attr_set_zero_points
  3969. ///
  3970. /// @param arg Parameter argument index as passed to the
  3971. /// primitive::execute() call.
  3972. /// @param mask Zero point correspondence mask that defines the
  3973. /// correspondence between the tensor dimensions and the @p
  3974. /// zero_points vector. The set i-th bit indicates that a dedicated
  3975. /// zero point is used for each index along that dimension. Set the
  3976. /// mask to 0 to use a common zero point for the whole output tensor.
  3977. /// @param groups Zero point factors correspondence groups that define the
  3978. /// correspondence between the tensor dimensions and the zero_points array.
  3979. /// The set i-th dimension indicates a number of groups of zero point
  3980. /// factors used for that logical dimension in a memory indicated by @p arg.
  3981. /// @param data_type Zero point factors data_type.
  3982. void set_zero_points(int arg, int mask, const memory::dims &groups,
  3983. memory::data_type data_type = memory::data_type::s32) {
  3984. error::wrap_c_api(dnnl_primitive_attr_set_zero_points(get(), arg, mask,
  3985. (int)groups.size(), groups.data(),
  3986. memory::convert_to_c(data_type)),
  3987. "could not set zero points primitive attribute");
  3988. }
  3989. /// Returns post-ops previously set via set_post_ops().
  3990. ///
  3991. /// @returns Post-ops.
  3992. const post_ops get_post_ops() const {
  3993. const_dnnl_post_ops_t const_c_post_ops;
  3994. error::wrap_c_api(
  3995. dnnl_primitive_attr_get_post_ops(get(), &const_c_post_ops),
  3996. "could not get post-ops primitive attribute");
  3997. dnnl_post_ops_t c_post_ops;
  3998. error::wrap_c_api(dnnl_post_ops_clone(&c_post_ops, const_c_post_ops),
  3999. "could not clone post-ops primitive attribute");
  4000. return post_ops(c_post_ops);
  4001. }
  4002. /// Sets post-ops.
  4003. ///
  4004. /// @note
  4005. /// There is no way to check whether the post-ops would be supported
  4006. /// by the target primitive. Any error will be reported
  4007. /// by the respective primitive descriptor constructor.
  4008. ///
  4009. /// @param ops Post-ops object to copy post-ops from.
  4010. void set_post_ops(const post_ops ops) {
  4011. error::wrap_c_api(dnnl_primitive_attr_set_post_ops(get(), ops.get()),
  4012. "could not set post-ops primitive attribute");
  4013. }
  4014. /// Sets quantization scale and shift parameters for RNN data tensors.
  4015. ///
  4016. /// For performance reasons, the low-precision configuration of the RNN
  4017. /// primitives expect input activations to have the unsigned 8-bit integer
  4018. /// data type. The scale and shift parameters are used to quantize
  4019. /// floating-point data to unsigned integer and must be passed to the RNN
  4020. /// primitive using attributes.
  4021. ///
  4022. /// The quantization formula is `scale * data + shift`.
  4023. ///
  4024. /// Example usage:
  4025. /// @code
  4026. /// // RNN parameters
  4027. /// int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
  4028. /// // Activations quantization parameters
  4029. /// float scale = 63.f, shift = 64.f;
  4030. ///
  4031. /// primitive_attr attr;
  4032. ///
  4033. /// // Set scale and shift for int8 quantization of activation
  4034. /// attr.set_rnn_data_qparams(scale, shift);
  4035. ///
  4036. /// // Create an RNN primitive descriptor.
  4037. /// vanilla_rnn_forward::primitive_desc rnn_d(
  4038. /// engine, /* arguments */, attr);
  4039. /// @endcode
  4040. ///
  4041. /// @note
  4042. /// Quantization scale and shift are common for src_layer, src_iter,
  4043. /// dst_iter, and dst_layer.
  4044. ///
  4045. /// @param scale The value to scale the data by.
  4046. /// @param shift The value to shift the data by.
  4047. void set_rnn_data_qparams(float scale, float shift) {
  4048. error::wrap_c_api(
  4049. dnnl_primitive_attr_set_rnn_data_qparams(get(), scale, shift),
  4050. "could not set RNN data quantization parameters primitive "
  4051. "attribute");
  4052. }
  4053. /// Returns the quantization scale and shift parameters for RNN data
  4054. /// tensors.
  4055. ///
  4056. /// @note
  4057. /// Quantization scale and shift are common for src_layer, src_iter,
  4058. /// dst_iter, and dst_layer.
  4059. ///
  4060. /// @param scale The value to scale the data by.
  4061. /// @param shift The value to shift the data by.
  4062. void get_rnn_data_qparams(float &scale, float &shift) {
  4063. float c_scale, c_shift;
  4064. error::wrap_c_api(dnnl_primitive_attr_get_rnn_data_qparams(
  4065. get(), &c_scale, &c_shift),
  4066. "could not set RNN data quantization parameters primitive "
  4067. "attribute");
  4068. scale = c_scale;
  4069. shift = c_shift;
  4070. }
  4071. /// Sets quantization scaling factors for RNN weights tensors. The
  4072. /// low-precision configuration of the RNN primitives expect input weights
  4073. /// to use the signed 8-bit integer data type. The scaling factors are
  4074. /// used to quantize floating-point data to signed integer and must be
  4075. /// passed to RNN primitives using attributes.
  4076. ///
  4077. /// @note
  4078. /// The dimension order is always native and does not depend on the
  4079. /// actual layout used. For example, five-dimensional weights always
  4080. /// have (l, d, i, g, o) logical dimension ordering.
  4081. ///
  4082. /// @note
  4083. /// Quantization scales are common for weights_layer and
  4084. /// weights_iteration
  4085. ///
  4086. /// @param mask Scaling factors correspondence mask that defines the
  4087. /// correspondence between the output tensor dimensions and the @p
  4088. /// scales vector. The set i-th bit indicates that a dedicated scaling
  4089. /// factor should be used each index along that dimension. Set the
  4090. /// mask to 0 to use a common scaling factor for the whole output
  4091. /// tensor.
  4092. /// @param scales Constant vector of output scaling factors. The following
  4093. /// equality must hold:
  4094. /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
  4095. /// Violations can only be detected when the attributes are used to
  4096. /// create a primitive descriptor.
  4097. void set_rnn_weights_qparams(int mask, const std::vector<float> &scales) {
  4098. error::wrap_c_api(dnnl_primitive_attr_set_rnn_weights_qparams(get(),
  4099. (int)scales.size(), mask, scales.data()),
  4100. "could not set RNN weights quantization parameters primitive "
  4101. "attribute");
  4102. }
  4103. /// Returns the quantization scaling factors for RNN projection weights
  4104. /// tensors.
  4105. ///
  4106. /// @note
  4107. /// The dimension order is always native and does not depend on the
  4108. /// actual layout used. For example, five-dimensional weights always
  4109. /// have (l, d, i, g, o) logical dimension ordering.
  4110. ///
  4111. /// @param mask Scaling factors correspondence mask that defines the
  4112. /// correspondence between the output tensor dimensions and the @p
  4113. /// scales vector. The set i-th bit indicates that a dedicated scaling
  4114. /// factor should be used each index along that dimension. Set the
  4115. /// mask to 0 to use a common scaling factor for the whole output
  4116. /// tensor.
  4117. /// @param scales Constant vector of output scaling factors. The following
  4118. /// equality must hold:
  4119. /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
  4120. /// Violations can only be detected when the attributes are used to
  4121. /// create a primitive descriptor.
  4122. void get_rnn_weights_qparams(int &mask, std::vector<float> &scales) {
  4123. dnnl_dim_t count;
  4124. int c_mask;
  4125. const float *c_scales;
  4126. error::wrap_c_api(dnnl_primitive_attr_get_rnn_weights_qparams(
  4127. get(), &count, &c_mask, &c_scales),
  4128. "could not get primitive RNN weights quantization "
  4129. "parameters attributes");
  4130. scales.resize(count);
  4131. mask = c_mask;
  4132. for (dnnl_dim_t c = 0; c < count; c++)
  4133. scales[c] = c_scales[c];
  4134. }
  4135. /// Sets quantization scaling factors for RNN projection weights tensors.
  4136. // The low-precision configuration of the RNN primitives expect input
  4137. // weights to use the signed 8-bit integer data type. The scaling factors
  4138. // are used to quantize floating-point data to signed integer and must be
  4139. /// passed to RNN primitives using attributes.
  4140. ///
  4141. /// @note
  4142. /// The dimension order is always native and does not depend on the
  4143. /// actual layout used. For example, five-dimensional weights always
  4144. /// have (l, d, i, g, o) logical dimension ordering.
  4145. ///
  4146. /// @note
  4147. /// Quantization scales are common for weights_layer and
  4148. /// weights_iteration
  4149. ///
  4150. /// @param mask Scaling factors correspondence mask that defines the
  4151. /// correspondence between the output tensor dimensions and the @p
  4152. /// scales vector. The set i-th bit indicates that a dedicated scaling
  4153. /// factor should be used each index along that dimension. Set the
  4154. /// mask to 0 to use a common scaling factor for the whole output
  4155. /// tensor.
  4156. /// @param scales Constant vector of output scaling factors. The following
  4157. /// equality must hold:
  4158. /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
  4159. /// Violations can only be detected when the attributes are used to
  4160. /// create a primitive descriptor.
  4161. void set_rnn_weights_projection_qparams(
  4162. int mask, const std::vector<float> &scales) {
  4163. error::wrap_c_api(
  4164. dnnl_primitive_attr_set_rnn_weights_projection_qparams(
  4165. get(), (int)scales.size(), mask, scales.data()),
  4166. "could not set primitive RNN weights projection quantization "
  4167. "parameters attributes");
  4168. }
  4169. /// Returns the quantization scaling factors for RNN projection weights
  4170. /// tensors.
  4171. ///
  4172. /// @note
  4173. /// The dimension order is always native and does not depend on the
  4174. /// actual layout used. For example, five-dimensional weights always
  4175. /// have (l, d, i, g, o) logical dimension ordering.
  4176. ///
  4177. /// @param mask Scaling factors correspondence mask that defines the
  4178. /// correspondence between the output tensor dimensions and the @p
  4179. /// scales vector. The set i-th bit indicates that a dedicated scaling
  4180. /// factor should be used each index along that dimension. Set the
  4181. /// mask to 0 to use a common scaling factor for the whole output
  4182. /// tensor.
  4183. /// @param scales Constant vector of output scaling factors. The following
  4184. /// equality must hold:
  4185. /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
  4186. /// Violations can only be detected when the attributes are used to
  4187. /// create a primitive descriptor.
  4188. void get_rnn_weights_projection_qparams(
  4189. int &mask, std::vector<float> &scales) {
  4190. dnnl_dim_t count;
  4191. int c_mask;
  4192. const float *c_scales;
  4193. error::wrap_c_api(
  4194. dnnl_primitive_attr_get_rnn_weights_projection_qparams(
  4195. get(), &count, &c_mask, &c_scales),
  4196. "could not get primitive RNN weights projection quantization "
  4197. "parameters attributes");
  4198. scales.resize(count);
  4199. mask = c_mask;
  4200. for (dnnl_dim_t c = 0; c < count; c++)
  4201. scales[c] = c_scales[c];
  4202. }
  4203. };
  4204. /// @} dnnl_api_attributes
  4205. /// @addtogroup dnnl_api_primitives_common
  4206. /// @{
  4207. /// Base class for all primitive descriptors.
  4208. struct primitive_desc_base : public handle<dnnl_primitive_desc_t> {
  4209. using handle<dnnl_primitive_desc_t>::handle;
  4210. /// Default constructor. Produces an empty object.
  4211. primitive_desc_base() = default;
  4212. /// Returns the engine of the primitive descriptor.
  4213. /// @returns The engine of the primitive descriptor.
  4214. engine get_engine() const { return query_engine(query::engine); }
  4215. /// Returns implementation name.
  4216. /// @returns The implementation name.
  4217. const char *impl_info_str() const {
  4218. const char *res;
  4219. error::wrap_c_api(dnnl_primitive_desc_query(
  4220. get(), dnnl_query_impl_info_str, 0, &res),
  4221. "could not retrieve implementation info string from a "
  4222. "primitive descriptor");
  4223. return res;
  4224. }
  4225. /// Returns a memory::dim value (same as int64_t).
  4226. /// @param what The value to query.
  4227. /// @returns The result of the query.
  4228. memory::dim query_s64(query what) const {
  4229. memory::dim res;
  4230. dnnl_status_t status = dnnl_primitive_desc_query(
  4231. get(), dnnl::convert_to_c(what), 0, &res);
  4232. return status == dnnl_success ? res : 0;
  4233. }
  4234. /// Returns strides.
  4235. /// @returns Strides.
  4236. /// @returns An empty #dnnl::memory::dims if the primitive does not have
  4237. /// a strides parameter.
  4238. memory::dims get_strides() const { return query_dims(query::strides); }
  4239. /// Returns dilations.
  4240. /// @returns Dilations.
  4241. /// @returns An empty #dnnl::memory::dims if the primitive does not have
  4242. /// a dilations parameter.
  4243. memory::dims get_dilations() const { return query_dims(query::dilations); }
  4244. /// Returns a left padding.
  4245. /// @returns A left padding.
  4246. /// @returns An empty #dnnl::memory::dims if the primitive does not have
  4247. /// a left padding parameter.
  4248. memory::dims get_padding_l() const { return query_dims(query::padding_l); }
  4249. /// Returns a right padding.
  4250. /// @returns A right padding.
  4251. /// @returns An empty #dnnl::memory::dims if the primitive does not have
  4252. /// a right padding parameter.
  4253. memory::dims get_padding_r() const { return query_dims(query::padding_r); }
  4254. /// Returns an epsilon.
  4255. /// @returns An epsilon.
  4256. /// @returns Zero if the primitive does not have an epsilon parameter.
  4257. float get_epsilon() const { return query_f32(query::epsilon_f32); }
  4258. /// Returns flags.
  4259. /// @tparam T Flags enumeration type.
  4260. /// @returns Flags.
  4261. /// @returns Zero if the primitive does not have a flags parameter.
  4262. template <typename T = unsigned>
  4263. T get_flags() const {
  4264. unsigned res;
  4265. dnnl_status_t status
  4266. = dnnl_primitive_desc_query(get(), dnnl_query_flags, 0, &res);
  4267. return static_cast<T>(status == dnnl_success ? res : 0x0U);
  4268. }
  4269. /// Returns an algorithm kind.
  4270. /// @returns An algorithm kind.
  4271. /// @returns #dnnl::algorithm::undef if the primitive does not have an
  4272. /// algorithm parameter.
  4273. dnnl::algorithm get_algorithm() const { return query_alg(query::alg_kind); }
  4274. /// Returns an alpha.
  4275. /// @returns An alpha.
  4276. /// @returns Zero if the primitive does not have an alpha parameter.
  4277. float get_alpha() const { return query_f32(query::alpha_f32); }
  4278. /// Returns a beta.
  4279. /// @returns A beta.
  4280. /// @returns Zero if the primitive does not have a beta parameter.
  4281. float get_beta() const { return query_f32(query::beta_f32); }
  4282. /// Returns an axis.
  4283. /// @returns An axis.
  4284. /// @returns A negative number if the primitive does not have an axis
  4285. /// parameter.
  4286. int get_axis() const {
  4287. int res;
  4288. dnnl_status_t status = dnnl_primitive_desc_query(
  4289. get(), dnnl_query_axis_s32, 0, &res);
  4290. return status == dnnl_success ? res : -1;
  4291. }
  4292. /// Returns an LRN local size parameter.
  4293. /// @returns An LRN local size parameter.
  4294. /// @returns Zero if the primitive does not have an LRN local size
  4295. /// parameter.
  4296. memory::dim get_local_size() const {
  4297. return query_s64(query::local_size_s64);
  4298. }
  4299. /// Returns an LRN K parameter.
  4300. /// @returns An LRN K parameter.
  4301. /// @returns Zero if the primitive does not have an LRN K parameter.
  4302. float get_k() const { return query_f32(query::k_f32); }
  4303. /// Returns a reduction P parameter.
  4304. /// @returns A reduction P parameter.
  4305. /// @returns Zero if the primitive does not have a reduction P parameter.
  4306. float get_p() const { return query_f32(query::p_f32); }
  4307. /// Returns a resampling factors parameters.
  4308. /// @returns A vector of factors.
  4309. /// @returns An empty vector if the primitive does not have a resampling
  4310. /// factors parameter.
  4311. std::vector<float> get_factors() const {
  4312. float *factors;
  4313. dnnl_status_t status = dnnl_primitive_desc_query(
  4314. get(), dnnl_query_factors, 0, &factors);
  4315. const bool is_backward = get_prop_kind() != prop_kind::forward_training
  4316. && get_prop_kind() != prop_kind::forward_inference;
  4317. const_dnnl_memory_desc_t md = dnnl_primitive_desc_query_md(get(),
  4318. is_backward ? dnnl_query_diff_dst_md : dnnl_query_dst_md, 0);
  4319. int ndims;
  4320. error::wrap_c_api(
  4321. dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &ndims),
  4322. "could not query ndims from a memory descriptor");
  4323. return status == dnnl_success
  4324. ? std::vector<float>(factors, factors + (ndims - 2))
  4325. : std::vector<float> {};
  4326. }
  4327. /// Returns an RNN cell kind parameter.
  4328. /// @returns An RNN cell kind parameter.
  4329. /// @returns #dnnl::algorithm::undef if the primitive does not have an
  4330. /// RNN cell kind parameter.
  4331. dnnl::algorithm get_cell_kind() const {
  4332. return query_alg(query::cell_kind);
  4333. }
  4334. /// Returns an RNN direction parameter.
  4335. /// @returns An RNN direction parameter.
  4336. /// @returns #dnnl::rnn_direction::undef if the primitive does not have
  4337. /// an RNN direction parameter.
  4338. dnnl::rnn_direction get_direction() const {
  4339. dnnl_rnn_direction_t direction;
  4340. dnnl_status_t status = dnnl_primitive_desc_query(
  4341. get(), dnnl_query_direction, 0, &direction);
  4342. return status == dnnl_success
  4343. ? static_cast<dnnl::rnn_direction>(direction)
  4344. : dnnl::rnn_direction::undef;
  4345. }
  4346. /// Returns an RNN activation kind parameter.
  4347. /// @returns An RNN activation kind parameter.
  4348. /// @returns #dnnl::algorithm::undef if the primitive does not have an
  4349. /// RNN activation kind parameter.
  4350. dnnl::algorithm get_activation_kind() const {
  4351. return query_alg(query::activation_kind);
  4352. }
  4353. /// Returns a pooling kernel parameter.
  4354. /// @returns A pooling kernel parameter.
  4355. /// @returns An empty #dnnl::memory::dims if the primitive does not have
  4356. /// a pooling kernel parameter.
  4357. memory::dims get_kernel() const { return query_dims(query::kernel); }
  4358. /// Returns a group size parameter.
  4359. /// @returns A group size parameter.
  4360. /// @returns Zero if the primitive does not have a group size
  4361. /// parameter.
  4362. memory::dim get_group_size() const {
  4363. return query_s64(query::group_size_s64);
  4364. }
  4365. /// Returns a propagation kind.
  4366. /// @returns A propagation kind.
  4367. /// @returns #dnnl::prop_kind::undef if the primitive does not have
  4368. /// a propagation parameter.
  4369. dnnl::prop_kind get_prop_kind() const {
  4370. dnnl_prop_kind_t prop_kind;
  4371. dnnl_status_t status = dnnl_primitive_desc_query(
  4372. get(), dnnl_query_prop_kind, 0, &prop_kind);
  4373. return status == dnnl_success ? static_cast<dnnl::prop_kind>(prop_kind)
  4374. : dnnl::prop_kind::undef;
  4375. }
  4376. /// Returns a memory descriptor.
  4377. ///
  4378. /// @note
  4379. /// There are also convenience methods
  4380. /// #dnnl::primitive_desc_base::src_desc(),
  4381. /// #dnnl::primitive_desc_base::dst_desc(), and others.
  4382. ///
  4383. /// @param what The kind of parameter to query; can be
  4384. /// #dnnl::query::src_md, #dnnl::query::dst_md, etc.
  4385. /// @param idx Index of the parameter. For example, convolution bias can
  4386. /// be queried with what = #dnnl::query::weights_md and idx = 1.
  4387. /// @returns The requested memory descriptor.
  4388. /// @returns A zero memory descriptor if the primitive does not have a
  4389. /// parameter of the specified kind or index.
  4390. memory::desc query_md(query what, int idx = 0) const {
  4391. std::vector<query> valid_q {query::src_md, query::diff_src_md,
  4392. query::weights_md, query::diff_weights_md, query::dst_md,
  4393. query::diff_dst_md, query::workspace_md, query::scratchpad_md,
  4394. query::exec_arg_md};
  4395. if (!std::any_of(valid_q.cbegin(), valid_q.cend(),
  4396. [=](query q) { return what == q; }))
  4397. DNNL_THROW_ERROR(dnnl_invalid_arguments,
  4398. "memory descriptor query is invalid");
  4399. const_dnnl_memory_desc_t cdesc = dnnl_primitive_desc_query_md(
  4400. get(), dnnl::convert_to_c(what), idx);
  4401. if (!cdesc) return memory::desc();
  4402. dnnl_memory_desc_t cloned_md = nullptr;
  4403. error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
  4404. "could not clone a memory descriptor");
  4405. return memory::desc(cloned_md);
  4406. }
  4407. /// Returns a source memory descriptor.
  4408. /// @param idx Source index.
  4409. /// @returns Source memory descriptor.
  4410. /// @returns A zero memory descriptor if the primitive does not have a
  4411. /// source parameter with index @p idx.
  4412. memory::desc src_desc(int idx) const {
  4413. return query_md(query::src_md, idx);
  4414. }
  4415. /// Returns a destination memory descriptor.
  4416. /// @param idx Destination index.
  4417. /// @returns Destination memory descriptor.
  4418. /// @returns A zero memory descriptor if the primitive does not have a
  4419. /// destination parameter with index @p idx.
  4420. memory::desc dst_desc(int idx) const {
  4421. return query_md(query::dst_md, idx);
  4422. }
  4423. /// Returns a weights memory descriptor.
  4424. /// @param idx Weights index.
  4425. /// @returns Weights memory descriptor.
  4426. /// @returns A zero memory descriptor if the primitive does not have a
  4427. /// weights parameter with index @p idx.
  4428. memory::desc weights_desc(int idx) const {
  4429. return query_md(query::weights_md, idx);
  4430. }
  4431. /// Returns a diff source memory descriptor.
  4432. /// @param idx Diff source index.
  4433. /// @returns Diff source memory descriptor.
  4434. /// @returns A zero memory descriptor if the primitive does not have a
  4435. /// diff source parameter with index @p idx.
  4436. memory::desc diff_src_desc(int idx) const {
  4437. return query_md(query::diff_src_md, idx);
  4438. }
  4439. /// Returns a diff destination memory descriptor.
  4440. /// @param idx Diff destination index.
  4441. /// @returns Diff destination memory descriptor.
  4442. /// @returns A zero memory descriptor if the primitive does not have a
  4443. /// diff destination parameter with index @p idx.
  4444. memory::desc diff_dst_desc(int idx) const {
  4445. return query_md(query::diff_dst_md, idx);
  4446. }
  4447. /// Returns a diff weights memory descriptor.
  4448. /// @param idx Diff weights index.
  4449. /// @returns Diff weights memory descriptor.
  4450. /// @returns A zero memory descriptor if the primitive does not have a
  4451. /// diff weights parameter with index @p idx.
  4452. memory::desc diff_weights_desc(int idx) const {
  4453. return query_md(query::diff_weights_md, idx);
  4454. }
  4455. // Separate versions without the index argument for documentation
  4456. // purposes.
  4457. /// Returns a source memory descriptor.
  4458. /// @returns Source memory descriptor.
  4459. /// @returns A zero memory descriptor if the primitive does not have a
  4460. /// source parameter.
  4461. memory::desc src_desc() const { return src_desc(0); }
  4462. /// Returns a destination memory descriptor.
  4463. /// @returns Destination memory descriptor.
  4464. /// @returns A zero memory descriptor if the primitive does not have a
  4465. /// destination parameter.
  4466. memory::desc dst_desc() const { return dst_desc(0); }
  4467. /// Returns a weights memory descriptor.
  4468. /// @returns Weights memory descriptor.
  4469. /// @returns A zero memory descriptor if the primitive does not have a
  4470. /// weights parameter.
  4471. memory::desc weights_desc() const { return weights_desc(0); }
  4472. /// Returns a diff source memory descriptor.
  4473. /// @returns Diff source memory descriptor.
  4474. /// @returns A zero memory descriptor if the primitive does not have a
  4475. /// diff source memory with.
  4476. memory::desc diff_src_desc() const { return diff_src_desc(0); }
  4477. /// Returns a diff destination memory descriptor.
  4478. /// @returns Diff destination memory descriptor.
  4479. /// @returns A zero memory descriptor if the primitive does not have a
  4480. /// diff destination parameter.
  4481. memory::desc diff_dst_desc() const { return diff_dst_desc(0); }
  4482. /// Returns a diff weights memory descriptor.
  4483. /// @returns Diff weights memory descriptor.
  4484. /// @returns A zero memory descriptor if the primitive does not have a
  4485. /// diff weights parameter.
  4486. memory::desc diff_weights_desc() const { return diff_weights_desc(0); }
  4487. /// Returns the workspace memory descriptor.
  4488. /// @returns Workspace memory descriptor.
  4489. /// @returns A zero memory descriptor if the primitive does not require
  4490. /// workspace parameter.
  4491. memory::desc workspace_desc() const {
  4492. return query_md(query::workspace_md, 0);
  4493. }
  4494. /// Returns the scratchpad memory descriptor.
  4495. /// @returns scratchpad memory descriptor.
  4496. /// @returns A zero memory descriptor if the primitive does not require
  4497. /// scratchpad parameter.
  4498. /// @sa @ref dev_guide_attributes_scratchpad
  4499. memory::desc scratchpad_desc() const {
  4500. return query_md(query::scratchpad_md, 0);
  4501. }
  4502. /// Returns the engine on which the scratchpad memory is located.
  4503. /// @returns The engine on which the scratchpad memory is located.
  4504. engine scratchpad_engine() const {
  4505. dnnl_engine_t c_engine;
  4506. error::wrap_c_api(dnnl_primitive_desc_query(get(),
  4507. dnnl::convert_to_c(query::scratchpad_engine),
  4508. 0, &c_engine),
  4509. "could not retrieve scratchpad engine from a primitive "
  4510. "descriptor");
  4511. return engine(c_engine, true);
  4512. }
  4513. /// Returns the primitive attributes.
  4514. /// @returns The primitive attributes.
  4515. primitive_attr get_primitive_attr() const {
  4516. const_dnnl_primitive_attr_t const_c_attr;
  4517. error::wrap_c_api(dnnl_primitive_desc_get_attr(get(), &const_c_attr),
  4518. "could not get attributes from a primitive descriptor");
  4519. dnnl_primitive_attr_t c_attr;
  4520. error::wrap_c_api(dnnl_primitive_attr_clone(&c_attr, const_c_attr),
  4521. "could not clone primitive attributes");
  4522. return primitive_attr(c_attr);
  4523. }
  4524. /// Returns the kind of the primitive descriptor.
  4525. /// @returns The kind of the primitive descriptor.
  4526. dnnl::primitive::kind get_kind() const {
  4527. dnnl_primitive_kind_t kind;
  4528. error::wrap_c_api(dnnl_primitive_desc_query(get(),
  4529. dnnl_query_primitive_kind, 0, (void *)&kind),
  4530. "could not get primitive kind from a primitive descriptor");
  4531. return static_cast<dnnl::primitive::kind>(kind);
  4532. }
  4533. /// Returns the cache blob ID of the primitive descriptor.
  4534. /// @returns The cache blob ID of the primitive descriptor.
  4535. std::vector<uint8_t> get_cache_blob_id() const {
  4536. dnnl_dim_t count;
  4537. const uint8_t *c_id;
  4538. error::wrap_c_api(
  4539. dnnl_primitive_desc_query(get(),
  4540. dnnl::convert_to_c(query::cache_blob_id_size_s64), 0,
  4541. (void *)&count),
  4542. "could not get size of cache blob ID from a primitive "
  4543. "descriptor");
  4544. error::wrap_c_api(dnnl_primitive_desc_query(get(),
  4545. dnnl::convert_to_c(query::cache_blob_id), 0,
  4546. (void **)&c_id),
  4547. "could not get cache blob ID from a primitive descriptor");
  4548. std::vector<uint8_t> id(c_id, c_id + count);
  4549. return id;
  4550. }
  4551. protected:
  4552. /// Returns a float value.
  4553. /// @param what The value to query.
  4554. /// @returns The result of the query.
  4555. /// @returns Zero if the primitive doesn't support the query.
  4556. float query_f32(query what) const {
  4557. float res;
  4558. dnnl_status_t status = dnnl_primitive_desc_query(
  4559. get(), dnnl::convert_to_c(what), 0, &res);
  4560. return status == dnnl_success ? res : 0.0f;
  4561. }
  4562. /// Returns an #dnnl::algorithm value.
  4563. /// @param what The value to query.
  4564. /// @returns The result of the query.
  4565. /// @returns #dnnl::algorithm::undef if the primitive doesn't support
  4566. /// the query.
  4567. algorithm query_alg(query what) const {
  4568. dnnl_alg_kind_t res;
  4569. dnnl_status_t status = dnnl_primitive_desc_query(
  4570. get(), dnnl::convert_to_c(what), 0, &res);
  4571. return status == dnnl_success ? static_cast<dnnl::algorithm>(res)
  4572. : algorithm::undef;
  4573. }
  4574. /// Returns a memory::dims value.
  4575. /// @param what The value to query.
  4576. /// @returns The result of the query.
  4577. /// @returns An empty #dnnl::memory::dims if the primitive doesn't support
  4578. /// the query.
  4579. memory::dims query_dims(query what) const {
  4580. const bool is_backward = get_prop_kind() != prop_kind::forward_training
  4581. && get_prop_kind() != prop_kind::forward_inference;
  4582. const_dnnl_memory_desc_t md = dnnl_primitive_desc_query_md(get(),
  4583. is_backward ? dnnl_query_diff_dst_md : dnnl_query_dst_md, 0);
  4584. int nspatial_dims = 0;
  4585. if (md) {
  4586. int ndims;
  4587. error::wrap_c_api(
  4588. dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &ndims),
  4589. "could not query ndims from a memory descriptor");
  4590. nspatial_dims = ndims - 2;
  4591. }
  4592. dnnl_dims_t *c_dims;
  4593. dnnl_status_t status = dnnl_primitive_desc_query(
  4594. get(), dnnl::convert_to_c(what), 0, &c_dims);
  4595. return status == dnnl_success
  4596. ? memory::dims(*c_dims, *c_dims + nspatial_dims)
  4597. : memory::dims {};
  4598. }
  4599. /// Returns an #dnnl::engine value.
  4600. /// @param what The value to query.
  4601. /// @returns The result of the query.
  4602. /// @returns A weak handle to the engine that the primitive descriptor was
  4603. /// created with.
  4604. engine query_engine(query what) const {
  4605. dnnl_engine_t c_engine;
  4606. error::wrap_c_api(dnnl_primitive_desc_query(get(),
  4607. dnnl::convert_to_c(what), 0, &c_engine),
  4608. "could not get an engine from a primitive_desc");
  4609. return engine(c_engine, true);
  4610. }
  4611. /// Resets the value of the handle to a clone of a C API primitive
  4612. /// descriptor.
  4613. /// @param pd A C API primitive descriptor to clone.
  4614. void reset_with_clone(const_dnnl_primitive_desc_t pd) {
  4615. dnnl_primitive_desc_t new_pd;
  4616. error::wrap_c_api(dnnl_primitive_desc_clone(&new_pd, pd),
  4617. "could not clone a primitive descriptor");
  4618. reset(new_pd);
  4619. }
  4620. /// Constructs a primitive descriptor base object from a clone of a C API
  4621. /// primitive descriptor after verifying that it is what the caller
  4622. /// expects.
  4623. ///
  4624. /// @note
  4625. /// The @p prim_kind should map to a primitive that does not have
  4626. /// different values of propagation kind (e.g. #dnnl::binary).
  4627. /// @note
  4628. /// Primitive descriptor base constructed this way does not support
  4629. /// next_impl() (will throw).
  4630. ///
  4631. /// @param pd C API primitive descriptor to clone.
  4632. /// @param prim_kind Expected primitive kind.
  4633. primitive_desc_base(
  4634. dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind)
  4635. : primitive_desc_base(pd, prim_kind, dnnl::prop_kind::undef) {}
  4636. /// Constructs a primitive descriptor base object from a clone of a C API
  4637. /// primitive descriptor after verifying that it is what the caller
  4638. /// expects.
  4639. ///
  4640. /// @note
  4641. /// Primitive descriptor base constructed this way does not support
  4642. /// next_impl() (will throw).
  4643. ///
  4644. /// @param pd C API primitive descriptor to clone.
  4645. /// @param prim_kind Expected primitive kind.
  4646. /// @param aprop_kind Expected propagation kind.
  4647. primitive_desc_base(dnnl_primitive_desc_t pd,
  4648. dnnl::primitive::kind prim_kind, dnnl::prop_kind aprop_kind)
  4649. : primitive_desc_base(pd, prim_kind, aprop_kind, aprop_kind) {}
  4650. /// Constructs a primitive descriptor base object from a clone of a C API
  4651. /// primitive descriptor after verifying that it is what the caller
  4652. /// expects.
  4653. ///
  4654. /// @note
  4655. /// Primitive descriptor base constructed this way does not support
  4656. /// next_impl() (will throw).
  4657. ///
  4658. /// @param pd C API primitive descriptor to clone.
  4659. /// @param prim_kind Expected primitive kind.
  4660. /// @param prop_kind1 Expected propagation kind (option 1).
  4661. /// @param prop_kind2 Expected propagation kind (option 2). This value is
  4662. /// checked if the check with @p prop_kind1 fails.
  4663. primitive_desc_base(dnnl_primitive_desc_t pd,
  4664. dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1,
  4665. dnnl::prop_kind prop_kind2) {
  4666. // It is OK to pass an empty primitive descriptor
  4667. if (pd == nullptr) return;
  4668. dnnl_status_t rc;
  4669. dnnl_primitive_kind_t c_prim_kind = convert_to_c(prim_kind);
  4670. dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
  4671. dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
  4672. // Check that primitive kind matches
  4673. dnnl_primitive_kind_t pd_kind;
  4674. rc = dnnl_primitive_desc_query(
  4675. pd, dnnl_query_primitive_kind, 0, (void *)&pd_kind);
  4676. error::wrap_c_api(
  4677. rc, "could not get primitive kind from a primitive descriptor");
  4678. if (pd_kind != c_prim_kind)
  4679. DNNL_THROW_ERROR(dnnl_invalid_arguments,
  4680. "primitive descriptor operation kind mismatch");
  4681. // Check that propagation kind matches
  4682. dnnl_prop_kind_t pd_prop_kind;
  4683. rc = dnnl_primitive_desc_query(
  4684. pd, dnnl_query_prop_kind, 0, (void *)&pd_prop_kind);
  4685. // Something went wrong
  4686. if (rc != dnnl_success && rc != dnnl_unimplemented)
  4687. DNNL_THROW_ERROR(dnnl_invalid_arguments,
  4688. "could not get propagation kind from the primitive "
  4689. "descriptor");
  4690. // Everything is fine
  4691. if ((rc == dnnl_unimplemented && c_prop_kind1 == dnnl_prop_kind_undef)
  4692. || (rc == dnnl_success
  4693. && (pd_prop_kind == c_prop_kind1
  4694. || pd_prop_kind == c_prop_kind2))) {
  4695. reset_with_clone(pd);
  4696. return;
  4697. }
  4698. // We could get the propagation kind but there is a mismatch
  4699. DNNL_THROW_ERROR(dnnl_invalid_arguments,
  4700. "primitive descriptor propagation kind mismatch");
  4701. }
  4702. /// Returns a constant reference to a static instance of default constructed
  4703. /// primitive attributes
  4704. static const primitive_attr &default_attr() {
  4705. static const primitive_attr attr;
  4706. return attr;
  4707. }
  4708. const_dnnl_memory_desc_t optional_arg(const memory::desc *md) {
  4709. return md ? md->get() : nullptr;
  4710. }
  4711. const dnnl_dim_t *optional_arg(const memory::dims *dims) {
  4712. return dims ? dims->data() : nullptr;
  4713. }
  4714. const float *optional_arg(const std::vector<float> *arg) {
  4715. return arg ? arg->data() : nullptr;
  4716. }
  4717. using base = primitive_desc_base;
  4718. };
  4719. /// @} dnnl_api_primitives_common
  4720. /// @addtogroup dnnl_api_reorder Reorder
  4721. ///
  4722. /// A primitive to copy data between two memory objects. This primitive is
  4723. /// typically used to change the way the data is laid out in memory.
  4724. ///
  4725. /// @sa @ref dev_guide_reorder in developer guide
  4726. ///
  4727. /// @{
  4728. /// Reorder primitive.
  4729. struct reorder : public primitive {
  4730. /// Primitive descriptor for a reorder primitive.
  4731. struct primitive_desc : public primitive_desc_base {
  4732. using primitive_desc_base::primitive_desc_base;
  4733. /// Default constructor. Produces an empty object.
  4734. primitive_desc() = default;
  4735. /// Constructs a primitive descriptor for reorder primitive.
  4736. ///
  4737. /// @note
  4738. /// If @p allow_empty is true, the constructor does not throw if a
  4739. /// primitive descriptor cannot be created.
  4740. ///
  4741. /// @param src_engine Engine on which the source memory object will be
  4742. /// located.
  4743. /// @param src_md Source memory descriptor.
  4744. /// @param dst_engine Engine on which the destination memory object
  4745. /// will be located.
  4746. /// @param dst_md Destination memory descriptor.
  4747. /// @param attr Primitive attributes to use. Attributes are optional
  4748. /// and default to empty attributes.
  4749. /// @param allow_empty A flag signifying whether construction is allowed
  4750. /// to fail without throwing an exception. In this case an empty
  4751. /// object will be produced. This flag is optional and defaults to
  4752. /// false.
  4753. primitive_desc(const engine &src_engine, const memory::desc &src_md,
  4754. const engine &dst_engine, const memory::desc &dst_md,
  4755. const primitive_attr &attr = default_attr(),
  4756. bool allow_empty = false) {
  4757. dnnl_primitive_desc_t result;
  4758. dnnl_status_t status = dnnl_reorder_primitive_desc_create(&result,
  4759. src_md.get(), src_engine.get(), dst_md.get(),
  4760. dst_engine.get(), attr.get());
  4761. if (!allow_empty)
  4762. error::wrap_c_api(status,
  4763. "could not create a primitive descriptor for "
  4764. "the reorder primitive. Run workload with "
  4765. "environment variable ONEDNN_VERBOSE=all to get "
  4766. "additional diagnostic information.");
  4767. reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
  4768. }
  4769. /// Constructs a primitive descriptor for reorder primitive.
  4770. ///
  4771. /// @param src Source memory object. It is used to obtain the source
  4772. /// memory descriptor and engine.
  4773. /// @param dst Destination memory object. It is used to obtain the
  4774. /// destination memory descriptor and engine.
  4775. /// @param attr Primitive attributes to use. Attributes are optional
  4776. /// and default to empty attributes.
  4777. /// @param allow_empty A flag signifying whether construction is allowed
  4778. /// to fail without throwing an exception. In this case an empty
  4779. /// object will be produced. This flag is optional and defaults to
  4780. /// false.
  4781. primitive_desc(const memory &src, const memory &dst,
  4782. const primitive_attr &attr = default_attr(),
  4783. bool allow_empty = false) {
  4784. dnnl_primitive_desc_t result;
  4785. auto src_md = src.get_desc();
  4786. auto dst_md = dst.get_desc();
  4787. dnnl_status_t status = dnnl_reorder_primitive_desc_create(&result,
  4788. src_md.get(), src.get_engine().get(), dst_md.get(),
  4789. dst.get_engine().get(), attr.get());
  4790. if (!allow_empty)
  4791. error::wrap_c_api(status,
  4792. "could not create a primitive descriptor for "
  4793. "the reorder primitive. Run workload with "
  4794. "environment variable ONEDNN_VERBOSE=all to get "
  4795. "additional diagnostic information.");
  4796. reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
  4797. }
  4798. /// Constructs a primitive descriptor for reorder primitive from a C
  4799. /// API primitive descriptor which must have a matching kind.
  4800. ///
  4801. /// @param pd C API primitive descriptor for reorder primitive.
  4802. primitive_desc(dnnl_primitive_desc_t pd)
  4803. : primitive_desc_base(pd, dnnl::primitive::kind::reorder) {}
  4804. /// Returns the engine on which the source memory is allocated.
  4805. /// @returns The engine on which the source memory is allocated.
  4806. engine get_src_engine() const {
  4807. return query_engine(dnnl::query::reorder_src_engine);
  4808. }
  4809. /// Returns the engine on which the destination memory is allocated.
  4810. /// @returns The engine on which the destination memory is allocated.
  4811. engine get_dst_engine() const {
  4812. return query_engine(dnnl::query::reorder_dst_engine);
  4813. }
  4814. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  4815. memory::desc src_desc() const { return base::src_desc(0); }
  4816. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  4817. memory::desc dst_desc() const { return base::dst_desc(0); }
  4818. };
  4819. /// Default constructor. Produces an empty object.
  4820. reorder() = default;
  4821. /// Constructs a reorder primitive.
  4822. /// @param pd Primitive descriptor for reorder primitive.
  4823. reorder(const primitive_desc &pd) : primitive(pd.get()) {}
  4824. /// Constructs a reorder primitive from a cache blob.
  4825. /// @param pd Primitive descriptor for reorder primitive.
  4826. /// @param cache_blob Cache blob.
  4827. reorder(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  4828. : primitive(pd.get(), cache_blob) {}
  4829. /// Constructs a reorder primitive that would reorder data between memory
  4830. /// objects having the same memory descriptors as memory objects @p src and
  4831. /// @p dst.
  4832. ///
  4833. /// @param src Source memory object.
  4834. /// @param dst Destination memory object.
  4835. /// @param attr Primitive attributes to use (optional).
  4836. reorder(const memory &src, const memory &dst,
  4837. const primitive_attr &attr = primitive_attr())
  4838. : primitive(primitive_desc(src, dst, attr).get()) {}
  4839. using primitive::execute;
  4840. /// Executes the reorder primitive.
  4841. ///
  4842. /// @param astream Stream object. The stream must belong to the same engine
  4843. /// as the primitive.
  4844. /// @param src Source memory object.
  4845. /// @param dst Destination memory object.
  4846. void execute(const stream &astream, memory &src, memory &dst) const {
  4847. primitive::execute(astream, {{DNNL_ARG_FROM, src}, {DNNL_ARG_TO, dst}});
  4848. }
  4849. };
  4850. /// @} dnnl_api_reorder
  4851. /// @addtogroup dnnl_api_concat Concat
  4852. ///
  4853. /// A primitive to concatenate data by arbitrary dimension.
  4854. ///
  4855. /// @sa @ref dev_guide_concat in developer guide
  4856. ///
  4857. /// @{
  4858. /// @cond DO_NOT_DOCUMENT_THIS
  4859. inline std::vector<const_dnnl_memory_desc_t> convert_to_c(
  4860. const std::vector<memory::desc> &mds) {
  4861. std::vector<const_dnnl_memory_desc_t> c_mds;
  4862. c_mds.reserve(mds.size());
  4863. for (const auto &md : mds)
  4864. c_mds.push_back(md.get());
  4865. return c_mds;
  4866. }
  4867. /// @endcond
  4868. /// Tensor concatenation (concat) primitive.
  4869. struct concat : public primitive {
  4870. /// Primitive descriptor for a concat primitive.
  4871. struct primitive_desc : public primitive_desc_base {
  4872. using primitive_desc_base::primitive_desc_base;
  4873. /// Default constructor. Produces an empty object.
  4874. primitive_desc() = default;
  4875. /// Constructs a primitive descriptor for an out-of-place concatenation
  4876. /// primitive.
  4877. ///
  4878. /// @param aengine Engine to perform the operation on.
  4879. /// @param dst Destination memory descriptor.
  4880. /// @param concat_dimension Source tensors will be concatenated over
  4881. /// dimension with this index. Note that order of dimensions does
  4882. /// not depend on memory format.
  4883. /// @param srcs Vector of source memory descriptors.
  4884. /// @param attr Primitive attributes to use. Attributes are optional
  4885. /// and default to empty attributes.
  4886. /// @param allow_empty A flag signifying whether construction is
  4887. /// allowed to fail without throwing an exception. In this case an
  4888. /// empty object will be produced. This flag is optional and
  4889. /// defaults to false.
  4890. primitive_desc(const engine &aengine, const memory::desc &dst,
  4891. int concat_dimension, const std::vector<memory::desc> &srcs,
  4892. const primitive_attr &attr = default_attr(),
  4893. bool allow_empty = false) {
  4894. auto c_srcs = convert_to_c(srcs);
  4895. dnnl_primitive_desc_t result;
  4896. dnnl_status_t status = dnnl_concat_primitive_desc_create(&result,
  4897. aengine.get(), dst.get(), (int)c_srcs.size(),
  4898. concat_dimension, c_srcs.data(), attr.get());
  4899. if (!allow_empty)
  4900. error::wrap_c_api(status,
  4901. "could not create a primitive descriptor for "
  4902. "the concat primitive. Run workload with "
  4903. "environment variable ONEDNN_VERBOSE=all to get "
  4904. "additional diagnostic information.");
  4905. reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
  4906. }
  4907. /// Constructs a primitive descriptor for an out-of-place concatenation
  4908. /// primitive.
  4909. ///
  4910. /// This version derives the destination memory descriptor
  4911. /// automatically.
  4912. ///
  4913. /// @param aengine Engine to perform the operation on.
  4914. /// @param concat_dimension Source tensors will be concatenated over
  4915. /// dimension with this index. Note that order of dimensions does
  4916. /// not depend on memory format.
  4917. /// @param srcs Vector of source memory descriptors.
  4918. /// @param attr Primitive attributes to use. Attributes are optional
  4919. /// and default to empty attributes.
  4920. /// @param allow_empty A flag signifying whether construction is
  4921. /// allowed to fail without throwing an exception. In this case an
  4922. /// empty object will be produced. This flag is optional and
  4923. /// defaults to false.
  4924. primitive_desc(const engine &aengine, int concat_dimension,
  4925. const std::vector<memory::desc> &srcs,
  4926. const primitive_attr &attr = default_attr(),
  4927. bool allow_empty = false) {
  4928. auto c_api_srcs = convert_to_c(srcs);
  4929. dnnl_primitive_desc_t result;
  4930. dnnl_status_t status = dnnl_concat_primitive_desc_create(&result,
  4931. aengine.get(), nullptr, (int)c_api_srcs.size(),
  4932. concat_dimension, c_api_srcs.data(), attr.get());
  4933. if (!allow_empty)
  4934. error::wrap_c_api(status,
  4935. "could not create a primitive descriptor for "
  4936. "the concat primitive. Run workload with "
  4937. "environment variable ONEDNN_VERBOSE=all to get "
  4938. "additional diagnostic information.");
  4939. reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
  4940. }
  4941. /// Constructs a primitive descriptor for concat primitive from a C
  4942. /// API primitive descriptor which must have a matching kind.
  4943. ///
  4944. /// @param pd C API primitive descriptor for concat primitive.
  4945. primitive_desc(dnnl_primitive_desc_t pd)
  4946. : primitive_desc_base(pd, dnnl::primitive::kind::concat) {}
  4947. /// @copydoc dnnl::primitive_desc_base::src_desc(int)const
  4948. memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
  4949. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  4950. memory::desc dst_desc() const { return base::dst_desc(0); }
  4951. };
  4952. /// Default constructor. Produces an empty object.
  4953. concat() = default;
  4954. /// Constructs a concatenation primitive.
  4955. /// @param pd Primitive descriptor for concatenation primitive.
  4956. concat(const primitive_desc &pd) : primitive(pd.get()) {}
  4957. /// Constructs a concatenation primitive from a cache blob.
  4958. /// @param pd Primitive descriptor for concatenation primitive.
  4959. /// @param cache_blob Cache blob.
  4960. concat(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  4961. : primitive(pd.get(), cache_blob) {}
  4962. };
  4963. /// @} dnnl_api_concat
  4964. /// @addtogroup dnnl_api_sum Sum
  4965. ///
  4966. /// A primitive to sum multiple tensors.
  4967. ///
  4968. /// @sa @ref dev_guide_sum in developer guide
  4969. ///
  4970. /// @{
  4971. /// Out-of-place summation (sum) primitive.
  4972. struct sum : public primitive {
  4973. /// Primitive descriptor for a sum primitive.
  4974. struct primitive_desc : public primitive_desc_base {
  4975. using primitive_desc_base::primitive_desc_base;
  4976. /// Default constructor. Produces an empty object.
  4977. primitive_desc() = default;
  4978. /// Constructs a primitive descriptor for a sum primitive.
  4979. ///
  4980. /// @param aengine Engine to perform the operation on.
  4981. /// @param dst Destination memory descriptor.
  4982. /// @param scales Vector of scales to multiply data in each source
  4983. /// memory by.
  4984. /// @param srcs Vector of source memory descriptors.
  4985. /// @param attr Primitive attributes to use. Attributes are optional
  4986. /// and default to empty attributes.
  4987. /// @param allow_empty A flag signifying whether construction is
  4988. /// allowed to fail without throwing an exception. In this case an
  4989. /// empty object will be produced. This flag is optional and
  4990. /// defaults to false.
  4991. primitive_desc(const engine &aengine, const memory::desc &dst,
  4992. const std::vector<float> &scales,
  4993. const std::vector<memory::desc> &srcs,
  4994. const primitive_attr &attr = default_attr(),
  4995. bool allow_empty = false) {
  4996. validate_container_size(scales,
  4997. "counts of scales and sources are not equal",
  4998. (int)srcs.size(), (int)srcs.size());
  4999. auto c_api_srcs = convert_to_c(srcs);
  5000. dnnl_primitive_desc_t result;
  5001. dnnl_status_t status = dnnl_sum_primitive_desc_create(&result,
  5002. aengine.get(), dst.get(), (int)c_api_srcs.size(),
  5003. scales.data(), c_api_srcs.data(), attr.get());
  5004. if (!allow_empty)
  5005. error::wrap_c_api(status,
  5006. "could not create a primitive descriptor for "
  5007. "the sum primitive. Run workload with "
  5008. "environment variable ONEDNN_VERBOSE=all to get "
  5009. "additional diagnostic information.");
  5010. reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
  5011. }
  5012. /// Constructs a primitive descriptor for a sum primitive.
  5013. ///
  5014. /// This version derives the destination memory descriptor
  5015. /// automatically.
  5016. ///
  5017. /// @param aengine Engine on which to perform the operation.
  5018. /// @param scales Vector of scales by which to multiply data in each
  5019. /// source memory object.
  5020. /// @param srcs Vector of source memory descriptors.
  5021. /// @param attr Primitive attributes to use. Attributes are optional
  5022. /// and default to empty attributes.
  5023. /// @param allow_empty A flag signifying whether construction is
  5024. /// allowed to fail without throwing an exception. In this case an
  5025. /// empty object will be produced. This flag is optional and
  5026. /// defaults to false.
  5027. primitive_desc(const engine &aengine, const std::vector<float> &scales,
  5028. const std::vector<memory::desc> &srcs,
  5029. const primitive_attr &attr = default_attr(),
  5030. bool allow_empty = false) {
  5031. validate_container_size(scales,
  5032. "counts of scales and sources are not equal",
  5033. (int)srcs.size(), (int)srcs.size());
  5034. auto c_api_srcs = convert_to_c(srcs);
  5035. dnnl_primitive_desc_t result;
  5036. dnnl_status_t status = dnnl_sum_primitive_desc_create(&result,
  5037. aengine.get(), nullptr, (int)c_api_srcs.size(),
  5038. scales.data(), c_api_srcs.data(), attr.get());
  5039. if (!allow_empty)
  5040. error::wrap_c_api(status,
  5041. "could not create a primitive descriptor for "
  5042. "the sum primitive. Run workload with "
  5043. "environment variable ONEDNN_VERBOSE=all to get "
  5044. "additional diagnostic information.");
  5045. reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
  5046. }
  5047. /// Constructs a primitive descriptor for sum primitive from a C API
  5048. /// primitive descriptor which must have a matching kind.
  5049. ///
  5050. /// @param pd C API primitive descriptor for sum primitive.
  5051. primitive_desc(dnnl_primitive_desc_t pd)
  5052. : primitive_desc_base(pd, dnnl::primitive::kind::sum) {}
  5053. /// @copydoc dnnl::primitive_desc_base::src_desc(int)const
  5054. memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
  5055. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  5056. memory::desc dst_desc() const { return base::dst_desc(0); }
  5057. };
  5058. /// Default constructor. Produces an empty object.
  5059. sum() = default;
  5060. /// Constructs a sum primitive.
  5061. /// @param pd Primitive descriptor for sum primitive.
  5062. sum(const primitive_desc &pd) : primitive(pd.get()) {}
  5063. /// Constructs a sum primitive from a cache blob.
  5064. /// @param pd Primitive descriptor for sum primitive.
  5065. /// @param cache_blob Cache blob.
  5066. sum(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  5067. : primitive(pd.get(), cache_blob) {}
  5068. };
  5069. /// @} dnnl_api_sum
  5070. /// @addtogroup dnnl_api_primitives_common
  5071. /// @{
  5072. /// A base class for descriptors of all primitives that support iteration
  5073. /// over multiple implementations.
  5074. struct primitive_desc : public primitive_desc_base {
  5075. using primitive_desc_base::primitive_desc_base;
  5076. primitive_desc() = default;
  5077. /// Changes the primitive descriptor to point to the next available
  5078. /// implementation.
  5079. ///
  5080. /// @returns @c true on success and @c false if the last available
  5081. /// implementation has already been reached. In the latter case, the
  5082. /// primitive descriptor itself is kept unchanged.
  5083. bool next_impl() {
  5084. dnnl_status_t status = dnnl_primitive_desc_next_impl(get());
  5085. if (status == dnnl_last_impl_reached) return false;
  5086. error::wrap_c_api(status, "last available implementation is reached");
  5087. return true;
  5088. }
  5089. };
  5090. /// @} dnnl_api_primitives_common
  5091. /// @addtogroup dnnl_api_convolution Convolution
  5092. ///
  5093. /// A primitive to perform 1D, 2D or 3D convolution. Supported variants are
  5094. /// forward propagation, backward propagation, and weights gradient with or
  5095. /// without bias.
  5096. ///
  5097. /// @sa @ref dev_guide_convolution in developer guide
  5098. ///
  5099. /// @{
  5100. /// Convolution forward propagation primitive.
  5101. struct convolution_forward : public primitive {
  5102. /// Primitive descriptor for a convolution forward propagation primitive.
  5103. struct primitive_desc : public dnnl::primitive_desc {
  5104. /// Default constructor. Produces an empty object.
  5105. primitive_desc() = default;
  5106. /// Constructs a primitive descriptor for a convolution forward
  5107. /// propagation primitive with bias.
  5108. ///
  5109. /// @note
  5110. /// All the memory descriptors may be initialized with the
  5111. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5112. ///
  5113. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  5114. /// for spatial dimensions only and hence must have the same number of
  5115. /// elements as there are spatial dimensions. The order of values is
  5116. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  5117. /// and 2D tensors), and width.
  5118. ///
  5119. /// @param aengine Engine to use.
  5120. /// @param aprop_kind Propagation kind. Possible values are
  5121. /// #dnnl::prop_kind::forward_training, and
  5122. /// #dnnl::prop_kind::forward_inference.
  5123. /// @param aalgorithm Convolution algorithm. Possible values are
  5124. /// #dnnl::algorithm::convolution_direct,
  5125. /// #dnnl::algorithm::convolution_winograd, and
  5126. /// #dnnl::algorithm::convolution_auto.
  5127. /// @param src_desc Source memory descriptor.
  5128. /// @param weights_desc Weights memory descriptor.
  5129. /// @param bias_desc Bias memory descriptor. Passing zero memory
  5130. /// descriptor disables the bias term.
  5131. /// @param dst_desc Destination memory descriptor.
  5132. /// @param strides Strides for each spatial dimension.
  5133. /// @param padding_l Vector of padding values for low indices for each
  5134. /// spatial dimension `([[front,] top,] left)`.
  5135. /// @param padding_r Vector of padding values for high indices for
  5136. /// each spatial dimension `([[back,] bottom,] right)`.
  5137. /// @param attr Primitive attributes to use. Attributes are optional
  5138. /// and default to empty attributes.
  5139. /// @param allow_empty A flag signifying whether construction is
  5140. /// allowed to fail without throwing an exception. In this case an
  5141. /// empty object will be produced. This flag is optional and
  5142. /// defaults to false.
  5143. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  5144. algorithm aalgorithm, const memory::desc &src_desc,
  5145. const memory::desc &weights_desc, const memory::desc &bias_desc,
  5146. const memory::desc &dst_desc, const memory::dims &strides,
  5147. const memory::dims &padding_l, const memory::dims &padding_r,
  5148. const primitive_attr &attr = default_attr(),
  5149. bool allow_empty = false)
  5150. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  5151. weights_desc, &bias_desc, dst_desc, strides, nullptr,
  5152. padding_l, padding_r, attr, allow_empty) {}
  5153. /// Constructs a primitive descriptor for a convolution forward
  5154. /// propagation primitive without bias.
  5155. ///
  5156. /// @note
  5157. /// All the memory descriptors may be initialized with the
  5158. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5159. ///
  5160. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  5161. /// for spatial dimensions only and hence must have the same number of
  5162. /// elements as there are spatial dimensions. The order of values is
  5163. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  5164. /// and 2D tensors), and width.
  5165. ///
  5166. /// @param aengine Engine to use.
  5167. /// @param aprop_kind Propagation kind. Possible values are
  5168. /// #dnnl::prop_kind::forward_training, and
  5169. /// #dnnl::prop_kind::forward_inference.
  5170. /// @param aalgorithm Convolution algorithm. Possible values are
  5171. /// #dnnl::algorithm::convolution_direct,
  5172. /// #dnnl::algorithm::convolution_winograd, and
  5173. /// #dnnl::algorithm::convolution_auto.
  5174. /// @param src_desc Source memory descriptor.
  5175. /// @param weights_desc Weights memory descriptor.
  5176. /// @param dst_desc Destination memory descriptor.
  5177. /// @param strides Strides for each spatial dimension.
  5178. /// @param padding_l Vector of padding values for low indices for each
  5179. /// spatial dimension `([[front,] top,] left)`.
  5180. /// @param padding_r Vector of padding values for high indices for
  5181. /// each spatial dimension `([[back,] bottom,] right)`.
  5182. /// @param attr Primitive attributes to use. Attributes are optional
  5183. /// and default to empty attributes.
  5184. /// @param allow_empty A flag signifying whether construction is
  5185. /// allowed to fail without throwing an exception. In this case an
  5186. /// empty object will be produced. This flag is optional and
  5187. /// defaults to false.
  5188. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  5189. algorithm aalgorithm, const memory::desc &src_desc,
  5190. const memory::desc &weights_desc, const memory::desc &dst_desc,
  5191. const memory::dims &strides, const memory::dims &padding_l,
  5192. const memory::dims &padding_r,
  5193. const primitive_attr &attr = default_attr(),
  5194. bool allow_empty = false)
  5195. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  5196. weights_desc, nullptr, dst_desc, strides, nullptr,
  5197. padding_l, padding_r, attr, allow_empty) {}
  5198. /// Constructs a primitive descriptor for a convolution forward
  5199. /// propagation primitive with bias.
  5200. ///
  5201. /// @note
  5202. /// All the memory descriptors may be initialized with the
  5203. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5204. ///
  5205. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  5206. /// contain values for spatial dimensions only and hence must have the
  5207. /// same number of elements as there are spatial dimensions. The order
  5208. /// of values is the same as in the tensor: depth (for 3D tensors),
  5209. /// height (for 3D and 2D tensors), and width.
  5210. ///
  5211. /// @param aengine Engine to use.
  5212. /// @param aprop_kind Propagation kind. Possible values are
  5213. /// #dnnl::prop_kind::forward_training, and
  5214. /// #dnnl::prop_kind::forward_inference.
  5215. /// @param aalgorithm Convolution algorithm. Possible values are
  5216. /// #dnnl::algorithm::convolution_direct,
  5217. /// #dnnl::algorithm::convolution_winograd, and
  5218. /// #dnnl::algorithm::convolution_auto.
  5219. /// @param src_desc Source memory descriptor.
  5220. /// @param weights_desc Weights memory descriptor.
  5221. /// @param bias_desc Bias memory descriptor. Passing zero memory
  5222. /// descriptor disables the bias term.
  5223. /// @param dst_desc Destination memory descriptor.
  5224. /// @param strides Strides for each spatial dimension.
  5225. /// @param dilates Dilations for each spatial dimension. A zero value
  5226. /// means no dilation in the corresponding dimension.
  5227. /// @param padding_l Vector of padding values for low indices for each
  5228. /// spatial dimension `([[front,] top,] left)`.
  5229. /// @param padding_r Vector of padding values for high indices for
  5230. /// each spatial dimension `([[back,] bottom,] right)`.
  5231. /// @param attr Primitive attributes to use. Attributes are optional
  5232. /// and default to empty attributes.
  5233. /// @param allow_empty A flag signifying whether construction is
  5234. /// allowed to fail without throwing an exception. In this case an
  5235. /// empty object will be produced. This flag is optional and
  5236. /// defaults to false.
  5237. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  5238. algorithm aalgorithm, const memory::desc &src_desc,
  5239. const memory::desc &weights_desc, const memory::desc &bias_desc,
  5240. const memory::desc &dst_desc, const memory::dims &strides,
  5241. const memory::dims &dilates, const memory::dims &padding_l,
  5242. const memory::dims &padding_r,
  5243. const primitive_attr &attr = default_attr(),
  5244. bool allow_empty = false)
  5245. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  5246. weights_desc, &bias_desc, dst_desc, strides, &dilates,
  5247. padding_l, padding_r, attr, allow_empty) {}
  5248. /// Constructs a primitive descriptor for a convolution forward
  5249. /// propagation primitive without bias.
  5250. ///
  5251. /// @note
  5252. /// All the memory descriptors may be initialized with the
  5253. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5254. ///
  5255. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  5256. /// contain values for spatial dimensions only and hence must have the
  5257. /// same number of elements as there are spatial dimensions. The order
  5258. /// of values is the same as in the tensor: depth (for 3D tensors),
  5259. /// height (for 3D and 2D tensors), and width.
  5260. ///
  5261. /// @param aengine Engine to use.
  5262. /// @param aprop_kind Propagation kind. Possible values are
  5263. /// #dnnl::prop_kind::forward_training, and
  5264. /// #dnnl::prop_kind::forward_inference.
  5265. /// @param aalgorithm Convolution algorithm. Possible values are
  5266. /// #dnnl::algorithm::convolution_direct,
  5267. /// #dnnl::algorithm::convolution_winograd, and
  5268. /// #dnnl::algorithm::convolution_auto.
  5269. /// @param src_desc Source memory descriptor.
  5270. /// @param weights_desc Weights memory descriptor.
  5271. /// @param dst_desc Destination memory descriptor.
  5272. /// @param strides Strides for each spatial dimension.
  5273. /// @param dilates Dilations for each spatial dimension. A zero value
  5274. /// means no dilation in the corresponding dimension.
  5275. /// @param padding_l Vector of padding values for low indices for each
  5276. /// spatial dimension `([[front,] top,] left)`.
  5277. /// @param padding_r Vector of padding values for high indices for
  5278. /// each spatial dimension `([[back,] bottom,] right)`.
  5279. /// @param attr Primitive attributes to use. Attributes are optional
  5280. /// and default to empty attributes.
  5281. /// @param allow_empty A flag signifying whether construction is
  5282. /// allowed to fail without throwing an exception. In this case an
  5283. /// empty object will be produced. This flag is optional and
  5284. /// defaults to false.
  5285. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  5286. algorithm aalgorithm, const memory::desc &src_desc,
  5287. const memory::desc &weights_desc, const memory::desc &dst_desc,
  5288. const memory::dims &strides, const memory::dims &dilates,
  5289. const memory::dims &padding_l, const memory::dims &padding_r,
  5290. const primitive_attr &attr = default_attr(),
  5291. bool allow_empty = false)
  5292. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  5293. weights_desc, nullptr, dst_desc, strides, &dilates,
  5294. padding_l, padding_r, attr, allow_empty) {}
  5295. /// Constructs a primitive descriptor for a convolution forward
  5296. /// propagation primitive from a C API primitive descriptor that must
  5297. /// have a matching kind.
  5298. ///
  5299. /// @param pd C API primitive descriptor for a convolution forward
  5300. /// propagation primitive.
  5301. primitive_desc(dnnl_primitive_desc_t pd)
  5302. : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
  5303. dnnl::prop_kind::forward_training,
  5304. dnnl::prop_kind::forward_inference) {}
  5305. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  5306. memory::desc src_desc() const { return base::src_desc(0); }
  5307. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  5308. memory::desc weights_desc() const { return base::weights_desc(0); }
  5309. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  5310. memory::desc dst_desc() const { return base::dst_desc(0); }
  5311. /// Returns the bias memory descriptor.
  5312. /// @returns The bias memory descriptor.
  5313. /// @returns A zero memory descriptor of the primitive does not have a
  5314. /// bias parameter.
  5315. memory::desc bias_desc() const { return base::weights_desc(1); }
  5316. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  5317. algorithm get_algorithm() const { return base::get_algorithm(); }
  5318. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  5319. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  5320. /// @copydoc dnnl::primitive_desc_base::get_strides()const
  5321. memory::dims get_strides() const { return base::get_strides(); }
  5322. /// @copydoc dnnl::primitive_desc_base::get_dilations()const
  5323. memory::dims get_dilations() const { return base::get_dilations(); }
  5324. /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
  5325. memory::dims get_padding_l() const { return base::get_padding_l(); }
  5326. /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
  5327. memory::dims get_padding_r() const { return base::get_padding_r(); }
  5328. private:
  5329. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  5330. algorithm aalgorithm, const memory::desc &src_desc,
  5331. const memory::desc &weights_desc, const memory::desc *bias_desc,
  5332. const memory::desc &dst_desc, const memory::dims &strides,
  5333. const memory::dims *dilates, const memory::dims &padding_l,
  5334. const memory::dims &padding_r, const primitive_attr &attr,
  5335. bool allow_empty) {
  5336. memory::validate_dims(strides, src_desc.get_ndims() - 2);
  5337. memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
  5338. memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
  5339. if (dilates)
  5340. memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
  5341. dnnl_primitive_desc_t pd = nullptr;
  5342. dnnl_status_t status
  5343. = dnnl_convolution_forward_primitive_desc_create(&pd,
  5344. aengine.get(), dnnl::convert_to_c(aprop_kind),
  5345. convert_to_c(aalgorithm), src_desc.get(),
  5346. weights_desc.get(), optional_arg(bias_desc),
  5347. dst_desc.get(), &strides[0], optional_arg(dilates),
  5348. &padding_l[0], &padding_r[0], attr.get());
  5349. if (!allow_empty)
  5350. error::wrap_c_api(status,
  5351. "could not create a primitive descriptor for "
  5352. "the convolution forward propagation primitive. Run "
  5353. "workload with environment variable ONEDNN_VERBOSE=all "
  5354. "to get additional diagnostic information.");
  5355. reset(pd);
  5356. }
  5357. };
  5358. /// Default constructor. Produces an empty object.
  5359. convolution_forward() = default;
  5360. /// Constructs a convolution forward propagation primitive.
  5361. /// @param pd Primitive descriptor for a convolution forward propagation
  5362. /// primitive.
  5363. convolution_forward(const primitive_desc &pd) : primitive(pd) {}
  5364. /// Constructs a convolution forward propagation primitive from a cache
  5365. /// blob.
  5366. /// @param pd Primitive descriptor for a convolution forward propagation
  5367. /// primitive.
  5368. /// @param cache_blob Cache blob.
  5369. convolution_forward(
  5370. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  5371. : primitive(pd, cache_blob) {}
  5372. };
  5373. /// Convolution backward propagation primitive.
  5374. struct convolution_backward_data : public primitive {
  5375. /// Primitive descriptor for a convolution backward propagation primitive.
  5376. struct primitive_desc : public dnnl::primitive_desc {
  5377. /// Default constructor. Produces an empty object.
  5378. primitive_desc() = default;
  5379. /// Constructs a primitive descriptor for a convolution backward
  5380. /// propagation primitive.
  5381. ///
  5382. /// @note
  5383. /// All the memory descriptors may be initialized with the
  5384. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5385. ///
  5386. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  5387. /// for spatial dimensions only and hence must have the same number of
  5388. /// elements as there are spatial dimensions. The order of values is
  5389. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  5390. /// and 2D tensors), and width.
  5391. ///
  5392. /// @param aengine Engine to use.
  5393. /// @param aalgorithm Convolution algorithm. Possible values are
  5394. /// #dnnl::algorithm::convolution_direct,
  5395. /// #dnnl::algorithm::convolution_winograd, and
  5396. /// #dnnl::algorithm::convolution_auto.
  5397. /// @param diff_src_desc Diff source memory descriptor.
  5398. /// @param weights_desc Weights memory descriptor.
  5399. /// @param diff_dst_desc Diff destination memory descriptor.
  5400. /// @param strides Strides for each spatial dimension.
  5401. /// @param padding_l Vector of padding values for low indices for each
  5402. /// spatial dimension `([[front,] top,] left)`.
  5403. /// @param padding_r Vector of padding values for high indices for
  5404. /// each spatial dimension `([[back,] bottom,] right)`.
  5405. /// @param hint_fwd_pd Primitive descriptor for a convolution
  5406. /// forward propagation primitive. It is used as a hint for
  5407. /// deciding which memory format to use.
  5408. /// @param attr Primitive attributes to use. Attributes are optional
  5409. /// and default to empty attributes.
  5410. /// @param allow_empty A flag signifying whether construction is
  5411. /// allowed to fail without throwing an exception. In this case an
  5412. /// empty object will be produced. This flag is optional and
  5413. /// defaults to false.
  5414. primitive_desc(const engine &aengine, algorithm aalgorithm,
  5415. const memory::desc &diff_src_desc,
  5416. const memory::desc &weights_desc,
  5417. const memory::desc &diff_dst_desc, const memory::dims &strides,
  5418. const memory::dims &padding_l, const memory::dims &padding_r,
  5419. const convolution_forward::primitive_desc &hint_fwd_pd,
  5420. const primitive_attr &attr = default_attr(),
  5421. bool allow_empty = false)
  5422. : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
  5423. diff_dst_desc, strides, nullptr, padding_l, padding_r,
  5424. hint_fwd_pd, attr, allow_empty) {}
  5425. /// Constructs a primitive descriptor for a convolution backward
  5426. /// propagation primitive.
  5427. ///
  5428. /// @note
  5429. /// All the memory descriptors may be initialized with the
  5430. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5431. ///
  5432. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  5433. /// contain values for spatial dimensions only and hence must have the
  5434. /// same number of elements as there are spatial dimensions. The order
  5435. /// of values is the same as in the tensor: depth (for 3D tensors),
  5436. /// height (for 3D and 2D tensors), and width.
  5437. ///
  5438. /// @param aengine Engine to use.
  5439. /// @param aalgorithm Convolution algorithm. Possible values are
  5440. /// #dnnl::algorithm::convolution_direct,
  5441. /// #dnnl::algorithm::convolution_winograd, and
  5442. /// #dnnl::algorithm::convolution_auto.
  5443. /// @param diff_src_desc Diff source memory descriptor.
  5444. /// @param weights_desc Weights memory descriptor.
  5445. /// @param diff_dst_desc Diff destination memory descriptor.
  5446. /// @param strides Strides for each spatial dimension.
  5447. /// @param dilates Dilations for each spatial dimension. A zero value
  5448. /// means no dilation in the corresponding dimension.
  5449. /// @param padding_l Vector of padding values for low indices for each
  5450. /// spatial dimension `([[front,] top,] left)`.
  5451. /// @param padding_r Vector of padding values for high indices for
  5452. /// each spatial dimension `([[back,] bottom,] right)`.
  5453. /// @param hint_fwd_pd Primitive descriptor for a convolution
  5454. /// forward propagation primitive. It is used as a hint for
  5455. /// deciding which memory format to use.
  5456. /// @param attr Primitive attributes to use. Attributes are optional
  5457. /// and default to empty attributes.
  5458. /// @param allow_empty A flag signifying whether construction is
  5459. /// allowed to fail without throwing an exception. In this case an
  5460. /// empty object will be produced. This flag is optional and
  5461. /// defaults to false.
  5462. primitive_desc(const engine &aengine, algorithm aalgorithm,
  5463. const memory::desc &diff_src_desc,
  5464. const memory::desc &weights_desc,
  5465. const memory::desc &diff_dst_desc, const memory::dims &strides,
  5466. const memory::dims &dilates, const memory::dims &padding_l,
  5467. const memory::dims &padding_r,
  5468. const convolution_forward::primitive_desc &hint_fwd_pd,
  5469. const primitive_attr &attr = default_attr(),
  5470. bool allow_empty = false)
  5471. : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
  5472. diff_dst_desc, strides, &dilates, padding_l, padding_r,
  5473. hint_fwd_pd, attr, allow_empty) {}
  5474. /// Constructs a primitive descriptor for a convolution backward
  5475. /// propagation primitive from a C API primitive descriptor that must
  5476. /// have a matching kind.
  5477. ///
  5478. /// @param pd C API primitive descriptor for a convolution backward
  5479. /// propagation primitive.
  5480. primitive_desc(dnnl_primitive_desc_t pd)
  5481. : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
  5482. dnnl::prop_kind::backward_data) {}
  5483. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  5484. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  5485. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  5486. memory::desc weights_desc() const { return base::weights_desc(0); }
  5487. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  5488. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  5489. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  5490. algorithm get_algorithm() const { return base::get_algorithm(); }
  5491. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  5492. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  5493. /// @copydoc dnnl::primitive_desc_base::get_strides()const
  5494. memory::dims get_strides() const { return base::get_strides(); }
  5495. /// @copydoc dnnl::primitive_desc_base::get_dilations()const
  5496. memory::dims get_dilations() const { return base::get_dilations(); }
  5497. /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
  5498. memory::dims get_padding_l() const { return base::get_padding_l(); }
  5499. /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
  5500. memory::dims get_padding_r() const { return base::get_padding_r(); }
  5501. private:
  5502. primitive_desc(const engine &aengine, algorithm aalgorithm,
  5503. const memory::desc &diff_src_desc,
  5504. const memory::desc &weights_desc,
  5505. const memory::desc &diff_dst_desc, const memory::dims &strides,
  5506. const memory::dims *dilates, const memory::dims &padding_l,
  5507. const memory::dims &padding_r,
  5508. const convolution_forward::primitive_desc &hint_fwd_pd,
  5509. const primitive_attr &attr, bool allow_empty) {
  5510. memory::validate_dims(strides, diff_src_desc.get_ndims() - 2);
  5511. memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2);
  5512. memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2);
  5513. if (dilates)
  5514. memory::validate_dims(*dilates, diff_src_desc.get_ndims() - 2);
  5515. dnnl_primitive_desc_t pd = nullptr;
  5516. dnnl_status_t status
  5517. = dnnl_convolution_backward_data_primitive_desc_create(&pd,
  5518. aengine.get(), convert_to_c(aalgorithm),
  5519. diff_src_desc.get(), weights_desc.get(),
  5520. diff_dst_desc.get(), &strides[0],
  5521. optional_arg(dilates), &padding_l[0], &padding_r[0],
  5522. hint_fwd_pd.get(), attr.get());
  5523. if (!allow_empty)
  5524. error::wrap_c_api(status,
  5525. "could not create a primitive descriptor for "
  5526. "the convolution backward propagation primitive. Run "
  5527. "workload with environment variable ONEDNN_VERBOSE=all "
  5528. "to get additional diagnostic information.");
  5529. reset(pd);
  5530. }
  5531. };
  5532. /// Default constructor. Produces an empty object.
  5533. convolution_backward_data() = default;
  5534. /// Constructs a convolution backward propagation primitive.
  5535. /// @param pd Primitive descriptor for a convolution backward propagation
  5536. /// primitive.
  5537. convolution_backward_data(const primitive_desc &pd) : primitive(pd) {}
  5538. /// Constructs a convolution backward propagation primitive from a cache
  5539. /// blob.
  5540. /// @param pd Primitive descriptor for a convolution backward propagation
  5541. /// primitive.
  5542. /// @param cache_blob Cache blob.
  5543. convolution_backward_data(
  5544. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  5545. : primitive(pd, cache_blob) {}
  5546. };
  5547. /// Convolution weights gradient primitive.
  5548. struct convolution_backward_weights : public primitive {
  5549. /// Primitive descriptor for a convolution weights gradient primitive.
  5550. struct primitive_desc : public dnnl::primitive_desc {
  5551. /// Default constructor. Produces an empty object.
  5552. primitive_desc() = default;
  5553. /// Constructs a primitive descriptor for a convolution weights gradient
  5554. /// primitive with bias.
  5555. ///
  5556. /// @note
  5557. /// All the memory descriptors may be initialized with the
  5558. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5559. ///
  5560. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  5561. /// for spatial dimensions only and hence must have the same number of
  5562. /// elements as there are spatial dimensions. The order of values is
  5563. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  5564. /// and 2D tensors), and width.
  5565. ///
  5566. /// @param aengine Engine to use.
  5567. /// @param aalgorithm Convolution algorithm. Possible values are
  5568. /// #dnnl::algorithm::convolution_direct,
  5569. /// #dnnl::algorithm::convolution_winograd, and
  5570. /// #dnnl::algorithm::convolution_auto.
  5571. /// @param src_desc Source memory descriptor.
  5572. /// @param diff_weights_desc Diff weights memory descriptor.
  5573. /// @param diff_bias_desc Diff bias memory descriptor. Passing zero
  5574. /// memory descriptor disables the bias term.
  5575. /// @param diff_dst_desc Diff destination memory descriptor.
  5576. /// @param strides Strides for each spatial dimension.
  5577. /// @param padding_l Vector of padding values for low indices for each
  5578. /// spatial dimension `([[front,] top,] left)`.
  5579. /// @param padding_r Vector of padding values for high indices for
  5580. /// each spatial dimension `([[back,] bottom,] right)`.
  5581. /// @param hint_fwd_pd Primitive descriptor for a convolution
  5582. /// forward propagation primitive. It is used as a hint for
  5583. /// deciding which memory format to use.
  5584. /// @param attr Primitive attributes to use. Attributes are optional
  5585. /// and default to empty attributes.
  5586. /// @param allow_empty A flag signifying whether construction is
  5587. /// allowed to fail without throwing an exception. In this case an
  5588. /// empty object will be produced. This flag is optional and
  5589. /// defaults to false.
  5590. primitive_desc(const engine &aengine, algorithm aalgorithm,
  5591. const memory::desc &src_desc,
  5592. const memory::desc &diff_weights_desc,
  5593. const memory::desc &diff_bias_desc,
  5594. const memory::desc &diff_dst_desc, const memory::dims &strides,
  5595. const memory::dims &padding_l, const memory::dims &padding_r,
  5596. const convolution_forward::primitive_desc &hint_fwd_pd,
  5597. const primitive_attr &attr = default_attr(),
  5598. bool allow_empty = false)
  5599. : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
  5600. &diff_bias_desc, diff_dst_desc, strides, nullptr, padding_l,
  5601. padding_r, hint_fwd_pd, attr, allow_empty) {}
  5602. /// Constructs a primitive descriptor for a convolution weights gradient
  5603. /// primitive without bias.
  5604. ///
  5605. /// @note
  5606. /// All the memory descriptors may be initialized with the
  5607. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5608. ///
  5609. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  5610. /// for spatial dimensions only and hence must have the same number of
  5611. /// elements as there are spatial dimensions. The order of values is
  5612. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  5613. /// and 2D tensors), and width.
  5614. ///
  5615. /// @param aengine Engine to use.
  5616. /// @param aalgorithm Convolution algorithm. Possible values are
  5617. /// #dnnl::algorithm::convolution_direct,
  5618. /// #dnnl::algorithm::convolution_winograd, and
  5619. /// #dnnl::algorithm::convolution_auto.
  5620. /// @param src_desc Source memory descriptor.
  5621. /// @param diff_weights_desc Diff weights memory descriptor.
  5622. /// @param diff_dst_desc Diff destination memory descriptor.
  5623. /// @param strides Strides for each spatial dimension.
  5624. /// @param padding_l Vector of padding values for low indices for each
  5625. /// spatial dimension `([[front,] top,] left)`.
  5626. /// @param padding_r Vector of padding values for high indices for
  5627. /// each spatial dimension `([[back,] bottom,] right)`.
  5628. /// @param hint_fwd_pd Primitive descriptor for a convolution
  5629. /// forward propagation primitive. It is used as a hint for
  5630. /// deciding which memory format to use.
  5631. /// @param attr Primitive attributes to use. Attributes are optional
  5632. /// and default to empty attributes.
  5633. /// @param allow_empty A flag signifying whether construction is
  5634. /// allowed to fail without throwing an exception. In this case an
  5635. /// empty object will be produced. This flag is optional and
  5636. /// defaults to false.
  5637. primitive_desc(const engine &aengine, algorithm aalgorithm,
  5638. const memory::desc &src_desc,
  5639. const memory::desc &diff_weights_desc,
  5640. const memory::desc &diff_dst_desc, const memory::dims &strides,
  5641. const memory::dims &padding_l, const memory::dims &padding_r,
  5642. const convolution_forward::primitive_desc &hint_fwd_pd,
  5643. const primitive_attr &attr = default_attr(),
  5644. bool allow_empty = false)
  5645. : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
  5646. nullptr, diff_dst_desc, strides, nullptr, padding_l,
  5647. padding_r, hint_fwd_pd, attr, allow_empty) {}
  5648. /// Constructs a primitive descriptor for a convolution weights
  5649. /// gradient primitive with bias.
  5650. ///
  5651. /// @note
  5652. /// All the memory descriptors may be initialized with the
  5653. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5654. ///
  5655. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  5656. /// contain values for spatial dimensions only and hence must have the
  5657. /// same number of elements as there are spatial dimensions. The order
  5658. /// of values is the same as in the tensor: depth (for 3D tensors),
  5659. /// height (for 3D and 2D tensors), and width.
  5660. ///
  5661. /// @param aengine Engine to use.
  5662. /// @param aalgorithm Convolution algorithm. Possible values are
  5663. /// #dnnl::algorithm::convolution_direct,
  5664. /// #dnnl::algorithm::convolution_winograd, and
  5665. /// #dnnl::algorithm::convolution_auto.
  5666. /// @param src_desc Source memory descriptor.
  5667. /// @param diff_weights_desc Diff weights memory descriptor.
  5668. /// @param diff_bias_desc Diff bias memory descriptor. Passing zero
  5669. /// memory descriptor disables the bias term.
  5670. /// @param diff_dst_desc Diff destination memory descriptor.
  5671. /// @param strides Strides for each spatial dimension.
  5672. /// @param dilates Dilations for each spatial dimension. A zero value
  5673. /// means no dilation in the corresponding dimension.
  5674. /// @param padding_l Vector of padding values for low indices for each
  5675. /// spatial dimension `([[front,] top,] left)`.
  5676. /// @param padding_r Vector of padding values for high indices for
  5677. /// each spatial dimension `([[back,] bottom,] right)`.
  5678. /// @param hint_fwd_pd Primitive descriptor for a convolution
  5679. /// forward propagation primitive. It is used as a hint for
  5680. /// deciding which memory format to use.
  5681. /// @param attr Primitive attributes to use. Attributes are optional
  5682. /// and default to empty attributes.
  5683. /// @param allow_empty A flag signifying whether construction is
  5684. /// allowed to fail without throwing an exception. In this case an
  5685. /// empty object will be produced. This flag is optional and
  5686. /// defaults to false.
  5687. primitive_desc(const engine &aengine, algorithm aalgorithm,
  5688. const memory::desc &src_desc,
  5689. const memory::desc &diff_weights_desc,
  5690. const memory::desc &diff_bias_desc,
  5691. const memory::desc &diff_dst_desc, const memory::dims &strides,
  5692. const memory::dims &dilates, const memory::dims &padding_l,
  5693. const memory::dims &padding_r,
  5694. const convolution_forward::primitive_desc &hint_fwd_pd,
  5695. const primitive_attr &attr = default_attr(),
  5696. bool allow_empty = false)
  5697. : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
  5698. &diff_bias_desc, diff_dst_desc, strides, &dilates,
  5699. padding_l, padding_r, hint_fwd_pd, attr, allow_empty) {}
  5700. /// Constructs a primitive descriptor for a convolution weights
  5701. /// gradient primitive without bias.
  5702. ///
  5703. /// @note
  5704. /// All the memory descriptors may be initialized with the
  5705. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5706. ///
  5707. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  5708. /// contain values for spatial dimensions only and hence must have the
  5709. /// same number of elements as there are spatial dimensions. The order
  5710. /// of values is the same as in the tensor: depth (for 3D tensors),
  5711. /// height (for 3D and 2D tensors), and width.
  5712. ///
  5713. /// @param aengine Engine to use.
  5714. /// @param aalgorithm Convolution algorithm. Possible values are
  5715. /// #dnnl::algorithm::convolution_direct,
  5716. /// #dnnl::algorithm::convolution_winograd, and
  5717. /// #dnnl::algorithm::convolution_auto.
  5718. /// @param src_desc Source memory descriptor.
  5719. /// @param diff_weights_desc Diff weights memory descriptor.
  5720. /// @param diff_dst_desc Diff destination memory descriptor.
  5721. /// @param strides Strides for each spatial dimension.
  5722. /// @param dilates Dilations for each spatial dimension. A zero value
  5723. /// means no dilation in the corresponding dimension.
  5724. /// @param padding_l Vector of padding values for low indices for each
  5725. /// spatial dimension `([[front,] top,] left)`.
  5726. /// @param padding_r Vector of padding values for high indices for
  5727. /// each spatial dimension `([[back,] bottom,] right)`.
  5728. /// @param hint_fwd_pd Primitive descriptor for a convolution
  5729. /// forward propagation primitive. It is used as a hint for
  5730. /// deciding which memory format to use.
  5731. /// @param attr Primitive attributes to use. Attributes are optional
  5732. /// and default to empty attributes.
  5733. /// @param allow_empty A flag signifying whether construction is
  5734. /// allowed to fail without throwing an exception. In this case an
  5735. /// empty object will be produced. This flag is optional and
  5736. /// defaults to false.
  5737. primitive_desc(const engine &aengine, algorithm aalgorithm,
  5738. const memory::desc &src_desc,
  5739. const memory::desc &diff_weights_desc,
  5740. const memory::desc &diff_dst_desc, const memory::dims &strides,
  5741. const memory::dims &dilates, const memory::dims &padding_l,
  5742. const memory::dims &padding_r,
  5743. const convolution_forward::primitive_desc &hint_fwd_pd,
  5744. const primitive_attr &attr = default_attr(),
  5745. bool allow_empty = false)
  5746. : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
  5747. nullptr, diff_dst_desc, strides, &dilates, padding_l,
  5748. padding_r, hint_fwd_pd, attr, allow_empty) {}
  5749. /// Constructs a primitive descriptor for a convolution weights gradient
  5750. /// primitive from a C API primitive descriptor that must have a
  5751. /// matching kind.
  5752. ///
  5753. /// @param pd C API primitive descriptor for a convolution weights
  5754. /// gradient primitive.
  5755. primitive_desc(dnnl_primitive_desc_t pd)
  5756. : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
  5757. dnnl::prop_kind::backward_weights) {}
  5758. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  5759. memory::desc src_desc() const { return base::src_desc(0); }
  5760. /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
  5761. memory::desc diff_weights_desc() const {
  5762. return base::diff_weights_desc(0);
  5763. }
  5764. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  5765. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  5766. /// Returns the diff bias memory descriptor.
  5767. /// @returns The diff bias memory descriptor.
  5768. /// @returns A zero memory descriptor of the primitive does not have a
  5769. /// diff bias parameter.
  5770. memory::desc diff_bias_desc() const {
  5771. return base::diff_weights_desc(1);
  5772. }
  5773. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  5774. algorithm get_algorithm() const { return base::get_algorithm(); }
  5775. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  5776. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  5777. /// @copydoc dnnl::primitive_desc_base::get_strides()const
  5778. memory::dims get_strides() const { return base::get_strides(); }
  5779. /// @copydoc dnnl::primitive_desc_base::get_dilations()const
  5780. memory::dims get_dilations() const { return base::get_dilations(); }
  5781. /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
  5782. memory::dims get_padding_l() const { return base::get_padding_l(); }
  5783. /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
  5784. memory::dims get_padding_r() const { return base::get_padding_r(); }
  5785. private:
  5786. primitive_desc(const engine &aengine, algorithm aalgorithm,
  5787. const memory::desc &src_desc,
  5788. const memory::desc &diff_weights_desc,
  5789. const memory::desc *diff_bias_desc,
  5790. const memory::desc &diff_dst_desc, const memory::dims &strides,
  5791. const memory::dims *dilates, const memory::dims &padding_l,
  5792. const memory::dims &padding_r,
  5793. const convolution_forward::primitive_desc &hint_fwd_pd,
  5794. const primitive_attr &attr, bool allow_empty) {
  5795. memory::validate_dims(strides, src_desc.get_ndims() - 2);
  5796. memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
  5797. memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
  5798. if (dilates)
  5799. memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
  5800. dnnl_primitive_desc_t pd = nullptr;
  5801. dnnl_status_t status
  5802. = dnnl_convolution_backward_weights_primitive_desc_create(
  5803. &pd, aengine.get(), convert_to_c(aalgorithm),
  5804. src_desc.get(), diff_weights_desc.get(),
  5805. optional_arg(diff_bias_desc), diff_dst_desc.get(),
  5806. &strides[0], optional_arg(dilates), &padding_l[0],
  5807. &padding_r[0], hint_fwd_pd.get(), attr.get());
  5808. if (!allow_empty)
  5809. error::wrap_c_api(status,
  5810. "could not create a primitive descriptor for "
  5811. "the convolution weights update primitive. Run "
  5812. "workload with environment variable ONEDNN_VERBOSE=all "
  5813. "to get additional diagnostic information.");
  5814. reset(pd);
  5815. }
  5816. };
  5817. /// Default constructor. Produces an empty object.
  5818. convolution_backward_weights() = default;
  5819. /// Constructs a convolution weights gradient primitive.
  5820. /// @param pd Primitive descriptor for a convolution weights gradient
  5821. /// primitive.
  5822. convolution_backward_weights(const primitive_desc &pd) : primitive(pd) {}
  5823. /// Constructs a convolution weights gradient primitive from a cache blob.
  5824. /// @param pd Primitive descriptor for a convolution weights gradient
  5825. /// primitive.
  5826. /// @param cache_blob Cache blob.
  5827. convolution_backward_weights(
  5828. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  5829. : primitive(pd, cache_blob) {}
  5830. };
  5831. /// @} dnnl_api_convolution
  5832. //
  5833. /// @addtogroup dnnl_api_deconvolution Deconvolution
  5834. ///
  5835. /// A primitive to perform 1D, 2D or 3D deconvolution. Supported variants are
  5836. /// forward propagation, backward propagation, and weights gradient with or
  5837. /// without bias.
  5838. ///
  5839. /// @{
  5840. /// Deconvolution forward propagation primitive.
  5841. struct deconvolution_forward : public primitive {
  5842. /// Primitive descriptor for a deconvolution forward propagation primitive.
  5843. struct primitive_desc : public dnnl::primitive_desc {
  5844. /// Default constructor. Produces an empty object.
  5845. primitive_desc() = default;
  5846. /// Constructs a primitive descriptor for a deconvolution forward
  5847. /// propagation primitive with bias.
  5848. ///
  5849. /// @note
  5850. /// All the memory descriptors may be initialized with the
  5851. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5852. ///
  5853. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  5854. /// for spatial dimensions only and hence must have the same number of
  5855. /// elements as there are spatial dimensions. The order of values is
  5856. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  5857. /// and 2D tensors), and width.
  5858. ///
  5859. /// @param aengine Engine to use.
  5860. /// @param aprop_kind Propagation kind. Possible values are
  5861. /// #dnnl::prop_kind::forward_training, and
  5862. /// #dnnl::prop_kind::forward_inference.
  5863. /// @param aalgorithm Deconvolution algorithm:
  5864. /// #dnnl::algorithm::deconvolution_direct, and
  5865. /// #dnnl::algorithm::deconvolution_winograd.
  5866. /// @param src_desc Source memory descriptor.
  5867. /// @param weights_desc Weights memory descriptor.
  5868. /// @param bias_desc Bias memory descriptor. Passing zero memory
  5869. /// descriptor disables the bias term.
  5870. /// @param dst_desc Destination memory descriptor.
  5871. /// @param strides Vector of strides for spatial dimension.
  5872. /// @param padding_l Vector of padding values for low indices for each
  5873. /// spatial dimension `([[front,] top,] left)`.
  5874. /// @param padding_r Vector of padding values for high indices for
  5875. /// each spatial dimension `([[back,] bottom,] right)`.
  5876. /// @param attr Primitive attributes to use. Attributes are optional
  5877. /// and default to empty attributes.
  5878. /// @param allow_empty A flag signifying whether construction is
  5879. /// allowed to fail without throwing an exception. In this case an
  5880. /// empty object will be produced. This flag is optional and
  5881. /// defaults to false.
  5882. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  5883. algorithm aalgorithm, const memory::desc &src_desc,
  5884. const memory::desc &weights_desc, const memory::desc &bias_desc,
  5885. const memory::desc &dst_desc, const memory::dims &strides,
  5886. const memory::dims &padding_l, const memory::dims &padding_r,
  5887. const primitive_attr &attr = default_attr(),
  5888. bool allow_empty = false)
  5889. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  5890. weights_desc, &bias_desc, dst_desc, strides, nullptr,
  5891. padding_l, padding_r, attr, allow_empty) {}
  5892. /// Constructs a primitive descriptor for a deconvolution forward
  5893. /// propagation primitive without bias.
  5894. ///
  5895. /// @note
  5896. /// All the memory descriptors may be initialized with the
  5897. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5898. ///
  5899. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  5900. /// for spatial dimensions only and hence must have the same number of
  5901. /// elements as there are spatial dimensions. The order of values is
  5902. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  5903. /// and 2D tensors), and width.
  5904. ///
  5905. /// @param aengine Engine to use.
  5906. /// @param aprop_kind Propagation kind. Possible values are
  5907. /// #dnnl::prop_kind::forward_training, and
  5908. /// #dnnl::prop_kind::forward_inference.
  5909. /// @param aalgorithm Deconvolution algorithm:
  5910. /// #dnnl::algorithm::deconvolution_direct, and
  5911. /// #dnnl::algorithm::deconvolution_winograd.
  5912. /// @param src_desc Source memory descriptor.
  5913. /// @param weights_desc Weights memory descriptor.
  5914. /// @param dst_desc Destination memory descriptor.
  5915. /// @param strides Vector of strides for spatial dimension.
  5916. /// @param padding_l Vector of padding values for low indices for each
  5917. /// spatial dimension `([[front,] top,] left)`.
  5918. /// @param padding_r Vector of padding values for high indices for
  5919. /// each spatial dimension `([[back,] bottom,] right)`.
  5920. /// @param attr Primitive attributes to use. Attributes are optional
  5921. /// and default to empty attributes.
  5922. /// @param allow_empty A flag signifying whether construction is
  5923. /// allowed to fail without throwing an exception. In this case an
  5924. /// empty object will be produced. This flag is optional and
  5925. /// defaults to false.
  5926. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  5927. algorithm aalgorithm, const memory::desc &src_desc,
  5928. const memory::desc &weights_desc, const memory::desc &dst_desc,
  5929. const memory::dims &strides, const memory::dims &padding_l,
  5930. const memory::dims &padding_r,
  5931. const primitive_attr &attr = default_attr(),
  5932. bool allow_empty = false)
  5933. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  5934. weights_desc, nullptr, dst_desc, strides, nullptr,
  5935. padding_l, padding_r, attr, allow_empty) {}
  5936. /// Constructs a primitive descriptor for a deconvolution forward
  5937. /// propagation primitive with bias.
  5938. ///
  5939. /// @note
  5940. /// All the memory descriptors may be initialized with the
  5941. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5942. ///
  5943. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  5944. /// contain values for spatial dimensions only and hence must have the
  5945. /// same number of elements as there are spatial dimensions. The order
  5946. /// of values is the same as in the tensor: depth (for 3D tensors),
  5947. /// height (for 3D and 2D tensors), and width.
  5948. ///
  5949. /// @param aengine Engine to use.
  5950. /// @param aprop_kind Propagation kind. Possible values are
  5951. /// #dnnl::prop_kind::forward_training, and
  5952. /// #dnnl::prop_kind::forward_inference.
  5953. /// @param aalgorithm Deconvolution algorithm:
  5954. /// #dnnl::algorithm::deconvolution_direct, and
  5955. /// #dnnl::algorithm::deconvolution_winograd.
  5956. /// @param src_desc Source memory descriptor.
  5957. /// @param weights_desc Weights memory descriptor.
  5958. /// @param bias_desc Bias memory descriptor. Passing zero memory
  5959. /// descriptor disables the bias term.
  5960. /// @param dst_desc Destination memory descriptor.
  5961. /// @param strides Vector of strides for spatial dimension.
  5962. /// @param dilates Dilations for each spatial dimension. A zero value
  5963. /// means no dilation in the corresponding dimension.
  5964. /// @param padding_l Vector of padding values for low indices for each
  5965. /// spatial dimension `([[front,] top,] left)`.
  5966. /// @param padding_r Vector of padding values for high indices for
  5967. /// each spatial dimension `([[back,] bottom,] right)`.
  5968. /// @param attr Primitive attributes to use. Attributes are optional
  5969. /// and default to empty attributes.
  5970. /// @param allow_empty A flag signifying whether construction is
  5971. /// allowed to fail without throwing an exception. In this case an
  5972. /// empty object will be produced. This flag is optional and
  5973. /// defaults to false.
  5974. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  5975. algorithm aalgorithm, const memory::desc &src_desc,
  5976. const memory::desc &weights_desc, const memory::desc &bias_desc,
  5977. const memory::desc &dst_desc, const memory::dims &strides,
  5978. const memory::dims &dilates, const memory::dims &padding_l,
  5979. const memory::dims &padding_r,
  5980. const primitive_attr &attr = default_attr(),
  5981. bool allow_empty = false)
  5982. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  5983. weights_desc, &bias_desc, dst_desc, strides, &dilates,
  5984. padding_l, padding_r, attr, allow_empty) {}
  5985. /// Constructs a primitive descriptor for a deconvolution forward
  5986. /// propagation primitive without bias.
  5987. ///
  5988. /// @note
  5989. /// All the memory descriptors may be initialized with the
  5990. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5991. ///
  5992. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  5993. /// contain values for spatial dimensions only and hence must have the
  5994. /// same number of elements as there are spatial dimensions. The order
  5995. /// of values is the same as in the tensor: depth (for 3D tensors),
  5996. /// height (for 3D and 2D tensors), and width.
  5997. ///
  5998. /// @param aengine Engine to use.
  5999. /// @param aprop_kind Propagation kind. Possible values are
  6000. /// #dnnl::prop_kind::forward_training, and
  6001. /// #dnnl::prop_kind::forward_inference.
  6002. /// @param aalgorithm Deconvolution algorithm:
  6003. /// #dnnl::algorithm::deconvolution_direct, and
  6004. /// #dnnl::algorithm::deconvolution_winograd.
  6005. /// @param src_desc Source memory descriptor.
  6006. /// @param weights_desc Weights memory descriptor.
  6007. /// @param dst_desc Destination memory descriptor.
  6008. /// @param strides Vector of strides for spatial dimension.
  6009. /// @param dilates Dilations for each spatial dimension. A zero value
  6010. /// means no dilation in the corresponding dimension.
  6011. /// @param padding_l Vector of padding values for low indices for each
  6012. /// spatial dimension `([[front,] top,] left)`.
  6013. /// @param padding_r Vector of padding values for high indices for
  6014. /// each spatial dimension `([[back,] bottom,] right)`.
  6015. /// @param attr Primitive attributes to use. Attributes are optional
  6016. /// and default to empty attributes.
  6017. /// @param allow_empty A flag signifying whether construction is
  6018. /// allowed to fail without throwing an exception. In this case an
  6019. /// empty object will be produced. This flag is optional and
  6020. /// defaults to false.
  6021. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  6022. algorithm aalgorithm, const memory::desc &src_desc,
  6023. const memory::desc &weights_desc, const memory::desc &dst_desc,
  6024. const memory::dims &strides, const memory::dims &dilates,
  6025. const memory::dims &padding_l, const memory::dims &padding_r,
  6026. const primitive_attr &attr = default_attr(),
  6027. bool allow_empty = false)
  6028. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  6029. weights_desc, nullptr, dst_desc, strides, &dilates,
  6030. padding_l, padding_r, attr, allow_empty) {}
  6031. /// Constructs a primitive descriptor for a deconvolution forward
  6032. /// propagation primitive from a C API primitive descriptor that must
  6033. /// have a matching kind.
  6034. ///
  6035. /// @param pd C API primitive descriptor for a deconvolution forward
  6036. /// propagation primitive.
  6037. primitive_desc(dnnl_primitive_desc_t pd)
  6038. : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
  6039. dnnl::prop_kind::forward_training,
  6040. dnnl::prop_kind::forward_inference) {}
  6041. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  6042. memory::desc src_desc() const { return base::src_desc(0); }
  6043. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  6044. memory::desc weights_desc() const { return base::weights_desc(0); }
  6045. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  6046. memory::desc dst_desc() const { return base::dst_desc(0); }
  6047. /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
  6048. memory::desc bias_desc() const { return base::weights_desc(1); }
  6049. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  6050. algorithm get_algorithm() const { return base::get_algorithm(); }
  6051. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  6052. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  6053. /// @copydoc dnnl::primitive_desc_base::get_strides()const
  6054. memory::dims get_strides() const { return base::get_strides(); }
  6055. /// @copydoc dnnl::primitive_desc_base::get_dilations()const
  6056. memory::dims get_dilations() const { return base::get_dilations(); }
  6057. /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
  6058. memory::dims get_padding_l() const { return base::get_padding_l(); }
  6059. /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
  6060. memory::dims get_padding_r() const { return base::get_padding_r(); }
  6061. private:
  6062. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  6063. algorithm aalgorithm, const memory::desc &src_desc,
  6064. const memory::desc &weights_desc, const memory::desc *bias_desc,
  6065. const memory::desc &dst_desc, const memory::dims &strides,
  6066. const memory::dims *dilates, const memory::dims &padding_l,
  6067. const memory::dims &padding_r, const primitive_attr &attr,
  6068. bool allow_empty) {
  6069. memory::validate_dims(strides, src_desc.get_ndims() - 2);
  6070. memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
  6071. memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
  6072. if (dilates)
  6073. memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
  6074. dnnl_primitive_desc_t pd = nullptr;
  6075. dnnl_status_t status
  6076. = dnnl_deconvolution_forward_primitive_desc_create(&pd,
  6077. aengine.get(), dnnl::convert_to_c(aprop_kind),
  6078. convert_to_c(aalgorithm), src_desc.get(),
  6079. weights_desc.get(), optional_arg(bias_desc),
  6080. dst_desc.get(), &strides[0], optional_arg(dilates),
  6081. &padding_l[0], &padding_r[0], attr.get());
  6082. if (!allow_empty)
  6083. error::wrap_c_api(status,
  6084. "could not create a primitive descriptor for "
  6085. "the deconvolution forward propagation primitive. Run "
  6086. "workload with environment variable ONEDNN_VERBOSE=all "
  6087. "to get additional diagnostic information.");
  6088. reset(pd);
  6089. }
  6090. };
  6091. /// Default constructor. Produces an empty object.
  6092. deconvolution_forward() = default;
  6093. /// Constructs a deconvolution forward propagation primitive.
  6094. /// @param pd Primitive descriptor for a deconvolution forward propagation
  6095. /// primitive.
  6096. deconvolution_forward(const primitive_desc &pd) : primitive(pd) {}
  6097. /// Constructs a deconvolution forward propagation primitive from a cache
  6098. /// blob.
  6099. /// @param pd Primitive descriptor for a deconvolution forward propagation
  6100. /// primitive.
  6101. /// @param cache_blob Cache blob.
  6102. deconvolution_forward(
  6103. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  6104. : primitive(pd, cache_blob) {}
  6105. };
  6106. /// Deconvolution backward propagation primitive.
  6107. struct deconvolution_backward_data : public primitive {
  6108. /// Primitive descriptor for a deconvolution backward propagation primitive.
  6109. struct primitive_desc : public dnnl::primitive_desc {
  6110. /// Default constructor. Produces an empty object.
  6111. primitive_desc() = default;
  6112. /// Constructs a primitive descriptor for a deconvolution backward
  6113. /// propagation primitive.
  6114. ///
  6115. /// @note
  6116. /// All the memory descriptors may be initialized with the
  6117. /// #dnnl::memory::format_tag::any value of @p format_tag.
  6118. ///
  6119. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  6120. /// for spatial dimensions only and hence must have the same number of
  6121. /// elements as there are spatial dimensions. The order of values is
  6122. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  6123. /// and 2D tensors), and width.
  6124. ///
  6125. /// @param aengine Engine to use.
  6126. /// @param aalgorithm Deconvolution algorithm
  6127. /// (#dnnl::algorithm::convolution_direct,
  6128. /// #dnnl::algorithm::convolution_winograd).
  6129. /// @param diff_src_desc Diff source memory descriptor.
  6130. /// @param weights_desc Weights memory descriptor.
  6131. /// @param diff_dst_desc Diff destination memory descriptor.
  6132. /// @param strides Strides for each spatial dimension.
  6133. /// @param padding_l Vector of padding values for low indices for each
  6134. /// spatial dimension `([[front,] top,] left)`.
  6135. /// @param padding_r Vector of padding values for high indices for
  6136. /// each spatial dimension `([[back,] bottom,] right)`.
  6137. /// @param hint_fwd_pd Primitive descriptor for a deconvolution
  6138. /// forward propagation primitive. It is used as a hint for
  6139. /// deciding which memory format to use.
  6140. /// @param attr Primitive attributes to use. Attributes are optional
  6141. /// and default to empty attributes.
  6142. /// @param allow_empty A flag signifying whether construction is
  6143. /// allowed to fail without throwing an exception. In this case an
  6144. /// empty object will be produced. This flag is optional and
  6145. /// defaults to false.
  6146. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6147. const memory::desc &diff_src_desc,
  6148. const memory::desc &weights_desc,
  6149. const memory::desc &diff_dst_desc, const memory::dims &strides,
  6150. const memory::dims &padding_l, const memory::dims &padding_r,
  6151. const deconvolution_forward::primitive_desc &hint_fwd_pd,
  6152. const primitive_attr &attr = default_attr(),
  6153. bool allow_empty = false)
  6154. : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
  6155. diff_dst_desc, strides, nullptr, padding_l, padding_r,
  6156. hint_fwd_pd, attr, allow_empty) {}
  6157. /// Constructs a primitive descriptor for a deconvolution backward
  6158. /// propagation primitive.
  6159. ///
  6160. /// @note
  6161. /// All the memory descriptors may be initialized with the
  6162. /// #dnnl::memory::format_tag::any value of @p format_tag.
  6163. ///
  6164. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  6165. /// contain values for spatial dimensions only and hence must have the
  6166. /// same number of elements as there are spatial dimensions. The order
  6167. /// of values is the same as in the tensor: depth (for 3D tensors),
  6168. /// height (for 3D and 2D tensors), and width.
  6169. ///
  6170. /// @param aengine Engine to use.
  6171. /// @param aalgorithm Deconvolution algorithm
  6172. /// (#dnnl::algorithm::convolution_direct,
  6173. /// #dnnl::algorithm::convolution_winograd).
  6174. /// @param diff_src_desc Diff source memory descriptor.
  6175. /// @param weights_desc Weights memory descriptor.
  6176. /// @param diff_dst_desc Diff destination memory descriptor.
  6177. /// @param strides Strides for each spatial dimension.
  6178. /// @param dilates Dilations for each spatial dimension. A zero value
  6179. /// means no dilation in the corresponding dimension.
  6180. /// @param padding_l Vector of padding values for low indices for each
  6181. /// spatial dimension `([[front,] top,] left)`.
  6182. /// @param padding_r Vector of padding values for high indices for
  6183. /// each spatial dimension `([[back,] bottom,] right)`.
  6184. /// @param hint_fwd_pd Primitive descriptor for a deconvolution
  6185. /// forward propagation primitive. It is used as a hint for
  6186. /// deciding which memory format to use.
  6187. /// @param attr Primitive attributes to use. Attributes are optional
  6188. /// and default to empty attributes.
  6189. /// @param allow_empty A flag signifying whether construction is
  6190. /// allowed to fail without throwing an exception. In this case an
  6191. /// empty object will be produced. This flag is optional and
  6192. /// defaults to false.
  6193. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6194. const memory::desc &diff_src_desc,
  6195. const memory::desc &weights_desc,
  6196. const memory::desc &diff_dst_desc, const memory::dims &strides,
  6197. const memory::dims &dilates, const memory::dims &padding_l,
  6198. const memory::dims &padding_r,
  6199. const deconvolution_forward::primitive_desc &hint_fwd_pd,
  6200. const primitive_attr &attr = default_attr(),
  6201. bool allow_empty = false)
  6202. : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
  6203. diff_dst_desc, strides, &dilates, padding_l, padding_r,
  6204. hint_fwd_pd, attr, allow_empty) {}
  6205. /// Constructs a primitive descriptor for a deconvolution backward
  6206. /// propagation primitive from a C API primitive descriptor that must
  6207. /// have a matching kind.
  6208. ///
  6209. /// @param pd C API primitive descriptor for a deconvolution backward
  6210. /// propagation primitive.
  6211. primitive_desc(dnnl_primitive_desc_t pd)
  6212. : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
  6213. dnnl::prop_kind::backward_data) {}
  6214. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  6215. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  6216. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  6217. memory::desc weights_desc() const { return base::weights_desc(0); }
  6218. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  6219. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  6220. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  6221. algorithm get_algorithm() const { return base::get_algorithm(); }
  6222. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  6223. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  6224. /// @copydoc dnnl::primitive_desc_base::get_strides()const
  6225. memory::dims get_strides() const { return base::get_strides(); }
  6226. /// @copydoc dnnl::primitive_desc_base::get_dilations()const
  6227. memory::dims get_dilations() const { return base::get_dilations(); }
  6228. /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
  6229. memory::dims get_padding_l() const { return base::get_padding_l(); }
  6230. /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
  6231. memory::dims get_padding_r() const { return base::get_padding_r(); }
  6232. private:
  6233. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6234. const memory::desc &diff_src_desc,
  6235. const memory::desc &weights_desc,
  6236. const memory::desc &diff_dst_desc, const memory::dims &strides,
  6237. const memory::dims *dilates, const memory::dims &padding_l,
  6238. const memory::dims &padding_r,
  6239. const deconvolution_forward::primitive_desc &hint_fwd_pd,
  6240. const primitive_attr &attr, bool allow_empty) {
  6241. memory::validate_dims(strides, diff_src_desc.get_ndims() - 2);
  6242. memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2);
  6243. memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2);
  6244. if (dilates)
  6245. memory::validate_dims(*dilates, diff_src_desc.get_ndims() - 2);
  6246. dnnl_primitive_desc_t pd = nullptr;
  6247. dnnl_status_t status
  6248. = dnnl_deconvolution_backward_data_primitive_desc_create(
  6249. &pd, aengine.get(), convert_to_c(aalgorithm),
  6250. diff_src_desc.get(), weights_desc.get(),
  6251. diff_dst_desc.get(), &strides[0],
  6252. optional_arg(dilates), &padding_l[0], &padding_r[0],
  6253. hint_fwd_pd.get(), attr.get());
  6254. if (!allow_empty)
  6255. error::wrap_c_api(status,
  6256. "could not create a primitive descriptor for "
  6257. "the deconvolution backward propagation primitive. Run "
  6258. "workload with environment variable ONEDNN_VERBOSE=all "
  6259. "to get additional diagnostic information.");
  6260. reset(pd);
  6261. }
  6262. };
  6263. /// Default constructor. Produces an empty object.
  6264. deconvolution_backward_data() = default;
  6265. /// Constructs a deconvolution backward propagation primitive.
  6266. /// @param pd Primitive descriptor for a deconvolution backward propagation
  6267. /// primitive.
  6268. deconvolution_backward_data(const primitive_desc &pd) : primitive(pd) {}
  6269. /// Constructs a deconvolution backward propagation primitive from a cache
  6270. /// blob.
  6271. /// @param pd Primitive descriptor for a deconvolution backward propagation
  6272. /// primitive.
  6273. /// @param cache_blob Cache blob.
  6274. deconvolution_backward_data(
  6275. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  6276. : primitive(pd, cache_blob) {}
  6277. };
  6278. /// Deconvolution weights gradient primitive.
  6279. struct deconvolution_backward_weights : public primitive {
  6280. /// Primitive descriptor for a deconvolution weights gradient primitive.
  6281. struct primitive_desc : public dnnl::primitive_desc {
  6282. /// Default constructor. Produces an empty object.
  6283. primitive_desc() = default;
  6284. /// Constructs a primitive descriptor for a deconvolution weights
  6285. /// gradient primitive with bias.
  6286. ///
  6287. /// @note
  6288. /// All the memory descriptors may be initialized with the
  6289. /// #dnnl::memory::format_tag::any value of @p format_tag.
  6290. ///
  6291. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  6292. /// for spatial dimensions only and hence must have the same number of
  6293. /// elements as there are spatial dimensions. The order of values is
  6294. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  6295. /// and 2D tensors), and width.
  6296. ///
  6297. /// @param aengine Engine to use.
  6298. /// @param aalgorithm Deconvolution algorithm. Possible values are
  6299. /// #dnnl::algorithm::deconvolution_direct, and
  6300. /// #dnnl::algorithm::deconvolution_winograd.
  6301. /// @param src_desc Source memory descriptor.
  6302. /// @param diff_weights_desc Diff weights memory descriptor.
  6303. /// @param diff_bias_desc Diff bias memory descriptor. Passing zero
  6304. /// memory descriptor disables the bias term.
  6305. /// @param diff_dst_desc Diff destination memory descriptor.
  6306. /// @param strides Strides for each spatial dimension.
  6307. /// @param padding_l Vector of padding values for low indices for each
  6308. /// spatial dimension `([[front,] top,] left)`.
  6309. /// @param padding_r Vector of padding values for high indices for
  6310. /// each spatial dimension `([[back,] bottom,] right)`.
  6311. /// @param hint_fwd_pd Primitive descriptor for a deconvolution
  6312. /// forward propagation primitive. It is used as a hint for
  6313. /// deciding which memory format to use.
  6314. /// @param attr Primitive attributes to use. Attributes are optional
  6315. /// and default to empty attributes.
  6316. /// @param allow_empty A flag signifying whether construction is
  6317. /// allowed to fail without throwing an exception. In this case an
  6318. /// empty object will be produced. This flag is optional and
  6319. /// defaults to false.
  6320. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6321. const memory::desc &src_desc,
  6322. const memory::desc &diff_weights_desc,
  6323. const memory::desc &diff_bias_desc,
  6324. const memory::desc &diff_dst_desc, const memory::dims &strides,
  6325. const memory::dims &padding_l, const memory::dims &padding_r,
  6326. const deconvolution_forward::primitive_desc &hint_fwd_pd,
  6327. const primitive_attr &attr = default_attr(),
  6328. bool allow_empty = false)
  6329. : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
  6330. &diff_bias_desc, diff_dst_desc, strides, nullptr, padding_l,
  6331. padding_r, hint_fwd_pd, attr, allow_empty) {}
  6332. /// Constructs a primitive descriptor for a deconvolution weights
  6333. /// gradient primitive without bias.
  6334. ///
  6335. /// @note
  6336. /// All the memory descriptors may be initialized with the
  6337. /// #dnnl::memory::format_tag::any value of @p format_tag.
  6338. ///
  6339. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  6340. /// for spatial dimensions only and hence must have the same number of
  6341. /// elements as there are spatial dimensions. The order of values is
  6342. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  6343. /// and 2D tensors), and width.
  6344. ///
  6345. /// @param aengine Engine to use.
  6346. /// @param aalgorithm Deconvolution algorithm. Possible values are
  6347. /// #dnnl::algorithm::deconvolution_direct, and
  6348. /// #dnnl::algorithm::deconvolution_winograd.
  6349. /// @param src_desc Source memory descriptor.
  6350. /// @param diff_weights_desc Diff weights memory descriptor.
  6351. /// @param diff_dst_desc Diff destination memory descriptor.
  6352. /// @param strides Strides for each spatial dimension.
  6353. /// @param padding_l Vector of padding values for low indices for each
  6354. /// spatial dimension `([[front,] top,] left)`.
  6355. /// @param padding_r Vector of padding values for high indices for
  6356. /// each spatial dimension `([[back,] bottom,] right)`.
  6357. /// @param hint_fwd_pd Primitive descriptor for a deconvolution
  6358. /// forward propagation primitive. It is used as a hint for
  6359. /// deciding which memory format to use.
  6360. /// @param attr Primitive attributes to use. Attributes are optional
  6361. /// and default to empty attributes.
  6362. /// @param allow_empty A flag signifying whether construction is
  6363. /// allowed to fail without throwing an exception. In this case an
  6364. /// empty object will be produced. This flag is optional and
  6365. /// defaults to false.
  6366. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6367. const memory::desc &src_desc,
  6368. const memory::desc &diff_weights_desc,
  6369. const memory::desc &diff_dst_desc, const memory::dims &strides,
  6370. const memory::dims &padding_l, const memory::dims &padding_r,
  6371. const deconvolution_forward::primitive_desc &hint_fwd_pd,
  6372. const primitive_attr &attr = default_attr(),
  6373. bool allow_empty = false)
  6374. : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
  6375. nullptr, diff_dst_desc, strides, nullptr, padding_l,
  6376. padding_r, hint_fwd_pd, attr, allow_empty) {}
  6377. /// Constructs a primitive descriptor for a deconvolution weights
  6378. /// gradient primitive with bias.
  6379. ///
  6380. /// @note
  6381. /// All the memory descriptors may be initialized with the
  6382. /// #dnnl::memory::format_tag::any value of @p format_tag.
  6383. ///
  6384. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  6385. /// contain values for spatial dimensions only and hence must have the
  6386. /// same number of elements as there are spatial dimensions. The order
  6387. /// of values is the same as in the tensor: depth (for 3D tensors),
  6388. /// height (for 3D and 2D tensors), and width.
  6389. ///
  6390. /// @param aengine Engine to use.
  6391. /// @param aalgorithm Deconvolution algorithm. Possible values are
  6392. /// #dnnl::algorithm::deconvolution_direct, and
  6393. /// #dnnl::algorithm::deconvolution_winograd.
  6394. /// @param src_desc Source memory descriptor.
  6395. /// @param diff_weights_desc Diff weights memory descriptor.
  6396. /// @param diff_bias_desc Diff bias memory descriptor. Passing zero
  6397. /// memory descriptor disables the bias term.
  6398. /// @param diff_dst_desc Diff destination memory descriptor.
  6399. /// @param strides Strides for each spatial dimension.
  6400. /// @param dilates Dilations for each spatial dimension. A zero value
  6401. /// means no dilation in the corresponding dimension.
  6402. /// @param padding_l Vector of padding values for low indices for each
  6403. /// spatial dimension `([[front,] top,] left)`.
  6404. /// @param padding_r Vector of padding values for high indices for
  6405. /// each spatial dimension `([[back,] bottom,] right)`.
  6406. /// @param hint_fwd_pd Primitive descriptor for a deconvolution
  6407. /// forward propagation primitive. It is used as a hint for
  6408. /// deciding which memory format to use.
  6409. /// @param attr Primitive attributes to use. Attributes are optional
  6410. /// and default to empty attributes.
  6411. /// @param allow_empty A flag signifying whether construction is
  6412. /// allowed to fail without throwing an exception. In this case an
  6413. /// empty object will be produced. This flag is optional and
  6414. /// defaults to false.
  6415. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6416. const memory::desc &src_desc,
  6417. const memory::desc &diff_weights_desc,
  6418. const memory::desc &diff_bias_desc,
  6419. const memory::desc &diff_dst_desc, const memory::dims &strides,
  6420. const memory::dims &dilates, const memory::dims &padding_l,
  6421. const memory::dims &padding_r,
  6422. const deconvolution_forward::primitive_desc &hint_fwd_pd,
  6423. const primitive_attr &attr = default_attr(),
  6424. bool allow_empty = false)
  6425. : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
  6426. &diff_bias_desc, diff_dst_desc, strides, &dilates,
  6427. padding_l, padding_r, hint_fwd_pd, attr, allow_empty) {}
  6428. /// Constructs a primitive descriptor for a deconvolution weights
  6429. /// gradient primitive without bias.
  6430. ///
  6431. /// @note
  6432. /// All the memory descriptors may be initialized with the
  6433. /// #dnnl::memory::format_tag::any value of @p format_tag.
  6434. ///
  6435. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  6436. /// contain values for spatial dimensions only and hence must have the
  6437. /// same number of elements as there are spatial dimensions. The order
  6438. /// of values is the same as in the tensor: depth (for 3D tensors),
  6439. /// height (for 3D and 2D tensors), and width.
  6440. ///
  6441. /// @param aengine Engine to use.
  6442. /// @param aalgorithm Deconvolution algorithm. Possible values are
  6443. /// #dnnl::algorithm::deconvolution_direct, and
  6444. /// #dnnl::algorithm::deconvolution_winograd.
  6445. /// @param src_desc Source memory descriptor.
  6446. /// @param diff_weights_desc Diff weights memory descriptor.
  6447. /// @param diff_dst_desc Diff destination memory descriptor.
  6448. /// @param strides Strides for each spatial dimension.
  6449. /// @param dilates Dilations for each spatial dimension. A zero value
  6450. /// means no dilation in the corresponding dimension.
  6451. /// @param padding_l Vector of padding values for low indices for each
  6452. /// spatial dimension `([[front,] top,] left)`.
  6453. /// @param padding_r Vector of padding values for high indices for
  6454. /// each spatial dimension `([[back,] bottom,] right)`.
  6455. /// @param hint_fwd_pd Primitive descriptor for a deconvolution
  6456. /// forward propagation primitive. It is used as a hint for
  6457. /// deciding which memory format to use.
  6458. /// @param attr Primitive attributes to use. Attributes are optional
  6459. /// and default to empty attributes.
  6460. /// @param allow_empty A flag signifying whether construction is
  6461. /// allowed to fail without throwing an exception. In this case an
  6462. /// empty object will be produced. This flag is optional and
  6463. /// defaults to false.
  6464. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6465. const memory::desc &src_desc,
  6466. const memory::desc &diff_weights_desc,
  6467. const memory::desc &diff_dst_desc, const memory::dims &strides,
  6468. const memory::dims &dilates, const memory::dims &padding_l,
  6469. const memory::dims &padding_r,
  6470. const deconvolution_forward::primitive_desc &hint_fwd_pd,
  6471. const primitive_attr &attr = default_attr(),
  6472. bool allow_empty = false)
  6473. : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
  6474. nullptr, diff_dst_desc, strides, &dilates, padding_l,
  6475. padding_r, hint_fwd_pd, attr, allow_empty) {}
  6476. /// Constructs a primitive descriptor for a deconvolution weights
  6477. /// gradient primitive from a C API primitive descriptor that must
  6478. /// have a matching kind.
  6479. ///
  6480. /// @param pd C API primitive descriptor for a deconvolution weights
  6481. /// gradient primitive.
  6482. primitive_desc(dnnl_primitive_desc_t pd)
  6483. : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
  6484. dnnl::prop_kind::backward_weights) {}
  6485. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  6486. memory::desc src_desc() const { return base::src_desc(0); }
  6487. /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
  6488. memory::desc diff_weights_desc() const {
  6489. return base::diff_weights_desc(0);
  6490. }
  6491. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  6492. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  6493. /// @copydoc dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc()const
  6494. memory::desc diff_bias_desc() const {
  6495. return base::diff_weights_desc(1);
  6496. }
  6497. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  6498. algorithm get_algorithm() const { return base::get_algorithm(); }
  6499. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  6500. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  6501. /// @copydoc dnnl::primitive_desc_base::get_strides()const
  6502. memory::dims get_strides() const { return base::get_strides(); }
  6503. /// @copydoc dnnl::primitive_desc_base::get_dilations()const
  6504. memory::dims get_dilations() const { return base::get_dilations(); }
  6505. /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
  6506. memory::dims get_padding_l() const { return base::get_padding_l(); }
  6507. /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
  6508. memory::dims get_padding_r() const { return base::get_padding_r(); }
  6509. private:
  6510. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6511. const memory::desc &src_desc,
  6512. const memory::desc &diff_weights_desc,
  6513. const memory::desc *diff_bias_desc,
  6514. const memory::desc &diff_dst_desc, const memory::dims &strides,
  6515. const memory::dims *dilates, const memory::dims &padding_l,
  6516. const memory::dims &padding_r,
  6517. const deconvolution_forward::primitive_desc &hint_fwd_pd,
  6518. const primitive_attr &attr, bool allow_empty) {
  6519. memory::validate_dims(strides, src_desc.get_ndims() - 2);
  6520. memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
  6521. memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
  6522. if (dilates)
  6523. memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
  6524. dnnl_primitive_desc_t pd = nullptr;
  6525. dnnl_status_t status
  6526. = dnnl_deconvolution_backward_weights_primitive_desc_create(
  6527. &pd, aengine.get(), convert_to_c(aalgorithm),
  6528. src_desc.get(), diff_weights_desc.get(),
  6529. optional_arg(diff_bias_desc), diff_dst_desc.get(),
  6530. &strides[0], optional_arg(dilates), &padding_l[0],
  6531. &padding_r[0], hint_fwd_pd.get(), attr.get());
  6532. if (!allow_empty)
  6533. error::wrap_c_api(status,
  6534. "could not create a primitive descriptor for "
  6535. "the deconvolution weights update primitive. Run "
  6536. "workload with environment variable ONEDNN_VERBOSE=all "
  6537. "to get additional diagnostic information.");
  6538. reset(pd);
  6539. }
  6540. };
  6541. /// Default constructor. Produces an empty object.
  6542. deconvolution_backward_weights() = default;
  6543. /// Constructs a deconvolution weights gradient primitive.
  6544. /// @param pd Primitive descriptor for a deconvolution weights gradient
  6545. /// primitive.
  6546. deconvolution_backward_weights(const primitive_desc &pd) : primitive(pd) {}
  6547. /// Constructs a deconvolution weights gradient primitive from a cache
  6548. /// blob.
  6549. /// @param pd Primitive descriptor for a deconvolution weights gradient
  6550. /// primitive.
  6551. /// @param cache_blob Cache blob.
  6552. deconvolution_backward_weights(
  6553. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  6554. : primitive(pd, cache_blob) {}
  6555. };
  6556. /// @} dnnl_api_deconvolution
  6557. /// @addtogroup dnnl_api_lrn LRN
  6558. ///
  6559. /// A primitive to perform local response normalization (LRN) across or within
  6560. /// channels.
  6561. ///
  6562. /// @sa @ref dev_guide_lrn in developer guide
  6563. ///
  6564. /// @{
  6565. /// Local response normalization (LRN) forward propagation primitive.
  6566. struct lrn_forward : public primitive {
  6567. /// Primitive descriptor for an LRN forward propagation primitive.
  6568. struct primitive_desc : public dnnl::primitive_desc {
  6569. /// Default constructor. Produces an empty object.
  6570. primitive_desc() = default;
  6571. /// Constructs a primitive descriptor for an LRN forward propagation
  6572. /// primitive.
  6573. ///
  6574. /// @param aengine Engine to use.
  6575. /// @param aprop_kind Propagation kind. Possible values are
  6576. /// #dnnl::prop_kind::forward_training, and
  6577. /// #dnnl::prop_kind::forward_inference.
  6578. /// @param aalgorithm LRN algorithm kind: either
  6579. /// #dnnl::algorithm::lrn_across_channels, or
  6580. /// #dnnl::algorithm::lrn_within_channel.
  6581. /// @param src_desc Source memory descriptor.
  6582. /// @param dst_desc Destination memory descriptor.
  6583. /// @param local_size Regularization local size.
  6584. /// @param alpha The alpha regularization parameter.
  6585. /// @param beta The beta regularization parameter.
  6586. /// @param k The k regularization parameter.
  6587. /// @param attr Primitive attributes to use. Attributes are optional
  6588. /// and default to empty attributes.
  6589. /// @param allow_empty A flag signifying whether construction is
  6590. /// allowed to fail without throwing an exception. In this case an
  6591. /// empty object will be produced. This flag is optional and
  6592. /// defaults to false.
  6593. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  6594. algorithm aalgorithm, const memory::desc &src_desc,
  6595. const memory::desc &dst_desc, memory::dim local_size,
  6596. float alpha, float beta, float k,
  6597. const primitive_attr &attr = default_attr(),
  6598. bool allow_empty = false) {
  6599. dnnl_primitive_desc_t pd = nullptr;
  6600. dnnl_status_t status = dnnl_lrn_forward_primitive_desc_create(&pd,
  6601. aengine.get(), dnnl::convert_to_c(aprop_kind),
  6602. convert_to_c(aalgorithm), src_desc.get(), dst_desc.get(),
  6603. local_size, alpha, beta, k, attr.get());
  6604. if (!allow_empty)
  6605. error::wrap_c_api(status,
  6606. "could not create a primitive descriptor for "
  6607. "the lrn forward propagation primitive. Run workload "
  6608. "with environment variable ONEDNN_VERBOSE=all to get "
  6609. "additional diagnostic information.");
  6610. reset(pd);
  6611. }
  6612. /// Constructs a primitive descriptor for an LRN forward propagation
  6613. /// primitive from a C API primitive descriptor that must have a
  6614. /// matching kind.
  6615. ///
  6616. /// @param pd C API primitive descriptor for an LRN forward
  6617. /// propagation primitive.
  6618. primitive_desc(dnnl_primitive_desc_t pd)
  6619. : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
  6620. dnnl::prop_kind::forward_training,
  6621. dnnl::prop_kind::forward_inference) {}
  6622. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  6623. memory::desc src_desc() const { return base::src_desc(0); }
  6624. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  6625. memory::desc dst_desc() const { return base::dst_desc(0); }
  6626. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  6627. memory::desc workspace_desc() const { return base::workspace_desc(); }
  6628. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  6629. algorithm get_algorithm() const { return base::get_algorithm(); }
  6630. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  6631. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  6632. /// @copydoc dnnl::primitive_desc_base::get_alpha()const
  6633. float get_alpha() const { return base::get_alpha(); }
  6634. /// @copydoc dnnl::primitive_desc_base::get_beta()const
  6635. float get_beta() const { return base::get_beta(); }
  6636. /// @copydoc dnnl::primitive_desc_base::get_local_size()const
  6637. memory::dim get_local_size() const { return base::get_local_size(); }
  6638. /// @copydoc dnnl::primitive_desc_base::get_k()const
  6639. float get_k() const { return base::get_k(); }
  6640. };
  6641. /// Default constructor. Produces an empty object.
  6642. lrn_forward() = default;
  6643. /// Constructs an LRN forward propagation primitive.
  6644. /// @param pd Primitive descriptor for an LRN forward propagation
  6645. /// primitive.
  6646. lrn_forward(const primitive_desc &pd) : primitive(pd) {}
  6647. /// Constructs an LRN forward propagation primitive from a cache blob.
  6648. /// @param pd Primitive descriptor for an LRN forward propagation
  6649. /// primitive.
  6650. /// @param cache_blob Cache blob.
  6651. lrn_forward(
  6652. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  6653. : primitive(pd, cache_blob) {}
  6654. };
  6655. /// Local response normalization (LRN) backward propagation primitive.
  6656. struct lrn_backward : public primitive {
  6657. /// Primitive descriptor for an LRN backward propagation primitive.
  6658. struct primitive_desc : public dnnl::primitive_desc {
  6659. /// Default constructor. Produces an empty object.
  6660. primitive_desc() = default;
  6661. /// Constructs a primitive descriptor for an LRN backward propagation
  6662. /// primitive.
  6663. ///
  6664. /// @param aengine Engine to use.
  6665. /// @param aalgorithm LRN algorithm kind: either
  6666. /// #dnnl::algorithm::lrn_across_channels, or
  6667. /// #dnnl::algorithm::lrn_within_channel.
  6668. /// @param diff_src_desc Diff source memory descriptor.
  6669. /// @param diff_dst_desc Diff destination memory descriptor.
  6670. /// @param src_desc Source memory descriptor.
  6671. /// @param local_size Regularization local size.
  6672. /// @param alpha The alpha regularization parameter.
  6673. /// @param beta The beta regularization parameter.
  6674. /// @param k The k regularization parameter.
  6675. /// @param hint_fwd_pd Primitive descriptor for an LRN forward
  6676. /// propagation primitive. It is used as a hint for deciding which
  6677. /// memory format to use.
  6678. /// @param attr Primitive attributes to use. Attributes are optional
  6679. /// and default to empty attributes.
  6680. /// @param allow_empty A flag signifying whether construction is
  6681. /// allowed to fail without throwing an exception. In this case an
  6682. /// empty object will be produced. This flag is optional and
  6683. /// defaults to false.
  6684. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6685. const memory::desc &diff_src_desc,
  6686. const memory::desc &diff_dst_desc, const memory::desc &src_desc,
  6687. memory::dim local_size, float alpha, float beta, float k,
  6688. const lrn_forward::primitive_desc &hint_fwd_pd,
  6689. const primitive_attr &attr = default_attr(),
  6690. bool allow_empty = false) {
  6691. dnnl_primitive_desc_t pd = nullptr;
  6692. dnnl_status_t status = dnnl_lrn_backward_primitive_desc_create(&pd,
  6693. aengine.get(), convert_to_c(aalgorithm),
  6694. diff_src_desc.get(), diff_dst_desc.get(), src_desc.get(),
  6695. local_size, alpha, beta, k, hint_fwd_pd.get(), attr.get());
  6696. if (!allow_empty)
  6697. error::wrap_c_api(status,
  6698. "could not create a primitive descriptor for "
  6699. "the lrn backward propagation primitive. Run workload "
  6700. "with environment variable ONEDNN_VERBOSE=all to get "
  6701. "additional diagnostic information.");
  6702. reset(pd);
  6703. }
  6704. /// Constructs a primitive descriptor for an LRN backward propagation
  6705. /// primitive from a C API primitive descriptor that must have a
  6706. /// matching kind.
  6707. ///
  6708. /// @param pd C API primitive descriptor for an LRN backward
  6709. /// propagation primitive.
  6710. primitive_desc(dnnl_primitive_desc_t pd)
  6711. : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
  6712. dnnl::prop_kind::backward_data) {}
  6713. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  6714. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  6715. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  6716. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  6717. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  6718. memory::desc workspace_desc() const { return base::workspace_desc(); }
  6719. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  6720. algorithm get_algorithm() const { return base::get_algorithm(); }
  6721. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  6722. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  6723. /// @copydoc dnnl::primitive_desc_base::get_alpha()const
  6724. float get_alpha() const { return base::get_alpha(); }
  6725. /// @copydoc dnnl::primitive_desc_base::get_beta()const
  6726. float get_beta() const { return base::get_beta(); }
  6727. /// @copydoc dnnl::primitive_desc_base::get_local_size()const
  6728. memory::dim get_local_size() const { return base::get_local_size(); }
  6729. /// @copydoc dnnl::primitive_desc_base::get_k()const
  6730. float get_k() const { return base::get_k(); }
  6731. };
  6732. /// Default constructor. Produces an empty object.
  6733. lrn_backward() = default;
  6734. /// Constructs an LRN backward propagation primitive.
  6735. /// @param pd Primitive descriptor for an LRN backward propagation
  6736. /// primitive.
  6737. lrn_backward(const primitive_desc &pd) : primitive(pd) {}
  6738. /// Constructs an LRN backward propagation primitive from a cache blob.
  6739. /// @param pd Primitive descriptor for an LRN backward propagation
  6740. /// primitive.
  6741. /// @param cache_blob Cache blob.
  6742. lrn_backward(
  6743. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  6744. : primitive(pd, cache_blob) {}
  6745. };
  6746. /// @} dnnl_api_lrn
  6747. /// @addtogroup dnnl_api_eltwise Eltwise
  6748. ///
  6749. /// A primitive to perform elementwise operations such as the
  6750. /// rectifier linear unit (ReLU).
  6751. ///
  6752. /// Both forward and backward propagation primitives support in-place
  6753. /// operation; that is, src and dst can refer to the same memory for forward
  6754. /// propagation, and diff_dst and diff_src can refer to the same memory for
  6755. /// backward propagation.
  6756. ///
  6757. /// @warning
  6758. /// Because the original source data is required for backward propagation,
  6759. /// in-place forward propagation is not generally supported in the
  6760. /// training mode. However, for algorithms supporting destination as input
  6761. /// memory, dst can be used for the backward propagation, which makes it
  6762. /// possible to get performance benefit even in the training mode.
  6763. ///
  6764. /// @sa @ref dev_guide_eltwise in developer guide
  6765. ///
  6766. /// @{
  6767. /// Elementwise unary operation forward propagation primitive.
  6768. struct eltwise_forward : public primitive {
  6769. /// Primitive descriptor for an elementwise forward propagation primitive.
  6770. struct primitive_desc : public dnnl::primitive_desc {
  6771. /// Default constructor. Produces an empty object.
  6772. primitive_desc() = default;
  6773. /// Constructs a primitive descriptor for an elementwise forward
  6774. /// propagation primitive.
  6775. ///
  6776. /// @param aengine Engine to use.
  6777. /// @param aprop_kind Propagation kind. Possible values are
  6778. /// #dnnl::prop_kind::forward_training, and
  6779. /// #dnnl::prop_kind::forward_inference.
  6780. /// @param aalgorithm Elementwise algorithm kind.
  6781. /// @param src_desc Source memory descriptor.
  6782. /// @param dst_desc Destination memory descriptor.
  6783. /// @param attr Primitive attributes to use. Attributes are optional
  6784. /// and default to empty attributes.
  6785. /// @param allow_empty A flag signifying whether construction is
  6786. /// allowed to fail without throwing an exception. In this case an
  6787. /// empty object will be produced. This flag is optional and
  6788. /// defaults to false.
  6789. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  6790. algorithm aalgorithm, const memory::desc &src_desc,
  6791. const memory::desc &dst_desc,
  6792. const primitive_attr &attr = default_attr(),
  6793. bool allow_empty = false)
  6794. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  6795. dst_desc, nullptr, nullptr, attr, allow_empty) {}
  6796. /// Constructs a primitive descriptor for an elementwise forward
  6797. /// propagation primitive with an alpha parameter.
  6798. ///
  6799. /// @param aengine Engine to use.
  6800. /// @param aprop_kind Propagation kind. Possible values are
  6801. /// #dnnl::prop_kind::forward_training, and
  6802. /// #dnnl::prop_kind::forward_inference.
  6803. /// @param aalgorithm Elementwise algorithm kind.
  6804. /// @param src_desc Source memory descriptor.
  6805. /// @param dst_desc Destination memory descriptor.
  6806. /// @param alpha The alpha parameter for the elementwise operation.
  6807. /// Specific meaning depends on the algorithm.
  6808. /// @param attr Primitive attributes to use. Attributes are optional
  6809. /// and default to empty attributes.
  6810. /// @param allow_empty A flag signifying whether construction is
  6811. /// allowed to fail without throwing an exception. In this case an
  6812. /// empty object will be produced. This flag is optional and
  6813. /// defaults to false.
  6814. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  6815. algorithm aalgorithm, const memory::desc &src_desc,
  6816. const memory::desc &dst_desc, float alpha,
  6817. const primitive_attr &attr = default_attr(),
  6818. bool allow_empty = false)
  6819. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  6820. dst_desc, &alpha, nullptr, attr, allow_empty) {}
  6821. /// Constructs a primitive descriptor for an elementwise forward
  6822. /// propagation primitive with an alpha and beta parameters.
  6823. ///
  6824. /// @param aengine Engine to use.
  6825. /// @param aprop_kind Propagation kind. Possible values are
  6826. /// #dnnl::prop_kind::forward_training, and
  6827. /// #dnnl::prop_kind::forward_inference.
  6828. /// @param aalgorithm Elementwise algorithm kind.
  6829. /// @param src_desc Source memory descriptor.
  6830. /// @param dst_desc Destination memory descriptor.
  6831. /// @param alpha The alpha parameter for the elementwise operation.
  6832. /// Specific meaning depends on the algorithm.
  6833. /// @param beta The beta parameter for the elementwise operation.
  6834. /// Specific meaning depends on the algorithm.
  6835. /// @param attr Primitive attributes to use. Attributes are optional
  6836. /// and default to empty attributes.
  6837. /// @param allow_empty A flag signifying whether construction is
  6838. /// allowed to fail without throwing an exception. In this case an
  6839. /// empty object will be produced. This flag is optional and
  6840. /// defaults to false.
  6841. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  6842. algorithm aalgorithm, const memory::desc &src_desc,
  6843. const memory::desc &dst_desc, float alpha, float beta,
  6844. const primitive_attr &attr = default_attr(),
  6845. bool allow_empty = false)
  6846. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  6847. dst_desc, &alpha, &beta, attr, allow_empty) {}
  6848. /// Constructs a primitive descriptor for an eltwise forward
  6849. /// propagation primitive from a C API primitive descriptor that must
  6850. /// have a matching kind.
  6851. ///
  6852. /// @param pd C API primitive descriptor for an eltwise forward
  6853. /// propagation primitive.
  6854. primitive_desc(dnnl_primitive_desc_t pd)
  6855. : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
  6856. dnnl::prop_kind::forward_training,
  6857. dnnl::prop_kind::forward_inference) {}
  6858. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  6859. memory::desc src_desc() const { return base::src_desc(0); }
  6860. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  6861. memory::desc dst_desc() const { return base::dst_desc(0); }
  6862. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  6863. dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
  6864. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  6865. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  6866. /// @copydoc dnnl::primitive_desc_base::get_alpha()const
  6867. float get_alpha() const { return base::get_alpha(); }
  6868. /// @copydoc dnnl::primitive_desc_base::get_beta()const
  6869. float get_beta() const { return base::get_beta(); }
  6870. private:
  6871. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  6872. algorithm aalgorithm, const memory::desc &src_desc,
  6873. const memory::desc &dst_desc, const float *alpha,
  6874. const float *beta, const primitive_attr &attr,
  6875. bool allow_empty) {
  6876. dnnl_primitive_desc_t pd = nullptr;
  6877. dnnl_status_t status = dnnl_eltwise_forward_primitive_desc_create(
  6878. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  6879. dnnl::convert_to_c(aalgorithm), src_desc.get(),
  6880. dst_desc.get(), alpha ? *alpha : 0.0f, beta ? *beta : 0.0f,
  6881. attr.get());
  6882. if (!allow_empty)
  6883. error::wrap_c_api(status,
  6884. "could not create a primitive descriptor for "
  6885. "the eltwise forward propagation primitive. Run "
  6886. "workload with environment variable ONEDNN_VERBOSE=all "
  6887. "to get additional diagnostic information.");
  6888. reset(pd);
  6889. }
  6890. };
  6891. /// Default constructor. Produces an empty object.
  6892. eltwise_forward() = default;
  6893. /// Constructs an eltwise forward propagation primitive.
  6894. /// @param pd Primitive descriptor for an eltwise forward propagation
  6895. /// primitive.
  6896. eltwise_forward(const primitive_desc &pd) : primitive(pd) {}
  6897. /// Constructs an eltwise forward propagation primitive from a cache blob.
  6898. /// @param pd Primitive descriptor for an eltwise forward propagation
  6899. /// primitive.
  6900. /// @param cache_blob Cache blob.
  6901. eltwise_forward(
  6902. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  6903. : primitive(pd, cache_blob) {}
  6904. };
  6905. /// Elementwise unary operation backward propagation primitive.
  6906. struct eltwise_backward : public primitive {
  6907. /// Primitive descriptor for eltwise backward propagation.
  6908. struct primitive_desc : public dnnl::primitive_desc {
  6909. /// Default constructor. Produces an empty object.
  6910. primitive_desc() = default;
  6911. /// Constructs a primitive descriptor for an elementwise backward
  6912. /// propagation primitive with an alpha parameter.
  6913. ///
  6914. /// @param aengine Engine to use.
  6915. /// @param aalgorithm Elementwise algorithm kind.
  6916. /// @param diff_src_desc Diff source memory descriptor.
  6917. /// @param diff_dst_desc Diff destination memory descriptor.
  6918. /// @param data_desc Destination memory descriptor if one of the
  6919. /// "use_dst_for_bwd" algorithms are used (such as
  6920. /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor
  6921. /// otherwise.
  6922. /// @param hint_fwd_pd Primitive descriptor for an elementwise
  6923. /// forward propagation primitive. It is used as a hint for
  6924. /// deciding which memory format to use.
  6925. /// @param attr Primitive attributes to use. Attributes are optional
  6926. /// and default to empty attributes.
  6927. /// @param allow_empty A flag signifying whether construction is
  6928. /// allowed to fail without throwing an exception. In this case an
  6929. /// empty object will be produced. This flag is optional and
  6930. /// defaults to false.
  6931. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6932. const memory::desc &diff_src_desc,
  6933. const memory::desc &diff_dst_desc,
  6934. const memory::desc &data_desc,
  6935. const eltwise_forward::primitive_desc &hint_fwd_pd,
  6936. const primitive_attr &attr = default_attr(),
  6937. bool allow_empty = false)
  6938. : primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc,
  6939. data_desc, nullptr, nullptr, hint_fwd_pd, attr,
  6940. allow_empty) {}
  6941. /// Constructs a primitive descriptor for an elementwise backward
  6942. /// propagation primitive with an alpha parameter.
  6943. ///
  6944. /// @param aengine Engine to use.
  6945. /// @param aalgorithm Elementwise algorithm kind.
  6946. /// @param diff_src_desc Diff source memory descriptor.
  6947. /// @param diff_dst_desc Diff destination memory descriptor.
  6948. /// @param data_desc Destination memory descriptor if one of the
  6949. /// "use_dst_for_bwd" algorithms are used (such as
  6950. /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor
  6951. /// otherwise.
  6952. /// @param alpha The alpha parameter for the elementwise operation.
  6953. /// Specific meaning depends on the algorithm.
  6954. /// @param hint_fwd_pd Primitive descriptor for an elementwise
  6955. /// forward propagation primitive. It is used as a hint for
  6956. /// deciding which memory format to use.
  6957. /// @param attr Primitive attributes to use. Attributes are optional
  6958. /// and default to empty attributes.
  6959. /// @param allow_empty A flag signifying whether construction is
  6960. /// allowed to fail without throwing an exception. In this case an
  6961. /// empty object will be produced. This flag is optional and
  6962. /// defaults to false.
  6963. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6964. const memory::desc &diff_src_desc,
  6965. const memory::desc &diff_dst_desc,
  6966. const memory::desc &data_desc, float alpha,
  6967. const eltwise_forward::primitive_desc &hint_fwd_pd,
  6968. const primitive_attr &attr = default_attr(),
  6969. bool allow_empty = false)
  6970. : primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc,
  6971. data_desc, &alpha, nullptr, hint_fwd_pd, attr,
  6972. allow_empty) {}
  6973. /// Constructs a primitive descriptor for an elementwise backward
  6974. /// propagation primitive with an alpha and beta parameters.
  6975. ///
  6976. /// @param aengine Engine to use.
  6977. /// @param aalgorithm Elementwise algorithm kind.
  6978. /// @param diff_src_desc Diff source memory descriptor.
  6979. /// @param diff_dst_desc Diff destination memory descriptor.
  6980. /// @param data_desc Destination memory descriptor if one of the
  6981. /// "use_dst_for_bwd" algorithms are used (such as
  6982. /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor
  6983. /// otherwise.
  6984. /// @param alpha The alpha parameter for the elementwise operation.
  6985. /// Specific meaning depends on the algorithm.
  6986. /// @param beta The beta parameter for the elementwise operation.
  6987. /// Specific meaning depends on the algorithm.
  6988. /// @param hint_fwd_pd Primitive descriptor for an elementwise
  6989. /// forward propagation primitive. It is used as a hint for
  6990. /// deciding which memory format to use.
  6991. /// @param attr Primitive attributes to use. Attributes are optional
  6992. /// and default to empty attributes.
  6993. /// @param allow_empty A flag signifying whether construction is
  6994. /// allowed to fail without throwing an exception. In this case an
  6995. /// empty object will be produced. This flag is optional and
  6996. /// defaults to false.
  6997. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6998. const memory::desc &diff_src_desc,
  6999. const memory::desc &diff_dst_desc,
  7000. const memory::desc &data_desc, float alpha, float beta,
  7001. const eltwise_forward::primitive_desc &hint_fwd_pd,
  7002. const primitive_attr &attr = default_attr(),
  7003. bool allow_empty = false)
  7004. : primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc,
  7005. data_desc, &alpha, &beta, hint_fwd_pd, attr, allow_empty) {}
  7006. /// Constructs a primitive descriptor for an eltwise backward
  7007. /// propagation primitive from a C API primitive descriptor that must
  7008. /// have a matching kind.
  7009. ///
  7010. /// @param pd C API primitive descriptor for an eltwise backward
  7011. /// propagation primitive.
  7012. primitive_desc(dnnl_primitive_desc_t pd)
  7013. : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
  7014. dnnl::prop_kind::backward_data) {}
  7015. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  7016. memory::desc src_desc() const { return base::src_desc(0); }
  7017. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  7018. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  7019. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  7020. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  7021. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  7022. dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
  7023. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  7024. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  7025. /// @copydoc dnnl::primitive_desc_base::get_alpha()const
  7026. float get_alpha() const { return base::get_alpha(); }
  7027. /// @copydoc dnnl::primitive_desc_base::get_beta()const
  7028. float get_beta() const { return base::get_beta(); }
  7029. private:
  7030. primitive_desc(const engine &aengine, algorithm aalgorithm,
  7031. const memory::desc &diff_src_desc,
  7032. const memory::desc &diff_dst_desc,
  7033. const memory::desc &data_desc, const float *alpha,
  7034. const float *beta,
  7035. const eltwise_forward::primitive_desc &hint_fwd_pd,
  7036. const primitive_attr &attr, bool allow_empty) {
  7037. dnnl_primitive_desc_t pd = nullptr;
  7038. dnnl_status_t status = dnnl_eltwise_backward_primitive_desc_create(
  7039. &pd, aengine.get(), dnnl::convert_to_c(aalgorithm),
  7040. diff_src_desc.get(), diff_dst_desc.get(), data_desc.get(),
  7041. alpha ? *alpha : 0.0f, beta ? *beta : 0.0f,
  7042. hint_fwd_pd.get(), attr.get());
  7043. if (!allow_empty)
  7044. error::wrap_c_api(status,
  7045. "could not create a primitive descriptor for "
  7046. "the eltwise backward propagation primitive. Run "
  7047. "workload with environment variable ONEDNN_VERBOSE=all "
  7048. "to get additional diagnostic information.");
  7049. reset(pd);
  7050. }
  7051. };
  7052. /// Default constructor. Produces an empty object.
  7053. eltwise_backward() = default;
  7054. /// Constructs an eltwise backward propagation primitive.
  7055. /// @param pd Primitive descriptor for an eltwise backward propagation
  7056. /// primitive.
  7057. eltwise_backward(const primitive_desc &pd) : primitive(pd) {}
  7058. /// Constructs an eltwise backward propagation primitive from a cache blob.
  7059. /// @param pd Primitive descriptor for an eltwise backward propagation
  7060. /// primitive.
  7061. /// @param cache_blob Cache blob.
  7062. eltwise_backward(
  7063. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  7064. : primitive(pd, cache_blob) {}
  7065. };
  7066. /// @} dnnl_api_eltwise
  7067. /// @addtogroup dnnl_api_softmax Softmax
  7068. ///
  7069. /// A primitive to perform softmax.
  7070. ///
  7071. /// @sa @ref dev_guide_softmax in developer guide
  7072. ///
  7073. /// @{
  7074. /// Softmax forward propagation primitive.
  7075. struct softmax_forward : public primitive {
  7076. /// Primitive descriptor for a softmax forward propagation primitive.
  7077. struct primitive_desc : public dnnl::primitive_desc {
  7078. /// Default constructor. Produces an empty object.
  7079. primitive_desc() = default;
  7080. /// Constructs a primitive descriptor for a softmax forward propagation
  7081. /// primitive.
  7082. ///
  7083. /// @param aengine Engine to use.
  7084. /// @param aprop_kind Propagation kind. Possible values are
  7085. /// #dnnl::prop_kind::forward_training, and
  7086. /// #dnnl::prop_kind::forward_inference.
  7087. /// @param aalgorithm Softmax algorithm kind: either
  7088. /// #dnnl::algorithm::softmax_accurate,
  7089. /// or #dnnl::algorithm::softmax_log.
  7090. /// @param src_desc Source memory descriptor.
  7091. /// @param dst_desc Destination memory descriptor.
  7092. /// @param axis Axis over which softmax is computed.
  7093. /// @param attr Primitive attributes to use. Attributes are optional
  7094. /// and default to empty attributes.
  7095. /// @param allow_empty A flag signifying whether construction is
  7096. /// allowed to fail without throwing an exception. In this case an
  7097. /// empty object will be produced. This flag is optional and
  7098. /// defaults to false.
  7099. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7100. algorithm aalgorithm, const memory::desc &src_desc,
  7101. const memory::desc &dst_desc, int axis,
  7102. const primitive_attr &attr = default_attr(),
  7103. bool allow_empty = false) {
  7104. dnnl_primitive_desc_t pd = nullptr;
  7105. dnnl_status_t status = dnnl_softmax_forward_primitive_desc_create(
  7106. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  7107. dnnl::convert_to_c(aalgorithm), src_desc.get(),
  7108. dst_desc.get(), axis, attr.get());
  7109. if (!allow_empty)
  7110. error::wrap_c_api(status,
  7111. "could not create a primitive descriptor for "
  7112. "the softmax forward propagation primitive. Run "
  7113. "workload with environment variable ONEDNN_VERBOSE=all "
  7114. "to get additional diagnostic information.");
  7115. reset(pd);
  7116. }
  7117. /// Constructs a primitive descriptor for a softmax forward
  7118. /// propagation primitive from a C API primitive descriptor that must
  7119. /// have a matching kind.
  7120. ///
  7121. /// @param pd C API primitive descriptor for a softmax forward
  7122. /// propagation primitive.
  7123. primitive_desc(dnnl_primitive_desc_t pd)
  7124. : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
  7125. dnnl::prop_kind::forward_training,
  7126. dnnl::prop_kind::forward_inference) {}
  7127. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  7128. memory::desc src_desc() const { return base::src_desc(0); }
  7129. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  7130. memory::desc dst_desc() const { return base::dst_desc(0); }
  7131. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  7132. dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
  7133. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  7134. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  7135. /// @copydoc dnnl::primitive_desc_base::get_axis()const
  7136. int get_axis() const { return base::get_axis(); }
  7137. };
  7138. /// Default constructor. Produces an empty object.
  7139. softmax_forward() = default;
  7140. /// Constructs a softmax forward propagation primitive.
  7141. /// @param pd Primitive descriptor for a softmax forward propagation
  7142. /// primitive.
  7143. softmax_forward(const primitive_desc &pd) : primitive(pd) {}
  7144. /// Constructs a softmax forward propagation primitive from a cache blob.
  7145. /// @param pd Primitive descriptor for a softmax forward propagation
  7146. /// primitive.
  7147. /// @param cache_blob Cache blob.
  7148. softmax_forward(
  7149. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  7150. : primitive(pd, cache_blob) {}
  7151. };
  7152. /// Softmax backward propagation primitive.
  7153. struct softmax_backward : public primitive {
  7154. /// Primitive descriptor for a softmax backward propagation primitive.
  7155. struct primitive_desc : public dnnl::primitive_desc {
  7156. /// Default constructor. Produces an empty object.
  7157. primitive_desc() = default;
  7158. /// Constructs a primitive descriptor for a softmax backward propagation
  7159. /// primitive.
  7160. ///
  7161. /// @param aengine Engine to use.
  7162. /// @param aalgorithm Softmax algorithm kind: either
  7163. /// #dnnl::algorithm::softmax_accurate,
  7164. /// or #dnnl::algorithm::softmax_log.
  7165. /// @param diff_src_desc Diff source memory descriptor.
  7166. /// @param diff_dst_desc Diff destination memory descriptor.
  7167. /// @param dst_desc Destination memory descriptor.
  7168. /// @param axis Axis over which softmax is computed.
  7169. /// @param hint_fwd_pd Primitive descriptor for a softmax
  7170. /// forward propagation primitive. It is used as a hint for
  7171. /// deciding which memory format to use.
  7172. /// @param attr Primitive attributes to use. Attributes are optional
  7173. /// and default to empty attributes.
  7174. /// @param allow_empty A flag signifying whether construction is
  7175. /// allowed to fail without throwing an exception. In this case an
  7176. /// empty object will be produced. This flag is optional and
  7177. /// defaults to false.
  7178. primitive_desc(const engine &aengine, algorithm aalgorithm,
  7179. const memory::desc &diff_src_desc,
  7180. const memory::desc &diff_dst_desc, const memory::desc &dst_desc,
  7181. int axis, const softmax_forward::primitive_desc &hint_fwd_pd,
  7182. const primitive_attr &attr = default_attr(),
  7183. bool allow_empty = false) {
  7184. dnnl_primitive_desc_t pd = nullptr;
  7185. dnnl_status_t status = dnnl_softmax_backward_primitive_desc_create(
  7186. &pd, aengine.get(), dnnl::convert_to_c(aalgorithm),
  7187. diff_src_desc.get(), diff_dst_desc.get(), dst_desc.get(),
  7188. axis, hint_fwd_pd.get(), attr.get());
  7189. if (!allow_empty)
  7190. error::wrap_c_api(status,
  7191. "could not create a primitive descriptor for "
  7192. "the softmax backward propagation primitive. Run "
  7193. "workload with environment variable ONEDNN_VERBOSE=all "
  7194. "to get additional diagnostic information.");
  7195. reset(pd);
  7196. }
  7197. /// Constructs a primitive descriptor for a softmax backward
  7198. /// propagation primitive from a C API primitive descriptor that must
  7199. /// have a matching kind.
  7200. ///
  7201. /// @param pd C API primitive descriptor for a softmax backward
  7202. /// propagation primitive.
  7203. primitive_desc(dnnl_primitive_desc_t pd)
  7204. : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
  7205. dnnl::prop_kind::backward_data) {}
  7206. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  7207. memory::desc dst_desc() const { return base::dst_desc(0); }
  7208. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  7209. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  7210. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  7211. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  7212. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  7213. dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
  7214. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  7215. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  7216. /// @copydoc dnnl::primitive_desc_base::get_axis()const
  7217. int get_axis() const { return base::get_axis(); }
  7218. };
  7219. /// Default constructor. Produces an empty object.
  7220. softmax_backward() = default;
  7221. /// Constructs a softmax backward propagation primitive.
  7222. /// @param pd Primitive descriptor for a softmax backward propagation
  7223. /// primitive.
  7224. softmax_backward(const primitive_desc &pd) : primitive(pd) {}
  7225. /// Constructs a softmax backward propagation primitive from a cache blob.
  7226. /// @param pd Primitive descriptor for a softmax backward propagation
  7227. /// primitive.
  7228. /// @param cache_blob Cache blob.
  7229. softmax_backward(
  7230. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  7231. : primitive(pd, cache_blob) {}
  7232. };
  7233. /// @} dnnl_api_softmax
  7234. /// @addtogroup dnnl_api_batch_normalization Batch Normalization
  7235. ///
  7236. /// A primitive to perform batch normalization.
  7237. ///
  7238. /// Both forward and backward propagation primitives support in-place
  7239. /// operation; that is, src and dst can refer to the same memory for forward
  7240. /// propagation, and diff_dst and diff_src can refer to the same memory for
  7241. /// backward propagation.
  7242. ///
  7243. /// The batch normalization primitives computations can be controlled by
  7244. /// specifying different @ref dnnl::normalization_flags values. For example,
  7245. /// batch normalization forward propagation can be configured to either
  7246. /// compute the mean and variance or take them as arguments. It can either
  7247. /// perform scaling and shifting using gamma and beta parameters or not.
  7248. /// Optionally, it can also perform a fused ReLU, which in case of training
  7249. /// would also require a workspace.
  7250. ///
  7251. /// @sa @ref dev_guide_batch_normalization in developer guide
  7252. ///
  7253. /// @{
  7254. /// Batch normalization forward propagation primitive.
  7255. struct batch_normalization_forward : public primitive {
  7256. /// Primitive descriptor for a batch normalization forward propagation
  7257. /// primitive.
  7258. struct primitive_desc : public dnnl::primitive_desc {
  7259. /// Default constructor. Produces an empty object.
  7260. primitive_desc() = default;
  7261. /// Constructs a primitive descriptor for a batch normalization forward
  7262. /// propagation primitive.
  7263. ///
  7264. /// @note
  7265. /// In-place operation is supported: the dst can refer to the same
  7266. /// memory as the src.
  7267. ///
  7268. /// @param aengine Engine to use.
  7269. /// @param aprop_kind Propagation kind. Possible values are
  7270. /// #dnnl::prop_kind::forward_training and
  7271. /// #dnnl::prop_kind::forward_inference.
  7272. /// @param src_desc Source memory descriptor.
  7273. /// @param dst_desc Destination memory descriptor.
  7274. /// @param epsilon Batch normalization epsilon parameter.
  7275. /// @param flags Batch normalization flags (@ref
  7276. /// dnnl::normalization_flags).
  7277. /// @param attr Primitive attributes to use. Attributes are optional
  7278. /// and default to empty attributes.
  7279. /// @param allow_empty A flag signifying whether construction is
  7280. /// allowed to fail without throwing an exception. In this case an
  7281. /// empty object will be produced. This flag is optional and
  7282. /// defaults to false.
  7283. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7284. const memory::desc &src_desc, const memory::desc &dst_desc,
  7285. float epsilon, normalization_flags flags,
  7286. const primitive_attr &attr = default_attr(),
  7287. bool allow_empty = false) {
  7288. dnnl_primitive_desc_t pd = nullptr;
  7289. dnnl_status_t status
  7290. = dnnl_batch_normalization_forward_primitive_desc_create(
  7291. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  7292. src_desc.get(), dst_desc.get(), epsilon,
  7293. convert_to_c(flags), attr.get());
  7294. if (!allow_empty)
  7295. error::wrap_c_api(status,
  7296. "could not create a primitive descriptor for "
  7297. "the batch normalization forward propagation "
  7298. "primitive. Run workload with environment variable "
  7299. "ONEDNN_VERBOSE=all to get additional diagnostic "
  7300. "information.");
  7301. reset(pd);
  7302. }
  7303. /// Constructs a primitive descriptor for a batch normalization
  7304. /// forward propagation primitive from a C API primitive descriptor
  7305. /// that must have a matching kind.
  7306. ///
  7307. /// @param pd C API primitive descriptor for a batch normalization
  7308. /// forward propagation primitive.
  7309. primitive_desc(dnnl_primitive_desc_t pd)
  7310. : dnnl::primitive_desc(pd,
  7311. dnnl::primitive::kind::batch_normalization,
  7312. dnnl::prop_kind::forward_training,
  7313. dnnl::prop_kind::forward_inference) {}
  7314. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  7315. memory::desc src_desc() const { return base::src_desc(0); }
  7316. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  7317. memory::desc dst_desc() const { return base::dst_desc(0); }
  7318. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  7319. memory::desc weights_desc() const { return base::weights_desc(0); }
  7320. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  7321. memory::desc workspace_desc() const { return base::workspace_desc(); }
  7322. /// Returns memory descriptor for mean.
  7323. /// @returns Memory descriptor for mean.
  7324. memory::desc mean_desc() const { return stat_desc(mean); }
  7325. /// Returns memory descriptor for variance.
  7326. /// @returns Memory descriptor for variance.
  7327. memory::desc variance_desc() const { return stat_desc(var); }
  7328. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  7329. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  7330. /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
  7331. float get_epsilon() const { return base::get_epsilon(); }
  7332. /// Returns normalization flags.
  7333. /// @return Normalization flags.
  7334. normalization_flags get_flags() const {
  7335. return base::get_flags<normalization_flags>();
  7336. }
  7337. private:
  7338. enum {
  7339. mean = 1,
  7340. var = 2,
  7341. };
  7342. memory::desc stat_desc(int kind) const {
  7343. const bool use_global_stats
  7344. = (get_flags() & normalization_flags::use_global_stats)
  7345. != normalization_flags::none;
  7346. return query_md(
  7347. use_global_stats ? query::src_md : query::dst_md, kind);
  7348. }
  7349. };
  7350. /// Default constructor. Produces an empty object.
  7351. batch_normalization_forward() = default;
  7352. /// Constructs a batch normalization forward propagation primitive.
  7353. /// @param pd Primitive descriptor for a batch normalization forward
  7354. /// propagation primitive.
  7355. batch_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
  7356. /// Constructs a batch normalization forward propagation primitive from
  7357. /// a cache blob.
  7358. /// @param pd Primitive descriptor for a batch normalization forward
  7359. /// propagation primitive.
  7360. /// @param cache_blob Cache blob.
  7361. batch_normalization_forward(
  7362. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  7363. : primitive(pd, cache_blob) {}
  7364. };
  7365. /// Batch normalization backward propagation primitive.
  7366. struct batch_normalization_backward : public primitive {
  7367. /// Primitive descriptor for a batch normalization backward propagation
  7368. /// primitive.
  7369. struct primitive_desc : public dnnl::primitive_desc {
  7370. /// Default constructor. Produces an empty object.
  7371. primitive_desc() = default;
  7372. /// Constructs a primitive descriptor for a batch normalization backward
  7373. /// propagation primitive.
  7374. ///
  7375. /// @param aengine Engine to use.
  7376. /// @param aprop_kind Propagation kind. Possible values are
  7377. /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
  7378. /// (diffs for all parameters are computed in this case).
  7379. /// @param diff_src_desc Diff source memory descriptor.
  7380. /// @param diff_dst_desc Diff destination memory descriptor.
  7381. /// @param src_desc Source memory descriptor.
  7382. /// @param epsilon Batch normalization epsilon parameter.
  7383. /// @param flags Batch normalization flags (@ref
  7384. /// dnnl::normalization_flags).
  7385. /// @param hint_fwd_pd Primitive descriptor for a batch normalization
  7386. /// forward propagation primitive. It is used as a hint for
  7387. /// deciding which memory format to use.
  7388. /// @param attr Primitive attributes to use. Attributes are optional
  7389. /// and default to empty attributes.
  7390. /// @param allow_empty A flag signifying whether construction is
  7391. /// allowed to fail without throwing an exception. In this case an
  7392. /// empty object will be produced. This flag is optional and
  7393. /// defaults to false.
  7394. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7395. const memory::desc &diff_src_desc,
  7396. const memory::desc &diff_dst_desc, const memory::desc &src_desc,
  7397. float epsilon, normalization_flags flags,
  7398. const batch_normalization_forward::primitive_desc &hint_fwd_pd,
  7399. const primitive_attr &attr = default_attr(),
  7400. bool allow_empty = false) {
  7401. dnnl_primitive_desc_t pd = nullptr;
  7402. dnnl_status_t status
  7403. = dnnl_batch_normalization_backward_primitive_desc_create(
  7404. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  7405. diff_src_desc.get(), diff_dst_desc.get(),
  7406. src_desc.get(), epsilon, convert_to_c(flags),
  7407. hint_fwd_pd.get(), attr.get());
  7408. if (!allow_empty)
  7409. error::wrap_c_api(status,
  7410. "could not create a primitive descriptor for "
  7411. "the batch normalization backward propagation "
  7412. "primitive. Run workload with environment variable "
  7413. "ONEDNN_VERBOSE=all to get additional diagnostic "
  7414. "information.");
  7415. reset(pd);
  7416. }
  7417. /// Constructs a primitive descriptor for a batch normalization
  7418. /// backward propagation primitive from a C API primitive descriptor
  7419. /// that must have a matching kind.
  7420. ///
  7421. /// @param pd C API primitive descriptor for a batch normalization
  7422. /// backward propagation primitive.
  7423. primitive_desc(dnnl_primitive_desc_t pd)
  7424. : dnnl::primitive_desc(pd,
  7425. dnnl::primitive::kind::batch_normalization,
  7426. dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
  7427. }
  7428. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  7429. memory::desc src_desc() const { return base::src_desc(0); }
  7430. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  7431. memory::desc weights_desc() const { return base::weights_desc(0); }
  7432. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  7433. memory::desc dst_desc() const { return base::dst_desc(0); }
  7434. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  7435. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  7436. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  7437. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  7438. /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
  7439. memory::desc diff_weights_desc() const {
  7440. return base::diff_weights_desc(0);
  7441. }
  7442. /// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
  7443. memory::desc mean_desc() const { return query_md(query::src_md, 1); }
  7444. /// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
  7445. memory::desc variance_desc() const {
  7446. return query_md(query::src_md, 2);
  7447. }
  7448. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  7449. memory::desc workspace_desc() const { return base::workspace_desc(); }
  7450. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  7451. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  7452. /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
  7453. float get_epsilon() const { return base::get_epsilon(); }
  7454. /// Returns normalization flags.
  7455. /// @return Normalization flags.
  7456. normalization_flags get_flags() const {
  7457. return base::get_flags<normalization_flags>();
  7458. }
  7459. };
  7460. /// Default constructor. Produces an empty object.
  7461. batch_normalization_backward() = default;
  7462. /// Constructs a batch normalization backward propagation primitive.
  7463. /// @param pd Primitive descriptor for a batch normalization backward
  7464. /// propagation primitive.
  7465. batch_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
  7466. /// Constructs a batch normalization backward propagation primitive from
  7467. /// a cache blob.
  7468. /// @param pd Primitive descriptor for a batch normalization backward
  7469. /// propagation primitive.
  7470. /// @param cache_blob Cache blob.
  7471. batch_normalization_backward(
  7472. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  7473. : primitive(pd, cache_blob) {}
  7474. };
  7475. /// @} dnnl_api_batch_normalization
  7476. /// @addtogroup dnnl_api_group_normalization Group Normalization
  7477. ///
  7478. /// A primitive to perform group normalization.
  7479. ///
  7480. /// Both forward and backward propagation primitives support in-place
  7481. /// operation; that is, src and dst can refer to the same memory for forward
  7482. /// propagation, and diff_dst and diff_src can refer to the same memory for
  7483. /// backward propagation.
  7484. ///
  7485. /// The group normalization primitives computations can be controlled by
  7486. /// specifying different @ref dnnl::normalization_flags values. For example,
  7487. /// group normalization forward propagation can be configured to either
  7488. /// compute the mean and variance or take them as arguments. It can either
  7489. /// perform scaling and shifting using gamma and beta parameters or not.
  7490. ///
  7491. /// @sa @ref dev_guide_group_normalization in developer guide
  7492. ///
  7493. /// @{
  7494. /// Group normalization forward propagation primitive.
  7495. struct group_normalization_forward : public primitive {
  7496. /// Primitive descriptor for a group normalization forward propagation
  7497. /// primitive.
  7498. struct primitive_desc : public dnnl::primitive_desc {
  7499. /// Default constructor. Produces an empty object.
  7500. primitive_desc() = default;
  7501. /// Constructs a primitive descriptor for a group normalization forward
  7502. /// propagation primitive.
  7503. ///
  7504. /// @note
  7505. /// In-place operation is supported: the dst can refer to the same
  7506. /// memory as the src.
  7507. ///
  7508. /// @param aengine Engine to use.
  7509. /// @param aprop_kind Propagation kind. Possible values are
  7510. /// #dnnl::prop_kind::forward_training and
  7511. /// #dnnl::prop_kind::forward_inference.
  7512. /// @param src_desc Source memory descriptor.
  7513. /// @param dst_desc Destination memory descriptor.
  7514. /// @param groups Group normalization groups parameter.
  7515. /// @param epsilon Group normalization epsilon parameter.
  7516. /// @param flags Group normalization flags (@ref
  7517. /// dnnl::normalization_flags).
  7518. /// @param attr Primitive attributes to use. Attributes are optional
  7519. /// and default to empty attributes.
  7520. /// @param allow_empty A flag signifying whether construction is
  7521. /// allowed to fail without throwing an exception. In this case an
  7522. /// empty object will be produced. This flag is optional and
  7523. /// defaults to false.
  7524. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7525. const memory::desc &src_desc, const memory::desc &dst_desc,
  7526. memory::dim groups, float epsilon, normalization_flags flags,
  7527. const primitive_attr &attr = default_attr(),
  7528. bool allow_empty = false) {
  7529. dnnl_primitive_desc_t pd = nullptr;
  7530. dnnl_status_t status
  7531. = dnnl_group_normalization_forward_primitive_desc_create(
  7532. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  7533. src_desc.get(), dst_desc.get(), groups, epsilon,
  7534. convert_to_c(flags), attr.get());
  7535. if (!allow_empty)
  7536. error::wrap_c_api(status,
  7537. "could not create a primitive descriptor for "
  7538. "the group normalization forward propagation "
  7539. "primitive. Run workload with environment variable "
  7540. "ONEDNN_VERBOSE=all to get additional diagnostic "
  7541. "information.");
  7542. reset(pd);
  7543. }
  7544. /// Constructs a primitive descriptor for a group normalization
  7545. /// forward propagation primitive from a C API primitive descriptor
  7546. /// that must have a matching kind.
  7547. ///
  7548. /// @param pd C API primitive descriptor for a group normalization
  7549. /// forward propagation primitive.
  7550. primitive_desc(dnnl_primitive_desc_t pd)
  7551. : dnnl::primitive_desc(pd,
  7552. dnnl::primitive::kind::group_normalization,
  7553. dnnl::prop_kind::forward_training,
  7554. dnnl::prop_kind::forward_inference) {}
  7555. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  7556. memory::desc src_desc() const { return base::src_desc(0); }
  7557. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  7558. memory::desc dst_desc() const { return base::dst_desc(0); }
  7559. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  7560. memory::desc weights_desc() const { return base::weights_desc(0); }
  7561. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  7562. memory::desc workspace_desc() const { return base::workspace_desc(); }
  7563. /// Returns memory descriptor for mean.
  7564. /// @returns Memory descriptor for mean.
  7565. memory::desc mean_desc() const { return stat_desc(mean); }
  7566. /// Returns memory descriptor for variance.
  7567. /// @returns Memory descriptor for variance.
  7568. memory::desc variance_desc() const { return stat_desc(var); }
  7569. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  7570. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  7571. /// @copydoc dnnl::primitive_desc_base::get_group_size()const
  7572. memory::dim get_group_size() const { return base::get_group_size(); }
  7573. /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
  7574. float get_epsilon() const { return base::get_epsilon(); }
  7575. /// Returns normalization flags.
  7576. /// @return Normalization flags.
  7577. normalization_flags get_flags() const {
  7578. return base::get_flags<normalization_flags>();
  7579. }
  7580. private:
  7581. enum {
  7582. mean = 1,
  7583. var = 2,
  7584. };
  7585. memory::desc stat_desc(int kind) const {
  7586. const bool use_global_stats
  7587. = (get_flags() & normalization_flags::use_global_stats)
  7588. != normalization_flags::none;
  7589. return query_md(
  7590. use_global_stats ? query::src_md : query::dst_md, kind);
  7591. }
  7592. };
  7593. /// Default constructor. Produces an empty object.
  7594. group_normalization_forward() = default;
  7595. /// Constructs a group normalization forward propagation primitive.
  7596. /// @param pd Primitive descriptor for a group normalization forward
  7597. /// propagation primitive.
  7598. group_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
  7599. /// Constructs a group normalization forward propagation primitive from
  7600. /// a cache blob.
  7601. /// @param pd Primitive descriptor for a group normalization forward
  7602. /// propagation primitive.
  7603. /// @param cache_blob Cache blob.
  7604. group_normalization_forward(
  7605. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  7606. : primitive(pd, cache_blob) {}
  7607. };
  7608. /// Group normalization backward propagation primitive.
  7609. struct group_normalization_backward : public primitive {
  7610. /// Primitive descriptor for a group normalization backward propagation
  7611. /// primitive.
  7612. struct primitive_desc : public dnnl::primitive_desc {
  7613. /// Default constructor. Produces an empty object.
  7614. primitive_desc() = default;
  7615. /// Constructs a primitive descriptor for a group normalization backward
  7616. /// propagation primitive.
  7617. ///
  7618. /// @param aengine Engine to use.
  7619. /// @param aprop_kind Propagation kind. Possible values are
  7620. /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
  7621. /// (diffs for all parameters are computed in this case).
  7622. /// @param diff_src_desc Diff source memory descriptor.
  7623. /// @param diff_dst_desc Diff destination memory descriptor.
  7624. /// @param src_desc Source memory descriptor.
  7625. /// @param groups Group normalization groups parameter.
  7626. /// @param epsilon Group normalization epsilon parameter.
  7627. /// @param flags Group normalization flags (@ref
  7628. /// dnnl::normalization_flags).
  7629. /// @param hint_fwd_pd Primitive descriptor for a group normalization
  7630. /// forward propagation primitive. It is used as a hint for
  7631. /// deciding which memory format to use.
  7632. /// @param attr Primitive attributes to use. Attributes are optional
  7633. /// and default to empty attributes.
  7634. /// @param allow_empty A flag signifying whether construction is
  7635. /// allowed to fail without throwing an exception. In this case an
  7636. /// empty object will be produced. This flag is optional and
  7637. /// defaults to false.
  7638. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7639. const memory::desc &diff_src_desc,
  7640. const memory::desc &diff_dst_desc, const memory::desc &src_desc,
  7641. memory::dim groups, float epsilon, normalization_flags flags,
  7642. const group_normalization_forward::primitive_desc &hint_fwd_pd,
  7643. const primitive_attr &attr = default_attr(),
  7644. bool allow_empty = false) {
  7645. dnnl_primitive_desc_t pd = nullptr;
  7646. dnnl_status_t status
  7647. = dnnl_group_normalization_backward_primitive_desc_create(
  7648. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  7649. diff_src_desc.get(), diff_dst_desc.get(),
  7650. src_desc.get(), groups, epsilon,
  7651. convert_to_c(flags), hint_fwd_pd.get(), attr.get());
  7652. if (!allow_empty)
  7653. error::wrap_c_api(status,
  7654. "could not create a primitive descriptor for "
  7655. "the group normalization backward propagation "
  7656. "primitive. Run workload with environment variable "
  7657. "ONEDNN_VERBOSE=all to get additional diagnostic "
  7658. "information.");
  7659. reset(pd);
  7660. }
  7661. /// Constructs a primitive descriptor for a group normalization
  7662. /// backward propagation primitive from a C API primitive descriptor
  7663. /// that must have a matching kind.
  7664. ///
  7665. /// @param pd C API primitive descriptor for a group normalization
  7666. /// backward propagation primitive.
  7667. primitive_desc(dnnl_primitive_desc_t pd)
  7668. : dnnl::primitive_desc(pd,
  7669. dnnl::primitive::kind::group_normalization,
  7670. dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
  7671. }
  7672. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  7673. memory::desc src_desc() const { return base::src_desc(0); }
  7674. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  7675. memory::desc weights_desc() const { return base::weights_desc(0); }
  7676. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  7677. memory::desc dst_desc() const { return base::dst_desc(0); }
  7678. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  7679. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  7680. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  7681. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  7682. /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
  7683. memory::desc diff_weights_desc() const {
  7684. return base::diff_weights_desc(0);
  7685. }
  7686. /// @copydoc dnnl::group_normalization_forward::primitive_desc::mean_desc()const
  7687. memory::desc mean_desc() const { return query_md(query::src_md, 1); }
  7688. /// @copydoc dnnl::group_normalization_forward::primitive_desc::variance_desc()const
  7689. memory::desc variance_desc() const {
  7690. return query_md(query::src_md, 2);
  7691. }
  7692. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  7693. memory::desc workspace_desc() const { return base::workspace_desc(); }
  7694. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  7695. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  7696. /// @copydoc dnnl::primitive_desc_base::get_group_size()const
  7697. memory::dim get_group_size() const { return base::get_group_size(); }
  7698. /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
  7699. float get_epsilon() const { return base::get_epsilon(); }
  7700. /// Returns normalization flags.
  7701. /// @return Normalization flags.
  7702. normalization_flags get_flags() const {
  7703. return base::get_flags<normalization_flags>();
  7704. }
  7705. };
  7706. /// Default constructor. Produces an empty object.
  7707. group_normalization_backward() = default;
  7708. /// Constructs a group normalization backward propagation primitive.
  7709. /// @param pd Primitive descriptor for a group normalization backward
  7710. /// propagation primitive.
  7711. group_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
  7712. /// Constructs a group normalization backward propagation primitive from
  7713. /// a cache blob.
  7714. /// @param pd Primitive descriptor for a group normalization backward
  7715. /// propagation primitive.
  7716. /// @param cache_blob Cache blob.
  7717. group_normalization_backward(
  7718. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  7719. : primitive(pd, cache_blob) {}
  7720. };
  7721. /// @} dnnl_api_group_normalization
  7722. /// @addtogroup dnnl_api_layer_normalization Layer Normalization
  7723. ///
  7724. /// A primitive to perform layer normalization. Normalization is performed
  7725. /// within the last logical dimension of data tensor.
  7726. ///
  7727. /// Both forward and backward propagation primitives support in-place
  7728. /// operation; that is, src and dst can refer to the same memory for forward
  7729. /// propagation, and diff_dst and diff_src can refer to the same memory for
  7730. /// backward propagation.
  7731. ///
  7732. /// The layer normalization primitives computations can be controlled by
  7733. /// specifying different @ref dnnl::normalization_flags values. For example,
  7734. /// layer normalization forward propagation can be configured to either
  7735. /// compute the mean and variance or take them as arguments. It can either
  7736. /// perform scaling and shifting using gamma and beta parameters or not.
  7737. ///
  7738. /// @sa @ref dev_guide_layer_normalization in developer guide
  7739. ///
  7740. /// @{
  7741. /// Layer normalization forward propagation primitive.
  7742. struct layer_normalization_forward : public primitive {
  7743. /// Primitive descriptor for a layer normalization forward propagation
  7744. /// primitive.
  7745. struct primitive_desc : public dnnl::primitive_desc {
  7746. /// Default constructor. Produces an empty object.
  7747. primitive_desc() = default;
  7748. /// Constructs a primitive descriptor for a layer normalization forward
  7749. /// propagation primitive.
  7750. ///
  7751. /// @param aengine Engine to use.
  7752. /// @param aprop_kind Propagation kind. Possible values are
  7753. /// #dnnl::prop_kind::forward_training, and
  7754. /// #dnnl::prop_kind::forward_inference.
  7755. /// @param src_desc Source memory descriptor.
  7756. /// @param dst_desc Destination memory descriptor.
  7757. /// @param stat_desc Statistics memory descriptors.
  7758. /// @param epsilon Layer normalization epsilon parameter.
  7759. /// @param flags Layer normalization flags (@ref
  7760. /// dnnl::normalization_flags).
  7761. /// @param attr Primitive attributes to use. Attributes are optional
  7762. /// and default to empty attributes.
  7763. /// @param allow_empty A flag signifying whether construction is
  7764. /// allowed to fail without throwing an exception. In this case an
  7765. /// empty object will be produced. This flag is optional and
  7766. /// defaults to false.
  7767. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7768. const memory::desc &src_desc, const memory::desc &dst_desc,
  7769. const memory::desc &stat_desc, float epsilon,
  7770. normalization_flags flags,
  7771. const primitive_attr &attr = default_attr(),
  7772. bool allow_empty = false)
  7773. : primitive_desc(aengine, aprop_kind, src_desc, dst_desc,
  7774. &stat_desc, memory::data_type::f32, epsilon, flags, attr,
  7775. allow_empty) {}
  7776. /// Constructs a primitive descriptor for a layer normalization forward
  7777. /// propagation primitive.
  7778. ///
  7779. /// @param aengine Engine to use.
  7780. /// @param aprop_kind Propagation kind. Possible values are
  7781. /// #dnnl::prop_kind::forward_training, and
  7782. /// #dnnl::prop_kind::forward_inference.
  7783. /// @param src_desc Source memory descriptor.
  7784. /// @param dst_desc Destination memory descriptor.
  7785. /// @param epsilon Layer normalization epsilon parameter.
  7786. /// @param flags Layer normalization flags (@ref
  7787. /// dnnl::normalization_flags).
  7788. /// @param attr Primitive attributes to use. Attributes are optional
  7789. /// and default to empty attributes.
  7790. /// @param allow_empty A flag signifying whether construction is
  7791. /// allowed to fail without throwing an exception. In this case an
  7792. /// empty object will be produced. This flag is optional and
  7793. /// defaults to false.
  7794. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7795. const memory::desc &src_desc, const memory::desc &dst_desc,
  7796. float epsilon, normalization_flags flags,
  7797. const primitive_attr &attr = default_attr(),
  7798. bool allow_empty = false)
  7799. : primitive_desc(aengine, aprop_kind, src_desc, dst_desc, nullptr,
  7800. memory::data_type::f32, epsilon, flags, attr, allow_empty) {
  7801. }
  7802. /// Constructs a primitive descriptor for a layer normalization forward
  7803. /// propagation primitive with a user-provided data type for the scale
  7804. /// and shift memory objects.
  7805. ///
  7806. /// @param aengine Engine to use.
  7807. /// @param aprop_kind Propagation kind. Possible values are
  7808. /// #dnnl::prop_kind::forward_training, and
  7809. /// #dnnl::prop_kind::forward_inference.
  7810. /// @param src_desc Source memory descriptor.
  7811. /// @param dst_desc Destination memory descriptor.
  7812. /// @param stat_desc Statistics memory descriptors.
  7813. /// @param scale_shift_data_type Data type of scale and shift memory.
  7814. /// If neither scale nor shift flag are specified the parameter
  7815. /// is ignored.
  7816. /// @param epsilon Layer normalization epsilon parameter.
  7817. /// @param flags Layer normalization flags (@ref
  7818. /// dnnl::normalization_flags).
  7819. /// @param attr Primitive attributes to use. Attributes are optional
  7820. /// and default to empty attributes.
  7821. /// @param allow_empty A flag signifying whether construction is
  7822. /// allowed to fail without throwing an exception. In this case an
  7823. /// empty object will be produced. This flag is optional and
  7824. /// defaults to false.
  7825. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7826. const memory::desc &src_desc, const memory::desc &dst_desc,
  7827. const memory::desc &stat_desc,
  7828. memory::data_type scale_shift_data_type, float epsilon,
  7829. normalization_flags flags,
  7830. const primitive_attr &attr = default_attr(),
  7831. bool allow_empty = false)
  7832. : primitive_desc(aengine, aprop_kind, src_desc, dst_desc,
  7833. &stat_desc, scale_shift_data_type, epsilon, flags, attr,
  7834. allow_empty) {}
  7835. /// Constructs a primitive descriptor for a layer normalization forward
  7836. /// propagation primitive with a user-provided data type for the scale
  7837. /// and shift memory objects.
  7838. ///
  7839. /// @param aengine Engine to use.
  7840. /// @param aprop_kind Propagation kind. Possible values are
  7841. /// #dnnl::prop_kind::forward_training, and
  7842. /// #dnnl::prop_kind::forward_inference.
  7843. /// @param src_desc Source memory descriptor.
  7844. /// @param dst_desc Destination memory descriptor.
  7845. /// @param scale_shift_data_type Data type of scale and shift memory.
  7846. /// If neither scale nor shift flag are specified the parameter
  7847. /// is ignored.
  7848. /// @param epsilon Layer normalization epsilon parameter.
  7849. /// @param flags Layer normalization flags (@ref
  7850. /// dnnl::normalization_flags).
  7851. /// @param attr Primitive attributes to use. Attributes are optional
  7852. /// and default to empty attributes.
  7853. /// @param allow_empty A flag signifying whether construction is
  7854. /// allowed to fail without throwing an exception. In this case an
  7855. /// empty object will be produced. This flag is optional and
  7856. /// defaults to false.
  7857. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7858. const memory::desc &src_desc, const memory::desc &dst_desc,
  7859. memory::data_type scale_shift_data_type, float epsilon,
  7860. normalization_flags flags,
  7861. const primitive_attr &attr = default_attr(),
  7862. bool allow_empty = false)
  7863. : primitive_desc(aengine, aprop_kind, src_desc, dst_desc, nullptr,
  7864. scale_shift_data_type, epsilon, flags, attr, allow_empty) {}
  7865. /// Constructs a primitive descriptor for a layer normalization
  7866. /// forward propagation primitive from a C API primitive descriptor
  7867. /// that must have a matching kind.
  7868. ///
  7869. /// @param pd C API primitive descriptor for a layer normalization
  7870. /// forward propagation primitive.
  7871. primitive_desc(dnnl_primitive_desc_t pd)
  7872. : dnnl::primitive_desc(pd,
  7873. dnnl::primitive::kind::layer_normalization,
  7874. dnnl::prop_kind::forward_training,
  7875. dnnl::prop_kind::forward_inference) {}
  7876. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  7877. memory::desc src_desc() const { return base::src_desc(0); }
  7878. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  7879. memory::desc dst_desc() const { return base::dst_desc(0); }
  7880. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  7881. memory::desc weights_desc() const { return base::weights_desc(0); }
  7882. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  7883. memory::desc workspace_desc() const { return base::workspace_desc(); }
  7884. /// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
  7885. memory::desc mean_desc() const { return stat_desc(mean); }
  7886. /// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
  7887. memory::desc variance_desc() const { return stat_desc(var); }
  7888. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  7889. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  7890. /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
  7891. float get_epsilon() const { return base::get_epsilon(); }
  7892. /// Returns normalization flags.
  7893. /// @return Normalization flags.
  7894. normalization_flags get_flags() const {
  7895. return base::get_flags<normalization_flags>();
  7896. }
  7897. private:
  7898. enum {
  7899. mean = 1,
  7900. var = 2,
  7901. };
  7902. memory::desc stat_desc(int kind) const {
  7903. const bool use_global_stats
  7904. = (get_flags() & normalization_flags::use_global_stats)
  7905. != normalization_flags::none;
  7906. return query_md(
  7907. use_global_stats ? query::src_md : query::dst_md, kind);
  7908. }
  7909. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7910. const memory::desc &src_desc, const memory::desc &dst_desc,
  7911. const memory::desc *stat_desc,
  7912. memory::data_type scale_shift_data_type, float epsilon,
  7913. normalization_flags flags, const primitive_attr &attr,
  7914. bool allow_empty) {
  7915. dnnl_primitive_desc_t pd = nullptr;
  7916. dnnl_status_t status
  7917. = dnnl_layer_normalization_forward_primitive_desc_create_v2(
  7918. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  7919. src_desc.get(), dst_desc.get(),
  7920. optional_arg(stat_desc),
  7921. memory::convert_to_c(scale_shift_data_type),
  7922. epsilon, convert_to_c(flags), attr.get());
  7923. if (!allow_empty)
  7924. error::wrap_c_api(status,
  7925. "could not create a primitive descriptor for "
  7926. "the layer normalization forward propagation "
  7927. "primitive. Run workload with environment variable "
  7928. "ONEDNN_VERBOSE=all to get additional diagnostic "
  7929. "information.");
  7930. reset(pd);
  7931. }
  7932. };
  7933. /// Default constructor. Produces an empty object.
  7934. layer_normalization_forward() = default;
  7935. /// Constructs a layer normalization forward propagation primitive.
  7936. /// @param pd Primitive descriptor for a layer normalization forward
  7937. /// propagation primitive.
  7938. layer_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
  7939. /// Constructs a layer normalization forward propagation primitive from
  7940. /// a cache blob.
  7941. /// @param pd Primitive descriptor for a layer normalization forward
  7942. /// propagation primitive.
  7943. /// @param cache_blob Cache blob.
  7944. layer_normalization_forward(
  7945. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  7946. : primitive(pd, cache_blob) {}
  7947. };
  7948. /// Layer normalization backward propagation primitive.
  7949. struct layer_normalization_backward : public primitive {
  7950. /// Primitive descriptor for a layer normalization backward propagation
  7951. /// primitive.
  7952. struct primitive_desc : public dnnl::primitive_desc {
  7953. /// Default constructor. Produces an empty object.
  7954. primitive_desc() = default;
  7955. /// Constructs a primitive descriptor for a layer normalization backward
  7956. /// propagation primitive.
  7957. ///
  7958. /// @param aengine Engine to use.
  7959. /// @param aprop_kind Propagation kind. Possible values are
  7960. /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
  7961. /// (diffs for all parameters are computed in this case).
  7962. /// @param diff_src_desc Diff source memory descriptor.
  7963. /// @param diff_dst_desc Diff destination memory descriptor.
  7964. /// @param src_desc Source memory descriptor.
  7965. /// @param stat_desc Statistics memory descriptors.
  7966. /// @param epsilon Layer normalization epsilon parameter.
  7967. /// @param flags Layer normalization flags (@ref
  7968. /// dnnl::normalization_flags).
  7969. /// @param attr Primitive attributes to use. Attributes are optional
  7970. /// and default to empty attributes.
  7971. /// @param hint_fwd_pd Primitive descriptor for a layer normalization
  7972. /// forward propagation primitive. It is used as a hint for
  7973. /// deciding which memory format to use.
  7974. /// @param allow_empty A flag signifying whether construction is
  7975. /// allowed to fail without throwing an exception. In this case an
  7976. /// empty object will be produced. This flag is optional and
  7977. /// defaults to false.
  7978. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7979. const memory::desc &diff_src_desc,
  7980. const memory::desc &diff_dst_desc, const memory::desc &src_desc,
  7981. const memory::desc &stat_desc, float epsilon,
  7982. normalization_flags flags,
  7983. const layer_normalization_forward::primitive_desc &hint_fwd_pd,
  7984. const primitive_attr &attr = default_attr(),
  7985. bool allow_empty = false)
  7986. : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc,
  7987. src_desc, &stat_desc, memory::data_type::f32,
  7988. memory::data_type::f32, epsilon, flags, hint_fwd_pd, attr,
  7989. allow_empty) {}
  7990. /// Constructs a primitive descriptor for a layer normalization backward
  7991. /// propagation primitive.
  7992. ///
  7993. /// @param aengine Engine to use.
  7994. /// @param aprop_kind Propagation kind. Possible values are
  7995. /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
  7996. /// (diffs for all parameters are computed in this case).
  7997. /// @param diff_src_desc Diff source memory descriptor.
  7998. /// @param diff_dst_desc Diff destination memory descriptor.
  7999. /// @param src_desc Source memory descriptor.
  8000. /// @param epsilon Layer normalization epsilon parameter.
  8001. /// @param flags Layer normalization flags (@ref
  8002. /// dnnl::normalization_flags).
  8003. /// @param attr Primitive attributes to use. Attributes are optional
  8004. /// and default to empty attributes.
  8005. /// @param hint_fwd_pd Primitive descriptor for a layer normalization
  8006. /// forward propagation primitive. It is used as a hint for
  8007. /// deciding which memory format to use.
  8008. /// @param allow_empty A flag signifying whether construction is
  8009. /// allowed to fail without throwing an exception. In this case an
  8010. /// empty object will be produced. This flag is optional and
  8011. /// defaults to false.
  8012. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  8013. const memory::desc &diff_src_desc,
  8014. const memory::desc &diff_dst_desc, const memory::desc &src_desc,
  8015. float epsilon, normalization_flags flags,
  8016. const layer_normalization_forward::primitive_desc &hint_fwd_pd,
  8017. const primitive_attr &attr = default_attr(),
  8018. bool allow_empty = false)
  8019. : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc,
  8020. src_desc, nullptr, memory::data_type::f32,
  8021. memory::data_type::f32, epsilon, flags, hint_fwd_pd, attr,
  8022. allow_empty) {}
  8023. /// Constructs a primitive descriptor for a layer normalization backward
  8024. /// propagation primitive with a user-provided data type for the scale
  8025. /// and shift memory objects.
  8026. ///
  8027. /// @param aengine Engine to use.
  8028. /// @param aprop_kind Propagation kind. Possible values are
  8029. /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
  8030. /// (diffs for all parameters are computed in this case).
  8031. /// @param diff_src_desc Diff source memory descriptor.
  8032. /// @param diff_dst_desc Diff destination memory descriptor.
  8033. /// @param src_desc Source memory descriptor.
  8034. /// @param stat_desc Statistics memory descriptors.
  8035. /// @param diff_scale_shift_data_type Data type of diff scale and shift
  8036. /// memory. If neither scale nor shift flag are specified the
  8037. /// parameter is ignored.
  8038. /// @param scale_shift_data_type Data type of scale and shift memory.
  8039. /// If neither scale nor shift flag are specified the parameter
  8040. /// is ignored.
  8041. /// @param epsilon Layer normalization epsilon parameter.
  8042. /// @param flags Layer normalization flags (@ref
  8043. /// dnnl::normalization_flags).
  8044. /// @param attr Primitive attributes to use. Attributes are optional
  8045. /// and default to empty attributes.
  8046. /// @param hint_fwd_pd Primitive descriptor for a layer normalization
  8047. /// forward propagation primitive. It is used as a hint for
  8048. /// deciding which memory format to use.
  8049. /// @param allow_empty A flag signifying whether construction is
  8050. /// allowed to fail without throwing an exception. In this case an
  8051. /// empty object will be produced. This flag is optional and
  8052. /// defaults to false.
  8053. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  8054. const memory::desc &diff_src_desc,
  8055. const memory::desc &diff_dst_desc, const memory::desc &src_desc,
  8056. const memory::desc &stat_desc,
  8057. memory::data_type diff_scale_shift_data_type,
  8058. memory::data_type scale_shift_data_type, float epsilon,
  8059. normalization_flags flags,
  8060. const layer_normalization_forward::primitive_desc &hint_fwd_pd,
  8061. const primitive_attr &attr = default_attr(),
  8062. bool allow_empty = false)
  8063. : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc,
  8064. src_desc, &stat_desc, diff_scale_shift_data_type,
  8065. scale_shift_data_type, epsilon, flags, hint_fwd_pd, attr,
  8066. allow_empty) {}
  8067. /// Constructs a primitive descriptor for a layer normalization backward
  8068. /// propagation primitive with a user-provided data type for the scale
  8069. /// and shift memory objects.
  8070. ///
  8071. /// @param aengine Engine to use.
  8072. /// @param aprop_kind Propagation kind. Possible values are
  8073. /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
  8074. /// (diffs for all parameters are computed in this case).
  8075. /// @param diff_src_desc Diff source memory descriptor.
  8076. /// @param diff_dst_desc Diff destination memory descriptor.
  8077. /// @param src_desc Source memory descriptor.
  8078. /// @param diff_scale_shift_data_type Data type of diff scale and shift
  8079. /// memory. If neither scale nor shift flag are specified the
  8080. /// parameter is ignored.
  8081. /// @param scale_shift_data_type Data type of scale and shift memory.
  8082. /// If neither scale nor shift flag are specified the parameter
  8083. /// is ignored.
  8084. /// @param epsilon Layer normalization epsilon parameter.
  8085. /// @param flags Layer normalization flags (@ref
  8086. /// dnnl::normalization_flags).
  8087. /// @param attr Primitive attributes to use. Attributes are optional
  8088. /// and default to empty attributes.
  8089. /// @param hint_fwd_pd Primitive descriptor for a layer normalization
  8090. /// forward propagation primitive. It is used as a hint for
  8091. /// deciding which memory format to use.
  8092. /// @param allow_empty A flag signifying whether construction is
  8093. /// allowed to fail without throwing an exception. In this case an
  8094. /// empty object will be produced. This flag is optional and
  8095. /// defaults to false.
  8096. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  8097. const memory::desc &diff_src_desc,
  8098. const memory::desc &diff_dst_desc, const memory::desc &src_desc,
  8099. memory::data_type diff_scale_shift_data_type,
  8100. memory::data_type scale_shift_data_type, float epsilon,
  8101. normalization_flags flags,
  8102. const layer_normalization_forward::primitive_desc &hint_fwd_pd,
  8103. const primitive_attr &attr = default_attr(),
  8104. bool allow_empty = false)
  8105. : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc,
  8106. src_desc, nullptr, diff_scale_shift_data_type,
  8107. scale_shift_data_type, epsilon, flags, hint_fwd_pd, attr,
  8108. allow_empty) {}
  8109. /// Constructs a primitive descriptor for a layer normalization
  8110. /// backward propagation primitive from a C API primitive descriptor
  8111. /// that must have a matching kind.
  8112. ///
  8113. /// @param pd C API primitive descriptor for a layer normalization
  8114. /// backward propagation primitive.
  8115. primitive_desc(dnnl_primitive_desc_t pd)
  8116. : dnnl::primitive_desc(pd,
  8117. dnnl::primitive::kind::layer_normalization,
  8118. dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
  8119. }
  8120. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  8121. memory::desc src_desc() const { return base::src_desc(0); }
  8122. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  8123. memory::desc weights_desc() const { return base::weights_desc(0); }
  8124. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  8125. memory::desc dst_desc() const { return base::dst_desc(0); }
  8126. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  8127. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  8128. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  8129. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  8130. /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
  8131. memory::desc diff_weights_desc() const {
  8132. return base::diff_weights_desc(0);
  8133. }
  8134. /// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
  8135. memory::desc mean_desc() const { return query_md(query::src_md, 1); }
  8136. /// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
  8137. memory::desc variance_desc() const {
  8138. return query_md(query::src_md, 2);
  8139. }
  8140. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  8141. memory::desc workspace_desc() const { return base::workspace_desc(); }
  8142. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  8143. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  8144. /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
  8145. float get_epsilon() const { return base::get_epsilon(); }
  8146. /// Returns normalization flags.
  8147. /// @return Normalization flags.
  8148. normalization_flags get_flags() const {
  8149. return base::get_flags<normalization_flags>();
  8150. }
  8151. private:
  8152. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  8153. const memory::desc &diff_src_desc,
  8154. const memory::desc &diff_dst_desc, const memory::desc &src_desc,
  8155. const memory::desc *stat_desc,
  8156. memory::data_type diff_scale_shift_data_type,
  8157. memory::data_type scale_shift_data_type, float epsilon,
  8158. normalization_flags flags,
  8159. const layer_normalization_forward::primitive_desc &hint_fwd_pd,
  8160. const primitive_attr &attr, bool allow_empty) {
  8161. dnnl_primitive_desc_t pd = nullptr;
  8162. dnnl_status_t status
  8163. = dnnl_layer_normalization_backward_primitive_desc_create_v2(
  8164. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  8165. diff_src_desc.get(), diff_dst_desc.get(),
  8166. src_desc.get(), optional_arg(stat_desc),
  8167. memory::convert_to_c(diff_scale_shift_data_type),
  8168. memory::convert_to_c(scale_shift_data_type),
  8169. epsilon, convert_to_c(flags), hint_fwd_pd.get(),
  8170. attr.get());
  8171. if (!allow_empty)
  8172. error::wrap_c_api(status,
  8173. "could not create a primitive descriptor for "
  8174. "the layer normalization backward propagation "
  8175. "primitive. Run workload with environment variable "
  8176. "ONEDNN_VERBOSE=all to get additional diagnostic "
  8177. "information.");
  8178. reset(pd);
  8179. }
  8180. };
  8181. /// Default constructor. Produces an empty object.
  8182. layer_normalization_backward() = default;
  8183. /// Constructs a layer normalization backward propagation primitive.
  8184. /// @param pd Primitive descriptor for a layer normalization backward
  8185. /// propagation primitive.
  8186. layer_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
  8187. /// Constructs a layer normalization backward propagation primitive from
  8188. /// a cache blob.
  8189. /// @param pd Primitive descriptor for a layer normalization backward
  8190. /// propagation primitive.
  8191. /// @param cache_blob Cache blob.
  8192. layer_normalization_backward(
  8193. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  8194. : primitive(pd, cache_blob) {}
  8195. };
  8196. /// @} dnnl_api_layer_normalization
  8197. /// @addtogroup dnnl_api_inner_product Inner Product
  8198. ///
  8199. /// A primitive to compute an inner product.
  8200. ///
  8201. /// @sa @ref dev_guide_inner_product in developer guide
  8202. ///
  8203. /// @{
  8204. /// Inner product forward propagation primitive.
  8205. struct inner_product_forward : public primitive {
  8206. /// Primitive descriptor for an inner product forward propagation primitive.
  8207. struct primitive_desc : public dnnl::primitive_desc {
  8208. /// Default constructor. Produces an empty object.
  8209. primitive_desc() = default;
  8210. /// Constructs a primitive descriptor for an inner product forward
  8211. /// propagation primitive with bias.
  8212. ///
  8213. /// @note
  8214. /// All the memory descriptors may be initialized with the
  8215. /// #dnnl::memory::format_tag::any value of @p format_tag.
  8216. ///
  8217. /// @param aengine Engine to use.
  8218. /// @param aprop_kind Propagation kind. Possible values are
  8219. /// #dnnl::prop_kind::forward_training, and
  8220. /// #dnnl::prop_kind::forward_inference.
  8221. /// @param src_desc Memory descriptor for src.
  8222. /// @param weights_desc Memory descriptor for weights.
  8223. /// @param bias_desc Memory descriptor for bias.
  8224. /// @param dst_desc Memory descriptor for dst.
  8225. /// @param attr Primitive attributes to use. Attributes are optional
  8226. /// and default to empty attributes.
  8227. /// @param allow_empty A flag signifying whether construction is
  8228. /// allowed to fail without throwing an exception. In this case an
  8229. /// empty object will be produced. This flag is optional and
  8230. /// defaults to false.
  8231. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  8232. const memory::desc &src_desc, const memory::desc &weights_desc,
  8233. const memory::desc &bias_desc, const memory::desc &dst_desc,
  8234. const primitive_attr &attr = default_attr(),
  8235. bool allow_empty = false)
  8236. : primitive_desc(aengine, aprop_kind, src_desc, weights_desc,
  8237. &bias_desc, dst_desc, attr, allow_empty) {}
  8238. /// Constructs a primitive descriptor for an inner product forward
  8239. /// propagation primitive.
  8240. ///
  8241. /// @note
  8242. /// All the memory descriptors may be initialized with the
  8243. /// #dnnl::memory::format_tag::any value of @p format_tag.
  8244. ///
  8245. /// @param aengine Engine to use.
  8246. /// @param aprop_kind Propagation kind. Possible values are
  8247. /// #dnnl::prop_kind::forward_training, and
  8248. /// #dnnl::prop_kind::forward_inference.
  8249. /// @param src_desc Memory descriptor for src.
  8250. /// @param weights_desc Memory descriptor for weights.
  8251. /// @param dst_desc Memory descriptor for dst.
  8252. /// @param attr Primitive attributes to use. Attributes are optional
  8253. /// and default to empty attributes.
  8254. /// @param allow_empty A flag signifying whether construction is
  8255. /// allowed to fail without throwing an exception. In this case an
  8256. /// empty object will be produced. This flag is optional and
  8257. /// defaults to false.
  8258. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  8259. const memory::desc &src_desc, const memory::desc &weights_desc,
  8260. const memory::desc &dst_desc,
  8261. const primitive_attr &attr = default_attr(),
  8262. bool allow_empty = false)
  8263. : primitive_desc(aengine, aprop_kind, src_desc, weights_desc,
  8264. nullptr, dst_desc, attr, allow_empty) {}
  8265. /// Constructs a primitive descriptor for an inner product forward
  8266. /// propagation primitive from a C API primitive descriptor that must
  8267. /// have a matching kind.
  8268. ///
  8269. /// @param pd C API primitive descriptor for an inner product forward
  8270. /// propagation primitive.
  8271. primitive_desc(dnnl_primitive_desc_t pd)
  8272. : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
  8273. dnnl::prop_kind::forward_training,
  8274. dnnl::prop_kind::forward_inference) {}
  8275. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  8276. memory::desc src_desc() const { return base::src_desc(0); }
  8277. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  8278. memory::desc weights_desc() const { return base::weights_desc(0); }
  8279. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  8280. memory::desc dst_desc() const { return base::dst_desc(0); }
  8281. /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
  8282. memory::desc bias_desc() const { return base::weights_desc(1); }
  8283. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  8284. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  8285. private:
  8286. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  8287. const memory::desc &src_desc, const memory::desc &weights_desc,
  8288. const memory::desc *bias_desc, const memory::desc &dst_desc,
  8289. const primitive_attr &attr, bool allow_empty) {
  8290. dnnl_primitive_desc_t pd = nullptr;
  8291. dnnl_status_t status
  8292. = dnnl_inner_product_forward_primitive_desc_create(&pd,
  8293. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8294. src_desc.get(), weights_desc.get(),
  8295. optional_arg(bias_desc), dst_desc.get(),
  8296. attr.get());
  8297. if (!allow_empty)
  8298. error::wrap_c_api(status,
  8299. "could not create a primitive descriptor for "
  8300. "the inner product forward propagation primitive. Run "
  8301. "workload with environment variable ONEDNN_VERBOSE=all "
  8302. "to get additional diagnostic information.");
  8303. reset(pd);
  8304. }
  8305. };
  8306. /// Default constructor. Produces an empty object.
  8307. inner_product_forward() = default;
  8308. /// Constructs an inner product forward propagation primitive.
  8309. /// @param pd Primitive descriptor for an inner product forward
  8310. /// propagation primitive.
  8311. inner_product_forward(const primitive_desc &pd) : primitive(pd) {}
  8312. /// Constructs an inner product forward propagation primitive from
  8313. /// a cache blob.
  8314. /// @param pd Primitive descriptor for an inner product forward
  8315. /// propagation primitive.
  8316. /// @param cache_blob Cache blob.
  8317. inner_product_forward(
  8318. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  8319. : primitive(pd, cache_blob) {}
  8320. };
  8321. /// Inner product backward propagation primitive.
  8322. struct inner_product_backward_data : public primitive {
  8323. /// Primitive descriptor for an inner product backward propagation
  8324. /// primitive.
  8325. struct primitive_desc : public dnnl::primitive_desc {
  8326. /// Default constructor. Produces an empty object.
  8327. primitive_desc() = default;
  8328. /// Constructs a primitive descriptor for an inner product backward
  8329. /// propagation primitive.
  8330. ///
  8331. /// @note
  8332. /// All the memory descriptors may be initialized with the
  8333. /// #dnnl::memory::format_tag::any value of @p format_tag.
  8334. ///
  8335. /// @param aengine Engine to use.
  8336. /// @param diff_src_desc Memory descriptor for diff src.
  8337. /// @param weights_desc Memory descriptor for weights.
  8338. /// @param diff_dst_desc Memory descriptor for diff dst.
  8339. /// @param hint_fwd_pd Primitive descriptor for an inner product
  8340. /// forward propagation primitive. It is used as a hint for
  8341. /// deciding which memory format to use.
  8342. /// @param attr Primitive attributes to use. Attributes are optional
  8343. /// and default to empty attributes.
  8344. /// @param allow_empty A flag signifying whether construction is
  8345. /// allowed to fail without throwing an exception. In this case an
  8346. /// empty object will be produced. This flag is optional and
  8347. /// defaults to false.
  8348. primitive_desc(const engine &aengine, const memory::desc &diff_src_desc,
  8349. const memory::desc &weights_desc,
  8350. const memory::desc &diff_dst_desc,
  8351. const inner_product_forward::primitive_desc &hint_fwd_pd,
  8352. const primitive_attr &attr = default_attr(),
  8353. bool allow_empty = false) {
  8354. dnnl_primitive_desc_t pd = nullptr;
  8355. dnnl_status_t status
  8356. = dnnl_inner_product_backward_data_primitive_desc_create(
  8357. &pd, aengine.get(), diff_src_desc.get(),
  8358. weights_desc.get(), diff_dst_desc.get(),
  8359. hint_fwd_pd.get(), attr.get());
  8360. if (!allow_empty)
  8361. error::wrap_c_api(status,
  8362. "could not create a primitive descriptor for "
  8363. "the inner product backward propagation primitive. Run "
  8364. "workload with environment variable ONEDNN_VERBOSE=all "
  8365. "to get additional diagnostic information.");
  8366. reset(pd);
  8367. }
  8368. /// Constructs a primitive descriptor for an inner product backward
  8369. /// propagation primitive from a C API primitive descriptor that must
  8370. /// have a matching kind.
  8371. ///
  8372. /// @param pd C API primitive descriptor for an inner product backward
  8373. /// propagation primitive.
  8374. primitive_desc(dnnl_primitive_desc_t pd)
  8375. : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
  8376. dnnl::prop_kind::backward_data) {}
  8377. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  8378. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  8379. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  8380. memory::desc weights_desc() const { return base::weights_desc(0); }
  8381. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  8382. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  8383. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  8384. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  8385. };
  8386. /// Default constructor. Produces an empty object.
  8387. inner_product_backward_data() = default;
  8388. /// Constructs an inner product backward propagation primitive.
  8389. /// @param pd Primitive descriptor for an inner product backward
  8390. /// propagation primitive.
  8391. inner_product_backward_data(const primitive_desc &pd) : primitive(pd) {}
  8392. /// Constructs an inner product backward propagation primitive from
  8393. /// a cache blob.
  8394. /// @param pd Primitive descriptor for an inner product backward
  8395. /// propagation primitive.
  8396. /// @param cache_blob Cache blob.
  8397. inner_product_backward_data(
  8398. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  8399. : primitive(pd, cache_blob) {}
  8400. };
  8401. /// Inner product weights gradient primitive.
  8402. struct inner_product_backward_weights : public primitive {
  8403. /// Primitive descriptor for an inner product weights gradient primitive.
  8404. struct primitive_desc : public dnnl::primitive_desc {
  8405. /// Default constructor. Produces an empty object.
  8406. primitive_desc() = default;
  8407. /// Constructs a primitive descriptor for an inner product weights
  8408. /// update primitive with bias.
  8409. ///
  8410. /// @note
  8411. /// All the memory descriptors may be initialized with the
  8412. /// #dnnl::memory::format_tag::any value of @p format_tag.
  8413. ///
  8414. /// @param aengine Engine to use.
  8415. /// @param src_desc Memory descriptor for src.
  8416. /// @param diff_weights_desc Memory descriptor for diff weights.
  8417. /// @param diff_bias_desc Memory descriptor for diff bias.
  8418. /// @param diff_dst_desc Memory descriptor for diff dst.
  8419. /// @param hint_fwd_pd Primitive descriptor for an inner product
  8420. /// forward propagation primitive. It is used as a hint for
  8421. /// deciding which memory format to use.
  8422. /// @param attr Primitive attributes to use. Attributes are optional
  8423. /// and default to empty attributes.
  8424. /// @param allow_empty A flag signifying whether construction is
  8425. /// allowed to fail without throwing an exception. In this case an
  8426. /// empty object will be produced. This flag is optional and
  8427. /// defaults to false.
  8428. primitive_desc(const engine &aengine, const memory::desc &src_desc,
  8429. const memory::desc &diff_weights_desc,
  8430. const memory::desc &diff_bias_desc,
  8431. const memory::desc &diff_dst_desc,
  8432. const inner_product_forward::primitive_desc &hint_fwd_pd,
  8433. const primitive_attr &attr = default_attr(),
  8434. bool allow_empty = false)
  8435. : primitive_desc(aengine, src_desc, diff_weights_desc,
  8436. &diff_bias_desc, diff_dst_desc, hint_fwd_pd, attr,
  8437. allow_empty) {}
  8438. /// Constructs a primitive descriptor for an inner product weights
  8439. /// update primitive.
  8440. ///
  8441. /// @note
  8442. /// All the memory descriptors may be initialized with the
  8443. /// #dnnl::memory::format_tag::any value of @p format_tag.
  8444. ///
  8445. /// @param aengine Engine to use.
  8446. /// @param src_desc Memory descriptor for src.
  8447. /// @param diff_weights_desc Memory descriptor for diff weights.
  8448. /// @param diff_dst_desc Memory descriptor for diff dst.
  8449. /// @param attr Primitive attributes to use. Attributes are optional
  8450. /// and default to empty attributes.
  8451. /// @param hint_fwd_pd Primitive descriptor for an inner product
  8452. /// forward propagation primitive. It is used as a hint for
  8453. /// deciding which memory format to use.
  8454. /// @param allow_empty A flag signifying whether construction is
  8455. /// allowed to fail without throwing an exception. In this case an
  8456. /// empty object will be produced. This flag is optional and
  8457. /// defaults to false.
  8458. primitive_desc(const engine &aengine, const memory::desc &src_desc,
  8459. const memory::desc &diff_weights_desc,
  8460. const memory::desc &diff_dst_desc,
  8461. const inner_product_forward::primitive_desc &hint_fwd_pd,
  8462. const primitive_attr &attr = default_attr(),
  8463. bool allow_empty = false)
  8464. : primitive_desc(aengine, src_desc, diff_weights_desc, nullptr,
  8465. diff_dst_desc, hint_fwd_pd, attr, allow_empty) {}
  8466. /// Constructs a primitive descriptor for an inner product weights
  8467. /// update primitive from a C API primitive descriptor that must
  8468. /// have a matching kind.
  8469. ///
  8470. /// @param pd C API primitive descriptor for an inner product weights
  8471. /// gradient primitive.
  8472. primitive_desc(dnnl_primitive_desc_t pd)
  8473. : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
  8474. dnnl::prop_kind::backward_weights) {}
  8475. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  8476. memory::desc src_desc() const { return base::src_desc(0); }
  8477. /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
  8478. memory::desc diff_weights_desc() const {
  8479. return base::diff_weights_desc(0);
  8480. }
  8481. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  8482. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  8483. /// @copydoc dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc()const
  8484. memory::desc diff_bias_desc() const {
  8485. return base::diff_weights_desc(1);
  8486. }
  8487. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  8488. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  8489. private:
  8490. primitive_desc(const engine &aengine, const memory::desc &src_desc,
  8491. const memory::desc &diff_weights_desc,
  8492. const memory::desc *diff_bias_desc,
  8493. const memory::desc &diff_dst_desc,
  8494. const inner_product_forward::primitive_desc &hint_fwd_pd,
  8495. const primitive_attr &attr, bool allow_empty) {
  8496. dnnl_primitive_desc_t pd = nullptr;
  8497. dnnl_status_t status
  8498. = dnnl_inner_product_backward_weights_primitive_desc_create(
  8499. &pd, aengine.get(), src_desc.get(),
  8500. diff_weights_desc.get(),
  8501. optional_arg(diff_bias_desc), diff_dst_desc.get(),
  8502. hint_fwd_pd.get(), attr.get());
  8503. if (!allow_empty)
  8504. error::wrap_c_api(status,
  8505. "could not create a primitive descriptor for "
  8506. "the inner product weights gradient primitive. Run "
  8507. "workload with environment variable ONEDNN_VERBOSE=all "
  8508. "to get additional diagnostic information.");
  8509. reset(pd);
  8510. }
  8511. };
  8512. /// Default constructor. Produces an empty object.
  8513. inner_product_backward_weights() = default;
  8514. /// Constructs an inner product weights gradient primitive.
  8515. /// @param pd Primitive descriptor for an inner product weights gradient
  8516. /// primitive.
  8517. inner_product_backward_weights(const primitive_desc &pd) : primitive(pd) {}
  8518. /// Constructs an inner product weights gradient primitive from a cache
  8519. /// blob.
  8520. /// @param pd Primitive descriptor for an inner product weights gradient
  8521. /// primitive.
  8522. /// @param cache_blob Cache blob.
  8523. inner_product_backward_weights(
  8524. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  8525. : primitive(pd, cache_blob) {}
  8526. };
  8527. /// @} dnnl_api_inner_product
  8528. /// @addtogroup dnnl_api_rnn RNN
  8529. ///
  8530. /// A primitive to compute recurrent neural network layers.
  8531. ///
  8532. /// @sa @ref dev_guide_rnn in developer guide
  8533. ///
  8534. /// @{
  8535. /// Base class for primitive descriptors for RNN primitives.
  8536. struct rnn_primitive_desc_base : public primitive_desc {
  8537. using primitive_desc::primitive_desc;
  8538. /// Default constructor. Produces an empty object.
  8539. rnn_primitive_desc_base() = default;
  8540. /// Constructs an RNN primitive descriptor base from a C API primitive
  8541. /// descriptor while checking that it actually describes the expected
  8542. /// primitive by comparing propagation and primitive kinds.
  8543. ///
  8544. /// @param pd C API primitive descriptor.
  8545. /// @param aprop_kind Expected propagation kind.
  8546. /// @param cell_kind Expected cell kind.
  8547. rnn_primitive_desc_base(dnnl_primitive_desc_t pd,
  8548. dnnl::prop_kind aprop_kind, dnnl::algorithm cell_kind)
  8549. : rnn_primitive_desc_base(pd, aprop_kind, aprop_kind, cell_kind) {}
  8550. /// Returns source layer memory descriptor.
  8551. /// @returns Source layer memory descriptor.
  8552. memory::desc src_layer_desc() const {
  8553. return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_LAYER);
  8554. }
  8555. /// Returns AUGRU attention memory descriptor.
  8556. /// @returns AUGRU attention memory descriptor.
  8557. memory::desc augru_attention_desc() const {
  8558. return base::query_md(query::exec_arg_md, DNNL_ARG_AUGRU_ATTENTION);
  8559. }
  8560. /// Returns source iteration memory descriptor.
  8561. /// @returns Source iteration memory descriptor.
  8562. /// @returns A zero memory descriptor if the primitive does not have a
  8563. /// source iteration parameter.
  8564. memory::desc src_iter_desc() const {
  8565. return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER);
  8566. }
  8567. /// Returns source recurrent cell state memory descriptor.
  8568. /// @returns Source recurrent cell state memory descriptor.
  8569. memory::desc src_iter_c_desc() const {
  8570. return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER_C);
  8571. }
  8572. /// Returns weights layer memory descriptor.
  8573. /// @returns Weights layer memory descriptor.
  8574. memory::desc weights_layer_desc() const {
  8575. return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_LAYER);
  8576. }
  8577. /// Returns weights iteration memory descriptor.
  8578. /// @returns Weights iteration memory descriptor.
  8579. memory::desc weights_iter_desc() const {
  8580. return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_ITER);
  8581. }
  8582. /// Returns weights peephole memory descriptor.
  8583. /// @returns Weights peephole memory descriptor.
  8584. memory::desc weights_peephole_desc() const {
  8585. return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PEEPHOLE);
  8586. }
  8587. /// Returns weights projection memory descriptor.
  8588. /// @returns Weights projection memory descriptor.
  8589. memory::desc weights_projection_desc() const {
  8590. return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PROJECTION);
  8591. }
  8592. /// Returns bias memory descriptor.
  8593. /// @returns Bias memory descriptor.
  8594. /// @returns A zero memory descriptor if the primitive does not have a
  8595. /// bias parameter.
  8596. memory::desc bias_desc() const {
  8597. return base::query_md(query::exec_arg_md, DNNL_ARG_BIAS);
  8598. }
  8599. /// Returns destination layer memory descriptor.
  8600. /// @returns Destination layer memory descriptor.
  8601. memory::desc dst_layer_desc() const {
  8602. return base::query_md(query::exec_arg_md, DNNL_ARG_DST_LAYER);
  8603. }
  8604. /// Returns destination iteration memory descriptor.
  8605. /// @returns Destination iteration memory descriptor.
  8606. /// @returns A zero memory descriptor if the primitive does not have a
  8607. /// destination iteration parameter.
  8608. memory::desc dst_iter_desc() const {
  8609. return base::query_md(query::exec_arg_md, DNNL_ARG_DST_ITER);
  8610. }
  8611. /// Returns destination recurrent cell state memory descriptor.
  8612. /// @returns Destination recurrent cell state memory descriptor.
  8613. memory::desc dst_iter_c_desc() const {
  8614. return base::query_md(query::exec_arg_md, DNNL_ARG_DST_ITER_C);
  8615. }
  8616. /// Returns diff source layer memory descriptor.
  8617. /// @returns Diff source layer memory descriptor.
  8618. memory::desc diff_src_layer_desc() const {
  8619. return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_LAYER);
  8620. }
  8621. /// Returns diff AUGRU attention memory descriptor.
  8622. /// @returns Diff AUGRU attention memory descriptor.
  8623. memory::desc diff_augru_attention_desc() const {
  8624. return base::query_md(
  8625. query::exec_arg_md, DNNL_ARG_DIFF_AUGRU_ATTENTION);
  8626. }
  8627. /// Returns diff source iteration memory descriptor.
  8628. /// @returns Diff source iteration memory descriptor.
  8629. /// @returns A zero memory descriptor if the primitive does not have a
  8630. /// diff source iteration parameter.
  8631. memory::desc diff_src_iter_desc() const {
  8632. return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_ITER);
  8633. }
  8634. /// Returns diff source recurrent cell state memory descriptor.
  8635. /// @returns Diff source recurrent cell state memory descriptor.
  8636. memory::desc diff_src_iter_c_desc() const {
  8637. return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_ITER_C);
  8638. }
  8639. /// Returns diff weights layer memory descriptor.
  8640. /// @returns Diff weights layer memory descriptor.
  8641. memory::desc diff_weights_layer_desc() const {
  8642. return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_LAYER);
  8643. }
  8644. /// Returns diff weights iteration memory descriptor.
  8645. /// @returns Diff weights iteration memory descriptor.
  8646. memory::desc diff_weights_iter_desc() const {
  8647. return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_ITER);
  8648. }
  8649. /// Returns diff weights peephole memory descriptor.
  8650. /// @returns Diff weights peephole memory descriptor.
  8651. memory::desc diff_weights_peephole_desc() const {
  8652. return base::query_md(
  8653. query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE);
  8654. }
  8655. /// Returns diff weights projection memory descriptor.
  8656. /// @returns Diff weights projection memory descriptor.
  8657. memory::desc diff_weights_projection_desc() const {
  8658. return base::query_md(
  8659. query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_PROJECTION);
  8660. }
  8661. /// Returns diff bias memory descriptor.
  8662. /// @returns Diff bias memory descriptor.
  8663. /// @returns A zero memory descriptor if the primitive does not have a
  8664. /// diff bias parameter.
  8665. memory::desc diff_bias_desc() const {
  8666. return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_BIAS);
  8667. }
  8668. /// Returns diff destination layer memory descriptor.
  8669. /// @returns Diff destination layer memory descriptor.
  8670. memory::desc diff_dst_layer_desc() const {
  8671. return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_LAYER);
  8672. }
  8673. /// Returns diff destination iteration memory descriptor.
  8674. /// @returns Diff destination iteration memory descriptor.
  8675. /// @returns A zero memory descriptor if the primitive does not have a
  8676. /// diff destination iteration parameter.
  8677. memory::desc diff_dst_iter_desc() const {
  8678. return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_ITER);
  8679. }
  8680. /// Returns diff destination recurrent cell state memory descriptor.
  8681. /// @returns Diff destination recurrent cell state memory descriptor.
  8682. memory::desc diff_dst_iter_c_desc() const {
  8683. return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_ITER_C);
  8684. }
  8685. protected:
  8686. using rnn_base = rnn_primitive_desc_base;
  8687. // (Deliberately not using doxygen comments)
  8688. //
  8689. // Constructs an RNN primitive descriptor base from a C API primitive
  8690. // descriptor while checking that it actually describes the expected
  8691. // primitive by comparing propagation and primitive kinds. Caller can
  8692. // pass two options propagation kinds. This is typically used to check
  8693. // that propagation kind is inference or training forward propagation.
  8694. //
  8695. // @param pd C API primitive descriptor.
  8696. // @param prop_kind1 Expected propagation kind.
  8697. // @param prop_kind2 Expected propagation kind.
  8698. // @param cell_kind Expected cell kind.
  8699. rnn_primitive_desc_base(dnnl_primitive_desc_t pd,
  8700. dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2,
  8701. dnnl::algorithm cell_kind) {
  8702. dnnl_status_t rc;
  8703. dnnl_primitive_kind_t q_primitive_kind;
  8704. rc = dnnl_primitive_desc_query(
  8705. pd, dnnl_query_primitive_kind, 0, &q_primitive_kind);
  8706. error::wrap_c_api(rc,
  8707. "could not retrieve a primitive kind from a primitive "
  8708. "descriptor for an RNN primitive");
  8709. dnnl_prop_kind_t q_prop_kind;
  8710. rc = dnnl_primitive_desc_query(
  8711. pd, dnnl_query_prop_kind, 0, &q_prop_kind);
  8712. error::wrap_c_api(rc,
  8713. "could not retrieve a propagation kind from a primitive "
  8714. "descriptor for an RNN primitive");
  8715. dnnl_alg_kind_t q_cell_kind;
  8716. rc = dnnl_primitive_desc_query(
  8717. pd, dnnl_query_cell_kind, 0, &q_cell_kind);
  8718. error::wrap_c_api(rc,
  8719. "could not retrieve a cell kind from a primitive descriptor "
  8720. "for an RNN primitive");
  8721. dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
  8722. dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
  8723. dnnl_alg_kind_t c_cell_kind = convert_to_c(cell_kind);
  8724. bool ok = q_primitive_kind == dnnl_rnn
  8725. && (q_prop_kind == c_prop_kind1 || q_prop_kind == c_prop_kind2)
  8726. && q_cell_kind == c_cell_kind;
  8727. if (!ok)
  8728. DNNL_THROW_ERROR(dnnl_invalid_arguments,
  8729. "mismatch between expected and provided descriptors for an "
  8730. "RNN primitive");
  8731. reset_with_clone(pd);
  8732. }
  8733. // Constructs an RNN forward propagation primitive descriptor base for
  8734. // any cell kind.
  8735. rnn_primitive_desc_base(const engine &aengine, algorithm cell_kind,
  8736. prop_kind aprop_kind, algorithm activation, rnn_direction direction,
  8737. const memory::desc &src_layer_desc,
  8738. const memory::desc &src_iter_desc,
  8739. const memory::desc *src_iter_c_desc,
  8740. const memory::desc *attention_desc,
  8741. const memory::desc &weights_layer_desc,
  8742. const memory::desc &weights_iter_desc,
  8743. const memory::desc *weights_peephole_desc,
  8744. const memory::desc *weights_projection_desc,
  8745. const memory::desc &bias_desc, const memory::desc &dst_layer_desc,
  8746. const memory::desc &dst_iter_desc,
  8747. const memory::desc *dst_iter_c_desc, rnn_flags flags, float alpha,
  8748. float beta, const primitive_attr &attr, bool allow_empty) {
  8749. dnnl_status_t status = dnnl_success;
  8750. const char *msg
  8751. = "could not create a primitive descriptor for a requested "
  8752. "cell kind";
  8753. dnnl_primitive_desc_t pd = nullptr;
  8754. switch (cell_kind) {
  8755. case algorithm::vanilla_rnn:
  8756. status = dnnl_vanilla_rnn_forward_primitive_desc_create(&pd,
  8757. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8758. dnnl::convert_to_c(activation),
  8759. dnnl::convert_to_c(direction), src_layer_desc.get(),
  8760. src_iter_desc.get(), weights_layer_desc.get(),
  8761. weights_iter_desc.get(), bias_desc.get(),
  8762. dst_layer_desc.get(), dst_iter_desc.get(),
  8763. convert_to_c(flags), alpha, beta, attr.get());
  8764. msg = "could not create a primitive descriptor for "
  8765. "the vanilla RNN forward propagation primitive. Run "
  8766. "workload with environment variable ONEDNN_VERBOSE=all "
  8767. "to get additional diagnostic information.";
  8768. break;
  8769. case algorithm::vanilla_lstm:
  8770. status = dnnl_lstm_forward_primitive_desc_create(&pd,
  8771. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8772. dnnl::convert_to_c(direction), src_layer_desc.get(),
  8773. src_iter_desc.get(), optional_arg(src_iter_c_desc),
  8774. weights_layer_desc.get(), weights_iter_desc.get(),
  8775. optional_arg(weights_peephole_desc),
  8776. optional_arg(weights_projection_desc), bias_desc.get(),
  8777. dst_layer_desc.get(), dst_iter_desc.get(),
  8778. optional_arg(dst_iter_c_desc), convert_to_c(flags),
  8779. attr.get());
  8780. msg = "could not create a primitive descriptor for "
  8781. "the LSTM forward propagation primitive. Run workload "
  8782. "with environment variable ONEDNN_VERBOSE=all to get "
  8783. "additional diagnostic information.";
  8784. break;
  8785. case algorithm::vanilla_gru:
  8786. status = dnnl_gru_forward_primitive_desc_create(&pd,
  8787. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8788. dnnl::convert_to_c(direction), src_layer_desc.get(),
  8789. src_iter_desc.get(), weights_layer_desc.get(),
  8790. weights_iter_desc.get(), bias_desc.get(),
  8791. dst_layer_desc.get(), dst_iter_desc.get(),
  8792. convert_to_c(flags), attr.get());
  8793. msg = "could not create a primitive descriptor for "
  8794. "the GRU forward propagation primitive. Run workload "
  8795. "with environment variable ONEDNN_VERBOSE=all to get "
  8796. "additional diagnostic information.";
  8797. break;
  8798. case algorithm::lbr_gru:
  8799. status = dnnl_lbr_gru_forward_primitive_desc_create(&pd,
  8800. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8801. dnnl::convert_to_c(direction), src_layer_desc.get(),
  8802. src_iter_desc.get(), weights_layer_desc.get(),
  8803. weights_iter_desc.get(), bias_desc.get(),
  8804. dst_layer_desc.get(), dst_iter_desc.get(),
  8805. convert_to_c(flags), attr.get());
  8806. msg = "could not create a primitive descriptor for "
  8807. "the LBR GRU forward propagation primitive. Run workload "
  8808. "with environment variable ONEDNN_VERBOSE=all to get "
  8809. "additional diagnostic information.";
  8810. break;
  8811. case algorithm::vanilla_augru:
  8812. status = dnnl_augru_forward_primitive_desc_create(&pd,
  8813. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8814. dnnl::convert_to_c(direction), src_layer_desc.get(),
  8815. src_iter_desc.get(), optional_arg(attention_desc),
  8816. weights_layer_desc.get(), weights_iter_desc.get(),
  8817. bias_desc.get(), dst_layer_desc.get(),
  8818. dst_iter_desc.get(), convert_to_c(flags), attr.get());
  8819. msg = "could not create a primitive descriptor for "
  8820. "the AUGRU forward propagation primitive. Run workload "
  8821. "with environment variable ONEDNN_VERBOSE=all to get "
  8822. "additional diagnostic information.";
  8823. break;
  8824. case algorithm::lbr_augru:
  8825. status = dnnl_lbr_augru_forward_primitive_desc_create(&pd,
  8826. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8827. dnnl::convert_to_c(direction), src_layer_desc.get(),
  8828. src_iter_desc.get(), optional_arg(attention_desc),
  8829. weights_layer_desc.get(), weights_iter_desc.get(),
  8830. bias_desc.get(), dst_layer_desc.get(),
  8831. dst_iter_desc.get(), convert_to_c(flags), attr.get());
  8832. msg = "could not create a primitive descriptor for "
  8833. "the LBR AUGRU forward propagation primitive. Run "
  8834. "workload with environment variable ONEDNN_VERBOSE=all "
  8835. "to get additional diagnostic information.";
  8836. break;
  8837. default: status = dnnl_unimplemented;
  8838. }
  8839. if (!allow_empty) error::wrap_c_api(status, msg);
  8840. reset(pd);
  8841. }
  8842. // Constructs an RNN backward propagation primitive descriptor base for
  8843. // any cell kind.
  8844. rnn_primitive_desc_base(const engine &aengine, algorithm cell_kind,
  8845. prop_kind aprop_kind, algorithm activation, rnn_direction direction,
  8846. const memory::desc &src_layer_desc,
  8847. const memory::desc &src_iter_desc,
  8848. const memory::desc *src_iter_c_desc,
  8849. const memory::desc *attention_desc,
  8850. const memory::desc &weights_layer_desc,
  8851. const memory::desc &weights_iter_desc,
  8852. const memory::desc *weights_peephole_desc,
  8853. const memory::desc *weights_projection_desc,
  8854. const memory::desc &bias_desc, const memory::desc &dst_layer_desc,
  8855. const memory::desc &dst_iter_desc,
  8856. const memory::desc *dst_iter_c_desc,
  8857. const memory::desc &diff_src_layer_desc,
  8858. const memory::desc &diff_src_iter_desc,
  8859. const memory::desc *diff_src_iter_c_desc,
  8860. const memory::desc *diff_attention_desc,
  8861. const memory::desc &diff_weights_layer_desc,
  8862. const memory::desc &diff_weights_iter_desc,
  8863. const memory::desc *diff_weights_peephole_desc,
  8864. const memory::desc *diff_weights_projection_desc,
  8865. const memory::desc &diff_bias_desc,
  8866. const memory::desc &diff_dst_layer_desc,
  8867. const memory::desc &diff_dst_iter_desc,
  8868. const memory::desc *diff_dst_iter_c_desc, rnn_flags flags,
  8869. float alpha, float beta, const rnn_primitive_desc_base &hint_fwd_pd,
  8870. const primitive_attr &attr, bool allow_empty) {
  8871. dnnl_status_t status = dnnl_success;
  8872. const char *msg = "";
  8873. dnnl_primitive_desc_t pd = nullptr;
  8874. switch (cell_kind) {
  8875. case algorithm::vanilla_rnn:
  8876. status = dnnl_vanilla_rnn_backward_primitive_desc_create(&pd,
  8877. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8878. dnnl::convert_to_c(activation),
  8879. dnnl::convert_to_c(direction), src_layer_desc.get(),
  8880. src_iter_desc.get(), weights_layer_desc.get(),
  8881. weights_iter_desc.get(), bias_desc.get(),
  8882. dst_layer_desc.get(), dst_iter_desc.get(),
  8883. diff_src_layer_desc.get(), diff_src_iter_desc.get(),
  8884. diff_weights_layer_desc.get(),
  8885. diff_weights_iter_desc.get(), diff_bias_desc.get(),
  8886. diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
  8887. convert_to_c(flags), alpha, beta, hint_fwd_pd.get(),
  8888. attr.get());
  8889. msg = "could not create a primitive descriptor for "
  8890. "the vanilla RNN backward propagation primitive. Run "
  8891. "workload with environment variable ONEDNN_VERBOSE=all "
  8892. "to get additional diagnostic information.";
  8893. break;
  8894. case algorithm::vanilla_lstm:
  8895. status = dnnl_lstm_backward_primitive_desc_create(&pd,
  8896. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8897. dnnl::convert_to_c(direction), src_layer_desc.get(),
  8898. src_iter_desc.get(), optional_arg(src_iter_c_desc),
  8899. weights_layer_desc.get(), weights_iter_desc.get(),
  8900. optional_arg(weights_peephole_desc),
  8901. optional_arg(weights_projection_desc), bias_desc.get(),
  8902. dst_layer_desc.get(), dst_iter_desc.get(),
  8903. optional_arg(dst_iter_c_desc),
  8904. diff_src_layer_desc.get(), diff_src_iter_desc.get(),
  8905. optional_arg(diff_src_iter_c_desc),
  8906. diff_weights_layer_desc.get(),
  8907. diff_weights_iter_desc.get(),
  8908. optional_arg(diff_weights_peephole_desc),
  8909. optional_arg(diff_weights_projection_desc),
  8910. diff_bias_desc.get(), diff_dst_layer_desc.get(),
  8911. diff_dst_iter_desc.get(),
  8912. optional_arg(diff_dst_iter_c_desc), convert_to_c(flags),
  8913. hint_fwd_pd.get(), attr.get());
  8914. msg = "could not create a primitive descriptor for "
  8915. "the LSTM backward propagation primitive. Run workload "
  8916. "with environment variable ONEDNN_VERBOSE=all to get "
  8917. "additional diagnostic information.";
  8918. break;
  8919. case algorithm::vanilla_gru:
  8920. status = dnnl_gru_backward_primitive_desc_create(&pd,
  8921. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8922. dnnl::convert_to_c(direction), src_layer_desc.get(),
  8923. src_iter_desc.get(), weights_layer_desc.get(),
  8924. weights_iter_desc.get(), bias_desc.get(),
  8925. dst_layer_desc.get(), dst_iter_desc.get(),
  8926. diff_src_layer_desc.get(), diff_src_iter_desc.get(),
  8927. diff_weights_layer_desc.get(),
  8928. diff_weights_iter_desc.get(), diff_bias_desc.get(),
  8929. diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
  8930. convert_to_c(flags), hint_fwd_pd.get(), attr.get());
  8931. msg = "could not create a primitive descriptor for "
  8932. "the GRU backward propagation primitive. Run workload "
  8933. "with environment variable ONEDNN_VERBOSE=all to get "
  8934. "additional diagnostic information.";
  8935. break;
  8936. case algorithm::lbr_gru:
  8937. status = dnnl_lbr_gru_backward_primitive_desc_create(&pd,
  8938. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8939. dnnl::convert_to_c(direction), src_layer_desc.get(),
  8940. src_iter_desc.get(), weights_layer_desc.get(),
  8941. weights_iter_desc.get(), bias_desc.get(),
  8942. dst_layer_desc.get(), dst_iter_desc.get(),
  8943. diff_src_layer_desc.get(), diff_src_iter_desc.get(),
  8944. diff_weights_layer_desc.get(),
  8945. diff_weights_iter_desc.get(), diff_bias_desc.get(),
  8946. diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
  8947. convert_to_c(flags), hint_fwd_pd.get(), attr.get());
  8948. msg = "could not create a primitive descriptor for "
  8949. "the LBR GRU backward propagation primitive. Run "
  8950. "workload with environment variable ONEDNN_VERBOSE=all "
  8951. "to get additional diagnostic information.";
  8952. break;
  8953. case algorithm::vanilla_augru:
  8954. status = dnnl_augru_backward_primitive_desc_create(&pd,
  8955. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8956. dnnl::convert_to_c(direction), src_layer_desc.get(),
  8957. src_iter_desc.get(), optional_arg(attention_desc),
  8958. weights_layer_desc.get(), weights_iter_desc.get(),
  8959. bias_desc.get(), dst_layer_desc.get(),
  8960. dst_iter_desc.get(), diff_src_layer_desc.get(),
  8961. diff_src_iter_desc.get(),
  8962. optional_arg(diff_attention_desc),
  8963. diff_weights_layer_desc.get(),
  8964. diff_weights_iter_desc.get(), diff_bias_desc.get(),
  8965. diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
  8966. convert_to_c(flags), hint_fwd_pd.get(), attr.get());
  8967. msg = "could not create a primitive descriptor for "
  8968. "the AUGRU backward propagation primitive. Run workload "
  8969. "with environment variable ONEDNN_VERBOSE=all to get "
  8970. "additional diagnostic information.";
  8971. break;
  8972. case algorithm::lbr_augru:
  8973. status = dnnl_lbr_augru_backward_primitive_desc_create(&pd,
  8974. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8975. dnnl::convert_to_c(direction), src_layer_desc.get(),
  8976. src_iter_desc.get(), optional_arg(attention_desc),
  8977. weights_layer_desc.get(), weights_iter_desc.get(),
  8978. bias_desc.get(), dst_layer_desc.get(),
  8979. dst_iter_desc.get(), diff_src_layer_desc.get(),
  8980. diff_src_iter_desc.get(),
  8981. optional_arg(diff_attention_desc),
  8982. diff_weights_layer_desc.get(),
  8983. diff_weights_iter_desc.get(), diff_bias_desc.get(),
  8984. diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
  8985. convert_to_c(flags), hint_fwd_pd.get(), attr.get());
  8986. msg = "could not create a primitive descriptor for "
  8987. "the LBR AUGRU backward propagation primitive. Run "
  8988. "workload with environment variable ONEDNN_VERBOSE=all "
  8989. "to get additional diagnostic information.";
  8990. break;
  8991. default: status = dnnl_unimplemented;
  8992. }
  8993. if (!allow_empty) error::wrap_c_api(status, msg);
  8994. reset(pd);
  8995. }
  8996. };
  8997. /// Vanilla RNN forward propagation primitive.
  8998. struct vanilla_rnn_forward : public primitive {
  8999. /// Primitive descriptor for a vanilla RNN forward propagation primitive.
  9000. struct primitive_desc : public rnn_primitive_desc_base {
  9001. /// Default constructor. Produces an empty object.
  9002. primitive_desc() = default;
  9003. /// Constructs a primitive descriptor for a vanilla RNN forward
  9004. /// propagation primitive.
  9005. ///
  9006. /// The following arguments may point to a zero memory descriptor:
  9007. /// - @p src_iter_desc,
  9008. /// - @p bias_desc,
  9009. /// - @p dst_iter_desc.
  9010. ///
  9011. /// This would then indicate that the RNN forward propagation primitive
  9012. /// should not use them and should default to zero values instead.
  9013. ///
  9014. /// @note
  9015. /// All memory descriptors except @p src_iter_desc can be
  9016. /// initialized with an #dnnl::memory::format_tag::any value of @p
  9017. /// format_tag.
  9018. ///
  9019. /// @param aengine Engine to use.
  9020. /// @param aprop_kind Propagation kind. Possible values are
  9021. /// #dnnl::prop_kind::forward_training, and
  9022. /// #dnnl::prop_kind::forward_inference.
  9023. /// @param activation Activation kind. Possible values are
  9024. /// #dnnl::algorithm::eltwise_relu,
  9025. /// #dnnl::algorithm::eltwise_tanh, or
  9026. /// #dnnl::algorithm::eltwise_logistic.
  9027. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  9028. /// more info.
  9029. /// @param src_layer_desc Memory descriptor for the input vector.
  9030. /// @param src_iter_desc Memory descriptor for the input recurrent
  9031. /// hidden state vector.
  9032. /// @param weights_layer_desc Memory descriptor for the weights
  9033. /// applied to the layer input.
  9034. /// @param weights_iter_desc Memory descriptor for the weights applied
  9035. /// to the recurrent input.
  9036. /// @param bias_desc Bias memory descriptor.
  9037. /// @param dst_layer_desc Memory descriptor for the output vector.
  9038. /// @param dst_iter_desc Memory descriptor for the output recurrent
  9039. /// hidden state vector.
  9040. /// @param attr Primitive attributes to use. Attributes are optional
  9041. /// and default to empty attributes.
  9042. /// @param allow_empty A flag signifying whether construction is
  9043. /// allowed to fail without throwing an exception. In this case an
  9044. /// empty object will be produced. This flag is optional and
  9045. /// defaults to false.
  9046. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  9047. algorithm activation, rnn_direction direction,
  9048. const memory::desc &src_layer_desc,
  9049. const memory::desc &src_iter_desc,
  9050. const memory::desc &weights_layer_desc,
  9051. const memory::desc &weights_iter_desc,
  9052. const memory::desc &bias_desc,
  9053. const memory::desc &dst_layer_desc,
  9054. const memory::desc &dst_iter_desc,
  9055. const primitive_attr &attr = default_attr(),
  9056. bool allow_empty = false)
  9057. : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
  9058. aprop_kind, activation, direction, src_layer_desc,
  9059. src_iter_desc, nullptr, nullptr, weights_layer_desc,
  9060. weights_iter_desc, nullptr, nullptr, bias_desc,
  9061. dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
  9062. 0.0f, 0.0f, attr, allow_empty) {}
  9063. /// Constructs a primitive descriptor for a vanilla RNN forward
  9064. /// propagation primitive with alpha parameter.
  9065. ///
  9066. /// The following arguments may point to a zero memory descriptor:
  9067. /// - @p src_iter_desc,
  9068. /// - @p bias_desc,
  9069. /// - @p dst_iter_desc.
  9070. ///
  9071. /// This would then indicate that the RNN forward propagation primitive
  9072. /// should not use them and should default to zero values instead.
  9073. ///
  9074. /// @note
  9075. /// All memory descriptors except @p src_iter_desc can be
  9076. /// initialized with an #dnnl::memory::format_tag::any value of @p
  9077. /// format_tag.
  9078. ///
  9079. /// @param aengine Engine to use.
  9080. /// @param aprop_kind Propagation kind. Possible values are
  9081. /// #dnnl::prop_kind::forward_training, and
  9082. /// #dnnl::prop_kind::forward_inference.
  9083. /// @param activation Activation kind. Possible values are
  9084. /// #dnnl::algorithm::eltwise_relu,
  9085. /// #dnnl::algorithm::eltwise_tanh, or
  9086. /// #dnnl::algorithm::eltwise_logistic.
  9087. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  9088. /// more info.
  9089. /// @param src_layer_desc Memory descriptor for the input vector.
  9090. /// @param src_iter_desc Memory descriptor for the input recurrent
  9091. /// hidden state vector.
  9092. /// @param weights_layer_desc Memory descriptor for the weights
  9093. /// applied to the layer input.
  9094. /// @param weights_iter_desc Memory descriptor for the weights applied
  9095. /// to the recurrent input.
  9096. /// @param bias_desc Bias memory descriptor.
  9097. /// @param dst_layer_desc Memory descriptor for the output vector.
  9098. /// @param dst_iter_desc Memory descriptor for the output recurrent
  9099. /// hidden state vector.
  9100. /// @param alpha Negative slope if activation is
  9101. /// #dnnl::algorithm::eltwise_relu.
  9102. /// @param attr Primitive attributes to use. Attributes are optional
  9103. /// and default to empty attributes.
  9104. /// @param allow_empty A flag signifying whether construction is
  9105. /// allowed to fail without throwing an exception. In this case an
  9106. /// empty object will be produced. This flag is optional and
  9107. /// defaults to false.
  9108. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  9109. algorithm activation, rnn_direction direction,
  9110. const memory::desc &src_layer_desc,
  9111. const memory::desc &src_iter_desc,
  9112. const memory::desc &weights_layer_desc,
  9113. const memory::desc &weights_iter_desc,
  9114. const memory::desc &bias_desc,
  9115. const memory::desc &dst_layer_desc,
  9116. const memory::desc &dst_iter_desc, float alpha,
  9117. const primitive_attr &attr = default_attr(),
  9118. bool allow_empty = false)
  9119. : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
  9120. aprop_kind, activation, direction, src_layer_desc,
  9121. src_iter_desc, nullptr, nullptr, weights_layer_desc,
  9122. weights_iter_desc, nullptr, nullptr, bias_desc,
  9123. dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
  9124. alpha, 0.0f, attr, allow_empty) {}
  9125. /// Constructs a primitive descriptor for a vanilla RNN forward
  9126. /// propagation primitive from a C API primitive descriptor that must
  9127. /// have a matching kind.
  9128. ///
  9129. /// @param pd C API primitive descriptor for a vanilla RNN forward
  9130. /// propagation primitive.
  9131. primitive_desc(dnnl_primitive_desc_t pd)
  9132. : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
  9133. dnnl::prop_kind::forward_inference,
  9134. dnnl::algorithm::vanilla_rnn) {}
  9135. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  9136. memory::desc src_layer_desc() const {
  9137. return rnn_base::src_layer_desc();
  9138. }
  9139. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  9140. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  9141. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  9142. memory::desc weights_layer_desc() const {
  9143. return rnn_base::weights_layer_desc();
  9144. }
  9145. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  9146. memory::desc weights_iter_desc() const {
  9147. return rnn_base::weights_iter_desc();
  9148. }
  9149. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  9150. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  9151. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  9152. memory::desc dst_layer_desc() const {
  9153. return rnn_base::dst_layer_desc();
  9154. }
  9155. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  9156. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  9157. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  9158. memory::desc workspace_desc() const {
  9159. return rnn_base::workspace_desc();
  9160. }
  9161. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  9162. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  9163. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  9164. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  9165. /// @copydoc dnnl::primitive_desc_base::get_activation_kind()const
  9166. algorithm get_activation_kind() const {
  9167. return base::get_activation_kind();
  9168. }
  9169. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  9170. rnn_direction get_direction() const { return base::get_direction(); }
  9171. /// @copydoc dnnl::primitive_desc_base::get_alpha()const
  9172. float get_alpha() const { return base::get_alpha(); }
  9173. /// @copydoc dnnl::primitive_desc_base::get_beta()const
  9174. float get_beta() const { return base::get_beta(); }
  9175. };
  9176. /// Default constructor. Produces an empty object.
  9177. vanilla_rnn_forward() = default;
  9178. /// Constructs a vanilla RNN forward propagation primitive.
  9179. /// @param pd Primitive descriptor for a vanilla RNN forward
  9180. /// propagation primitive.
  9181. vanilla_rnn_forward(const primitive_desc &pd) : primitive(pd) {}
  9182. /// Constructs a vanilla RNN forward propagation primitive from
  9183. /// a cache blob.
  9184. /// @param pd Primitive descriptor for a vanilla RNN forward
  9185. /// propagation primitive.
  9186. /// @param cache_blob Cache blob.
  9187. vanilla_rnn_forward(
  9188. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  9189. : primitive(pd, cache_blob) {}
  9190. };
  9191. /// Vanilla RNN backward propagation primitive.
  9192. struct vanilla_rnn_backward : public primitive {
  9193. /// Primitive descriptor for an RNN backward propagation primitive.
  9194. struct primitive_desc : public rnn_primitive_desc_base {
  9195. /// Default constructor. Produces an empty object.
  9196. primitive_desc() = default;
  9197. /// Constructs a primitive descriptor for a vanilla RNN backward
  9198. /// propagation primitive.
  9199. ///
  9200. /// The following arguments may point to a zero memory descriptor:
  9201. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  9202. /// - @p bias_desc together with @p diff_bias_desc,
  9203. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  9204. ///
  9205. /// This would then indicate that the RNN backward propagation
  9206. /// primitive should not use the respective data and should use zero
  9207. /// values instead.
  9208. ///
  9209. /// @note
  9210. /// All the memory descriptors may be initialized with the
  9211. /// #dnnl::memory::format_tag::any value of @p format_tag.
  9212. ///
  9213. /// @param aengine Engine to use.
  9214. /// @param aprop_kind Propagation kind. Must be
  9215. /// #dnnl::prop_kind::backward.
  9216. /// @param activation Activation kind. Possible values are
  9217. /// #dnnl::algorithm::eltwise_relu,
  9218. /// #dnnl::algorithm::eltwise_tanh, or
  9219. /// #dnnl::algorithm::eltwise_logistic.
  9220. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  9221. /// more info.
  9222. /// @param src_layer_desc Memory descriptor for the input vector.
  9223. /// @param src_iter_desc Memory descriptor for the input recurrent
  9224. /// hidden state vector.
  9225. /// @param weights_layer_desc Memory descriptor for the weights
  9226. /// applied to the layer input.
  9227. /// @param weights_iter_desc Memory descriptor for the weights applied
  9228. /// to the recurrent input.
  9229. /// @param bias_desc Bias memory descriptor.
  9230. /// @param dst_layer_desc Memory descriptor for the output vector.
  9231. /// @param dst_iter_desc Memory descriptor for the output recurrent
  9232. /// hidden state vector.
  9233. /// @param diff_src_layer_desc Memory descriptor for the diff of input
  9234. /// vector.
  9235. /// @param diff_src_iter_desc Memory descriptor for the diff of input
  9236. /// recurrent hidden state vector.
  9237. /// @param diff_weights_layer_desc Memory descriptor for the diff of
  9238. /// weights applied to the layer input.
  9239. /// @param diff_weights_iter_desc Memory descriptor for the diff of
  9240. /// weights applied to the recurrent input.
  9241. /// @param diff_bias_desc Diff bias memory descriptor.
  9242. /// @param diff_dst_layer_desc Memory descriptor for the diff of
  9243. /// output vector.
  9244. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  9245. /// recurrent hidden state vector.
  9246. /// @param hint_fwd_pd Primitive descriptor for a vanilla RNN
  9247. /// forward propagation primitive. It is used as a hint for
  9248. /// deciding which memory format to use.
  9249. /// @param attr Primitive attributes to use. Attributes are optional
  9250. /// and default to empty attributes.
  9251. /// @param allow_empty A flag signifying whether construction is
  9252. /// allowed to fail without throwing an exception. In this case an
  9253. /// empty object will be produced. This flag is optional and
  9254. /// defaults to false.
  9255. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  9256. algorithm activation, rnn_direction direction,
  9257. const memory::desc &src_layer_desc,
  9258. const memory::desc &src_iter_desc,
  9259. const memory::desc &weights_layer_desc,
  9260. const memory::desc &weights_iter_desc,
  9261. const memory::desc &bias_desc,
  9262. const memory::desc &dst_layer_desc,
  9263. const memory::desc &dst_iter_desc,
  9264. const memory::desc &diff_src_layer_desc,
  9265. const memory::desc &diff_src_iter_desc,
  9266. const memory::desc &diff_weights_layer_desc,
  9267. const memory::desc &diff_weights_iter_desc,
  9268. const memory::desc &diff_bias_desc,
  9269. const memory::desc &diff_dst_layer_desc,
  9270. const memory::desc &diff_dst_iter_desc,
  9271. const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
  9272. const primitive_attr &attr = default_attr(),
  9273. bool allow_empty = false)
  9274. : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
  9275. aprop_kind, activation, direction, src_layer_desc,
  9276. src_iter_desc, nullptr, nullptr, weights_layer_desc,
  9277. weights_iter_desc, nullptr, nullptr, bias_desc,
  9278. dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
  9279. diff_src_iter_desc, nullptr, nullptr,
  9280. diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
  9281. nullptr, diff_bias_desc, diff_dst_layer_desc,
  9282. diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
  9283. hint_fwd_pd, attr, allow_empty) {}
  9284. /// Constructs a primitive descriptor for a vanilla RNN backward
  9285. /// propagation primitive with an alpha parameter.
  9286. ///
  9287. /// The following arguments may point to a zero memory descriptor:
  9288. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  9289. /// - @p bias_desc together with @p diff_bias_desc,
  9290. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  9291. ///
  9292. /// This would then indicate that the RNN backward propagation
  9293. /// primitive should not use the respective data and should use zero
  9294. /// values instead.
  9295. ///
  9296. /// @note
  9297. /// All the memory descriptors may be initialized with the
  9298. /// #dnnl::memory::format_tag::any value of @p format_tag.
  9299. ///
  9300. /// @param aengine Engine to use.
  9301. /// @param aprop_kind Propagation kind. Must be
  9302. /// #dnnl::prop_kind::backward.
  9303. /// @param activation Activation kind. Possible values are
  9304. /// #dnnl::algorithm::eltwise_relu,
  9305. /// #dnnl::algorithm::eltwise_tanh, or
  9306. /// #dnnl::algorithm::eltwise_logistic.
  9307. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  9308. /// more info.
  9309. /// @param src_layer_desc Memory descriptor for the input vector.
  9310. /// @param src_iter_desc Memory descriptor for the input recurrent
  9311. /// hidden state vector.
  9312. /// @param weights_layer_desc Memory descriptor for the weights
  9313. /// applied to the layer input.
  9314. /// @param weights_iter_desc Memory descriptor for the weights applied
  9315. /// to the recurrent input.
  9316. /// @param bias_desc Bias memory descriptor.
  9317. /// @param dst_layer_desc Memory descriptor for the output vector.
  9318. /// @param dst_iter_desc Memory descriptor for the output recurrent
  9319. /// hidden state vector.
  9320. /// @param diff_src_layer_desc Memory descriptor for the diff of input
  9321. /// vector.
  9322. /// @param diff_src_iter_desc Memory descriptor for the diff of input
  9323. /// recurrent hidden state vector.
  9324. /// @param diff_weights_layer_desc Memory descriptor for the diff of
  9325. /// weights applied to the layer input.
  9326. /// @param diff_weights_iter_desc Memory descriptor for the diff of
  9327. /// weights applied to the recurrent input.
  9328. /// @param diff_bias_desc Diff bias memory descriptor.
  9329. /// @param diff_dst_layer_desc Memory descriptor for the diff of
  9330. /// output vector.
  9331. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  9332. /// recurrent hidden state vector.
  9333. /// @param alpha Negative slope if activation is
  9334. /// #dnnl::algorithm::eltwise_relu.
  9335. /// @param hint_fwd_pd Primitive descriptor for a vanilla RNN
  9336. /// forward propagation primitive. It is used as a hint for
  9337. /// deciding which memory format to use.
  9338. /// @param attr Primitive attributes to use. Attributes are optional
  9339. /// and default to empty attributes.
  9340. /// @param allow_empty A flag signifying whether construction is
  9341. /// allowed to fail without throwing an exception. In this case an
  9342. /// empty object will be produced. This flag is optional and
  9343. /// defaults to false.
  9344. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  9345. algorithm activation, rnn_direction direction,
  9346. const memory::desc &src_layer_desc,
  9347. const memory::desc &src_iter_desc,
  9348. const memory::desc &weights_layer_desc,
  9349. const memory::desc &weights_iter_desc,
  9350. const memory::desc &bias_desc,
  9351. const memory::desc &dst_layer_desc,
  9352. const memory::desc &dst_iter_desc,
  9353. const memory::desc &diff_src_layer_desc,
  9354. const memory::desc &diff_src_iter_desc,
  9355. const memory::desc &diff_weights_layer_desc,
  9356. const memory::desc &diff_weights_iter_desc,
  9357. const memory::desc &diff_bias_desc,
  9358. const memory::desc &diff_dst_layer_desc,
  9359. const memory::desc &diff_dst_iter_desc, float alpha,
  9360. const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
  9361. const primitive_attr &attr = default_attr(),
  9362. bool allow_empty = false)
  9363. : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
  9364. aprop_kind, activation, direction, src_layer_desc,
  9365. src_iter_desc, nullptr, nullptr, weights_layer_desc,
  9366. weights_iter_desc, nullptr, nullptr, bias_desc,
  9367. dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
  9368. diff_src_iter_desc, nullptr, nullptr,
  9369. diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
  9370. nullptr, diff_bias_desc, diff_dst_layer_desc,
  9371. diff_dst_iter_desc, nullptr, rnn_flags::undef, alpha, 0.0f,
  9372. hint_fwd_pd, attr, allow_empty) {}
  9373. /// Constructs a primitive descriptor for a vanilla RNN backward
  9374. /// propagation primitive from a C API primitive descriptor that must
  9375. /// have a matching kind.
  9376. ///
  9377. /// @param pd C API primitive descriptor for a vanilla RNN backward
  9378. /// propagation primitive.
  9379. primitive_desc(dnnl_primitive_desc_t pd)
  9380. : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
  9381. dnnl::algorithm::vanilla_rnn) {}
  9382. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  9383. memory::desc src_layer_desc() const {
  9384. return rnn_base::src_layer_desc();
  9385. }
  9386. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  9387. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  9388. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  9389. memory::desc weights_layer_desc() const {
  9390. return rnn_base::weights_layer_desc();
  9391. }
  9392. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  9393. memory::desc weights_iter_desc() const {
  9394. return rnn_base::weights_iter_desc();
  9395. }
  9396. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  9397. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  9398. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  9399. memory::desc dst_layer_desc() const {
  9400. return rnn_base::dst_layer_desc();
  9401. }
  9402. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  9403. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  9404. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  9405. memory::desc workspace_desc() const {
  9406. return rnn_base::workspace_desc();
  9407. }
  9408. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
  9409. memory::desc diff_src_layer_desc() const {
  9410. return rnn_base::diff_src_layer_desc();
  9411. }
  9412. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
  9413. memory::desc diff_src_iter_desc() const {
  9414. return rnn_base::diff_src_iter_desc();
  9415. }
  9416. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
  9417. memory::desc diff_weights_layer_desc() const {
  9418. return rnn_base::diff_weights_layer_desc();
  9419. }
  9420. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
  9421. memory::desc diff_weights_iter_desc() const {
  9422. return rnn_base::diff_weights_iter_desc();
  9423. }
  9424. /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
  9425. memory::desc diff_bias_desc() const {
  9426. return rnn_base::diff_bias_desc();
  9427. }
  9428. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
  9429. memory::desc diff_dst_layer_desc() const {
  9430. return rnn_base::diff_dst_layer_desc();
  9431. }
  9432. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
  9433. memory::desc diff_dst_iter_desc() const {
  9434. return rnn_base::diff_dst_iter_desc();
  9435. }
  9436. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  9437. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  9438. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  9439. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  9440. /// @copydoc dnnl::primitive_desc_base::get_activation_kind()const
  9441. algorithm get_activation_kind() const {
  9442. return base::get_activation_kind();
  9443. }
  9444. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  9445. rnn_direction get_direction() const { return base::get_direction(); }
  9446. /// @copydoc dnnl::primitive_desc_base::get_alpha()const
  9447. float get_alpha() const { return base::get_alpha(); }
  9448. /// @copydoc dnnl::primitive_desc_base::get_beta()const
  9449. float get_beta() const { return base::get_beta(); }
  9450. };
  9451. /// Default constructor. Produces an empty object.
  9452. vanilla_rnn_backward() = default;
  9453. /// Constructs a vanilla RNN backward propagation primitive.
  9454. /// @param pd Primitive descriptor for a vanilla RNN backward
  9455. /// propagation primitive.
  9456. vanilla_rnn_backward(const primitive_desc &pd) : primitive(pd) {}
  9457. /// Constructs a vanilla RNN backward propagation primitive from
  9458. /// a cache blob.
  9459. /// @param pd Primitive descriptor for a vanilla RNN backward
  9460. /// propagation primitive.
  9461. /// @param cache_blob Cache blob.
  9462. vanilla_rnn_backward(
  9463. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  9464. : primitive(pd, cache_blob) {}
  9465. };
  9466. /// LSTM forward propagation primitive.
  9467. struct lstm_forward : public primitive {
  9468. /// Primitive descriptor for an LSTM forward propagation primitive.
  9469. struct primitive_desc : public rnn_primitive_desc_base {
  9470. /// Default constructor. Produces an empty object.
  9471. primitive_desc() = default;
  9472. /// Constructs a primitive descriptor for an LSTM (with or without
  9473. /// peephole and with or without projection) forward propagation
  9474. /// primitive.
  9475. ///
  9476. /// The following arguments may point to a zero memory descriptor:
  9477. /// - @p src_iter_desc together with @p src_iter_c_desc,
  9478. /// - @p weights_peephole_desc,
  9479. /// - @p bias_desc,
  9480. /// - @p dst_iter_desc together with @p dst_iter_c_desc.
  9481. ///
  9482. /// This would then indicate that the LSTM forward propagation
  9483. /// primitive should not use them and should default to zero values
  9484. /// instead.
  9485. ///
  9486. /// The @p weights_projection_desc may point to a zero memory
  9487. /// descriptor. This would then indicate that the LSTM doesn't have
  9488. /// recurrent projection layer.
  9489. ///
  9490. /// @note
  9491. /// All memory descriptors can be initialized with an
  9492. /// #dnnl::memory::format_tag::any value of @p format_tag.
  9493. ///
  9494. /// @param aengine Engine to use.
  9495. /// @param aprop_kind Propagation kind. Possible values are
  9496. /// #dnnl::prop_kind::forward_training, and
  9497. /// #dnnl::prop_kind::forward_inference.
  9498. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  9499. /// more info.
  9500. /// @param src_layer_desc Memory descriptor for the input vector.
  9501. /// @param src_iter_desc Memory descriptor for the input recurrent
  9502. /// hidden state vector.
  9503. /// @param src_iter_c_desc Memory descriptor for the input recurrent
  9504. /// cell state vector.
  9505. /// @param weights_layer_desc Memory descriptor for the weights
  9506. /// applied to the layer input.
  9507. /// @param weights_iter_desc Memory descriptor for the weights applied
  9508. /// to the recurrent input.
  9509. /// @param weights_peephole_desc Memory descriptor for the weights
  9510. /// applied to the cell states (according to the Peephole LSTM
  9511. /// formula).
  9512. /// @param weights_projection_desc Memory descriptor for the weights
  9513. /// applied to the hidden states to get the recurrent projection
  9514. /// (according to the Projection LSTM formula).
  9515. /// @param bias_desc Bias memory descriptor.
  9516. /// @param dst_layer_desc Memory descriptor for the output vector.
  9517. /// @param dst_iter_desc Memory descriptor for the output recurrent
  9518. /// hidden state vector.
  9519. /// @param dst_iter_c_desc Memory descriptor for the output recurrent
  9520. /// cell state vector.
  9521. /// @param attr Primitive attributes to use. Attributes are optional
  9522. /// and default to empty attributes.
  9523. /// @param allow_empty A flag signifying whether construction is
  9524. /// allowed to fail without throwing an exception. In this case an
  9525. /// empty object will be produced. This flag is optional and
  9526. /// defaults to false.
  9527. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  9528. rnn_direction direction, const memory::desc &src_layer_desc,
  9529. const memory::desc &src_iter_desc,
  9530. const memory::desc &src_iter_c_desc,
  9531. const memory::desc &weights_layer_desc,
  9532. const memory::desc &weights_iter_desc,
  9533. const memory::desc &weights_peephole_desc,
  9534. const memory::desc &weights_projection_desc,
  9535. const memory::desc &bias_desc,
  9536. const memory::desc &dst_layer_desc,
  9537. const memory::desc &dst_iter_desc,
  9538. const memory::desc &dst_iter_c_desc,
  9539. const primitive_attr &attr = default_attr(),
  9540. bool allow_empty = false)
  9541. : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
  9542. aprop_kind, algorithm::undef, direction, src_layer_desc,
  9543. src_iter_desc, &src_iter_c_desc, nullptr,
  9544. weights_layer_desc, weights_iter_desc,
  9545. &weights_peephole_desc, &weights_projection_desc, bias_desc,
  9546. dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
  9547. rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {}
  9548. /// Constructs a primitive descriptor for an LSTM (with or without
  9549. /// peephole) forward propagation primitive.
  9550. ///
  9551. /// The following arguments may point to a zero memory descriptor:
  9552. /// - @p src_iter_desc together with @p src_iter_c_desc,
  9553. /// - @p weights_peephole_desc,
  9554. /// - @p bias_desc,
  9555. /// - @p dst_iter_desc together with @p dst_iter_c_desc.
  9556. ///
  9557. /// This would then indicate that the LSTM forward propagation
  9558. /// primitive should not use them and should default to zero values
  9559. /// instead.
  9560. ///
  9561. /// @note
  9562. /// All memory descriptors can be initialized with an
  9563. /// #dnnl::memory::format_tag::any value of @p format_tag.
  9564. ///
  9565. /// @param aengine Engine to use.
  9566. /// @param aprop_kind Propagation kind. Possible values are
  9567. /// #dnnl::prop_kind::forward_training, and
  9568. /// #dnnl::prop_kind::forward_inference.
  9569. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  9570. /// more info.
  9571. /// @param src_layer_desc Memory descriptor for the input vector.
  9572. /// @param src_iter_desc Memory descriptor for the input recurrent
  9573. /// hidden state vector.
  9574. /// @param src_iter_c_desc Memory descriptor for the input recurrent
  9575. /// cell state vector.
  9576. /// @param weights_layer_desc Memory descriptor for the weights
  9577. /// applied to the layer input.
  9578. /// @param weights_iter_desc Memory descriptor for the weights applied
  9579. /// to the recurrent input.
  9580. /// @param weights_peephole_desc Memory descriptor for the weights
  9581. /// applied to the cell states (according to the Peephole LSTM
  9582. /// formula).
  9583. /// @param bias_desc Bias memory descriptor.
  9584. /// @param dst_layer_desc Memory descriptor for the output vector.
  9585. /// @param dst_iter_desc Memory descriptor for the output recurrent
  9586. /// hidden state vector.
  9587. /// @param dst_iter_c_desc Memory descriptor for the output recurrent
  9588. /// cell state vector.
  9589. /// @param attr Primitive attributes to use. Attributes are optional
  9590. /// and default to empty attributes.
  9591. /// @param allow_empty A flag signifying whether construction is
  9592. /// allowed to fail without throwing an exception. In this case an
  9593. /// empty object will be produced. This flag is optional and
  9594. /// defaults to false.
  9595. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  9596. rnn_direction direction, const memory::desc &src_layer_desc,
  9597. const memory::desc &src_iter_desc,
  9598. const memory::desc &src_iter_c_desc,
  9599. const memory::desc &weights_layer_desc,
  9600. const memory::desc &weights_iter_desc,
  9601. const memory::desc &weights_peephole_desc,
  9602. const memory::desc &bias_desc,
  9603. const memory::desc &dst_layer_desc,
  9604. const memory::desc &dst_iter_desc,
  9605. const memory::desc &dst_iter_c_desc,
  9606. const primitive_attr &attr = default_attr(),
  9607. bool allow_empty = false)
  9608. : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
  9609. aprop_kind, algorithm::undef, direction, src_layer_desc,
  9610. src_iter_desc, &src_iter_c_desc, nullptr,
  9611. weights_layer_desc, weights_iter_desc,
  9612. &weights_peephole_desc, nullptr, bias_desc, dst_layer_desc,
  9613. dst_iter_desc, &dst_iter_c_desc, rnn_flags::undef, 0.0f,
  9614. 0.0f, attr, allow_empty) {}
  9615. /// Constructs a primitive descriptor for an LSTM forward propagation
  9616. /// primitive.
  9617. ///
  9618. /// The following arguments may point to a zero memory descriptor:
  9619. /// - @p src_iter_desc together with @p src_iter_c_desc,
  9620. /// - @p bias_desc,
  9621. /// - @p dst_iter_desc together with @p dst_iter_c_desc.
  9622. ///
  9623. /// This would then indicate that the LSTM forward propagation
  9624. /// primitive should not use them and should default to zero values
  9625. /// instead.
  9626. ///
  9627. /// @note
  9628. /// All memory descriptors can be initialized with an
  9629. /// #dnnl::memory::format_tag::any value of @p format_tag.
  9630. ///
  9631. /// @param aengine Engine to use.
  9632. /// @param aprop_kind Propagation kind. Possible values are
  9633. /// #dnnl::prop_kind::forward_training, and
  9634. /// #dnnl::prop_kind::forward_inference.
  9635. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  9636. /// more info.
  9637. /// @param src_layer_desc Memory descriptor for the input vector.
  9638. /// @param src_iter_desc Memory descriptor for the input recurrent
  9639. /// hidden state vector.
  9640. /// @param src_iter_c_desc Memory descriptor for the input recurrent
  9641. /// cell state vector.
  9642. /// @param weights_layer_desc Memory descriptor for the weights
  9643. /// applied to the layer input.
  9644. /// @param weights_iter_desc Memory descriptor for the weights applied
  9645. /// to the recurrent input.
  9646. /// @param bias_desc Bias memory descriptor.
  9647. /// @param dst_layer_desc Memory descriptor for the output vector.
  9648. /// @param dst_iter_desc Memory descriptor for the output recurrent
  9649. /// hidden state vector.
  9650. /// @param dst_iter_c_desc Memory descriptor for the output recurrent
  9651. /// cell state vector.
  9652. /// @param attr Primitive attributes to use. Attributes are optional
  9653. /// and default to empty attributes.
  9654. /// @param allow_empty A flag signifying whether construction is
  9655. /// allowed to fail without throwing an exception. In this case an
  9656. /// empty object will be produced. This flag is optional and
  9657. /// defaults to false.
  9658. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  9659. rnn_direction direction, const memory::desc &src_layer_desc,
  9660. const memory::desc &src_iter_desc,
  9661. const memory::desc &src_iter_c_desc,
  9662. const memory::desc &weights_layer_desc,
  9663. const memory::desc &weights_iter_desc,
  9664. const memory::desc &bias_desc,
  9665. const memory::desc &dst_layer_desc,
  9666. const memory::desc &dst_iter_desc,
  9667. const memory::desc &dst_iter_c_desc,
  9668. const primitive_attr &attr = default_attr(),
  9669. bool allow_empty = false)
  9670. : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
  9671. aprop_kind, algorithm::undef, direction, src_layer_desc,
  9672. src_iter_desc, &src_iter_c_desc, nullptr,
  9673. weights_layer_desc, weights_iter_desc, nullptr, nullptr,
  9674. bias_desc, dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
  9675. rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {}
  9676. /// Constructs a primitive descriptor for an LSTM forward propagation
  9677. /// primitive from a C API primitive descriptor that must have a
  9678. /// matching kind.
  9679. ///
  9680. /// @param pd C API primitive descriptor for an LSTM forward
  9681. /// propagation primitive.
  9682. primitive_desc(dnnl_primitive_desc_t pd)
  9683. : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
  9684. dnnl::prop_kind::forward_inference,
  9685. dnnl::algorithm::vanilla_lstm) {}
  9686. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  9687. memory::desc src_layer_desc() const {
  9688. return rnn_base::src_layer_desc();
  9689. }
  9690. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  9691. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  9692. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  9693. memory::desc src_iter_c_desc() const {
  9694. return rnn_base::src_iter_c_desc();
  9695. }
  9696. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  9697. memory::desc weights_layer_desc() const {
  9698. return rnn_base::weights_layer_desc();
  9699. }
  9700. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  9701. memory::desc weights_iter_desc() const {
  9702. return rnn_base::weights_iter_desc();
  9703. }
  9704. /// @copydoc dnnl::rnn_primitive_desc_base::weights_peephole_desc()const
  9705. memory::desc weights_peephole_desc() const {
  9706. return rnn_base::weights_peephole_desc();
  9707. }
  9708. /// @copydoc dnnl::rnn_primitive_desc_base::weights_projection_desc()const
  9709. memory::desc weights_projection_desc() const {
  9710. return rnn_base::weights_projection_desc();
  9711. }
  9712. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  9713. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  9714. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  9715. memory::desc dst_layer_desc() const {
  9716. return rnn_base::dst_layer_desc();
  9717. }
  9718. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  9719. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  9720. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  9721. memory::desc dst_iter_c_desc() const {
  9722. return rnn_base::dst_iter_c_desc();
  9723. }
  9724. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  9725. memory::desc workspace_desc() const {
  9726. return rnn_base::workspace_desc();
  9727. }
  9728. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  9729. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  9730. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  9731. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  9732. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  9733. rnn_direction get_direction() const { return base::get_direction(); }
  9734. };
  9735. /// Default constructor. Produces an empty object.
  9736. lstm_forward() = default;
  9737. /// Constructs an LSTM forward propagation primitive.
  9738. /// @param pd Primitive descriptor for an LSTM forward propagation
  9739. /// primitive.
  9740. lstm_forward(const primitive_desc &pd) : primitive(pd) {}
  9741. /// Constructs an LSTM forward propagation primitive from a cache blob.
  9742. /// @param pd Primitive descriptor for an LSTM forward propagation
  9743. /// primitive.
  9744. /// @param cache_blob Cache blob.
  9745. lstm_forward(
  9746. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  9747. : primitive(pd, cache_blob) {}
  9748. };
  9749. /// LSTM backward propagation primitive.
  9750. struct lstm_backward : public primitive {
  9751. /// Primitive descriptor for an LSTM backward propagation primitive.
  9752. struct primitive_desc : public rnn_primitive_desc_base {
  9753. /// Default constructor. Produces an empty object.
  9754. primitive_desc() = default;
  9755. /// Constructs an LSTM (with or without peephole and with or without
  9756. /// projection) primitive descriptor for backward propagation
  9757. /// using @p prop_kind, @p direction, and memory descriptors.
  9758. ///
  9759. /// The following arguments may point to a zero memory descriptor:
  9760. /// - @p src_iter_desc together with @p src_iter_c_desc,
  9761. /// @p diff_src_iter_desc, and @p diff_src_iter_c_desc,
  9762. /// - @p weights_peephole_desc together with
  9763. /// @p diff_weights_peephole_desc
  9764. /// - @p bias_desc together with @p diff_bias_desc,
  9765. /// - @p dst_iter_desc together with @p dst_iter_c_desc,
  9766. /// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc.
  9767. ///
  9768. /// This would then indicate that the LSTM backward propagation
  9769. /// primitive should not use them and should default to zero values
  9770. /// instead.
  9771. ///
  9772. /// The @p weights_projection_desc together with @p
  9773. /// diff_weights_projection_desc may point to a zero memory descriptor.
  9774. /// This would then indicate that the LSTM doesn't have recurrent
  9775. /// projection layer.
  9776. ///
  9777. /// @note
  9778. /// All memory descriptors can be initialized with
  9779. /// #dnnl::memory::format_tag::any value of @p format_tag.
  9780. ///
  9781. /// @param aengine Engine to use.
  9782. /// @param aprop_kind Propagation kind. Must be
  9783. /// #dnnl::prop_kind::backward.
  9784. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  9785. /// more info.
  9786. /// @param src_layer_desc Memory descriptor for the input vector.
  9787. /// @param src_iter_desc Memory descriptor for the input recurrent
  9788. /// hidden state vector.
  9789. /// @param src_iter_c_desc Memory descriptor for the input recurrent
  9790. /// cell state vector.
  9791. /// @param weights_layer_desc Memory descriptor for the weights
  9792. /// applied to the layer input.
  9793. /// @param weights_iter_desc Memory descriptor for the weights applied
  9794. /// to the recurrent input.
  9795. /// @param weights_peephole_desc Memory descriptor for the weights
  9796. /// applied to the cell states (according to the Peephole LSTM
  9797. /// formula).
  9798. /// @param weights_projection_desc Memory descriptor for the weights
  9799. /// applied to the hidden states to get the recurrent projection
  9800. /// (according to the Projection LSTM formula).
  9801. /// @param bias_desc Bias memory descriptor.
  9802. /// @param dst_layer_desc Memory descriptor for the output vector.
  9803. /// @param dst_iter_desc Memory descriptor for the output recurrent
  9804. /// hidden state vector.
  9805. /// @param dst_iter_c_desc Memory descriptor for the output recurrent
  9806. /// cell state vector.
  9807. /// @param diff_src_layer_desc Memory descriptor for the diff of input
  9808. /// vector.
  9809. /// @param diff_src_iter_desc Memory descriptor for the diff of input
  9810. /// recurrent hidden state vector.
  9811. /// @param diff_src_iter_c_desc Memory descriptor for the diff of
  9812. /// input recurrent cell state vector.
  9813. /// @param diff_weights_layer_desc Memory descriptor for the diff of
  9814. /// weights applied to the layer input.
  9815. /// @param diff_weights_iter_desc Memory descriptor for the diff of
  9816. /// weights applied to the recurrent input.
  9817. /// @param diff_weights_peephole_desc Memory descriptor for the diff of
  9818. /// weights applied to the cell states (according to the Peephole
  9819. /// LSTM formula).
  9820. /// @param diff_weights_projection_desc Memory descriptor for the diff
  9821. /// of weights applied to the hidden states to get the recurrent
  9822. /// projection (according to the Projection LSTM formula).
  9823. /// @param diff_bias_desc Diff bias memory descriptor.
  9824. /// @param diff_dst_layer_desc Memory descriptor for the diff of
  9825. /// output vector.
  9826. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  9827. /// recurrent hidden state vector.
  9828. /// @param diff_dst_iter_c_desc Memory descriptor for the diff of
  9829. /// output recurrent cell state vector.
  9830. /// @param hint_fwd_pd Primitive descriptor for an LSTM
  9831. /// forward propagation primitive. It is used as a hint for
  9832. /// deciding which memory format to use.
  9833. /// @param attr Primitive attributes to use. Attributes are optional
  9834. /// and default to empty attributes.
  9835. /// @param allow_empty A flag signifying whether construction is
  9836. /// allowed to fail without throwing an exception. In this case an
  9837. /// empty object will be produced. This flag is optional and
  9838. /// defaults to false.
  9839. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  9840. rnn_direction direction, const memory::desc &src_layer_desc,
  9841. const memory::desc &src_iter_desc,
  9842. const memory::desc &src_iter_c_desc,
  9843. const memory::desc &weights_layer_desc,
  9844. const memory::desc &weights_iter_desc,
  9845. const memory::desc &weights_peephole_desc,
  9846. const memory::desc &weights_projection_desc,
  9847. const memory::desc &bias_desc,
  9848. const memory::desc &dst_layer_desc,
  9849. const memory::desc &dst_iter_desc,
  9850. const memory::desc &dst_iter_c_desc,
  9851. const memory::desc &diff_src_layer_desc,
  9852. const memory::desc &diff_src_iter_desc,
  9853. const memory::desc &diff_src_iter_c_desc,
  9854. const memory::desc &diff_weights_layer_desc,
  9855. const memory::desc &diff_weights_iter_desc,
  9856. const memory::desc &diff_weights_peephole_desc,
  9857. const memory::desc &diff_weights_projection_desc,
  9858. const memory::desc &diff_bias_desc,
  9859. const memory::desc &diff_dst_layer_desc,
  9860. const memory::desc &diff_dst_iter_desc,
  9861. const memory::desc &diff_dst_iter_c_desc,
  9862. const lstm_forward::primitive_desc &hint_fwd_pd,
  9863. const primitive_attr &attr = default_attr(),
  9864. bool allow_empty = false)
  9865. : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
  9866. aprop_kind, algorithm::undef, direction, src_layer_desc,
  9867. src_iter_desc, &src_iter_c_desc, nullptr,
  9868. weights_layer_desc, weights_iter_desc,
  9869. &weights_peephole_desc, &weights_projection_desc, bias_desc,
  9870. dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
  9871. diff_src_layer_desc, diff_src_iter_desc,
  9872. &diff_src_iter_c_desc, nullptr, diff_weights_layer_desc,
  9873. diff_weights_iter_desc, &diff_weights_peephole_desc,
  9874. &diff_weights_projection_desc, diff_bias_desc,
  9875. diff_dst_layer_desc, diff_dst_iter_desc,
  9876. &diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f,
  9877. hint_fwd_pd, attr, allow_empty) {}
  9878. /// Constructs an LSTM (with or without peephole) primitive descriptor
  9879. /// for backward propagation using @p prop_kind, @p direction,
  9880. /// and memory descriptors.
  9881. ///
  9882. /// The following arguments may point to a zero memory descriptor:
  9883. /// - @p src_iter_desc together with @p src_iter_c_desc,
  9884. /// @p diff_src_iter_desc, and @p diff_src_iter_c_desc,
  9885. /// - @p weights_peephole_desc together with
  9886. /// @p diff_weights_peephole_desc
  9887. /// - @p bias_desc together with @p diff_bias_desc,
  9888. /// - @p dst_iter_desc together with @p dst_iter_c_desc,
  9889. /// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc.
  9890. ///
  9891. /// This would then indicate that the LSTM backward propagation
  9892. /// primitive should not use them and should default to zero values
  9893. /// instead.
  9894. ///
  9895. /// @note
  9896. /// All memory descriptors may be initialized with
  9897. /// #dnnl::memory::format_tag::any value of @p format_tag.
  9898. ///
  9899. /// @param aengine Engine to use.
  9900. /// @param aprop_kind Propagation kind. Must be
  9901. /// #dnnl::prop_kind::backward.
  9902. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  9903. /// more info.
  9904. /// @param src_layer_desc Memory descriptor for the input vector.
  9905. /// @param src_iter_desc Memory descriptor for the input recurrent
  9906. /// hidden state vector.
  9907. /// @param src_iter_c_desc Memory descriptor for the input recurrent
  9908. /// cell state vector.
  9909. /// @param weights_layer_desc Memory descriptor for the weights
  9910. /// applied to the layer input.
  9911. /// @param weights_iter_desc Memory descriptor for the weights applied
  9912. /// to the recurrent input.
  9913. /// @param weights_peephole_desc Memory descriptor for the weights
  9914. /// applied to the cell states (according to the Peephole LSTM
  9915. /// formula).
  9916. /// @param bias_desc Bias memory descriptor.
  9917. /// @param dst_layer_desc Memory descriptor for the output vector.
  9918. /// @param dst_iter_desc Memory descriptor for the output recurrent
  9919. /// hidden state vector.
  9920. /// @param dst_iter_c_desc Memory descriptor for the output recurrent
  9921. /// cell state vector.
  9922. /// @param diff_src_layer_desc Memory descriptor for the diff of input
  9923. /// vector.
  9924. /// @param diff_src_iter_desc Memory descriptor for the diff of input
  9925. /// recurrent hidden state vector.
  9926. /// @param diff_src_iter_c_desc Memory descriptor for the diff of
  9927. /// input recurrent cell state vector.
  9928. /// @param diff_weights_layer_desc Memory descriptor for the diff of
  9929. /// weights applied to the layer input.
  9930. /// @param diff_weights_iter_desc Memory descriptor for the diff of
  9931. /// weights applied to the recurrent input.
  9932. /// @param diff_weights_peephole_desc Memory descriptor for the diff of
  9933. /// weights applied to the cell states (according to the Peephole
  9934. /// LSTM formula).
  9935. /// @param diff_bias_desc Diff bias memory descriptor.
  9936. /// @param diff_dst_layer_desc Memory descriptor for the diff of
  9937. /// output vector.
  9938. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  9939. /// recurrent hidden state vector.
  9940. /// @param diff_dst_iter_c_desc Memory descriptor for the diff of
  9941. /// output recurrent cell state vector.
  9942. /// @param hint_fwd_pd Primitive descriptor for an LSTM
  9943. /// forward propagation primitive. It is used as a hint for
  9944. /// deciding which memory format to use.
  9945. /// @param attr Primitive attributes to use. Attributes are optional
  9946. /// and default to empty attributes.
  9947. /// @param allow_empty A flag signifying whether construction is
  9948. /// allowed to fail without throwing an exception. In this case an
  9949. /// empty object will be produced. This flag is optional and
  9950. /// defaults to false.
  9951. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  9952. rnn_direction direction, const memory::desc &src_layer_desc,
  9953. const memory::desc &src_iter_desc,
  9954. const memory::desc &src_iter_c_desc,
  9955. const memory::desc &weights_layer_desc,
  9956. const memory::desc &weights_iter_desc,
  9957. const memory::desc &weights_peephole_desc,
  9958. const memory::desc &bias_desc,
  9959. const memory::desc &dst_layer_desc,
  9960. const memory::desc &dst_iter_desc,
  9961. const memory::desc &dst_iter_c_desc,
  9962. const memory::desc &diff_src_layer_desc,
  9963. const memory::desc &diff_src_iter_desc,
  9964. const memory::desc &diff_src_iter_c_desc,
  9965. const memory::desc &diff_weights_layer_desc,
  9966. const memory::desc &diff_weights_iter_desc,
  9967. const memory::desc &diff_weights_peephole_desc,
  9968. const memory::desc &diff_bias_desc,
  9969. const memory::desc &diff_dst_layer_desc,
  9970. const memory::desc &diff_dst_iter_desc,
  9971. const memory::desc &diff_dst_iter_c_desc,
  9972. const lstm_forward::primitive_desc &hint_fwd_pd,
  9973. const primitive_attr &attr = default_attr(),
  9974. bool allow_empty = false)
  9975. : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
  9976. aprop_kind, algorithm::undef, direction, src_layer_desc,
  9977. src_iter_desc, &src_iter_c_desc, nullptr,
  9978. weights_layer_desc, weights_iter_desc,
  9979. &weights_peephole_desc, nullptr, bias_desc, dst_layer_desc,
  9980. dst_iter_desc, &dst_iter_c_desc, diff_src_layer_desc,
  9981. diff_src_iter_desc, &diff_src_iter_c_desc, nullptr,
  9982. diff_weights_layer_desc, diff_weights_iter_desc,
  9983. &diff_weights_peephole_desc, nullptr, diff_bias_desc,
  9984. diff_dst_layer_desc, diff_dst_iter_desc,
  9985. &diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f,
  9986. hint_fwd_pd, attr, allow_empty) {}
  9987. /// Constructs an LSTM primitive descriptor for backward propagation
  9988. /// using @p prop_kind, @p direction, and memory descriptors.
  9989. ///
  9990. /// The following arguments may point to a zero memory descriptor:
  9991. /// - @p src_iter_desc together with @p src_iter_c_desc,
  9992. /// @p diff_src_iter_desc, and @p diff_src_iter_c_desc,
  9993. /// - @p bias_desc together with @p diff_bias_desc,
  9994. /// - @p dst_iter_desc together with @p dst_iter_c_desc,
  9995. /// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc.
  9996. ///
  9997. /// This would then indicate that the LSTM backward propagation
  9998. /// primitive should not use them and should default to zero values
  9999. /// instead.
  10000. ///
  10001. /// @note
  10002. /// All memory descriptors may be initialized with
  10003. /// #dnnl::memory::format_tag::any value of @p format_tag.
  10004. ///
  10005. /// @param aengine Engine to use.
  10006. /// @param aprop_kind Propagation kind. Must be
  10007. /// #dnnl::prop_kind::backward.
  10008. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  10009. /// more info.
  10010. /// @param src_layer_desc Memory descriptor for the input vector.
  10011. /// @param src_iter_desc Memory descriptor for the input recurrent
  10012. /// hidden state vector.
  10013. /// @param src_iter_c_desc Memory descriptor for the input recurrent
  10014. /// cell state vector.
  10015. /// @param weights_layer_desc Memory descriptor for the weights
  10016. /// applied to the layer input.
  10017. /// @param weights_iter_desc Memory descriptor for the weights applied
  10018. /// to the recurrent input.
  10019. /// @param bias_desc Bias memory descriptor.
  10020. /// @param dst_layer_desc Memory descriptor for the output vector.
  10021. /// @param dst_iter_desc Memory descriptor for the output recurrent
  10022. /// hidden state vector.
  10023. /// @param dst_iter_c_desc Memory descriptor for the output recurrent
  10024. /// cell state vector.
  10025. /// @param diff_src_layer_desc Memory descriptor for the diff of input
  10026. /// vector.
  10027. /// @param diff_src_iter_desc Memory descriptor for the diff of input
  10028. /// recurrent hidden state vector.
  10029. /// @param diff_src_iter_c_desc Memory descriptor for the diff of
  10030. /// input recurrent cell state vector.
  10031. /// @param diff_weights_layer_desc Memory descriptor for the diff of
  10032. /// weights applied to the layer input.
  10033. /// @param diff_weights_iter_desc Memory descriptor for the diff of
  10034. /// weights applied to the recurrent input.
  10035. /// @param diff_bias_desc Diff bias memory descriptor.
  10036. /// @param diff_dst_layer_desc Memory descriptor for the diff of
  10037. /// output vector.
  10038. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  10039. /// recurrent hidden state vector.
  10040. /// @param diff_dst_iter_c_desc Memory descriptor for the diff of
  10041. /// output recurrent cell state vector.
  10042. /// @param hint_fwd_pd Primitive descriptor for a convolution
  10043. /// forward propagation primitive. It is used as a hint for
  10044. /// deciding which memory format to use.
  10045. /// @param attr Primitive attributes to use. Attributes are optional
  10046. /// and default to empty attributes.
  10047. /// @param allow_empty A flag signifying whether construction is
  10048. /// allowed to fail without throwing an exception. In this case an
  10049. /// empty object will be produced. This flag is optional and
  10050. /// defaults to false.
  10051. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  10052. rnn_direction direction, const memory::desc &src_layer_desc,
  10053. const memory::desc &src_iter_desc,
  10054. const memory::desc &src_iter_c_desc,
  10055. const memory::desc &weights_layer_desc,
  10056. const memory::desc &weights_iter_desc,
  10057. const memory::desc &bias_desc,
  10058. const memory::desc &dst_layer_desc,
  10059. const memory::desc &dst_iter_desc,
  10060. const memory::desc &dst_iter_c_desc,
  10061. const memory::desc &diff_src_layer_desc,
  10062. const memory::desc &diff_src_iter_desc,
  10063. const memory::desc &diff_src_iter_c_desc,
  10064. const memory::desc &diff_weights_layer_desc,
  10065. const memory::desc &diff_weights_iter_desc,
  10066. const memory::desc &diff_bias_desc,
  10067. const memory::desc &diff_dst_layer_desc,
  10068. const memory::desc &diff_dst_iter_desc,
  10069. const memory::desc &diff_dst_iter_c_desc,
  10070. const lstm_forward::primitive_desc &hint_fwd_pd,
  10071. const primitive_attr &attr = default_attr(),
  10072. bool allow_empty = false)
  10073. : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
  10074. aprop_kind, algorithm::undef, direction, src_layer_desc,
  10075. src_iter_desc, &src_iter_c_desc, nullptr,
  10076. weights_layer_desc, weights_iter_desc, nullptr, nullptr,
  10077. bias_desc, dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
  10078. diff_src_layer_desc, diff_src_iter_desc,
  10079. &diff_src_iter_c_desc, nullptr, diff_weights_layer_desc,
  10080. diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc,
  10081. diff_dst_layer_desc, diff_dst_iter_desc,
  10082. &diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f,
  10083. hint_fwd_pd, attr, allow_empty) {}
  10084. /// Constructs a primitive descriptor for an LSTM backward propagation
  10085. /// primitive from a C API primitive descriptor that must have a
  10086. /// matching kind.
  10087. ///
  10088. /// @param pd C API primitive descriptor for an LSTM backward
  10089. /// propagation primitive.
  10090. primitive_desc(dnnl_primitive_desc_t pd)
  10091. : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
  10092. dnnl::algorithm::vanilla_lstm) {}
  10093. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  10094. memory::desc src_layer_desc() const {
  10095. return rnn_base::src_layer_desc();
  10096. }
  10097. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  10098. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  10099. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  10100. memory::desc src_iter_c_desc() const {
  10101. return rnn_base::src_iter_c_desc();
  10102. }
  10103. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  10104. memory::desc weights_layer_desc() const {
  10105. return rnn_base::weights_layer_desc();
  10106. }
  10107. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  10108. memory::desc weights_iter_desc() const {
  10109. return rnn_base::weights_iter_desc();
  10110. }
  10111. /// @copydoc dnnl::rnn_primitive_desc_base::weights_peephole_desc()const
  10112. memory::desc weights_peephole_desc() const {
  10113. return rnn_base::weights_peephole_desc();
  10114. }
  10115. /// @copydoc dnnl::rnn_primitive_desc_base::weights_projection_desc()const
  10116. memory::desc weights_projection_desc() const {
  10117. return rnn_base::weights_projection_desc();
  10118. }
  10119. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  10120. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  10121. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  10122. memory::desc dst_layer_desc() const {
  10123. return rnn_base::dst_layer_desc();
  10124. }
  10125. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  10126. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  10127. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  10128. memory::desc dst_iter_c_desc() const {
  10129. return rnn_base::dst_iter_c_desc();
  10130. }
  10131. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  10132. memory::desc workspace_desc() const {
  10133. return rnn_base::workspace_desc();
  10134. }
  10135. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
  10136. memory::desc diff_src_layer_desc() const {
  10137. return rnn_base::diff_src_layer_desc();
  10138. }
  10139. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
  10140. memory::desc diff_src_iter_desc() const {
  10141. return rnn_base::diff_src_iter_desc();
  10142. }
  10143. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_c_desc()const
  10144. memory::desc diff_src_iter_c_desc() const {
  10145. return rnn_base::diff_src_iter_c_desc();
  10146. }
  10147. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
  10148. memory::desc diff_weights_layer_desc() const {
  10149. return rnn_base::diff_weights_layer_desc();
  10150. }
  10151. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
  10152. memory::desc diff_weights_iter_desc() const {
  10153. return rnn_base::diff_weights_iter_desc();
  10154. }
  10155. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_peephole_desc()const
  10156. memory::desc diff_weights_peephole_desc() const {
  10157. return rnn_base::diff_weights_peephole_desc();
  10158. }
  10159. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_projection_desc()const
  10160. memory::desc diff_weights_projection_desc() const {
  10161. return rnn_base::diff_weights_projection_desc();
  10162. }
  10163. /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
  10164. memory::desc diff_bias_desc() const {
  10165. return rnn_base::diff_bias_desc();
  10166. }
  10167. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
  10168. memory::desc diff_dst_layer_desc() const {
  10169. return rnn_base::diff_dst_layer_desc();
  10170. }
  10171. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
  10172. memory::desc diff_dst_iter_desc() const {
  10173. return rnn_base::diff_dst_iter_desc();
  10174. }
  10175. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_c_desc()const
  10176. memory::desc diff_dst_iter_c_desc() const {
  10177. return rnn_base::diff_dst_iter_c_desc();
  10178. }
  10179. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  10180. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  10181. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  10182. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  10183. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  10184. rnn_direction get_direction() const { return base::get_direction(); }
  10185. };
  10186. /// Default constructor. Produces an empty object.
  10187. lstm_backward() = default;
  10188. /// Constructs an LSTM backward propagation primitive.
  10189. /// @param pd Primitive descriptor for an LSTM backward propagation
  10190. /// primitive.
  10191. lstm_backward(const primitive_desc &pd) : primitive(pd) {}
  10192. /// Constructs an LSTM backward propagation primitive from a cache blob.
  10193. /// @param pd Primitive descriptor for an LSTM backward propagation
  10194. /// primitive.
  10195. /// @param cache_blob Cache blob.
  10196. lstm_backward(
  10197. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  10198. : primitive(pd, cache_blob) {}
  10199. };
  10200. /// GRU forward propagation primitive.
  10201. struct gru_forward : public primitive {
  10202. /// Primitive descriptor for a GRU forward propagation primitive.
  10203. struct primitive_desc : public rnn_primitive_desc_base {
  10204. /// Default constructor. Produces an empty object.
  10205. primitive_desc() = default;
  10206. /// Constructs a primitive descriptor for a GRU forward propagation
  10207. /// primitive.
  10208. ///
  10209. /// The following arguments may point to a zero memory descriptor:
  10210. /// - @p src_iter_desc,
  10211. /// - @p bias_desc,
  10212. /// - @p dst_iter_desc.
  10213. ///
  10214. /// This would then indicate that the GRU forward propagation primitive
  10215. /// should not use them and should default to zero values instead.
  10216. ///
  10217. /// @note
  10218. /// All memory descriptors except @p src_iter_desc may be
  10219. /// initialized with an #dnnl::memory::format_tag::any value of @p
  10220. /// format_tag.
  10221. ///
  10222. /// @param aengine Engine to use.
  10223. /// @param aprop_kind Propagation kind. Possible values are
  10224. /// #dnnl::prop_kind::forward_training, and
  10225. /// #dnnl::prop_kind::forward_inference.
  10226. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  10227. /// more info.
  10228. /// @param src_layer_desc Memory descriptor for the input vector.
  10229. /// @param src_iter_desc Memory descriptor for the input recurrent
  10230. /// hidden state vector.
  10231. /// @param weights_layer_desc Memory descriptor for the weights
  10232. /// applied to the layer input.
  10233. /// @param weights_iter_desc Memory descriptor for the weights applied
  10234. /// to the recurrent input.
  10235. /// @param bias_desc Bias memory descriptor.
  10236. /// @param dst_layer_desc Memory descriptor for the output vector.
  10237. /// @param dst_iter_desc Memory descriptor for the output recurrent
  10238. /// hidden state vector.
  10239. /// @param attr Primitive attributes to use. Attributes are optional
  10240. /// and default to empty attributes.
  10241. /// @param allow_empty A flag signifying whether construction is
  10242. /// allowed to fail without throwing an exception. In this case an
  10243. /// empty object will be produced. This flag is optional and
  10244. /// defaults to false.
  10245. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  10246. rnn_direction direction, const memory::desc &src_layer_desc,
  10247. const memory::desc &src_iter_desc,
  10248. const memory::desc &weights_layer_desc,
  10249. const memory::desc &weights_iter_desc,
  10250. const memory::desc &bias_desc,
  10251. const memory::desc &dst_layer_desc,
  10252. const memory::desc &dst_iter_desc,
  10253. const primitive_attr &attr = default_attr(),
  10254. bool allow_empty = false)
  10255. : rnn_primitive_desc_base(aengine, algorithm::vanilla_gru,
  10256. aprop_kind, algorithm::undef, direction, src_layer_desc,
  10257. src_iter_desc, nullptr, nullptr, weights_layer_desc,
  10258. weights_iter_desc, nullptr, nullptr, bias_desc,
  10259. dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
  10260. 0.0f, 0.0f, attr, allow_empty) {}
  10261. /// Constructs a primitive descriptor for a GRU forward propagation
  10262. /// primitive from a C API primitive descriptor that must have a
  10263. /// matching kind.
  10264. ///
  10265. /// @param pd C API primitive descriptor for a GRU forward
  10266. /// propagation primitive.
  10267. primitive_desc(dnnl_primitive_desc_t pd)
  10268. : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
  10269. dnnl::prop_kind::forward_inference,
  10270. dnnl::algorithm::vanilla_gru) {}
  10271. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  10272. memory::desc src_layer_desc() const {
  10273. return rnn_base::src_layer_desc();
  10274. }
  10275. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  10276. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  10277. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  10278. memory::desc weights_layer_desc() const {
  10279. return rnn_base::weights_layer_desc();
  10280. }
  10281. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  10282. memory::desc weights_iter_desc() const {
  10283. return rnn_base::weights_iter_desc();
  10284. }
  10285. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  10286. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  10287. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  10288. memory::desc dst_layer_desc() const {
  10289. return rnn_base::dst_layer_desc();
  10290. }
  10291. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  10292. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  10293. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  10294. memory::desc workspace_desc() const {
  10295. return rnn_base::workspace_desc();
  10296. }
  10297. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  10298. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  10299. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  10300. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  10301. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  10302. rnn_direction get_direction() const { return base::get_direction(); }
  10303. };
  10304. /// Default constructor. Produces an empty object.
  10305. gru_forward() = default;
  10306. /// Constructs a GRU forward propagation primitive.
  10307. /// @param pd Primitive descriptor for a GRU forward propagation
  10308. /// primitive.
  10309. gru_forward(const primitive_desc &pd) : primitive(pd) {}
  10310. /// Constructs a GRU forward propagation primitive from a cache blob.
  10311. /// @param pd Primitive descriptor for a GRU forward propagation
  10312. /// primitive.
  10313. /// @param cache_blob Cache blob.
  10314. gru_forward(
  10315. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  10316. : primitive(pd, cache_blob) {}
  10317. };
  10318. /// GRU backward propagation primitive.
  10319. struct gru_backward : public primitive {
  10320. /// Primitive descriptor for a GRU backward propagation primitive.
  10321. struct primitive_desc : public rnn_primitive_desc_base {
  10322. /// Default constructor. Produces an empty object.
  10323. primitive_desc() = default;
  10324. /// Constructs a primitive descriptor for a GRU backward propagation
  10325. /// primitive.
  10326. ///
  10327. /// The following arguments may point to a zero memory descriptor:
  10328. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  10329. /// - @p bias_desc together with @p diff_bias_desc,
  10330. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  10331. ///
  10332. /// This would then indicate that the GRU backward propagation
  10333. /// primitive should not use them and should default to zero values
  10334. /// instead.
  10335. ///
  10336. /// @note
  10337. /// All memory descriptors may be initialized with
  10338. /// #dnnl::memory::format_tag::any value of @p format_tag.
  10339. ///
  10340. /// @param aengine Engine to use.
  10341. /// @param aprop_kind Propagation kind. Must be
  10342. /// #dnnl::prop_kind::backward.
  10343. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  10344. /// more info.
  10345. /// @param src_layer_desc Memory descriptor for the input vector.
  10346. /// @param src_iter_desc Memory descriptor for the input recurrent
  10347. /// hidden state vector.
  10348. /// @param weights_layer_desc Memory descriptor for the weights
  10349. /// applied to the layer input.
  10350. /// @param weights_iter_desc Memory descriptor for the weights applied
  10351. /// to the recurrent input.
  10352. /// @param bias_desc Bias memory descriptor.
  10353. /// @param dst_layer_desc Memory descriptor for the output vector.
  10354. /// @param dst_iter_desc Memory descriptor for the output recurrent
  10355. /// hidden state vector.
  10356. /// @param diff_src_layer_desc Memory descriptor for the diff of input
  10357. /// vector.
  10358. /// @param diff_src_iter_desc Memory descriptor for the diff of input
  10359. /// recurrent hidden state vector.
  10360. /// @param diff_weights_layer_desc Memory descriptor for the diff of
  10361. /// weights applied to the layer input.
  10362. /// @param diff_weights_iter_desc Memory descriptor for the diff of
  10363. /// weights applied to the recurrent input.
  10364. /// @param diff_bias_desc Diff bias memory descriptor.
  10365. /// @param diff_dst_layer_desc Memory descriptor for the diff of
  10366. /// output vector.
  10367. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  10368. /// recurrent hidden state vector.
  10369. /// @param hint_fwd_pd Primitive descriptor for a GRU
  10370. /// forward propagation primitive. It is used as a hint for
  10371. /// deciding which memory format to use.
  10372. /// @param attr Primitive attributes to use. Attributes are optional
  10373. /// and default to empty attributes.
  10374. /// @param allow_empty A flag signifying whether construction is
  10375. /// allowed to fail without throwing an exception. In this case an
  10376. /// empty object will be produced. This flag is optional and
  10377. /// defaults to false.
  10378. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  10379. rnn_direction direction, const memory::desc &src_layer_desc,
  10380. const memory::desc &src_iter_desc,
  10381. const memory::desc &weights_layer_desc,
  10382. const memory::desc &weights_iter_desc,
  10383. const memory::desc &bias_desc,
  10384. const memory::desc &dst_layer_desc,
  10385. const memory::desc &dst_iter_desc,
  10386. const memory::desc &diff_src_layer_desc,
  10387. const memory::desc &diff_src_iter_desc,
  10388. const memory::desc &diff_weights_layer_desc,
  10389. const memory::desc &diff_weights_iter_desc,
  10390. const memory::desc &diff_bias_desc,
  10391. const memory::desc &diff_dst_layer_desc,
  10392. const memory::desc &diff_dst_iter_desc,
  10393. const gru_forward::primitive_desc &hint_fwd_pd,
  10394. const primitive_attr &attr = default_attr(),
  10395. bool allow_empty = false)
  10396. : rnn_primitive_desc_base(aengine, algorithm::vanilla_gru,
  10397. aprop_kind, algorithm::undef, direction, src_layer_desc,
  10398. src_iter_desc, nullptr, nullptr, weights_layer_desc,
  10399. weights_iter_desc, nullptr, nullptr, bias_desc,
  10400. dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
  10401. diff_src_iter_desc, nullptr, nullptr,
  10402. diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
  10403. nullptr, diff_bias_desc, diff_dst_layer_desc,
  10404. diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
  10405. hint_fwd_pd, attr, allow_empty) {}
  10406. /// Constructs a primitive descriptor for a GRU backward propagation
  10407. /// primitive from a C API primitive descriptor that must have a
  10408. /// matching kind.
  10409. ///
  10410. /// @param pd C API primitive descriptor for a GRU backward
  10411. /// propagation primitive.
  10412. primitive_desc(dnnl_primitive_desc_t pd)
  10413. : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
  10414. dnnl::algorithm::vanilla_gru) {}
  10415. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  10416. memory::desc src_layer_desc() const {
  10417. return rnn_base::src_layer_desc();
  10418. }
  10419. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  10420. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  10421. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  10422. memory::desc weights_layer_desc() const {
  10423. return rnn_base::weights_layer_desc();
  10424. }
  10425. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  10426. memory::desc weights_iter_desc() const {
  10427. return rnn_base::weights_iter_desc();
  10428. }
  10429. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  10430. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  10431. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  10432. memory::desc dst_layer_desc() const {
  10433. return rnn_base::dst_layer_desc();
  10434. }
  10435. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  10436. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  10437. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  10438. memory::desc workspace_desc() const {
  10439. return rnn_base::workspace_desc();
  10440. }
  10441. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
  10442. memory::desc diff_src_layer_desc() const {
  10443. return rnn_base::diff_src_layer_desc();
  10444. }
  10445. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
  10446. memory::desc diff_src_iter_desc() const {
  10447. return rnn_base::diff_src_iter_desc();
  10448. }
  10449. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
  10450. memory::desc diff_weights_layer_desc() const {
  10451. return rnn_base::diff_weights_layer_desc();
  10452. }
  10453. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
  10454. memory::desc diff_weights_iter_desc() const {
  10455. return rnn_base::diff_weights_iter_desc();
  10456. }
  10457. /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
  10458. memory::desc diff_bias_desc() const {
  10459. return rnn_base::diff_bias_desc();
  10460. }
  10461. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
  10462. memory::desc diff_dst_layer_desc() const {
  10463. return rnn_base::diff_dst_layer_desc();
  10464. }
  10465. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
  10466. memory::desc diff_dst_iter_desc() const {
  10467. return rnn_base::diff_dst_iter_desc();
  10468. }
  10469. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  10470. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  10471. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  10472. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  10473. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  10474. rnn_direction get_direction() const { return base::get_direction(); }
  10475. };
  10476. /// Default constructor. Produces an empty object.
  10477. gru_backward() = default;
  10478. /// Constructs a GRU backward propagation primitive.
  10479. /// @param pd Primitive descriptor for a GRU backward propagation
  10480. /// primitive.
  10481. gru_backward(const primitive_desc &pd) : primitive(pd) {}
  10482. /// Constructs a GRU backward propagation primitive from a cache blob.
  10483. /// @param pd Primitive descriptor for a GRU backward propagation
  10484. /// primitive.
  10485. /// @param cache_blob Cache blob.
  10486. gru_backward(
  10487. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  10488. : primitive(pd, cache_blob) {}
  10489. };
  10490. /// LBR GRU forward propagation primitive.
  10491. struct lbr_gru_forward : public primitive {
  10492. /// Primitive descriptor for an LBR GRU forward propagation primitive.
  10493. struct primitive_desc : public rnn_primitive_desc_base {
  10494. /// Default constructor. Produces an empty object.
  10495. primitive_desc() = default;
  10496. /// Constructs a primitive descriptor for LBR GRU forward propagation
  10497. /// primitive.
  10498. ///
  10499. /// The following arguments may point to a zero memory descriptor:
  10500. /// - @p src_iter_desc,
  10501. /// - @p bias_desc,
  10502. /// - @p dst_iter_desc.
  10503. ///
  10504. /// This would then indicate that the LBR GRU forward propagation
  10505. /// primitive should not use them and should default to zero values
  10506. /// instead.
  10507. ///
  10508. /// @note
  10509. /// All memory descriptors except @p src_iter_desc may be
  10510. /// initialized with an #dnnl::memory::format_tag::any value of @p
  10511. /// format_tag.
  10512. ///
  10513. /// @param aengine Engine to use.
  10514. /// @param aprop_kind Propagation kind. Possible values are
  10515. /// #dnnl::prop_kind::forward_training, and
  10516. /// #dnnl::prop_kind::forward_inference.
  10517. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  10518. /// more info.
  10519. /// @param src_layer_desc Memory descriptor for the input vector.
  10520. /// @param src_iter_desc Memory descriptor for the input recurrent
  10521. /// hidden state vector.
  10522. /// @param weights_layer_desc Memory descriptor for the weights
  10523. /// applied to the layer input.
  10524. /// @param weights_iter_desc Memory descriptor for the weights applied
  10525. /// to the recurrent input.
  10526. /// @param bias_desc Bias memory descriptor.
  10527. /// @param dst_layer_desc Memory descriptor for the output vector.
  10528. /// @param dst_iter_desc Memory descriptor for the output recurrent
  10529. /// hidden state vector.
  10530. /// @param attr Primitive attributes to use. Attributes are optional
  10531. /// and default to empty attributes.
  10532. /// @param allow_empty A flag signifying whether construction is
  10533. /// allowed to fail without throwing an exception. In this case an
  10534. /// empty object will be produced. This flag is optional and
  10535. /// defaults to false.
  10536. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  10537. rnn_direction direction, const memory::desc &src_layer_desc,
  10538. const memory::desc &src_iter_desc,
  10539. const memory::desc &weights_layer_desc,
  10540. const memory::desc &weights_iter_desc,
  10541. const memory::desc &bias_desc,
  10542. const memory::desc &dst_layer_desc,
  10543. const memory::desc &dst_iter_desc,
  10544. const primitive_attr &attr = default_attr(),
  10545. bool allow_empty = false)
  10546. : rnn_primitive_desc_base(aengine, algorithm::lbr_gru, aprop_kind,
  10547. algorithm::undef, direction, src_layer_desc, src_iter_desc,
  10548. nullptr, nullptr, weights_layer_desc, weights_iter_desc,
  10549. nullptr, nullptr, bias_desc, dst_layer_desc, dst_iter_desc,
  10550. nullptr, rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {}
  10551. /// Constructs a primitive descriptor for a LBR GRU forward propagation
  10552. /// primitive from a C API primitive descriptor that must have a
  10553. /// matching kind.
  10554. ///
  10555. /// @param pd C API primitive descriptor for a LBR GRU forward
  10556. /// propagation primitive.
  10557. primitive_desc(dnnl_primitive_desc_t pd)
  10558. : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
  10559. dnnl::prop_kind::forward_inference,
  10560. dnnl::algorithm::lbr_gru) {}
  10561. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  10562. memory::desc src_layer_desc() const {
  10563. return rnn_base::src_layer_desc();
  10564. }
  10565. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  10566. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  10567. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  10568. memory::desc weights_layer_desc() const {
  10569. return rnn_base::weights_layer_desc();
  10570. }
  10571. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  10572. memory::desc weights_iter_desc() const {
  10573. return rnn_base::weights_iter_desc();
  10574. }
  10575. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  10576. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  10577. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  10578. memory::desc dst_layer_desc() const {
  10579. return rnn_base::dst_layer_desc();
  10580. }
  10581. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  10582. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  10583. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  10584. memory::desc workspace_desc() const {
  10585. return rnn_base::workspace_desc();
  10586. }
  10587. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  10588. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  10589. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  10590. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  10591. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  10592. rnn_direction get_direction() const { return base::get_direction(); }
  10593. };
  10594. /// Default constructor. Produces an empty object.
  10595. lbr_gru_forward() = default;
  10596. /// Constructs an LBR GRU forward propagation primitive.
  10597. /// @param pd Primitive descriptor for an LBR GRU forward propagation
  10598. /// primitive.
  10599. lbr_gru_forward(const primitive_desc &pd) : primitive(pd) {}
  10600. /// Constructs an LBR GRU forward propagation primitive from a cache blob.
  10601. /// @param pd Primitive descriptor for an LBR GRU forward propagation
  10602. /// primitive.
  10603. /// @param cache_blob Cache blob.
  10604. lbr_gru_forward(
  10605. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  10606. : primitive(pd, cache_blob) {}
  10607. };
  10608. /// LBR GRU backward propagation primitive.
  10609. struct lbr_gru_backward : public primitive {
  10610. /// Primitive descriptor for an LBR GRU backward propagation primitive.
  10611. struct primitive_desc : public rnn_primitive_desc_base {
  10612. /// Default constructor. Produces an empty object.
  10613. primitive_desc() = default;
  10614. /// Constructs a primitive descriptor for LBR GRU backward propagation
  10615. /// primitive.
  10616. ///
  10617. /// The following arguments may point to a zero memory descriptor:
  10618. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  10619. /// - @p bias_desc together with @p diff_bias_desc,
  10620. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  10621. ///
  10622. /// This would then indicate that the LBR GRU backward propagation
  10623. /// primitive should not use them and should default to zero values
  10624. /// instead.
  10625. ///
  10626. /// @note
  10627. /// All memory descriptors may be initialized with
  10628. /// #dnnl::memory::format_tag::any value of @p format_tag.
  10629. ///
  10630. /// @param aengine Engine to use.
  10631. /// @param aprop_kind Propagation kind. Must be
  10632. /// #dnnl::prop_kind::backward.
  10633. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  10634. /// more info.
  10635. /// @param src_layer_desc Memory descriptor for the input vector.
  10636. /// @param src_iter_desc Memory descriptor for the input recurrent
  10637. /// hidden state vector.
  10638. /// @param weights_layer_desc Memory descriptor for the weights
  10639. /// applied to the layer input.
  10640. /// @param weights_iter_desc Memory descriptor for the weights applied
  10641. /// to the recurrent input.
  10642. /// @param bias_desc Bias memory descriptor.
  10643. /// @param dst_layer_desc Memory descriptor for the output vector.
  10644. /// @param dst_iter_desc Memory descriptor for the output recurrent
  10645. /// hidden state vector.
  10646. /// @param diff_src_layer_desc Memory descriptor for the diff of input
  10647. /// vector.
  10648. /// @param diff_src_iter_desc Memory descriptor for the diff of input
  10649. /// recurrent hidden state vector.
  10650. /// @param diff_weights_layer_desc Memory descriptor for the diff of
  10651. /// weights applied to the layer input.
  10652. /// @param diff_weights_iter_desc Memory descriptor for the diff of
  10653. /// weights applied to the recurrent input.
  10654. /// @param diff_bias_desc Diff bias memory descriptor.
  10655. /// @param diff_dst_layer_desc Memory descriptor for the diff of
  10656. /// output vector.
  10657. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  10658. /// recurrent hidden state vector.
  10659. /// @param hint_fwd_pd Primitive descriptor for an LBR GRU
  10660. /// forward propagation primitive. It is used as a hint for
  10661. /// deciding which memory format to use.
  10662. /// @param attr Primitive attributes to use. Attributes are optional
  10663. /// and default to empty attributes.
  10664. /// @param allow_empty A flag signifying whether construction is
  10665. /// allowed to fail without throwing an exception. In this case an
  10666. /// empty object will be produced. This flag is optional and
  10667. /// defaults to false.
  10668. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  10669. rnn_direction direction, const memory::desc &src_layer_desc,
  10670. const memory::desc &src_iter_desc,
  10671. const memory::desc &weights_layer_desc,
  10672. const memory::desc &weights_iter_desc,
  10673. const memory::desc &bias_desc,
  10674. const memory::desc &dst_layer_desc,
  10675. const memory::desc &dst_iter_desc,
  10676. const memory::desc &diff_src_layer_desc,
  10677. const memory::desc &diff_src_iter_desc,
  10678. const memory::desc &diff_weights_layer_desc,
  10679. const memory::desc &diff_weights_iter_desc,
  10680. const memory::desc &diff_bias_desc,
  10681. const memory::desc &diff_dst_layer_desc,
  10682. const memory::desc &diff_dst_iter_desc,
  10683. const lbr_gru_forward::primitive_desc &hint_fwd_pd,
  10684. const primitive_attr &attr = default_attr(),
  10685. bool allow_empty = false)
  10686. : rnn_primitive_desc_base(aengine, algorithm::lbr_gru, aprop_kind,
  10687. algorithm::undef, direction, src_layer_desc, src_iter_desc,
  10688. nullptr, nullptr, weights_layer_desc, weights_iter_desc,
  10689. nullptr, nullptr, bias_desc, dst_layer_desc, dst_iter_desc,
  10690. nullptr, diff_src_layer_desc, diff_src_iter_desc, nullptr,
  10691. nullptr, diff_weights_layer_desc, diff_weights_iter_desc,
  10692. nullptr, nullptr, diff_bias_desc, diff_dst_layer_desc,
  10693. diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
  10694. hint_fwd_pd, attr, allow_empty) {}
  10695. /// Constructs a primitive descriptor for a LBR GRU backward propagation
  10696. /// primitive from a C API primitive descriptor that must have a
  10697. /// matching kind.
  10698. ///
  10699. /// @param pd C API primitive descriptor for a LBR GRU backward
  10700. /// propagation primitive.
  10701. primitive_desc(dnnl_primitive_desc_t pd)
  10702. : rnn_primitive_desc_base(
  10703. pd, dnnl::prop_kind::backward, dnnl::algorithm::lbr_gru) {}
  10704. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  10705. memory::desc src_layer_desc() const {
  10706. return rnn_base::src_layer_desc();
  10707. }
  10708. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  10709. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  10710. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  10711. memory::desc weights_layer_desc() const {
  10712. return rnn_base::weights_layer_desc();
  10713. }
  10714. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  10715. memory::desc weights_iter_desc() const {
  10716. return rnn_base::weights_iter_desc();
  10717. }
  10718. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  10719. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  10720. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  10721. memory::desc dst_layer_desc() const {
  10722. return rnn_base::dst_layer_desc();
  10723. }
  10724. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  10725. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  10726. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  10727. memory::desc workspace_desc() const {
  10728. return rnn_base::workspace_desc();
  10729. }
  10730. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
  10731. memory::desc diff_src_layer_desc() const {
  10732. return rnn_base::diff_src_layer_desc();
  10733. }
  10734. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
  10735. memory::desc diff_src_iter_desc() const {
  10736. return rnn_base::diff_src_iter_desc();
  10737. }
  10738. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
  10739. memory::desc diff_weights_layer_desc() const {
  10740. return rnn_base::diff_weights_layer_desc();
  10741. }
  10742. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
  10743. memory::desc diff_weights_iter_desc() const {
  10744. return rnn_base::diff_weights_iter_desc();
  10745. }
  10746. /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
  10747. memory::desc diff_bias_desc() const {
  10748. return rnn_base::diff_bias_desc();
  10749. }
  10750. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
  10751. memory::desc diff_dst_layer_desc() const {
  10752. return rnn_base::diff_dst_layer_desc();
  10753. }
  10754. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
  10755. memory::desc diff_dst_iter_desc() const {
  10756. return rnn_base::diff_dst_iter_desc();
  10757. }
  10758. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  10759. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  10760. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  10761. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  10762. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  10763. rnn_direction get_direction() const { return base::get_direction(); }
  10764. };
  10765. /// Default constructor. Produces an empty object.
  10766. lbr_gru_backward() = default;
  10767. /// Constructs an LBR GRU backward propagation primitive.
  10768. /// @param pd Primitive descriptor for an LBR GRU backward propagation
  10769. /// primitive.
  10770. lbr_gru_backward(const primitive_desc &pd) : primitive(pd) {}
  10771. /// Constructs an LBR GRU backward propagation primitive from a cache blob.
  10772. /// @param pd Primitive descriptor for an LBR GRU backward propagation
  10773. /// primitive.
  10774. /// @param cache_blob Cache blob.
  10775. lbr_gru_backward(
  10776. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  10777. : primitive(pd, cache_blob) {}
  10778. };
  10779. /// AUGRU forward propagation primitive.
  10780. struct augru_forward : public primitive {
  10781. /// Primitive descriptor for an AUGRU forward propagation primitive.
  10782. struct primitive_desc : public rnn_primitive_desc_base {
  10783. /// Default constructor. Produces an empty object.
  10784. primitive_desc() = default;
  10785. /// Constructs a primitive descriptor for an AUGRU forward propagation
  10786. /// primitive.
  10787. ///
  10788. /// The following arguments may point to a zero memory descriptor:
  10789. /// - @p src_iter_desc,
  10790. /// - @p bias_desc,
  10791. /// - @p dst_iter_desc.
  10792. ///
  10793. /// This would then indicate that the AUGRU forward propagation
  10794. /// primitive should not use them and should default to zero values
  10795. /// instead.
  10796. ///
  10797. /// @note
  10798. /// All memory descriptors except @p src_iter_desc may be
  10799. /// initialized with an #dnnl::memory::format_tag::any value of @p
  10800. /// format_tag.
  10801. ///
  10802. /// @param aengine Engine to use.
  10803. /// @param aprop_kind Propagation kind. Possible values are
  10804. /// #dnnl::prop_kind::forward_training, and
  10805. /// #dnnl::prop_kind::forward_inference.
  10806. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  10807. /// more info.
  10808. /// @param src_layer_desc Memory descriptor for the input vector.
  10809. /// @param src_iter_desc Memory descriptor for the input recurrent
  10810. /// hidden state vector.
  10811. /// @param attention_desc Memory descriptor for the attention vector.
  10812. /// @param weights_layer_desc Memory descriptor for the weights
  10813. /// applied to the layer input.
  10814. /// @param weights_iter_desc Memory descriptor for the weights applied
  10815. /// to the recurrent input.
  10816. /// @param bias_desc Bias memory descriptor.
  10817. /// @param dst_layer_desc Memory descriptor for the output vector.
  10818. /// @param dst_iter_desc Memory descriptor for the output recurrent
  10819. /// hidden state vector.
  10820. /// @param attr Primitive attributes to use. Attributes are optional
  10821. /// and default to empty attributes.
  10822. /// @param allow_empty A flag signifying whether construction is
  10823. /// allowed to fail without throwing an exception. In this case an
  10824. /// empty object will be produced. This flag is optional and
  10825. /// defaults to false.
  10826. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  10827. rnn_direction direction, const memory::desc &src_layer_desc,
  10828. const memory::desc &src_iter_desc,
  10829. const memory::desc &attention_desc,
  10830. const memory::desc &weights_layer_desc,
  10831. const memory::desc &weights_iter_desc,
  10832. const memory::desc &bias_desc,
  10833. const memory::desc &dst_layer_desc,
  10834. const memory::desc &dst_iter_desc,
  10835. const primitive_attr &attr = default_attr(),
  10836. bool allow_empty = false)
  10837. : rnn_primitive_desc_base(aengine, algorithm::vanilla_augru,
  10838. aprop_kind, algorithm::undef, direction, src_layer_desc,
  10839. src_iter_desc, nullptr, &attention_desc, weights_layer_desc,
  10840. weights_iter_desc, nullptr, nullptr, bias_desc,
  10841. dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
  10842. 0.0f, 0.0f, attr, allow_empty) {}
  10843. /// Constructs a primitive descriptor for an AUGRU forward propagation
  10844. /// primitive from a C API primitive descriptor that must have a
  10845. /// matching kind.
  10846. ///
  10847. /// @param pd C API primitive descriptor for an AUGRU forward
  10848. /// propagation primitive.
  10849. primitive_desc(dnnl_primitive_desc_t pd)
  10850. : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
  10851. dnnl::prop_kind::forward_inference,
  10852. dnnl::algorithm::vanilla_augru) {}
  10853. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  10854. memory::desc src_layer_desc() const {
  10855. return rnn_base::src_layer_desc();
  10856. }
  10857. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  10858. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  10859. /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
  10860. memory::desc attention_desc() const {
  10861. return rnn_base::augru_attention_desc();
  10862. }
  10863. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  10864. memory::desc weights_layer_desc() const {
  10865. return rnn_base::weights_layer_desc();
  10866. }
  10867. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  10868. memory::desc weights_iter_desc() const {
  10869. return rnn_base::weights_iter_desc();
  10870. }
  10871. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  10872. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  10873. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  10874. memory::desc dst_layer_desc() const {
  10875. return rnn_base::dst_layer_desc();
  10876. }
  10877. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  10878. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  10879. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  10880. memory::desc workspace_desc() const {
  10881. return rnn_base::workspace_desc();
  10882. }
  10883. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  10884. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  10885. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  10886. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  10887. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  10888. rnn_direction get_direction() const { return base::get_direction(); }
  10889. };
  10890. /// Default constructor. Produces an empty object.
  10891. augru_forward() = default;
  10892. /// Constructs an AUGRU forward propagation primitive.
  10893. /// @param pd Primitive descriptor for an AUGRU forward propagation
  10894. /// primitive.
  10895. augru_forward(const primitive_desc &pd) : primitive(pd) {}
  10896. /// Constructs an AUGRU forward propagation primitive from a cache blob.
  10897. /// @param pd Primitive descriptor for an AUGRU forward propagation
  10898. /// primitive.
  10899. /// @param cache_blob Cache blob.
  10900. augru_forward(
  10901. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  10902. : primitive(pd, cache_blob) {}
  10903. };
  10904. /// AUGRU backward propagation primitive.
  10905. struct augru_backward : public primitive {
  10906. /// Descriptor for an AUGRU backward propagation primitive.
  10907. /// Primitive descriptor for an AUGRU backward propagation primitive.
  10908. struct primitive_desc : public rnn_primitive_desc_base {
  10909. /// Default constructor. Produces an empty object.
  10910. primitive_desc() = default;
  10911. /// Constructs a primitive descriptor for an AUGRU backward propagation
  10912. /// primitive.
  10913. ///
  10914. /// The following arguments may point to a zero memory descriptor:
  10915. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  10916. /// - @p bias_desc together with @p diff_bias_desc,
  10917. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  10918. ///
  10919. /// This would then indicate that the AUGRU backward propagation
  10920. /// primitive should not use them and should default to zero values
  10921. /// instead.
  10922. ///
  10923. /// @note
  10924. /// All memory descriptors may be initialized with
  10925. /// #dnnl::memory::format_tag::any value of @p format_tag.
  10926. ///
  10927. /// @param aengine Engine to use.
  10928. /// @param aprop_kind Propagation kind. Must be
  10929. /// #dnnl::prop_kind::backward.
  10930. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  10931. /// more info.
  10932. /// @param src_layer_desc Memory descriptor for the input vector.
  10933. /// @param src_iter_desc Memory descriptor for the input recurrent
  10934. /// hidden state vector.
  10935. /// @param attention_desc Memory descriptor for the attention vector.
  10936. /// @param weights_layer_desc Memory descriptor for the weights
  10937. /// applied to the layer input.
  10938. /// @param weights_iter_desc Memory descriptor for the weights applied
  10939. /// to the recurrent input.
  10940. /// @param bias_desc Bias memory descriptor.
  10941. /// @param dst_layer_desc Memory descriptor for the output vector.
  10942. /// @param dst_iter_desc Memory descriptor for the output recurrent
  10943. /// hidden state vector.
  10944. /// @param diff_src_layer_desc Memory descriptor for the diff of input
  10945. /// vector.
  10946. /// @param diff_src_iter_desc Memory descriptor for the diff of input
  10947. /// recurrent hidden state vector.
  10948. /// @param diff_attention_desc Memory descriptor for the diff of
  10949. /// attention vector.
  10950. /// @param diff_weights_layer_desc Memory descriptor for the diff of
  10951. /// weights applied to the layer input.
  10952. /// @param diff_weights_iter_desc Memory descriptor for the diff of
  10953. /// weights applied to the recurrent input.
  10954. /// @param diff_bias_desc Diff bias memory descriptor.
  10955. /// @param diff_dst_layer_desc Memory descriptor for the diff of
  10956. /// output vector.
  10957. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  10958. /// recurrent hidden state vector.
  10959. /// @param hint_fwd_pd Primitive descriptor for an AUGRU
  10960. /// forward propagation primitive. It is used as a hint for
  10961. /// deciding which memory format to use.
  10962. /// @param attr Primitive attributes to use. Attributes are optional
  10963. /// and default to empty attributes.
  10964. /// @param allow_empty A flag signifying whether construction is
  10965. /// allowed to fail without throwing an exception. In this case an
  10966. /// empty object will be produced. This flag is optional and
  10967. /// defaults to false.
  10968. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  10969. rnn_direction direction, const memory::desc &src_layer_desc,
  10970. const memory::desc &src_iter_desc,
  10971. const memory::desc &attention_desc,
  10972. const memory::desc &weights_layer_desc,
  10973. const memory::desc &weights_iter_desc,
  10974. const memory::desc &bias_desc,
  10975. const memory::desc &dst_layer_desc,
  10976. const memory::desc &dst_iter_desc,
  10977. const memory::desc &diff_src_layer_desc,
  10978. const memory::desc &diff_src_iter_desc,
  10979. const memory::desc &diff_attention_desc,
  10980. const memory::desc &diff_weights_layer_desc,
  10981. const memory::desc &diff_weights_iter_desc,
  10982. const memory::desc &diff_bias_desc,
  10983. const memory::desc &diff_dst_layer_desc,
  10984. const memory::desc &diff_dst_iter_desc,
  10985. const augru_forward::primitive_desc &hint_fwd_pd,
  10986. const primitive_attr &attr = default_attr(),
  10987. bool allow_empty = false)
  10988. : rnn_primitive_desc_base(aengine, algorithm::vanilla_augru,
  10989. aprop_kind, algorithm::undef, direction, src_layer_desc,
  10990. src_iter_desc, nullptr, &attention_desc, weights_layer_desc,
  10991. weights_iter_desc, nullptr, nullptr, bias_desc,
  10992. dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
  10993. diff_src_iter_desc, nullptr, &diff_attention_desc,
  10994. diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
  10995. nullptr, diff_bias_desc, diff_dst_layer_desc,
  10996. diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
  10997. hint_fwd_pd, attr, allow_empty) {}
  10998. /// Constructs a primitive descriptor for an AUGRU backward propagation
  10999. /// primitive from a C API primitive descriptor that must have a
  11000. /// matching kind.
  11001. ///
  11002. /// @param pd C API primitive descriptor for an AUGRU backward
  11003. /// propagation primitive.
  11004. primitive_desc(dnnl_primitive_desc_t pd)
  11005. : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
  11006. dnnl::algorithm::vanilla_augru) {}
  11007. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  11008. memory::desc src_layer_desc() const {
  11009. return rnn_base::src_layer_desc();
  11010. }
  11011. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  11012. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  11013. /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
  11014. memory::desc attention_desc() const {
  11015. return rnn_base::augru_attention_desc();
  11016. }
  11017. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  11018. memory::desc weights_layer_desc() const {
  11019. return rnn_base::weights_layer_desc();
  11020. }
  11021. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  11022. memory::desc weights_iter_desc() const {
  11023. return rnn_base::weights_iter_desc();
  11024. }
  11025. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  11026. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  11027. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  11028. memory::desc dst_layer_desc() const {
  11029. return rnn_base::dst_layer_desc();
  11030. }
  11031. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  11032. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  11033. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  11034. memory::desc workspace_desc() const {
  11035. return rnn_base::workspace_desc();
  11036. }
  11037. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
  11038. memory::desc diff_src_layer_desc() const {
  11039. return rnn_base::diff_src_layer_desc();
  11040. }
  11041. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
  11042. memory::desc diff_src_iter_desc() const {
  11043. return rnn_base::diff_src_iter_desc();
  11044. }
  11045. /// @copydoc dnnl::rnn_primitive_desc_base::diff_augru_attention_desc()const
  11046. memory::desc diff_attention_desc() const {
  11047. return rnn_base::diff_augru_attention_desc();
  11048. }
  11049. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
  11050. memory::desc diff_weights_layer_desc() const {
  11051. return rnn_base::diff_weights_layer_desc();
  11052. }
  11053. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
  11054. memory::desc diff_weights_iter_desc() const {
  11055. return rnn_base::diff_weights_iter_desc();
  11056. }
  11057. /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
  11058. memory::desc diff_bias_desc() const {
  11059. return rnn_base::diff_bias_desc();
  11060. }
  11061. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
  11062. memory::desc diff_dst_layer_desc() const {
  11063. return rnn_base::diff_dst_layer_desc();
  11064. }
  11065. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
  11066. memory::desc diff_dst_iter_desc() const {
  11067. return rnn_base::diff_dst_iter_desc();
  11068. }
  11069. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  11070. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  11071. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  11072. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  11073. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  11074. rnn_direction get_direction() const { return base::get_direction(); }
  11075. };
  11076. /// Default constructor. Produces an empty object.
  11077. augru_backward() = default;
  11078. /// Constructs an AUGRU backward propagation primitive.
  11079. /// @param pd Primitive descriptor for an AUGRU backward propagation
  11080. /// primitive.
  11081. augru_backward(const primitive_desc &pd) : primitive(pd) {}
  11082. /// Constructs an AUGRU backward propagation primitive from a cache blob.
  11083. /// @param pd Primitive descriptor for an AUGRU backward propagation
  11084. /// primitive.
  11085. /// @param cache_blob Cache blob.
  11086. augru_backward(
  11087. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  11088. : primitive(pd, cache_blob) {}
  11089. };
  11090. /// LBR AUGRU forward propagation primitive.
  11091. struct lbr_augru_forward : public primitive {
  11092. /// Descriptor for an LBR AUGRU forward propagation primitive.
  11093. /// Primitive descriptor for an LBR AUGRU forward propagation primitive.
  11094. struct primitive_desc : public rnn_primitive_desc_base {
  11095. /// Default constructor. Produces an empty object.
  11096. primitive_desc() = default;
  11097. /// Constructs a primitive descriptor for LBR AUGRU forward propagation
  11098. /// primitive.
  11099. ///
  11100. /// The following arguments may point to a zero memory descriptor:
  11101. /// - @p src_iter_desc,
  11102. /// - @p bias_desc,
  11103. /// - @p dst_iter_desc.
  11104. ///
  11105. /// This would then indicate that the LBR AUGRU forward propagation
  11106. /// primitive should not use them and should default to zero values
  11107. /// instead.
  11108. ///
  11109. /// @note
  11110. /// All memory descriptors except @p src_iter_desc may be
  11111. /// initialized with an #dnnl::memory::format_tag::any value of @p
  11112. /// format_tag.
  11113. ///
  11114. /// @param aengine Engine to use.
  11115. /// @param aprop_kind Propagation kind. Possible values are
  11116. /// #dnnl::prop_kind::forward_training, and
  11117. /// #dnnl::prop_kind::forward_inference.
  11118. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  11119. /// more info.
  11120. /// @param src_layer_desc Memory descriptor for the input vector.
  11121. /// @param src_iter_desc Memory descriptor for the input recurrent
  11122. /// hidden state vector.
  11123. /// @param attention_desc Memory descriptor for the attention vector.
  11124. /// @param weights_layer_desc Memory descriptor for the weights
  11125. /// applied to the layer input.
  11126. /// @param weights_iter_desc Memory descriptor for the weights applied
  11127. /// to the recurrent input.
  11128. /// @param bias_desc Bias memory descriptor.
  11129. /// @param dst_layer_desc Memory descriptor for the output vector.
  11130. /// @param dst_iter_desc Memory descriptor for the output recurrent
  11131. /// hidden state vector.
  11132. /// @param attr Primitive attributes to use. Attributes are optional
  11133. /// and default to empty attributes.
  11134. /// @param allow_empty A flag signifying whether construction is
  11135. /// allowed to fail without throwing an exception. In this case an
  11136. /// empty object will be produced. This flag is optional and
  11137. /// defaults to false.
  11138. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  11139. rnn_direction direction, const memory::desc &src_layer_desc,
  11140. const memory::desc &src_iter_desc,
  11141. const memory::desc &attention_desc,
  11142. const memory::desc &weights_layer_desc,
  11143. const memory::desc &weights_iter_desc,
  11144. const memory::desc &bias_desc,
  11145. const memory::desc &dst_layer_desc,
  11146. const memory::desc &dst_iter_desc,
  11147. const primitive_attr &attr = default_attr(),
  11148. bool allow_empty = false)
  11149. : rnn_primitive_desc_base(aengine, algorithm::lbr_augru, aprop_kind,
  11150. algorithm::undef, direction, src_layer_desc, src_iter_desc,
  11151. nullptr, &attention_desc, weights_layer_desc,
  11152. weights_iter_desc, nullptr, nullptr, bias_desc,
  11153. dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
  11154. 0.0f, 0.0f, attr, allow_empty) {}
  11155. /// Constructs a primitive descriptor for an LBR AUGRU forward propagation
  11156. /// primitive from a C API primitive descriptor that must have a
  11157. /// matching kind.
  11158. ///
  11159. /// @param pd C API primitive descriptor for an LBR AUGRU forward
  11160. /// propagation primitive.
  11161. primitive_desc(dnnl_primitive_desc_t pd)
  11162. : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
  11163. dnnl::prop_kind::forward_inference,
  11164. dnnl::algorithm::lbr_augru) {}
  11165. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  11166. memory::desc src_layer_desc() const {
  11167. return rnn_base::src_layer_desc();
  11168. }
  11169. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  11170. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  11171. /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
  11172. memory::desc attention_desc() const {
  11173. return rnn_base::augru_attention_desc();
  11174. }
  11175. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  11176. memory::desc weights_layer_desc() const {
  11177. return rnn_base::weights_layer_desc();
  11178. }
  11179. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  11180. memory::desc weights_iter_desc() const {
  11181. return rnn_base::weights_iter_desc();
  11182. }
  11183. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  11184. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  11185. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  11186. memory::desc dst_layer_desc() const {
  11187. return rnn_base::dst_layer_desc();
  11188. }
  11189. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  11190. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  11191. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  11192. memory::desc workspace_desc() const {
  11193. return rnn_base::workspace_desc();
  11194. }
  11195. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  11196. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  11197. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  11198. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  11199. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  11200. rnn_direction get_direction() const { return base::get_direction(); }
  11201. };
  11202. /// Default constructor. Produces an empty object.
  11203. lbr_augru_forward() = default;
  11204. /// Constructs an LBR AUGRU forward propagation primitive.
  11205. /// @param pd Primitive descriptor for an LBR AUGRU forward propagation
  11206. /// primitive.
  11207. lbr_augru_forward(const primitive_desc &pd) : primitive(pd) {}
  11208. /// Constructs an LBR AUGRU forward propagation primitive from a cache blob.
  11209. /// @param pd Primitive descriptor for an LBR AUGRU forward propagation
  11210. /// primitive.
  11211. /// @param cache_blob Cache blob.
  11212. lbr_augru_forward(
  11213. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  11214. : primitive(pd, cache_blob) {}
  11215. };
  11216. /// LBR AUGRU backward propagation primitive.
  11217. struct lbr_augru_backward : public primitive {
  11218. /// Primitive descriptor for an LBR AUGRU backward propagation primitive.
  11219. struct primitive_desc : public rnn_primitive_desc_base {
  11220. /// Default constructor. Produces an empty object.
  11221. primitive_desc() = default;
  11222. /// Constructs a primitive descriptor for LBR AUGRU backward propagation
  11223. /// primitive.
  11224. ///
  11225. /// The following arguments may point to a zero memory descriptor:
  11226. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  11227. /// - @p bias_desc together with @p diff_bias_desc,
  11228. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  11229. ///
  11230. /// This would then indicate that the LBR AUGRU backward propagation
  11231. /// primitive should not use them and should default to zero values
  11232. /// instead.
  11233. ///
  11234. /// @note
  11235. /// All memory descriptors may be initialized with
  11236. /// #dnnl::memory::format_tag::any value of @p format_tag.
  11237. ///
  11238. /// @param aengine Engine to use.
  11239. /// @param aprop_kind Propagation kind. Must be
  11240. /// #dnnl::prop_kind::backward.
  11241. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  11242. /// more info.
  11243. /// @param src_layer_desc Memory descriptor for the input vector.
  11244. /// @param src_iter_desc Memory descriptor for the input recurrent
  11245. /// hidden state vector.
  11246. /// @param attention_desc Memory descriptor for the attention vector.
  11247. /// @param weights_layer_desc Memory descriptor for the weights
  11248. /// applied to the layer input.
  11249. /// @param weights_iter_desc Memory descriptor for the weights applied
  11250. /// to the recurrent input.
  11251. /// @param bias_desc Bias memory descriptor.
  11252. /// @param dst_layer_desc Memory descriptor for the output vector.
  11253. /// @param dst_iter_desc Memory descriptor for the output recurrent
  11254. /// hidden state vector.
  11255. /// @param diff_src_layer_desc Memory descriptor for the diff of input
  11256. /// vector.
  11257. /// @param diff_src_iter_desc Memory descriptor for the diff of input
  11258. /// recurrent hidden state vector.
  11259. /// @param diff_attention_desc Memory descriptor for the diff of
  11260. /// attention vector.
  11261. /// @param diff_weights_layer_desc Memory descriptor for the diff of
  11262. /// weights applied to the layer input.
  11263. /// @param diff_weights_iter_desc Memory descriptor for the diff of
  11264. /// weights applied to the recurrent input.
  11265. /// @param diff_bias_desc Diff bias memory descriptor.
  11266. /// @param diff_dst_layer_desc Memory descriptor for the diff of
  11267. /// output vector.
  11268. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  11269. /// recurrent hidden state vector.
  11270. /// @param hint_fwd_pd Primitive descriptor for an LBR AUGRU
  11271. /// forward propagation primitive. It is used as a hint for
  11272. /// deciding which memory format to use.
  11273. /// @param attr Primitive attributes to use. Attributes are optional
  11274. /// and default to empty attributes.
  11275. /// @param allow_empty A flag signifying whether construction is
  11276. /// allowed to fail without throwing an exception. In this case an
  11277. /// empty object will be produced. This flag is optional and
  11278. /// defaults to false.
  11279. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  11280. rnn_direction direction, const memory::desc &src_layer_desc,
  11281. const memory::desc &src_iter_desc,
  11282. const memory::desc &attention_desc,
  11283. const memory::desc &weights_layer_desc,
  11284. const memory::desc &weights_iter_desc,
  11285. const memory::desc &bias_desc,
  11286. const memory::desc &dst_layer_desc,
  11287. const memory::desc &dst_iter_desc,
  11288. const memory::desc &diff_src_layer_desc,
  11289. const memory::desc &diff_src_iter_desc,
  11290. const memory::desc &diff_attention_desc,
  11291. const memory::desc &diff_weights_layer_desc,
  11292. const memory::desc &diff_weights_iter_desc,
  11293. const memory::desc &diff_bias_desc,
  11294. const memory::desc &diff_dst_layer_desc,
  11295. const memory::desc &diff_dst_iter_desc,
  11296. const lbr_augru_forward::primitive_desc &hint_fwd_pd,
  11297. const primitive_attr &attr = default_attr(),
  11298. bool allow_empty = false)
  11299. : rnn_primitive_desc_base(aengine, algorithm::lbr_augru, aprop_kind,
  11300. algorithm::undef, direction, src_layer_desc, src_iter_desc,
  11301. nullptr, &attention_desc, weights_layer_desc,
  11302. weights_iter_desc, nullptr, nullptr, bias_desc,
  11303. dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
  11304. diff_src_iter_desc, nullptr, &diff_attention_desc,
  11305. diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
  11306. nullptr, diff_bias_desc, diff_dst_layer_desc,
  11307. diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
  11308. hint_fwd_pd, attr, allow_empty) {}
  11309. /// Constructs a primitive descriptor for an LBR AUGRU backward
  11310. /// propagation primitive from a C API primitive descriptor that must
  11311. /// have a matching kind.
  11312. ///
  11313. /// @param pd C API primitive descriptor for an LBR AUGRU backward
  11314. /// propagation primitive.
  11315. primitive_desc(dnnl_primitive_desc_t pd)
  11316. : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
  11317. dnnl::algorithm::lbr_augru) {}
  11318. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  11319. memory::desc src_layer_desc() const {
  11320. return rnn_base::src_layer_desc();
  11321. }
  11322. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  11323. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  11324. /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
  11325. memory::desc attention_desc() const {
  11326. return rnn_base::augru_attention_desc();
  11327. }
  11328. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  11329. memory::desc weights_layer_desc() const {
  11330. return rnn_base::weights_layer_desc();
  11331. }
  11332. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  11333. memory::desc weights_iter_desc() const {
  11334. return rnn_base::weights_iter_desc();
  11335. }
  11336. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  11337. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  11338. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  11339. memory::desc dst_layer_desc() const {
  11340. return rnn_base::dst_layer_desc();
  11341. }
  11342. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  11343. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  11344. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  11345. memory::desc workspace_desc() const {
  11346. return rnn_base::workspace_desc();
  11347. }
  11348. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
  11349. memory::desc diff_src_layer_desc() const {
  11350. return rnn_base::diff_src_layer_desc();
  11351. }
  11352. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
  11353. memory::desc diff_src_iter_desc() const {
  11354. return rnn_base::diff_src_iter_desc();
  11355. }
  11356. /// @copydoc dnnl::rnn_primitive_desc_base::diff_augru_attention_desc()const
  11357. memory::desc diff_attention_desc() const {
  11358. return rnn_base::diff_augru_attention_desc();
  11359. }
  11360. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
  11361. memory::desc diff_weights_layer_desc() const {
  11362. return rnn_base::diff_weights_layer_desc();
  11363. }
  11364. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
  11365. memory::desc diff_weights_iter_desc() const {
  11366. return rnn_base::diff_weights_iter_desc();
  11367. }
  11368. /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
  11369. memory::desc diff_bias_desc() const {
  11370. return rnn_base::diff_bias_desc();
  11371. }
  11372. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
  11373. memory::desc diff_dst_layer_desc() const {
  11374. return rnn_base::diff_dst_layer_desc();
  11375. }
  11376. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
  11377. memory::desc diff_dst_iter_desc() const {
  11378. return rnn_base::diff_dst_iter_desc();
  11379. }
  11380. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  11381. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  11382. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  11383. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  11384. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  11385. rnn_direction get_direction() const { return base::get_direction(); }
  11386. };
  11387. /// Default constructor. Produces an empty object.
  11388. lbr_augru_backward() = default;
  11389. /// Constructs an LBR AUGRU backward propagation primitive.
  11390. /// @param pd Primitive descriptor for an LBR AUGRU backward propagation
  11391. /// primitive.
  11392. lbr_augru_backward(const primitive_desc &pd) : primitive(pd) {}
  11393. /// Constructs an LBR AUGRU backward propagation primitive from a cache blob.
  11394. /// @param pd Primitive descriptor for an LBR AUGRU backward propagation
  11395. /// primitive.
  11396. /// @param cache_blob Cache blob.
  11397. lbr_augru_backward(
  11398. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  11399. : primitive(pd, cache_blob) {}
  11400. };
  11401. /// @} dnnl_api_rnn
  11402. /// @addtogroup dnnl_api_shuffle Shuffle
  11403. ///
  11404. /// A primitive to shuffle tensor data along an axis.
  11405. ///
  11406. /// @sa @ref dev_guide_shuffle in developer guide
  11407. ///
  11408. /// @{
  11409. /// Shuffle forward propagation primitive.
  11410. struct shuffle_forward : public primitive {
  11411. /// Primitive descriptor for a shuffle forward propagation primitive.
  11412. struct primitive_desc : public dnnl::primitive_desc {
  11413. /// Default constructor. Produces an empty object.
  11414. primitive_desc() = default;
  11415. /// Constructs a primitive descriptor for a shuffle forward propagation
  11416. /// primitive.
  11417. ///
  11418. /// @param aengine Engine to use.
  11419. /// @param aprop_kind Propagation kind. Possible values are
  11420. /// #dnnl::prop_kind::forward_training, and
  11421. /// #dnnl::prop_kind::forward_inference.
  11422. /// @param src_desc Source memory descriptor.
  11423. /// @param dst_desc Destination memory descriptor.
  11424. /// @param axis The axis along which the data is shuffled.
  11425. /// @param group_size Shuffle group size.
  11426. /// @param attr Primitive attributes to use. Attributes are optional
  11427. /// and default to empty attributes.
  11428. /// @param allow_empty A flag signifying whether construction is
  11429. /// allowed to fail without throwing an exception. In this case an
  11430. /// empty object will be produced. This flag is optional and
  11431. /// defaults to false.
  11432. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  11433. const memory::desc &src_desc, const memory::desc &dst_desc,
  11434. int axis, int group_size,
  11435. const primitive_attr &attr = default_attr(),
  11436. bool allow_empty = false) {
  11437. dnnl_primitive_desc_t pd = nullptr;
  11438. dnnl_status_t status = dnnl_shuffle_forward_primitive_desc_create(
  11439. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  11440. src_desc.get(), dst_desc.get(), axis, group_size,
  11441. attr.get());
  11442. if (!allow_empty)
  11443. error::wrap_c_api(status,
  11444. "could not create a primitive descriptor for "
  11445. "the shuffle forward propagation primitive. Run "
  11446. "workload with environment variable ONEDNN_VERBOSE=all "
  11447. "to get additional diagnostic information.");
  11448. reset(pd);
  11449. }
  11450. /// Constructs a primitive descriptor for a shuffle forward propagation
  11451. /// primitive from a C API primitive descriptor that must have a
  11452. /// matching kind.
  11453. ///
  11454. /// @param pd C API primitive descriptor for a shuffle forward
  11455. /// propagation primitive.
  11456. primitive_desc(dnnl_primitive_desc_t pd)
  11457. : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
  11458. dnnl::prop_kind::forward_training,
  11459. dnnl::prop_kind::forward_inference) {}
  11460. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  11461. memory::desc src_desc() const { return base::src_desc(0); }
  11462. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  11463. memory::desc dst_desc() const { return base::dst_desc(0); }
  11464. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  11465. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  11466. /// @copydoc dnnl::primitive_desc_base::get_axis()const
  11467. int get_axis() const { return base::get_axis(); }
  11468. /// @copydoc dnnl::primitive_desc_base::get_group_size()const
  11469. memory::dim get_group_size() const { return base::get_group_size(); }
  11470. };
  11471. /// Default constructor. Produces an empty object.
  11472. shuffle_forward() = default;
  11473. /// Constructs a shuffle forward propagation primitive.
  11474. /// @param pd Primitive descriptor for a shuffle forward propagation
  11475. /// primitive.
  11476. shuffle_forward(const primitive_desc &pd) : primitive(pd) {}
  11477. /// Constructs a shuffle forward propagation primitive from a cache blob.
  11478. /// @param pd Primitive descriptor for a shuffle forward propagation
  11479. /// primitive.
  11480. /// @param cache_blob Cache blob.
  11481. shuffle_forward(
  11482. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  11483. : primitive(pd, cache_blob) {}
  11484. };
  11485. /// Shuffle backward propagation primitive.
  11486. struct shuffle_backward : public primitive {
  11487. /// Primitive descriptor for a shuffle backward propagation primitive.
  11488. struct primitive_desc : public dnnl::primitive_desc {
  11489. /// Default constructor. Produces an empty object.
  11490. primitive_desc() = default;
  11491. /// Constructs a primitive descriptor for a shuffle backward propagation
  11492. /// primitive.
  11493. ///
  11494. /// @param aengine Engine to use.
  11495. /// @param diff_src_desc Diff source memory descriptor.
  11496. /// @param diff_dst_desc Diff destination memory descriptor.
  11497. /// @param axis The axis along which the data is shuffled.
  11498. /// @param group_size Shuffle group size.
  11499. /// @param hint_fwd_pd Primitive descriptor for a shuffle forward
  11500. /// propagation primitive. It is used as a hint for deciding which
  11501. /// memory format to use.
  11502. /// @param attr Primitive attributes to use. Attributes are optional
  11503. /// and default to empty attributes.
  11504. /// @param allow_empty A flag signifying whether construction is
  11505. /// allowed to fail without throwing an exception. In this case an
  11506. /// empty object will be produced. This flag is optional and
  11507. /// defaults to false.
  11508. primitive_desc(const engine &aengine, const memory::desc &diff_src_desc,
  11509. const memory::desc &diff_dst_desc, int axis, int group_size,
  11510. const shuffle_forward::primitive_desc &hint_fwd_pd,
  11511. const primitive_attr &attr = default_attr(),
  11512. bool allow_empty = false) {
  11513. dnnl_primitive_desc_t pd = nullptr;
  11514. dnnl_status_t status = dnnl_shuffle_backward_primitive_desc_create(
  11515. &pd, aengine.get(), diff_src_desc.get(),
  11516. diff_dst_desc.get(), axis, group_size, hint_fwd_pd.get(),
  11517. attr.get());
  11518. if (!allow_empty)
  11519. error::wrap_c_api(status,
  11520. "could not create a primitive descriptor for "
  11521. "the shuffle backward propagation primitive. Run "
  11522. "workload with environment variable ONEDNN_VERBOSE=all "
  11523. "to get additional diagnostic information.");
  11524. reset(pd);
  11525. }
  11526. /// Constructs a primitive descriptor for a shuffle backward
  11527. /// propagation primitive from a C API primitive descriptor that must
  11528. /// have a matching kind.
  11529. ///
  11530. /// @param pd C API primitive descriptor for a shuffle backward
  11531. /// propagation primitive.
  11532. primitive_desc(dnnl_primitive_desc_t pd)
  11533. : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
  11534. dnnl::prop_kind::backward_data) {}
  11535. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  11536. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  11537. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  11538. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  11539. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  11540. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  11541. /// @copydoc dnnl::primitive_desc_base::get_axis()const
  11542. int get_axis() const { return base::get_axis(); }
  11543. /// @copydoc dnnl::primitive_desc_base::get_group_size()const
  11544. memory::dim get_group_size() const { return base::get_group_size(); }
  11545. };
  11546. /// Default constructor. Produces an empty object.
  11547. shuffle_backward() = default;
  11548. /// Constructs a shuffle backward propagation primitive.
  11549. /// @param pd Primitive descriptor for a shuffle backward propagation
  11550. /// primitive.
  11551. shuffle_backward(const primitive_desc &pd) : primitive(pd) {}
  11552. /// Constructs a shuffle backward propagation primitive from a cache blob.
  11553. /// @param pd Primitive descriptor for a shuffle backward propagation
  11554. /// primitive.
  11555. /// @param cache_blob Cache blob.
  11556. shuffle_backward(
  11557. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  11558. : primitive(pd, cache_blob) {}
  11559. };
  11560. /// @} dnnl_api_shuffle
  11561. /// @addtogroup dnnl_api_binary Binary
  11562. ///
  11563. /// A primitive to perform tensor operations over two tensors.
  11564. ///
  11565. /// @sa @ref dev_guide_binary in developer guide
  11566. ///
  11567. /// @{
  11568. /// Elementwise binary operator primitive.
  11569. struct binary : public primitive {
  11570. /// Primitive descriptor for an elementwise binary operator primitive.
  11571. struct primitive_desc : public dnnl::primitive_desc {
  11572. /// Default constructor. Produces an empty object.
  11573. primitive_desc() = default;
  11574. /// Constructs a primitive descriptor for an elementwise binary operator
  11575. /// primitive.
  11576. ///
  11577. /// @param aengine Engine to use.
  11578. /// @param aalgorithm Elementwise binary algorithm.
  11579. /// @param src0 Memory descriptor for source tensor #0.
  11580. /// @param src1 Memory descriptor for source tensor #1.
  11581. /// @param dst Memory descriptor for destination tensor.
  11582. /// @param attr Primitive attributes to use. Attributes are optional
  11583. /// and default to empty attributes.
  11584. /// @param allow_empty A flag signifying whether construction is
  11585. /// allowed to fail without throwing an exception. In this case an
  11586. /// empty object will be produced. This flag is optional and
  11587. /// defaults to false.
  11588. primitive_desc(const engine &aengine, algorithm aalgorithm,
  11589. const memory::desc &src0, const memory::desc &src1,
  11590. const memory::desc &dst,
  11591. const primitive_attr &attr = default_attr(),
  11592. bool allow_empty = false) {
  11593. dnnl_primitive_desc_t pd = nullptr;
  11594. dnnl_status_t status = dnnl_binary_primitive_desc_create(&pd,
  11595. aengine.get(), dnnl::convert_to_c(aalgorithm), src0.get(),
  11596. src1.get(), dst.get(), attr.get());
  11597. if (!allow_empty)
  11598. error::wrap_c_api(status,
  11599. "could not create a primitive descriptor for "
  11600. "the binary operation primitive. Run workload with "
  11601. "environment variable ONEDNN_VERBOSE=all to get "
  11602. "additional diagnostic information.");
  11603. reset(pd);
  11604. }
  11605. /// Constructs a primitive descriptor for an elementwise binary operator
  11606. /// primitive with support of ternary operators.
  11607. ///
  11608. /// @param aengine Engine to use.
  11609. /// @param aalgorithm Elementwise binary algorithm.
  11610. /// @param src0 Memory descriptor for source tensor #0.
  11611. /// @param src1 Memory descriptor for source tensor #1.
  11612. /// @param src2 Memory descriptor for source tensor #2 for ternary
  11613. /// operations. Might be empty.
  11614. /// @param dst Memory descriptor for destination tensor.
  11615. /// @param attr Primitive attributes to use. Attributes are optional
  11616. /// and default to empty attributes.
  11617. /// @param allow_empty A flag signifying whether construction is
  11618. /// allowed to fail without throwing an exception. In this case an
  11619. /// empty object will be produced. This flag is optional and
  11620. /// defaults to false.
  11621. primitive_desc(const engine &aengine, algorithm aalgorithm,
  11622. const memory::desc &src0, const memory::desc &src1,
  11623. const memory::desc &src2, const memory::desc &dst,
  11624. const primitive_attr &attr = default_attr(),
  11625. bool allow_empty = false) {
  11626. dnnl_primitive_desc_t pd = nullptr;
  11627. dnnl_status_t status = dnnl_binary_primitive_desc_create_v2(&pd,
  11628. aengine.get(), dnnl::convert_to_c(aalgorithm), src0.get(),
  11629. src1.get(), src2.get(), dst.get(), attr.get());
  11630. if (!allow_empty)
  11631. error::wrap_c_api(status,
  11632. "could not create a primitive descriptor for "
  11633. "the binary v2 operation primitive. Run workload with "
  11634. "environment variable ONEDNN_VERBOSE=all to get "
  11635. "additional diagnostic information.");
  11636. reset(pd);
  11637. }
  11638. /// Constructs a primitive descriptor for a binary primitive from a C
  11639. /// API primitive descriptor that must have a matching kind.
  11640. ///
  11641. /// @param pd C API primitive descriptor for a binary primitive.
  11642. primitive_desc(dnnl_primitive_desc_t pd)
  11643. : dnnl::primitive_desc(pd, dnnl::primitive::kind::binary) {}
  11644. /// @copydoc dnnl::primitive_desc_base::src_desc(int)const
  11645. memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
  11646. /// Returns the memory descriptor for source #0.
  11647. memory::desc src0_desc() const { return base::src_desc(0); }
  11648. /// Returns the memory descriptor for source #1.
  11649. memory::desc src1_desc() const { return base::src_desc(1); }
  11650. /// Returns the memory descriptor for source #2.
  11651. memory::desc src2_desc() const { return base::src_desc(2); }
  11652. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  11653. memory::desc dst_desc() const { return base::dst_desc(0); }
  11654. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  11655. algorithm get_algorithm() const { return base::get_algorithm(); }
  11656. };
  11657. /// Default constructor. Produces an empty object.
  11658. binary() = default;
  11659. /// Constructs an elementwise binary operation primitive.
  11660. /// @param pd Primitive descriptor for an elementwise binary operation
  11661. /// primitive.
  11662. binary(const primitive_desc &pd) : primitive(pd) {}
  11663. /// Constructs an elementwise binary operation primitive from a cache blob.
  11664. /// @param pd Primitive descriptor for an elementwise binary operation
  11665. /// primitive.
  11666. /// @param cache_blob Cache blob.
  11667. binary(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  11668. : primitive(pd, cache_blob) {}
  11669. };
  11670. /// @} dnnl_api_binary
  11671. /// @addtogroup dnnl_api_matmul Matrix Multiplication
  11672. ///
  11673. /// A primitive to perform matrix-matrix multiplication. The batched mode
  11674. /// is supported with 3D tensors.
  11675. ///
  11676. /// @sa @ref dev_guide_matmul in developer guide
  11677. ///
  11678. ///
  11679. /// @{
  11680. /// Matrix multiplication (matmul) primitive.
  11681. struct matmul : public primitive {
  11682. /// Primitive descriptor for a matmul primitive.
  11683. struct primitive_desc : public dnnl::primitive_desc {
  11684. /// Default constructor. Produces an empty object.
  11685. primitive_desc() = default;
  11686. /// Constructs a primitive descriptor for a matmul primitive
  11687. /// without bias.
  11688. ///
  11689. /// @param aengine Engine to use.
  11690. /// @param src_desc Memory descriptor for source (matrix A).
  11691. /// @param weights_desc Memory descriptor for weights (matrix B).
  11692. /// @param dst_desc Memory descriptor for destination (matrix C).
  11693. /// @param attr Primitive attributes to use. Attributes are optional
  11694. /// and default to empty attributes.
  11695. /// @param allow_empty A flag signifying whether construction is
  11696. /// allowed to fail without throwing an exception. In this case an
  11697. /// empty object will be produced. This flag is optional and
  11698. /// defaults to false.
  11699. primitive_desc(const engine &aengine, const memory::desc &src_desc,
  11700. const memory::desc &weights_desc, const memory::desc &dst_desc,
  11701. const primitive_attr &attr = default_attr(),
  11702. bool allow_empty = false)
  11703. : primitive_desc(aengine, src_desc, weights_desc, nullptr, dst_desc,
  11704. attr, allow_empty) {}
  11705. /// Constructs a primitive descriptor for a matmul primitive with bias.
  11706. ///
  11707. /// @param aengine Engine to use.
  11708. /// @param src_desc Memory descriptor for source (matrix A).
  11709. /// @param weights_desc Memory descriptor for weights (matrix B).
  11710. /// @param dst_desc Memory descriptor for destination (matrix C).
  11711. /// @param bias_desc Memory descriptor for bias.
  11712. /// @param attr Primitive attributes to use. Attributes are optional
  11713. /// and default to empty attributes.
  11714. /// @param allow_empty A flag signifying whether construction is
  11715. /// allowed to fail without throwing an exception. In this case an
  11716. /// empty object will be produced. This flag is optional and
  11717. /// defaults to false.
  11718. primitive_desc(const engine &aengine, const memory::desc &src_desc,
  11719. const memory::desc &weights_desc, const memory::desc &bias_desc,
  11720. const memory::desc &dst_desc,
  11721. const primitive_attr &attr = default_attr(),
  11722. bool allow_empty = false)
  11723. : primitive_desc(aengine, src_desc, weights_desc, &bias_desc,
  11724. dst_desc, attr, allow_empty) {}
  11725. /// Constructs a primitive descriptor for a matmul primitive from a C
  11726. /// API primitive descriptor that must have a matching kind.
  11727. ///
  11728. /// @param pd C API primitive descriptor for a matmul primitive.
  11729. primitive_desc(dnnl_primitive_desc_t pd)
  11730. : dnnl::primitive_desc(pd, dnnl::primitive::kind::matmul) {}
  11731. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  11732. memory::desc src_desc() const { return query_md(query::src_md, 0); }
  11733. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  11734. memory::desc weights_desc() const {
  11735. return query_md(query::weights_md, 0);
  11736. }
  11737. /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
  11738. memory::desc bias_desc() const {
  11739. return query_md(query::weights_md, 1);
  11740. }
  11741. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  11742. memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
  11743. private:
  11744. primitive_desc(const engine &aengine, const memory::desc &src_desc,
  11745. const memory::desc &weights_desc, const memory::desc *bias_desc,
  11746. const memory::desc &dst_desc, const primitive_attr &attr,
  11747. bool allow_empty) {
  11748. dnnl_primitive_desc_t pd = nullptr;
  11749. dnnl_status_t status = dnnl_matmul_primitive_desc_create(&pd,
  11750. aengine.get(), src_desc.get(), weights_desc.get(),
  11751. optional_arg(bias_desc), dst_desc.get(), attr.get());
  11752. if (!allow_empty)
  11753. error::wrap_c_api(status,
  11754. "could not create a primitive descriptor for "
  11755. "the matmul primitive. Run workload with "
  11756. "environment variable ONEDNN_VERBOSE=all to get "
  11757. "additional diagnostic information.");
  11758. reset(pd);
  11759. }
  11760. };
  11761. /// Default constructor. Produces an empty object.
  11762. matmul() = default;
  11763. /// Constructs a matmul primitive.
  11764. /// @param pd Primitive descriptor for a matmul primitive.
  11765. matmul(const primitive_desc &pd) : primitive(pd) {}
  11766. /// Constructs a matmul primitive from a cache blob.
  11767. /// @param pd Primitive descriptor for a matmul primitive.
  11768. /// @param cache_blob Cache blob.
  11769. matmul(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  11770. : primitive(pd, cache_blob) {}
  11771. };
  11772. /// @} dnnl_api_matmul
  11773. /// @addtogroup dnnl_api_resampling Resampling
  11774. ///
  11775. /// A primitive to compute resampling operation on 1D, 2D or 3D data tensor
  11776. /// using Nearest Neighbor, or Linear (Bilinear, Trilinear) interpolation
  11777. /// method.
  11778. ///
  11779. /// @sa @ref dev_guide_resampling in developer guide
  11780. ///
  11781. /// @{
  11782. /// Resampling forward propagation.
  11783. struct resampling_forward : public primitive {
  11784. /// Primitive descriptor for a resampling forward propagation primitive.
  11785. struct primitive_desc : public dnnl::primitive_desc {
  11786. /// Default constructor. Produces an empty object.
  11787. primitive_desc() = default;
  11788. /// Constructs a primitive descriptor for a resampling forward
  11789. /// propagation primitive using source and destination memory
  11790. /// descriptors.
  11791. ///
  11792. /// @note
  11793. /// Destination memory descriptor may be initialized with
  11794. /// #dnnl::memory::format_tag::any value of @p format_tag.
  11795. ///
  11796. /// @param aengine Engine to use.
  11797. /// @param aprop_kind Propagation kind. Possible values are
  11798. /// #dnnl::prop_kind::forward_training, and
  11799. /// #dnnl::prop_kind::forward_inference.
  11800. /// @param aalgorithm resampling algorithm kind: either
  11801. /// #dnnl::algorithm::resampling_nearest, or
  11802. /// #dnnl::algorithm::resampling_linear
  11803. /// @param src_desc Source memory descriptor.
  11804. /// @param dst_desc Destination memory descriptor.
  11805. /// @param attr Primitive attributes to use. Attributes are optional
  11806. /// and default to empty attributes.
  11807. /// @param allow_empty A flag signifying whether construction is
  11808. /// allowed to fail without throwing an exception. In this case an
  11809. /// empty object will be produced. This flag is optional and
  11810. /// defaults to false.
  11811. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  11812. algorithm aalgorithm, const memory::desc &src_desc,
  11813. const memory::desc &dst_desc,
  11814. const primitive_attr &attr = default_attr(),
  11815. bool allow_empty = false)
  11816. : primitive_desc(aengine, aprop_kind, aalgorithm, nullptr, src_desc,
  11817. &dst_desc, attr, allow_empty) {}
  11818. /// Constructs a primitive descriptor for a resampling forward
  11819. /// propagation primitive using source memory descriptor and
  11820. /// factors.
  11821. ///
  11822. /// @param aengine Engine to use.
  11823. /// @param aprop_kind Propagation kind. Possible values are
  11824. /// #dnnl::prop_kind::forward_training, and
  11825. /// #dnnl::prop_kind::forward_inference.
  11826. /// @param aalgorithm resampling algorithm kind: either
  11827. /// #dnnl::algorithm::resampling_nearest, or
  11828. /// #dnnl::algorithm::resampling_linear
  11829. /// @param factors Vector of scaling factors for spatial dimension.
  11830. /// @param src_desc Source memory descriptor.
  11831. /// @param attr Primitive attributes to use. Attributes are optional
  11832. /// and default to empty attributes.
  11833. /// @param allow_empty A flag signifying whether construction is
  11834. /// allowed to fail without throwing an exception. In this case an
  11835. /// empty object will be produced. This flag is optional and
  11836. /// defaults to false.
  11837. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  11838. algorithm aalgorithm, const std::vector<float> &factors,
  11839. const memory::desc &src_desc,
  11840. const primitive_attr &attr = default_attr(),
  11841. bool allow_empty = false)
  11842. : primitive_desc(aengine, aprop_kind, aalgorithm, &factors,
  11843. src_desc, nullptr, attr, allow_empty) {}
  11844. /// Constructs a primitive descriptor for a resampling forward
  11845. /// propagation primitive.
  11846. ///
  11847. /// @note
  11848. /// The destination memory descriptor may be initialized with
  11849. /// #dnnl::memory::format_tag::any value of @p format_tag.
  11850. ///
  11851. /// @param aengine Engine to use.
  11852. /// @param aprop_kind Propagation kind. Possible values are
  11853. /// #dnnl::prop_kind::forward_training, and
  11854. /// #dnnl::prop_kind::forward_inference.
  11855. /// @param aalgorithm resampling algorithm kind: either
  11856. /// #dnnl::algorithm::resampling_nearest, or
  11857. /// #dnnl::algorithm::resampling_linear
  11858. /// @param factors Vector of scaling factors for spatial dimension.
  11859. /// @param src_desc Source memory descriptor.
  11860. /// @param dst_desc Destination memory descriptor.
  11861. /// @param attr Primitive attributes to use. Attributes are optional
  11862. /// and default to empty attributes.
  11863. /// @param allow_empty A flag signifying whether construction is
  11864. /// allowed to fail without throwing an exception. In this case an
  11865. /// empty object will be produced. This flag is optional and
  11866. /// defaults to false.
  11867. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  11868. algorithm aalgorithm, const std::vector<float> &factors,
  11869. const memory::desc &src_desc, const memory::desc &dst_desc,
  11870. const primitive_attr &attr = default_attr(),
  11871. bool allow_empty = false)
  11872. : primitive_desc(aengine, aprop_kind, aalgorithm, &factors,
  11873. src_desc, &dst_desc, attr, allow_empty) {}
  11874. /// Constructs a primitive descriptor for a resampling forward
  11875. /// propagation primitive from a C API primitive descriptor that must
  11876. /// have a matching kind.
  11877. ///
  11878. /// @param pd C API primitive descriptor for a resampling forward
  11879. /// propagation primitive.
  11880. primitive_desc(dnnl_primitive_desc_t pd)
  11881. : dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
  11882. dnnl::prop_kind::forward_training,
  11883. dnnl::prop_kind::forward_inference) {}
  11884. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  11885. memory::desc src_desc() const { return base::src_desc(0); }
  11886. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  11887. memory::desc dst_desc() const { return base::dst_desc(0); }
  11888. private:
  11889. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  11890. algorithm aalgorithm, const std::vector<float> *factors,
  11891. const memory::desc &src_desc, const memory::desc *dst_desc,
  11892. const primitive_attr &attr, bool allow_empty) {
  11893. if (factors)
  11894. memory::validate_dims(*factors, src_desc.get_ndims() - 2);
  11895. dnnl_primitive_desc_t pd = nullptr;
  11896. dnnl_status_t status
  11897. = dnnl_resampling_forward_primitive_desc_create(&pd,
  11898. aengine.get(), dnnl::convert_to_c(aprop_kind),
  11899. convert_to_c(aalgorithm), optional_arg(factors),
  11900. src_desc.get(), optional_arg(dst_desc), attr.get());
  11901. if (!allow_empty)
  11902. error::wrap_c_api(status,
  11903. "could not create a primitive descriptor for "
  11904. "the resampling forward propagation primitive. Run "
  11905. "workload with environment variable ONEDNN_VERBOSE=all "
  11906. "to get additional diagnostic information.");
  11907. reset(pd);
  11908. }
  11909. };
  11910. /// Default constructor. Produces an empty object.
  11911. resampling_forward() = default;
  11912. /// Constructs a resampling forward propagation primitive.
  11913. /// @param pd Primitive descriptor for a resampling forward propagation
  11914. /// primitive.
  11915. resampling_forward(const primitive_desc &pd) : primitive(pd) {}
  11916. /// Constructs a resampling forward propagation primitive from a cache
  11917. /// blob.
  11918. /// @param pd Primitive descriptor for a resampling forward propagation
  11919. /// primitive.
  11920. /// @param cache_blob Cache blob.
  11921. resampling_forward(
  11922. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  11923. : primitive(pd, cache_blob) {}
  11924. };
  11925. /// Resampling backward propagation primitive.
  11926. struct resampling_backward : public primitive {
  11927. /// Primitive descriptor for resampling backward propagation primitive.
  11928. struct primitive_desc : public dnnl::primitive_desc {
  11929. /// Default constructor. Produces an empty object.
  11930. primitive_desc() = default;
  11931. /// Constructs a primitive descriptor for a resampling backward
  11932. /// propagation primitive using source and destination memory
  11933. /// descriptors.
  11934. ///
  11935. /// @param aengine Engine to use.
  11936. /// @param aalgorithm resampling algorithm kind: either
  11937. /// #dnnl::algorithm::resampling_nearest, or
  11938. /// #dnnl::algorithm::resampling_linear
  11939. /// @param diff_src_desc Diff source memory descriptor.
  11940. /// @param diff_dst_desc Diff destination memory descriptor.
  11941. /// @param hint_fwd_pd Primitive descriptor for a resampling
  11942. /// forward propagation primitive. It is used as a hint for
  11943. /// deciding which memory format to use.
  11944. /// @param attr Primitive attributes to use. Attributes are optional
  11945. /// and default to empty attributes.
  11946. /// @param allow_empty A flag signifying whether construction is
  11947. /// allowed to fail without throwing an exception. In this case an
  11948. /// empty object will be produced. This flag is optional and
  11949. /// defaults to false.
  11950. primitive_desc(const engine &aengine, algorithm aalgorithm,
  11951. const memory::desc &diff_src_desc,
  11952. const memory::desc &diff_dst_desc,
  11953. const resampling_forward::primitive_desc &hint_fwd_pd,
  11954. const primitive_attr &attr = default_attr(),
  11955. bool allow_empty = false)
  11956. : primitive_desc(aengine, aalgorithm, nullptr, diff_src_desc,
  11957. diff_dst_desc, hint_fwd_pd, attr, allow_empty) {}
  11958. /// Constructs a primitive descriptor for resampling backward
  11959. /// propagation primitive.
  11960. ///
  11961. /// @param aengine Engine to use.
  11962. /// @param aalgorithm resampling algorithm kind: either
  11963. /// #dnnl::algorithm::resampling_nearest, or
  11964. /// #dnnl::algorithm::resampling_linear
  11965. /// @param factors Vector of scaling factors for spatial dimension.
  11966. /// @param diff_src_desc Diff source memory descriptor.
  11967. /// @param diff_dst_desc Diff destination memory descriptor.
  11968. /// @param hint_fwd_pd Primitive descriptor for a resampling
  11969. /// forward propagation primitive. It is used as a hint for
  11970. /// deciding which memory format to use.
  11971. /// @param attr Primitive attributes to use. Attributes are optional
  11972. /// and default to empty attributes.
  11973. /// @param allow_empty A flag signifying whether construction is
  11974. /// allowed to fail without throwing an exception. In this case an
  11975. /// empty object will be produced. This flag is optional and
  11976. /// defaults to false.
  11977. primitive_desc(const engine &aengine, algorithm aalgorithm,
  11978. const std::vector<float> &factors,
  11979. const memory::desc &diff_src_desc,
  11980. const memory::desc &diff_dst_desc,
  11981. const resampling_forward::primitive_desc &hint_fwd_pd,
  11982. const primitive_attr &attr = default_attr(),
  11983. bool allow_empty = false)
  11984. : primitive_desc(aengine, aalgorithm, &factors, diff_src_desc,
  11985. diff_dst_desc, hint_fwd_pd, attr, allow_empty) {}
  11986. /// Constructs a primitive descriptor for a resampling backward
  11987. /// propagation primitive from a C API primitive descriptor that must
  11988. /// have a matching kind.
  11989. ///
  11990. /// @param pd C API primitive descriptor for a resampling backward
  11991. /// propagation primitive.
  11992. primitive_desc(dnnl_primitive_desc_t pd)
  11993. : dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
  11994. dnnl::prop_kind::backward_data) {}
  11995. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  11996. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  11997. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  11998. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  11999. private:
  12000. primitive_desc(const engine &aengine, algorithm aalgorithm,
  12001. const std::vector<float> *factors,
  12002. const memory::desc &diff_src_desc,
  12003. const memory::desc &diff_dst_desc,
  12004. const resampling_forward::primitive_desc &hint_fwd_pd,
  12005. const primitive_attr &attr, bool allow_empty) {
  12006. if (factors)
  12007. memory::validate_dims(*factors, diff_src_desc.get_ndims() - 2);
  12008. dnnl_primitive_desc_t pd = nullptr;
  12009. dnnl_status_t status
  12010. = dnnl_resampling_backward_primitive_desc_create(&pd,
  12011. aengine.get(), convert_to_c(aalgorithm),
  12012. optional_arg(factors), diff_src_desc.get(),
  12013. diff_dst_desc.get(), hint_fwd_pd.get(), attr.get());
  12014. if (!allow_empty)
  12015. error::wrap_c_api(status,
  12016. "could not create a primitive descriptor for "
  12017. "the resampling backward propagation primitive. Run "
  12018. "workload with environment variable ONEDNN_VERBOSE=all "
  12019. "to get additional diagnostic information.");
  12020. reset(pd);
  12021. }
  12022. };
  12023. /// Default constructor. Produces an empty object.
  12024. resampling_backward() = default;
  12025. /// Constructs a resampling backward propagation primitive.
  12026. /// @param pd Primitive descriptor for a resampling backward propagation
  12027. /// primitive.
  12028. resampling_backward(const primitive_desc &pd) : primitive(pd) {}
  12029. /// Constructs a resampling backward propagation primitive from a cache
  12030. /// blob.
  12031. /// @param pd Primitive descriptor for a resampling backward propagation
  12032. /// primitive.
  12033. /// @param cache_blob Cache blob.
  12034. resampling_backward(
  12035. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  12036. : primitive(pd, cache_blob) {}
  12037. };
  12038. /// @} dnnl_api_resampling
  12039. /// @addtogroup dnnl_api_pooling Pooling
  12040. ///
  12041. /// A primitive to perform max or average pooling with dilation.
  12042. ///
  12043. /// @sa @ref dev_guide_pooling in developer guide
  12044. ///
  12045. /// @{
  12046. /// Pooling forward propagation primitive.
  12047. struct pooling_forward : public primitive {
  12048. /// Primitive descriptor for a pooling forward propagation primitive.
  12049. struct primitive_desc : public dnnl::primitive_desc {
  12050. /// Default constructor. Produces an empty object.
  12051. primitive_desc() = default;
  12052. /// Constructs a primitive descriptor for pooling forward propagation
  12053. /// primitive.
  12054. ///
  12055. /// Arrays @p strides, @p kernel, @p dilation, @p padding_l
  12056. /// and @p padding_r contain values for spatial dimensions only and
  12057. /// hence must have the same number of elements as there are spatial
  12058. /// dimensions. The order of values is the same as in the tensor:
  12059. /// depth (for 3D tensors), height (for 3D and 2D tensors), and width.
  12060. ///
  12061. /// @param aengine Engine to use.
  12062. /// @param aprop_kind Propagation kind. Possible values are
  12063. /// #dnnl::prop_kind::forward_training, and
  12064. /// #dnnl::prop_kind::forward_inference.
  12065. /// @param aalgorithm Pooling algorithm kind: either
  12066. /// #dnnl::algorithm::pooling_max,
  12067. /// #dnnl::algorithm::pooling_avg_include_padding,
  12068. /// or #dnnl::algorithm::pooling_avg_exclude_padding.
  12069. /// @param src_desc Source memory descriptor.
  12070. /// @param dst_desc Destination memory descriptor.
  12071. /// @param strides Vector of strides for spatial dimension.
  12072. /// @param kernel Vector of kernel spatial dimensions.
  12073. /// @param dilation Array of dilations for spatial dimension.
  12074. /// @param padding_l Vector of padding values for low indices for each
  12075. /// spatial dimension `([[front,] top,] left)`.
  12076. /// @param padding_r Vector of padding values for high indices for
  12077. /// each spatial dimension `([[back,] bottom,] right)`.
  12078. /// @param attr Primitive attributes to use. Attributes are optional
  12079. /// and default to empty attributes.
  12080. /// @param allow_empty A flag signifying whether construction is
  12081. /// allowed to fail without throwing an exception. In this case an
  12082. /// empty object will be produced. This flag is optional and
  12083. /// defaults to false.
  12084. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  12085. algorithm aalgorithm, const memory::desc &src_desc,
  12086. const memory::desc &dst_desc, const memory::dims &strides,
  12087. const memory::dims &kernel, const memory::dims &dilation,
  12088. const memory::dims &padding_l, const memory::dims &padding_r,
  12089. const primitive_attr &attr = default_attr(),
  12090. bool allow_empty = false) {
  12091. memory::validate_dims(strides, src_desc.get_ndims() - 2);
  12092. memory::validate_dims(kernel, src_desc.get_ndims() - 2);
  12093. memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
  12094. memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
  12095. memory::validate_dims(dilation, src_desc.get_ndims() - 2);
  12096. dnnl_primitive_desc_t pd = nullptr;
  12097. dnnl_status_t status = dnnl_pooling_forward_primitive_desc_create(
  12098. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  12099. convert_to_c(aalgorithm), src_desc.get(), dst_desc.get(),
  12100. &strides[0], &kernel[0], &dilation[0], &padding_l[0],
  12101. &padding_r[0], attr.get());
  12102. if (!allow_empty)
  12103. error::wrap_c_api(status,
  12104. "could not create a descriptor for a pooling forward "
  12105. "propagation primitive");
  12106. reset(pd);
  12107. }
  12108. /// Constructs a primitive descriptor for a pooling forward propagation
  12109. /// primitive from a C API primitive descriptor that must have a
  12110. /// matching kind.
  12111. ///
  12112. /// @param pd C API primitive descriptor for a pooling forward
  12113. /// propagation primitive.
  12114. primitive_desc(dnnl_primitive_desc_t pd)
  12115. : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
  12116. dnnl::prop_kind::forward_training,
  12117. dnnl::prop_kind::forward_inference) {}
  12118. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  12119. memory::desc src_desc() const { return base::src_desc(0); }
  12120. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  12121. memory::desc dst_desc() const { return base::dst_desc(0); }
  12122. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  12123. memory::desc workspace_desc() const { return base::workspace_desc(); }
  12124. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  12125. algorithm get_algorithm() const { return base::get_algorithm(); }
  12126. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  12127. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  12128. /// @copydoc dnnl::primitive_desc_base::get_strides()const
  12129. memory::dims get_strides() const { return base::get_strides(); }
  12130. /// @copydoc dnnl::primitive_desc_base::get_kernel()const
  12131. memory::dims get_kernel() const { return base::get_kernel(); }
  12132. /// @copydoc dnnl::primitive_desc_base::get_dilations()const
  12133. memory::dims get_dilations() const { return base::get_dilations(); }
  12134. /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
  12135. memory::dims get_padding_l() const { return base::get_padding_l(); }
  12136. /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
  12137. memory::dims get_padding_r() const { return base::get_padding_r(); }
  12138. };
  12139. /// Default constructor. Produces an empty object.
  12140. pooling_forward() = default;
  12141. /// Constructs a pooling forward propagation primitive.
  12142. ///
  12143. /// @param pd Primitive descriptor for a pooling forward propagation
  12144. /// primitive.
  12145. pooling_forward(const primitive_desc &pd) : primitive(pd) {}
  12146. /// Constructs a pooling forward propagation primitive from a cache blob.
  12147. ///
  12148. /// @param pd Primitive descriptor for a pooling forward propagation
  12149. /// primitive.
  12150. /// @param cache_blob Cache blob.
  12151. pooling_forward(
  12152. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  12153. : primitive(pd, cache_blob) {}
  12154. };
  12155. /// Pooling backward propagation primitive.
  12156. struct pooling_backward : public primitive {
  12157. /// Primitive descriptor for a pooling backward propagation primitive.
  12158. struct primitive_desc : public dnnl::primitive_desc {
  12159. /// Default constructor. Produces an empty object.
  12160. primitive_desc() = default;
  12161. /// Constructs a primitive descriptor for a pooling backward propagation
  12162. /// primitive.
  12163. ///
  12164. /// Arrays @p strides, @p kernel, @p dilation, @p padding_l
  12165. /// and @p padding_r contain values for spatial dimensions only and
  12166. /// hence must have the same number of elements as there are spatial
  12167. /// dimensions. The order of values is the same as in the tensor:
  12168. /// depth (for 3D tensors), height (for 3D and 2D tensors), and width.
  12169. ///
  12170. /// @param aengine Engine to use.
  12171. /// @param aalgorithm Pooling algorithm kind: either
  12172. /// #dnnl::algorithm::pooling_max,
  12173. /// #dnnl::algorithm::pooling_avg_include_padding,
  12174. /// or #dnnl::algorithm::pooling_avg_exclude_padding.
  12175. /// @param diff_src_desc Diff source memory descriptor.
  12176. /// @param diff_dst_desc Diff destination memory descriptor.
  12177. /// @param strides Vector of strides for spatial dimension.
  12178. /// @param kernel Vector of kernel spatial dimensions.
  12179. /// @param dilation Array of dilations for spatial dimension.
  12180. /// @param padding_l Vector of padding values for low indices for each
  12181. /// spatial dimension `([[front,] top,] left)`.
  12182. /// @param padding_r Vector of padding values for high indices for
  12183. /// each spatial dimension `([[back,] bottom,] right)`.
  12184. /// @param hint_fwd_pd Primitive descriptor for a pooling
  12185. /// forward propagation primitive. It is used as a hint for
  12186. /// deciding which memory format to use.
  12187. /// @param attr Primitive attributes to use. Attributes are optional
  12188. /// and default to empty attributes.
  12189. /// @param allow_empty A flag signifying whether construction is
  12190. /// allowed to fail without throwing an exception. In this case an
  12191. /// empty object will be produced. This flag is optional and
  12192. /// defaults to false.
  12193. primitive_desc(const engine &aengine, algorithm aalgorithm,
  12194. const memory::desc &diff_src_desc,
  12195. const memory::desc &diff_dst_desc, const memory::dims &strides,
  12196. const memory::dims &kernel, const memory::dims &dilation,
  12197. const memory::dims &padding_l, const memory::dims &padding_r,
  12198. const pooling_forward::primitive_desc &hint_fwd_pd,
  12199. const primitive_attr &attr = default_attr(),
  12200. bool allow_empty = false) {
  12201. memory::validate_dims(strides, diff_src_desc.get_ndims() - 2);
  12202. memory::validate_dims(kernel, diff_src_desc.get_ndims() - 2);
  12203. memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2);
  12204. memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2);
  12205. memory::validate_dims(dilation, diff_src_desc.get_ndims() - 2);
  12206. dnnl_primitive_desc_t pd = nullptr;
  12207. dnnl_status_t status = dnnl_pooling_backward_primitive_desc_create(
  12208. &pd, aengine.get(), convert_to_c(aalgorithm),
  12209. diff_src_desc.get(), diff_dst_desc.get(), &strides[0],
  12210. &kernel[0], &dilation[0], &padding_l[0], &padding_r[0],
  12211. hint_fwd_pd.get(), attr.get());
  12212. if (!allow_empty)
  12213. error::wrap_c_api(status,
  12214. "could not create a descriptor for a pooling backward "
  12215. "propagation primitive");
  12216. reset(pd);
  12217. }
  12218. /// Constructs a primitive descriptor for a pooling backward propagation
  12219. /// primitive from a C API primitive descriptor that must have a
  12220. /// matching kind.
  12221. ///
  12222. /// @param pd C API primitive descriptor for a pooling backward
  12223. /// propagation primitive.
  12224. primitive_desc(dnnl_primitive_desc_t pd)
  12225. : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
  12226. dnnl::prop_kind::backward_data) {}
  12227. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  12228. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  12229. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  12230. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  12231. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  12232. memory::desc workspace_desc() const { return base::workspace_desc(); }
  12233. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  12234. algorithm get_algorithm() const { return base::get_algorithm(); }
  12235. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  12236. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  12237. /// @copydoc dnnl::primitive_desc_base::get_strides()const
  12238. memory::dims get_strides() const { return base::get_strides(); }
  12239. /// @copydoc dnnl::primitive_desc_base::get_kernel()const
  12240. memory::dims get_kernel() const { return base::get_kernel(); }
  12241. /// @copydoc dnnl::primitive_desc_base::get_dilations()const
  12242. memory::dims get_dilations() const { return base::get_dilations(); }
  12243. /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
  12244. memory::dims get_padding_l() const { return base::get_padding_l(); }
  12245. /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
  12246. memory::dims get_padding_r() const { return base::get_padding_r(); }
  12247. };
  12248. /// Default constructor. Produces an empty object.
  12249. pooling_backward() = default;
  12250. /// Constructs a pooling backward propagation primitive.
  12251. ///
  12252. /// @param pd Primitive descriptor for a pooling backward propagation
  12253. /// primitive.
  12254. pooling_backward(const primitive_desc &pd) : primitive(pd) {}
  12255. /// Constructs a pooling backward propagation primitive from a cache blob.
  12256. ///
  12257. /// @param pd Primitive descriptor for a pooling backward propagation
  12258. /// primitive.
  12259. /// @param cache_blob Cache blob.
  12260. pooling_backward(
  12261. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  12262. : primitive(pd, cache_blob) {}
  12263. };
  12264. /// @} dnnl_api_pooling
  12265. /// @addtogroup dnnl_api_prelu PReLU
  12266. ///
  12267. /// PReLU primitive
  12268. /// A primitive to perform PReLU (leaky ReLU with trainable alpha parameter)
  12269. ///
  12270. /// @sa @ref dev_guide_prelu in developer guide
  12271. ///
  12272. /// @{
  12273. /// PReLU forward propagation primitive.
  12274. struct prelu_forward : public primitive {
  12275. /// Primitive descriptor for a PReLU forward propagation primitive.
  12276. struct primitive_desc : public dnnl::primitive_desc {
  12277. /// Default constructor. Produces an empty object.
  12278. primitive_desc() = default;
  12279. /// Constructs a primitive descriptor for a PReLU forward propagation
  12280. /// primitive.
  12281. ///
  12282. /// @param aengine Engine to use.
  12283. /// @param aprop_kind Propagation kind. Possible values are
  12284. /// #dnnl::prop_kind::forward_training, and
  12285. /// #dnnl::prop_kind::forward_inference.
  12286. /// @param src_desc Source memory descriptor.
  12287. /// @param weight_desc Alpha parameters memory descriptor.
  12288. /// @param dst_desc Destination memory descriptor.
  12289. /// @param attr Primitive attributes to use. Attributes are optional
  12290. /// and default to empty attributes.
  12291. /// @param allow_empty A flag signifying whether construction is
  12292. /// allowed to fail without throwing an exception. In this case an
  12293. /// empty object will be produced. This flag is optional and
  12294. /// defaults to false.
  12295. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  12296. const memory::desc &src_desc, const memory::desc &weight_desc,
  12297. const memory::desc &dst_desc,
  12298. const primitive_attr &attr = default_attr(),
  12299. bool allow_empty = false) {
  12300. dnnl_primitive_desc_t pd = nullptr;
  12301. dnnl_status_t status = dnnl_prelu_forward_primitive_desc_create(&pd,
  12302. aengine.get(), dnnl::convert_to_c(aprop_kind),
  12303. src_desc.get(), weight_desc.get(), dst_desc.get(),
  12304. attr.get());
  12305. if (!allow_empty)
  12306. error::wrap_c_api(status,
  12307. "could not create a primitive descriptor for "
  12308. "the prelu forward propagation primitive. Run workload "
  12309. "with environment variable ONEDNN_VERBOSE=all to get "
  12310. "additional diagnostic information.");
  12311. reset(pd);
  12312. }
  12313. /// Constructs a primitive descriptor for a prelu forward
  12314. /// propagation primitive from a C API primitive descriptor that must
  12315. /// have a matching kind.
  12316. ///
  12317. /// @param pd C API primitive descriptor for a prelu forward
  12318. /// propagation primitive.
  12319. primitive_desc(dnnl_primitive_desc_t pd)
  12320. : dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu,
  12321. dnnl::prop_kind::forward_training,
  12322. dnnl::prop_kind::forward_inference) {}
  12323. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  12324. memory::desc src_desc() const { return base::src_desc(0); }
  12325. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  12326. memory::desc dst_desc() const { return base::dst_desc(0); }
  12327. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  12328. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  12329. };
  12330. /// Default constructor. Produces an empty object.
  12331. prelu_forward() = default;
  12332. /// Constructs a prelu forward propagation primitive.
  12333. /// @param pd Primitive descriptor for a prelu forward propagation
  12334. /// primitive.
  12335. prelu_forward(const primitive_desc &pd) : primitive(pd) {}
  12336. /// Constructs a prelu forward propagation primitive from a cache blob.
  12337. /// @param pd Primitive descriptor for a prelu forward propagation
  12338. /// primitive.
  12339. /// @param cache_blob Cache blob.
  12340. prelu_forward(
  12341. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  12342. : primitive(pd, cache_blob) {}
  12343. };
  12344. /// PReLU backward propagation primitive.
  12345. struct prelu_backward : public primitive {
  12346. /// Primitive descriptor for prelu backward propagation.
  12347. struct primitive_desc : public dnnl::primitive_desc {
  12348. /// Default constructor. Produces an empty object.
  12349. primitive_desc() = default;
  12350. /// Constructs a descriptor for a PReLU backward propagation
  12351. /// primitive.
  12352. ///
  12353. /// @param aengine Engine to use.
  12354. /// @param src_desc Source memory descriptor.
  12355. /// @param weight_desc Alpha parameters memory descriptor.
  12356. /// @param diff_src_desc Diff source memory descriptor.
  12357. /// @param diff_weights_desc Diff alpha parameters memory descriptor.
  12358. /// @param diff_dst_desc Diff destination memory descriptor.
  12359. /// @param hint_fwd_pd Primitive descriptor for a PReLU
  12360. /// forward propagation primitive. It is used as a hint for
  12361. /// deciding which memory format to use.
  12362. /// @param attr Primitive attributes to use. Attributes are optional
  12363. /// and default to empty attributes.
  12364. /// @param allow_empty A flag signifying whether construction is
  12365. /// allowed to fail without throwing an exception. In this case an
  12366. /// empty object will be produced. This flag is optional and
  12367. /// defaults to false.
  12368. primitive_desc(const engine &aengine, const memory::desc &src_desc,
  12369. const memory::desc &weight_desc,
  12370. const memory::desc &diff_src_desc,
  12371. const memory::desc &diff_weights_desc,
  12372. const memory::desc &diff_dst_desc,
  12373. const prelu_forward::primitive_desc &hint_fwd_pd,
  12374. const primitive_attr &attr = default_attr(),
  12375. bool allow_empty = false) {
  12376. dnnl_primitive_desc_t pd = nullptr;
  12377. dnnl_status_t status = dnnl_prelu_backward_primitive_desc_create(
  12378. &pd, aengine.get(), src_desc.get(), weight_desc.get(),
  12379. diff_src_desc.get(), diff_weights_desc.get(),
  12380. diff_dst_desc.get(), hint_fwd_pd.get(), attr.get());
  12381. if (!allow_empty)
  12382. error::wrap_c_api(status,
  12383. "could not create a primitive descriptor for "
  12384. "the prelu backward propagation primitive. Run "
  12385. "workload with environment variable ONEDNN_VERBOSE=all "
  12386. "to get additional diagnostic information.");
  12387. reset(pd);
  12388. }
  12389. /// Constructs a primitive descriptor for a prelu backward
  12390. /// propagation primitive from a C API primitive descriptor that must
  12391. /// have a matching kind.
  12392. ///
  12393. /// @param pd C API primitive descriptor for a prelu backward
  12394. /// propagation primitive.
  12395. primitive_desc(dnnl_primitive_desc_t pd)
  12396. : dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu,
  12397. dnnl::prop_kind::backward) {}
  12398. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  12399. memory::desc src_desc() const { return base::src_desc(0); }
  12400. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  12401. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  12402. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  12403. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  12404. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  12405. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  12406. };
  12407. /// Default constructor. Produces an empty object.
  12408. prelu_backward() = default;
  12409. /// Constructs a prelu backward propagation primitive.
  12410. /// @param pd Primitive descriptor for a prelu backward propagation
  12411. /// primitive.
  12412. prelu_backward(const primitive_desc &pd) : primitive(pd) {}
  12413. /// Constructs a prelu backward propagation primitive from a cache blob.
  12414. /// @param pd Primitive descriptor for a prelu backward propagation
  12415. /// primitive.
  12416. /// @param cache_blob Cache blob.
  12417. prelu_backward(
  12418. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  12419. : primitive(pd, cache_blob) {}
  12420. };
  12421. /// @} dnnl_api_prelu
  12422. /// @addtogroup dnnl_api_reduction Reduction
  12423. ///
  12424. /// A primitive to compute reduction operation on data tensor
  12425. /// using min, max, mul, sum, mean and norm_lp operations.
  12426. ///
  12427. /// @sa @ref dev_guide_reduction in developer guide
  12428. ///
  12429. /// @{
  12430. /// Reduction.
  12431. struct reduction : public primitive {
  12432. /// Primitive descriptor for a reduction primitive.
  12433. struct primitive_desc : public dnnl::primitive_desc {
  12434. /// Default constructor. Produces an empty object.
  12435. primitive_desc() = default;
  12436. /// Constructs a primitive descriptor for a reduction primitive using
  12437. /// algorithm specific parameters, source and destination memory
  12438. /// descriptors.
  12439. ///
  12440. /// @note
  12441. /// Destination memory descriptor may be initialized with
  12442. /// #dnnl::memory::format_tag::any value of @p format_tag.
  12443. ///
  12444. /// @param aengine Engine to use.
  12445. /// @param aalgorithm reduction algorithm kind. Possible values:
  12446. /// #dnnl_reduction_max, #dnnl_reduction_min, #dnnl_reduction_sum,
  12447. /// #dnnl_reduction_mul, #dnnl_reduction_mean,
  12448. /// #dnnl_reduction_norm_lp_max, #dnnl_reduction_norm_lp_sum,
  12449. /// #dnnl_reduction_norm_lp_power_p_max,
  12450. /// #dnnl_reduction_norm_lp_power_p_sum.
  12451. /// @param p algorithm specific parameter.
  12452. /// @param eps algorithm specific parameter.
  12453. /// @param src_desc Source memory descriptor.
  12454. /// @param dst_desc Destination memory descriptor.
  12455. /// @param attr Primitive attributes to use. Attributes are optional
  12456. /// and default to empty attributes.
  12457. /// @param allow_empty A flag signifying whether construction is
  12458. /// allowed to fail without throwing an exception. In this case an
  12459. /// empty object will be produced. This flag is optional and
  12460. /// defaults to false.
  12461. primitive_desc(const engine &aengine, algorithm aalgorithm,
  12462. const memory::desc &src_desc, const memory::desc &dst_desc,
  12463. float p, float eps, const primitive_attr &attr = default_attr(),
  12464. bool allow_empty = false) {
  12465. dnnl_primitive_desc_t pd = nullptr;
  12466. dnnl_status_t status = dnnl_reduction_primitive_desc_create(&pd,
  12467. aengine.get(), convert_to_c(aalgorithm), src_desc.get(),
  12468. dst_desc.get(), p, eps, attr.get());
  12469. if (!allow_empty)
  12470. error::wrap_c_api(status,
  12471. "could not create a primitive descriptor for "
  12472. "the reduction primitive. Run workload with "
  12473. "environment variable ONEDNN_VERBOSE=all to get "
  12474. "additional diagnostic information.");
  12475. reset(pd);
  12476. }
  12477. /// Constructs a primitive descriptor for a reduction primitive from a C
  12478. /// API primitive descriptor that must have a matching kind.
  12479. ///
  12480. /// @param pd C API primitive descriptor for a reduction primitive.
  12481. primitive_desc(dnnl_primitive_desc_t pd)
  12482. : dnnl::primitive_desc(pd, dnnl::primitive::kind::reduction) {}
  12483. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  12484. memory::desc src_desc() const { return base::src_desc(0); }
  12485. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  12486. memory::desc dst_desc() const { return base::dst_desc(0); }
  12487. /// @copydoc dnnl::primitive_desc_base::get_p()const
  12488. float get_p() const { return base::get_p(); }
  12489. /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
  12490. float get_epsilon() const { return base::get_epsilon(); }
  12491. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  12492. algorithm get_algorithm() const { return base::get_algorithm(); }
  12493. };
  12494. /// Default constructor. Produces an empty object.
  12495. reduction() = default;
  12496. /// Constructs a reduction primitive.
  12497. /// @param pd Primitive descriptor for a reduction primitive.
  12498. reduction(const primitive_desc &pd) : primitive(pd) {}
  12499. /// Constructs a reduction primitive from a cache blob.
  12500. /// @param pd Primitive descriptor for a reduction primitive.
  12501. /// @param cache_blob Cache blob.
  12502. reduction(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  12503. : primitive(pd, cache_blob) {}
  12504. };
  12505. /// @} dnnl_api_reduction
  12506. /// @} dnnl_api_primitives
  12507. /// @addtogroup dnnl_api_service Service
  12508. ///
  12509. /// A set of functions that aid in oneDNN debugging and profiling.
  12510. ///
  12511. /// @{
  12512. /// @copydoc dnnl_version_t
  12513. using version_t = dnnl_version_t;
  12514. /// Status values returned by the library functions.
  12515. enum class status {
  12516. /// @copydoc dnnl_success
  12517. success = dnnl_success,
  12518. /// @copydoc dnnl_out_of_memory
  12519. out_of_memory = dnnl_out_of_memory,
  12520. /// @copydoc dnnl_invalid_arguments
  12521. invalid_arguments = dnnl_invalid_arguments,
  12522. /// @copydoc dnnl_unimplemented
  12523. unimplemented = dnnl_unimplemented,
  12524. /// @copydoc dnnl_last_impl_reached
  12525. last_impl_reached = dnnl_last_impl_reached,
  12526. /// @copydoc dnnl_runtime_error
  12527. runtime_error = dnnl_runtime_error,
  12528. /// @copydoc dnnl_not_required
  12529. not_required = dnnl_not_required,
  12530. };
  12531. /// @copydoc dnnl_set_verbose()
  12532. inline status set_verbose(int level) {
  12533. return static_cast<status>(dnnl_set_verbose(level));
  12534. }
  12535. /// @copydoc dnnl_version()
  12536. inline const version_t *version() {
  12537. return dnnl_version();
  12538. }
  12539. /// Returns the floating-point math mode that will be used by default
  12540. /// for all subsequently created primitives.
  12541. ///
  12542. /// @returns Output FP math mode.
  12543. inline fpmath_mode get_default_fpmath_mode() {
  12544. dnnl_fpmath_mode_t mode;
  12545. error::wrap_c_api(dnnl_get_default_fpmath_mode(&mode),
  12546. "could not get a default fpmath mode");
  12547. return static_cast<fpmath_mode>(mode);
  12548. }
  12549. /// @copydoc dnnl_set_default_fpmath_mode()
  12550. inline status set_default_fpmath_mode(fpmath_mode mode) {
  12551. return static_cast<status>(
  12552. dnnl_set_default_fpmath_mode(convert_to_c(mode)));
  12553. }
  12554. /// @copydoc dnnl_set_jit_dump()
  12555. inline status set_jit_dump(int enable) {
  12556. return static_cast<status>(dnnl_set_jit_dump(enable));
  12557. }
  12558. /// @copydoc dnnl_set_jit_profiling_flags()
  12559. inline status set_jit_profiling_flags(unsigned flags) {
  12560. return static_cast<status>(dnnl_set_jit_profiling_flags(flags));
  12561. }
  12562. /// @copydoc dnnl_set_jit_profiling_jitdumpdir()
  12563. inline status set_jit_profiling_jitdumpdir(const std::string &dir) {
  12564. return static_cast<status>(dnnl_set_jit_profiling_jitdumpdir(dir.c_str()));
  12565. }
  12566. /// @copydoc dnnl_cpu_isa_t
  12567. enum class cpu_isa {
  12568. /// @copydoc dnnl_cpu_isa_default
  12569. isa_default = dnnl_cpu_isa_default,
  12570. /// @copydoc dnnl_cpu_isa_sse41
  12571. sse41 = dnnl_cpu_isa_sse41,
  12572. /// @copydoc dnnl_cpu_isa_avx
  12573. avx = dnnl_cpu_isa_avx,
  12574. /// @copydoc dnnl_cpu_isa_avx2
  12575. avx2 = dnnl_cpu_isa_avx2,
  12576. /// @copydoc dnnl_cpu_isa_avx2_vnni
  12577. avx2_vnni = dnnl_cpu_isa_avx2_vnni,
  12578. /// @copydoc dnnl_cpu_isa_avx2_vnni_2
  12579. avx2_vnni_2 = dnnl_cpu_isa_avx2_vnni_2,
  12580. /// @copydoc dnnl_cpu_isa_avx512_core
  12581. avx512_core = dnnl_cpu_isa_avx512_core,
  12582. /// @copydoc dnnl_cpu_isa_avx512_core_vnni
  12583. avx512_core_vnni = dnnl_cpu_isa_avx512_core_vnni,
  12584. /// @copydoc dnnl_cpu_isa_avx512_core_bf16
  12585. avx512_core_bf16 = dnnl_cpu_isa_avx512_core_bf16,
  12586. /// @copydoc dnnl_cpu_isa_avx10_1_512
  12587. avx10_1_512 = dnnl_cpu_isa_avx10_1_512,
  12588. /// @copydoc dnnl_cpu_isa_avx512_core_fp16
  12589. avx512_core_fp16 = dnnl_cpu_isa_avx512_core_fp16,
  12590. /// @copydoc dnnl_cpu_isa_avx10_1_512_amx
  12591. avx10_1_512_amx = dnnl_cpu_isa_avx10_1_512_amx,
  12592. /// @copydoc dnnl_cpu_isa_avx512_core_amx
  12593. avx512_core_amx = dnnl_cpu_isa_avx512_core_amx,
  12594. /// @copydoc dnnl_cpu_isa_avx10_1_512_amx_fp16
  12595. avx10_1_512_amx_fp16 = dnnl_cpu_isa_avx10_1_512_amx_fp16,
  12596. /// @copydoc dnnl_cpu_isa_avx512_core_amx_fp16
  12597. avx512_core_amx_fp16 = dnnl_cpu_isa_avx512_core_amx_fp16,
  12598. };
  12599. /// @copydoc dnnl_set_max_cpu_isa()
  12600. inline status set_max_cpu_isa(cpu_isa isa) {
  12601. return static_cast<status>(
  12602. dnnl_set_max_cpu_isa(static_cast<dnnl_cpu_isa_t>(isa)));
  12603. }
  12604. /// @copydoc dnnl_get_effective_cpu_isa()
  12605. inline cpu_isa get_effective_cpu_isa() {
  12606. return static_cast<cpu_isa>(dnnl_get_effective_cpu_isa());
  12607. }
  12608. /// @copydoc dnnl_cpu_isa_hints_t
  12609. enum class cpu_isa_hints {
  12610. /// @copydoc dnnl_cpu_isa_no_hints
  12611. no_hints = dnnl_cpu_isa_no_hints,
  12612. /// @copydoc dnnl_cpu_isa_prefer_ymm
  12613. prefer_ymm = dnnl_cpu_isa_prefer_ymm,
  12614. };
  12615. /// @copydoc dnnl_set_cpu_isa_hints()
  12616. inline status set_cpu_isa_hints(cpu_isa_hints isa_hints) {
  12617. return static_cast<status>(dnnl_set_cpu_isa_hints(
  12618. static_cast<dnnl_cpu_isa_hints_t>(isa_hints)));
  12619. }
  12620. /// @copydoc dnnl_get_cpu_isa_hints()
  12621. inline cpu_isa_hints get_cpu_isa_hints() {
  12622. return static_cast<cpu_isa_hints>(dnnl_get_cpu_isa_hints());
  12623. }
  12624. /// @} dnnl_api_service
  12625. #ifdef DNNL_EXPERIMENTAL_PROFILING
  12626. /// @addtogroup dnnl_api_profiling Profiling
  12627. /// @{
  12628. /// Profiling data kind.
  12629. enum class profiling_data_kind {
  12630. /// Undefined profiling data kind.
  12631. undef = dnnl_profiling_data_kind_undef,
  12632. /// Data kind to query an execution time in nanoseconds.
  12633. time = dnnl_profiling_data_kind_time,
  12634. };
  12635. /// Resets a profiler's state.
  12636. ///
  12637. /// @param stream Stream associated with the profiler.
  12638. inline void reset_profiling(stream &stream) {
  12639. error::wrap_c_api(
  12640. dnnl_reset_profiling(stream.get()), "could not reset profiling");
  12641. }
  12642. /// Returns requested profiling data. The profiling data accumulates for each
  12643. /// primitive execution. The size of the vector will be equal to the number
  12644. /// of executions since the last `dnnl::reset_profiling` call.
  12645. ///
  12646. /// The profiling data can be reset by calling #dnnl::reset_profiling.
  12647. ///
  12648. /// @note
  12649. /// It is required to wait for all submitted primitives to complete
  12650. /// using #dnnl::stream::wait prior to querying profiling data.
  12651. ///
  12652. /// @param stream Stream that was used for executing a primitive that
  12653. /// is being profiled.
  12654. /// @param data_kind Profiling data kind to query.
  12655. ///
  12656. /// @returns A vector with the requested profiling data.
  12657. inline std::vector<uint64_t> get_profiling_data(
  12658. stream &stream, profiling_data_kind data_kind) {
  12659. int num_entries = 0;
  12660. error::wrap_c_api(
  12661. dnnl_query_profiling_data(stream.get(),
  12662. static_cast<dnnl_profiling_data_kind_t>(data_kind),
  12663. &num_entries, nullptr),
  12664. "could not get number of entries for profiling data");
  12665. if (num_entries == 0) return {};
  12666. std::vector<uint64_t> data(num_entries);
  12667. error::wrap_c_api(
  12668. dnnl_query_profiling_data(stream.get(),
  12669. static_cast<dnnl_profiling_data_kind_t>(data_kind),
  12670. &num_entries, data.data()),
  12671. "could not get profiling data");
  12672. return data;
  12673. }
  12674. /// @} dnnl_api_profiling
  12675. #endif
  12676. /// @addtogroup dnnl_api_primitive_cache Primitive Cache
  12677. ///
  12678. /// A set of functions that provide primitive cache control.
  12679. ///
  12680. /// @{
  12681. /// Returns the number of primitives that can be held in the primitive cache
  12682. /// at the same time.
  12683. inline int get_primitive_cache_capacity() {
  12684. int result = 0;
  12685. error::wrap_c_api(dnnl_get_primitive_cache_capacity(&result),
  12686. "could not get primitive cache capacity");
  12687. return result;
  12688. }
  12689. /// @copydoc dnnl_set_primitive_cache_capacity(int capacity)
  12690. inline void set_primitive_cache_capacity(int capacity) {
  12691. error::wrap_c_api(dnnl_set_primitive_cache_capacity(capacity),
  12692. "could not set primitive cache capacity");
  12693. }
  12694. /// @} dnnl_api_primitive_cache
  12695. /// @addtogroup dnnl_api_blas BLAS functions
  12696. ///
  12697. /// A subset of Basic Linear Algebra (BLAS) functions that perform
  12698. /// matrix-matrix multiplication.
  12699. ///
  12700. /// @{
  12701. /// @copydoc dnnl_sgemm()
  12702. inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N,
  12703. dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
  12704. const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc) {
  12705. return static_cast<status>(dnnl_sgemm(
  12706. transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc));
  12707. }
  12708. /// @copydoc dnnl_gemm_u8s8s32()
  12709. inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
  12710. dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
  12711. dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
  12712. float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
  12713. return static_cast<status>(dnnl_gemm_u8s8s32(transa, transb, offsetc, M, N,
  12714. K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
  12715. }
  12716. /// @copydoc dnnl_gemm_s8s8s32()
  12717. inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
  12718. dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
  12719. dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
  12720. float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
  12721. return static_cast<status>(dnnl_gemm_s8s8s32(transa, transb, offsetc, M, N,
  12722. K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
  12723. }
  12724. /// @} dnnl_api_blas
  12725. // implementation section
  12726. /// @cond DO_NOT_DOCUMENT_THIS
  12727. inline primitive::primitive(const_dnnl_primitive_desc_t c_pd) {
  12728. dnnl_primitive_t result;
  12729. error::wrap_c_api(dnnl_primitive_create(&result, c_pd),
  12730. "could not create a primitive");
  12731. reset(result);
  12732. }
  12733. inline primitive::primitive(const_dnnl_primitive_desc_t c_pd,
  12734. const std::vector<uint8_t> &cache_blob) {
  12735. dnnl_primitive_t result;
  12736. size_t size = cache_blob.size();
  12737. const uint8_t *cache_blob_data = cache_blob.data();
  12738. error::wrap_c_api(dnnl_primitive_create_from_cache_blob(
  12739. &result, c_pd, size, cache_blob_data),
  12740. "could not create a primitive from a cache blob");
  12741. reset(result);
  12742. }
  12743. inline primitive::primitive(const primitive_desc &pd) : primitive(pd.get()) {}
  12744. inline primitive::primitive(
  12745. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  12746. : primitive(pd.get(), cache_blob) {}
  12747. inline void primitive::execute(const stream &astream,
  12748. const std::unordered_map<int, memory> &args) const {
  12749. std::vector<dnnl_exec_arg_t> c_args;
  12750. c_args.reserve(args.size());
  12751. for (const auto &a : args)
  12752. c_args.push_back({a.first, a.second.get(true)});
  12753. error::wrap_c_api(dnnl_primitive_execute(get(), astream.get(),
  12754. (int)c_args.size(), c_args.data()),
  12755. "could not execute a primitive");
  12756. }
  12757. /// @endcond
  12758. #undef DNNL_DEFINE_BITMASK_OPS
  12759. } // namespace dnnl
  12760. /// oneAPI namespace
  12761. /// The oneAPI namespace.
  12762. /// Contains the oneapi::dnnl namespace as an alias to the ::dnnl namespace.
  12763. namespace oneapi {
  12764. // Note: without this guard, doxygen warns of potentially recursive namespace
  12765. #ifndef DOXYGEN_SHOULD_SKIP_THIS
  12766. /// oneDNN alias namespace
  12767. namespace dnnl = ::dnnl;
  12768. #endif
  12769. } // namespace oneapi
  12770. /// @} dnnl_api
  12771. #endif /* ONEAPI_DNNL_DNNL_HPP */