utils.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. import atexit
  16. import builtins
  17. import functools
  18. import importlib.util
  19. import inspect
  20. import os
  21. import shutil
  22. import sys
  23. import tempfile
  24. import textwrap
  25. import types
  26. from importlib.machinery import SourceFileLoader
  27. import numpy as np
  28. import paddle
  29. from paddle.base import backward, core, framework, unique_name
  30. from paddle.base.data_feeder import convert_dtype
  31. from paddle.base.layer_helper import LayerHelper
  32. from paddle.base.wrapped_decorator import signature_safe_contextmanager
  33. from paddle.framework import CUDAPinnedPlace
  34. from paddle.jit.utils import OrderedSet
  35. from paddle.utils import flatten
  36. from .ast_utils import ast_to_source_code
  37. __all__ = []
  38. # Note(Aurelius): Do not forget the dot `.` to distinguish other
  39. # module such as paddlenlp.
  40. PADDLE_MODULE_PREFIX = 'paddle.'
  41. ALREADY_D2S = '__already_d2s'
  42. # NOTE(liym27): Please use `getattr(ast_node, ORIGIN_INFO)` instead of . operation to get the original information of ast node.
  43. ORIGIN_INFO = "Original information of source code for ast node."
  44. DEL_TEMP_DIR = True # A flag to avoid atexit.register more than once
  45. RE_PYNAME = '[a-zA-Z0-9_]+'
  46. RE_PYMODULE = r'[a-zA-Z0-9_]+\.'
  47. # Assign not support float64, use float32 value as magic number.
  48. RETURN_NO_VALUE_VAR_NAME = "__no_value_return_var"
  49. RETURN_NO_VALUE_MAGIC_NUM = 1.77113e27
  50. NO_SHAPE_VAR_TYPE = [
  51. core.VarDesc.VarType.READER,
  52. core.VarDesc.VarType.STEP_SCOPES,
  53. core.VarDesc.VarType.FEED_MINIBATCH,
  54. core.VarDesc.VarType.FETCH_LIST,
  55. ]
  56. def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
  57. """
  58. This function creates a Tensor on the global block. The created Tensor
  59. doesn't check the dtype and the shape of feed data because dygraph input
  60. data can be various-length. This API is used in translating dygraph into
  61. static graph.
  62. Note:
  63. The default :code:`stop_gradient` attribute of the Tensor created by
  64. this API is true, which means the gradient won't be passed backward
  65. through the data Tensor. Set :code:`var.stop_gradient = False` If
  66. user would like to pass backward gradient.
  67. Args:
  68. name (str): The name/alias of the Tensor, see :ref:`api_guide_Name`
  69. for more details.
  70. shape (list|tuple): List|Tuple of integers declaring the shape. You can
  71. set "None" at a dimension to indicate the dimension can be of any
  72. size. For example, it is useful to set changeable batch size as "None"
  73. dtype (np.dtype|VarType|str, optional): The type of the data. Supported
  74. dtype: bool, float16, float32, float64, int8, int16, int32, int64,
  75. uint8. Default: float32
  76. lod_level (int, optional): The LoD level of the LoDTensor. Usually users
  77. don't have to set this value. Default: 0
  78. Returns:
  79. Tensor: The global Tensor that gives access to the data.
  80. """
  81. helper = LayerHelper('data', **locals())
  82. shape = list(shape)
  83. for i in range(len(shape)):
  84. if shape[i] is None:
  85. shape[i] = -1
  86. return helper.create_global_variable(
  87. name=name,
  88. shape=shape,
  89. dtype=dtype,
  90. type=core.VarDesc.VarType.LOD_TENSOR,
  91. stop_gradient=True,
  92. lod_level=lod_level,
  93. is_data=True,
  94. need_check_feed=False,
  95. )
  96. def create_undefined_variable():
  97. var = data_layer_not_check(
  98. unique_name.generate("undefined_var"), [1], "float64"
  99. )
  100. var.stop_gradient = False
  101. # the variable is created in block(0), we append assign in block(0) either.
  102. helper = LayerHelper('create_undefined_variable', **locals())
  103. saved_block_ids = helper.main_program.current_block_idx
  104. helper.main_program.current_block_idx = 0
  105. paddle.assign(RETURN_NO_VALUE_MAGIC_NUM, var)
  106. helper.main_program.current_block_idx = saved_block_ids
  107. return var
  108. class UndefinedVar:
  109. def __init__(self, name):
  110. self.name = name
  111. def check(self):
  112. raise UnboundLocalError(
  113. "local variable '{}' should be created before using it."
  114. )
  115. class Dygraph2StaticException(Exception):
  116. def __init__(self, message):
  117. super().__init__(message)
  118. def saw(x):
  119. if isinstance(x, UndefinedVar):
  120. return x.check()
  121. else:
  122. return x
  123. def parse_arg_and_kwargs(function):
  124. """
  125. Returns full argument names as list. e.g ['x', 'y', 'z']
  126. """
  127. fullargspec = inspect.getfullargspec(function)
  128. arg_names = fullargspec.args
  129. if arg_names and 'self' == arg_names[0]:
  130. arg_names = fullargspec.args[1:]
  131. # parse default kwargs
  132. default_kwargs = {}
  133. default_values = fullargspec.defaults
  134. if default_values:
  135. assert len(default_values) <= len(arg_names)
  136. default_kwarg_names = arg_names[-len(default_values) :]
  137. default_kwargs = dict(zip(default_kwarg_names, default_values))
  138. return arg_names, default_kwargs
  139. def parse_varargs_name(function):
  140. """
  141. Returns varargs name string of function. e.g: 'input' from `foo(x, *input)`
  142. """
  143. fullargspec = inspect.getfullargspec(function)
  144. varargs = fullargspec.varargs
  145. return varargs
  146. def type_name(v):
  147. return type(v).__name__
  148. def make_hashable(x, error_msg=None):
  149. """
  150. Makes input `x` hashable.
  151. For some unhashable objects, such as `dict/list/set/np.ndarray`,applying hash function by using their values.
  152. """
  153. if isinstance(x, (tuple, list, set)):
  154. return tuple(map(make_hashable, x))
  155. try:
  156. hash(x)
  157. except TypeError:
  158. if isinstance(x, np.ndarray):
  159. # Note: `tostring()` will return the binary data from np.ndarray that
  160. # means different value will lead to different hash code.
  161. return hash(x.tostring())
  162. elif isinstance(x, dict):
  163. return tuple(map(make_hashable, x.values()))
  164. error_msg = error_msg or "Requires a hashable object."
  165. raise ValueError(f"{error_msg} But received type: {type_name(x)}")
  166. return x
  167. # NOTE(Aurelius84): Consider the following paddle inner API as common case to
  168. # apply @to_static code transformation as usual. Because they contains
  169. # user-defined layer, like paddle.distributed.auto_parallel.helper.ProxyLayer.
  170. AS_NOT_INNER_FUNC_LIST = {"paddle.nn.layer.container.Sequential"}
  171. def as_not_paddle_func(path):
  172. """
  173. Append API or class as ignored case for is_paddle_func, and they
  174. will be returned False while calling is_paddle_func(func).
  175. """
  176. global INNER_FUNC_WHITE_LIST
  177. AS_NOT_INNER_FUNC_LIST.add(path)
  178. def is_paddle_func(func, ignore_white_list=True):
  179. """
  180. Return True if function is defined in Paddle module.
  181. Skip to check APIs in white list if specifying ignore_white_list as True.
  182. """
  183. def in_white_list(module, func_name):
  184. if func_name is None:
  185. return False
  186. return (module.__name__ + '.' + func_name) in AS_NOT_INNER_FUNC_LIST
  187. try:
  188. if isinstance(func, functools.partial):
  189. func = func.func
  190. func_name = getattr(func, '__name__', None)
  191. if inspect.ismethod(func):
  192. func_name = func.__self__.__class__.__name__
  193. func = func.__func__
  194. elif hasattr(func, '__class__'): # for nn.Sequential
  195. func_name = func.__class__.__name__
  196. m = inspect.getmodule(func)
  197. flag = m is not None and m.__name__.startswith(PADDLE_MODULE_PREFIX)
  198. if ignore_white_list:
  199. flag = flag and not in_white_list(m, func_name)
  200. return flag
  201. except Exception:
  202. return False
  203. def get_temp_dir():
  204. """
  205. Return @to_static temp directory.
  206. """
  207. dir_name = f"paddle/to_static_tmp/{os.getpid()}"
  208. temp_dir = os.path.join(os.path.expanduser('~/.cache'), dir_name)
  209. is_windows = sys.platform.startswith('win')
  210. if is_windows:
  211. temp_dir = os.path.normpath(temp_dir)
  212. if not os.path.exists(temp_dir):
  213. os.makedirs(temp_dir)
  214. return temp_dir
  215. def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
  216. """
  217. Transform modified AST of decorated function into python callable object.
  218. TODO: If only decorate one of inner function instead of decorating the main
  219. function, the other inner functions are invisible for the decorated function.
  220. """
  221. def remove_if_exit(dir_path):
  222. if os.path.exists(dir_path):
  223. shutil.rmtree(dir_path)
  224. def func_prefix(func):
  225. pre_fix = func.__name__
  226. if hasattr(func, '__self__'):
  227. try:
  228. pre_fix = func.__self__.__class__.__name__ + '_' + func.__name__
  229. except:
  230. pass
  231. return pre_fix
  232. source = ast_to_source_code(ast_root)
  233. source = _inject_import_statements() + source
  234. temp_dir = get_temp_dir()
  235. f = tempfile.NamedTemporaryFile(
  236. mode='w',
  237. prefix=func_prefix(dyfunc),
  238. suffix='.py',
  239. delete=False,
  240. dir=temp_dir,
  241. encoding='utf-8',
  242. )
  243. with f:
  244. module_name = os.path.basename(f.name[:-3])
  245. f.write(source)
  246. global DEL_TEMP_DIR
  247. if delete_on_exit and DEL_TEMP_DIR:
  248. # Clear temporary files in TEMP_DIR while exiting Python process
  249. atexit.register(remove_if_exit, dir_path=temp_dir)
  250. DEL_TEMP_DIR = False
  251. func_name = dyfunc.__name__
  252. loader = SourceFileLoader(module_name, f.name)
  253. spec = importlib.util.spec_from_loader(loader.name, loader)
  254. module = importlib.util.module_from_spec(spec)
  255. loader.exec_module(module)
  256. # The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained
  257. # through 'func_name'. So set the special function name '__i_m_p_l__'.
  258. if hasattr(module, '__i_m_p_l__'):
  259. callable_func = module.__i_m_p_l__
  260. callable_func.__name__ = func_name
  261. elif hasattr(module, func_name):
  262. callable_func = getattr(module, func_name)
  263. else:
  264. raise ValueError(
  265. f'Function: {func_name} doesn\'t exist in the Module transformed from AST.'
  266. )
  267. # After transform dygraph function into callable_func saved in tmp file,
  268. # it lost the global variables from imported statements or defined in source file.
  269. # Recovers the necessary variables by `__globals__`.
  270. recover_globals_attribute(dyfunc, callable_func)
  271. return callable_func, f.name
  272. def _inject_import_statements():
  273. import_statements = [
  274. "import paddle",
  275. "from paddle import Tensor",
  276. "import paddle.base as base",
  277. "import paddle.jit.dy2static as _jst",
  278. "from typing import *",
  279. "import numpy as np",
  280. "import warnings",
  281. "warnings.filterwarnings('ignore', category=DeprecationWarning)",
  282. ]
  283. return '\n'.join(import_statements) + '\n'
  284. def recover_globals_attribute(src_obj, dst_obj):
  285. attr_name = '__globals__'
  286. src_globals = getattr(src_obj, attr_name, {})
  287. dst_globals = getattr(dst_obj, attr_name, {})
  288. for k, v in src_globals.items():
  289. # ignore builtin attribute.
  290. if not (k.startswith('__') and k.endswith('__')):
  291. dst_globals[k] = v
  292. # Inject source function closure into destination function globals
  293. # Because the destination function is a standalone function, the original
  294. # closure of the source function is compiled as LOAD_GLOBAL in the
  295. # destination function.
  296. src_closure = inspect.getclosurevars(src_obj)
  297. for k, v in src_closure.nonlocals.items():
  298. dst_globals[k] = v
  299. def func_to_source_code(function, dedent=True):
  300. """
  301. Transforms function into raw string of source code.
  302. """
  303. if isinstance(function, functools.partial):
  304. function = function.func
  305. if not (inspect.isfunction(function) or inspect.ismethod(function)):
  306. raise TypeError(
  307. f"The type of 'function' should be a function or method, but received {type(function).__name__}."
  308. )
  309. source_code_list, _ = inspect.getsourcelines(function)
  310. # Replace comments with blank lines so that error messages are not misplaced
  311. source_code_list = [
  312. line if not line.lstrip().startswith('#') else '\n'
  313. for line in source_code_list
  314. ]
  315. source_code = ''.join(source_code_list)
  316. if dedent:
  317. source_code = textwrap.dedent(source_code)
  318. return source_code
  319. def input_specs_compatible(src_input_specs, desired_input_specs):
  320. """
  321. Returns True if the two input specs are compatible, otherwise False.
  322. args:
  323. src_input_spec (list or tuple[InputSpec et.al]): list/tuple of
  324. paddle.static.InputSpec or int/str et.al
  325. desired_input_specs (list or tuple[InputSpec et.al]): list/tuple of
  326. paddle.static.InputSpec or int/str et.al
  327. """
  328. len_specs = len(src_input_specs)
  329. if len_specs != len(desired_input_specs):
  330. # NOTE(chenweihang): if the input_spec of jit.save is a subset of
  331. # input_spec of to_static, also compatible
  332. for spec in src_input_specs:
  333. if spec not in desired_input_specs:
  334. return False
  335. else:
  336. for src_spec, desired_spec in zip(src_input_specs, desired_input_specs):
  337. if isinstance(src_spec, paddle.static.InputSpec) or isinstance(
  338. desired_spec, paddle.static.InputSpec
  339. ):
  340. if not _compatible_tensor_spec(src_spec, desired_spec):
  341. return False
  342. else:
  343. if not _compatible_non_tensor_spec(src_spec, desired_spec):
  344. return False
  345. return True
  346. def _compatible_tensor_spec(src_spec, desired_spec):
  347. """
  348. Check whether two tensor type spec is compatible.
  349. """
  350. for spec in [src_spec, desired_spec]:
  351. if not isinstance(spec, paddle.static.InputSpec):
  352. return False
  353. src_shape = src_spec.shape
  354. other_shape = desired_spec.shape
  355. len_shape = len(src_shape)
  356. if len_shape != len(other_shape):
  357. return False
  358. for j in range(len_shape):
  359. if src_shape[j] is None or src_shape[j] < 0:
  360. continue
  361. if other_shape[j] is None or other_shape[j] < 0:
  362. continue
  363. if src_shape[j] != other_shape[j]:
  364. return False
  365. src_dtype = convert_dtype(src_spec.dtype)
  366. other_dtype = convert_dtype(desired_spec.dtype)
  367. if src_dtype != other_dtype:
  368. return False
  369. return True
  370. def _compatible_non_tensor_spec(src_spec, desired_spec):
  371. """
  372. Check whether two non-tensor type spec is compatible.
  373. """
  374. def hash_value(spec):
  375. try:
  376. hash_val = make_hashable(spec)
  377. except:
  378. hash_val = None
  379. return hash_val
  380. src_hash_val = hash_value(src_spec)
  381. desired_hash_val = hash_value(desired_spec)
  382. if src_hash_val != desired_hash_val:
  383. return False
  384. else:
  385. return True
  386. class GetterSetterHelper:
  387. """we have two classes of names in setter and getter function:
  388. w_vars(loop_vars) + push_pop_vars
  389. To simplify the setter logic in convert_while and convert_cond,
  390. we extract the helper class here.
  391. """
  392. def __init__(self, getter_func, setter_func, *name_lists):
  393. name_lists = ([] if x is None else x for x in name_lists)
  394. name_sets = (OrderedSet(x) for x in name_lists)
  395. self._union = list(
  396. functools.reduce(lambda x, y: x | y, name_sets, OrderedSet())
  397. )
  398. self._union.sort()
  399. self.getter = getter_func
  400. self.setter = setter_func
  401. self.name2id = {name: idx for idx, name in enumerate(self._union)}
  402. def union(self):
  403. return self._union
  404. def get(self, names):
  405. if names is None:
  406. names = []
  407. vars = self.getter()
  408. if vars is None:
  409. return ()
  410. for n in names:
  411. assert (
  412. n in self.name2id
  413. ), f"the name `{n}` not in name union set`{self.name2id.keys()}`."
  414. return tuple(vars[self.name2id[n]] for n in names)
  415. def set(self, names, values):
  416. if names is None:
  417. names = []
  418. if values is None:
  419. values = []
  420. vars = self.getter()
  421. if vars is None:
  422. return
  423. for n in names:
  424. assert (
  425. n in self.name2id
  426. ), f"the name `{n}` not in name union set`{self.name2id.keys()}`."
  427. vars = list(vars)
  428. indices = [self.name2id[n] for n in names]
  429. for i, v in zip(indices, values):
  430. vars[i] = v
  431. self.setter(vars)
  432. def prim_or_cinn_is_enabled(build_strategy, backend):
  433. return cinn_is_enabled(build_strategy, backend) or prim_is_enabled()
  434. def cinn_is_enabled(build_strategy, backend):
  435. if backend == 'CINN':
  436. return True
  437. if build_strategy is not None and build_strategy.build_cinn_pass:
  438. return True
  439. value = os.getenv('FLAGS_use_cinn')
  440. if value is not None and value.lower() in ['true', '1']:
  441. return True
  442. return False
  443. def cse_is_enabled():
  444. return paddle.get_flags(["FLAGS_enable_cse_in_dy2st"])[
  445. "FLAGS_enable_cse_in_dy2st"
  446. ]
  447. def prim_is_enabled():
  448. core.check_and_set_prim_all_enabled()
  449. return core._is_bwd_prim_enabled() or core._is_fwd_prim_enabled()
  450. def is_api_in_module_helper(obj, module_prefix):
  451. m = inspect.getmodule(obj)
  452. return m is not None and m.__name__.startswith(module_prefix)
  453. def is_builtin(func, name=None):
  454. """predict whether a function is a builtin function with name={name}.
  455. if name == None, then any builtin function will return True
  456. """
  457. def name_judge():
  458. return name is None or func.__name__ == name
  459. if isinstance(func, types.BuiltinFunctionType) and name_judge():
  460. return True
  461. elif func in builtins.__dict__.values() and name_judge():
  462. return True
  463. else:
  464. return False
  465. @signature_safe_contextmanager
  466. def backend_guard(backend):
  467. core.check_and_set_prim_all_enabled()
  468. origin_fwd = core._is_fwd_prim_enabled()
  469. origin_bwd = core._is_bwd_prim_enabled()
  470. if backend == 'CINN':
  471. core._set_prim_all_enabled(True)
  472. try:
  473. yield
  474. finally:
  475. core._set_prim_forward_enabled(origin_fwd)
  476. core._set_prim_backward_enabled(origin_bwd)
  477. def construct_grad_names(grad_info_map, x_vars, param_vars, out_vars):
  478. grad_var_names = {}
  479. fn = lambda grad_var: (
  480. grad_var.name
  481. if isinstance(grad_var, framework.Variable)
  482. else framework.EMPTY_VAR_NAME
  483. )
  484. x_grad_vars = backward._get_grad_vars(grad_info_map, x_vars)
  485. grad_var_names['x'] = list(map(fn, x_grad_vars))
  486. param_grad_vars = backward._get_grad_vars(grad_info_map, param_vars)
  487. grad_var_names['param'] = list(map(fn, param_grad_vars))
  488. out_grad_vars = backward._get_grad_vars(grad_info_map, out_vars)
  489. grad_var_names['out'] = list(map(fn, out_grad_vars))
  490. return grad_var_names
  491. @signature_safe_contextmanager
  492. def tensor_name_guard(tensors, names):
  493. try:
  494. assert len(tensors) == len(names)
  495. origin_names = [t.name for t in tensors]
  496. for t, name in zip(tensors, names):
  497. t.name = name
  498. yield
  499. finally:
  500. for t, name in zip(tensors, origin_names):
  501. t.name = name
  502. def cuda_pinned_tensors_move_to_excepted_place(inputs):
  503. if paddle.is_compiled_with_cuda():
  504. expected_place = framework._current_expected_place()
  505. cuda_pinned_place = CUDAPinnedPlace()
  506. for value in flatten(inputs):
  507. if (
  508. isinstance(value, core.eager.Tensor)
  509. and value.stop_gradient
  510. and value.place._equals(cuda_pinned_place)
  511. ):
  512. var = value._copy_to(expected_place, True)
  513. var.stop_gradient = True
  514. var._share_buffer_to(value)