fleet.py 67 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import os
  16. import time
  17. import paddle
  18. from paddle.base import compiler
  19. from paddle.base.wrapped_decorator import wrap_decorator
  20. from paddle.framework import _global_flags, in_dynamic_mode
  21. from paddle.framework.ir import apply_build_strategy
  22. from .base import topology as tp
  23. from .base.distributed_strategy import DistributedStrategy
  24. from .base.meta_optimizer_factory import MetaOptimizerFactory
  25. from .base.role_maker import PaddleCloudRoleMaker, RoleMakerBase
  26. from .base.runtime_factory import RuntimeFactory
  27. from .base.strategy_compiler import StrategyCompiler
  28. from .meta_parallel import model_parallel_random_seed
  29. from .utils.log_util import logger, set_log_level
  30. __all__ = []
  31. def apply_ir_passes(main_program, startup_program, config):
  32. build_strategy = config._user_defined_strategy.build_strategy._copy()
  33. if not _global_flags()['FLAGS_apply_pass_to_program']:
  34. return build_strategy
  35. pipeline_opt = getattr(main_program, "_pipeline_opt", {})
  36. if pipeline_opt:
  37. main_program = pipeline_opt["section_program"]
  38. startup_program = startup_program._pipeline_opt["startup_program"]
  39. pass_attrs = {"use_cuda": config._is_collective}
  40. fuse_all_reduce = config._user_defined_strategy.fuse_all_reduce_ops
  41. if fuse_all_reduce and build_strategy.fuse_all_optimizer_ops:
  42. # FIXME(zjl): currently, fuse_all_optimizer_ops
  43. # have conflict with fuse_all_reduce_ops because
  44. # RawProgramOptimizer also inserts coalesce_tensor
  45. # into program. These two procedures may conflict
  46. # in which vars are to be fused.
  47. logger.warning(
  48. 'Currently, the fuse_all_optimizer_ops pass has conflict with fuse_all_reduce_ops pass. Disable the fuse_all_optimizer_ops pass temporarily.'
  49. )
  50. build_strategy.fuse_all_optimizer_ops = False
  51. return apply_build_strategy(
  52. main_program, startup_program, build_strategy, pass_attrs
  53. )
  54. def _inited_runtime_handler_(func):
  55. def __impl__(*args, **kwargs):
  56. cls = args[0]
  57. if cls._runtime_handle is None:
  58. raise ValueError("Fleet can not find suitable runtime handler")
  59. return func(*args, **kwargs)
  60. return __impl__
  61. def _is_non_distributed_check_(func):
  62. def __impl__(*args, **kwargs):
  63. cls = args[0]
  64. if (
  65. cls._role_maker is not None
  66. and cls._role_maker._is_non_distributed() is True
  67. ):
  68. logger.warning(
  69. "%s() function doesn't work when use non_distributed fleet."
  70. % (func.__name__)
  71. )
  72. return
  73. return func(*args, **kwargs)
  74. return __impl__
  75. inited_runtime_handler = wrap_decorator(_inited_runtime_handler_)
  76. is_non_distributed_check = wrap_decorator(_is_non_distributed_check_)
  77. class Fleet:
  78. """
  79. Unified API for distributed training of PaddlePaddle.
  80. Please reference the https://github.com/PaddlePaddle/PaddleFleetX for details
  81. Returns:
  82. Fleet: A Fleet instance
  83. Examples:
  84. .. code-block:: python
  85. :name: code-example1
  86. >>> # Example1: for collective training
  87. >>> import paddle
  88. >>> paddle.enable_static()
  89. >>> import paddle.distributed.fleet as fleet
  90. >>> fleet.init(is_collective=True)
  91. >>> strategy = fleet.DistributedStrategy()
  92. >>> linear = paddle.nn.Linear(10, 10)
  93. >>> optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters=linear.parameters())
  94. >>> optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
  95. >>> # do distributed training
  96. .. code-block:: python
  97. :name: code-example2
  98. >>> # Example2: for parameter server training
  99. >>> import paddle
  100. >>> paddle.enable_static()
  101. >>> import paddle.distributed.fleet as fleet
  102. >>> strategy = fleet.DistributedStrategy()
  103. >>> fleet.init(strategy=strategy)
  104. >>> optimizer = paddle.optimizer.SGD(learning_rate=0.001)
  105. >>> optimizer = fleet.distributed_optimizer(optimizer)
  106. >>> if fleet.is_first_worker():
  107. ... print("this is first worker")
  108. >>> print("current node index: {}".format(fleet.worker_index()))
  109. >>> print("total number of worker num: {}".format(fleet.worker_num()))
  110. >>> if fleet.is_worker():
  111. ... print("this is worker")
  112. >>> print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True)))
  113. >>> print("server num: {}".format(fleet.server_num()))
  114. >>> print("server endpoints: {}".format(fleet.server_endpoints(to_string=True)))
  115. >>> if fleet.is_server():
  116. ... print("this is server")
  117. >>> fleet.stop_worker()
  118. """
  119. def __init__(self):
  120. self._role_maker = None
  121. self.strategy_compiler = None
  122. self._is_collective = False
  123. self._runtime_handle = None
  124. self._util = None
  125. self._context = {}
  126. self.user_defined_optimizer = paddle.optimizer.Optimizer(0.0)
  127. def init(
  128. self,
  129. role_maker=None,
  130. is_collective=False,
  131. strategy=None,
  132. log_level="INFO",
  133. ):
  134. """
  135. Initialize role_maker in Fleet.
  136. This function is responsible for the distributed architecture
  137. what you want to run your code behind.
  138. Args:
  139. role_maker (RoleMakerBase, optional): A ``RoleMakerBase`` containing the configuration
  140. of environment variables related to distributed training.If you did not initialize
  141. the rolemaker by yourself, it will be automatically initialized to PaddleRoleMaker.
  142. The default value is None.
  143. is_collective (Boolean, optional): A ``Boolean`` variable determines whether the program
  144. runs on Collective mode or ParameterServer mode. True means the program runs on
  145. Collective mode, and False means running on ParameterServer mode. The default value
  146. is False.
  147. strategy (DistributedStrategy): Extra properties for distributed training.
  148. For details, please refer to paddle.distributed.fleet.DistributedStrategy. Default: None.
  149. log_level (Integer, String, optional): A ``Integer`` or ``String`` Variable determining how hight
  150. the logging level is. Default is "INFO".
  151. Returns:
  152. None
  153. Examples:
  154. .. code-block:: python
  155. :name: code-init-example1
  156. >>> import paddle.distributed.fleet as fleet
  157. >>> fleet.init()
  158. .. code-block:: python
  159. :name: code-init-example2
  160. >>> import paddle.distributed.fleet as fleet
  161. >>> fleet.init(is_collective=True)
  162. .. code-block:: python
  163. :name: code-init-example3
  164. >>> import paddle.distributed.fleet as fleet
  165. >>> role = fleet.PaddleCloudRoleMaker()
  166. >>> fleet.init(role)
  167. .. code-block:: python
  168. :name: code-init-example4
  169. >>> import paddle.distributed.fleet as fleet
  170. >>> strategy = fleet.DistributedStrategy()
  171. >>> fleet.init(strategy=strategy)
  172. .. code-block:: python
  173. :name: code-init-example5
  174. >>> import paddle.distributed.fleet as fleet
  175. >>> strategy = fleet.DistributedStrategy()
  176. >>> fleet.init(log_level = "DEBUG")
  177. """
  178. from paddle.distributed import parallel_helper
  179. set_log_level(log_level)
  180. if strategy is None:
  181. strategy = DistributedStrategy()
  182. self._user_defined_strategy = copy.deepcopy(strategy)
  183. if role_maker is None:
  184. if isinstance(is_collective, bool):
  185. self._is_collective = is_collective
  186. self._role_maker = PaddleCloudRoleMaker(
  187. is_collective=self._is_collective
  188. )
  189. else:
  190. raise ValueError(
  191. f"`is_collective` should be instance of `bool`, but got {type(is_collective)}"
  192. )
  193. else:
  194. if isinstance(role_maker, RoleMakerBase):
  195. self._role_maker = role_maker
  196. self._is_collective = role_maker._is_collective
  197. else:
  198. raise ValueError(
  199. f"`role_maker` should be subclass of `RoleMakerBase`, but got {type(role_maker)}"
  200. )
  201. self._role_maker._generate_role()
  202. from paddle.distributed import fleet
  203. fleet.util._set_role_maker(self._role_maker)
  204. self.strategy_compiler = StrategyCompiler()
  205. if in_dynamic_mode():
  206. if parallel_helper._is_parallel_ctx_initialized():
  207. logger.warning(
  208. "The dygraph parallel environment has been initialized."
  209. )
  210. else:
  211. # FLAGS_nccl_nrings is used for dynamic graph multi-stream communication
  212. if "FLAGS_nccl_nrings" in os.environ:
  213. logger.warning(
  214. "You have set the environment variable FLAGS_nccl_nrings "
  215. "outside the program, so the nccl_comm_num in "
  216. "DistributedStrategy will not take effect here."
  217. )
  218. else:
  219. os.environ["FLAGS_nccl_nrings"] = str(
  220. self._user_defined_strategy.nccl_comm_num
  221. )
  222. paddle.distributed.init_parallel_env()
  223. # hybrid parallel not support for npu/xpu
  224. if not self._user_defined_strategy.heter_ccl_mode:
  225. # init hybrid parallel environment in dygraph
  226. if tp._HYBRID_PARALLEL_GROUP is None:
  227. self._init_hybrid_parallel_env()
  228. else:
  229. logger.warning(
  230. "The dygraph hybrid parallel environment has been initialized."
  231. )
  232. elif self._is_collective:
  233. use_sharding = self._user_defined_strategy.sharding
  234. # global group
  235. global_rank = self.worker_index()
  236. global_world_size = self.worker_num()
  237. # NOTE(wangxi): see sharding_optimizer
  238. global_ring_id = 3 if use_sharding else 0
  239. global_ranks = list(range(global_world_size))
  240. if tp._HYBRID_PARALLEL_GROUP is None:
  241. tp._CommunicateGroup()
  242. cg = tp._HYBRID_PARALLEL_GROUP
  243. self._hcg = cg
  244. cg.set_comm_group(
  245. 'global',
  246. global_rank,
  247. global_world_size,
  248. global_ring_id,
  249. global_ranks,
  250. )
  251. use_tensor_parallel = self._user_defined_strategy.tensor_parallel
  252. use_mp = use_sharding or use_tensor_parallel
  253. # hybrid group
  254. if use_mp is False:
  255. return
  256. mp_degree_sharding = 1
  257. mp_degree_tensor_parallel = 1
  258. if use_sharding:
  259. sharding_configs = self._user_defined_strategy.sharding_configs
  260. mp_degree_sharding = int(sharding_configs['mp_degree'])
  261. if use_tensor_parallel:
  262. tensor_parallel_configs = (
  263. self._user_defined_strategy.tensor_parallel_configs
  264. )
  265. mp_degree_tensor_parallel = int(
  266. tensor_parallel_configs['tensor_parallel_degree']
  267. )
  268. if use_sharding and use_tensor_parallel:
  269. assert mp_degree_sharding == mp_degree_tensor_parallel
  270. mp_degree = (
  271. mp_degree_sharding
  272. if use_sharding
  273. else mp_degree_tensor_parallel
  274. )
  275. if mp_degree > 1:
  276. assert global_world_size % mp_degree == 0
  277. # NOTE(wangxi): mp_ring_id sync with sharding_optimizer.py _build_groups
  278. mp_ring_id = 0
  279. mp_rank = global_rank % mp_degree
  280. mp_group_id = global_rank // mp_degree
  281. mp_group_ranks = [
  282. idx
  283. for idx in global_ranks
  284. if idx // mp_degree == mp_group_id
  285. ]
  286. cg.set_comm_group(
  287. 'model', mp_rank, mp_degree, mp_ring_id, mp_group_ranks
  288. )
  289. return self
  290. # test allreduce perf
  291. def allreduce_perf(
  292. self,
  293. iteration,
  294. x,
  295. group,
  296. perf_size,
  297. perf_threshold_time,
  298. warmup=False,
  299. ):
  300. if group is None or group.nranks <= 1:
  301. logger.warning("allreduce_perf is invalid, group invalid!")
  302. return
  303. paddle.distributed.barrier()
  304. paddle.device.cuda.synchronize()
  305. start_t = time.time()
  306. for _ in range(iteration):
  307. paddle.distributed.all_reduce(x, group=group)
  308. paddle.device.cuda.synchronize()
  309. end_t = time.time()
  310. ret = (end_t - start_t) / iteration
  311. if warmup:
  312. return
  313. logger.info(
  314. f"[AllReduceTest] nbytes {perf_size}B test result: {ret} s/iter"
  315. )
  316. if perf_threshold_time > -1 and ret > perf_threshold_time:
  317. logger.warning(
  318. f"[Perf Warning] AllReduce Test Timeout! {ret} > {perf_threshold_time}"
  319. )
  320. # test reduce perf
  321. def reduce_perf(self, iteration, x, group, perf_size, perf_threshold_time):
  322. if group is None or group.nranks <= 1:
  323. logger.warning("reduce_perf is invalid, group invalid!")
  324. return
  325. paddle.distributed.barrier()
  326. paddle.device.cuda.synchronize()
  327. start_t = time.time()
  328. for _ in range(iteration):
  329. paddle.distributed.reduce(x, dst=min(group.ranks), group=group)
  330. paddle.device.cuda.synchronize()
  331. end_t = time.time()
  332. ret = (end_t - start_t) / iteration
  333. logger.info(
  334. f"[ReduceTest] nbytes {perf_size}B test result: {ret} s/iter"
  335. )
  336. if perf_threshold_time > -1 and ret > perf_threshold_time:
  337. logger.warning(
  338. f"[Perf Warning] Reduce Test Timeout! {ret} > {perf_threshold_time}"
  339. )
  340. # test broadcast perf
  341. def broadcast_perf(
  342. self, iteration, x, group, perf_size, perf_threshold_time
  343. ):
  344. if group is None or group.nranks <= 1:
  345. logger.warning("broadcast_perf is invalid, group invalid!")
  346. return
  347. paddle.distributed.barrier()
  348. paddle.device.cuda.synchronize()
  349. start_t = time.time()
  350. for _ in range(iteration):
  351. paddle.distributed.broadcast(x, src=min(group.ranks), group=group)
  352. paddle.device.cuda.synchronize()
  353. end_t = time.time()
  354. ret = (end_t - start_t) / iteration
  355. logger.info(
  356. f"[BroadcastTest] nbytes {perf_size}B test result: {ret} s/iter"
  357. )
  358. if perf_threshold_time > -1 and ret > perf_threshold_time:
  359. logger.warning(
  360. f"[Perf Warning] Broadcast Test Timeout! {ret} > {perf_threshold_time}"
  361. )
  362. # test allgather perf
  363. def allgather_perf(
  364. self, iteration, x, group, perf_size, perf_threshold_time
  365. ):
  366. if group is None or group.nranks <= 1:
  367. logger.warning("allgather_perf is invalid, group invalid!")
  368. return
  369. paddle.distributed.barrier()
  370. paddle.device.cuda.synchronize()
  371. start_t = time.time()
  372. for _ in range(iteration):
  373. tmp = []
  374. paddle.distributed.all_gather(tmp, x, group=group)
  375. paddle.device.cuda.synchronize()
  376. end_t = time.time()
  377. ret = (end_t - start_t) / iteration
  378. logger.info(
  379. f"[AllgatherTest] nbytes {perf_size}B test result: {ret} s/iter"
  380. )
  381. if perf_threshold_time > -1 and ret > perf_threshold_time:
  382. logger.warning(
  383. f"[Perf Warning] Allgather Test Timeout! {ret} > {perf_threshold_time}"
  384. )
  385. # test reduce_scatter perf
  386. def reduce_scatter_perf(
  387. self,
  388. iteration,
  389. x,
  390. group,
  391. perf_size,
  392. perf_threshold_time,
  393. ):
  394. if group is None or group.nranks <= 1:
  395. logger.warning("reduce_scatter_perf is invalid, group invalid!")
  396. return
  397. paddle.distributed.barrier()
  398. paddle.device.cuda.synchronize()
  399. parallelism = group.nranks
  400. output_shape = x.shape
  401. if x.shape[0] % parallelism != 0:
  402. logger.warning(
  403. f"the shape of input[{x.shape[0]}] can't be divided exactly by reduce_scatter parallelism[{parallelism}], test stopped!"
  404. )
  405. return
  406. output_shape[0] = output_shape[0] // parallelism
  407. output = paddle.empty(shape=output_shape, dtype=x.dtype)
  408. start_t = time.time()
  409. for _ in range(iteration):
  410. paddle.distributed.stream.reduce_scatter(
  411. output,
  412. x,
  413. op=paddle.distributed.ReduceOp.SUM,
  414. group=group,
  415. sync_op=True,
  416. )
  417. paddle.device.cuda.synchronize()
  418. end_t = time.time()
  419. ret = (end_t - start_t) / iteration
  420. logger.info(
  421. f"[ReduceScatterTest] nbytes {perf_size}B test result: {ret} s/iter"
  422. )
  423. if perf_threshold_time > -1 and ret > perf_threshold_time:
  424. logger.warning(
  425. f"[Perf Warning] ReduceScatter Test Timeout! {ret} > {perf_threshold_time}"
  426. )
  427. def _collective_perf_impl(self, round=50, context={}, hcg=None):
  428. if hcg is None:
  429. hcg = self.get_hybrid_communicate_group()
  430. collective_perf_func_map = {
  431. "allreduce": self.allreduce_perf,
  432. "reduce": self.reduce_perf,
  433. "broadcast": self.broadcast_perf,
  434. "allgather": self.allgather_perf,
  435. "reduce_scatter": self.reduce_scatter_perf,
  436. }
  437. dp_group = hcg.get_data_parallel_group()
  438. sharding_group = hcg.get_sharding_parallel_group()
  439. mp_group = hcg.get_model_parallel_group()
  440. data_group = None
  441. if dp_group.nranks > 1:
  442. data_group = dp_group
  443. elif sharding_group.nranks > 1:
  444. data_group = sharding_group
  445. collective_perf_group_map = {
  446. "allreduce": data_group,
  447. "reduce": data_group,
  448. "broadcast": data_group,
  449. "allgather": mp_group,
  450. "reduce_scatter": mp_group,
  451. }
  452. for comm_type, size_and_time in context.items():
  453. # test 1M ~ 1G as default
  454. nbytes = 1 << 20 # 1048576(1MB)
  455. final_nbytes = 1 << 30 # 1073741824(1GB)
  456. dtype = paddle.float32
  457. time_threshold = 0
  458. if size_and_time is not None:
  459. nbytes = size_and_time[0]
  460. # Run only once when test specific message size.
  461. final_nbytes = nbytes
  462. time_threshold = size_and_time[1]
  463. if nbytes <= 0:
  464. logger.warning(
  465. f"Size for collective performance check should be positive, but got {nbytes}"
  466. )
  467. return
  468. while nbytes <= final_nbytes:
  469. x = paddle.zeros([nbytes // 4], dtype=dtype)
  470. # warmup
  471. self.allreduce_perf(10, x, None, nbytes, 1, warmup=True)
  472. collective_perf_func_map[comm_type](
  473. iteration=round,
  474. x=x,
  475. group=collective_perf_group_map[comm_type],
  476. perf_size=nbytes,
  477. perf_threshold_time=time_threshold,
  478. )
  479. nbytes = nbytes << 1
  480. def collective_perf(self, comm_type, round=50, size_and_time={}):
  481. """
  482. Run performance test for given communication type
  483. and compare the time cost with the threshold.
  484. Args:
  485. comm_type (str): Communication type for performance test. Currently support
  486. "allreduce", "broadcast", "reduce", "allgather" and "reduce_scatter".
  487. round (int, optional): Loop times for performance test. More loops will cost more time
  488. and provide more accurate result. Defaults to 50.
  489. size_and_time (dict, optional): Message sizes and time thresholds for performance test.
  490. each pair will invoke a performance check. Defaults to {}, which indicates
  491. acting performance check from 1MB to 1GB without threshold set.
  492. Returns:
  493. None
  494. Examples:
  495. .. code-block:: python
  496. >>> import paddle.distributed.fleet as fleet
  497. >>> fleet.init(is_collective=True)
  498. >>> # run two tests, one with 1MB (threshold 0.5s) and another with 1GB (threshold 1s)
  499. >>> size_and_time = {1<<20: 0.5, 1<<30: 1}
  500. >>> fleet.collective_perf("allreduce", round=50, size_and_time = size_and_time)
  501. """
  502. if not self._is_collective:
  503. logger.warning(
  504. "fleet.collective_perf is only for collective mode, will return with no test acted."
  505. )
  506. return
  507. for size, time_threshold in size_and_time.items():
  508. context = {comm_type: [size, time_threshold]}
  509. self._collective_perf_impl(round=round, context=context)
  510. def _init_hybrid_parallel_env(self):
  511. """initialize the hybrid environment."""
  512. self.hybrid_configs = self._user_defined_strategy.hybrid_configs
  513. self.dp_degree = self.hybrid_configs["dp_degree"]
  514. self.mp_degree = self.hybrid_configs["mp_degree"]
  515. self.pp_degree = self.hybrid_configs["pp_degree"]
  516. self.sep_degree = self.hybrid_configs["sep_degree"]
  517. self.sharding_degree = self.hybrid_configs["sharding_degree"]
  518. assert self.mp_degree >= 0, "mp_degree should be greater or equal to 0"
  519. assert self.pp_degree >= 0, "pp_degree should be greater or equal to 0"
  520. assert (
  521. self.sep_degree >= 0
  522. ), "sep_degree should be greater or equal to 0"
  523. assert (
  524. self.sharding_degree >= 0
  525. ), "sharding_degree should be greater or equal to 0"
  526. self.mp_degree = max(self.mp_degree, 1)
  527. self.pp_degree = max(self.pp_degree, 1)
  528. self.sep_degree = max(self.sep_degree, 1)
  529. if self.dp_degree < 0:
  530. nranks = paddle.distributed.get_world_size()
  531. self.dp_degree = nranks // (self.mp_degree * self.pp_degree)
  532. self.dp_degree = max(self.dp_degree, 1)
  533. d_hybrid_degree = {
  534. "dp": ["data", self.dp_degree],
  535. "pp": ['pipe', self.pp_degree],
  536. "sharding": ['sharding', self.sharding_degree],
  537. "mp": ['model', self.mp_degree],
  538. "sep": ["sep", self.sep_degree],
  539. }
  540. order = self._user_defined_strategy.hybrid_parallel_order
  541. if order[:].sort() != list(d_hybrid_degree.keys())[:].sort():
  542. raise AssertionError(
  543. 'The order of hybrid_config setting is incorrect.'
  544. )
  545. hybrid_group_names = []
  546. dims = []
  547. for h_name in order:
  548. name, degree = d_hybrid_degree[h_name]
  549. hybrid_group_names.append(name)
  550. dims.append(degree)
  551. self._topology = tp.CommunicateTopology(
  552. hybrid_group_names=hybrid_group_names, dims=dims
  553. )
  554. self._hcg = tp.HybridCommunicateGroup(self._topology)
  555. if self.mp_degree > 1:
  556. tensor_parallel_configs = (
  557. self._user_defined_strategy.tensor_parallel_configs
  558. )
  559. tensor_init_seed = tensor_parallel_configs["tensor_init_seed"]
  560. if tensor_init_seed == -1:
  561. model_parallel_random_seed()
  562. else:
  563. model_parallel_random_seed(tensor_init_seed)
  564. def get_hybrid_communicate_group(self):
  565. assert self._hcg is not None
  566. return self._hcg
  567. def get_hybrid_parallel_topology(self):
  568. assert self._topology is not None
  569. return self._topology
  570. def is_first_worker(self):
  571. """
  572. Check whether the node is the first instance of worker.
  573. Returns:
  574. bool: True if this is the first node of worker, False if not.
  575. Examples:
  576. .. code-block:: python
  577. >>> import paddle.distributed.fleet as fleet
  578. >>> fleet.init()
  579. >>> fleet.is_first_worker()
  580. """
  581. return self._role_maker._is_first_worker()
  582. def worker_index(self):
  583. """
  584. Get current worker index.
  585. Returns:
  586. int: node id
  587. Examples:
  588. .. code-block:: python
  589. >>> import paddle.distributed.fleet as fleet
  590. >>> fleet.init()
  591. >>> fleet.worker_index()
  592. """
  593. return self._role_maker._worker_index()
  594. def worker_num(self):
  595. """
  596. Get current total worker number.
  597. Returns:
  598. int: worker numbers
  599. Examples:
  600. .. code-block:: python
  601. >>> import paddle.distributed.fleet as fleet
  602. >>> fleet.init()
  603. >>> fleet.worker_num()
  604. """
  605. return self._role_maker._worker_num()
  606. def node_num(self):
  607. return self._role_maker._get_node_num()
  608. def local_rank(self):
  609. return self._role_maker._get_local_rank()
  610. def local_device_ids(self):
  611. return self._role_maker._get_local_device_ids()
  612. def world_device_ids(self):
  613. return self._role_maker._get_world_device_ids()
  614. def is_worker(self):
  615. """
  616. Check whether the node is an instance of worker.
  617. Returns:
  618. bool: True if this is a node of worker,
  619. False if not.
  620. Examples:
  621. .. code-block:: python
  622. >>> import paddle.distributed.fleet as fleet
  623. >>> fleet.init()
  624. >>> fleet.is_worker()
  625. """
  626. return self._role_maker._is_worker()
  627. def is_coordinator(self):
  628. return self._role_maker._is_coordinator()
  629. def worker_endpoints(self, to_string=False):
  630. """
  631. Get current worker endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].
  632. Returns:
  633. list/string: server endpoints
  634. Examples:
  635. .. code-block:: python
  636. >>> import paddle.distributed.fleet as fleet
  637. >>> fleet.init()
  638. >>> fleet.worker_endpoints()
  639. """
  640. if to_string:
  641. return ",".join(self._role_maker._get_trainer_endpoints())
  642. else:
  643. return self._role_maker._get_trainer_endpoints()
  644. def server_num(self):
  645. """
  646. Get current total worker number.
  647. Returns:
  648. int: server number
  649. Examples:
  650. .. code-block:: python
  651. >>> import paddle.distributed.fleet as fleet
  652. >>> fleet.init()
  653. >>> fleet.server_num()
  654. """
  655. return len(self._role_maker._get_pserver_endpoints())
  656. def server_index(self):
  657. """
  658. Get current server index.
  659. Returns:
  660. int: node id
  661. Examples:
  662. .. code-block:: python
  663. >>> import paddle.distributed.fleet as fleet
  664. >>> fleet.init()
  665. >>> fleet.server_index()
  666. """
  667. return self._role_maker._server_index()
  668. def server_endpoints(self, to_string=False):
  669. """
  670. Get current server endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].
  671. Returns:
  672. list/string: server endpoints
  673. Examples:
  674. .. code-block:: python
  675. >>> import paddle.distributed.fleet as fleet
  676. >>> fleet.init()
  677. >>> fleet.server_endpoints()
  678. """
  679. if to_string:
  680. return ",".join(self._role_maker._get_pserver_endpoints())
  681. else:
  682. return self._role_maker._get_pserver_endpoints()
  683. def is_server(self):
  684. """
  685. Check whether the node is an instance of server.
  686. Returns:
  687. bool: True if this is a node of server,
  688. False if not.
  689. Examples:
  690. .. code-block:: python
  691. >>> import paddle.distributed.fleet as fleet
  692. >>> fleet.init()
  693. >>> fleet.is_server()
  694. """
  695. return self._role_maker._is_server()
  696. def barrier_worker(self):
  697. """
  698. barrier all workers
  699. Returns:
  700. None
  701. Examples:
  702. .. code-block:: python
  703. >>> import paddle.distributed.fleet as fleet
  704. >>> fleet.init()
  705. >>> fleet.barrier_worker()
  706. """
  707. self._role_maker._barrier("worker")
  708. def all_reduce(self, input, mode="sum"):
  709. """
  710. all reduce input between all workers, mode can be sum, mean or max, default is sum
  711. Returns:
  712. list/int: all reduce result
  713. Examples:
  714. .. code-block:: python
  715. >>> import paddle.distributed.fleet as fleet
  716. >>> fleet.init()
  717. >>> res = fleet.all_reduce(5)
  718. """
  719. return self._role_maker._all_reduce(input, mode, "worker")
  720. @is_non_distributed_check
  721. @inited_runtime_handler
  722. def init_worker(self, scopes=None):
  723. """
  724. initialize `Communicator` for parameter server training.
  725. Returns:
  726. None
  727. Examples:
  728. .. code-block:: python
  729. >>> import paddle.distributed.fleet as fleet
  730. >>> fleet.init()
  731. >>> # build net
  732. >>> # fleet.distributed_optimizer(...)
  733. >>> fleet.init_worker()
  734. """
  735. self._runtime_handle._init_worker(scopes)
  736. @is_non_distributed_check
  737. @inited_runtime_handler
  738. def init_coordinator(self, scopes=None):
  739. """
  740. initialize coordinator node
  741. """
  742. self._runtime_handle._init_coordinator(scopes)
  743. def make_fl_strategy(self):
  744. self._runtime_handle._make_fl_strategy()
  745. @is_non_distributed_check
  746. @inited_runtime_handler
  747. def get_fl_client(self):
  748. """
  749. get worker(training node) ptr
  750. """
  751. return self._runtime_handle._worker
  752. @is_non_distributed_check
  753. @inited_runtime_handler
  754. def init_server(self, *args, **kwargs):
  755. """
  756. init_server executor to initialize startup program,
  757. if the `args` is not empty, it will run load_persistables for increment training.
  758. Returns:
  759. None
  760. Examples:
  761. .. code-block:: python
  762. >>> import paddle.distributed.fleet as fleet
  763. >>> fleet.init()
  764. >>> # build net
  765. >>> # fleet.distributed_optimizer(...)
  766. >>> fleet.init_server()
  767. """
  768. self._runtime_handle._init_server(*args, **kwargs)
  769. @is_non_distributed_check
  770. @inited_runtime_handler
  771. def load_model(self, path, mode):
  772. """
  773. load fleet model from path
  774. Returns:
  775. None
  776. Examples:
  777. .. code-block:: python
  778. >>> import paddle.distributed.fleet as fleet
  779. >>> fleet.init()
  780. >>> # build net
  781. >>> # fleet.distributed_optimizer(...)
  782. >>> fleet.load_model("path", mode=0)
  783. """
  784. self._runtime_handle._load_persistables(path, mode)
  785. @is_non_distributed_check
  786. @inited_runtime_handler
  787. def load_one_table(self, table_id, path, mode):
  788. """
  789. load fleet one table from path
  790. Returns:
  791. None
  792. Examples:
  793. .. code-block:: python
  794. >>> import paddle.distributed.fleet as fleet
  795. >>> fleet.init()
  796. >>> # build net
  797. >>> # fleet.distributed_optimizer(...)
  798. >>> fleet.load_one_table(0, "path", mode=0)
  799. """
  800. self._runtime_handle._load_one_table(table_id, path, mode)
  801. @is_non_distributed_check
  802. @inited_runtime_handler
  803. def load_inference_model(self, path, mode):
  804. """
  805. load fleet inference model from path
  806. Returns:
  807. None
  808. Examples:
  809. .. code-block:: python
  810. >>> import paddle.distributed.fleet as fleet
  811. >>> fleet.init()
  812. >>> # build net
  813. >>> # fleet.distributed_optimizer(...)
  814. >>> fleet.load_inference_model("path", mode=1)
  815. """
  816. self._runtime_handle._load_inference_model(path, mode)
  817. @is_non_distributed_check
  818. @inited_runtime_handler
  819. def run_server(self):
  820. """
  821. run server will run pserver main program with executor.
  822. Returns:
  823. None
  824. Examples:
  825. .. code-block:: python
  826. >>> import paddle.distributed.fleet as fleet
  827. >>> fleet.init()
  828. >>> # build net
  829. >>> # fleet.distributed_optimizer(...)
  830. >>> if fleet.is_server():
  831. ... fleet.init_server()
  832. """
  833. self._runtime_handle._run_server()
  834. @is_non_distributed_check
  835. @inited_runtime_handler
  836. def stop_worker(self):
  837. """
  838. stop `Communicator` and give training complete notice to parameter server.
  839. Returns:
  840. None
  841. Examples:
  842. .. code-block:: python
  843. >>> import paddle.distributed.fleet as fleet
  844. >>> fleet.init()
  845. >>> # build net
  846. >>> # fleet.distributed_optimizer(...)
  847. >>> fleet.init_server()
  848. """
  849. self._runtime_handle._stop_worker()
  850. @is_non_distributed_check
  851. @inited_runtime_handler
  852. def save(self, dirname, feed=[], fetch=[], **configs):
  853. inference = True
  854. if not feed and not fetch:
  855. inference = False
  856. place = paddle.CPUPlace()
  857. executor = paddle.static.Executor(place)
  858. if inference:
  859. feeded_var_names = []
  860. fetch_var_names = []
  861. for var in feed:
  862. if isinstance(var, str):
  863. feeded_var_names.append(var)
  864. elif isinstance(var, paddle.static.Variable):
  865. feeded_var_names.append(var.name)
  866. else:
  867. raise ValueError("feed must be [str|Variable]")
  868. for var in fetch:
  869. if isinstance(var, str):
  870. fetch_var_names.append(var)
  871. elif isinstance(var, paddle.static.Variable):
  872. fetch_var_names.append(var.name)
  873. else:
  874. raise ValueError("feed must be [str|Variable]")
  875. fetch_vars = [
  876. paddle.static.default_main_program().global_block().var(name)
  877. for name in fetch_var_names
  878. ]
  879. self._runtime_handle._save_inference_model(
  880. executor, dirname, feeded_var_names, fetch_vars, None, True, 0
  881. )
  882. else:
  883. increment_mode = 0
  884. if "mode" in configs:
  885. increment_mode = int(configs["mode"])
  886. self._runtime_handle._save_persistables(
  887. executor, dirname, main_program=None, mode=increment_mode
  888. )
  889. @is_non_distributed_check
  890. @inited_runtime_handler
  891. def save_inference_model(
  892. self,
  893. executor,
  894. dirname,
  895. feeded_var_names,
  896. target_vars,
  897. main_program=None,
  898. export_for_deployment=True,
  899. mode=0,
  900. ):
  901. """
  902. save inference model for inference.
  903. Returns:
  904. None
  905. Examples:
  906. .. code-block:: python
  907. >>> import paddle.distributed.fleet as fleet
  908. >>> fleet.init()
  909. >>> # build net
  910. >>> # fleet.distributed_optimizer(...)
  911. >>> fleet.init_server()
  912. """
  913. self._runtime_handle._save_inference_model(
  914. executor,
  915. dirname,
  916. feeded_var_names,
  917. target_vars,
  918. main_program,
  919. export_for_deployment,
  920. mode,
  921. )
  922. @is_non_distributed_check
  923. @inited_runtime_handler
  924. def save_persistables(self, executor, dirname, main_program=None, mode=0):
  925. """
  926. saves all persistable tensors from :code:`main_program` to
  927. the folder :code:`dirname`. You can refer to
  928. The :code:`dirname` is used to specify the folder where persistable tensors
  929. are going to be saved. If you would like to save tensors in separate
  930. files, set :code:`filename` None.
  931. Args:
  932. executor(Executor): The executor to run for saving persistable tensors.
  933. You can refer to :ref:`api_guide_executor_en` for
  934. more details.
  935. dirname(str, optional): The saving directory path.
  936. When you need to save the parameter to the memory, set it to None.
  937. main_program(Program, optional): The program whose persistable tensors will
  938. be saved. Default: None.
  939. Returns:
  940. None
  941. Examples:
  942. .. code-block:: python
  943. >>> import paddle
  944. >>> paddle.enable_static()
  945. >>> import paddle.distributed.fleet as fleet
  946. >>> fleet.init()
  947. >>> # build net
  948. >>> # fleet.distributed_optimizer(...)
  949. >>> exe = paddle.static.Executor(paddle.CPUPlace())
  950. >>> fleet.save_persistables(exe, "dirname", paddle.static.default_main_program())
  951. """
  952. self._runtime_handle._save_persistables(
  953. executor, dirname, main_program, mode
  954. )
  955. @is_non_distributed_check
  956. @inited_runtime_handler
  957. def save_cache_model(self, dirname, **configs):
  958. return self._runtime_handle._save_cache_model(dirname, **configs)
  959. @is_non_distributed_check
  960. @inited_runtime_handler
  961. def check_save_pre_patch_done(self):
  962. return self._runtime_handle._check_save_pre_patch_done()
  963. @is_non_distributed_check
  964. @inited_runtime_handler
  965. def save_cache_table(
  966. self, table_id, pass_id, mem_cache_key_threshold=4000000000
  967. ):
  968. return self._runtime_handle._save_cache_table(
  969. table_id, pass_id, mem_cache_key_threshold
  970. )
  971. @is_non_distributed_check
  972. @inited_runtime_handler
  973. def save_one_table(self, table_id, path, mode):
  974. """
  975. save fleet one table from path
  976. Returns:
  977. None
  978. Examples:
  979. .. code-block:: python
  980. >>> import paddle.distributed.fleet as fleet
  981. >>> fleet.init()
  982. >>> # build net
  983. >>> # fleet.distributed_optimizer(...)
  984. >>> fleet.save_one_table(0, "path", mode=0)
  985. """
  986. self._runtime_handle._save_one_table(table_id, path, mode)
  987. @is_non_distributed_check
  988. @inited_runtime_handler
  989. def save_dense_params(
  990. self, executor, dirname, scope, program, var_names=None
  991. ):
  992. """
  993. save fleet one table from path
  994. Returns:
  995. None
  996. Examples:
  997. .. code-block:: python
  998. >>> import paddle.distributed.fleet as fleet
  999. >>> fleet.init()
  1000. >>> import paddle
  1001. >>> place = paddle.CPUPlace()
  1002. >>> exe = paddle.static.Executor(place)
  1003. >>> # build net
  1004. >>> # fleet.distributed_optimizer(...)
  1005. >>> fleet.save_dense_params(exe, "path", scope=paddle.static.global_scope(), program=paddle.static.default_main_program())
  1006. """
  1007. self._runtime_handle._save_dense_params(
  1008. executor, dirname, scope, program, var_names
  1009. )
  1010. @is_non_distributed_check
  1011. @inited_runtime_handler
  1012. def set_date(self, table_id, day_id):
  1013. """
  1014. set_date for gpups table
  1015. Returns:
  1016. None
  1017. Examples:
  1018. .. code-block:: python
  1019. >>> import paddle.distributed.fleet as fleet
  1020. >>> fleet.init()
  1021. >>> # build net
  1022. >>> # fleet.distributed_optimizer(...)
  1023. >>> fleet.set_date(0, "20250101")
  1024. """
  1025. self._runtime_handle._set_date(table_id, day_id)
  1026. @is_non_distributed_check
  1027. @inited_runtime_handler
  1028. def shrink(self, threshold=None):
  1029. self._runtime_handle._shrink(threshold)
  1030. def distributed_optimizer(self, optimizer, strategy=None):
  1031. """
  1032. Optimizer for distributed training.
  1033. For the distributed training, this method would rebuild a new instance of DistributedOptimizer.
  1034. Which has basic Optimizer function and special features for distributed training.
  1035. Args:
  1036. optimizer(Optimizer): The executor to run for init server.
  1037. strategy(DistributedStrategy): Extra properties for distributed optimizer.
  1038. It is recommended to use DistributedStrategy in fleet.init(). The strategy
  1039. here is for compatibility. If the strategy in fleet.distributed_optimizer()
  1040. is not None, then it will overwrite the DistributedStrategy in fleet.init(),
  1041. which will take effect in distributed training.
  1042. Returns:
  1043. Fleet: instance of fleet.
  1044. Examples:
  1045. .. code-block:: python
  1046. >>> import paddle
  1047. >>> import paddle.distributed.fleet as fleet
  1048. >>> fleet.init(is_collective=True)
  1049. >>> linear = paddle.nn.Linear(10, 10)
  1050. >>> strategy = fleet.DistributedStrategy()
  1051. >>> optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters=linear.parameters())
  1052. >>> optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
  1053. """
  1054. self.user_defined_optimizer = optimizer
  1055. if strategy is not None:
  1056. if self._is_collective:
  1057. logger.warning(
  1058. "It is recommended to use DistributedStrategy "
  1059. "in fleet.init(). The strategy here is only for compatibility. "
  1060. "If the strategy in fleet.distributed_optimizer() is "
  1061. "not None, then it will overwrite the DistributedStrategy in fleet.init(), "
  1062. "which will take effect in distributed training."
  1063. )
  1064. self._user_defined_strategy = copy.deepcopy(strategy)
  1065. self._context = {}
  1066. return self
  1067. def _get_amp_optimizer(self):
  1068. # imitate target optimizer retrieval
  1069. amp_optimizer = None
  1070. for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
  1071. if hasattr(optimizer, 'amp_init'):
  1072. amp_optimizer = optimizer
  1073. break
  1074. if amp_optimizer is None:
  1075. if hasattr(self.user_defined_optimizer, 'amp_init'):
  1076. amp_optimizer = self.user_defined_optimizer
  1077. assert (
  1078. amp_optimizer is not None
  1079. ), "amp_init can only be used when the amp(auto mixed precision) strategy is turned on."
  1080. return amp_optimizer
  1081. def get_loss_scaling(self):
  1082. """Return the real-time loss scaling factor."""
  1083. amp_optimizer = self._get_amp_optimizer()
  1084. return amp_optimizer.get_loss_scaling()
  1085. def amp_init(
  1086. self, place, scope=None, test_program=None, use_fp16_test=False
  1087. ):
  1088. """
  1089. Init the amp training, such as cast fp32 parameters to fp16 type.
  1090. Args:
  1091. place(CUDAPlace): place is used to initialize
  1092. fp16 parameters with fp32 values.
  1093. scope(Scope): The scope is used to find fp32 parameters.
  1094. test_program(Program): The program is used for testing.
  1095. use_fp16_test(bool): Whether to use fp16 testing.
  1096. Examples:
  1097. .. code-block:: python
  1098. >>> import paddle
  1099. >>> import paddle.nn.functional as F
  1100. >>> paddle.enable_static()
  1101. >>> def run_example_code():
  1102. ... place = paddle.CUDAPlace(0)
  1103. ... exe = paddle.static.Executor(place)
  1104. ... data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
  1105. ... conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
  1106. ... # 1) Use fp16_guard to control the range of fp16 kernels used.
  1107. ... with paddle.static.amp.fp16_guard():
  1108. ... bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
  1109. ... pool = F.max_pool2d(bn, kernel_size=2, stride=2)
  1110. ... hidden = paddle.static.nn.fc(pool, size=10)
  1111. ... loss = paddle.mean(hidden)
  1112. ... # 2) Create the optimizer and set `multi_precision` to True.
  1113. ... # Setting `multi_precision` to True can avoid the poor accuracy
  1114. ... # or the slow convergence in a way.
  1115. ... optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
  1116. ... # 3) These ops in `custom_black_list` will keep in the float32 computation type.
  1117. ... amp_list = paddle.static.amp.CustomOpLists(
  1118. ... custom_black_list=['pool2d'])
  1119. ... # 4) The entry of Paddle AMP.
  1120. ... # Enable pure fp16 training by setting `use_pure_fp16` to True.
  1121. ... optimizer = paddle.static.amp.decorate(
  1122. ... optimizer,
  1123. ... amp_list,
  1124. ... init_loss_scaling=128.0,
  1125. ... use_dynamic_loss_scaling=True,
  1126. ... use_pure_fp16=True)
  1127. ... # If you don't use the default_startup_program(), you should pass
  1128. ... # your defined `startup_program` into `minimize`.
  1129. ... optimizer.minimize(loss)
  1130. ... exe.run(paddle.static.default_startup_program())
  1131. ... # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
  1132. ... # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
  1133. ... optimizer.amp_init(place, scope=paddle.static.global_scope())
  1134. >>> if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
  1135. ... run_example_code()
  1136. """
  1137. amp_optimizer = self._get_amp_optimizer()
  1138. return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test)
  1139. def _get_qat_optimizer(self):
  1140. # imitate target optimizer retrieval
  1141. qat_optimizer = None
  1142. for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
  1143. if hasattr(optimizer, 'qat_init'):
  1144. qat_optimizer = optimizer
  1145. break
  1146. if qat_optimizer is None:
  1147. if hasattr(self.user_defined_optimizer, 'qat_init'):
  1148. qat_optimizer = self.user_defined_optimizer
  1149. assert (
  1150. qat_optimizer is not None
  1151. ), "qat_init can only be used when the qat(quantization aware training) strategy is turned on."
  1152. return qat_optimizer
  1153. def qat_init(self, place, scope=None, test_program=None):
  1154. """
  1155. Init the qat training, such as insert qdq ops and scale variables.
  1156. Args:
  1157. place(CUDAPlace): place is used to initialize
  1158. scale parameters.
  1159. scope(Scope): The scope is used to find parameters and variables.
  1160. test_program(Program): The program is used for testing.
  1161. """
  1162. qat_optimizer = self._get_qat_optimizer()
  1163. return qat_optimizer.qat_init(
  1164. place, scope=scope, test_program=test_program
  1165. )
  1166. def _final_strategy(self):
  1167. if "valid_strategy" not in self._context:
  1168. print(
  1169. "WARNING: You may need to call minimize function before this function is called"
  1170. )
  1171. return {}
  1172. else:
  1173. return self._context["valid_strategy"]
  1174. def _get_applied_meta_list(self):
  1175. if "applied_meta_list" not in self._context:
  1176. print(
  1177. "WARNING: You may need to call minimize function before _get_applied_meta_list called"
  1178. )
  1179. return []
  1180. else:
  1181. return self._context["applied_meta_list"]
  1182. def _get_applied_graph_list(self):
  1183. if "applied_graph_list" not in self._context:
  1184. print(
  1185. "WARNING: You may need to call minimize function before _get_applied_graph_list called"
  1186. )
  1187. return []
  1188. else:
  1189. return self._context["applied_graph_list"]
  1190. def minimize(
  1191. self, loss, startup_program=None, parameter_list=None, no_grad_set=None
  1192. ):
  1193. """
  1194. Add distributed operations to minimize ``loss`` by updating ``parameter_list``.
  1195. Args:
  1196. loss (Tensor): A ``Tensor`` containing the value to minimize.
  1197. startup_program (Program, optional): :ref:`api_paddle_static_Program` for
  1198. initializing parameters in ``parameter_list``. The default value
  1199. is None, at this time :ref:`api_paddle_static_default_startup_program` will be used.
  1200. parameter_list (Iterable, optional): Iterable of ``Tensor`` or ``Tensor.name`` to update
  1201. to minimize ``loss``. The default value is None, at this time all parameters
  1202. will be updated.
  1203. no_grad_set (set, optional): Set of ``Tensor`` or ``Tensor.name`` that don't need
  1204. to be updated. The default value is None.
  1205. Returns:
  1206. tuple: tuple (optimize_ops, params_grads), A list of operators appended
  1207. by minimize and a list of (param, grad) tensor pairs, param is
  1208. ``Parameter``, grad is the gradient value corresponding to the parameter.
  1209. The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
  1210. indicate program pruning. If so, the program will be pruned by ``feed`` and
  1211. ``fetch_list`` before run, see details in ``Executor``.
  1212. Examples:
  1213. .. code-block:: python
  1214. >>> import paddle
  1215. >>> paddle.enable_static()
  1216. >>> import paddle.distributed.fleet as fleet
  1217. >>> import paddle.nn.functional as F
  1218. >>> hid_dim = 10
  1219. >>> label_dim = 2
  1220. >>> input_x = paddle.static.data(name='x', shape=[None, 13], dtype='float32')
  1221. >>> input_y = paddle.static.data(name='y', shape=[None, 1], dtype='int64')
  1222. >>> fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh')
  1223. >>> fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh')
  1224. >>> prediction = paddle.static.nn.fc(x=[fc_2], size=label_dim, activation='softmax')
  1225. >>> cost = F.cross_entropy(input=prediction, label=input_y)
  1226. >>> avg_cost = paddle.mean(x=cost)
  1227. >>> fleet.init(is_collective=True)
  1228. >>> strategy = fleet.DistributedStrategy()
  1229. >>> linear = paddle.nn.Linear(10, 10)
  1230. >>> optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters=linear.parameters())
  1231. >>> optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
  1232. >>> optimizer.minimize(avg_cost)
  1233. >>> # for more examples, please reference https://github.com/PaddlePaddle/PaddleFleetX
  1234. """
  1235. if not isinstance(loss, list):
  1236. return self._minimize_impl(
  1237. loss, startup_program, parameter_list, no_grad_set
  1238. )
  1239. else:
  1240. if (
  1241. in_dynamic_mode()
  1242. or self._role_maker._is_non_distributed()
  1243. or self._is_collective
  1244. ):
  1245. raise ValueError("loss can be list only in PS mode")
  1246. return self._minimize_losses_impl(
  1247. loss, startup_program, parameter_list, no_grad_set
  1248. )
  1249. def _minimize_impl(
  1250. self, loss, startup_program=None, parameter_list=None, no_grad_set=None
  1251. ):
  1252. context = {}
  1253. context["user_defined_strategy"] = copy.deepcopy(
  1254. self._user_defined_strategy
  1255. )
  1256. if in_dynamic_mode():
  1257. # imitate target optimizer retrieval
  1258. target_opt = self.user_defined_optimizer
  1259. self._context = context
  1260. return target_opt.minimize(loss)
  1261. else:
  1262. # cache original feed forward program
  1263. self.origin_main_program = loss.block.program
  1264. # add distributed attr
  1265. if not hasattr(self.origin_main_program, "distributed_info_"):
  1266. self.origin_main_program.distributed_info_ = {}
  1267. self.origin_main_program.distributed_info_[
  1268. "dp_degree"
  1269. ] = self._user_defined_strategy.sharding_configs["dp_degree"]
  1270. self.origin_main_program.distributed_info_[
  1271. "mp_degree"
  1272. ] = self._user_defined_strategy.sharding_configs["mp_degree"]
  1273. self.origin_main_program.distributed_info_[
  1274. "pp_degree"
  1275. ] = self._user_defined_strategy.sharding_configs["pp_degree"]
  1276. self.origin_main_program.distributed_info_[
  1277. "sharding_degree"
  1278. ] = self._user_defined_strategy.sharding_configs[
  1279. "sharding_degree"
  1280. ]
  1281. context["origin_main_program"] = self.origin_main_program
  1282. context["origin_main_programs"] = [self.origin_main_program]
  1283. context["loss"] = loss
  1284. if startup_program is None:
  1285. self.origin_startup_program = (
  1286. paddle.static.default_startup_program().clone(
  1287. for_test=False
  1288. )
  1289. )
  1290. startup_program = paddle.static.default_startup_program()
  1291. else:
  1292. self.origin_startup_program = startup_program.clone(
  1293. for_test=False
  1294. )
  1295. context["origin_startup_program"] = startup_program
  1296. context["origin_startup_programs"] = [startup_program]
  1297. context["role_maker"] = self._role_maker
  1298. # Use the auto-parallel's routines instead
  1299. if (
  1300. self._user_defined_strategy.semi_auto
  1301. or self._user_defined_strategy.auto_search
  1302. ):
  1303. from ..auto_parallel.static.parallelizer import AutoParallelizer
  1304. auto_parallelizer = AutoParallelizer(self)
  1305. (
  1306. optimize_ops,
  1307. params_grads,
  1308. dist_startup_prog,
  1309. dist_main_prog,
  1310. ) = auto_parallelizer.parallelize(
  1311. loss, startup_program, parameter_list, no_grad_set
  1312. )
  1313. return (
  1314. optimize_ops,
  1315. params_grads,
  1316. dist_startup_prog,
  1317. dist_main_prog,
  1318. )
  1319. context["user_defined_strategy"] = copy.deepcopy(
  1320. self._user_defined_strategy
  1321. )
  1322. copy_user_defined_strategy = copy.deepcopy(
  1323. self._user_defined_strategy
  1324. )
  1325. can_not_apply_optimizer_list = []
  1326. valid_optimizer_list = []
  1327. valid_graph_optimizer_list = []
  1328. skip_names = []
  1329. if (
  1330. self._is_collective
  1331. and len(self._user_defined_strategy.sparse_table_configs) > 0
  1332. ):
  1333. skip_names.append("ShardingOptimizer")
  1334. # compile time
  1335. distributed_optimizer_list = (
  1336. MetaOptimizerFactory()._get_valid_meta_optimizers(
  1337. self.user_defined_optimizer, skip_names
  1338. )
  1339. )
  1340. # trigger the auto-parallel in very strict condition
  1341. # strategy = DistributedStrategy()
  1342. # strategy.auto = True
  1343. # optimizer = paddle.optimizer.SGD(learning_rate=0.1)
  1344. # optimizer = fleet.distributed_optimizer(optimizer, strategy)
  1345. if copy_user_defined_strategy._is_strict_auto():
  1346. # turn on all the strategy for each optimizer
  1347. for opt in distributed_optimizer_list:
  1348. opt._enable_strategy(copy_user_defined_strategy, context)
  1349. valid_optimizer_list = []
  1350. valid_graph_optimizer_list = []
  1351. # recall meta optimizers for ranking
  1352. for opt in distributed_optimizer_list:
  1353. opt._set_basic_info(
  1354. loss,
  1355. self._role_maker,
  1356. self.user_defined_optimizer,
  1357. copy_user_defined_strategy,
  1358. )
  1359. if opt._can_apply() and not opt._is_graph_out():
  1360. valid_optimizer_list.append(opt)
  1361. elif opt._can_apply() and opt._is_graph_out():
  1362. valid_graph_optimizer_list.append(opt)
  1363. else:
  1364. can_not_apply_optimizer_list.append(opt)
  1365. # fix set collective and fleet ps gpu error
  1366. if (
  1367. self._is_collective
  1368. and len(self._user_defined_strategy.sparse_table_configs) > 0
  1369. ):
  1370. context["use_fleet_ps"] = True
  1371. from .meta_optimizers import ParameterServerOptimizer
  1372. meta_optimizer = ParameterServerOptimizer(
  1373. self.user_defined_optimizer
  1374. )
  1375. meta_optimizer._set_basic_info(
  1376. loss,
  1377. self._role_maker,
  1378. self.user_defined_optimizer,
  1379. copy_user_defined_strategy,
  1380. )
  1381. valid_optimizer_list.clear()
  1382. valid_optimizer_list.append(meta_optimizer)
  1383. can_not_apply_optimizer_list.append(meta_optimizer)
  1384. # meaningless, just for compatibility with other code
  1385. graph_optimizer = None
  1386. # valid_graph_optimizer_list.clear()
  1387. # valid_graph_optimizer_list.append(graph_optimizer)
  1388. # can_not_apply_optimizer_list.append(graph_optimizer)
  1389. print("valid_optimizer_list=", valid_optimizer_list)
  1390. # combine recalled meta optimizers to be a valid meta optimizer
  1391. (
  1392. meta_optimizer,
  1393. graph_optimizer,
  1394. ) = self.strategy_compiler.generate_optimizer(
  1395. loss,
  1396. self._role_maker,
  1397. self.user_defined_optimizer,
  1398. copy_user_defined_strategy,
  1399. valid_optimizer_list,
  1400. valid_graph_optimizer_list,
  1401. )
  1402. print("meta_optimizer=", meta_optimizer)
  1403. print("graph_optimizer=", graph_optimizer)
  1404. valid_strategy = self.strategy_compiler._get_valid_strategy(
  1405. copy_user_defined_strategy, can_not_apply_optimizer_list
  1406. )
  1407. context["valid_strategy"] = copy.deepcopy(valid_strategy)
  1408. logger.debug("valid_strategy: " + str(context["valid_strategy"]))
  1409. logger.debug(
  1410. "user_defined_strategy: "
  1411. + str(context["user_defined_strategy"])
  1412. )
  1413. applied_meta_list = self.strategy_compiler._get_applied_meta_list()
  1414. applied_graph_list = (
  1415. self.strategy_compiler._get_applied_graph_list()
  1416. )
  1417. context['applied_meta_list'] = applied_meta_list
  1418. context['applied_graph_list'] = applied_graph_list
  1419. self._context = context
  1420. self.valid_strategy = valid_strategy
  1421. self.valid_strategy._enable_env()
  1422. optimize_ops = []
  1423. params_grads = []
  1424. if (
  1425. self._role_maker._is_non_distributed()
  1426. and not self._is_collective
  1427. ):
  1428. if self._runtime_handle is None:
  1429. self._runtime_handle = RuntimeFactory()._create_runtime(
  1430. context
  1431. )
  1432. compiled_program = compiler.CompiledProgram(
  1433. self.origin_main_program
  1434. )
  1435. loss.block.program._graph = compiled_program
  1436. return self.user_defined_optimizer.minimize(
  1437. loss,
  1438. startup_program,
  1439. parameter_list,
  1440. no_grad_set=no_grad_set,
  1441. )
  1442. if meta_optimizer:
  1443. logger.debug(
  1444. "before minimize program id: " + str(id(loss.block.program))
  1445. )
  1446. optimize_ops, params_grads = meta_optimizer.minimize(
  1447. loss,
  1448. startup_program,
  1449. parameter_list,
  1450. no_grad_set=no_grad_set,
  1451. )
  1452. logger.debug(
  1453. "after minimize program id: " + str(id(loss.block.program))
  1454. )
  1455. default_program = paddle.static.default_main_program()
  1456. logger.debug("default program id: " + str(id(default_program)))
  1457. if id(default_program) != id(loss.block.program):
  1458. paddle.framework.switch_main_program(loss.block.program)
  1459. logger.debug(
  1460. "default program id after switch: "
  1461. + str(id(default_program))
  1462. )
  1463. else:
  1464. (
  1465. optimize_ops,
  1466. params_grads,
  1467. ) = self.user_defined_optimizer.minimize(
  1468. loss,
  1469. startup_program,
  1470. parameter_list,
  1471. no_grad_set=no_grad_set,
  1472. )
  1473. context["program_optimize_ops"] = optimize_ops
  1474. context["program_params_grads"] = params_grads
  1475. if graph_optimizer:
  1476. logger.debug(
  1477. "before graph minimize program id: "
  1478. + str(id(loss.block.program))
  1479. )
  1480. optimize_ops, params_grads = graph_optimizer.minimize(
  1481. loss,
  1482. startup_program,
  1483. parameter_list,
  1484. no_grad_set=no_grad_set,
  1485. )
  1486. # since we do not encourage users to use graph operations
  1487. # if a graph optimizer takes effect, mostly
  1488. # optimizers_ops and params_grads are None
  1489. # i.e. users can not modify current computation graph anymore
  1490. context["graph_optimize_ops"] = optimize_ops
  1491. context["graph_optimize_grads"] = params_grads
  1492. elif loss.block.program._pass_applied is None:
  1493. apply_ir_passes(loss.block.program, startup_program, self)
  1494. if not self._role_maker._is_heter_parameter_server_mode:
  1495. program = paddle.static.default_main_program()
  1496. opt_info = (
  1497. {} if program._fleet_opt is None else program._fleet_opt
  1498. )
  1499. opt_info["mpi_size"] = self.worker_num()
  1500. opt_info["mpi_rank"] = self.worker_index()
  1501. for (
  1502. k,
  1503. v,
  1504. ) in self._user_defined_strategy.trainer_desc_configs.items():
  1505. if v or k not in opt_info:
  1506. opt_info[k] = v
  1507. program._fleet_opt = opt_info
  1508. if self._runtime_handle is None:
  1509. self._runtime_handle = RuntimeFactory()._create_runtime(context)
  1510. from paddle.distributed import fleet
  1511. fleet.util._set_strategy(context["valid_strategy"])
  1512. return optimize_ops, params_grads
  1513. def _minimize_losses_impl(
  1514. self,
  1515. losses,
  1516. startup_programs=None,
  1517. parameter_list=None,
  1518. no_grad_set=None,
  1519. ):
  1520. context = {}
  1521. # cache original feed forward program
  1522. self.origin_main_program = losses[0].block.program
  1523. context["origin_main_program"] = self.origin_main_program
  1524. context["origin_main_programs"] = []
  1525. for loss in losses:
  1526. context["origin_main_programs"].append(loss.block.program)
  1527. context["loss"] = losses
  1528. if startup_programs is None:
  1529. if len(losses) == 1:
  1530. startup_programs = [paddle.static.default_startup_program()]
  1531. else:
  1532. raise ValueError(
  1533. "startup_program can't be None when loss is list."
  1534. )
  1535. self.origin_startup_program = startup_programs[0].clone(for_test=False)
  1536. context["origin_startup_program"] = startup_programs[0]
  1537. context["origin_startup_programs"] = []
  1538. for program in startup_programs:
  1539. context["origin_startup_programs"].append(program)
  1540. context["role_maker"] = self._role_maker
  1541. context["user_defined_strategy"] = copy.deepcopy(
  1542. self._user_defined_strategy
  1543. )
  1544. context["valid_strategy"] = copy.deepcopy(self._user_defined_strategy)
  1545. self._context = context
  1546. self.valid_strategy = context["valid_strategy"]
  1547. self.valid_strategy._enable_env()
  1548. optimize_ops = []
  1549. params_grads = []
  1550. from .meta_optimizers import ParameterServerOptimizer
  1551. ps_optimizer = ParameterServerOptimizer(self.user_defined_optimizer)
  1552. ps_optimizer._set_basic_info(
  1553. losses,
  1554. self._role_maker,
  1555. self.user_defined_optimizer,
  1556. self._user_defined_strategy,
  1557. )
  1558. optimize_ops, params_grads = ps_optimizer.minimize_losses_impl(
  1559. losses, startup_programs, parameter_list, no_grad_set=no_grad_set
  1560. )
  1561. # default_program = paddle.static.default_main_program()
  1562. # if id(default_program) != id(losses[0].block.program):
  1563. # paddle.framework.switch_main_program(losses[0].block.program)
  1564. context["program_optimize_ops"] = optimize_ops
  1565. context["program_params_grads"] = params_grads
  1566. for loss in losses:
  1567. program = loss.block.program
  1568. opt_info = {} if program._fleet_opt is None else program._fleet_opt
  1569. opt_info["mpi_size"] = self.worker_num()
  1570. opt_info["mpi_rank"] = self.worker_index()
  1571. for (
  1572. k,
  1573. v,
  1574. ) in self._user_defined_strategy.trainer_desc_configs.items():
  1575. if v or k not in opt_info:
  1576. opt_info[k] = v
  1577. program._fleet_opt = opt_info
  1578. logger.info(
  1579. "fleet base opt info: "
  1580. + str(id(program))
  1581. + str(program._fleet_opt)
  1582. )
  1583. if self._runtime_handle is None:
  1584. self._runtime_handle = RuntimeFactory()._create_runtime(context)
  1585. from paddle.distributed import fleet
  1586. fleet.util._set_strategy(context["valid_strategy"])
  1587. return optimize_ops, params_grads