base.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824
  1. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. import sys
  16. import warnings
  17. import decorator
  18. import paddle
  19. from paddle.base import core, framework
  20. from paddle.base.framework import global_var
  21. from paddle.base.multiprocess_utils import CleanupFuncRegistrar
  22. from ..framework import _get_paddle_place
  23. from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
  24. from .tracer import Tracer
  25. __all__ = []
  26. NON_PERSISTABLE_VAR_NAME_SUFFIX = "__non_persistable"
  27. def in_to_static_mode():
  28. """
  29. Return a bool value that indicates whether running code under `@to_static`
  30. """
  31. return global_var._in_to_static_mode_
  32. # TODO(Aurelius84): Need to remove this alias after clean usage in PaddleX
  33. in_declarative_mode = in_to_static_mode
  34. def to_static_unsupport_argument_warning(
  35. func_name, input_names, inputs, support_values
  36. ):
  37. """
  38. Warning if inputs do not elementwisely equals to support_values.
  39. It's a utility function for dy2static when dygraph interface have
  40. more inputs than static interface such as paddle.grad.
  41. """
  42. for name, inp, sup in zip(input_names, inputs, support_values):
  43. if inp != sup:
  44. warnings.warn(
  45. f"{func_name} has unsupported parameter in jit: "
  46. + f"{name}, jit will discard it"
  47. )
  48. def _switch_to_static_graph_(func):
  49. def __impl__(*args, **kwargs):
  50. with framework._dygraph_guard(None):
  51. return func(*args, **kwargs)
  52. return __impl__
  53. switch_to_static_graph = wrap_decorator(_switch_to_static_graph_)
  54. @signature_safe_contextmanager
  55. def _to_static_mode_guard_(is_to_static=True):
  56. global global_var
  57. original_val = global_var._in_to_static_mode_
  58. global_var._in_to_static_mode_ = is_to_static
  59. try:
  60. yield
  61. finally:
  62. global_var._in_to_static_mode_ = original_val
  63. @signature_safe_contextmanager
  64. def param_guard(parameters):
  65. # Note: parameters is a reference of self._parameters or self._buffers
  66. if in_to_static_mode() and not paddle.in_dynamic_mode() and parameters:
  67. try:
  68. origin_parameters = parameters.copy()
  69. for name, var_base in parameters.items():
  70. if isinstance(var_base, list):
  71. new_var = [_convert_into_variable(var) for var in var_base]
  72. else:
  73. new_var = _convert_into_variable(var_base)
  74. parameters[name] = new_var
  75. yield
  76. finally:
  77. parameters.update(origin_parameters)
  78. else:
  79. yield
  80. def _convert_into_variable(tensor):
  81. """
  82. Convert Tensor into Variable.
  83. """
  84. if paddle.framework.use_pir_api():
  85. return paddle.pir.core._convert_into_value(tensor)
  86. if isinstance(tensor, core.eager.Tensor):
  87. # Check whether has been created before.
  88. new_var = tensor.block._find_var_recursive(tensor.name)
  89. if new_var is not None:
  90. assert isinstance(new_var, framework.Variable)
  91. # Convert EagerParamBase into Parameter with same attributes in dy2stat.
  92. elif isinstance(tensor, framework.EagerParamBase):
  93. new_var = tensor._to_static_var(to_parameter=True)
  94. else:
  95. # Note(Aurelius84): Convert Tensor in self._buffers into Variable with
  96. # same attributes and set persistable=True to allow saving this var.
  97. # Because users can create a Tensor in `__init__` like a
  98. # `mask` Tensor or `hidden_0` in RNN layers, which is equivalent to a Parameter
  99. # and necessary for inferring. It will be pruned if it's not necessary for inferring.
  100. # But if its shape is empty while created from `create_variable()`, we consider this buffer
  101. # non-persistable. See case of `dropout_state` in lstm api.
  102. is_persistable = True
  103. if tensor.name.endswith(NON_PERSISTABLE_VAR_NAME_SUFFIX):
  104. is_persistable = False
  105. new_var = tensor._to_static_var(
  106. to_parameter=False, persistable=is_persistable
  107. )
  108. # add param into parameter recorder to collect all the params used in this program.
  109. if new_var.persistable is True:
  110. from paddle.jit.dy2static.program_translator import (
  111. ProgramTranslator,
  112. )
  113. ProgramTranslator.get_instance()._params_recorder.add(
  114. tensor.block.program, tensor
  115. )
  116. return new_var
  117. else:
  118. return tensor
  119. def enabled():
  120. """
  121. This function checks whether the program runs in dynamic graph mode or not.
  122. You can enable dynamic graph mode with :ref:`api_paddle_disable_static` api,
  123. or disable dynamic graph mode with :ref:`api_paddle_enable_static` .
  124. **Note**:
  125. ``base.dygraph.enabled`` is the alias of ``base.in_dygraph_mode``, and
  126. ``base.in_dygraph_mode`` is recommended to use for now.
  127. Returns:
  128. bool: Whether the program is running in dynamic graph mode.
  129. Examples:
  130. .. code-block:: python
  131. >>> import paddle.base as base
  132. >>> base.enable_dygraph() # Now we are in dygragh mode
  133. >>> print(base.dygraph.enabled())
  134. True
  135. >>> base.disable_dygraph()
  136. >>> print(base.dygraph.enabled())
  137. False
  138. """
  139. # TODO(jiabin): Make this check as in_dygraph_mode when we support default eager mode.
  140. return framework.in_dygraph_mode()
  141. def enable_dygraph(place=None):
  142. """
  143. .. note::
  144. Dynamic graph mode is turn ON by default since paddle 2.0.0
  145. This API turn OFF static graph mode. You can turn ON static graph mode by `enable_static <./disable_dygraph_en.html>`_ .
  146. Parameters:
  147. place(paddle.CPUPlace|paddle.CUDAPlace|str, optional): Place to run dynamic graph. Default: None. Which means that the running place will be
  148. determined according to the way of paddle compilation. If ``place`` is string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the
  149. index of the GPUs.
  150. return:
  151. None
  152. Examples:
  153. .. code-block:: python
  154. >>> import paddle
  155. >>> print(paddle.in_dynamic_mode())
  156. True
  157. >>> paddle.enable_static()
  158. >>> print(paddle.in_dynamic_mode())
  159. False
  160. >>> paddle.disable_static()
  161. >>> print(paddle.in_dynamic_mode())
  162. True
  163. """
  164. global global_var
  165. if global_var._functional_dygraph_context_manager is None:
  166. global_var._functional_dygraph_context_manager = guard(
  167. place=_get_paddle_place(place)
  168. )
  169. global_var._functional_dygraph_context_manager.__enter__()
  170. # call disable_dygraph when Python exit
  171. CleanupFuncRegistrar.register(disable_dygraph)
  172. def disable_dygraph():
  173. """
  174. .. note::
  175. Dynamic graph mode is turn ON by default since paddle 2.0.0
  176. This API turn ON static graph mode. You can turn ON static graph mode by `disable_static <./enable_dygraph_en.html>`_ .
  177. return:
  178. None
  179. Examples:
  180. .. code-block:: python
  181. >>> import paddle
  182. >>> print(paddle.in_dynamic_mode())
  183. True
  184. >>> paddle.enable_static()
  185. >>> print(paddle.in_dynamic_mode())
  186. False
  187. >>> paddle.disable_static()
  188. >>> print(paddle.in_dynamic_mode())
  189. True
  190. """
  191. global global_var
  192. if global_var._functional_dygraph_context_manager is not None:
  193. global_var._functional_dygraph_context_manager.__exit__(*sys.exc_info())
  194. global_var._functional_dygraph_context_manager = None
  195. @signature_safe_contextmanager
  196. def _switch_tracer_mode_guard_(is_train=True):
  197. tracer = framework._dygraph_tracer()
  198. if tracer:
  199. has_grad = tracer._has_grad
  200. tracer._has_grad = is_train
  201. try:
  202. yield
  203. finally:
  204. tracer._has_grad = has_grad
  205. else:
  206. yield
  207. def no_grad(func=None):
  208. """
  209. :api_attr: imperative
  210. Create a context which disables dygraph gradient calculation.
  211. In this mode, the result of every computation will have `stop_gradient=True`.
  212. Also functions as a decorator. (Make sure to instantiate without parenthesis.)
  213. Examples:
  214. .. code-block:: python
  215. >>> import numpy as np
  216. >>> import paddle.base as base
  217. >>> # use as generator
  218. >>> data = np.array([[2, 3], [4, 5]]).astype('float32')
  219. >>> with base.dygraph.guard():
  220. ... l0 = paddle.nn.Linear(2, 2) # l0.weight.gradient() is None
  221. ... l1 = paddle.nn.Linear(2, 2)
  222. ... with base.dygraph.no_grad():
  223. ... # l1.weight.stop_gradient is False
  224. ... tmp = l1.weight * 2 # tmp.stop_gradient is True
  225. ... x = base.dygraph.to_variable(data)
  226. ... y = l0(x) + tmp
  227. ... o = l1(y)
  228. ... o.backward()
  229. ... print(tmp.gradient() is None)
  230. ... print(l0.weight.gradient() is None)
  231. True
  232. False
  233. >>> @base.dygraph.no_grad
  234. >>> def test_layer():
  235. ... with base.dygraph.guard():
  236. ... inp = np.ones([3, 1024], dtype='float32')
  237. ... t = base.dygraph.base.to_variable(inp)
  238. ... linear1 = paddle.nn.Linear(1024, 4, bias_attr=False)
  239. ... linear2 = paddle.nn.Linear(4, 4)
  240. ... ret = linear1(t)
  241. ... dy_ret = linear2(ret)
  242. ...
  243. >>> test_layer()
  244. """
  245. if in_to_static_mode():
  246. warnings.warn(
  247. "paddle.no_grad is only supported for inference model, and not supported for training under @to_static."
  248. )
  249. if func is None:
  250. return _switch_tracer_mode_guard_(is_train=False)
  251. else:
  252. @decorator.decorator
  253. def __impl__(func, *args, **kwargs):
  254. with _switch_tracer_mode_guard_(is_train=False):
  255. return func(*args, **kwargs)
  256. return __impl__(func)
  257. class _DecoratorContextManager:
  258. """Allow a context manager to be used as a decorator"""
  259. def __call__(self, func):
  260. @decorator.decorator
  261. def _decorate_function(func, *args, **kwargs):
  262. with self:
  263. return func(*args, **kwargs)
  264. @decorator.decorator
  265. def _decorate_generator(func, *args, **kwargs):
  266. gen = func(*args, **kwargs)
  267. with self:
  268. yield from gen
  269. if inspect.isgeneratorfunction(func):
  270. return _decorate_generator(func)
  271. else:
  272. return _decorate_function(func)
  273. def __enter__(self):
  274. raise NotImplementedError
  275. def __exit__(self, exc_type, exc_value, traceback):
  276. raise NotImplementedError
  277. def clone(self):
  278. # override this method if your children class takes __init__ parameters
  279. return self.__class__()
  280. def is_grad_enabled():
  281. """
  282. Returns whether current dygraph gradient calculation mode is enabled.
  283. Returns:
  284. bool: True if current dygraph gradient calculation mode is enabled, otherwise false.
  285. Examples:
  286. .. code-block:: python
  287. >>> import paddle
  288. >>> # Dygraph gradient calculation mode is enabled by default.
  289. >>> paddle.is_grad_enabled()
  290. True
  291. >>> with paddle.set_grad_enabled(False):
  292. ... paddle.is_grad_enabled()
  293. False
  294. >>> paddle.enable_static()
  295. >>> paddle.is_grad_enabled()
  296. False
  297. """
  298. tracer = framework._dygraph_tracer()
  299. return tracer._has_grad if tracer else False
  300. def _set_grad_enabled(mode):
  301. tracer = framework._dygraph_tracer()
  302. if tracer:
  303. tracer._has_grad = mode
  304. class set_grad_enabled(_DecoratorContextManager):
  305. """
  306. Create a context which enables or disables dygraph gradient calculation.
  307. Args:
  308. mode(bool): whether to enable (`True`), or disable (`False`) grad.
  309. Returns:
  310. None.
  311. Examples:
  312. .. code-block:: python
  313. >>> import paddle
  314. >>> x = paddle.to_tensor([1.], stop_gradient=False)
  315. >>> is_train = False
  316. >>> with paddle.set_grad_enabled(is_train):
  317. ... y = x * 2
  318. >>> print(y.stop_gradient)
  319. True
  320. >>> paddle.set_grad_enabled(True)
  321. >>> y = x * 2
  322. >>> print(y.stop_gradient)
  323. False
  324. >>> paddle.set_grad_enabled(False)
  325. >>> y = x * 2
  326. >>> print(y.stop_gradient)
  327. True
  328. """
  329. def __init__(self, mode):
  330. self.prev = is_grad_enabled()
  331. _set_grad_enabled(mode)
  332. self.mode = mode
  333. def __enter__(self):
  334. ...
  335. def __exit__(self, *args):
  336. _set_grad_enabled(self.prev)
  337. def clone(self):
  338. return self.__class__(self.mode)
  339. class no_grad_(_DecoratorContextManager):
  340. """
  341. :api_attr: imperative
  342. Create a context which disables dygraph gradient calculation.
  343. In this mode, the result of every computation will have `stop_gradient` set
  344. to `True`.
  345. Also functions as a decorator. (Make sure to use an instance.)
  346. Examples:
  347. .. code-block:: python
  348. >>> import numpy as np
  349. >>> import paddle
  350. >>> # use as generator
  351. >>> data = np.array([[2, 3], [4, 5]]).astype('float32')
  352. >>> l0 = paddle.nn.Linear(2, 2) # l0.weight.gradient() is None
  353. >>> l1 = paddle.nn.Linear(2, 2)
  354. >>> with paddle.no_grad():
  355. ... # l1.weight.stop_gradient is False
  356. ... tmp = l1.weight * 2 # tmp.stop_gradient is True
  357. >>> x = paddle.to_tensor(data)
  358. >>> y = l0(x) + tmp
  359. >>> o = l1(y)
  360. >>> o.backward()
  361. >>> print(tmp.gradient() is None)
  362. True
  363. >>> print(l0.weight.gradient() is None)
  364. False
  365. >>> # use as decorator
  366. >>> @paddle.no_grad()
  367. >>> def test_layer():
  368. ... inp = np.ones([3, 1024], dtype='float32')
  369. ... t = paddle.to_tensor(inp)
  370. ... linear1 = paddle.nn.Linear(1024, 4, bias_attr=False)
  371. ... linear2 = paddle.nn.Linear(4, 4)
  372. ... ret = linear1(t)
  373. ... dy_ret = linear2(ret)
  374. ...
  375. >>> test_layer()
  376. """
  377. def __enter__(self):
  378. self.prev = is_grad_enabled()
  379. _set_grad_enabled(False)
  380. def __exit__(self, *args):
  381. _set_grad_enabled(self.prev)
  382. class enable_grad(_DecoratorContextManager):
  383. """
  384. :api_attr: imperative
  385. Create a context which enable dygraph gradient calculation,
  386. if it has been disabled by `no_grad` or `set_grad_enabled`.
  387. In this mode, the result of every computation will have `stop_gradient` set
  388. to `False`.
  389. Also functions as a decorator. (Make sure to use an instance.)
  390. Examples:
  391. .. code-block:: python
  392. >>> import paddle
  393. >>> # use as generator
  394. >>> x = paddle.to_tensor([1.], stop_gradient=False)
  395. >>> with paddle.no_grad():
  396. ... with paddle.enable_grad():
  397. ... y = x * 2
  398. >>> assert(y.stop_gradient == False)
  399. >>> y.backward()
  400. >>> assert(x.grad is not None)
  401. >>> # use as decorator
  402. >>> @paddle.enable_grad()
  403. >>> def double(x):
  404. ... return x * 2
  405. ...
  406. >>> with paddle.no_grad():
  407. ... z = double(x)
  408. ...
  409. >>> assert(z.stop_gradient == False)
  410. """
  411. def __enter__(self):
  412. self.prev = is_grad_enabled()
  413. _set_grad_enabled(True)
  414. def __exit__(self, *args):
  415. _set_grad_enabled(self.prev)
  416. @signature_safe_contextmanager
  417. def guard(place=None):
  418. """
  419. :api_attr: imperative
  420. This context will create a dygraph context for dygraph to run, using python ``with`` statement.
  421. Parameters:
  422. place(base.CPUPlace| base.CUDAPlace|str, optional): Place to execute dygraph.
  423. If None, the running place will be determined according to the way of paddle compilation.
  424. If ``place`` is string, It can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the
  425. index of the GPUs or XPUs. Default: None
  426. return:
  427. None
  428. Examples:
  429. .. code-block:: python
  430. >>> import numpy as np
  431. >>> import paddle.base as base
  432. >>> with base.dygraph.guard():
  433. ... inp = np.ones([3, 1024], dtype='float32')
  434. ... t = base.dygraph.base.to_variable(inp)
  435. ... linear1 = paddle.nn.Linear(1024, 4, bias_attr=False)
  436. ... linear2 = paddle.nn.Linear(4, 4)
  437. ... ret = linear1(t)
  438. ... dy_ret = linear2(ret)
  439. ...
  440. """
  441. train = framework.Program()
  442. startup = framework.Program()
  443. tracer = Tracer()
  444. if place is not None:
  445. expected_place = _get_paddle_place(place)
  446. else:
  447. expected_place = framework._current_expected_place_()
  448. with framework.program_guard(train, startup):
  449. with framework.unique_name.guard():
  450. with framework._dygraph_guard(tracer):
  451. with framework._dygraph_place_guard(expected_place):
  452. yield
  453. @framework.non_static_only
  454. def grad(
  455. outputs,
  456. inputs,
  457. grad_outputs=None,
  458. retain_graph=None,
  459. create_graph=False,
  460. only_inputs=True,
  461. allow_unused=False,
  462. no_grad_vars=None,
  463. ):
  464. '''
  465. .. note::
  466. **This API is ONLY available in imperative mode.**
  467. This API computes the sum of gradients of `outputs` with respect to each `inputs` .
  468. Parameters:
  469. outputs (Tensor|list(Tensor)|tuple(Tensor)): the output Tensor or
  470. Tensor list/tuple of the graph to compute gradients.
  471. inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or
  472. Tensor list/tuple of the graph to compute gradients. The returned
  473. values of this API are the gradients of `inputs` .
  474. grad_outputs (Tensor|list(Tensor|None)|tuple(Tensor|None), optional):
  475. initial gradient values of `outputs` . If `grad_outputs` is None,
  476. the initial gradient values of `outputs` would be Tensors filled with 1;
  477. if `grad_outputs` is not None, it must have the same length as `outputs` ,
  478. and in this case, the initial gradient value of the i-th `outputs` would
  479. be: (1) a Tensor filled with 1 when the i-th element of `grad_outputs`
  480. is None; (2) the i-th element of `grad_outputs` when the i-th element of
  481. `grad_outputs` is a Tensor. Default None.
  482. retain_graph (bool, optional): whether to retain the forward graph which
  483. is used to calculate the gradient. When it is True, the graph would
  484. be retained, in which way users can calculate backward twice for the
  485. same graph. When it is False, the graph would be freed. Default None,
  486. which means it is equal to `create_graph` .
  487. create_graph (bool, optional): whether to create the gradient graphs of
  488. the computing process. When it is True, higher order derivatives are
  489. supported to compute; when it is False, the gradient graphs of the
  490. computing process would be discarded. Default False.
  491. only_inputs (bool, optional): whether to only compute the gradients of
  492. `inputs` . If it is False, the gradients of all remaining leaf
  493. Tensors in the graph would be also computed and accumulated.
  494. If it is True, only the gradients of `inputs` would be computed.
  495. Default True. only_inputs=False is under development, and it is
  496. not supported yet.
  497. allow_unused (bool, optional): whether to raise error or return None if some
  498. Tensors of `inputs` are unreachable in the graph. If some Tensors of
  499. `inputs` are unreachable in the graph (i.e., their gradients are None),
  500. error would be raised if allow_unused=False, or None would be returned as
  501. their gradients if allow_unused=True. Default False.
  502. no_grad_vars (Tensor|list(Tensor)|tuple(Tensor)|set(Tensor), optional):
  503. the Tensors whose gradients are not needed to compute. Default None.
  504. Returns:
  505. list: a list of Tensors, whose length is the same as the Tensor number
  506. inside `inputs`, and the i-th returned Tensor is the sum of gradients of
  507. `outputs` with respect to the i-th `inputs`.
  508. Examples:
  509. .. code-block:: python
  510. :name: code-example-1
  511. >>> import paddle
  512. >>> def test_dygraph_grad(create_graph):
  513. ... x = paddle.ones(shape=[1], dtype='float32')
  514. ... x.stop_gradient = False
  515. ... y = x * x
  516. ...
  517. ... # Since y = x * x, dx = 2 * x
  518. ... dx = paddle.grad(
  519. ... outputs=[y],
  520. ... inputs=[x],
  521. ... create_graph=create_graph,
  522. ... retain_graph=True)[0]
  523. ...
  524. ... z = y + dx
  525. ...
  526. ... # If create_graph = False, the gradient of dx
  527. ... # would not be backpropagated. Therefore,
  528. ... # z = x * x + dx, and x.gradient() = 2 * x = 2.0
  529. ...
  530. ... # If create_graph = True, the gradient of dx
  531. ... # would be backpropagated. Therefore,
  532. ... # z = x * x + dx = x * x + 2 * x, and
  533. ... # x.gradient() = 2 * x + 2 = 4.0
  534. ...
  535. ... z.backward()
  536. ... return x.gradient()
  537. ...
  538. >>> print(test_dygraph_grad(create_graph=False))
  539. [2.]
  540. >>> print(test_dygraph_grad(create_graph=True))
  541. [4.]
  542. .. code-block:: python
  543. :name: code-example-2
  544. >>> import paddle
  545. >>> def test_dygraph_grad(grad_outputs=None):
  546. ... x = paddle.to_tensor(2.0)
  547. ... x.stop_gradient = False
  548. ...
  549. ... y1 = x * x
  550. ... y2 = x * 3
  551. ...
  552. ... # If grad_outputs=None, dy1 = [1], dy2 = [1].
  553. ... # If grad_outputs=[g1, g2], then:
  554. ... # - dy1 = [1] if g1 is None else g1
  555. ... # - dy2 = [1] if g2 is None else g2
  556. ...
  557. ... # Since y1 = x * x, dx = 2 * x * dy1.
  558. ... # Since y2 = x * 3, dx = 3 * dy2.
  559. ... # Therefore, the final result would be:
  560. ... # dx = 2 * x * dy1 + 3 * dy2 = 4 * dy1 + 3 * dy2.
  561. ...
  562. ... dx = paddle.grad(
  563. ... outputs=[y1, y2],
  564. ... inputs=[x],
  565. ... grad_outputs=grad_outputs)[0]
  566. ...
  567. ... return dx.numpy()
  568. ...
  569. >>> grad_value = paddle.to_tensor(4.0)
  570. >>> # dy1 = [1], dy2 = [1]
  571. >>> print(test_dygraph_grad(None))
  572. 7.
  573. >>> # dy1 = [1], dy2 = [4]
  574. >>> print(test_dygraph_grad([None, grad_value]))
  575. 16.
  576. >>> # dy1 = [4], dy2 = [1]
  577. >>> print(test_dygraph_grad([grad_value, None]))
  578. 19.
  579. >>> # dy1 = [3], dy2 = [4]
  580. >>> grad_y1 = paddle.to_tensor(3.0)
  581. >>> print(test_dygraph_grad([grad_y1, grad_value]))
  582. 24.
  583. '''
  584. if in_to_static_mode():
  585. # In dy2static context, we call static interface `gradients`
  586. # to calculate grads.
  587. from paddle.static import gradients
  588. to_static_unsupport_argument_warning(
  589. "paddle.grad",
  590. ["retain_graph", "create_grad", "only_inputs", "allow_unused"],
  591. [retain_graph, create_graph, only_inputs, allow_unused],
  592. [None, False, True, False],
  593. )
  594. return gradients(outputs, inputs, grad_outputs, no_grad_vars)
  595. def check_in_out(in_out_list, name):
  596. assert in_out_list is not None, f"{name} should not be None"
  597. if isinstance(in_out_list, (list, tuple)):
  598. assert len(in_out_list) > 0, f"{name} cannot be empty"
  599. for each_var in in_out_list:
  600. assert isinstance(
  601. each_var, core.eager.Tensor
  602. ), f"Elements of {name} must be Tensor"
  603. return in_out_list
  604. else:
  605. assert isinstance(
  606. in_out_list, core.eager.Tensor
  607. ), f"{name} must be Tensor or list of Tensor"
  608. return [in_out_list]
  609. outputs = check_in_out(outputs, 'outputs')
  610. inputs = check_in_out(inputs, 'inputs')
  611. if grad_outputs is not None:
  612. if not isinstance(grad_outputs, (list, tuple)):
  613. grad_outputs = [grad_outputs]
  614. for each_var in grad_outputs:
  615. if each_var is not None:
  616. assert isinstance(
  617. each_var, core.eager.Tensor
  618. ), "grad_outputs must be None, a Variable or a list containing None or Variables"
  619. else:
  620. grad_outputs = []
  621. if len(grad_outputs) > 0:
  622. assert len(grad_outputs) == len(
  623. outputs
  624. ), "The length of grad_outputs must be equal to outputs"
  625. if no_grad_vars is None:
  626. no_grad_vars = []
  627. elif isinstance(no_grad_vars, core.eager.Tensor):
  628. no_grad_vars = [no_grad_vars]
  629. elif isinstance(no_grad_vars, (list, tuple, set)):
  630. no_grad_vars = list(no_grad_vars)
  631. for var in no_grad_vars:
  632. assert isinstance(
  633. var, core.eager.Tensor
  634. ), "no_grad_vars can only contains Tensor"
  635. else:
  636. raise AssertionError(
  637. "no_grad_vars must be None, Tensor or list/tuple/set of Tensors"
  638. )
  639. assert isinstance(create_graph, bool), "create_graph must be True or False"
  640. if retain_graph is None:
  641. retain_graph = create_graph
  642. assert isinstance(
  643. retain_graph, bool
  644. ), "retain_graph must be None, True or False"
  645. assert isinstance(allow_unused, bool), "allow_unused must be True or False"
  646. assert isinstance(only_inputs, bool), "only_inputs must be True or False"
  647. assert only_inputs, "only_inputs=False is not supported yet"
  648. return core.eager.run_partial_grad(
  649. outputs,
  650. inputs,
  651. grad_outputs,
  652. retain_graph,
  653. create_graph,
  654. only_inputs,
  655. allow_unused,
  656. no_grad_vars,
  657. )