executor.py 123 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284
  1. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import logging
  16. import os
  17. import sys
  18. import warnings
  19. from functools import lru_cache
  20. import numpy as np
  21. from paddle import pir
  22. from paddle.base.framework import in_cinn_mode
  23. from paddle.base.libpaddle.pir import apply_cinn_pass
  24. from ..pir import (
  25. Program as PirProgram,
  26. Value,
  27. translate_to_pir,
  28. translate_to_pir_with_param_map,
  29. )
  30. from . import compiler, core, framework, unique_name
  31. from .data_feeder import convert_dtype
  32. from .framework import (
  33. Operator,
  34. Program,
  35. Variable,
  36. _apply_pass,
  37. convert_np_dtype_to_dtype_,
  38. default_main_program,
  39. get_flags,
  40. in_pir_mode,
  41. paddle_type_to_proto_type,
  42. process_type_promotion,
  43. set_flags,
  44. )
  45. from .incubate.checkpoint import auto_checkpoint as acp
  46. from .trainer_factory import FetchHandlerMonitor, TrainerFactory
  47. from .wrapped_decorator import signature_safe_contextmanager
  48. __all__ = []
  49. g_scope = core.Scope()
  50. InferNativeConfig = core.NativeConfig
  51. InferAnalysisConfig = core.AnalysisConfig
  52. def global_scope():
  53. """
  54. :api_attr: Static Graph
  55. Get the global/default scope instance. There are a lot of APIs use
  56. :code:`global_scope` as its default value, e.g., :code:`Executor.run`
  57. Returns:
  58. Scope: The global/default scope instance.
  59. Examples:
  60. .. code-block:: python
  61. >>> import paddle
  62. >>> import numpy
  63. >>> paddle.static.global_scope().var("data").get_tensor().set(numpy.ones((2, 2)), paddle.CPUPlace())
  64. >>> numpy.array(paddle.static.global_scope().find_var("data").get_tensor())
  65. """
  66. return g_scope
  67. def _switch_scope(scope):
  68. global g_scope
  69. ex = g_scope
  70. g_scope = scope
  71. return ex
  72. @signature_safe_contextmanager
  73. def scope_guard(scope):
  74. """
  75. This function switches scope through python `with` statement.
  76. Scope records the mapping between variable names and variables ( :ref:`api_guide_Variable` ),
  77. similar to brackets in programming languages.
  78. If this function is not invoked, all variables and variable names are recorded in the default global scope.
  79. When users need to create variables with the same name,
  80. they need to switch scopes through this function
  81. if they do not want the mapping of variables with the same name to be overwritten.
  82. After switching through the `with` statement,
  83. all variables created in the `with` block will be assigned to a new scope.
  84. Parameters:
  85. scope: The new scope.
  86. Returns:
  87. None
  88. Examples:
  89. .. code-block:: python
  90. >>> import paddle
  91. >>> import numpy
  92. >>> paddle.enable_static()
  93. >>> new_scope = paddle.static.Scope()
  94. >>> with paddle.static.scope_guard(new_scope):
  95. ... paddle.static.global_scope().var("data").get_tensor().set(numpy.ones((2, 2)), paddle.CPUPlace())
  96. >>> numpy.array(new_scope.find_var("data").get_tensor())
  97. array([[1., 1.],
  98. [1., 1.]])
  99. """
  100. ex = _switch_scope(scope)
  101. try:
  102. yield
  103. finally:
  104. _switch_scope(ex)
  105. def as_numpy(tensor, copy=False):
  106. """
  107. Convert a Tensor to a numpy.ndarray, its only support Tensor without LoD information.
  108. For higher dimensional sequence data, please use LoDTensor directly.
  109. Examples:
  110. .. code-block:: python
  111. >>> import paddle.base as base
  112. >>> import numpy
  113. >>> new_scope = base.Scope()
  114. >>> with base.scope_guard(new_scope):
  115. ... base.global_scope().var("data").get_tensor().set(numpy.ones((2, 2)), base.CPUPlace())
  116. >>> tensor = new_scope.find_var("data").get_tensor()
  117. >>> base.executor.as_numpy(tensor) # or numpy.array(new_scope.find_var("data").get_tensor())
  118. Args:
  119. tensor(Variable): a instance of Tensor
  120. copy(bool, optional): Whether to use deep copy.
  121. Returns:
  122. numpy.ndarray
  123. """
  124. if isinstance(tensor, core.LoDTensorArray):
  125. return [as_numpy(t, copy) for t in tensor]
  126. if isinstance(tensor, list):
  127. return [as_numpy(t, copy) for t in tensor]
  128. assert isinstance(tensor, core.LoDTensor)
  129. lod = tensor.lod()
  130. if len(lod) > 0:
  131. raise RuntimeError(
  132. "Some of your fetched tensors hold LoD information. \
  133. They can not be completely cast to Python ndarray. \
  134. Please set the parameter 'return_numpy' as 'False' to \
  135. return LoDTensor itself directly."
  136. )
  137. if tensor._is_initialized():
  138. if copy:
  139. return np.array(tensor)
  140. else:
  141. return np.asarray(tensor)
  142. else:
  143. return None
  144. def dtype_is_compatible_with(first, second):
  145. """
  146. Returns True if the first dtype can be compatible the second one.
  147. Currently, we require the two dtype's have to be same.
  148. Args:
  149. dtype (np.dtype|VarType|str): The type of data: float32, int64, etc.
  150. Returns:
  151. True if the two types are same.
  152. """
  153. if not isinstance(first, core.VarDesc.VarType):
  154. first = convert_np_dtype_to_dtype_(first)
  155. if not isinstance(second, core.VarDesc.VarType):
  156. second = convert_np_dtype_to_dtype_(second)
  157. return first == second
  158. def dimension_is_compatible_with(first, second):
  159. """
  160. Returns True if the two dimensions are compatible.
  161. A dimension is compatible with the other if:
  162. 1. The length of the dimensions are same.
  163. 2. Each non-negative number of the two dimensions are same.
  164. 3. For negative number or 'None' in a dimension, it means unknown so it
  165. is compatible with any number.
  166. Args:
  167. first (list/tuple): integers representing shape. "None" or negative
  168. number means unknown.
  169. second (list/tuple): integers representing shape. "None" or negative
  170. number means unknown.
  171. Returns:
  172. True if the two dimensions are compatible.
  173. """
  174. dim_len = len(first)
  175. if dim_len != len(second):
  176. return False
  177. for i in range(dim_len):
  178. if first[i] is None or first[i] < 0:
  179. continue
  180. if second[i] is None or second[i] < 0:
  181. continue
  182. if first[i] != second[i]:
  183. return False
  184. return True
  185. def check_feed_shape_type(var, feed, num_places=1):
  186. """
  187. Returns True if the variable doesn't require feed check or it is compatible
  188. with the shape and have same dtype as the fed value.
  189. A dimension is compatible with the other if:
  190. 1. The length of the dimensions are same.
  191. 2. Each non-negative number of the two dimensions are same.
  192. 3. For negative number or 'None' in a dimension, it means unknown so it
  193. is compatible with any number.
  194. Args:
  195. var (Variable): the Variable object
  196. feed (LoDTensor): the fed value, which must be a LoDTensor
  197. num_places: an integer value indicating the number of places.
  198. ParallelExecutor will divide data into devices (CPU/GPU) evenly.
  199. Returns:
  200. True if the shape and dtype of variable is compatible with the feed value
  201. Raises:
  202. ValueError: if the shape or dtype of the variable is not compatible with
  203. the feed value
  204. """
  205. if var.desc.need_check_feed():
  206. diff_shape = core.diff_tensor_shape(feed, var.desc, num_places)
  207. if diff_shape is not None:
  208. raise ValueError(
  209. 'The fed Variable %r should have dimensions = %d, shape = '
  210. '%r, but received fed shape %r on each device'
  211. % (var.name, len(var.shape), var.shape, diff_shape)
  212. )
  213. if not dtype_is_compatible_with(feed._dtype(), var.dtype):
  214. var_dtype_format = (
  215. convert_dtype(var.dtype)
  216. if isinstance(var.dtype, core.VarDesc.VarType)
  217. else var.dtype
  218. )
  219. feed_dtype_format = (
  220. convert_dtype(feed._dtype())
  221. if isinstance(feed._dtype(), core.VarDesc.VarType)
  222. else feed._dtype()
  223. )
  224. raise ValueError(
  225. f'The data type of fed Variable {var.name!r} must be {var_dtype_format!r}, but received {feed_dtype_format!r}'
  226. )
  227. return True
  228. def pir_check_feed_shape_type(feed, name, target_shape, dtype, num_places=1):
  229. """
  230. Returns True if the variable doesn't require feed check or it is compatible
  231. with the shape and have same dtype as the fed value.
  232. A dimension is compatible with the other if:
  233. 1. The length of the dimensions are same.
  234. 2. Each non-negative number of the two dimensions are same.
  235. 3. For negative number or 'None' in a dimension, it means unknown so it
  236. is compatible with any number.
  237. Args:
  238. feed (LoDTensor): the fed value, which must be a LoDTensor
  239. name (str): name of the variable
  240. target_shape (list): the shape that will be compared with feed
  241. dtype (core.VarDesc.VarType): the dtype that will be compared with feed
  242. num_places: an integer value indicating the number of places.
  243. ParallelExecutor will divide data into devices (CPU/GPU) evenly.
  244. Returns:
  245. True if the shape and dtype of variable is compatible with the feed value
  246. Raises:
  247. ValueError: if the shape or dtype of the variable is not compatible with
  248. the feed value
  249. """
  250. diff_shape = core.diff_tensor_shape(feed, target_shape, num_places)
  251. if diff_shape is not None:
  252. warnings.warn(
  253. 'The fed Variable %r should have dimensions = %d, shape = '
  254. '%r, but received fed shape %r on each device'
  255. % (name, len(target_shape), target_shape, diff_shape)
  256. )
  257. if not dtype_is_compatible_with(feed._dtype(), dtype):
  258. var_dtype_format = (
  259. convert_dtype(dtype)
  260. if isinstance(dtype, core.VarDesc.VarType)
  261. else dtype
  262. )
  263. feed_dtype_format = (
  264. convert_dtype(feed._dtype())
  265. if isinstance(feed._dtype(), core.VarDesc.VarType)
  266. else feed._dtype()
  267. )
  268. warnings.warn(
  269. f'The data type of fed Variable {name!r} must be {var_dtype_format!r}, but received {feed_dtype_format!r}'
  270. )
  271. return True
  272. def has_feed_operators(block, feed_targets, feed_holder_name):
  273. """Check whether the block already has feed operators.
  274. Return false if the block does not have any feed operators.
  275. If some feed operators have been prepended to the block, check that
  276. the info contained in these feed operators matches the feed_targets
  277. and feed_holder_name. Raise exception when any mismatch is found.
  278. Return true when the block has feed operators with matching info.
  279. Args:
  280. block: a block instance (typically global block of a program)
  281. feed_targets: a dictionary of {feed_target_name: feed_target_data}
  282. feed_holder_name: the name of the variable that holds the data of
  283. all feed targets. The type of this feed_holder variable is
  284. FEED_MINIBATCH, which is essentially vector<LoDTensor>.
  285. Returns:
  286. A boolean value that indicates whether a block has feed operators
  287. that match the info contained in feed_targets and feed_holder_name.
  288. """
  289. feed_count = 0
  290. for op in block.ops:
  291. if op.desc.type() == 'feed':
  292. feed_count += 1
  293. assert op.desc.input('X')[0] == feed_holder_name
  294. feed_target_name = op.desc.output('Out')[0]
  295. if feed_target_name not in feed_targets:
  296. raise Exception(
  297. f"'feed_targets' does not have {feed_target_name} variable"
  298. )
  299. else:
  300. break
  301. if feed_count > 0 and feed_count != len(feed_targets):
  302. raise Exception(
  303. "Feed operators in program desc do not match 'feed_targets'"
  304. )
  305. return feed_count > 0
  306. def has_fetch_operators(
  307. block, fetch_targets, fetch_holder_name, fetch_op='fetch'
  308. ):
  309. """Check whether the block already has fetch operators.
  310. Return false if the block does not have any fetch operators.
  311. If some fetch operators have been appended to the block, check that
  312. the info contained in these fetch operators matches the fetch_targets
  313. and fetch_holder_name. Raise exception when any mismatch is found.
  314. Return true when the block has fetch operators with matching info.
  315. Args:
  316. block: a block instance (typically global block of a program)
  317. fetch_targets: a dictionary of {fetch_target_name: fetch_target_data}
  318. fetch_holder_name: the name of the variable that holds the data of
  319. all fetch targets. The type of this fetch_holder variable is
  320. FETCH_LIST, which is essentially vector<LoDTensor>.
  321. fetch_op: the operator name of fetch
  322. Return:
  323. A boolean value that indicates whether a block has fetch operators
  324. that match the info contained in fetch_targets and fetch_holder_name.
  325. """
  326. fetch_count = 0
  327. for op in block.ops:
  328. if op.desc.type() == fetch_op:
  329. fetch_count += 1
  330. assert op.desc.output('Out')[0] == fetch_holder_name
  331. fetch_target_name = op.desc.input('X')[0]
  332. if fetch_target_name not in [
  333. var.desc.name() for var in fetch_targets
  334. ]:
  335. raise Exception(
  336. f"'fetch_targets' does not have {fetch_target_name} variable"
  337. )
  338. idx = op.desc.attr('col')
  339. assert fetch_target_name == fetch_targets[idx].desc.name()
  340. if fetch_count > 0 and fetch_count != len(fetch_targets):
  341. raise Exception(
  342. "Fetch operators in program desc do not match 'fetch_targets'"
  343. )
  344. return fetch_count > 0
  345. def has_fetch_operations(
  346. block, fetch_targets, fetch_holder_name, fetch_op='pd_op.fetch'
  347. ):
  348. """Check whether the block already has fetch operation.
  349. Return false if the block does not have any fetch operation.
  350. If some fetch operation have been appended to the block, check that
  351. the info contained in these fetch operation matches the fetch_targets.
  352. Raise exception when any mismatch is found.
  353. Return true when the block has fetch operation with matching info.
  354. Args:
  355. block: a block instance (typically global block of a program)
  356. fetch_targets: a list of fetch_target_data
  357. fetch_op: the operator name of fetch
  358. Return:
  359. A boolean value that indicates whether a block has fetch operators
  360. that match the info contained in fetch_targets.
  361. """
  362. from paddle.autograd.backward_utils import ValueSet
  363. fetch_info = [[], []]
  364. for op in block.ops:
  365. if op.name() == fetch_op:
  366. fetch_info[0].append(op.operand_source(0))
  367. fetch_info[1].append(op.attrs()["name"])
  368. need_fetch_info = []
  369. for i, fetch_var in enumerate(fetch_targets):
  370. if isinstance(fetch_var, str):
  371. if fetch_var not in fetch_info[1]:
  372. raise Exception(
  373. f"Found fetch_target[{i}] is type(str) and doesn't have fetch op."
  374. )
  375. elif fetch_var not in ValueSet(fetch_info[0]):
  376. need_fetch_info.append(fetch_var)
  377. return need_fetch_info
  378. def _add_feed_fetch_ops(
  379. program, feed, fetch_list, feed_var_name, fetch_var_name, use_fetch_v2=False
  380. ):
  381. tmp_program = program.clone()
  382. global_block = tmp_program.global_block()
  383. if feed_var_name in global_block.vars:
  384. feed_var = global_block.var(feed_var_name)
  385. else:
  386. feed_var = global_block.create_var(
  387. name=feed_var_name,
  388. type=core.VarDesc.VarType.FEED_MINIBATCH,
  389. persistable=True,
  390. )
  391. if fetch_var_name in global_block.vars:
  392. fetch_var = global_block.var(fetch_var_name)
  393. else:
  394. fetch_var = global_block.create_var(
  395. name=fetch_var_name,
  396. type=core.VarDesc.VarType.FETCH_LIST,
  397. persistable=True,
  398. )
  399. # prepend feed operators
  400. if not has_feed_operators(global_block, feed, feed_var_name):
  401. for i, name in enumerate(feed):
  402. if global_block.has_var(name):
  403. out = global_block.var(name)
  404. global_block._prepend_op(
  405. type='feed',
  406. inputs={'X': [feed_var]},
  407. outputs={'Out': [out]},
  408. attrs={'col': i},
  409. )
  410. else:
  411. warnings.warn(
  412. "The variable %s is not found in program. It is not declared or is pruned."
  413. % name
  414. )
  415. if use_fetch_v2:
  416. fetch_op = 'fetch_v2'
  417. else:
  418. fetch_op = 'fetch'
  419. # append fetch_operators
  420. if not has_fetch_operators(
  421. global_block, fetch_list, fetch_var_name, fetch_op
  422. ):
  423. for i, var in enumerate(fetch_list):
  424. assert isinstance(
  425. var, (Variable, str)
  426. ), f"Wrong type for fetch_list[{i}]: {type(var)}"
  427. global_block.append_op(
  428. type=fetch_op,
  429. inputs={'X': [var]},
  430. outputs={'Out': [fetch_var]},
  431. attrs={'col': i},
  432. )
  433. return tmp_program
  434. def _add_pir_fetch_ops(program, fetch_list, fetch_var_name):
  435. import paddle
  436. global_block = program.global_block()
  437. fetch_op = "pd_op.fetch"
  438. need_fetch_info = has_fetch_operations(
  439. global_block, fetch_list, fetch_var_name, fetch_op
  440. )
  441. if need_fetch_info:
  442. with paddle.static.program_guard(program):
  443. for i, fetch_input in enumerate(need_fetch_info):
  444. assert isinstance(
  445. fetch_input, Value
  446. ), f"Wrong type for fetch_list[{i}]: {type(fetch_input)}"
  447. out = paddle._pir_ops.fetch(
  448. fetch_input, fetch_var_name + str(i), i
  449. )
  450. out.persistable = True
  451. def _merge_tensors(tensor, micro_batch_num):
  452. if micro_batch_num <= 1:
  453. return tensor
  454. assert len(tensor) % micro_batch_num == 0
  455. chunk_tensor = [
  456. tensor[i : i + micro_batch_num]
  457. for i in range(0, len(tensor), micro_batch_num)
  458. ]
  459. return [np.array(chunk) for chunk in chunk_tensor]
  460. def _apply_inplace_addto_pass(
  461. program, enable_inplace, enable_addto, skip_var_names
  462. ):
  463. use_cuda = True if core.is_compiled_with_cuda() else False
  464. attrs = {"use_cuda": use_cuda, "mem_opt_skip_vars": skip_var_names}
  465. attr_types = {"use_cuda": "bool", "mem_opt_skip_vars": "list[str]"}
  466. empty_startup_program = Program()
  467. if enable_inplace:
  468. pass_name = "buffer_shared_inplace_pass"
  469. _apply_pass(
  470. program, empty_startup_program, pass_name, attrs, attr_types
  471. )
  472. if enable_addto and use_cuda:
  473. pass_name = "inplace_addto_op_pass"
  474. _apply_pass(
  475. program, empty_startup_program, pass_name, attrs, attr_types
  476. )
  477. def _fetch_var(name, scope=None, return_numpy=True):
  478. """
  479. Fetch the value of the variable with the given name from the
  480. given scope.
  481. Args:
  482. name(str): name of the variable. Typically, only persistable variables
  483. can be found in the scope used for running the program.
  484. scope(core.Scope|None): scope object. It should be the scope where
  485. you pass to Executor.run() when running your program.
  486. If None, global_scope() will be used. Default None.
  487. return_numpy(bool): whether convert the tensor to numpy.ndarray.
  488. Default True.
  489. Returns:
  490. LodTensor|numpy.ndarray
  491. """
  492. assert isinstance(name, str)
  493. if scope is None:
  494. scope = global_scope()
  495. assert isinstance(scope, core._Scope)
  496. var = scope.find_var(_to_name_str(name))
  497. assert var is not None, (
  498. "Cannot find " + name + " in scope. Perhaps you need to make the"
  499. " variable persistable by using var.persistable = True in your"
  500. " program."
  501. )
  502. tensor = var.get_tensor()
  503. if return_numpy:
  504. tensor = as_numpy(tensor, copy=True)
  505. return tensor
  506. def _to_name_str(var):
  507. def _to_str(var):
  508. if isinstance(var, Variable):
  509. return var.desc.name()
  510. elif isinstance(var, str):
  511. return var
  512. elif isinstance(var, str):
  513. return str(var)
  514. elif isinstance(var, Operator):
  515. return str(id(var))
  516. elif isinstance(var, Value):
  517. return str(var)
  518. elif isinstance(var, Value):
  519. return str(var)
  520. else:
  521. raise TypeError(str(var) + " should be Variable, Operator or str")
  522. # NOTEz(zhiqiu): The item in fetch_list may be tuple returned by Optimizer.minimize(),
  523. # see comments in _split_optimize_ops_in_fetch_list for more details.
  524. if isinstance(var, tuple):
  525. var = var[0]
  526. if isinstance(var, list):
  527. s = [_to_str(item) for item in var]
  528. return ','.join(s)
  529. else:
  530. return _to_str(var)
  531. def _prepare_fleet_executor():
  532. from ..distributed.backup_env import getenv_or_backup
  533. from ..distributed.fleet.proto import fleet_executor_desc_pb2
  534. trainer_endpoints_str = getenv_or_backup("PADDLE_TRAINER_ENDPOINTS", "")
  535. trainer_endpoints = trainer_endpoints_str.split(',')
  536. fleet_exe_desc = fleet_executor_desc_pb2.FleetExecutorDesc()
  537. cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0))
  538. fleet_exe_desc.cur_rank = cur_rank
  539. nrank = len(trainer_endpoints)
  540. for rank, endpoint in enumerate(trainer_endpoints):
  541. rank_info = fleet_executor_desc_pb2.RankInfo()
  542. rank_info.rank = rank
  543. rank_info.ip_port = endpoint
  544. fleet_exe_desc.cluster_info.append(rank_info)
  545. fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString())
  546. return fleet_exe
  547. def _get_strong_program_cache_key_for_new_exe(program, scope, feed, fetch_list):
  548. if isinstance(program, PirProgram):
  549. return (
  550. str(program)
  551. + str(scope.raw_address())
  552. + _get_program_cache_key(feed, fetch_list)
  553. )
  554. else:
  555. return (
  556. program.desc.cached_hash_str()
  557. + str(scope.raw_address())
  558. + _get_program_cache_key(feed, fetch_list)
  559. )
  560. def _get_strong_program_cache_key(program, feed, fetch_list):
  561. # TODO(zhiqiu): use hash_str to generate cache key as above
  562. def _get_varname_from_block(block):
  563. block_str = []
  564. for var_name in list(block.vars.keys()):
  565. block_str.append(var_name)
  566. return "\n".join(block_str)
  567. inner_program = (
  568. program._program
  569. if isinstance(program, compiler.CompiledProgram)
  570. else program
  571. )
  572. return (
  573. _get_varname_from_block(inner_program.blocks[0])
  574. + str(id(program))
  575. + _get_program_cache_key(feed, fetch_list)
  576. )
  577. def _get_feed_fetch_var_names(feed, fetch_list):
  578. feed_var_names = []
  579. if isinstance(feed, dict):
  580. feed_var_names = list(feed.keys())
  581. elif isinstance(feed, (list, tuple)):
  582. for i, each in enumerate(feed):
  583. feed_var_names += list(each.keys())
  584. fetch_var_names = list(map(_to_name_str, fetch_list))
  585. return feed_var_names + fetch_var_names
  586. def _get_program_cache_key(feed, fetch_list):
  587. return str(_get_feed_fetch_var_names(feed, fetch_list))
  588. def _as_lodtensor(data, place, dtype=None):
  589. """
  590. Convert numpy.ndarray to Tensor, its only support Tensor without LoD information.
  591. For higher dimensional sequence data, please use LoDTensor directly.
  592. Examples:
  593. .. code-block:: python
  594. >>> import numpy as np
  595. >>> import paddle.base as base
  596. >>> place = base.CPUPlace()
  597. >>> exe = base.Executor(place)
  598. >>> data = np.array((100, 200, 300))
  599. >>> np_outs = map(lambda x: base.executor._as_lodtensor(x, place), data)
  600. Args:
  601. data(numpy.ndarray|list|tuple|scalar): a instance of array, scalar, list or tuple
  602. data(core.Place): the place of created tensor
  603. dtype(core.VarDesc.VarType|str): the expected data type of created tensor
  604. Returns:
  605. LoDTensor
  606. """
  607. # NOTE(zhiqiu): convert python builtin, like float, int, and list, to numpy ndarray
  608. if not isinstance(data, np.ndarray):
  609. assert (
  610. dtype is not None
  611. ), 'The dtype should be given when feed data is not np.ndarray'
  612. dtype = (
  613. convert_dtype(dtype)
  614. if isinstance(dtype, core.VarDesc.VarType)
  615. else dtype
  616. )
  617. if np.isscalar(data):
  618. data = np.array(data).astype(dtype)
  619. elif isinstance(data, (list, tuple)):
  620. data = np.array(data)
  621. if data.dtype == np.object_:
  622. raise TypeError(
  623. "\n\tFailed to convert input data to a regular ndarray :\n\t* Usually "
  624. "this means the input data contains nested lists with different lengths. "
  625. "Please consider using 'base.create_lod_tensor' to convert it to a LoD-Tensor."
  626. )
  627. data = data.astype(dtype)
  628. else:
  629. raise TypeError(
  630. f"Convert data of type {type(data)} to Tensor is not supported"
  631. )
  632. # convert numpy.ndarray to tensor
  633. tensor = core.LoDTensor()
  634. tensor.set(data, place)
  635. return tensor
  636. def _can_use_interpreter_core(program, place):
  637. compiled = isinstance(program, compiler.CompiledProgram) or isinstance(
  638. program._graph, compiler.CompiledProgram
  639. )
  640. if compiled:
  641. compiled_program = (
  642. program
  643. if isinstance(program, compiler.CompiledProgram)
  644. else program._graph
  645. )
  646. # Unsupported case 1: inference
  647. if compiled_program._is_inference:
  648. warnings.warn(
  649. "Standalone executor is not used for inference",
  650. UserWarning,
  651. )
  652. return False
  653. return True
  654. @lru_cache
  655. def _warning_once(msg):
  656. logging.warning(msg)
  657. class FetchHandler:
  658. def __init__(self, var_dict=None, period_secs=60):
  659. assert var_dict is not None
  660. self.var_dict = var_dict
  661. self.period_secs = period_secs
  662. def handler(self, res_dict):
  663. for key in res_dict:
  664. if type(res_dict[key]) is np.ndarray:
  665. sys.stdout.write(f"{key}[0]: {res_dict[key][0]} ")
  666. sys.stdout.write("\n")
  667. @staticmethod
  668. def help():
  669. print(
  670. """
  671. class FetchHandlerExample(FetchHandler):
  672. def handler(self, res_dict):
  673. print(res_dict["auc"])
  674. print("auc: {}, {}".format(res_dict["auc"], time.ctime()))
  675. auc = Variable()
  676. var_dict = {"auc": auc}
  677. handler = FetchHandlerExample(var_dict=var_dict)
  678. """
  679. )
  680. class _StandaloneExecutor:
  681. def __init__(self, place, plan, scope):
  682. self._place = core.Place()
  683. self._place.set_place(place)
  684. self._plan = plan
  685. self._scope = scope
  686. self._new_exe = self._create_new_executor()
  687. def run(
  688. self, feed_names, return_numpy=True, enable_job_schedule_profiler=False
  689. ):
  690. """
  691. Args:
  692. feed_names(list): This parameter represents the input names of the model.
  693. fetch_list(list): This parameter represents the Tensors that need to be returned
  694. after the model runs. The default is None.
  695. return_numpy(bool): This parameter indicates whether convert the fetched Tensors
  696. (the Tensor specified in the fetch list) to numpy.ndarray. if it is False,
  697. the type of the return value is a list of :code:`LoDTensor`. The default is True.
  698. """
  699. tensors = self._new_exe.run(
  700. feed_names, enable_job_schedule_profiler
  701. )._move_to_list()
  702. if return_numpy:
  703. tensors = as_numpy(tensors, copy=True)
  704. if not get_flags("FLAGS_enable_pir_in_executor")[
  705. 'FLAGS_enable_pir_in_executor'
  706. ]:
  707. return _merge_tensors(tensors, self._plan.micro_batch_num())
  708. return tensors
  709. else:
  710. if self._plan.micro_batch_num() > 1:
  711. raise RuntimeError(
  712. "`merge_tensor` does not support when return_numpy is False."
  713. )
  714. return tensors
  715. def run_profile(self, feed_names) -> core.ProgramDesc:
  716. program_desc = self._new_exe.run_profile(feed_names)
  717. return program_desc
  718. def _create_new_executor(self):
  719. new_exe = core.StandaloneExecutor(self._place, self._plan, self._scope)
  720. return new_exe
  721. class _ExecutorCache:
  722. class _CachedData:
  723. def __init__(
  724. self,
  725. program,
  726. feed,
  727. fetch_list,
  728. feed_var_name,
  729. fetch_var_name,
  730. place,
  731. scope,
  732. ):
  733. self.program = program
  734. self.feed = feed
  735. self.fetch_list = fetch_list
  736. self.feed_var_name = feed_var_name
  737. self.fetch_var_name = fetch_var_name
  738. self.place = place
  739. self.scope = scope
  740. # NOTE(Ruibiao): Not all changeable item is considered for key at present,
  741. # ONLY: program, feed, and fetch_list
  742. if isinstance(self.program, compiler.CompiledProgram):
  743. if not self.program._program:
  744. # The program holds no _program, maybe it is constructed by graph.
  745. # Convert graph to program in order to generate key.
  746. self.program._program = framework.IrGraph(
  747. self.program._graph
  748. ).to_program()
  749. self.key = hash(
  750. _get_strong_program_cache_key_for_new_exe(
  751. self.program._program,
  752. self.scope,
  753. self.feed,
  754. self.fetch_list,
  755. )
  756. )
  757. else:
  758. self.key = hash(
  759. _get_strong_program_cache_key_for_new_exe(
  760. self.program, self.scope, self.feed, self.fetch_list
  761. )
  762. )
  763. def __eq__(self, other):
  764. return (
  765. isinstance(other, _ExecutorCache._CachedData)
  766. and self.key == other.key
  767. )
  768. def __hash__(self):
  769. return self.key
  770. def __init__(self):
  771. # NOTE(Ruibiao): Wrap the lru_cache in constructor so that the cache is local to
  772. # the _ExecutorCache instance, otherwise a global cache may not be released after
  773. # the Executor instance deleted
  774. self._get_cached_program_and_executor = lru_cache(maxsize=8)(
  775. self._get_program_and_executor
  776. )
  777. self._get_cached_program_and_executor_pir_mode = lru_cache(maxsize=8)(
  778. self._get_pir_program_and_executor
  779. )
  780. def clear(self):
  781. self._get_cached_program_and_executor.cache_clear()
  782. def get_program_and_executor(
  783. self,
  784. program,
  785. feed,
  786. fetch_list,
  787. feed_var_name,
  788. fetch_var_name,
  789. place,
  790. scope,
  791. ):
  792. return self._get_cached_program_and_executor(
  793. self._CachedData(
  794. program,
  795. feed,
  796. fetch_list,
  797. feed_var_name,
  798. fetch_var_name,
  799. place,
  800. scope,
  801. )
  802. )
  803. def _get_program_and_executor(self, cached_data):
  804. program = cached_data.program
  805. inner_program = (
  806. program._program
  807. if isinstance(program, compiler.CompiledProgram)
  808. else program
  809. )
  810. feed = cached_data.feed
  811. fetch_list = cached_data.fetch_list
  812. feed_var_name = cached_data.feed_var_name
  813. fetch_var_name = cached_data.fetch_var_name
  814. place = cached_data.place
  815. scope = cached_data.scope
  816. # To apply IR pass, compile the Program to IrGraph and convert it back to Program
  817. if isinstance(program, compiler.CompiledProgram) or isinstance(
  818. program._graph, compiler.CompiledProgram
  819. ):
  820. compiled_program = (
  821. program
  822. if isinstance(program, compiler.CompiledProgram)
  823. else program._graph
  824. )
  825. build_strategy = compiled_program._build_strategy
  826. # print(f"Program before convert:\n {inner_program}", flush=True)
  827. use_cuda_graph = False
  828. # When using cuda graph, the cuda graph preparation logic in PE is not
  829. # executed, but it is processed in the constructor of new executor.
  830. if (
  831. build_strategy is not None
  832. and build_strategy.allow_cuda_graph_capture
  833. ):
  834. use_cuda_graph = True
  835. build_strategy.allow_cuda_graph_capture = False
  836. set_flags({"FLAGS_new_executor_use_cuda_graph": True})
  837. compiled_program._compile(scope, place)
  838. if use_cuda_graph:
  839. build_strategy.allow_cuda_graph_capture = True
  840. ir_graph = framework.IrGraph(compiled_program._graph)
  841. converted_program = ir_graph.to_program()
  842. if hasattr(inner_program, 'lr_scheduler'):
  843. converted_program.lr_scheduler = inner_program.lr_scheduler
  844. inner_program = converted_program
  845. # print(f"Program after convert:\n {inner_program}", flush=True)
  846. else:
  847. build_strategy = None
  848. from paddle.incubate.autograd import prim2orig, prim_enabled
  849. if prim_enabled() and program == default_main_program():
  850. prim2orig()
  851. inner_program = program
  852. program = _add_feed_fetch_ops(
  853. program=inner_program,
  854. feed=feed,
  855. fetch_list=fetch_list,
  856. feed_var_name=feed_var_name,
  857. fetch_var_name=fetch_var_name,
  858. use_fetch_v2=True,
  859. )
  860. # standalone executor will apply buffer_shared_inplace_pass and
  861. # inplace_addto_op_pass to program according to build_strategy
  862. enable_inplace = (
  863. True
  864. if build_strategy is None or build_strategy.enable_inplace
  865. else False
  866. )
  867. enable_addto = (
  868. True
  869. if build_strategy is not None and build_strategy.enable_addto
  870. else False
  871. )
  872. if get_flags('FLAGS_enable_pir_in_executor')[
  873. 'FLAGS_enable_pir_in_executor'
  874. ]:
  875. # todo(phlrain), skip inplace add addto pass in new IR
  876. enable_inplace = False
  877. enable_addto = False
  878. if enable_inplace or enable_addto:
  879. # inplace should skip feed and fetch var
  880. skip_var_names = _get_feed_fetch_var_names(feed, fetch_list)
  881. _apply_inplace_addto_pass(
  882. program, enable_inplace, enable_addto, skip_var_names
  883. )
  884. new_program = program.clone()
  885. if (
  886. new_program._pipeline_opt
  887. and "standalone_opt" in new_program._pipeline_opt
  888. ):
  889. from paddle.distributed.passes.pipeline_scheduler_pass import (
  890. apply_pass,
  891. )
  892. standalone_opt = new_program._pipeline_opt["standalone_opt"]
  893. pass_name = standalone_opt["schedule_mode"]
  894. plan = apply_pass(
  895. new_program, new_program, pass_name, standalone_opt
  896. )
  897. else:
  898. default_job = core.Job("default")
  899. if get_flags("FLAGS_enable_pir_in_executor")[
  900. 'FLAGS_enable_pir_in_executor'
  901. ]:
  902. # if enables distributed training with prim mechanism (prim is behind of distributed)
  903. # step 1: translate program to pir program.
  904. # step 2: decompose PHI ops in pir program into prim ops.
  905. # When decomposing backward ops, the grad_var_to_var in distributed context is needed to finding corresponding forward op.
  906. if (
  907. os.getenv("FLAGS_enable_prim_after_distribute")
  908. in ['True', 'true', '1']
  909. and new_program._need_decomp
  910. ):
  911. (
  912. pir_program,
  913. param_mapping,
  914. ) = translate_to_pir_with_param_map(new_program.desc)
  915. from paddle.decomposition import decomp
  916. decomp.decompose_pir_program(
  917. pir_program, param_mapping, new_program._grad_var_to_var
  918. )
  919. if in_cinn_mode():
  920. apply_cinn_pass(pir_program)
  921. type_to_program = {"default": pir_program}
  922. else:
  923. type_to_program = {
  924. "default": translate_to_pir(new_program.desc)
  925. }
  926. else:
  927. type_to_program = {"default": new_program.desc}
  928. plan = core.Plan([default_job], type_to_program)
  929. if (
  930. new_program._pass_opt
  931. and "pass_list" in new_program._pass_opt
  932. and len(new_program._pass_opt['pass_list']) > 0
  933. ):
  934. pm = pir.PassManager()
  935. for p in new_program._pass_opt['pass_list']:
  936. # Temporary implementation, it will be refined when auto_parallel refactored
  937. if p == 'eliminate_transpose':
  938. from paddle.distributed.auto_parallel.static.pir_pass import (
  939. eliminate_transpose_by_reshape,
  940. )
  941. for job_type in plan.job_types():
  942. ir_program = plan.ir_program(job_type)
  943. eliminate_transpose_by_reshape(ir_program)
  944. else:
  945. pm.add_pass(p, {})
  946. for job_type in plan.job_types():
  947. ir_program = plan.ir_program(job_type)
  948. pm.run(ir_program)
  949. new_exe = _StandaloneExecutor(place, plan, scope)
  950. return new_program, new_exe
  951. def get_pir_program_and_executor(
  952. self,
  953. program,
  954. feed,
  955. fetch_list,
  956. feed_var_name,
  957. fetch_var_name,
  958. place,
  959. scope,
  960. ):
  961. return self._get_cached_program_and_executor_pir_mode(
  962. self._CachedData(
  963. program,
  964. feed,
  965. fetch_list,
  966. feed_var_name,
  967. fetch_var_name,
  968. place,
  969. scope,
  970. )
  971. )
  972. def _get_pir_program_and_executor(self, cached_data):
  973. program = cached_data.program
  974. feed = cached_data.feed
  975. fetch_list = cached_data.fetch_list
  976. feed_var_name = cached_data.feed_var_name
  977. fetch_var_name = cached_data.fetch_var_name
  978. place = cached_data.place
  979. scope = cached_data.scope
  980. _add_pir_fetch_ops(
  981. program, fetch_list=fetch_list, fetch_var_name=fetch_var_name
  982. )
  983. default_job = core.Job("default")
  984. type_to_program = {"default": program}
  985. plan = core.Plan([default_job], type_to_program)
  986. new_exe = _StandaloneExecutor(place, plan, scope)
  987. data_op_infos = []
  988. global_block = program.global_block()
  989. for op in global_block.ops:
  990. if op.name() == 'pd_op.data':
  991. feed_target_name = op.attrs()["name"]
  992. var_type = paddle_type_to_proto_type[op.attrs()["dtype"]]
  993. var_shape = op.attrs()["shape"]
  994. tup = (
  995. feed_target_name,
  996. var_type,
  997. var_shape,
  998. op.result(0).persistable,
  999. )
  1000. data_op_infos.append(tup)
  1001. from paddle.decomposition import decomp
  1002. if core._enable_dist_prim_all():
  1003. with decomp.prim_guard():
  1004. decomp.decompose_dist_program(program)
  1005. if in_cinn_mode():
  1006. apply_cinn_pass(program)
  1007. return program, new_exe, data_op_infos
  1008. class Executor:
  1009. """
  1010. :api_attr: Static Graph
  1011. An Executor in Python, supports single/multiple-GPU running,
  1012. and single/multiple-CPU running.
  1013. Args:
  1014. place(paddle.CPUPlace()|paddle.CUDAPlace(n)|str|None): This parameter represents
  1015. which device the executor runs on. When this parameter is None, PaddlePaddle
  1016. will set the default device according to its installation version. If Paddle
  1017. is CPU version, the default device would be set to `CPUPlace()` . If Paddle is
  1018. GPU version, the default device would be set to `CUDAPlace(0)` . Default is None.
  1019. If ``place`` is string, it can be ``cpu``, and ``gpu:x``, where ``x``
  1020. is the index of the GPUs. Note: users only pass one Place or None to initialize
  1021. Executor when using multiple-cards. Other APIs will override the cards. See
  1022. `document for multiple-cards <https://www.paddlepaddle.org.cn/documentation/docs/en/develop/guides/01_paddle2.0_introduction/update_en.html#stand-alone-multi-card-launch>`_
  1023. Returns:
  1024. Executor
  1025. Examples:
  1026. .. code-block:: python
  1027. >>> import paddle
  1028. >>> import numpy
  1029. >>> import os
  1030. >>> # Executor is only used in static graph mode
  1031. >>> paddle.enable_static()
  1032. >>> # Set place explicitly.
  1033. >>> # use_cuda = True
  1034. >>> # place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
  1035. >>> # exe = paddle.static.Executor(place)
  1036. >>> # If you don't set place, PaddlePaddle sets the default device.
  1037. >>> exe = paddle.static.Executor()
  1038. >>> train_program = paddle.static.Program()
  1039. >>> startup_program = paddle.static.Program()
  1040. >>> with paddle.static.program_guard(train_program, startup_program):
  1041. ... data = paddle.static.data(name='X', shape=[None, 1], dtype='float32')
  1042. ... hidden = paddle.static.nn.fc(data, 10)
  1043. ... loss = paddle.mean(hidden)
  1044. ... paddle.optimizer.SGD(learning_rate=0.01).minimize(loss)
  1045. ...
  1046. >>> # Run the startup program once and only once.
  1047. >>> # Not need to optimize/compile the startup program.
  1048. >>> exe.run(startup_program)
  1049. >>> # Run the main program directly without compile.
  1050. >>> x = numpy.random.random(size=(10, 1)).astype('float32')
  1051. >>> loss_data, = exe.run(train_program, feed={"X": x}, fetch_list=[loss.name])
  1052. >>> # Or, compiled the program and run. See `CompiledProgram`
  1053. >>> # for more details.
  1054. >>> compiled_prog = paddle.static.CompiledProgram(
  1055. ... train_program)
  1056. >>> loss_data, = exe.run(compiled_prog, feed={"X": x}, fetch_list=[loss.name])
  1057. """
  1058. def __init__(self, place=None):
  1059. if place is None:
  1060. expected_place = framework._current_expected_place_()
  1061. self.place = expected_place
  1062. else:
  1063. self.place = framework._get_paddle_place(place)
  1064. self.program_caches = {}
  1065. self.ctx_caches = {}
  1066. self.trainer_caches = {}
  1067. self.scope_caches = {}
  1068. self.micro_scope_cache = {}
  1069. self.var_caches = {}
  1070. self.pruned_program_caches = {}
  1071. p = core.Place()
  1072. p.set_place(self.place)
  1073. self._default_executor = core.Executor(p)
  1074. self._closed = False
  1075. self.pruned_program_scope_caches = {}
  1076. self._prepare_to_run_called = False
  1077. self._auto_checkpoint_name = unique_name.generate(
  1078. "__auto_checkpoint_executor__"
  1079. )
  1080. self._executor_cache = _ExecutorCache()
  1081. self._fleet_executor = None
  1082. # TODO(liyurui): This option will be removed and always true when the functionality
  1083. # of fleet executor with standalone executor is ready.
  1084. self._fleet_executor_with_standalone = False
  1085. self.op_role_key = core.op_proto_and_checker_maker.kOpRoleAttrName()
  1086. self.enable_job_schedule_profiler = False
  1087. def _is_optimizer_op(self, op):
  1088. return self.op_role_key in op.attr_names and int(
  1089. op.all_attrs()[self.op_role_key]
  1090. ) & int(core.op_proto_and_checker_maker.OpRole.Optimize)
  1091. def __del__(self):
  1092. # NOTE(Ruibiao): The manually call of clear is required. Because in Python, executor_cache
  1093. # may not immediately destructed after Executor instance deleted (so does not the _StandaloneExecutor),
  1094. # that brings errors to mkl-dnn unit tests (see ClearMKLDNNCache in interpretercore.cc for why).
  1095. self.close()
  1096. self._executor_cache.clear()
  1097. def _get_scope_cache(self, program_cache_key):
  1098. return self.scope_caches.get(program_cache_key, None)
  1099. def _get_ctx_cache(self, program_cache_key):
  1100. return self.ctx_caches.get(program_cache_key, None)
  1101. def _get_trainer_cache(self, program_cache_key):
  1102. return self.trainer_caches.get(program_cache_key, None)
  1103. def _get_program_cache(self, program_cache_key):
  1104. return self.program_caches.get(program_cache_key, None)
  1105. def _add_program_cache(self, program_cache_key, program):
  1106. self.program_caches[program_cache_key] = program
  1107. def _get_pruned_program_cache(self, program_cache_key):
  1108. return self.pruned_program_caches.get(program_cache_key, None)
  1109. def _add_pruned_program_cache(self, program_cache_key, program):
  1110. self.pruned_program_caches[program_cache_key] = program
  1111. def _get_pruned_program_scope_cache(self, program_cache_key):
  1112. return self.pruned_program_scope_caches.get(program_cache_key, None)
  1113. def _add_pruned_program_scope_cache(self, program_cache_key, program):
  1114. self.pruned_program_scope_caches[program_cache_key] = program
  1115. def _add_ctx_cache(self, ctx_cache_key, ctx):
  1116. self.ctx_caches[ctx_cache_key] = ctx
  1117. def _add_trainer_cache(self, trainer_cache_key, ctx):
  1118. self.trainer_caches[trainer_cache_key] = ctx
  1119. def _add_scope_cache(self, scope_cache_key, scope):
  1120. self.scope_caches[scope_cache_key] = scope
  1121. def _add_micro_scopes_cache(self, program_cache_key, micro_scopes: list):
  1122. self.micro_scope_cache[program_cache_key] = micro_scopes
  1123. def _get_micro_scopes_cache(self, program_cache_key):
  1124. return self.micro_scope_cache.get(program_cache_key, None)
  1125. def _log_force_set_program_cache(self, use_program_cache):
  1126. _warning_once(
  1127. f"use_program_cache is force set to {use_program_cache} by FLAGS_FORCE_USE_PROGRAM_CACHE"
  1128. )
  1129. def _feed_data(self, program, feed, feed_var_name, scope):
  1130. # feed var to framework
  1131. global_block = program.global_block()
  1132. for op in global_block.ops:
  1133. if op.desc.type() == 'feed':
  1134. feed_target_name = op.desc.output('Out')[0]
  1135. cur_feed = feed[feed_target_name]
  1136. var = global_block.var(feed_target_name)
  1137. if var.dtype != core.VarDesc.VarType.STRINGS:
  1138. if not isinstance(cur_feed, core.LoDTensor):
  1139. cur_feed = _as_lodtensor(
  1140. cur_feed, self.place, var.dtype
  1141. )
  1142. check_feed_shape_type(var, cur_feed)
  1143. idx = op.desc.attr('col')
  1144. pir_flag_name = 'FLAGS_enable_pir_in_executor'
  1145. if get_flags(pir_flag_name)[pir_flag_name]:
  1146. core.set_feed_variable(
  1147. scope, cur_feed, feed_target_name, idx
  1148. )
  1149. else:
  1150. micro_cur_feed = [cur_feed]
  1151. num_micro_batch = 1
  1152. if (
  1153. program._pipeline_opt
  1154. and "standalone_opt" in program._pipeline_opt
  1155. ):
  1156. num_micro_batch = program._pipeline_opt[
  1157. "standalone_opt"
  1158. ]["num_micro_batches"]
  1159. batch_size = (
  1160. cur_feed.shape()[0]
  1161. if callable(cur_feed.shape)
  1162. else cur_feed.shape[0]
  1163. )
  1164. assert batch_size % num_micro_batch == 0
  1165. micro_cur_feed = np.split(
  1166. np.array(cur_feed), num_micro_batch, 0
  1167. )
  1168. for i in range(num_micro_batch):
  1169. micro_feed = (
  1170. _as_lodtensor(
  1171. micro_cur_feed[i], self.place, var.dtype
  1172. )
  1173. if num_micro_batch > 1
  1174. else micro_cur_feed[i]
  1175. )
  1176. core.set_feed_variable(
  1177. scope,
  1178. micro_feed,
  1179. feed_var_name,
  1180. idx * num_micro_batch + i,
  1181. )
  1182. else:
  1183. break
  1184. def _pir_feed_data(self, program, feed, scope, data_op_infos):
  1185. # feed var to framework
  1186. feed_target_names = set()
  1187. for data_op_info in data_op_infos:
  1188. feed_target_name = data_op_info[0]
  1189. feed_target_names.add(feed_target_name)
  1190. var_type = data_op_info[1]
  1191. var_shape = data_op_info[2]
  1192. is_persistable = data_op_info[3]
  1193. if feed_target_name not in feed.keys() and is_persistable:
  1194. # If the feed_target_name is not in feed list, but is persistable, maybe it is a optimizer param
  1195. # and don't need feed data.
  1196. continue
  1197. cur_feed = feed[feed_target_name]
  1198. if not isinstance(cur_feed, core.LoDTensor):
  1199. cur_feed = _as_lodtensor(cur_feed, self.place, var_type)
  1200. pir_check_feed_shape_type(
  1201. cur_feed, feed_target_name, var_shape, var_type
  1202. )
  1203. # the last arg of set_feed_variable has no effect in pir, we pass 0 by default.
  1204. core.set_feed_variable(scope, cur_feed, feed_target_name, 0)
  1205. # pop variable which is not found in program
  1206. for feed_name in list(feed.keys()):
  1207. if feed_name not in feed_target_names:
  1208. feed.pop(feed_name)
  1209. warnings.warn(
  1210. "The value %s is not found in program. It is not declared or is pruned."
  1211. % feed_name
  1212. )
  1213. def _fetch_data(self, fetch_list, fetch_var_name, scope):
  1214. outs = [
  1215. core.get_fetch_variable(scope, fetch_var_name, i)
  1216. for i in range(len(fetch_list))
  1217. ]
  1218. return outs
  1219. @classmethod
  1220. def _split_optimize_ops_in_fetch_list(cls, fetch_list):
  1221. """
  1222. Split optimize_ops from fetch_list, which provided to specify program pruning.
  1223. Args:
  1224. fetch_list(list): The original fetch_list.
  1225. Possible types of fetch_list are:
  1226. fetch_list = ['loss']
  1227. fetch_list = [[sgd, sgd], 'loss']
  1228. fetch_list = [([sgd, sgd], [(param, grad)]), 'loss']
  1229. Returns:
  1230. optimize_ops(list): The optimize operators splited from fetch_list.
  1231. fetch_list(list): The updated fetch_list which does not contain optimize operators.
  1232. """
  1233. _optimize_ops = []
  1234. _fetch_list = []
  1235. def _get_targets(_optimize_ops, _fetch_list, item):
  1236. if isinstance(item, Operator):
  1237. if item._is_optimize_op():
  1238. _optimize_ops.append(item)
  1239. else:
  1240. raise TypeError(
  1241. "The operator in fetch_list is not an optimize_op"
  1242. )
  1243. elif isinstance(item, (Variable, str)):
  1244. _fetch_list.append(item)
  1245. else:
  1246. raise TypeError(
  1247. "The item in fetch_list should be str, variable or optimize_op, but received %s.",
  1248. type(item),
  1249. )
  1250. for index, item in enumerate(fetch_list):
  1251. # NOTE(zhiqiu): to support (optimizer_ops, param_and_grads) and optimizer_ops in fetch_list
  1252. # we should handle tuple and list in fetch_list.
  1253. # TODO(zhiqiu): find a better way to handle that.
  1254. if isinstance(item, list):
  1255. for i in item:
  1256. _get_targets(_optimize_ops, _fetch_list, i)
  1257. elif isinstance(item, tuple):
  1258. if not isinstance(item[0], (list, tuple)):
  1259. raise TypeError(
  1260. f"Requires fetch_list[{index}][0] shall be one of (list, tuple) when type(fetch_list[{index}]) is `tuple`, but received fetch_list[{index}][0]'s type is `{type(item[0]).__name__}`."
  1261. )
  1262. for i in item[0]:
  1263. _get_targets(_optimize_ops, _fetch_list, i)
  1264. else:
  1265. _get_targets(_optimize_ops, _fetch_list, item)
  1266. return _fetch_list, _optimize_ops
  1267. @classmethod
  1268. def _prune_program(
  1269. cls, program, feed=None, fetch_list=None, optimize_ops=None
  1270. ):
  1271. """
  1272. Prune operators and variables which are not needed to generate
  1273. :code:`fetch_list` and optimize operators.
  1274. Prune operators and variables which are needed
  1275. to generate variables to be feeded.
  1276. Notes: This is a very low level API. Users should not use this API
  1277. directly.
  1278. Args:
  1279. program(Program): the origin program
  1280. feed(list|dict): feed dict or list.
  1281. fetch_list(list|Variable): A list of variables need to be fetched
  1282. optimize_ops(list[Operator]): A list of optimizer operators
  1283. Returns:
  1284. Program: A new, pruned program.
  1285. """
  1286. compiled = isinstance(program, compiler.CompiledProgram)
  1287. if compiled:
  1288. if program._program:
  1289. origin_program = program._program
  1290. else:
  1291. warnings.warn(
  1292. "The program holds no _program, maybe it is constructed by graph, which can't be pruned yet."
  1293. )
  1294. return
  1295. else:
  1296. origin_program = program
  1297. feed_names = []
  1298. if isinstance(feed, dict):
  1299. feed_names = list(feed.keys())
  1300. elif isinstance(feed, (list, tuple)):
  1301. for i, each in enumerate(feed):
  1302. feed_names += list(each.keys())
  1303. # if optimize_ops is [], all optimize ops in the program is used.
  1304. if not optimize_ops:
  1305. for block in origin_program.blocks:
  1306. for op in block.ops:
  1307. if op._is_optimize_op():
  1308. optimize_ops.append(op)
  1309. targets = fetch_list + optimize_ops
  1310. pruned_program = origin_program._prune_with_input(feed_names, targets)
  1311. if compiled:
  1312. # for compiled program, update the underlying program, re-generate graph,
  1313. # and reset the flag so it can be compiled again.
  1314. program._program = pruned_program
  1315. program._graph = core.Graph(pruned_program.desc)
  1316. program._compiled = False
  1317. else:
  1318. program = pruned_program
  1319. return program
  1320. @classmethod
  1321. def _update_feed(cls, program, feed):
  1322. """
  1323. Update the feed dict, remove the feed item which is pruned in program.
  1324. Notes: This is a very low level API. Users should not use this API
  1325. directly.
  1326. Args:
  1327. program(Program): the pruned program.
  1328. feed(list|dict): feed dict or list.
  1329. Returns:
  1330. feed:(list|dict) updated feed.
  1331. """
  1332. compiled = isinstance(program, compiler.CompiledProgram)
  1333. if compiled:
  1334. if program._program:
  1335. global_block = program._program.global_block()
  1336. else:
  1337. warnings.warn(
  1338. "The program holds no _program, maybe it is constructed by graph."
  1339. )
  1340. return feed
  1341. else:
  1342. global_block = program.global_block()
  1343. if isinstance(feed, dict):
  1344. for feed_name in list(feed.keys()):
  1345. if not global_block.has_var(feed_name):
  1346. feed.pop(feed_name)
  1347. warnings.warn(
  1348. "The variable %s is not found in program. It is not declared or is pruned."
  1349. % feed_name
  1350. )
  1351. elif isinstance(feed, (list, tuple)):
  1352. for i, each in enumerate(feed):
  1353. for feed_name in list(each.keys()):
  1354. if not global_block.has_var(feed_name):
  1355. each.pop(feed_name)
  1356. warnings.warn(
  1357. "The variable %s is not found in program. It is not declared or is pruned."
  1358. % feed_name
  1359. )
  1360. return feed
  1361. '''
  1362. TODO(typhoonzero): Define "no longer use" meaning? Can user create
  1363. a new Executor for the same program and run?
  1364. TODO(panyx0718): Why ParallelExecutor doesn't have close?
  1365. '''
  1366. def close(self):
  1367. """
  1368. Close the executor. This interface is used for distributed training (PServers mode).
  1369. This executor can not be used after calling the interface, because
  1370. this interface releases resources associated with the current Trainer.
  1371. Returns:
  1372. None
  1373. Examples:
  1374. .. code-block:: python
  1375. >>> import paddle
  1376. >>> cpu = paddle.CPUPlace()
  1377. >>> exe = paddle.static.Executor(cpu)
  1378. >>> # execute training or testing
  1379. >>> exe.close()
  1380. """
  1381. if not self._closed:
  1382. self._closed = True
  1383. for k, trainer_instance in self.trainer_caches.items():
  1384. self._default_executor.release_trainer(trainer_instance)
  1385. del trainer_instance
  1386. self._default_executor.close()
  1387. def flush(self):
  1388. """
  1389. flush all trainer param to root_scope
  1390. """
  1391. if self._closed:
  1392. return
  1393. for _, trainer_instance in self.trainer_caches.items():
  1394. self._default_executor.release_trainer(trainer_instance)
  1395. del trainer_instance
  1396. self.trainer_caches.clear()
  1397. def run(
  1398. self,
  1399. program=None,
  1400. feed=None,
  1401. fetch_list=None,
  1402. feed_var_name='feed',
  1403. fetch_var_name='fetch',
  1404. scope=None,
  1405. return_numpy=True,
  1406. use_program_cache=False,
  1407. use_prune=False,
  1408. ):
  1409. """
  1410. Run the specified :code:`Program` or :code:`CompiledProgram`. It should be noted that the executor
  1411. will execute all the operators in :code:`Program` or :code:`CompiledProgram` without pruning some
  1412. operators of the :code:`Program` or :code:`CompiledProgram` according to fetch_list. And you could
  1413. specify the scope to store the :code:`Tensor` during the executor running if the scope
  1414. is not set, the executor will use the global scope, i.e. :code:`paddle.static.global_scope()`.
  1415. Args:
  1416. program(Program|CompiledProgram): This parameter represents the :code:`Program` or
  1417. :code:`CompiledProgram` to be executed. If this parameter is not provided, that
  1418. parameter is None, the program will be set to :code:`paddle.static.default_main_program()`.
  1419. The default is None.
  1420. feed(list|dict): This parameter represents the input Tensors of the model.
  1421. If it is single card training, the feed is dict type, and if it is multi-card
  1422. training, the parameter feed can be dict or list of Tensors. If the
  1423. parameter type is dict, the data in the feed will be split and sent to
  1424. multiple devices (CPU/GPU), that is to say, the input data will be evenly
  1425. sent to different devices, so you should make sure the number of samples of
  1426. the current mini-batch must be greater than the number of places;
  1427. if the parameter type is list, those data are copied directly to each device,
  1428. so the length of this list should be equal to the number of places.
  1429. The default is None.
  1430. fetch_list(list): This parameter represents the Tensors that need to be returned
  1431. after the model runs. The default is None.
  1432. feed_var_name(str): This parameter represents the name of the input Tensor of
  1433. the feed operator. The default is "feed".
  1434. fetch_var_name(str): This parameter represents the name of the output Tensor of
  1435. the fetch operator. The default is "fetch".
  1436. scope(Scope): the scope used to run this program, you can switch
  1437. it to different scope. default is :code:`paddle.static.global_scope()`
  1438. return_numpy(bool): This parameter indicates whether convert the fetched Tensors
  1439. (the Tensor specified in the fetch list) to numpy.ndarray. if it is False,
  1440. the type of the return value is a list of :code:`LoDTensor`. The default is True.
  1441. use_program_cache(bool): This parameter indicates whether the input :code:`Program` is cached.
  1442. If the parameter is True, the model may run faster in the following cases:
  1443. the input program is :code:`paddle.static.Program`, and the parameters(program, feed Tensor name
  1444. and fetch_list Tensor) of this interface remains unchanged during running.
  1445. The default is False.
  1446. use_prune(bool): This parameter indicates whether the input :code:`Program` will be pruned.
  1447. If the parameter is True, the program will be pruned according to the given feed and fetch_list,
  1448. which means the operators and variables in program that generate :code:`feed` and are not
  1449. needed to generate :code:`fetch_list` will be pruned. The default is False, which means the
  1450. program will not pruned and all the operators and variables will be executed during running.
  1451. Note that if the tuple returned from :code:`Optimizer.minimize()` is passed to :code:`fetch_list`,
  1452. :code:`use_prune` will be overridden to True, and the program will be pruned.
  1453. Returns:
  1454. List: The fetched result list.
  1455. Examples:
  1456. .. code-block:: python
  1457. :name: code-example-1
  1458. >>> import paddle
  1459. >>> import numpy
  1460. >>> # First create the Executor.
  1461. >>> paddle.enable_static()
  1462. >>> place = paddle.CPUPlace() # paddle.CUDAPlace(0)
  1463. >>> exe = paddle.static.Executor(place)
  1464. >>> data = paddle.static.data(name='X', shape=[None, 1], dtype='float32')
  1465. >>> hidden = paddle.static.nn.fc(data, 10)
  1466. >>> loss = paddle.mean(hidden)
  1467. >>> adam = paddle.optimizer.Adam()
  1468. >>> adam.minimize(loss)
  1469. >>> i = paddle.zeros(shape=[1], dtype='int64')
  1470. >>> array = paddle.tensor.array_write(x=loss, i=i)
  1471. >>> # Run the startup program once and only once.
  1472. >>> exe.run(paddle.static.default_startup_program())
  1473. >>> x = numpy.random.random(size=(10, 1)).astype('float32')
  1474. >>> loss_val, array_val = exe.run(feed={'X': x},
  1475. ... fetch_list=[loss.name, array.name])
  1476. >>> print(array_val)
  1477. >>> # doctest: +SKIP("Random output")
  1478. [array(0.16870381, dtype=float32)]
  1479. >>> # doctest: -SKIP
  1480. .. code-block:: python
  1481. :name: code-example-2
  1482. >>> # doctest: +REQUIRES(env:GPU)
  1483. >>> import paddle
  1484. >>> import numpy as np
  1485. >>> # First create the Executor.
  1486. >>> paddle.enable_static()
  1487. >>> place = paddle.CUDAPlace(0)
  1488. >>> exe = paddle.static.Executor(place)
  1489. >>> data = paddle.static.data(name='X', shape=[None, 1], dtype='float32')
  1490. >>> class_dim = 2
  1491. >>> prediction = paddle.static.nn.fc(data, class_dim)
  1492. >>> loss = paddle.mean(prediction)
  1493. >>> adam = paddle.optimizer.Adam()
  1494. >>> adam.minimize(loss)
  1495. >>> # Run the startup program once and only once.
  1496. >>> exe.run(paddle.static.default_startup_program())
  1497. >>> build_strategy = paddle.static.BuildStrategy()
  1498. >>> binary = paddle.static.CompiledProgram(
  1499. ... paddle.static.default_main_program(), build_strategy=build_strategy)
  1500. >>> batch_size = 6
  1501. >>> x = np.random.random(size=(batch_size, 1)).astype('float32')
  1502. >>> prediction, = exe.run(binary,
  1503. ... feed={'X': x},
  1504. ... fetch_list=[prediction.name])
  1505. >>> # If the user uses two GPU cards to run this python code, the printed result will be
  1506. >>> # (6, class_dim). The first dimension value of the printed result is the batch_size.
  1507. >>> print("The prediction shape: {}".format(
  1508. ... np.array(prediction).shape))
  1509. The prediction shape: (6, 2)
  1510. >>> print(prediction)
  1511. >>> # doctest: +SKIP("Random output")
  1512. [[-0.37789783 -0.19921964]
  1513. [-0.3577645 -0.18863106]
  1514. [-0.24274671 -0.12814042]
  1515. [-0.24635398 -0.13003758]
  1516. [-0.49232286 -0.25939852]
  1517. [-0.44514108 -0.2345845 ]]
  1518. >>> # doctest: -SKIP
  1519. """
  1520. # Temporary FLAGS, just for testing the performance of program cache
  1521. force_use_program_cache = os.environ.get(
  1522. 'FLAGS_FORCE_USE_PROGRAM_CACHE', None
  1523. )
  1524. if force_use_program_cache is not None:
  1525. use_program_cache = force_use_program_cache in [
  1526. 1,
  1527. '1',
  1528. True,
  1529. 'True',
  1530. 'true',
  1531. ]
  1532. self._log_force_set_program_cache(use_program_cache)
  1533. if in_pir_mode():
  1534. res = self._run_pir_impl(
  1535. program=program,
  1536. feed=feed,
  1537. fetch_list=fetch_list,
  1538. feed_var_name=feed_var_name,
  1539. fetch_var_name=fetch_var_name,
  1540. scope=scope,
  1541. return_numpy=return_numpy,
  1542. )
  1543. else:
  1544. # do type promotion if necessary
  1545. program = process_type_promotion(program)
  1546. res = self._run_impl(
  1547. program=program,
  1548. feed=feed,
  1549. fetch_list=fetch_list,
  1550. feed_var_name=feed_var_name,
  1551. fetch_var_name=fetch_var_name,
  1552. scope=scope,
  1553. return_numpy=return_numpy,
  1554. use_program_cache=use_program_cache,
  1555. use_prune=use_prune,
  1556. )
  1557. core.update_autotune_status()
  1558. return res
  1559. def _run_impl(
  1560. self,
  1561. program,
  1562. feed,
  1563. fetch_list,
  1564. feed_var_name,
  1565. fetch_var_name,
  1566. scope,
  1567. return_numpy,
  1568. use_program_cache,
  1569. use_prune,
  1570. ):
  1571. if self._closed:
  1572. raise RuntimeError("Attempted to use a closed Executor")
  1573. use_default_main_program = program is None
  1574. if program is None:
  1575. program = default_main_program()
  1576. fetch_list = self._check_fetch_list(fetch_list)
  1577. from paddle.distributed.auto_parallel.static.utils import (
  1578. use_new_executor,
  1579. )
  1580. if (
  1581. isinstance(program, Program)
  1582. and program._pipeline_opt
  1583. and not use_new_executor()
  1584. ):
  1585. if "fleet_opt" in program._pipeline_opt:
  1586. # Move prepare here for port conflict with nccl in startup program
  1587. if self._fleet_executor is None:
  1588. self._fleet_executor = _prepare_fleet_executor()
  1589. return self._run_using_fleet_executor(
  1590. program=program,
  1591. feed=feed,
  1592. fetch_list=fetch_list,
  1593. with_standalone_executor=self._fleet_executor_with_standalone,
  1594. return_numpy=return_numpy,
  1595. )
  1596. if "startup_program" in program._pipeline_opt:
  1597. program = program._pipeline_opt["startup_program"]
  1598. else:
  1599. return self._run_pipeline(
  1600. program,
  1601. fetch_list=fetch_list,
  1602. use_program_cache=use_program_cache,
  1603. )
  1604. if isinstance(program, Program) and program._heter_pipeline_opt:
  1605. # print("program._heter_pipeline_opt: {}".format(
  1606. # program._heter_pipeline_opt))
  1607. # change default executor
  1608. heter_place = program._heter_pipeline_opt["heter_place"]
  1609. heter_place = framework._get_paddle_place(heter_place)
  1610. p = core.Place()
  1611. p.set_place(heter_place)
  1612. self._default_executor = core.Executor(p)
  1613. # TODO(zhangminxu): support heterps pipeline training using exe.run
  1614. if "startup_program" in program._heter_pipeline_opt:
  1615. # print("get startup_program from _pipeline_opt")
  1616. program = program._heter_pipeline_opt["startup_program"]
  1617. if (
  1618. isinstance(program, Program)
  1619. and len(program.global_block().ops) == 0
  1620. ):
  1621. if use_default_main_program:
  1622. error_info = (
  1623. "Now you are using default_main_program, "
  1624. "but there are no operators in the program to be executed. "
  1625. "Please ensure you create model correctly or you can pass "
  1626. "the Program or the CompiledProgram manually."
  1627. )
  1628. warnings.warn(error_info)
  1629. if scope is None:
  1630. scope = global_scope()
  1631. # use_prune can be overridden by putting optimize_ops in fetch_list
  1632. _origin_fetch_list = fetch_list
  1633. _origin_program = program
  1634. fetch_list, optimize_ops = self._split_optimize_ops_in_fetch_list(
  1635. fetch_list
  1636. )
  1637. if optimize_ops:
  1638. use_prune = True
  1639. if use_prune:
  1640. cache_key = _get_strong_program_cache_key(
  1641. program, feed, _origin_fetch_list
  1642. )
  1643. cached_pruned_program = self._get_pruned_program_cache(cache_key)
  1644. if cached_pruned_program is None:
  1645. if isinstance(program, compiler.CompiledProgram):
  1646. program_scope_cache = self._get_pruned_program_scope_cache(
  1647. str(id(_origin_program))
  1648. )
  1649. # copy the original program, so it can be cached.
  1650. program = copy.copy(program)
  1651. # share the local scopes for same original CompiledProgram.
  1652. program._share_vars_from = program_scope_cache
  1653. if (
  1654. self._get_pruned_program_scope_cache(
  1655. str(id(_origin_program))
  1656. )
  1657. is None
  1658. ):
  1659. self._add_pruned_program_scope_cache(
  1660. str(id(_origin_program)), program
  1661. )
  1662. pruned_program = self._prune_program(
  1663. program, feed, fetch_list, optimize_ops
  1664. )
  1665. self._add_pruned_program_cache(cache_key, pruned_program)
  1666. else:
  1667. pruned_program = cached_pruned_program
  1668. feed = self._update_feed(pruned_program, feed)
  1669. program = pruned_program
  1670. if _can_use_interpreter_core(program, self.place):
  1671. if feed is None:
  1672. feed = {}
  1673. elif isinstance(feed, (list, tuple)):
  1674. assert len(feed) == 1, "Not compiled with data parallel"
  1675. feed = feed[0]
  1676. if not isinstance(feed, dict):
  1677. raise TypeError(
  1678. "feed requires dict as its Parameter. But you passed in %s"
  1679. % (type(feed))
  1680. )
  1681. feed = self._update_feed(program, feed)
  1682. stored_flag = {}
  1683. if isinstance(program, compiler.CompiledProgram) or isinstance(
  1684. program._graph, compiler.CompiledProgram
  1685. ):
  1686. compiled_program = (
  1687. program
  1688. if isinstance(program, compiler.CompiledProgram)
  1689. else program._graph
  1690. )
  1691. build_strategy = compiled_program._build_strategy
  1692. if build_strategy is not None and build_strategy.sequential_run:
  1693. schedule_flag = [
  1694. 'FLAGS_new_executor_serial_run',
  1695. 'FLAGS_new_executor_sequential_run',
  1696. ]
  1697. for flag in schedule_flag:
  1698. value = os.getenv(flag, False)
  1699. if isinstance(value, str):
  1700. value = value.lower()
  1701. value = True if value == 'true' else False
  1702. stored_flag[flag] = bool(value)
  1703. set_flags({f: True for f in schedule_flag})
  1704. program, new_exe = self._executor_cache.get_program_and_executor(
  1705. program,
  1706. feed,
  1707. fetch_list,
  1708. feed_var_name,
  1709. fetch_var_name,
  1710. self.place,
  1711. scope,
  1712. )
  1713. self._feed_data(program, feed, feed_var_name, scope)
  1714. if hasattr(program, 'lr_scheduler'):
  1715. from paddle.optimizer.lr import LRScheduler
  1716. assert isinstance(
  1717. program.lr_scheduler, LRScheduler
  1718. ), "must be LRScheduler"
  1719. lr_scheduler = program.lr_scheduler
  1720. lr_value = lr_scheduler()
  1721. lr_var = program.global_block().vars[lr_scheduler._var_name]
  1722. data = np.array([lr_value]).astype(convert_dtype(lr_var.dtype))
  1723. tensor = core.get_variable_tensor(scope, lr_scheduler._var_name)
  1724. # NOTE(dev): `tensor.set(data, self.place)` always call TensorCopySync that is a blocking behavior. So we use `_copy_from` to replace it.
  1725. cpu_tensor = _as_lodtensor(data, core.CPUPlace())
  1726. if core.is_cuda_graph_capturing():
  1727. warnings.warn(
  1728. "Caution!!! When capturing CUDA Graph, the learning rate scheduler would not "
  1729. "take any effect! Please set the learning rate manually before each batch!"
  1730. )
  1731. elif core.is_compiled_with_ipu():
  1732. # for ipu, tensor is allocated on cpu
  1733. tensor._copy_from(cpu_tensor, tensor._place())
  1734. else:
  1735. tensor._copy_from(cpu_tensor, self.place)
  1736. ret = new_exe.run(
  1737. list(feed.keys()),
  1738. return_numpy,
  1739. self.enable_job_schedule_profiler,
  1740. )
  1741. set_flags(stored_flag)
  1742. return ret
  1743. compiled = isinstance(program, compiler.CompiledProgram)
  1744. # Check if paddle.static.data() variable no feed data
  1745. if use_prune:
  1746. if compiled:
  1747. global_block = program._program.global_block()
  1748. else:
  1749. global_block = program.global_block()
  1750. for varname in global_block.vars:
  1751. vardesc = global_block.desc.find_var(varname.encode())
  1752. varobj = global_block.vars[varname]
  1753. if (
  1754. vardesc.persistable() is False
  1755. and vardesc.type() == core.VarDesc.VarType.LOD_TENSOR
  1756. and vardesc.need_check_feed() is True
  1757. and varobj.stop_gradient is True
  1758. and varobj.is_data is True
  1759. and varobj.belong_to_optimizer is False
  1760. and varname not in feed
  1761. ):
  1762. raise ValueError('Need feed data for variable %s' % varname)
  1763. acp._auto_checkpoint(self, program)
  1764. program._compile(scope, self.place)
  1765. assert (
  1766. program._is_inference
  1767. ), f"Program must have _is_inference = True, but get {program._is_inference}"
  1768. return self._run_inference(program._executor, feed)
  1769. def _run_pir_impl(
  1770. self,
  1771. program,
  1772. feed,
  1773. fetch_list,
  1774. feed_var_name,
  1775. fetch_var_name,
  1776. scope,
  1777. return_numpy,
  1778. ):
  1779. import paddle
  1780. Program = paddle.pir.Program
  1781. default_main_program = paddle.pir.core.default_main_program
  1782. if self._closed:
  1783. raise RuntimeError("Attempted to use a closed Executor")
  1784. use_default_main_program = program is None
  1785. if use_default_main_program:
  1786. program = default_main_program()
  1787. fetch_list = self._check_fetch_list(fetch_list)
  1788. if (
  1789. isinstance(program, Program)
  1790. and len(program.global_block().ops) == 0
  1791. ):
  1792. if use_default_main_program:
  1793. error_info = (
  1794. "Now you are using default_main_program, "
  1795. "but there are no operators in the program to be executed. "
  1796. "Please ensure you create model correctly or you can pass "
  1797. "the Program or the CompiledProgram manually."
  1798. )
  1799. warnings.warn(error_info)
  1800. if scope is None:
  1801. scope = global_scope()
  1802. if feed is None:
  1803. feed = {}
  1804. elif isinstance(feed, (list, tuple)):
  1805. assert len(feed) == 1, "Not compiled with data parallel"
  1806. feed = feed[0]
  1807. if not isinstance(feed, dict):
  1808. raise TypeError(
  1809. "feed requires dict as its Parameter. But you passed in %s"
  1810. % (type(feed))
  1811. )
  1812. (
  1813. program,
  1814. new_exe,
  1815. data_op_infos,
  1816. ) = self._executor_cache.get_pir_program_and_executor(
  1817. program,
  1818. feed,
  1819. fetch_list,
  1820. feed_var_name,
  1821. fetch_var_name,
  1822. self.place,
  1823. scope,
  1824. )
  1825. self._pir_feed_data(program, feed, scope, data_op_infos)
  1826. if hasattr(program, 'lr_scheduler'):
  1827. from paddle.optimizer.lr import LRScheduler
  1828. assert isinstance(
  1829. program.lr_scheduler, LRScheduler
  1830. ), "must be LRScheduler"
  1831. lr_scheduler = program.lr_scheduler
  1832. lr_value = lr_scheduler()
  1833. lr_var = program.get_parameter_value_by_name(program.lr_name)
  1834. data = np.array([lr_value]).astype(convert_dtype(lr_var.dtype))
  1835. tensor = core.get_variable_tensor(global_scope(), program.lr_name)
  1836. # NOTE(dev): `tensor.set(data, self.place)` always call TensorCopySync that is a blocking behavior. So we use `_copy_from` to replace it.
  1837. cpu_tensor = _as_lodtensor(data, core.CPUPlace())
  1838. if core.is_cuda_graph_capturing():
  1839. warnings.warn(
  1840. "Caution!!! When capturing CUDA Graph, the learning rate scheduler would not "
  1841. "take any effect! Please set the learning rate manually before each batch!"
  1842. )
  1843. elif core.is_compiled_with_ipu():
  1844. # for ipu, tensor is allocated on cpu
  1845. tensor._copy_from(cpu_tensor, tensor._place())
  1846. else:
  1847. tensor._copy_from(cpu_tensor, self.place)
  1848. ret = new_exe.run(list(feed.keys()), return_numpy)
  1849. return ret
  1850. def _run_inference(self, exe, feed):
  1851. return exe.run(feed)
  1852. def _check_fetch_list(self, fetch_list):
  1853. is_fetch_var = lambda var: isinstance(var, (Variable, str, Value))
  1854. is_tuple_list = lambda var: isinstance(var, (tuple, list))
  1855. if fetch_list is None:
  1856. return []
  1857. if is_fetch_var(fetch_list):
  1858. return [fetch_list]
  1859. assert is_tuple_list(fetch_list), (
  1860. "Currently , The fetch_list type only should be list or tuple, \n"
  1861. f"but the input type is {type(fetch_list)}. For more information please refer to \n"
  1862. "the executor.run(...)."
  1863. )
  1864. res = []
  1865. for i, var in enumerate(fetch_list):
  1866. if is_fetch_var(var):
  1867. res.append(var)
  1868. # such as [x, 'mean_out', loss]
  1869. elif is_tuple_list(var):
  1870. if all(is_fetch_var(v) for v in var):
  1871. res.extend(list(var))
  1872. else:
  1873. res.append(var)
  1874. else:
  1875. raise TypeError(
  1876. f"Require fetch_list[{i}] 's type shall be one of (Value, str), but received {type(var).__name__}."
  1877. )
  1878. return res
  1879. def _dump_debug_info(self, program=None, trainer=None):
  1880. with open(str(id(program)) + "_train_desc.prototxt", "w") as fout:
  1881. fout.write(str(trainer))
  1882. if program._fleet_opt and "fleet_desc" in program._fleet_opt:
  1883. with open("fleet_desc.prototxt", "w") as fout:
  1884. fout.write(str(program._fleet_opt["fleet_desc"]))
  1885. def _adjust_pipeline_resource(self, pipeline_opt, dataset, pipeline_num):
  1886. filelist_length = len(dataset.dataset.get_filelist())
  1887. if filelist_length < pipeline_num:
  1888. pipeline_num = filelist_length
  1889. print(
  1890. "Pipeline training: setting the pipeline num to %d is enough because there are only %d files"
  1891. % (filelist_length, filelist_length)
  1892. )
  1893. if filelist_length < pipeline_num * pipeline_opt["concurrency_list"][0]:
  1894. print(
  1895. "Pipeline training: setting the 1st element in concurrency_list to %d is enough because there are only %d files"
  1896. % (filelist_length // pipeline_num, filelist_length)
  1897. )
  1898. pipeline_opt["concurrency_list"][0] = (
  1899. filelist_length // pipeline_num
  1900. )
  1901. dataset.set_thread(pipeline_opt["concurrency_list"][0] * pipeline_num)
  1902. return pipeline_num
  1903. def split_program_by_device(self, program):
  1904. ops_list = []
  1905. type_list = []
  1906. pre = None
  1907. type_cpu = "cpu"
  1908. for op in program.global_block().ops:
  1909. if self._is_optimizer_op(op):
  1910. break
  1911. if op.has_attr("op_device"):
  1912. cur_attr = (
  1913. op.attr("op_device")
  1914. if op.attr("op_device") != ""
  1915. else type_cpu
  1916. )
  1917. if pre is None or pre != cur_attr:
  1918. ops_list.append([])
  1919. type_list.append(cur_attr)
  1920. ops_list[-1].append(op)
  1921. pre = cur_attr
  1922. l = len(type_list)
  1923. i = 0
  1924. type_heter = None
  1925. while i < l:
  1926. while i < l and type_list[i] == type_cpu:
  1927. i += 1
  1928. if i == l:
  1929. break
  1930. type_heter = type_list[i]
  1931. i += 1
  1932. start = i
  1933. valid = True
  1934. while i < l and type_list[i] != type_heter:
  1935. if type_list[i] != type_cpu:
  1936. valid = False
  1937. break
  1938. i += 1
  1939. if i == l:
  1940. break
  1941. elif not valid:
  1942. continue
  1943. for j in range(start, i):
  1944. for op in ops_list[j]:
  1945. op._set_attr("op_device", type_heter)
  1946. type_list[j] = type_heter
  1947. j += 1
  1948. pre = None
  1949. merged_ops_list = []
  1950. merged_type_list = []
  1951. for i in range(l):
  1952. if pre is None or pre != type_list[i]:
  1953. merged_ops_list.append([])
  1954. merged_type_list.append(type_list[i])
  1955. merged_ops_list[-1].extend(ops_list[i])
  1956. pre = type_list[i]
  1957. data_vars = set()
  1958. for k in program.global_block().vars:
  1959. var = program.global_block().var(k)
  1960. if not var.persistable:
  1961. data_vars.add(var.name)
  1962. l = len(merged_ops_list)
  1963. inputs_pre = set()
  1964. outputs_pre = set()
  1965. in_from_pre = [[] for i in range(l)]
  1966. for i in range(l):
  1967. inputs = set()
  1968. outputs = set()
  1969. for op in merged_ops_list[i]:
  1970. for input in op.input_names:
  1971. for tmp in op.input(input):
  1972. if tmp not in outputs:
  1973. inputs.add(tmp)
  1974. for output in op.output_names:
  1975. for tmp in op.output(output):
  1976. outputs.add(tmp)
  1977. if i == 0:
  1978. in_from_pre[i] = []
  1979. elif i == 1:
  1980. in_from_pre[i] = (outputs_pre | data_vars) & inputs
  1981. else:
  1982. in_from_pre[i] = outputs_pre & inputs
  1983. inputs_pre = copy.deepcopy(inputs)
  1984. outputs_pre = copy.deepcopy(outputs)
  1985. l = len(in_from_pre)
  1986. start_list = []
  1987. end_list = []
  1988. send_list = [[] for i in range(l)]
  1989. sum = 0
  1990. program_list = []
  1991. for i in range(l):
  1992. start_list.append(sum)
  1993. end_list.append(sum + len(merged_ops_list[i]) - 1)
  1994. sum += len(merged_ops_list[i])
  1995. if i < l - 1:
  1996. send_list[i].extend(list(in_from_pre[i + 1]))
  1997. prog = program.clone()
  1998. if merged_type_list[i] != type_cpu:
  1999. prog = prog._prune_with_input(
  2000. list(in_from_pre[i]), list(send_list[i])
  2001. )
  2002. program_list.append(prog)
  2003. else:
  2004. program_list.append(prog)
  2005. recv_list = [list(i) for i in in_from_pre]
  2006. found = False
  2007. heter_index = None
  2008. for i in range(len(merged_type_list)):
  2009. t = merged_type_list[i]
  2010. if t != type_cpu:
  2011. if found:
  2012. print("only one region of program can be heter")
  2013. found = True
  2014. heter_index = i
  2015. if heter_index is None:
  2016. print("warning: non heter program")
  2017. return None
  2018. else:
  2019. return [
  2020. start_list[heter_index],
  2021. end_list[heter_index],
  2022. send_list[heter_index],
  2023. recv_list[heter_index],
  2024. program_list[heter_index],
  2025. ]
  2026. def _prepare_trainer(
  2027. self,
  2028. program=None,
  2029. dataset=None,
  2030. scope=None,
  2031. thread=0,
  2032. debug=False,
  2033. fetch_list=None,
  2034. fetch_info=None,
  2035. print_period=100,
  2036. ):
  2037. is_heter = 0
  2038. use_ps_gpu = 0
  2039. if program._fleet_opt is not None:
  2040. if program._fleet_opt.get("worker_class", "") == "HeterCpuWorker":
  2041. is_heter = 1
  2042. if program._fleet_opt.get("trainer", "") == "HeterXpuTrainer":
  2043. is_heter = 1
  2044. if program._fleet_opt.get("use_ps_gpu", False):
  2045. use_ps_gpu = True
  2046. if scope is None:
  2047. scope = global_scope()
  2048. if fetch_list is None:
  2049. fetch_list = []
  2050. if fetch_info is None:
  2051. fetch_info = []
  2052. assert len(fetch_list) == len(fetch_info)
  2053. compiled = isinstance(program, compiler.CompiledProgram)
  2054. if is_heter:
  2055. ret = self.split_program_by_device(program)
  2056. if not compiled:
  2057. # TODO: Need a better way to distinguish and specify different execution mode
  2058. if program._pipeline_opt:
  2059. trainer = TrainerFactory()._create_trainer(
  2060. program._pipeline_opt
  2061. )
  2062. elif program._heter_pipeline_opt:
  2063. trainer = TrainerFactory()._create_trainer(
  2064. program._heter_pipeline_opt
  2065. )
  2066. else:
  2067. trainer = TrainerFactory()._create_trainer(program._fleet_opt)
  2068. trainer._set_thread_barrier(program._is_distributed)
  2069. trainer._set_program(program)
  2070. if is_heter:
  2071. trainer._set_heter_info(ret)
  2072. else:
  2073. if program._pipeline_opt:
  2074. trainer = TrainerFactory()._create_trainer(
  2075. program.program._pipeline_opt
  2076. )
  2077. elif program._heter_pipeline_opt:
  2078. trainer = TrainerFactory()._create_trainer(
  2079. program.program._heter_pipeline_opt
  2080. )
  2081. else:
  2082. trainer = TrainerFactory()._create_trainer(
  2083. program.program._fleet_opt
  2084. )
  2085. trainer._set_program(program.program)
  2086. if thread <= 0:
  2087. if use_ps_gpu:
  2088. trainer._set_thread(len(program._fleet_opt["worker_places"]))
  2089. elif dataset.thread_num <= 0:
  2090. raise RuntimeError(
  2091. "You should set thread num first, either in Dataset"
  2092. "or in Executor.train_from_dataset"
  2093. )
  2094. else:
  2095. trainer._set_thread(dataset.thread_num)
  2096. else:
  2097. trainer._set_thread(thread)
  2098. trainer._set_debug(debug)
  2099. trainer._set_fetch_var_and_info(fetch_list, fetch_info, print_period)
  2100. return scope, trainer
  2101. def _run_from_dataset(
  2102. self,
  2103. program=None,
  2104. dataset=None,
  2105. scope=None,
  2106. thread=0,
  2107. is_infer=False,
  2108. debug=False,
  2109. fetch_list=None,
  2110. fetch_info=None,
  2111. print_period=100,
  2112. fetch_handler=None,
  2113. ):
  2114. if program._pipeline_opt is not None:
  2115. import paddle
  2116. if dataset is not None:
  2117. raise RuntimeError("dataset should be None for pipeline mode")
  2118. # The following fake dataset is created to call
  2119. # the _prepare_trainer api, and it is meaningless.
  2120. data_vars = []
  2121. for var in program.global_block().vars.values():
  2122. if var.is_data:
  2123. data_vars.append(var)
  2124. dataset = paddle.base.DatasetFactory().create_dataset(
  2125. 'FileInstantDataset'
  2126. )
  2127. dataset.set_batch_size(1)
  2128. dataset.set_thread(1)
  2129. dataset.set_filelist(['None'])
  2130. dataset.set_use_var(data_vars)
  2131. elif program._heter_pipeline_opt is not None:
  2132. stage_id = program._heter_pipeline_opt["pipeline_stage"]
  2133. # print("test_fl_stage_id: {}".format(stage_id))
  2134. heter_place = program._heter_pipeline_opt["heter_place"]
  2135. if stage_id != 0:
  2136. if "is_fl_mode" not in program._heter_pipeline_opt:
  2137. import paddle
  2138. if dataset is not None:
  2139. raise RuntimeError(
  2140. "dataset should be None for heter pipeline mode"
  2141. )
  2142. # The following fake dataset is created to call
  2143. # the _prepare_trainer api, and it is meaningless.
  2144. data_vars = []
  2145. for var in program.global_block().vars.values():
  2146. if var.is_data:
  2147. data_vars.append(var)
  2148. dataset = paddle.base.DatasetFactory().create_dataset(
  2149. 'InMemoryDataset'
  2150. )
  2151. dataset.set_batch_size(1)
  2152. dataset.set_thread(1)
  2153. dataset.set_filelist(['None'])
  2154. dataset.set_use_var(data_vars)
  2155. else:
  2156. if dataset is None:
  2157. raise RuntimeError(
  2158. "dataset is need and should be initialized"
  2159. )
  2160. # change default executor
  2161. heter_place = framework._get_paddle_place(heter_place)
  2162. p = core.Place()
  2163. p.set_place(heter_place)
  2164. self._default_executor = core.Executor(p)
  2165. else:
  2166. if dataset is None:
  2167. raise RuntimeError("dataset is need and should be initialized")
  2168. dataset._prepare_to_run()
  2169. real_fetch_list = []
  2170. if program._pipeline_opt:
  2171. real_program = program._pipeline_opt["section_program"]
  2172. for fetch_var in fetch_list:
  2173. if isinstance(fetch_var, Variable):
  2174. fetch_var_name = fetch_var.name
  2175. else:
  2176. fetch_var_name = fetch_var
  2177. if fetch_var_name in real_program.global_block().vars:
  2178. real_fetch_list.append(fetch_var)
  2179. program._pipeline_opt["section_program"] = _add_feed_fetch_ops(
  2180. program=program._pipeline_opt["section_program"],
  2181. feed=[],
  2182. fetch_list=real_fetch_list,
  2183. feed_var_name='feed',
  2184. fetch_var_name='fetch',
  2185. )
  2186. main_block = program._pipeline_opt["section_program"].block(0)
  2187. for op in main_block.ops:
  2188. # set the op_role of fetch op to Optimize to avoid
  2189. # erase the fetched vars by gc for pipeline
  2190. if op.type == 'fetch':
  2191. op._set_attr(
  2192. 'op_role',
  2193. core.op_proto_and_checker_maker.OpRole.Optimize,
  2194. )
  2195. fetch_list = None
  2196. scope, trainer = self._prepare_trainer(
  2197. program=program,
  2198. dataset=dataset,
  2199. scope=scope,
  2200. thread=thread,
  2201. debug=debug,
  2202. fetch_list=fetch_list,
  2203. fetch_info=fetch_info,
  2204. print_period=print_period,
  2205. )
  2206. trainer._set_infer(is_infer)
  2207. trainer._gen_trainer_desc()
  2208. if program._pipeline_opt is None:
  2209. if program._heter_pipeline_opt is None:
  2210. self._dump_debug_info(program=program, trainer=trainer)
  2211. # warning if dataset not set psgpu in psgpu mode
  2212. if dataset.use_ps_gpu is False and trainer.proto_desc.use_ps_gpu:
  2213. logging.warning("dataset should call set_use_ps_gpu in PsGpu mode")
  2214. dataset._dynamic_adjust_before_train(trainer.proto_desc.thread_num)
  2215. reused_trainer = program._heter_pipeline_opt is not None or (
  2216. program._fleet_opt is not None
  2217. and program._fleet_opt.get("use_ps_gpu", False)
  2218. and program._fleet_opt.get("dump_fields_path", "") == ""
  2219. )
  2220. if reused_trainer is False:
  2221. trainer_instance = (
  2222. self._default_executor.init_for_dataset( # -->InitForDataset
  2223. program.desc, trainer._desc(), scope, dataset.dataset
  2224. )
  2225. )
  2226. else:
  2227. # cache trainer instance for heterps pipeline training
  2228. if fetch_list is None:
  2229. fetch_list = []
  2230. cache_key = _get_strong_program_cache_key(program, None, fetch_list)
  2231. trainer_instance = self._get_trainer_cache(cache_key)
  2232. if trainer_instance is None:
  2233. trainer_instance = self._default_executor.init_for_dataset(
  2234. program.desc, trainer._desc(), scope, dataset.dataset
  2235. )
  2236. # print("test_fl_ps - trainer_desc: {}\n".format(trainer))
  2237. self._add_trainer_cache(cache_key, trainer_instance)
  2238. else:
  2239. trainer_instance.ResetDataset(dataset.dataset)
  2240. if fetch_handler is not None:
  2241. scope0 = trainer_instance.get_worker_scope(0)
  2242. fetch_monitor = FetchHandlerMonitor(scope0, fetch_handler)
  2243. fetch_monitor.start()
  2244. self._default_executor.run_from_dataset(trainer_instance)
  2245. fetch_monitor.stop()
  2246. if reused_trainer is False:
  2247. self._default_executor.release_trainer(trainer_instance)
  2248. else:
  2249. self._default_executor.run_from_dataset(trainer_instance)
  2250. if reused_trainer is False:
  2251. self._default_executor.release_trainer(trainer_instance)
  2252. dataset._dynamic_adjust_after_train()
  2253. dataset._finish_to_run()
  2254. if real_fetch_list:
  2255. arr = scope.find_var('fetch').get_fetch_list()
  2256. tensors = arr._move_to_list()
  2257. return as_numpy(tensors)
  2258. return None
  2259. def _prepare_pipeline_ctx(
  2260. self,
  2261. program=None,
  2262. dataset=None,
  2263. scope=None,
  2264. thread=0,
  2265. is_infer=False,
  2266. debug=False,
  2267. fetch_list=None,
  2268. fetch_info=None,
  2269. print_period=100,
  2270. fetch_handler=None,
  2271. use_program_cache=False,
  2272. ):
  2273. assert program._pipeline_opt is not None
  2274. assert dataset is None, "dataset should be None for pipeline mode"
  2275. cache_key = _get_strong_program_cache_key(program, None, fetch_list)
  2276. ctx = self._get_ctx_cache(cache_key)
  2277. if use_program_cache and ctx is not None:
  2278. return ctx
  2279. import paddle
  2280. # The following fake dataset is created to call
  2281. # the _prepare_trainer api, and it is meaningless.
  2282. def _get_dataset():
  2283. data_vars = []
  2284. for var in program.global_block().vars.values():
  2285. if var.is_data:
  2286. data_vars.append(var)
  2287. dataset = paddle.base.DatasetFactory().create_dataset(
  2288. 'FileInstantDataset'
  2289. )
  2290. dataset.set_batch_size(1)
  2291. dataset.set_thread(1)
  2292. dataset.set_filelist(['None'])
  2293. dataset.set_use_var(data_vars)
  2294. dataset._prepare_to_run()
  2295. return dataset
  2296. dataset = _get_dataset()
  2297. def _get_real_program_fetch_list():
  2298. real_program = program._pipeline_opt["section_program"]
  2299. real_fetch_list = []
  2300. for fetch_var in fetch_list:
  2301. if isinstance(fetch_var, Variable):
  2302. fetch_var_name = fetch_var.name
  2303. else:
  2304. fetch_var_name = fetch_var
  2305. if fetch_var_name in real_program.global_block().vars:
  2306. real_fetch_list.append(fetch_var)
  2307. real_program = _add_feed_fetch_ops(
  2308. program=real_program,
  2309. feed=[],
  2310. fetch_list=real_fetch_list,
  2311. feed_var_name='feed',
  2312. fetch_var_name='fetch',
  2313. )
  2314. main_block = real_program.block(0)
  2315. for op in main_block.ops:
  2316. # set the op_role of fetch op to Optimize to avoid
  2317. # erase the fetched vars by gc for pipeline
  2318. if op.type == 'fetch':
  2319. op._set_attr(
  2320. 'op_role',
  2321. core.op_proto_and_checker_maker.OpRole.Optimize,
  2322. )
  2323. return real_program, real_fetch_list
  2324. real_program, real_fetch_list = _get_real_program_fetch_list()
  2325. program._pipeline_opt["section_program"] = real_program
  2326. fetch_list = None
  2327. scope, trainer = self._prepare_trainer(
  2328. program=program,
  2329. dataset=dataset,
  2330. scope=scope,
  2331. thread=thread,
  2332. debug=debug,
  2333. fetch_list=fetch_list,
  2334. fetch_info=fetch_info,
  2335. print_period=print_period,
  2336. )
  2337. trainer._set_infer(is_infer)
  2338. trainer._gen_trainer_desc()
  2339. # NOTE: only for debug, very slow
  2340. # self._dump_debug_info(program=program, trainer=trainer)
  2341. # warning if dataset not set psgpu in psgpu mode
  2342. if dataset.use_ps_gpu is False and trainer.proto_desc.use_ps_gpu:
  2343. logging.warning("dataset should call set_use_ps_gpu in PsGpu mode")
  2344. dataset._dynamic_adjust_before_train(trainer.proto_desc.thread_num)
  2345. trainer_desc = trainer._desc() # slow, cache
  2346. trainer_instance = self._default_executor.init_for_dataset(
  2347. program.desc, trainer_desc, scope, dataset.dataset
  2348. )
  2349. ctx = [scope, real_fetch_list, trainer_instance]
  2350. if use_program_cache:
  2351. self._add_ctx_cache(cache_key, ctx)
  2352. return ctx
  2353. def _prepare_fleet_executor_carrier(
  2354. self,
  2355. carrier_id="",
  2356. program=None,
  2357. scope=None,
  2358. fleet_opt=None,
  2359. micro_scope_list=[],
  2360. with_standalone_executor=False,
  2361. ):
  2362. num_micro_batches = (
  2363. fleet_opt["num_micro_batches"]
  2364. if "num_micro_batches" in fleet_opt
  2365. else 1
  2366. )
  2367. cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0))
  2368. trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", "").split(',')
  2369. nrank = len(trainer_endpoints)
  2370. assert 'scheduler' in fleet_opt or 'tasks' in fleet_opt, (
  2371. "Fleet executor need configuration for scheduler, you can choose from 1F1B or Origin. "
  2372. "Or you can provide a list of task nodes to init fleet executor directly."
  2373. )
  2374. if 'tasks' in fleet_opt:
  2375. assert 'task_id_to_rank' in fleet_opt, (
  2376. "If you provide tasks to init fleet executor,"
  2377. " task_id_to_rank should also be provided."
  2378. )
  2379. print('fleet executor will use user defined task nodes')
  2380. tasks = [task.task_node() for task in fleet_opt['tasks']]
  2381. task_id_to_rank = fleet_opt['task_id_to_rank']
  2382. else:
  2383. scheduler = fleet_opt['scheduler']
  2384. if scheduler == '1F1B':
  2385. from paddle.distributed.fleet.fleet_executor_utils import (
  2386. run1f1b,
  2387. )
  2388. if (
  2389. "dist_strategy" not in fleet_opt
  2390. or "pp_degree" not in fleet_opt["dist_strategy"]
  2391. or fleet_opt["dist_strategy"]["pp_degree"] == 1
  2392. ):
  2393. warnings.warn("Using 1F1B scheduler with pp_degree == 1.")
  2394. tasks, task_id_to_rank = run1f1b(
  2395. program,
  2396. cur_rank,
  2397. fleet_opt.get('num_micro_batches', 1),
  2398. fleet_opt.get('dist_strategy', {}),
  2399. nrank,
  2400. with_standalone_executor,
  2401. )
  2402. elif scheduler == 'Origin':
  2403. from paddle.distributed.fleet.fleet_executor_utils import origin
  2404. if (
  2405. "dist_strategy" in fleet_opt
  2406. and "pp_degree" in fleet_opt["dist_strategy"]
  2407. ):
  2408. assert (
  2409. fleet_opt["dist_strategy"]["pp_degree"] == 1
  2410. ), "For pipeline mode, the scheduler should be 1F1B instead of Origin."
  2411. if "num_micro_batches" in fleet_opt:
  2412. assert (
  2413. fleet_opt["num_micro_batches"] == 1
  2414. ), "For origin scheduler mode, the num micro batches should be 1."
  2415. tasks, task_id_to_rank = origin(program, cur_rank)
  2416. else:
  2417. raise "Fleet_executor only supports 1F1B and Origin scheduler, " "but received " + str(
  2418. scheduler
  2419. ) + "."
  2420. # NOTE: have to hold these vars, otherwise will be destructed
  2421. fleet_opt['tasks'] = tasks
  2422. fleet_opt['task_id_to_rank'] = task_id_to_rank
  2423. place = core.Place()
  2424. place.set_place(self.place)
  2425. inference_root_scope_vars = (
  2426. fleet_opt["fetch_var"] if "fetch_var" in fleet_opt else []
  2427. )
  2428. self._fleet_executor.init(
  2429. carrier_id,
  2430. program.desc,
  2431. scope,
  2432. place,
  2433. num_micro_batches,
  2434. tasks,
  2435. task_id_to_rank,
  2436. inference_root_scope_vars,
  2437. micro_scope_list,
  2438. )
  2439. def _run_using_fleet_executor(
  2440. self,
  2441. program=None,
  2442. feed=None,
  2443. feed_var_name="feed",
  2444. fetch_var_name="fetch",
  2445. fetch_list=None,
  2446. with_standalone_executor=False,
  2447. return_numpy=True,
  2448. ):
  2449. cache_key = _get_strong_program_cache_key(program, feed, fetch_list)
  2450. cached_program = self._get_program_cache(cache_key)
  2451. cached_scope = self._get_scope_cache(cache_key)
  2452. micro_cached_scopes = self._get_micro_scopes_cache(cache_key)
  2453. fleet_opt = program._pipeline_opt["fleet_opt"]
  2454. if cached_scope is None:
  2455. cached_scope = global_scope()
  2456. self._add_scope_cache(cache_key, cached_scope)
  2457. if micro_cached_scopes is None:
  2458. micro_cached_scopes = []
  2459. if (
  2460. "inference_generation" in fleet_opt
  2461. and fleet_opt["inference_generation"]
  2462. ):
  2463. for _ in range(int(fleet_opt["num_micro_batches"])):
  2464. micro_cached_scopes.append(cached_scope.new_scope())
  2465. self._add_micro_scopes_cache(cache_key, micro_cached_scopes)
  2466. if cached_program is None:
  2467. assert (
  2468. program._pipeline_opt
  2469. ), "program should have _pipeline_opt to start carrier"
  2470. real_feed = [] if feed is None else feed
  2471. real_program = program
  2472. if "section_program" in program._pipeline_opt:
  2473. real_program = program._pipeline_opt["section_program"]
  2474. cached_program = _add_feed_fetch_ops(
  2475. program=real_program,
  2476. feed=real_feed,
  2477. fetch_list=fetch_list,
  2478. feed_var_name=feed_var_name,
  2479. fetch_var_name=fetch_var_name,
  2480. )
  2481. main_block = cached_program.block(0)
  2482. for op in main_block.ops:
  2483. # set the op_role of fetch op to Optimize to avoid
  2484. # erase the fetched vars by gc for pipeline
  2485. if op.type == 'fetch':
  2486. op._set_attr(
  2487. 'op_role',
  2488. core.op_proto_and_checker_maker.OpRole.Optimize,
  2489. )
  2490. self._add_program_cache(cache_key, cached_program)
  2491. fleet_opt = program._pipeline_opt["fleet_opt"]
  2492. if 'tasks' in fleet_opt:
  2493. # Insert feed/fetch op for cloned program in each task node,
  2494. # these ops has already been inserted into the origin program.
  2495. # To avoid every task nodes all have feed/fetch ops,
  2496. # only insert feed ops into the first task node,
  2497. # then insert fetch ops into the last task node.
  2498. # Insert feed ops
  2499. feed_task = fleet_opt['tasks'][0]
  2500. print("Inserting feed ops for task", feed_task.task_id())
  2501. feed_program = feed_task.get_program()
  2502. feed_program = self._add_feed_ops(
  2503. program=feed_program,
  2504. feed=real_feed,
  2505. feed_var_name=feed_var_name,
  2506. )
  2507. feed_task.set_program(feed_program)
  2508. # Insert fetch ops
  2509. fetch_task = fleet_opt['tasks'][-1]
  2510. print("Inserting fetch ops for task", fetch_task.task_id())
  2511. fetch_program = fetch_task.get_program()
  2512. fetch_program = self._add_fetch_ops(
  2513. program=fetch_program,
  2514. fetch_list=fetch_list,
  2515. fetch_var_name=fetch_var_name,
  2516. )
  2517. main_block = fetch_program.block(0)
  2518. for op in main_block.ops:
  2519. # set the op_role of fetch op to Optimize to avoid
  2520. # erase the fetched vars by gc for pipeline
  2521. if op.type == 'fetch':
  2522. op._set_attr(
  2523. 'op_role',
  2524. core.op_proto_and_checker_maker.OpRole.Optimize,
  2525. )
  2526. fetch_task.set_program(fetch_program)
  2527. micro_scope_list = []
  2528. if (
  2529. "inference_generation" in fleet_opt
  2530. and fleet_opt["inference_generation"]
  2531. ):
  2532. for i in range(int(fleet_opt["num_micro_batches"])):
  2533. micro_scope_list.append(cached_scope.new_scope())
  2534. self._prepare_fleet_executor_carrier(
  2535. cache_key,
  2536. program=cached_program,
  2537. scope=cached_scope,
  2538. fleet_opt=fleet_opt,
  2539. micro_scope_list=micro_cached_scopes,
  2540. with_standalone_executor=with_standalone_executor,
  2541. )
  2542. if feed:
  2543. # NOTE: don't have to traverse programs in task nodes,
  2544. # since they all sub program of cached program and
  2545. # cached program is also added feed fetch var
  2546. self._feed_data(cached_program, feed, feed_var_name, cached_scope)
  2547. from paddle.optimizer.lr import LRScheduler
  2548. if hasattr(program, 'lr_scheduler'):
  2549. lr_scheduler = program.lr_scheduler
  2550. assert isinstance(lr_scheduler, LRScheduler), "must be LRScheduler"
  2551. lr_value = lr_scheduler()
  2552. lr_var = program.global_block().vars[lr_scheduler._var_name]
  2553. data = np.array([lr_value]).astype(convert_dtype(lr_var.dtype))
  2554. tensor = core.get_variable_tensor(
  2555. cached_scope, lr_scheduler._var_name
  2556. )
  2557. tensor.set(data, self.place)
  2558. self._fleet_executor.run(cache_key)
  2559. if "fetch_var" in fleet_opt:
  2560. # If we speed up the generation in evaluation, we need to generate
  2561. # multiple queries at the same time. Each query will in separate scope in order
  2562. # not mix up. It indicate that final result will in multiple scopes and need to
  2563. # fetch each.
  2564. result_list = []
  2565. for scope in micro_cached_scopes:
  2566. scope_result_list = []
  2567. for varname in fleet_opt["fetch_var"]:
  2568. tensor = None
  2569. try:
  2570. tensor = core.get_variable_tensor(scope, varname)
  2571. if return_numpy:
  2572. tensor = as_numpy(tensor)
  2573. except:
  2574. var = scope.find_var(varname)
  2575. tensor = var.get_lod_tensor_array()
  2576. if return_numpy:
  2577. tensor = as_numpy(tensor)
  2578. else:
  2579. tensor = list(tensor)
  2580. if tensor:
  2581. scope_result_list.append(tensor)
  2582. if scope_result_list:
  2583. result_list.append(scope_result_list)
  2584. return result_list
  2585. if fetch_list:
  2586. arr = cached_scope.find_var(fetch_var_name).get_fetch_list()
  2587. tensors = arr._move_to_list()
  2588. return as_numpy(tensors)
  2589. return None
  2590. def _add_feed_ops(self, program, feed, feed_var_name):
  2591. tmp_program = program.clone()
  2592. global_block = tmp_program.global_block()
  2593. if feed_var_name in global_block.vars:
  2594. feed_var = global_block.var(feed_var_name)
  2595. else:
  2596. feed_var = global_block.create_var(
  2597. name=feed_var_name,
  2598. type=core.VarDesc.VarType.FEED_MINIBATCH,
  2599. persistable=True,
  2600. )
  2601. # prepend feed operators
  2602. if not has_feed_operators(global_block, feed, feed_var_name):
  2603. for i, name in enumerate(feed):
  2604. if global_block.has_var(name):
  2605. out = global_block.var(name)
  2606. global_block._prepend_op(
  2607. type='feed',
  2608. inputs={'X': [feed_var]},
  2609. outputs={'Out': [out]},
  2610. attrs={'col': i},
  2611. )
  2612. else:
  2613. warnings.warn(
  2614. "The variable %s is not found in program. It is not declared or is pruned."
  2615. % name
  2616. )
  2617. return tmp_program
  2618. @classmethod
  2619. def _add_fetch_ops(
  2620. cls, program, fetch_list, fetch_var_name, use_fetch_v2=False
  2621. ):
  2622. tmp_program = program.clone()
  2623. global_block = tmp_program.global_block()
  2624. if fetch_var_name in global_block.vars:
  2625. fetch_var = global_block.var(fetch_var_name)
  2626. else:
  2627. fetch_var = global_block.create_var(
  2628. name=fetch_var_name,
  2629. type=core.VarDesc.VarType.FETCH_LIST,
  2630. persistable=True,
  2631. )
  2632. if use_fetch_v2:
  2633. fetch_op = 'fetch_v2'
  2634. else:
  2635. fetch_op = 'fetch'
  2636. # append fetch_operators
  2637. if not has_fetch_operators(
  2638. global_block, fetch_list, fetch_var_name, fetch_op
  2639. ):
  2640. for i, var in enumerate(fetch_list):
  2641. assert isinstance(
  2642. var, (Variable, str)
  2643. ), f"Wrong type for fetch_list[{i}]: {type(var)}"
  2644. global_block.append_op(
  2645. type=fetch_op,
  2646. inputs={'X': [var]},
  2647. outputs={'Out': [fetch_var]},
  2648. attrs={'col': i},
  2649. )
  2650. return tmp_program
  2651. @classmethod
  2652. def _remove_fetch_ops(cls, program, fetch_op_name='fetch'):
  2653. tmp_program = program.clone()
  2654. global_block = tmp_program.global_block()
  2655. op_num = len(global_block.ops)
  2656. for idx in reversed(range(op_num)):
  2657. if global_block.ops[idx].type == fetch_op_name:
  2658. global_block._remove_op(idx)
  2659. return tmp_program
  2660. def _run_pipeline(
  2661. self,
  2662. program=None,
  2663. dataset=None,
  2664. scope=None,
  2665. thread=0,
  2666. is_infer=False,
  2667. debug=False,
  2668. fetch_list=None,
  2669. fetch_info=None,
  2670. print_period=100,
  2671. fetch_handler=None,
  2672. use_program_cache=False,
  2673. ):
  2674. scope, real_fetch_list, trainer_instance = self._prepare_pipeline_ctx(
  2675. program,
  2676. dataset,
  2677. scope,
  2678. thread,
  2679. is_infer,
  2680. debug,
  2681. fetch_list,
  2682. fetch_info,
  2683. print_period,
  2684. fetch_handler,
  2685. use_program_cache,
  2686. )
  2687. from paddle.optimizer.lr import LRScheduler
  2688. if hasattr(program, 'lr_scheduler'):
  2689. lr_scheduler = program.lr_scheduler
  2690. assert isinstance(lr_scheduler, LRScheduler), "must be LRScheduler"
  2691. lr_value = lr_scheduler()
  2692. lr_var = program.global_block().vars[lr_scheduler._var_name]
  2693. data = np.array([lr_value]).astype(convert_dtype(lr_var.dtype))
  2694. tensor = core.get_variable_tensor(scope, lr_scheduler._var_name)
  2695. tensor.set(data, self.place)
  2696. self._default_executor.run_from_dataset(trainer_instance)
  2697. if not use_program_cache:
  2698. self._default_executor.release_trainer(trainer_instance)
  2699. if real_fetch_list:
  2700. arr = scope.find_var('fetch').get_fetch_list()
  2701. tensors = arr._move_to_list()
  2702. return as_numpy(tensors)
  2703. return None
  2704. def infer_from_dataset(
  2705. self,
  2706. program=None,
  2707. dataset=None,
  2708. scope=None,
  2709. thread=0,
  2710. debug=False,
  2711. fetch_list=None,
  2712. fetch_info=None,
  2713. print_period=100,
  2714. fetch_handler=None,
  2715. ):
  2716. """
  2717. Infer from a pre-defined Dataset. Dataset is defined in paddle.base.dataset.
  2718. Given a program, either a program or compiled program, infer_from_dataset will
  2719. consume all data samples in dataset. Input scope can be given by users. By default,
  2720. scope is global_scope(). The total number of thread run in training is `thread`.
  2721. Thread number used in training will be minimum value of threadnum in Dataset and
  2722. the value of thread in this interface. Debug can be set so that executor will display
  2723. Run-Time for all operators and the throughputs of current infer task.
  2724. The document of infer_from_dataset is almost the same as train_from_dataset,
  2725. except that in distributed training, push gradients will be disabled in infer_from_dataset.
  2726. infer_from_dataset() can be used for evaluation in multi-threadvery easily.
  2727. Args:
  2728. program(Program|CompiledProgram): the program that needs to be run,
  2729. if not provided, then default_main_program (not compiled) will be used.
  2730. dataset(paddle.base.Dataset): dataset created outside this function,
  2731. a user should provide a well-defined dataset before calling this function.
  2732. Please check the document of Dataset if needed. default is None
  2733. scope(Scope): the scope used to run this program, you can switch it to different scope
  2734. for each run. default is global_scope
  2735. thread(int): number of thread a user wants to run in this function. Default is 0, which
  2736. means using thread num of dataset
  2737. debug(bool): whether a user wants to run infer_from_dataset, default is False
  2738. fetch_list(Tensor List): fetch Tensor list, each Tensor will be printed during
  2739. training, default is None
  2740. fetch_info(String List): print information for each Tensor, default is None
  2741. print_period(int): the number of mini-batches for each print, default is 100
  2742. fetch_handler(FetchHandler): a user define class for fetch output.
  2743. Returns:
  2744. None
  2745. Examples:
  2746. .. code-block:: python
  2747. >>> import paddle
  2748. >>> paddle.enable_static()
  2749. >>> place = paddle.CPUPlace() # you can set place = paddle.CUDAPlace(0) to use gpu
  2750. >>> exe = paddle.static.Executor(place)
  2751. >>> x = paddle.static.data(name="x", shape=[None, 10, 10], dtype="int64")
  2752. >>> y = paddle.static.data(name="y", shape=[None, 1], dtype="int64", lod_level=1)
  2753. >>> dataset = paddle.base.DatasetFactory().create_dataset()
  2754. >>> dataset.set_use_var([x, y])
  2755. >>> dataset.set_thread(1)
  2756. >>> # you should set your own filelist, e.g. filelist = ["dataA.txt"]
  2757. >>> filelist = []
  2758. >>> dataset.set_filelist(filelist)
  2759. >>> exe.run(paddle.static.default_startup_program())
  2760. >>> exe.infer_from_dataset(program=paddle.static.default_main_program(),
  2761. ... dataset=dataset)
  2762. """
  2763. return self._run_from_dataset(
  2764. program,
  2765. dataset,
  2766. scope,
  2767. thread,
  2768. True,
  2769. debug,
  2770. fetch_list,
  2771. fetch_info,
  2772. print_period,
  2773. fetch_handler,
  2774. )
  2775. def start_heter_trainer(
  2776. self,
  2777. program=None,
  2778. scope=None,
  2779. debug=False,
  2780. fetch_list=None,
  2781. fetch_info=None,
  2782. print_period=100,
  2783. fetch_handler=None,
  2784. ):
  2785. scope, trainer = self._prepare_trainer(
  2786. program=program,
  2787. dataset=None,
  2788. scope=scope,
  2789. thread=1,
  2790. debug=debug,
  2791. fetch_list=fetch_list,
  2792. fetch_info=fetch_info,
  2793. print_period=print_period,
  2794. )
  2795. trainer._set_infer(False)
  2796. trainer._gen_trainer_desc()
  2797. self._dump_debug_info(program=program, trainer=trainer)
  2798. trainer_instance = self._default_executor.init_for_dataset(
  2799. program.desc, trainer._desc(), scope, None
  2800. )
  2801. # if fetch_handler is not None:
  2802. # scope0 = trainer_instance.get_worker_scope(0)
  2803. # fetch_monitor = FetchHandlerMonitor(scope0, fetch_handler)
  2804. # fetch_monitor.start()
  2805. # self._default_executor.run_from_dataset(trainer_instance)
  2806. # fetch_monitor.stop()
  2807. # self._default_executor.release_trainer(trainer_instance)
  2808. # else:
  2809. self._default_executor.run_from_dataset(trainer_instance)
  2810. # self._default_executor.release_trainer(trainer_instance)
  2811. return trainer_instance
  2812. def train_from_dataset(
  2813. self,
  2814. program=None,
  2815. dataset=None,
  2816. scope=None,
  2817. thread=0,
  2818. debug=False,
  2819. fetch_list=None,
  2820. fetch_info=None,
  2821. print_period=100,
  2822. fetch_handler=None,
  2823. ):
  2824. """
  2825. Train from a pre-defined Dataset. Dataset is defined in paddle.base.dataset.
  2826. Given a program, either a program or compiled program, train_from_dataset will
  2827. consume all data samples in dataset. Input scope can be given by users. By default,
  2828. scope is global_scope(). The total number of thread run in training is `thread`.
  2829. Thread number used in training will be minimum value of threadnum in Dataset and
  2830. the value of thread in this interface. Debug can be set so that executor will display
  2831. Run-Time for all operators and the throughputs of current training task.
  2832. Note: train_from_dataset will destroy all resources created within executor for each run.
  2833. Args:
  2834. program(Program|CompiledProgram): the program that needs to be run,
  2835. if not provided, then default_main_program (not compiled) will be used.
  2836. dataset(paddle.base.Dataset): dataset created outside this function,
  2837. a user should provide a well-defined dataset before calling this function.
  2838. Please check the document of Dataset if needed.
  2839. scope(Scope): the scope used to run this program, you can switch it to different scope
  2840. for each run. default is global_scope
  2841. thread(int): number of thread a user wants to run in this function. Default is 0, which
  2842. means using thread num of dataset
  2843. debug(bool): whether a user wants to run train_from_dataset
  2844. fetch_list(Tensor List): fetch Tensor list, each variable will be printed
  2845. during training
  2846. fetch_info(String List): print information for each Tensor, its length should be equal
  2847. to fetch_list
  2848. print_period(int): the number of mini-batches for each print, default is 100
  2849. fetch_handler(FetchHandler): a user define class for fetch output.
  2850. Returns:
  2851. None
  2852. Examples:
  2853. .. code-block:: python
  2854. >>> import paddle
  2855. >>> paddle.enable_static()
  2856. >>> place = paddle.CPUPlace() # you can set place = paddle.CUDAPlace(0) to use gpu
  2857. >>> exe = paddle.static.Executor(place)
  2858. >>> x = paddle.static.data(name="x", shape=[None, 10, 10], dtype="int64")
  2859. >>> y = paddle.static.data(name="y", shape=[None, 1], dtype="int64", lod_level=1)
  2860. >>> dataset = paddle.base.DatasetFactory().create_dataset()
  2861. >>> dataset.set_use_var([x, y])
  2862. >>> dataset.set_thread(1)
  2863. >>> # you should set your own filelist, e.g. filelist = ["dataA.txt"]
  2864. >>> filelist = []
  2865. >>> dataset.set_filelist(filelist)
  2866. >>> exe.run(paddle.static.default_startup_program())
  2867. >>> exe.train_from_dataset(program=paddle.static.default_main_program(),
  2868. ... dataset=dataset)
  2869. """
  2870. return self._run_from_dataset(
  2871. program,
  2872. dataset,
  2873. scope,
  2874. thread,
  2875. False,
  2876. debug,
  2877. fetch_list,
  2878. fetch_info,
  2879. print_period,
  2880. fetch_handler,
  2881. )