auto_cast.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import warnings
  16. import paddle
  17. from paddle.base import core
  18. from paddle.base.framework import (
  19. _current_expected_place,
  20. _dygraph_tracer,
  21. dygraph_only,
  22. in_dynamic_or_pir_mode,
  23. in_pir_mode,
  24. )
  25. from paddle.base.wrapped_decorator import signature_safe_contextmanager
  26. from paddle.static.amp.decorator import OptimizerWithMixedPrecision
  27. from .amp_lists import black_list, white_list
  28. AMP_RELATED_FLAGS = [
  29. 'FLAGS_cudnn_exhaustive_search',
  30. 'FLAGS_conv_workspace_size_limit',
  31. 'FLAGS_cudnn_batchnorm_spatial_persistent',
  32. ]
  33. AMP_RELATED_FLAGS_SETTING = {
  34. 'FLAGS_cudnn_exhaustive_search': 1,
  35. 'FLAGS_conv_workspace_size_limit': 1000,
  36. 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
  37. }
  38. AMP_LEVEL = core.AmpLevel
  39. _g_amp_state_ = None
  40. def amp_state():
  41. global _g_amp_state_
  42. return _g_amp_state_
  43. class AMPGlobalState:
  44. def __init__(self):
  45. self.model_parameters = []
  46. self.use_master_grad = False
  47. self.already_register_final_backward_hook = False
  48. self.already_classify_params_meshes = False # For dist
  49. self.mesh2params = {} # For dist
  50. self.amp_dtype = 'float32'
  51. def __setattr__(self, name, val):
  52. self.__dict__[name] = val
  53. _amp_global_state = AMPGlobalState()
  54. def amp_global_state():
  55. return _amp_global_state
  56. # NOTE(zhiqiu): similar as paddle.static.amp.fp16_lists.AutoMixedPrecisionLists._update_list
  57. # The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode.
  58. def _update_list(
  59. custom_white_list, custom_black_list, level='O1', dtype='float16'
  60. ):
  61. """
  62. Update black and white list according to users' custom list.
  63. """
  64. if level == 'O0':
  65. _white_list = set()
  66. _black_list = set()
  67. return _white_list, _black_list
  68. _white_list = copy.copy(white_list()[dtype][level])
  69. _black_list = copy.copy(black_list()[dtype][level])
  70. if custom_white_list and custom_black_list:
  71. for op_name in custom_white_list:
  72. if op_name in custom_black_list:
  73. raise ValueError(
  74. "Custom white list overlap " "custom black list"
  75. )
  76. if custom_white_list:
  77. for op_name in custom_white_list:
  78. if op_name in _black_list:
  79. _black_list.remove(op_name)
  80. _white_list.add(op_name)
  81. if custom_black_list:
  82. for op_name in custom_black_list:
  83. if op_name in _white_list:
  84. _white_list.remove(op_name)
  85. _black_list.add(op_name)
  86. return _white_list, _black_list
  87. def _in_amp_guard():
  88. """
  89. Judge whether current code block is in `amp_guard` context.
  90. """
  91. tracer = _dygraph_tracer()
  92. if tracer:
  93. if tracer._amp_level == core.AmpLevel.O1:
  94. return True
  95. else:
  96. return False
  97. else:
  98. return False
  99. def _in_pure_fp16_guard():
  100. tracer = _dygraph_tracer()
  101. return tracer and tracer._amp_level == core.AmpLevel.O2
  102. def _is_gpu_float16_supported():
  103. """
  104. Judge whether current gpu support float16 amp.
  105. """
  106. prop = paddle.device.cuda.get_device_capability()
  107. return prop[0] >= 7
  108. def _is_gpu_bfloat16_supported():
  109. """
  110. Judge whether current gpu support bfloat16 amp.
  111. """
  112. prop = paddle.device.cuda.get_device_capability()
  113. cuda_version = paddle.version.cuda()
  114. if cuda_version is not None and cuda_version != 'False':
  115. cuda_version_check = int(cuda_version.split('.')[0]) >= 11
  116. else:
  117. cuda_version_check = False
  118. return prop[0] >= 8 and cuda_version_check
  119. def _is_custom_device_bfloat16_supported():
  120. """
  121. Judge whether current custom device support bfloat16 amp.
  122. """
  123. place = _current_expected_place()
  124. return place.get_device_type() == 'npu'
  125. def need_keep_fp32(layer, dtype):
  126. need_keep_fp32 = False
  127. # Highest priority. Because all the layers except BN will use bfloat16 params in bfloat16 training,
  128. # here we provide a option to keep fp32 param.
  129. if not layer._cast_to_low_precision:
  130. need_keep_fp32 = True
  131. # The BN layers will keep fp32
  132. elif isinstance(
  133. layer,
  134. (
  135. paddle.nn.BatchNorm,
  136. paddle.nn.BatchNorm1D,
  137. paddle.nn.BatchNorm2D,
  138. paddle.nn.BatchNorm3D,
  139. paddle.nn.SyncBatchNorm,
  140. ),
  141. ):
  142. need_keep_fp32 = True
  143. # layer._dtype is used to set params dtype. BF16 will use bf16 params.
  144. elif (layer._dtype == 'float16') or (
  145. (dtype == 'float16')
  146. and isinstance(
  147. layer,
  148. (
  149. paddle.nn.LayerNorm,
  150. paddle.nn.InstanceNorm1D,
  151. paddle.nn.InstanceNorm2D,
  152. paddle.nn.InstanceNorm3D,
  153. ),
  154. )
  155. ):
  156. need_keep_fp32 = True
  157. return need_keep_fp32
  158. def set_excluded_layers(models, excluded_layers):
  159. excluded_layers_instances = []
  160. excluded_layers_types = []
  161. error_message = "excluded_layers must be either a nn.Layer instance/type or a list of nn.Layer instances/types."
  162. if excluded_layers is None:
  163. excluded_layers = []
  164. elif isinstance(excluded_layers, paddle.nn.Layer):
  165. excluded_layers_instances = [excluded_layers]
  166. elif isinstance(excluded_layers, type) and issubclass(
  167. excluded_layers, paddle.nn.Layer
  168. ):
  169. excluded_layers_types = [excluded_layers]
  170. elif isinstance(excluded_layers, list):
  171. for item in excluded_layers:
  172. if isinstance(item, paddle.nn.Layer):
  173. excluded_layers_instances.append(item)
  174. elif issubclass(item, paddle.nn.Layer):
  175. excluded_layers_types.append(item)
  176. else:
  177. raise TypeError(error_message)
  178. else:
  179. raise TypeError(error_message)
  180. for idx in range(len(excluded_layers_instances)):
  181. for layer in excluded_layers_instances[idx].sublayers(
  182. include_self=True
  183. ):
  184. layer._cast_to_low_precision = False
  185. excluded_layers_types = tuple(excluded_layers_types)
  186. for idx in range(len(models)):
  187. for layer in models[idx].sublayers(include_self=True):
  188. if isinstance(layer, excluded_layers_types):
  189. layer._cast_to_low_precision = False
  190. def _pir_apply(self, func, dtype, include_sublayers=True):
  191. if include_sublayers:
  192. for layer in self.children():
  193. _pir_apply(layer, func, dtype, include_sublayers)
  194. for key, param in self._parameters.items():
  195. if param is not None:
  196. param_applied = func(param, dtype)
  197. for key, buf in self._buffers.items():
  198. if buf is not None:
  199. self._buffers[key] = func(buf, dtype)
  200. self._dtype = dtype
  201. def _pir_transform(t, dtype):
  202. main = paddle.static.default_main_program()
  203. startup = paddle.static.default_startup_program()
  204. with paddle.static.program_guard(startup):
  205. block = startup.global_block()
  206. for op in block.ops:
  207. if (
  208. op.name() == 'builtin.set_parameter'
  209. and op.attrs()['parameter_name'] == t.name
  210. ):
  211. param = op.operand(0).source()
  212. cast_param = paddle.cast(param, dtype)
  213. cast_param.persistable = True
  214. paddle._pir_ops.update_parameter(cast_param, t.name)
  215. block.remove_op(op)
  216. break
  217. main.set_parameters_from(startup)
  218. with paddle.static.program_guard(main):
  219. paddle.pir.reset_insertion_point_to_start()
  220. block = main.global_block()
  221. cast_param = paddle._pir_ops.parameter(t.name)
  222. cast_param.trainable = t.trainable
  223. cast_param.stop_gradient = t.stop_gradient
  224. cast_param.persistable = t.persistable
  225. cast_param.optimize_attr = t.optimize_attr
  226. cast_param.regularizer = t.regularizer
  227. cast_param.do_model_average = t.do_model_average
  228. cast_param.need_clip = t.need_clip
  229. cast_param.is_distributed = t.is_distributed
  230. cast_param.is_parameter = t.is_parameter
  231. op = t.get_defining_op()
  232. t.replace_all_uses_with(cast_param)
  233. block.remove_op(op)
  234. t.value_assign(cast_param)
  235. def _pir_to_impl(self, dtype, include_sublayers, floating_only):
  236. def transform(t, dtype):
  237. if floating_only and (not paddle.is_floating_point(t)):
  238. return t
  239. return _pir_transform(t, dtype)
  240. with warnings.catch_warnings():
  241. warnings.filterwarnings("ignore", category=UserWarning)
  242. _pir_apply(self, transform, dtype, include_sublayers)
  243. self._dtype = dtype
  244. return self
  245. def amp_initialize(models, dtype, excluded_layers):
  246. set_excluded_layers(models, excluded_layers)
  247. for idx in range(len(models)):
  248. for layer in models[idx].sublayers(include_self=True):
  249. if need_keep_fp32(layer, dtype):
  250. continue
  251. if dtype == "float16" and isinstance(
  252. layer,
  253. (
  254. paddle.incubate.nn.FusedFeedForward,
  255. paddle.incubate.nn.FusedMultiHeadAttention,
  256. ),
  257. ):
  258. layer._amp_decorate(dtype=dtype)
  259. continue
  260. if in_pir_mode():
  261. _pir_to_impl(
  262. layer,
  263. dtype=dtype,
  264. include_sublayers=False,
  265. floating_only=True,
  266. )
  267. else:
  268. layer._to_impl(
  269. dtype=dtype, include_sublayers=False, floating_only=True
  270. )
  271. return models
  272. def check_models(models):
  273. for model in models:
  274. if not isinstance(model, paddle.nn.Layer):
  275. raise RuntimeError(
  276. f"Current train mode is pure fp16, models should be paddle.nn.Layer, but receive {type(model)}."
  277. )
  278. if isinstance(model, paddle.DataParallel):
  279. raise RuntimeError(
  280. "For distributed AMP training, you should first use paddle.amp.decorate() to decorate origin model, and then call paddle.DataParallel get distributed model."
  281. )
  282. def _is_valid_optimizer(optimizer):
  283. from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
  284. DygraphShardingOptimizer,
  285. DygraphShardingOptimizerV2,
  286. )
  287. return isinstance(
  288. optimizer,
  289. (
  290. paddle.optimizer.Optimizer,
  291. DygraphShardingOptimizer,
  292. DygraphShardingOptimizerV2,
  293. ),
  294. )
  295. def check_optimizers(optimizers):
  296. for optimizer in optimizers:
  297. if not _is_valid_optimizer(optimizer):
  298. raise RuntimeError(
  299. f"Current train mode is pure fp16, optimizers should be paddle.optimizer.Optimizer or DygraphShardingOptimizer, but receive {type(optimizer)}."
  300. )
  301. @signature_safe_contextmanager
  302. def amp_guard(
  303. enable=True,
  304. custom_white_list=None,
  305. custom_black_list=None,
  306. level='O1',
  307. dtype='float16',
  308. use_promote=True,
  309. ):
  310. """
  311. Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
  312. If enabled, the input data type (float32 or float16) of each operator is decided
  313. by autocast algorithm for better performance.
  314. Commonly, it is used together with `GradScaler` to achieve Auto-Mixed-Precision in
  315. imperative mode. It is used together with `decorator` to achieve Pure fp16 in imperative mode.
  316. Args:
  317. enable(bool, optional): Enable auto-mixed-precision or not. Default is True.
  318. custom_white_list(set|list|tuple, optional): The custom white_list. It's the set of ops that support
  319. fp16 calculation and are considered numerically-safe and performance-critical. These ops
  320. will be converted to fp16.
  321. custom_black_list(set|list|tuple, optional): The custom black_list. The set of ops that support fp16
  322. calculation and are considered numerically-dangerous and whose effects may also be
  323. observed in downstream ops. These ops will not be converted to fp16.
  324. level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list;
  325. O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp)
  326. dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
  327. Examples:
  328. .. code-block:: python
  329. >>> # doctest: +REQUIRES(env:GPU)
  330. >>> import paddle
  331. >>> data = paddle.uniform([10, 3, 32, 32], paddle.float32, -1, 1)
  332. >>> conv2d = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
  333. >>> conv2d = paddle.amp.amp_decorate(models=conv2d, level='O2')
  334. >>> with paddle.amp.amp_guard():
  335. ... conv = conv2d(data)
  336. ... print(conv.dtype)
  337. >>> # doctest: +SKIP("This has diff in xdoctest env")
  338. paddle.float16
  339. >>> # doctest: -SKIP
  340. ...
  341. >>> with paddle.amp.amp_guard(enable=False):
  342. ... conv = conv2d(data)
  343. ... print(conv.dtype)
  344. >>> # doctest: +SKIP("This has diff in xdoctest env")
  345. paddle.float32
  346. >>> # doctest: -SKIP
  347. """
  348. assert (
  349. in_dynamic_or_pir_mode()
  350. ), "We only support 'amp_guard' in dynamic or pir mode."
  351. amp_state = locals()
  352. global _g_amp_state_
  353. original_state = _g_amp_state_
  354. _g_amp_state_ = amp_state
  355. # check amp_level: O0-O2
  356. level = level.upper()
  357. if level not in ['O0', 'OD', 'O1', 'O2']:
  358. raise ValueError("level should be O0, OD, O1 or O2.")
  359. # check amp_dtype: float16 or bfloat16
  360. dtype = dtype.lower()
  361. if enable:
  362. if dtype not in ['float16', 'bfloat16']:
  363. raise ValueError(
  364. "If enable amp, dtype should be 'float16' or 'bfloat16'."
  365. )
  366. amp_dtype = dtype
  367. amp_global_state().amp_dtype = amp_dtype
  368. if level == 'OD':
  369. amp_level = AMP_LEVEL.OD
  370. elif level == 'O1':
  371. amp_level = AMP_LEVEL.O1
  372. elif level == 'O2':
  373. amp_level = AMP_LEVEL.O2
  374. elif level == 'O0':
  375. amp_level = AMP_LEVEL.O0
  376. _white_list, _black_list = _update_list(
  377. custom_white_list, custom_black_list, level, dtype
  378. )
  379. if in_pir_mode():
  380. if not enable:
  381. amp_level = AMP_LEVEL.O0
  382. amp_dtype = "float32"
  383. amp_attrs = core._get_amp_attrs()
  384. # set amp level
  385. original_amp_level = amp_attrs._amp_level
  386. amp_attrs._amp_level = amp_level
  387. # set amp op list
  388. original_white_list, original_black_list = core._get_amp_op_list()
  389. core._set_amp_op_list(_white_list, _black_list)
  390. # set amp dtype
  391. original_amp_dtype = amp_attrs._amp_dtype
  392. amp_attrs._amp_dtype = amp_dtype
  393. # switch promote
  394. if amp_level == AMP_LEVEL.O2:
  395. original_use_promote = amp_attrs._use_promote
  396. amp_attrs._use_promote = use_promote
  397. try:
  398. yield
  399. finally:
  400. _g_amp_state_ = original_state
  401. amp_attrs._amp_level = original_amp_level
  402. core._set_amp_op_list(original_white_list, original_black_list)
  403. amp_attrs._amp_dtype = original_amp_dtype
  404. if amp_level == AMP_LEVEL.O2:
  405. amp_attrs._use_promote = original_use_promote
  406. else:
  407. # check tracer
  408. tracer = _dygraph_tracer()
  409. if not tracer:
  410. raise ValueError(
  411. "current_tracer is None, maybe it is not in imperative mode."
  412. )
  413. # check device_type:
  414. # NOTE: Now, amp only support gpu for float16 and bfloat16, xpu for float16, npu for float16 and bfloat16.
  415. # Maybe we will support cpu for bfloat16.
  416. if enable and not (
  417. tracer._expected_place.is_gpu_place()
  418. or tracer._expected_place.is_xpu_place()
  419. or tracer._expected_place.is_custom_place()
  420. ):
  421. warnings.warn(
  422. 'amp_guard can only be enabled on CUDAPlace, XPUPlace, and CustomPlace, current place is %s, so it makes no effect.'
  423. % tracer._expected_place
  424. )
  425. enable = False
  426. if enable:
  427. # For xpu:
  428. if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'):
  429. warnings.warn('XPUPlace only support float16 amp.')
  430. enable = False
  431. # For custom device:
  432. if (
  433. tracer._expected_place.is_custom_place()
  434. and not _is_custom_device_bfloat16_supported()
  435. and (dtype == 'bfloat16')
  436. ):
  437. warnings.warn('CustomPlace only support float16 amp.')
  438. enable = False
  439. # For gpu float16: Compute Capability should >= 7.
  440. # For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11.
  441. if tracer._expected_place.is_gpu_place():
  442. if (dtype == 'float16') and not _is_gpu_float16_supported():
  443. prop = paddle.device.cuda.get_device_capability()
  444. warnings.warn(
  445. "For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d."
  446. % (
  447. paddle.device.cuda.get_device_name(),
  448. prop[0],
  449. prop[1],
  450. )
  451. )
  452. enable = False
  453. elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported():
  454. prop = paddle.device.cuda.get_device_capability()
  455. cuda_version = paddle.version.cuda()
  456. warnings.warn(
  457. "For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: %s, with Compute Capability: %d.%d, current CUDA Version is: %s."
  458. % (
  459. paddle.device.cuda.get_device_name(),
  460. prop[0],
  461. prop[1],
  462. cuda_version,
  463. )
  464. )
  465. enable = False
  466. if not enable:
  467. amp_level = AMP_LEVEL.O0
  468. amp_dtype = "float32"
  469. # master_grad_hook will run at the end of backward.
  470. # Since backward_final_hook will be cleared once they have been
  471. # done, we should register the hook every step.
  472. if (
  473. amp_global_state().use_master_grad
  474. and not amp_global_state().already_register_final_backward_hook
  475. ):
  476. def master_grad_hook():
  477. # NOTE(lizhiyu): To support semi-auto of dygraph mode, we must
  478. # classify the params of model into different calsses according to their process_mesh.
  479. # Otherwise, fault will occur.
  480. if not amp_global_state().already_classify_params_meshes:
  481. for param in amp_global_state().model_parameters:
  482. if param is not None and param.process_mesh is not None:
  483. if (
  484. param.process_mesh
  485. not in amp_global_state().mesh2params
  486. ):
  487. amp_global_state().mesh2params[
  488. param.process_mesh
  489. ] = [param]
  490. else:
  491. amp_global_state().mesh2params[
  492. param.process_mesh
  493. ].append(param)
  494. amp_global_state().already_classify_params_meshes = True
  495. if len(amp_global_state().mesh2params):
  496. for _, params in amp_global_state().mesh2params.items():
  497. core.eager.set_master_grads(params)
  498. else:
  499. core.eager.set_master_grads(
  500. amp_global_state().model_parameters
  501. )
  502. amp_global_state().already_register_final_backward_hook = False
  503. core.eager._add_backward_final_hook(master_grad_hook)
  504. amp_global_state().already_register_final_backward_hook = True
  505. if tracer:
  506. # enable auto_cast
  507. original_amp_level = tracer._amp_level
  508. tracer._amp_level = amp_level
  509. # set amp op list
  510. original_white_list, original_black_list = tracer._get_amp_op_list()
  511. tracer._set_amp_op_list(_white_list, _black_list)
  512. # TODO(zhiqiu) set amp related flags automatically in this guard
  513. # Currently, if FLAGS_cudnn_batchnorm_spatial_persistent is set True in amp_guard,
  514. # batch_norm can run in fast mode, but batch_norm_grad can not if backward if not executed inside amp_guard.
  515. # So, users need to set related flags manually.
  516. # original_flags = get_flags(AMP_RELATED_FLAGS)
  517. # set_flags(AMP_RELATED_FLAGS_SETTING)
  518. # set amp dtype
  519. original_amp_dtype = tracer._amp_dtype
  520. tracer._amp_dtype = amp_dtype
  521. # switch promote
  522. if amp_level == AMP_LEVEL.O2:
  523. original_use_promote = tracer._use_promote
  524. tracer._use_promote = use_promote
  525. # restore status
  526. try:
  527. yield
  528. finally:
  529. if tracer:
  530. _g_amp_state_ = original_state
  531. tracer._amp_level = original_amp_level
  532. tracer._set_amp_op_list(
  533. original_white_list, original_black_list
  534. )
  535. # set_flags(original_flags)
  536. tracer._amp_dtype = original_amp_dtype
  537. if amp_level == AMP_LEVEL.O2:
  538. tracer._use_promote = original_use_promote
  539. class StateDictHook:
  540. def __init__(self, save_dtype):
  541. self._save_dtype = save_dtype
  542. def __call__(self, state_dict):
  543. for key in state_dict:
  544. param = state_dict[key]
  545. if paddle.is_floating_point(param):
  546. param_applied = paddle.cast(param, self._save_dtype)
  547. param_applied.name = param.name
  548. state_dict[key] = param_applied
  549. def _set_multi_precision(optimizer, multi_precision):
  550. from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
  551. DygraphShardingOptimizer,
  552. DygraphShardingOptimizerV2,
  553. )
  554. optimizer = (
  555. optimizer._inner_opt
  556. if isinstance(
  557. optimizer, (DygraphShardingOptimizer, DygraphShardingOptimizerV2)
  558. )
  559. else optimizer
  560. )
  561. if hasattr(optimizer, "_multi_precision"):
  562. optimizer._multi_precision = multi_precision
  563. @dygraph_only
  564. def amp_decorate(
  565. models,
  566. optimizers=None,
  567. level='O1',
  568. dtype='float16',
  569. master_weight=None,
  570. save_dtype=None,
  571. master_grad=False,
  572. excluded_layers=None,
  573. ):
  574. """
  575. Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing.
  576. When level is O2(pure fp16), the decorate will cast all parameters of models to FP16, except BatchNorm, InstanceNorm and LayerNorm.
  577. Commonly, it is used together with `amp_guard` to achieve Pure fp16 in imperative mode.
  578. Args:
  579. models(Layer|list of Layer, optional): The defined models by user, models must be either a single model or a list of models. Default is None.
  580. optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None.
  581. level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the decorator will do nothing;
  582. O2 represent Pure fp16/bf16, the decorator will cast all parameters of models to FP16/BF16, except BatchNorm, InstanceNorm and LayerNorm. Default is O1(amp)
  583. dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
  584. master_weight(bool, optional): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None.
  585. save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None.
  586. The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None.
  587. Examples:
  588. .. code-block:: python
  589. >>> # doctest: +REQUIRES(env:GPU)
  590. >>> # Demo1: single model and optimizer:
  591. >>> import paddle
  592. >>> paddle.device.set_device('gpu')
  593. >>> model = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
  594. >>> optimizer = paddle.optimizer.SGD(parameters=model.parameters())
  595. >>> model, optimizer = paddle.amp.amp_decorate(models=model, optimizers=optimizer, level='O2')
  596. >>> data = paddle.rand([10, 3, 32, 32])
  597. >>> with paddle.amp.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
  598. ... output = model(data)
  599. ... print(output.dtype)
  600. paddle.float16
  601. >>> # Demo2: multi models and optimizers:
  602. >>> model2 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
  603. >>> optimizer2 = paddle.optimizer.Adam(parameters=model2.parameters())
  604. >>> models, optimizers = paddle.amp.amp_decorate(models=[model, model2], optimizers=[optimizer, optimizer2], level='O2')
  605. >>> data = paddle.rand([10, 3, 32, 32])
  606. >>> with paddle.amp.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
  607. ... output = models[0](data)
  608. ... output2 = models[1](data)
  609. ... print(output.dtype)
  610. ... print(output2.dtype)
  611. paddle.float16
  612. paddle.float16
  613. >>> # Demo3: optimizers is None:
  614. >>> model3 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
  615. >>> optimizer3 = paddle.optimizer.Adam(parameters=model2.parameters())
  616. >>> model = paddle.amp.amp_decorate(models=model3, level='O2')
  617. >>> data = paddle.rand([10, 3, 32, 32])
  618. >>> with paddle.amp.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
  619. ... output = model(data)
  620. ... print(output.dtype)
  621. paddle.float16
  622. """
  623. if level not in ['O1', 'O2']:
  624. raise ValueError(
  625. "level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode."
  626. )
  627. if dtype not in ['float16', 'bfloat16']:
  628. raise ValueError("dtype only support float16 or bfloat16.")
  629. if level == 'O1':
  630. if optimizers is None:
  631. return models
  632. else:
  633. return models, optimizers
  634. # check tracer
  635. tracer = _dygraph_tracer()
  636. if not tracer:
  637. raise ValueError(
  638. "current_tracer is None, maybe it is not in imperative mode."
  639. )
  640. # check device_type:
  641. if not (
  642. tracer._expected_place.is_gpu_place()
  643. or tracer._expected_place.is_xpu_place()
  644. or tracer._expected_place.is_custom_place()
  645. ):
  646. if optimizers is None:
  647. return models
  648. else:
  649. return models, optimizers
  650. # For xpu:
  651. if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'):
  652. if optimizers is None:
  653. return models
  654. else:
  655. return models, optimizers
  656. # For custom device:
  657. if (
  658. tracer._expected_place.is_custom_place()
  659. and not _is_custom_device_bfloat16_supported()
  660. and (dtype == 'bfloat16')
  661. ):
  662. if optimizers is None:
  663. return models
  664. else:
  665. return models, optimizers
  666. # For gpu float16: Compute Capability should >= 7.
  667. # For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11.
  668. if tracer._expected_place.is_gpu_place():
  669. if (dtype == 'float16' and not _is_gpu_float16_supported()) or (
  670. dtype == 'bfloat16' and not _is_gpu_bfloat16_supported()
  671. ):
  672. if optimizers is None:
  673. return models
  674. else:
  675. return models, optimizers
  676. models_is_list = False
  677. if isinstance(models, paddle.nn.Layer):
  678. models_is_list = False
  679. models = [models]
  680. check_models(models)
  681. elif isinstance(models, list):
  682. check_models(models)
  683. models_is_list = True
  684. else:
  685. raise TypeError(
  686. "models must be either a single model or a list of models."
  687. )
  688. # initialize parameters of the model.
  689. amp_initialize(models=models, dtype=dtype, excluded_layers=excluded_layers)
  690. if optimizers is not None:
  691. # check optimizers
  692. optimizers_is_list = False
  693. if _is_valid_optimizer(optimizers):
  694. optimizers_is_list = False
  695. optimizers = [optimizers]
  696. check_optimizers(optimizers)
  697. elif isinstance(optimizers, list):
  698. check_optimizers(optimizers)
  699. optimizers_is_list = True
  700. else:
  701. raise TypeError(
  702. "optimizers must be either a single optimizer or a list of optimizers."
  703. )
  704. # support master_weight
  705. use_multi_precision = master_weight is not False
  706. for opt in optimizers:
  707. _set_multi_precision(opt, use_multi_precision)
  708. # support master_grad
  709. if master_grad:
  710. amp_global_state().use_master_grad = True
  711. for idx in range(len(models)):
  712. amp_global_state().model_parameters.extend(models[idx].parameters())
  713. if save_dtype is not None:
  714. if save_dtype not in ['float16', 'bfloat16', 'float32', 'float64']:
  715. raise ValueError(
  716. "save_dtype can only be float16 float32 or float64, but your input save_dtype is %s."
  717. % save_dtype
  718. )
  719. for idx in range(len(models)):
  720. for layer in models[idx].sublayers(include_self=True):
  721. layer.register_state_dict_hook(StateDictHook(save_dtype))
  722. if models_is_list:
  723. if optimizers is not None:
  724. if optimizers_is_list:
  725. return models, optimizers
  726. else:
  727. return models, optimizers[0]
  728. else:
  729. return models
  730. else:
  731. if optimizers is not None:
  732. if optimizers_is_list:
  733. return models[0], optimizers
  734. else:
  735. return models[0], optimizers[0]
  736. else:
  737. return models[0]
  738. def auto_cast(
  739. enable=True,
  740. custom_white_list=None,
  741. custom_black_list=None,
  742. level='O1',
  743. dtype='float16',
  744. use_promote=True,
  745. ):
  746. """
  747. Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
  748. If enabled, the input data type (float32, float16 or bfloat16) of each operator is decided
  749. by autocast algorithm for better performance.
  750. Commonly, it is used together with `GradScaler` and `decorator` to achieve Auto-Mixed-Precision in
  751. imperative mode.
  752. Args:
  753. enable(bool, optional): Enable auto-mixed-precision or not. Default is True.
  754. custom_white_list(set|list|tuple, optional): A default white list is already set. Usually there is no need to set custom white list.
  755. The set of ops should be considered numerically-safe and performance-critical. These ops will be converted to float16/bfloat16.
  756. custom_black_list(set|list|tuple, optional): A default black list is already set. You can set a custom black list according to the model.
  757. The set of ops are considered numerically-dangerous and whose effects may also be observed in downstream ops. These ops will not be
  758. converted to float16/bfloat16.
  759. level(str, optional): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list
  760. will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2
  761. level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs
  762. will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
  763. default white list will compute in float16/bfloat16, and the others will compute in float32. Default is O1.
  764. dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
  765. use_promote(bool, optional): Whether to promotes to fp32 when op has any float32 inputs. It is only supported when amp level is O2. Default is True.
  766. Examples:
  767. .. code-block:: python
  768. >>> # doctest: +REQUIRES(env:GPU)
  769. >>> import paddle
  770. >>> conv2d = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
  771. >>> data = paddle.rand([10, 3, 32, 32])
  772. >>> with paddle.amp.auto_cast():
  773. ... conv = conv2d(data)
  774. ... print(conv.dtype)
  775. >>> # doctest: +SKIP("This has diff in xdoctest env")
  776. paddle.float16
  777. >>> # doctest: -SKIP
  778. >>> with paddle.amp.auto_cast(enable=False):
  779. ... conv = conv2d(data)
  780. ... print(conv.dtype)
  781. >>> # doctest: +SKIP("This has diff in xdoctest env")
  782. paddle.float32
  783. >>> # doctest: -SKIP
  784. >>> with paddle.amp.auto_cast(custom_black_list={'conv2d'}):
  785. ... conv = conv2d(data)
  786. ... print(conv.dtype)
  787. >>> # doctest: +SKIP("This has diff in xdoctest env")
  788. paddle.float32
  789. >>> # doctest: -SKIP
  790. >>> a = paddle.rand([2, 3])
  791. >>> b = paddle.rand([2, 3])
  792. >>> with paddle.amp.auto_cast(custom_white_list={'elementwise_add'}):
  793. ... c = a + b
  794. ... print(c.dtype)
  795. >>> # doctest: +SKIP("This has diff in xdoctest env")
  796. paddle.float16
  797. >>> # doctest: -SKIP
  798. >>> with paddle.amp.auto_cast(custom_white_list={'elementwise_add'}, level='O2'):
  799. ... d = a + b
  800. ... print(d.dtype)
  801. >>> # doctest: +SKIP("This has diff in xdoctest env")
  802. paddle.float16
  803. >>> # doctest: -SKIP
  804. """
  805. return amp_guard(
  806. enable, custom_white_list, custom_black_list, level, dtype, use_promote
  807. )
  808. def decorate(
  809. models,
  810. optimizers=None,
  811. level='O1',
  812. dtype='float16',
  813. master_weight=None,
  814. save_dtype=None,
  815. master_grad=False,
  816. excluded_layers=None,
  817. ):
  818. """
  819. Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing.
  820. When level is O2(pure float16/bfloat16), the decorate will cast all parameters of models to float16/bfloat16, except BatchNorm, InstanceNorm and LayerNorm.
  821. Commonly, it is used together with `auto_cast` to achieve Pure float16/bfloat16 in imperative mode.
  822. Args:
  823. models(Layer|list of Layer): The defined models by user, models must be either a single model or a list of models. Default is None.
  824. optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None.
  825. level(str, optional): Auto mixed precision level. Accepted values are 'O1' and 'O2': O1 represent mixed precision, the decorator will do nothing;
  826. O2 represent Pure float16/bfloat16, the decorator will cast all parameters of models to float16/bfloat16, except BatchNorm, InstanceNorm and LayerNorm. Default is O1(amp)
  827. dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
  828. master_weight(bool, optional): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None.
  829. save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None.
  830. The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None.
  831. master_grad(bool, optional): For level='O2', whether to use float32 weight gradients for calculations such as gradient clipping, weight decay, and weight updates. If master_grad is enabled, the weight
  832. gradients will be float32 dtype after the backpropagation. Default is False, there is only float16 weight gradients.
  833. excluded_layers(Layer|list of Layer, optional): Specify the layers not to be decorated. The weights of these layers will always keep float32 when level is O2. `excluded_layers` can be specified as
  834. an Layer instance/type or a list of Layer instances/types. Default is None, the weights of the whole model will be casted to float16 or bfloat16.
  835. Examples:
  836. .. code-block:: python
  837. >>> # doctest: +REQUIRES(env:GPU)
  838. >>> # Demo1: single model and optimizer:
  839. >>> import paddle
  840. >>> paddle.device.set_device('gpu')
  841. >>> model = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
  842. >>> optimizer = paddle.optimizer.SGD(parameters=model.parameters())
  843. >>> model, optimizer = paddle.amp.decorate(models=model, optimizers=optimizer, level='O2')
  844. >>> data = paddle.rand([10, 3, 32, 32])
  845. >>> with paddle.amp.auto_cast(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
  846. ... output = model(data)
  847. ... print(output.dtype)
  848. paddle.float16
  849. >>> # Demo2: multi models and optimizers:
  850. >>> model2 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
  851. >>> optimizer2 = paddle.optimizer.Adam(parameters=model2.parameters())
  852. >>> models, optimizers = paddle.amp.decorate(models=[model, model2], optimizers=[optimizer, optimizer2], level='O2')
  853. >>> data = paddle.rand([10, 3, 32, 32])
  854. >>> with paddle.amp.auto_cast(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
  855. ... output = models[0](data)
  856. ... output2 = models[1](data)
  857. ... print(output.dtype)
  858. ... print(output2.dtype)
  859. paddle.float16
  860. paddle.float16
  861. >>> # Demo3: optimizers is None:
  862. >>> model3 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
  863. >>> optimizer3 = paddle.optimizer.Adam(parameters=model3.parameters())
  864. >>> model = paddle.amp.decorate(models=model3, level='O2')
  865. >>> data = paddle.rand([10, 3, 32, 32])
  866. >>> with paddle.amp.auto_cast(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
  867. ... output = model(data)
  868. ... print(output.dtype)
  869. paddle.float16
  870. """
  871. if paddle.framework.in_pir_mode():
  872. assert not isinstance(models, (list, tuple))
  873. assert not isinstance(optimizers, (list, tuple))
  874. if level in ['O0', 'OD', 'O1']:
  875. if optimizers is None:
  876. return models
  877. else:
  878. optimizers = OptimizerWithMixedPrecision(
  879. optimizer=optimizers,
  880. amp_lists=None,
  881. level=level,
  882. dtype=dtype,
  883. init_loss_scaling=1.0,
  884. incr_every_n_steps=None,
  885. decr_every_n_nan_or_inf=None,
  886. incr_ratio=None,
  887. decr_ratio=None,
  888. use_dynamic_loss_scaling=False,
  889. use_amp_guard=None,
  890. use_master_grad=master_grad,
  891. use_promote=None,
  892. )
  893. return models, optimizers
  894. elif level == 'O2':
  895. amp_initialize(
  896. models=[models], dtype=dtype, excluded_layers=excluded_layers
  897. )
  898. use_multi_precision = master_weight is not False
  899. _set_multi_precision(optimizers, use_multi_precision)
  900. if optimizers is None:
  901. return models
  902. else:
  903. optimizers = OptimizerWithMixedPrecision(
  904. optimizer=optimizers,
  905. amp_lists=None,
  906. level=level,
  907. dtype=dtype,
  908. init_loss_scaling=1.0,
  909. incr_every_n_steps=None,
  910. decr_every_n_nan_or_inf=None,
  911. incr_ratio=None,
  912. decr_ratio=None,
  913. use_dynamic_loss_scaling=False,
  914. use_amp_guard=None,
  915. use_master_grad=master_grad,
  916. use_promote=None,
  917. )
  918. return models, optimizers
  919. else:
  920. raise ValueError("level should be O0, OD, O1 or O2.")
  921. else:
  922. return amp_decorate(
  923. models,
  924. optimizers,
  925. level,
  926. dtype,
  927. master_weight,
  928. save_dtype,
  929. master_grad,
  930. excluded_layers,
  931. )