| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import annotations
- import atexit
- import builtins
- import functools
- import importlib.util
- import inspect
- import os
- import shutil
- import sys
- import tempfile
- import textwrap
- import types
- from importlib.machinery import SourceFileLoader
- import numpy as np
- import paddle
- from paddle.base import backward, core, framework, unique_name
- from paddle.base.data_feeder import convert_dtype
- from paddle.base.layer_helper import LayerHelper
- from paddle.base.wrapped_decorator import signature_safe_contextmanager
- from paddle.framework import CUDAPinnedPlace
- from paddle.jit.utils import OrderedSet
- from paddle.utils import flatten
- from .ast_utils import ast_to_source_code
- __all__ = []
- # Note(Aurelius): Do not forget the dot `.` to distinguish other
- # module such as paddlenlp.
- PADDLE_MODULE_PREFIX = 'paddle.'
- ALREADY_D2S = '__already_d2s'
- # NOTE(liym27): Please use `getattr(ast_node, ORIGIN_INFO)` instead of . operation to get the original information of ast node.
- ORIGIN_INFO = "Original information of source code for ast node."
- DEL_TEMP_DIR = True # A flag to avoid atexit.register more than once
- RE_PYNAME = '[a-zA-Z0-9_]+'
- RE_PYMODULE = r'[a-zA-Z0-9_]+\.'
- # Assign not support float64, use float32 value as magic number.
- RETURN_NO_VALUE_VAR_NAME = "__no_value_return_var"
- RETURN_NO_VALUE_MAGIC_NUM = 1.77113e27
- NO_SHAPE_VAR_TYPE = [
- core.VarDesc.VarType.READER,
- core.VarDesc.VarType.STEP_SCOPES,
- core.VarDesc.VarType.FEED_MINIBATCH,
- core.VarDesc.VarType.FETCH_LIST,
- ]
- def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
- """
- This function creates a Tensor on the global block. The created Tensor
- doesn't check the dtype and the shape of feed data because dygraph input
- data can be various-length. This API is used in translating dygraph into
- static graph.
- Note:
- The default :code:`stop_gradient` attribute of the Tensor created by
- this API is true, which means the gradient won't be passed backward
- through the data Tensor. Set :code:`var.stop_gradient = False` If
- user would like to pass backward gradient.
- Args:
- name (str): The name/alias of the Tensor, see :ref:`api_guide_Name`
- for more details.
- shape (list|tuple): List|Tuple of integers declaring the shape. You can
- set "None" at a dimension to indicate the dimension can be of any
- size. For example, it is useful to set changeable batch size as "None"
- dtype (np.dtype|VarType|str, optional): The type of the data. Supported
- dtype: bool, float16, float32, float64, int8, int16, int32, int64,
- uint8. Default: float32
- lod_level (int, optional): The LoD level of the LoDTensor. Usually users
- don't have to set this value. Default: 0
- Returns:
- Tensor: The global Tensor that gives access to the data.
- """
- helper = LayerHelper('data', **locals())
- shape = list(shape)
- for i in range(len(shape)):
- if shape[i] is None:
- shape[i] = -1
- return helper.create_global_variable(
- name=name,
- shape=shape,
- dtype=dtype,
- type=core.VarDesc.VarType.LOD_TENSOR,
- stop_gradient=True,
- lod_level=lod_level,
- is_data=True,
- need_check_feed=False,
- )
- def create_undefined_variable():
- var = data_layer_not_check(
- unique_name.generate("undefined_var"), [1], "float64"
- )
- var.stop_gradient = False
- # the variable is created in block(0), we append assign in block(0) either.
- helper = LayerHelper('create_undefined_variable', **locals())
- saved_block_ids = helper.main_program.current_block_idx
- helper.main_program.current_block_idx = 0
- paddle.assign(RETURN_NO_VALUE_MAGIC_NUM, var)
- helper.main_program.current_block_idx = saved_block_ids
- return var
- class UndefinedVar:
- def __init__(self, name):
- self.name = name
- def check(self):
- raise UnboundLocalError(
- "local variable '{}' should be created before using it."
- )
- class Dygraph2StaticException(Exception):
- def __init__(self, message):
- super().__init__(message)
- def saw(x):
- if isinstance(x, UndefinedVar):
- return x.check()
- else:
- return x
- def parse_arg_and_kwargs(function):
- """
- Returns full argument names as list. e.g ['x', 'y', 'z']
- """
- fullargspec = inspect.getfullargspec(function)
- arg_names = fullargspec.args
- if arg_names and 'self' == arg_names[0]:
- arg_names = fullargspec.args[1:]
- # parse default kwargs
- default_kwargs = {}
- default_values = fullargspec.defaults
- if default_values:
- assert len(default_values) <= len(arg_names)
- default_kwarg_names = arg_names[-len(default_values) :]
- default_kwargs = dict(zip(default_kwarg_names, default_values))
- return arg_names, default_kwargs
- def parse_varargs_name(function):
- """
- Returns varargs name string of function. e.g: 'input' from `foo(x, *input)`
- """
- fullargspec = inspect.getfullargspec(function)
- varargs = fullargspec.varargs
- return varargs
- def type_name(v):
- return type(v).__name__
- def make_hashable(x, error_msg=None):
- """
- Makes input `x` hashable.
- For some unhashable objects, such as `dict/list/set/np.ndarray`,applying hash function by using their values.
- """
- if isinstance(x, (tuple, list, set)):
- return tuple(map(make_hashable, x))
- try:
- hash(x)
- except TypeError:
- if isinstance(x, np.ndarray):
- # Note: `tostring()` will return the binary data from np.ndarray that
- # means different value will lead to different hash code.
- return hash(x.tostring())
- elif isinstance(x, dict):
- return tuple(map(make_hashable, x.values()))
- error_msg = error_msg or "Requires a hashable object."
- raise ValueError(f"{error_msg} But received type: {type_name(x)}")
- return x
- # NOTE(Aurelius84): Consider the following paddle inner API as common case to
- # apply @to_static code transformation as usual. Because they contains
- # user-defined layer, like paddle.distributed.auto_parallel.helper.ProxyLayer.
- AS_NOT_INNER_FUNC_LIST = {"paddle.nn.layer.container.Sequential"}
- def as_not_paddle_func(path):
- """
- Append API or class as ignored case for is_paddle_func, and they
- will be returned False while calling is_paddle_func(func).
- """
- global INNER_FUNC_WHITE_LIST
- AS_NOT_INNER_FUNC_LIST.add(path)
- def is_paddle_func(func, ignore_white_list=True):
- """
- Return True if function is defined in Paddle module.
- Skip to check APIs in white list if specifying ignore_white_list as True.
- """
- def in_white_list(module, func_name):
- if func_name is None:
- return False
- return (module.__name__ + '.' + func_name) in AS_NOT_INNER_FUNC_LIST
- try:
- if isinstance(func, functools.partial):
- func = func.func
- func_name = getattr(func, '__name__', None)
- if inspect.ismethod(func):
- func_name = func.__self__.__class__.__name__
- func = func.__func__
- elif hasattr(func, '__class__'): # for nn.Sequential
- func_name = func.__class__.__name__
- m = inspect.getmodule(func)
- flag = m is not None and m.__name__.startswith(PADDLE_MODULE_PREFIX)
- if ignore_white_list:
- flag = flag and not in_white_list(m, func_name)
- return flag
- except Exception:
- return False
- def get_temp_dir():
- """
- Return @to_static temp directory.
- """
- dir_name = f"paddle/to_static_tmp/{os.getpid()}"
- temp_dir = os.path.join(os.path.expanduser('~/.cache'), dir_name)
- is_windows = sys.platform.startswith('win')
- if is_windows:
- temp_dir = os.path.normpath(temp_dir)
- if not os.path.exists(temp_dir):
- os.makedirs(temp_dir)
- return temp_dir
- def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
- """
- Transform modified AST of decorated function into python callable object.
- TODO: If only decorate one of inner function instead of decorating the main
- function, the other inner functions are invisible for the decorated function.
- """
- def remove_if_exit(dir_path):
- if os.path.exists(dir_path):
- shutil.rmtree(dir_path)
- def func_prefix(func):
- pre_fix = func.__name__
- if hasattr(func, '__self__'):
- try:
- pre_fix = func.__self__.__class__.__name__ + '_' + func.__name__
- except:
- pass
- return pre_fix
- source = ast_to_source_code(ast_root)
- source = _inject_import_statements() + source
- temp_dir = get_temp_dir()
- f = tempfile.NamedTemporaryFile(
- mode='w',
- prefix=func_prefix(dyfunc),
- suffix='.py',
- delete=False,
- dir=temp_dir,
- encoding='utf-8',
- )
- with f:
- module_name = os.path.basename(f.name[:-3])
- f.write(source)
- global DEL_TEMP_DIR
- if delete_on_exit and DEL_TEMP_DIR:
- # Clear temporary files in TEMP_DIR while exiting Python process
- atexit.register(remove_if_exit, dir_path=temp_dir)
- DEL_TEMP_DIR = False
- func_name = dyfunc.__name__
- loader = SourceFileLoader(module_name, f.name)
- spec = importlib.util.spec_from_loader(loader.name, loader)
- module = importlib.util.module_from_spec(spec)
- loader.exec_module(module)
- # The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained
- # through 'func_name'. So set the special function name '__i_m_p_l__'.
- if hasattr(module, '__i_m_p_l__'):
- callable_func = module.__i_m_p_l__
- callable_func.__name__ = func_name
- elif hasattr(module, func_name):
- callable_func = getattr(module, func_name)
- else:
- raise ValueError(
- f'Function: {func_name} doesn\'t exist in the Module transformed from AST.'
- )
- # After transform dygraph function into callable_func saved in tmp file,
- # it lost the global variables from imported statements or defined in source file.
- # Recovers the necessary variables by `__globals__`.
- recover_globals_attribute(dyfunc, callable_func)
- return callable_func, f.name
- def _inject_import_statements():
- import_statements = [
- "import paddle",
- "from paddle import Tensor",
- "import paddle.base as base",
- "import paddle.jit.dy2static as _jst",
- "from typing import *",
- "import numpy as np",
- "import warnings",
- "warnings.filterwarnings('ignore', category=DeprecationWarning)",
- ]
- return '\n'.join(import_statements) + '\n'
- def recover_globals_attribute(src_obj, dst_obj):
- attr_name = '__globals__'
- src_globals = getattr(src_obj, attr_name, {})
- dst_globals = getattr(dst_obj, attr_name, {})
- for k, v in src_globals.items():
- # ignore builtin attribute.
- if not (k.startswith('__') and k.endswith('__')):
- dst_globals[k] = v
- # Inject source function closure into destination function globals
- # Because the destination function is a standalone function, the original
- # closure of the source function is compiled as LOAD_GLOBAL in the
- # destination function.
- src_closure = inspect.getclosurevars(src_obj)
- for k, v in src_closure.nonlocals.items():
- dst_globals[k] = v
- def func_to_source_code(function, dedent=True):
- """
- Transforms function into raw string of source code.
- """
- if isinstance(function, functools.partial):
- function = function.func
- if not (inspect.isfunction(function) or inspect.ismethod(function)):
- raise TypeError(
- f"The type of 'function' should be a function or method, but received {type(function).__name__}."
- )
- source_code_list, _ = inspect.getsourcelines(function)
- # Replace comments with blank lines so that error messages are not misplaced
- source_code_list = [
- line if not line.lstrip().startswith('#') else '\n'
- for line in source_code_list
- ]
- source_code = ''.join(source_code_list)
- if dedent:
- source_code = textwrap.dedent(source_code)
- return source_code
- def input_specs_compatible(src_input_specs, desired_input_specs):
- """
- Returns True if the two input specs are compatible, otherwise False.
- args:
- src_input_spec (list or tuple[InputSpec et.al]): list/tuple of
- paddle.static.InputSpec or int/str et.al
- desired_input_specs (list or tuple[InputSpec et.al]): list/tuple of
- paddle.static.InputSpec or int/str et.al
- """
- len_specs = len(src_input_specs)
- if len_specs != len(desired_input_specs):
- # NOTE(chenweihang): if the input_spec of jit.save is a subset of
- # input_spec of to_static, also compatible
- for spec in src_input_specs:
- if spec not in desired_input_specs:
- return False
- else:
- for src_spec, desired_spec in zip(src_input_specs, desired_input_specs):
- if isinstance(src_spec, paddle.static.InputSpec) or isinstance(
- desired_spec, paddle.static.InputSpec
- ):
- if not _compatible_tensor_spec(src_spec, desired_spec):
- return False
- else:
- if not _compatible_non_tensor_spec(src_spec, desired_spec):
- return False
- return True
- def _compatible_tensor_spec(src_spec, desired_spec):
- """
- Check whether two tensor type spec is compatible.
- """
- for spec in [src_spec, desired_spec]:
- if not isinstance(spec, paddle.static.InputSpec):
- return False
- src_shape = src_spec.shape
- other_shape = desired_spec.shape
- len_shape = len(src_shape)
- if len_shape != len(other_shape):
- return False
- for j in range(len_shape):
- if src_shape[j] is None or src_shape[j] < 0:
- continue
- if other_shape[j] is None or other_shape[j] < 0:
- continue
- if src_shape[j] != other_shape[j]:
- return False
- src_dtype = convert_dtype(src_spec.dtype)
- other_dtype = convert_dtype(desired_spec.dtype)
- if src_dtype != other_dtype:
- return False
- return True
- def _compatible_non_tensor_spec(src_spec, desired_spec):
- """
- Check whether two non-tensor type spec is compatible.
- """
- def hash_value(spec):
- try:
- hash_val = make_hashable(spec)
- except:
- hash_val = None
- return hash_val
- src_hash_val = hash_value(src_spec)
- desired_hash_val = hash_value(desired_spec)
- if src_hash_val != desired_hash_val:
- return False
- else:
- return True
- class GetterSetterHelper:
- """we have two classes of names in setter and getter function:
- w_vars(loop_vars) + push_pop_vars
- To simplify the setter logic in convert_while and convert_cond,
- we extract the helper class here.
- """
- def __init__(self, getter_func, setter_func, *name_lists):
- name_lists = ([] if x is None else x for x in name_lists)
- name_sets = (OrderedSet(x) for x in name_lists)
- self._union = list(
- functools.reduce(lambda x, y: x | y, name_sets, OrderedSet())
- )
- self._union.sort()
- self.getter = getter_func
- self.setter = setter_func
- self.name2id = {name: idx for idx, name in enumerate(self._union)}
- def union(self):
- return self._union
- def get(self, names):
- if names is None:
- names = []
- vars = self.getter()
- if vars is None:
- return ()
- for n in names:
- assert (
- n in self.name2id
- ), f"the name `{n}` not in name union set`{self.name2id.keys()}`."
- return tuple(vars[self.name2id[n]] for n in names)
- def set(self, names, values):
- if names is None:
- names = []
- if values is None:
- values = []
- vars = self.getter()
- if vars is None:
- return
- for n in names:
- assert (
- n in self.name2id
- ), f"the name `{n}` not in name union set`{self.name2id.keys()}`."
- vars = list(vars)
- indices = [self.name2id[n] for n in names]
- for i, v in zip(indices, values):
- vars[i] = v
- self.setter(vars)
- def prim_or_cinn_is_enabled(build_strategy, backend):
- return cinn_is_enabled(build_strategy, backend) or prim_is_enabled()
- def cinn_is_enabled(build_strategy, backend):
- if backend == 'CINN':
- return True
- if build_strategy is not None and build_strategy.build_cinn_pass:
- return True
- value = os.getenv('FLAGS_use_cinn')
- if value is not None and value.lower() in ['true', '1']:
- return True
- return False
- def cse_is_enabled():
- return paddle.get_flags(["FLAGS_enable_cse_in_dy2st"])[
- "FLAGS_enable_cse_in_dy2st"
- ]
- def prim_is_enabled():
- core.check_and_set_prim_all_enabled()
- return core._is_bwd_prim_enabled() or core._is_fwd_prim_enabled()
- def is_api_in_module_helper(obj, module_prefix):
- m = inspect.getmodule(obj)
- return m is not None and m.__name__.startswith(module_prefix)
- def is_builtin(func, name=None):
- """predict whether a function is a builtin function with name={name}.
- if name == None, then any builtin function will return True
- """
- def name_judge():
- return name is None or func.__name__ == name
- if isinstance(func, types.BuiltinFunctionType) and name_judge():
- return True
- elif func in builtins.__dict__.values() and name_judge():
- return True
- else:
- return False
- @signature_safe_contextmanager
- def backend_guard(backend):
- core.check_and_set_prim_all_enabled()
- origin_fwd = core._is_fwd_prim_enabled()
- origin_bwd = core._is_bwd_prim_enabled()
- if backend == 'CINN':
- core._set_prim_all_enabled(True)
- try:
- yield
- finally:
- core._set_prim_forward_enabled(origin_fwd)
- core._set_prim_backward_enabled(origin_bwd)
- def construct_grad_names(grad_info_map, x_vars, param_vars, out_vars):
- grad_var_names = {}
- fn = lambda grad_var: (
- grad_var.name
- if isinstance(grad_var, framework.Variable)
- else framework.EMPTY_VAR_NAME
- )
- x_grad_vars = backward._get_grad_vars(grad_info_map, x_vars)
- grad_var_names['x'] = list(map(fn, x_grad_vars))
- param_grad_vars = backward._get_grad_vars(grad_info_map, param_vars)
- grad_var_names['param'] = list(map(fn, param_grad_vars))
- out_grad_vars = backward._get_grad_vars(grad_info_map, out_vars)
- grad_var_names['out'] = list(map(fn, out_grad_vars))
- return grad_var_names
- @signature_safe_contextmanager
- def tensor_name_guard(tensors, names):
- try:
- assert len(tensors) == len(names)
- origin_names = [t.name for t in tensors]
- for t, name in zip(tensors, names):
- t.name = name
- yield
- finally:
- for t, name in zip(tensors, origin_names):
- t.name = name
- def cuda_pinned_tensors_move_to_excepted_place(inputs):
- if paddle.is_compiled_with_cuda():
- expected_place = framework._current_expected_place()
- cuda_pinned_place = CUDAPinnedPlace()
- for value in flatten(inputs):
- if (
- isinstance(value, core.eager.Tensor)
- and value.stop_gradient
- and value.place._equals(cuda_pinned_place)
- ):
- var = value._copy_to(expected_place, True)
- var.stop_gradient = True
- var._share_buffer_to(value)
|