inplace_utils.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import warnings
  15. import paddle # noqa: F401
  16. from paddle.base.wrapped_decorator import wrap_decorator
  17. from paddle.framework import in_dynamic_mode
  18. # NOTE(pangyoki): The Inplace APIs with underline(`_`) is only valid for the method of calling `_C_ops`
  19. # in dygraph mode. If static graph mode is used, the inplace mechanism will not be used, and the static method
  20. # of the original API will be called.
  21. # NOTE(GGBond8488): Simply run the original version of the API under the static graph mode has a low
  22. # probability that the result is inconsistent with the dynamic graph.
  23. def _inplace_apis_in_dygraph_only_(func):
  24. def __impl__(*args, **kwargs):
  25. if not in_dynamic_mode():
  26. origin_api_name = func.__name__[:-1]
  27. warnings.warn(
  28. f"In static graph mode, {func.__name__}() is the same as {origin_api_name}() and does not perform inplace operation."
  29. )
  30. from ..base.dygraph.base import in_to_static_mode
  31. if in_to_static_mode():
  32. for arg in args:
  33. if hasattr(arg, "is_view_var") and arg.is_view_var:
  34. raise ValueError(
  35. f'Sorry about what\'s happened. In to_static mode, {func.__name__}\'s output variable {arg.name} is a viewed Tensor in dygraph. This will result in inconsistent calculation behavior between dynamic and static graphs. You must find the location of the strided API be called, and call {arg.name} = {arg.name}.assign().'
  36. )
  37. origin_func = f"{func.__module__}.{origin_api_name}"
  38. return eval(origin_func)(*args, **kwargs)
  39. return func(*args, **kwargs)
  40. return __impl__
  41. inplace_apis_in_dygraph_only = wrap_decorator(_inplace_apis_in_dygraph_only_)