compiler.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216
  1. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import sys
  15. from . import core, framework
  16. from .framework import cpu_places, cuda_places, xpu_places
  17. __all__ = []
  18. BuildStrategy = core.CompiledProgram.BuildStrategy
  19. InferNativeConfig = core.NativeConfig
  20. InferAnalysisConfig = core.AnalysisConfig
  21. DeviceType = core.DeviceType
  22. def _place_obj(place):
  23. p = core.Place()
  24. p.set_place(place)
  25. return p
  26. def _is_pserver_mode(main_program):
  27. main = main_program if main_program else framework.default_main_program()
  28. for op in main.global_block().ops:
  29. if op.type in ["send", "recv"]:
  30. return True
  31. return False
  32. def _has_backward_op(graph):
  33. for node in graph.nodes():
  34. if (
  35. node.is_op()
  36. and node.op() is not None
  37. and node.op().type().endswith("_grad")
  38. ):
  39. return True
  40. return False
  41. def _prune_feed_ops(program):
  42. # prune the feed ops in the program.
  43. pop_idx = []
  44. for i, op in enumerate(program.global_block().ops):
  45. if op.type == "feed":
  46. pop_idx.append(i)
  47. for index in pop_idx[::-1]:
  48. program.global_block()._remove_op(index)
  49. def _has_optimize_op(block):
  50. for op in block.ops:
  51. op_maker = core.op_proto_and_checker_maker
  52. optimize = core.op_proto_and_checker_maker.OpRole.Optimize
  53. if op_maker.kOpRoleVarAttrName() in op.attr_names and int(
  54. op.all_attrs()[op_maker.kOpRoleAttrName()]
  55. ) == int(optimize):
  56. return True
  57. return False
  58. def _should_broadcast_or_not_exists(program, var_name):
  59. block = program.global_block()
  60. var = block.vars.get(var_name, None)
  61. if var is None:
  62. return True
  63. is_distributed = getattr(var, '_is_distributed', False) or getattr(
  64. var, 'is_distributed', False
  65. )
  66. return not is_distributed
  67. class CompiledProgram:
  68. """
  69. :api_attr: Static Graph
  70. The CompiledProgram is used to transform a program or graph for
  71. various optimizations according to the configuration of build_strategy,
  72. for example, the operators' fusion in the computation graph, memory
  73. optimization during the execution of the computation graph, etc.
  74. For more information about build_strategy, please refer to
  75. :code:`paddle.static.BuildStrategy`.
  76. Args:
  77. program_or_graph (Graph|Program): This argument is the Program or Graph
  78. being executed.
  79. build_strategy(BuildStrategy): This argument is used to compile the
  80. program or graph with the specified options, such as operators' fusion
  81. in the computational graph and memory optimization during the execution
  82. of the computational graph. For more information about build_strategy,
  83. please refer to :code:`paddle.static.BuildStrategy`. The default is None.
  84. Returns:
  85. CompiledProgram
  86. Example:
  87. .. code-block:: python
  88. >>> import numpy
  89. >>> import paddle
  90. >>> import paddle.static as static
  91. >>> paddle.enable_static()
  92. >>> place = paddle.CPUPlace()
  93. >>> exe = static.Executor(place)
  94. >>> data = static.data(name='X', shape=[None, 1], dtype='float32')
  95. >>> hidden = static.nn.fc(x=data, size=10)
  96. >>> loss = paddle.mean(hidden)
  97. >>> paddle.optimizer.SGD(learning_rate=0.01).minimize(loss)
  98. >>> exe.run(static.default_startup_program())
  99. >>> compiled_prog = static.CompiledProgram(
  100. ... static.default_main_program())
  101. >>> x = numpy.random.random(size=(10, 1)).astype('float32')
  102. >>> loss_data, = exe.run(compiled_prog,
  103. ... feed={"X": x},
  104. ... fetch_list=[loss.name])
  105. """
  106. def __init__(self, program_or_graph, build_strategy=None):
  107. if isinstance(program_or_graph, core.Graph):
  108. self._graph = program_or_graph
  109. # don't not create a new program here.
  110. self._program = None
  111. elif isinstance(program_or_graph, framework.Program):
  112. _prune_feed_ops(program_or_graph)
  113. self._graph = core.Graph(program_or_graph.desc)
  114. self._program = program_or_graph
  115. else:
  116. raise TypeError(
  117. "The type of program_to_graph parameter is wrong, expected Graph or Program, but received %s"
  118. % type(program_or_graph)
  119. )
  120. self._scope = None
  121. self._place = None
  122. self._executor = None
  123. self._compiled = False
  124. self._is_inference = False
  125. self._share_vars_from = None
  126. self._places = None
  127. self._build_strategy = build_strategy
  128. def _with_inference_optimize(self, config):
  129. """Add inference optimize
  130. Args:
  131. config: instance of `NativeConfig` or `AnalysisConfig` to create predictor
  132. Returns:
  133. self
  134. """
  135. assert (
  136. not self._is_inference
  137. ), "Already compiled with inference, cannot be recompiled."
  138. assert any(
  139. [
  140. isinstance(config, InferNativeConfig),
  141. isinstance(config, InferAnalysisConfig),
  142. ]
  143. )
  144. self._is_inference = True
  145. self._infer_config = config
  146. return self
  147. def _with_distributed(self):
  148. raise NotImplementedError(
  149. "Subclass of CompiledProgram should implement _with_distributed method."
  150. )
  151. def _compile_data_parallel(self, places, use_device, scope=None):
  152. if self._share_vars_from:
  153. if scope:
  154. sys.stderr.write("share_vars_from is set, scope is ignored.\n")
  155. if self._share_vars_from._executor is None:
  156. raise ValueError(
  157. "The shared Program is not compiled and executed, so there is no "
  158. "variables to share."
  159. )
  160. self._local_scopes = self._share_vars_from._executor.local_scopes()
  161. else:
  162. assert scope is not None, ""
  163. self._local_scopes = []
  164. assert isinstance(
  165. places, (list, tuple)
  166. ), f"Currently, The places type can only be list or tuple, but the input type is {type(places)}."
  167. if self._build_strategy is None:
  168. self._build_strategy = BuildStrategy()
  169. self._build_strategy.is_distribution = _is_pserver_mode(self._program)
  170. # TODO(wuyi): trainer endpoints should be passed in through
  171. # build_strategy, not program.xxx.
  172. # TODO(gongwb): let user to set them once.
  173. if (
  174. self._program
  175. and self._build_strategy.num_trainers > 1
  176. and self._program._trainers_endpoints
  177. ):
  178. tps = self._program._trainers_endpoints
  179. assert self._build_strategy.num_trainers == len(
  180. tps
  181. ), "The trainer numbers is not equal to endpoint numbers."
  182. self._build_strategy.trainers_endpoints = tps
  183. if self._program:
  184. self._build_strategy.nccl_comm_num = self._program._nccl_comm_num
  185. self._build_strategy.use_hierarchical_allreduce = (
  186. self._program._use_hierarchical_allreduce
  187. )
  188. self._build_strategy.hierarchical_allreduce_inter_nranks = (
  189. self._program._hierarchical_allreduce_inter_nranks
  190. )
  191. if self._build_strategy.sync_batch_norm:
  192. self._build_strategy.enable_sequential_execution = True
  193. if self._program is not None and self._program._enable_dgc:
  194. assert (
  195. self._build_strategy.num_trainers * len(places) > 1
  196. ), "DGC is not available for single card training."
  197. assert (
  198. self._build_strategy.reduce_strategy
  199. == BuildStrategy.ReduceStrategy.AllReduce
  200. ), "DGC \
  201. only can be used for AllReduce BuildStrategy."
  202. # DGC doesn't support fuse for now, close fuse.
  203. self._build_strategy.fuse_all_reduce_ops = False
  204. self._persistable_vars = []
  205. for node in self._graph.nodes():
  206. if (
  207. node.is_var()
  208. and node.var() is not None
  209. and node.var().persistable()
  210. and node.var().type() != core.VarDesc.VarType.RAW
  211. ):
  212. name = node.name()
  213. if (
  214. self._program is not None
  215. and _should_broadcast_or_not_exists(self._program, name)
  216. ):
  217. self._persistable_vars.append(node.name())
  218. places = list(map(_place_obj, places))
  219. # ParallelExecutor would broadcast all the parameters during initializing.
  220. # The parameters of each process should be in the same ordered for the data-parallelism
  221. # distributed training to keep the broadcast correct.
  222. self._persistable_vars = list(set(self._persistable_vars))
  223. self._persistable_vars.sort()
  224. if core.is_cuda_graph_capturing():
  225. raise RuntimeError(
  226. "CUDA Graph is not allowed to capture when running the first batch."
  227. )
  228. return core.CompiledProgram(
  229. places,
  230. self._persistable_vars,
  231. '',
  232. self._scope,
  233. self._local_scopes,
  234. self._build_strategy,
  235. self._graph,
  236. )
  237. def _compile_inference(self):
  238. return core.create_paddle_predictor(self._infer_config)
  239. def _compile(self, scope, place):
  240. """Compile the program based on the configs.
  241. Args:
  242. scope: The variables (resources) that are associated with
  243. this compiled program.
  244. place: The location that the compiled program will be run on.
  245. Returns:
  246. self
  247. """
  248. if self._compiled:
  249. if scope and self._scope != scope:
  250. raise ValueError("Cannot compile program with different scope.")
  251. if place and not self._place._equals(place):
  252. raise ValueError("Cannot compile program with different place.")
  253. return self
  254. self._compiled = True
  255. self._scope = scope
  256. self._place = place
  257. if self._is_inference:
  258. self._executor = self._compile_inference()
  259. else:
  260. self._places = [self._place]
  261. if isinstance(self._place, core.CUDAPlace):
  262. use_device = DeviceType.CUDA
  263. elif isinstance(self._place, core.XPUPlace):
  264. use_device = DeviceType.XPU
  265. else:
  266. use_device = DeviceType.CPU
  267. self._executor = self._compile_data_parallel(
  268. use_device=use_device, scope=self._scope, places=self._places
  269. )
  270. return self
  271. def _get_places(self, place, place_list):
  272. has_set_place = place_list is not None
  273. if has_set_place:
  274. for p in place_list:
  275. assert (
  276. p._type() == place._type()
  277. ), "Place type not match. You may set wrong type of places."
  278. else:
  279. if isinstance(place, core.CUDAPlace):
  280. place_list = cuda_places()
  281. elif isinstance(place, core.XPUPlace):
  282. place_list = xpu_places()
  283. else:
  284. place_list = cpu_places()
  285. assert place_list, "No places for execution."
  286. return place_list
  287. class IpuDynamicPatcher:
  288. """
  289. Patcher for IPU dynamic2static support.
  290. """
  291. patcher_cache = []
  292. def __init__(self):
  293. pass
  294. @staticmethod
  295. def convert_concrete_program(
  296. ipu_strategy, concrete_program, class_instance=None
  297. ):
  298. """
  299. Convert the ConcreteProgram to IPUConcreteProgram.
  300. """
  301. import paddle
  302. from ..base import backward
  303. from ..base.dygraph.base import switch_to_static_graph
  304. from ..base.framework import device_guard
  305. inputs = concrete_program.inputs
  306. outputs = concrete_program.outputs
  307. startup_program = concrete_program.startup_program
  308. scope = paddle.static.global_scope()
  309. @switch_to_static_graph
  310. def append_backward_desc():
  311. program = concrete_program.main_program
  312. # backward with optimizer to add backward graph to program
  313. backward.gradients_with_optimizer(program, ipu_strategy._optimizer)
  314. # initialize backward parameters
  315. exe = paddle.static.Executor(paddle.CPUPlace())
  316. startup_program = paddle.static.default_startup_program()
  317. exe.run(startup_program)
  318. return program
  319. if ipu_strategy.enable_fp16:
  320. class_instance.to(dtype="float16")
  321. # copy the bias and filters
  322. for param_or_buffer in concrete_program.parameters:
  323. param_or_buffer_tensor = scope.var(
  324. param_or_buffer.name
  325. ).get_tensor()
  326. src_tensor = param_or_buffer.value().get_tensor()
  327. param_or_buffer_tensor._share_data_with(src_tensor)
  328. # TODO(czr): feed and fetch list needs to consider more type
  329. if class_instance:
  330. feed_list = [elem.name for elem in inputs[1:] if elem is not None]
  331. else:
  332. feed_list = [elem.name for elem in inputs if elem is not None]
  333. fetch_list = [elem.name for elem in outputs]
  334. if ipu_strategy.is_training:
  335. concrete_program.main_program = append_backward_desc()
  336. # copy optimizer parameters
  337. optimizer = ipu_strategy._optimizer
  338. for k, v in optimizer._accumulators.items():
  339. for param_name, var_tmp in v.items():
  340. var = optimizer.helper.create_global_variable(
  341. name=var_tmp.name,
  342. persistable=True,
  343. dtype=var_tmp.dtype,
  344. type=var_tmp.type,
  345. shape=var_tmp.shape,
  346. belong_to_optimizer=True,
  347. )
  348. device = optimizer._get_device_for_param(param_name)
  349. with device_guard(device):
  350. optimizer.helper.set_variable_initializer(
  351. var,
  352. initializer=paddle.nn.initializer.Constant(
  353. value=0.0
  354. ),
  355. )
  356. param_or_lr_tensor = scope.find_var(
  357. var_tmp.name
  358. ).get_tensor()
  359. optim_tensor = var.value().get_tensor()
  360. param_or_lr_tensor._share_data_with(optim_tensor)
  361. optimizer._accumulators[k][param_name] = var
  362. @switch_to_static_graph
  363. def func_compile():
  364. if ipu_strategy.enable_fp16:
  365. amp_list = paddle.static.amp.CustomOpLists()
  366. amp_list.unsupported_list = {"cumsum"}
  367. to_fp16_var_names = paddle.static.amp.cast_model_to_fp16(
  368. concrete_program.main_program,
  369. amp_list,
  370. use_fp16_guard=False,
  371. )
  372. paddle.static.amp.cast_parameters_to_fp16(
  373. paddle.CPUPlace(),
  374. concrete_program.main_program,
  375. to_fp16_var_names=to_fp16_var_names,
  376. )
  377. program = IpuCompiledProgram(
  378. concrete_program.main_program,
  379. ipu_strategy=ipu_strategy,
  380. scope=scope,
  381. ).compile(feed_list, fetch_list)
  382. return program
  383. main_program = func_compile()
  384. concrete_program.main_program = main_program
  385. return concrete_program
  386. @staticmethod
  387. def patch_program_cache(ipu_strategy):
  388. """Monkey patch ProgramCache descriptor to support dynamic2static in IPU.
  389. Args:
  390. ipu_strategy: The ipu_strategy used in dynamic graph.
  391. Returns:
  392. None
  393. """
  394. from paddle.jit.dy2static import logging_utils
  395. from paddle.jit.dy2static.partial_program import partial_program_from
  396. from paddle.jit.dy2static.program_translator import (
  397. MAX_TRACED_PROGRAM_COUNT,
  398. CacheKey,
  399. ProgramCache,
  400. )
  401. old_getter = ProgramCache.__getitem__
  402. def patch_getter(self, item):
  403. if not isinstance(item, CacheKey):
  404. raise ValueError(
  405. 'type(item) should be CacheKey, but received %s'
  406. % type(item).__name__
  407. )
  408. item_id = hash(item)
  409. self._recent_key = item_id
  410. if item_id not in self._caches or ipu_strategy.need_compile:
  411. if item_id in self._caches:
  412. logging_utils.warn(
  413. "ipu_strategy chances detected. Please sync weights."
  414. )
  415. if self._caches and not ipu_strategy.need_compile:
  416. logging_utils.warn(
  417. "dynamic2static on IPU doesn't support multiple caches. Please make sure"
  418. "dynamic inputs is not used."
  419. )
  420. concrete_program, _ = self._build_once(item)
  421. concrete_program = IpuDynamicPatcher.convert_concrete_program(
  422. ipu_strategy, concrete_program, item.class_instance
  423. )
  424. self._caches[item_id] = (
  425. concrete_program,
  426. partial_program_from(
  427. concrete_program, item.class_instance is not None
  428. ),
  429. )
  430. # Note: raise warnings if number of traced program is more than `max_tracing_count`
  431. current_tracing_count = len(self._caches)
  432. if current_tracing_count > MAX_TRACED_PROGRAM_COUNT:
  433. logging_utils.warn(
  434. f"Current traced program number: {current_tracing_count} > `max_tracing_count`:{MAX_TRACED_PROGRAM_COUNT}. Too much cached programs will bring expensive overhead. "
  435. "The reason may be: (1) passing tensors with different shapes, (2) passing python objects instead of tensors."
  436. )
  437. return self._caches[item_id]
  438. ProgramCache.__getitem__ = patch_getter
  439. IpuDynamicPatcher.patcher_cache.append(
  440. [ProgramCache, '__getitem__', old_getter]
  441. )
  442. @staticmethod
  443. def patch_lr_scheduler(ipu_strategy):
  444. from paddle.optimizer.lr import LRScheduler
  445. # For IPU dynamic graph usage, lr_var is not synced in executor as static graph mode do.
  446. # Manually set lr to ipu_strategy to update the lr.
  447. old_step = LRScheduler.step
  448. def patch_step(self, epoch=None):
  449. old_step(self, epoch)
  450. ipu_strategy.set_options({"lr": self.last_lr})
  451. LRScheduler.step = patch_step
  452. IpuDynamicPatcher.patcher_cache.append([LRScheduler, 'step', old_step])
  453. @staticmethod
  454. def register_patch(ipu_strategy):
  455. IpuDynamicPatcher.patch_program_cache(ipu_strategy)
  456. IpuDynamicPatcher.patch_lr_scheduler(ipu_strategy)
  457. @staticmethod
  458. def release_patch():
  459. for module, key, attr in IpuDynamicPatcher.patcher_cache:
  460. setattr(module, key, attr)
  461. class IpuStrategy:
  462. """
  463. Help users precisely control the graph building in :code:`paddle.static.IpuCompiledProgram` .
  464. Returns:
  465. The IpuStrategy instance.
  466. Examples:
  467. .. code-block:: python
  468. >>> # doctest: +REQUIRES(env:IPU)
  469. >>> import paddle
  470. >>> import paddle.static as static
  471. >>> paddle.enable_static()
  472. >>> ipu_strategy = static.IpuStrategy()
  473. """
  474. def __init__(self):
  475. if core.is_compiled_with_ipu():
  476. self._ipu_strategy = core.IpuStrategy()
  477. default_options = {
  478. 'location_optimizer': {
  479. 'on_chip': 0,
  480. 'use_replicated_tensor_sharding': 1,
  481. }, # set optimizer location
  482. 'accumulation_and_replication_reduction_type': 1, # popart::ReductionType::Mean
  483. 'mean_accumulation_and_replication_reduction_strategy': 1, # popart::MeanReductionStrategy::Post
  484. }
  485. self._ipu_strategy.set_options(default_options)
  486. self.has_custom_ops = False
  487. self.custom_op_names = []
  488. self.need_compile = True
  489. else:
  490. raise RuntimeError(
  491. "Can not use IpuStrategy in non IPU compiled environment, please re-compile with WITH_IPU=ON."
  492. )
  493. from paddle import in_dynamic_mode
  494. if in_dynamic_mode():
  495. self.register_patch()
  496. def register_patch(self):
  497. """
  498. Register patch function to support dynamic to static on IPU. This operation would break the dy2static functionality on CPU.
  499. Use `release_patch` to release the patch.
  500. Examples:
  501. .. code-block:: python
  502. >>> # doctest: +REQUIRES(env:IPU)
  503. >>> import paddle
  504. >>> import paddle.static as static
  505. >>> ipu_strategy = static.IpuStrategy()
  506. >>> ipu_strategy.register_patch()
  507. """
  508. IpuDynamicPatcher.register_patch(self)
  509. def release_patch(self):
  510. """
  511. Release the registered IPU functions.
  512. Examples:
  513. .. code-block:: python
  514. >>> # doctest: +REQUIRES(env:IPU)
  515. >>> import paddle
  516. >>> import paddle.static as static
  517. >>> ipu_strategy = static.IpuStrategy()
  518. >>> ipu_strategy.release_patch()
  519. """
  520. IpuDynamicPatcher.release_patch()
  521. def set_optimizer(self, optimizer):
  522. """
  523. Set optimizer to ipu_strategy in dynamic mode.
  524. Args:
  525. optimizer (Optimizer): Optimizer to be used in training.
  526. Returns:
  527. None.
  528. Examples:
  529. .. code-block:: python
  530. >>> # doctest: +REQUIRES(env:IPU)
  531. >>> import paddle
  532. >>> import paddle.static as static
  533. >>> linear = paddle.nn.Linear(10, 10)
  534. >>> optimizer = paddle.optimizer.SGD(learning_rate=0.01,
  535. ... parameters=linear.parameters())
  536. >>> ipu_strategy = static.IpuStrategy()
  537. >>> ipu_strategy.set_optimizer(optimizer)
  538. """
  539. from paddle import in_dynamic_mode
  540. if in_dynamic_mode():
  541. self._optimizer = optimizer
  542. optimizer_attrs = self.parse_optimizer(optimizer)
  543. self._ipu_strategy.set_options(optimizer_attrs)
  544. else:
  545. raise RuntimeError("Only needs to set optimizer in dynamic mode.")
  546. def parse_optimizer(self, optimizer):
  547. """
  548. Parse optimizer attributes for IPU dynamic to static support. Currently only support parse lr.
  549. Args:
  550. optimizer (Optimizer): Optimizer to be parsed.
  551. Returns:
  552. Dict.
  553. Examples:
  554. .. code-block:: python
  555. >>> # doctest: +REQUIRES(env:IPU)
  556. >>> import paddle
  557. >>> import paddle.static as static
  558. >>> linear = paddle.nn.Linear(10, 10)
  559. >>> optimizer = paddle.optimizer.SGD(learning_rate=0.01,
  560. ... parameters=linear.parameters())
  561. >>> ipu_strategy = static.IpuStrategy()
  562. >>> attrs = ipu_strategy.parse_optimizer(optimizer)
  563. """
  564. def get_lr():
  565. from paddle.optimizer.lr import LRScheduler
  566. if isinstance(optimizer._learning_rate, float):
  567. return {"lr": optimizer._learning_rate}
  568. elif isinstance(optimizer._learning_rate, LRScheduler):
  569. return {"lr": optimizer._learning_rate()}
  570. attr_fn = [get_lr]
  571. optimizer_attrs = {"is_dynamic": True}
  572. for fn in attr_fn:
  573. optimizer_attrs.update(fn())
  574. return optimizer_attrs
  575. def set_graph_config(
  576. self,
  577. num_ipus=1,
  578. is_training=True,
  579. micro_batch_size=1,
  580. enable_manual_shard=False,
  581. ):
  582. """
  583. Set graph configuration to the IpuStrategy instance.
  584. Args:
  585. num_ipus (int, optional): Number of IPU devices. Default 1, which means only use 1 IPU.
  586. is_training (bool, optional): True is training graph, False is inference graph. Default True, which means is training mode.
  587. batch_size (int, optional): The batch-size in the graph. Used to make the graph batch-size fixed,
  588. if the batch-size in the graph is dynamic. Default 1, which means the batch-size would be set 1, if the batch-size is dynamic.
  589. enable_manual_shard (bool, optional): Enable graph sharding or not. Only if num_ipus > 1, enable_manual_shard is able to be set True.
  590. Default False, which means disabled.
  591. Returns:
  592. None.
  593. Examples:
  594. .. code-block:: python
  595. >>> # doctest: +REQUIRES(env:IPU)
  596. >>> import paddle
  597. >>> import paddle.static as static
  598. >>> paddle.enable_static()
  599. >>> ipu_strategy = static.IpuStrategy()
  600. >>> ipu_strategy.set_graph_config(num_ipus=1,
  601. ... is_training=True,
  602. ... micro_batch_size=1,
  603. ... enable_manual_shard=False)
  604. """
  605. if num_ipus == 1 and enable_manual_shard:
  606. raise RuntimeError(
  607. "Only if num_ipus > 1, enable_manual_shard is able to be set True."
  608. )
  609. options = {
  610. 'num_ipus': num_ipus,
  611. 'is_training': is_training,
  612. 'micro_batch_size': micro_batch_size,
  613. 'enable_manual_shard': enable_manual_shard,
  614. }
  615. self.set_options(options)
  616. def set_pipelining_config(
  617. self,
  618. enable_pipelining=False,
  619. batches_per_step=1,
  620. enable_gradient_accumulation=False,
  621. accumulation_factor=1,
  622. ):
  623. """
  624. Set pipelining configuration to the IpuStrategy instance. Used to optimize the throughput performance.
  625. Args:
  626. enable_pipelining (bool, optional): Enable data pipelining between subgraphs. Only if enable_manual_shard=True, enable_pipelining is able to be set True.
  627. Default False, which means disabled.
  628. batches_per_step (int, optional): Set the batches per run in data pipelining mode. Only if enable_pipelining=True, batches_per_step is able to be set > 1.
  629. Default 1, which means no data pipelining.
  630. enable_gradient_accumulation (bool, optional): Enable to accumulate gradients before updating the weights in training mode. Only if enable_pipelining=True,
  631. enable_gradient_accumulation is able to be set True. Default False, which means no gradient accumulation.
  632. accumulation_factor (int, optional): Specify the number of micro-batches to accumulate
  633. before applying the varUpdate. Default 1, which means disable the accumulation.
  634. Returns:
  635. None.
  636. Examples:
  637. .. code-block:: python
  638. >>> # doctest: +REQUIRES(env:IPU)
  639. >>> import paddle
  640. >>> import paddle.static as static
  641. >>> paddle.enable_static()
  642. >>> ipu_strategy = static.IpuStrategy()
  643. >>> ipu_strategy.set_pipelining_config(enable_pipelining=False,
  644. ... batches_per_step=1,
  645. ... enable_gradient_accumulation=False,
  646. ... accumulation_factor=1)
  647. """
  648. enable_manual_shard = self.get_option('enable_manual_shard')
  649. if not enable_manual_shard and enable_pipelining:
  650. raise RuntimeError(
  651. "Only if enable_manual_shard=True, enable_pipelining is able to be set True."
  652. )
  653. options = {
  654. 'enable_pipelining': enable_pipelining,
  655. 'batches_per_step': batches_per_step,
  656. 'enable_gradient_accumulation': enable_gradient_accumulation,
  657. 'accumulation_factor': accumulation_factor,
  658. }
  659. self.set_options(options)
  660. def set_precision_config(self, enable_fp16=False):
  661. """
  662. Set half computation configuration to the IpuStrategy instance. Used to optimize the performance.
  663. Args:
  664. enable_fp16 (bool, optional): Enable FLOAT16 mode and transform FLOAT32 to FLOAT16. Default False, which means disable FLOAT16 mode.
  665. Returns:
  666. None.
  667. Examples:
  668. .. code-block:: python
  669. >>> # doctest: +REQUIRES(env:IPU)
  670. >>> import paddle
  671. >>> import paddle.static as static
  672. >>> paddle.enable_static()
  673. >>> ipu_strategy = static.IpuStrategy()
  674. >>> ipu_strategy.set_precision_config(enable_fp16=False)
  675. """
  676. options = {
  677. 'enable_fp16': enable_fp16,
  678. }
  679. self.set_options(options)
  680. def add_custom_op(
  681. self, paddle_op, popart_op=None, domain='custom.ops', version=1
  682. ):
  683. """
  684. Add a mapping to use popart custom ops running on the IPU.
  685. Args:
  686. paddle_op(str): the name of custom op in paddle.
  687. popart_op(str): the name of custom op in popart.
  688. domain(str): domain name of custom op in popart.
  689. version(int): version of custom op in popart.
  690. Returns:
  691. None.
  692. Examples:
  693. .. code-block:: python
  694. >>> # doctest: +REQUIRES(env:IPU)
  695. >>> import paddle
  696. >>> import paddle.static as static
  697. >>> paddle.enable_static()
  698. >>> ipu_strategy = static.IpuStrategy()
  699. >>> ipu_strategy.add_custom_op('paddle_relu', 'popart_relu')
  700. """
  701. if popart_op is None:
  702. popart_op = paddle_op
  703. custom_op = {
  704. 'paddle_op': paddle_op,
  705. 'popart_op': popart_op,
  706. 'domain': domain,
  707. 'version': version,
  708. }
  709. self.set_options({'custom_op': custom_op})
  710. self.custom_op_names.append(paddle_op)
  711. if not self.has_custom_ops:
  712. self.has_custom_ops = True
  713. def set_options(self, options):
  714. """
  715. Set options from dict.
  716. Args:
  717. options(dict): dict of options.
  718. Returns:
  719. None.
  720. Examples:
  721. .. code-block:: python
  722. >>> # doctest: +REQUIRES(env:IPU)
  723. >>> import paddle
  724. >>> import paddle.static as static
  725. >>> paddle.enable_static()
  726. >>> ipu_strategy = static.IpuStrategy()
  727. >>> options = {'num_ipus':1, 'enable_fp16': True}
  728. >>> ipu_strategy.set_options(options)
  729. """
  730. self._ipu_strategy.set_options(options)
  731. # check whether to recompile program with updated ipu options.
  732. recompile_white_list = {'lr'}
  733. if options.keys() - recompile_white_list:
  734. self.need_compile = True
  735. def get_option(self, option):
  736. """
  737. Get option.
  738. Args:
  739. option(str): name of option.
  740. Returns:
  741. option value.
  742. Examples:
  743. .. code-block:: python
  744. >>> # doctest: +REQUIRES(env:IPU)
  745. >>> import paddle
  746. >>> import paddle.static as static
  747. >>> paddle.enable_static()
  748. >>> ipu_strategy = static.IpuStrategy()
  749. >>> num_ipus = ipu_strategy.get_option('num_ipus')
  750. """
  751. return self._ipu_strategy.get_option(option)['value']
  752. def enable_pattern(self, pattern):
  753. """
  754. Enable PopART pattern to optimize the graph.
  755. Args:
  756. pattern(string): the name of the pattern.
  757. Returns:
  758. None.
  759. Examples:
  760. .. code-block:: python
  761. >>> # doctest: +REQUIRES(env:IPU)
  762. >>> import paddle
  763. >>> import paddle.static as static
  764. >>> paddle.enable_static()
  765. >>> ipu_strategy = static.IpuStrategy()
  766. >>> ipu_strategy.enable_pattern("ViewSimplifyPattern")
  767. """
  768. self._ipu_strategy.enable_pattern(pattern)
  769. def disable_pattern(self, pattern):
  770. """
  771. Disable PopART pattern.
  772. Args:
  773. pattern(string): the name of the pattern.
  774. Returns:
  775. None.
  776. Examples:
  777. .. code-block:: python
  778. >>> # doctest: +REQUIRES(env:IPU)
  779. >>> import paddle
  780. >>> import paddle.static as static
  781. >>> paddle.enable_static()
  782. >>> ipu_strategy = static.IpuStrategy()
  783. >>> ipu_strategy.disable_pattern("ViewSimplifyPattern")
  784. """
  785. self._ipu_strategy.disable_pattern(pattern)
  786. @property
  787. def num_ipus(self):
  788. """
  789. Get the number of IPU devices from IpuStrategy instance.
  790. """
  791. return self.get_option('num_ipus')
  792. @property
  793. def is_training(self):
  794. """
  795. Get the boolean of training or inference from IpuStrategy instance.
  796. """
  797. return self.get_option('is_training')
  798. @property
  799. def enable_pipelining(self):
  800. """
  801. Get the boolean of enable pipelining or not from IpuStrategy instance.
  802. """
  803. return self.get_option('enable_pipelining')
  804. @property
  805. def enable_fp16(self):
  806. """
  807. Get the boolean of float16 mode or not from IpuStrategy instance.
  808. """
  809. return self.get_option('enable_fp16')
  810. class IpuCompiledProgram:
  811. """
  812. The IpuCompiledProgram is used to transform a program to a ipu-target program,
  813. such as forward graph extraction, computing graph transformation, useless scale Ops clean, etc.
  814. Args:
  815. program(Program, optional): This parameter represents the :code:`Program`
  816. to be executed. Default is None, which means the program will be set to
  817. the default program :code:`paddle.static.default_main_program()` .
  818. scope(Scope, optional): The scope used to run this program, you can switch
  819. it to different scope. Default is None, which means use the global
  820. scope :code:`paddle.static.global_scope()` .
  821. ipu_strategy(IpuStrategy, optional): This argument is used to build the program with the
  822. specified options, such as half computation, training or inference session, the number of IPUs, etc.
  823. Default is None, which means build the program based on the default `ipu_strategy`.
  824. Returns:
  825. IpuCompiledProgram
  826. Example:
  827. .. code-block:: python
  828. >>> # doctest: +REQUIRES(env:IPU)
  829. >>> import paddle
  830. >>> import paddle.static as static
  831. >>> paddle.enable_static()
  832. >>> a = static.data(name='data', shape=[None, 1], dtype='int32')
  833. >>> b = a + 1
  834. >>> main_prog = static.default_main_program()
  835. >>> ipu_strategy = static.IpuStrategy()
  836. >>> ipu_strategy.set_graph_config(num_ipus=1, is_training=True, micro_batch_size=1)
  837. >>> ipu_strategy.set_pipelining_config(enable_pipelining=False, batches_per_step=1, enable_gradient_accumulation=False, accumulation_factor=1)
  838. >>> ipu_strategy.set_precision_config(enable_fp16=False)
  839. >>> ipu_compiled_program = static.IpuCompiledProgram(
  840. ... main_prog,
  841. ... ipu_strategy=ipu_strategy)
  842. """
  843. def __init__(self, program=None, scope=None, ipu_strategy=None):
  844. if not core.is_compiled_with_ipu():
  845. raise ValueError(
  846. "Can not use this function since PaddlePaddle is not compiled with IPU"
  847. )
  848. if program is None:
  849. program = framework.default_main_program()
  850. if not isinstance(program, framework.Program):
  851. raise TypeError(
  852. "The type of program is wrong, expected Program, but got %s"
  853. % type(program)
  854. )
  855. self._program = program
  856. self._compiled = False
  857. if scope is not None:
  858. self._scope = scope
  859. else:
  860. # import here to avoiding confused
  861. import paddle
  862. self._scope = paddle.static.global_scope()
  863. if ipu_strategy is not None:
  864. self._ipu_strategy = ipu_strategy
  865. else:
  866. self._ipu_strategy = IpuStrategy()
  867. if ipu_strategy.has_custom_ops:
  868. self._custom_op_names = set(ipu_strategy.custom_op_names)
  869. else:
  870. self._custom_op_names = ()
  871. self._backend = core.IpuBackend.get_instance()
  872. def compile(self, feed_list, fetch_list):
  873. """
  874. This interface is used to compile the input Program to a program
  875. to run the model on the ipu.
  876. Args:
  877. feed_list(list): This parameter represents the input Tensors of the model.
  878. fetch_list(list): This parameter represents the Tensors that need to be returned
  879. after the model.
  880. Returns:
  881. Program
  882. Example:
  883. .. code-block:: python
  884. >>> # doctest: +REQUIRES(env:IPU)
  885. >>> import paddle
  886. >>> import paddle.static as static
  887. >>> paddle.enable_static()
  888. >>> a = static.data(name='data', shape=[None, 1], dtype='int32')
  889. >>> b = a + 1
  890. >>> main_prog = static.default_main_program()
  891. >>> ipu_strategy = static.IpuStrategy()
  892. >>> ipu_strategy.set_graph_config(num_ipus=1, is_training=True, micro_batch_size=1)
  893. >>> ipu_strategy.set_pipelining_config(enable_pipelining=False, batches_per_step=1, enable_gradient_accumulation=False, accumulation_factor=1)
  894. >>> ipu_strategy.set_precision_config(enable_fp16=False)
  895. >>> program = static.IpuCompiledProgram(
  896. ... main_prog,
  897. ... ipu_strategy=ipu_strategy).compile([a.name], [b.name])
  898. """
  899. self._backend.set_scope(self._scope)
  900. self._backend.set_ipu_strategy(self._ipu_strategy._ipu_strategy)
  901. # feed and fetch doesn't have corresponding popart op, so we rm both here
  902. global_block = self._program.global_block()
  903. need_to_remove_op_index = []
  904. for i, op in enumerate(global_block.ops):
  905. op.desc.set_is_target(False)
  906. if op.type == 'feed' or op.type == 'fetch':
  907. need_to_remove_op_index.append(i)
  908. for index in need_to_remove_op_index[::-1]:
  909. global_block._remove_op(index)
  910. for var in ['feed', 'fetch']:
  911. if global_block.has_var(var):
  912. global_block._remove_var(var)
  913. self._program.desc.flush()
  914. self._graph = core.Graph(self._program.desc)
  915. if self._ipu_strategy.is_training:
  916. passes = [
  917. 'optimizer_extract_pass',
  918. 'optimizer_state_align_pass',
  919. ]
  920. for pass_name in passes:
  921. a_pass = core.get_pass(pass_name)
  922. a_pass.apply(self._graph)
  923. passes = [
  924. 'forward_graph_extract_pass',
  925. 'infer_shape_pass',
  926. 'avg_shard_pass',
  927. 'delete_scale_op_pass',
  928. ]
  929. for pass_name in passes:
  930. a_pass = core.get_pass(pass_name)
  931. if pass_name == 'infer_shape_pass':
  932. a_pass.set('feed_list', feed_list)
  933. a_pass.apply(self._graph)
  934. a_pass = core.get_pass('popart_canonicalization_pass')
  935. if self._custom_op_names:
  936. a_pass.set('custom_ops', self._custom_op_names)
  937. a_pass.apply(self._graph)
  938. passes = [
  939. 'ipu_inplace_pass',
  940. 'ipu_graph_builder_pass',
  941. 'ipu_runtime_replacer_pass',
  942. ]
  943. for pass_name in passes:
  944. a_pass = core.get_pass(pass_name)
  945. a_pass.set('feed_list', feed_list)
  946. a_pass.set('fetch_list', fetch_list)
  947. a_pass.apply(self._graph)
  948. convert_pass = core.get_pass('graph_to_program_pass')
  949. desc = core.ProgramDesc()
  950. convert_pass.set_not_owned('program', desc)
  951. convert_pass.apply(self._graph)
  952. program = framework.Program._construct_from_desc(desc)
  953. if hasattr(self._program, 'lr_scheduler'):
  954. # how to share var between two different block ?
  955. lr_var_name = self._program.lr_scheduler._var_name
  956. program.lr_scheduler = self._program.lr_scheduler
  957. # Program.clone will clone lr_scheduler, so i set lr_var as
  958. # lr_scheduler attribute
  959. global_block = self._program.global_block()
  960. program.lr_scheduler.lr_var = global_block.vars[lr_var_name]
  961. # with popart, we need to support batches_per_step, what means
  962. # the shape of feed_var and feed_tensor(maybe numpy array) will
  963. # mismatch, so we set need_check_feed to False. Thus we can avoid
  964. # modify logic of run.
  965. program_global_block = program.global_block()
  966. for feed_name in feed_list:
  967. feed_var = program_global_block.var(feed_name)
  968. feed_var.desc.set_need_check_feed(False)
  969. if not hasattr(program, 'org_program'):
  970. program.org_program = self._program
  971. self._ipu_strategy.need_compile = False
  972. return program