layers_utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583
  1. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. from collections import defaultdict
  16. from collections.abc import Sequence
  17. from uuid import uuid4
  18. from weakref import WeakKeyDictionary
  19. import numpy as np
  20. import paddle
  21. from paddle.pir.core import convert_np_dtype_to_dtype_
  22. from ..base.data_feeder import check_dtype, convert_dtype
  23. from ..base.framework import (
  24. Block,
  25. Variable,
  26. in_dygraph_mode,
  27. )
  28. from ..pir import Value
  29. def convert_to_list(value, n, name, dtype=int):
  30. """
  31. Converts a single numerical type or iterable of numerical
  32. types into a numerical type list.
  33. Arguments:
  34. value: The value to validate and convert. Could an int, or any iterable
  35. of ints.
  36. n: The size of the list to be returned.
  37. name: The name of the argument being validated, e.g. "stride" or
  38. "filter_size". This is only used to format error messages.
  39. dtype: the numerical type of the element of the list to be returned.
  40. Returns:
  41. A list of n dtypes.
  42. Raises:
  43. ValueError: If something else than an int/long or iterable thereof was
  44. passed.
  45. """
  46. if isinstance(value, dtype):
  47. return [
  48. value,
  49. ] * n
  50. else:
  51. try:
  52. value_list = list(value)
  53. except TypeError:
  54. raise ValueError(
  55. "The "
  56. + name
  57. + "'s type must be list or tuple. Received: "
  58. + str(value)
  59. )
  60. if len(value_list) != n:
  61. raise ValueError(
  62. "The "
  63. + name
  64. + "'s length must be "
  65. + str(n)
  66. + ". Received: "
  67. + str(value)
  68. )
  69. for single_value in value_list:
  70. assert not isinstance(single_value, (Variable, paddle.pir.Value)), (
  71. "Required numerical type with '%s', but received Tensor."
  72. % dtype
  73. )
  74. try:
  75. dtype(single_value)
  76. except (ValueError, TypeError):
  77. raise ValueError(
  78. "The "
  79. + name
  80. + "'s type must be a list or tuple of "
  81. + str(n)
  82. + " "
  83. + str(dtype)
  84. + " . Received: "
  85. + str(value)
  86. + " "
  87. "including element "
  88. + str(single_value)
  89. + " of type"
  90. + " "
  91. + str(type(single_value))
  92. )
  93. return value_list
  94. def is_sequence(seq):
  95. """
  96. Whether `seq` is an entry or nested structure
  97. """
  98. if isinstance(seq, dict):
  99. return True
  100. return isinstance(seq, Sequence) and not isinstance(seq, str)
  101. class UniqueIdMap(WeakKeyDictionary):
  102. def __init__(self):
  103. super().__init__(self)
  104. self.data = defaultdict(uuid4)
  105. uniqueidmap = UniqueIdMap()
  106. def uniqueid(obj):
  107. if isinstance(obj, str):
  108. return (hash(obj),)
  109. elif isinstance(obj, list):
  110. return (id(obj),)
  111. else:
  112. return (uniqueidmap[obj].int,)
  113. def _hash_with_id(*args):
  114. """
  115. Return int hash value calculated by id(arg) or tuple(id1,id2, ...).
  116. """
  117. assert len(args) > 0
  118. info = ()
  119. for v in args:
  120. info = info + uniqueid(v)
  121. return hash(info)
  122. def _sorted(dict_):
  123. """
  124. Returns a sorted list of the dict keys, with error if keys not sortable.
  125. """
  126. try:
  127. return sorted(dict_.keys())
  128. except TypeError:
  129. raise TypeError("nest only supports dicts with sortable keys.")
  130. def _yield_value(iterable):
  131. if isinstance(iterable, dict):
  132. for key in _sorted(iterable):
  133. yield iterable[key]
  134. else:
  135. yield from iterable
  136. def _yield_flat_nest(nest):
  137. for n in _yield_value(nest):
  138. if is_sequence(n):
  139. yield from _yield_flat_nest(n)
  140. else:
  141. yield n
  142. def to_sequence(nest):
  143. if is_sequence(nest):
  144. return nest
  145. else:
  146. return [nest]
  147. def flatten(nest):
  148. """
  149. :alias_main: paddle.flatten
  150. :alias: paddle.flatten,paddle.tensor.flatten,paddle.tensor.manipulation.flatten
  151. :old_api: paddle.base.layers.flatten
  152. Traverse all entries in the nested structure and put them into an list.
  153. """
  154. if is_sequence(nest):
  155. return list(_yield_flat_nest(nest))
  156. else:
  157. return [nest]
  158. def _sequence_like(instance, args):
  159. """
  160. Convert the sequence `args` to the same type as `instance`.
  161. """
  162. if isinstance(instance, dict):
  163. result = dict(zip(_sorted(instance), args))
  164. return type(instance)((key, result[key]) for key in instance.keys())
  165. elif (
  166. isinstance(instance, tuple)
  167. and hasattr(instance, "_fields")
  168. and isinstance(instance._fields, Sequence)
  169. and all(isinstance(f, str) for f in instance._fields)
  170. ):
  171. # This is a namedtuple
  172. return type(instance)(*args)
  173. else:
  174. # Not a namedtuple
  175. return type(instance)(args)
  176. def _packed_nest_with_indices(structure, flat, index):
  177. """
  178. Helper function for pack_sequence_as.
  179. """
  180. packed = []
  181. for s in _yield_value(structure):
  182. if is_sequence(s):
  183. new_index, child = _packed_nest_with_indices(s, flat, index)
  184. packed.append(_sequence_like(s, child))
  185. index = new_index
  186. else:
  187. packed.append(flat[index])
  188. index += 1
  189. return index, packed
  190. def pack_sequence_as(structure, flat_sequence):
  191. """
  192. Pack a given flattened sequence into a given structure.
  193. """
  194. if not is_sequence(flat_sequence):
  195. raise TypeError("flat_sequence must be a sequence")
  196. if not is_sequence(structure):
  197. if len(flat_sequence) != 1:
  198. raise ValueError(
  199. "Structure is a scalar but len(flat_sequence) == %d > 1"
  200. % len(flat_sequence)
  201. )
  202. return flat_sequence[0]
  203. flat_structure = flatten(structure)
  204. if len(flat_structure) != len(flat_sequence):
  205. raise ValueError(
  206. "Could not pack sequence. Structure had %d elements, but flat_sequence "
  207. "had %d elements. Structure: %s, flat_sequence: %s."
  208. % (
  209. len(flat_structure),
  210. len(flat_sequence),
  211. structure,
  212. flat_sequence,
  213. )
  214. )
  215. _, packed = _packed_nest_with_indices(structure, flat_sequence, 0)
  216. return _sequence_like(structure, packed)
  217. def map_structure(func, *structure):
  218. """
  219. Apply `func` to each entry in `structure` and return a new structure.
  220. """
  221. flat_structure = [flatten(s) for s in structure]
  222. entries = zip(*flat_structure)
  223. return pack_sequence_as(structure[0], [func(*x) for x in entries])
  224. def hold_mutable_vars(structure):
  225. """
  226. Returns whether structure holds sequence like `list/dict`.
  227. """
  228. for s in structure:
  229. if is_sequence(s):
  230. return True
  231. return False
  232. def copy_mutable_vars(structure):
  233. """
  234. Returns vars copied from sequence without mutable property.
  235. """
  236. flat_structure = copy.copy(flatten(structure))
  237. return pack_sequence_as(structure, flat_structure)
  238. def _recursive_assert_same_structure(nest1, nest2, check_types):
  239. """
  240. Helper function for `assert_same_structure`.
  241. """
  242. is_sequence_nest1 = is_sequence(nest1)
  243. if is_sequence_nest1 != is_sequence(nest2):
  244. raise ValueError(
  245. "The two structures don't have the same nested structure.\n\n"
  246. f"First structure: {nest1}\n\nSecond structure: {nest2}."
  247. )
  248. if not is_sequence_nest1:
  249. return # finished checking
  250. if check_types:
  251. type_nest1 = type(nest1)
  252. type_nest2 = type(nest2)
  253. if type_nest1 != type_nest2:
  254. raise TypeError(
  255. "The two structures don't have the same sequence type. First "
  256. f"structure has type {type_nest1}, while second structure has type {type_nest2}."
  257. )
  258. if isinstance(nest1, dict):
  259. keys1 = set(nest1.keys())
  260. keys2 = set(nest2.keys())
  261. if keys1 != keys2:
  262. raise ValueError(
  263. "The two dictionaries don't have the same set of keys. First "
  264. f"structure has keys {keys1}, while second structure has keys {keys2}."
  265. )
  266. nest1_as_sequence = list(_yield_value(nest1))
  267. nest2_as_sequence = list(_yield_value(nest2))
  268. for n1, n2 in zip(nest1_as_sequence, nest2_as_sequence):
  269. _recursive_assert_same_structure(n1, n2, check_types)
  270. def padding_to_same_structure(nest1, nest2, obj=None):
  271. def _padding_to_same_structure_single(value, obj):
  272. def change_none_to_obj(x):
  273. if x is None:
  274. return obj
  275. return x
  276. if is_sequence(value):
  277. value = pack_sequence_as(
  278. value, [change_none_to_obj(item) for item in flatten(value)]
  279. )
  280. else:
  281. value = change_none_to_obj(value)
  282. return value
  283. nest1 = _padding_to_same_structure_single(nest1, obj)
  284. nest2 = _padding_to_same_structure_single(nest2, obj)
  285. return nest1, nest2
  286. def assert_same_structure(nest1, nest2, check_types=True):
  287. """
  288. Confirm two nested structures with the same structure.
  289. """
  290. len_nest1 = len(flatten(nest1)) if is_sequence(nest1) else 1
  291. len_nest2 = len(flatten(nest2)) if is_sequence(nest2) else 1
  292. if len_nest1 != len_nest2:
  293. raise ValueError(
  294. "The two structures don't have the same number of "
  295. "elements.\n\nFirst structure (%i elements): %s\n\n"
  296. "Second structure (%i elements): %s"
  297. % (len_nest1, nest1, len_nest2, nest2)
  298. )
  299. _recursive_assert_same_structure(nest1, nest2, check_types)
  300. def _is_symmetric_padding(padding, data_dim):
  301. """
  302. Check whether padding is symmetrical.
  303. """
  304. assert len(padding) == data_dim * 2 or len(padding) == data_dim
  305. is_sys = True
  306. if len(padding) == data_dim * 2:
  307. for i in range(data_dim):
  308. if padding[i * 2] != padding[i * 2 + 1]:
  309. is_sys = False
  310. return is_sys
  311. def _contain_var(list_or_tuple):
  312. """
  313. Check whether list or tuple contains variable / Value.
  314. """
  315. for item in list_or_tuple:
  316. if isinstance(item, (Variable, paddle.pir.Value)):
  317. return True
  318. return False
  319. def get_int_tensor_list(ele_list, default_dtype='int64'):
  320. int_tensor_list = []
  321. for ele in ele_list:
  322. if isinstance(ele, paddle.pir.Value):
  323. ele.stop_gradient = True
  324. if convert_dtype(ele.dtype) != default_dtype:
  325. ele = paddle.cast(x=ele, dtype=default_dtype)
  326. if ele.shape != []:
  327. ele = paddle.reshape(ele, [])
  328. int_tensor_list.append(ele)
  329. else:
  330. temp_out = paddle.tensor.fill_constant(
  331. shape=[],
  332. dtype=convert_np_dtype_to_dtype_(np.dtype(default_dtype)),
  333. value=ele,
  334. force_cpu=True,
  335. )
  336. int_tensor_list.append(temp_out)
  337. return int_tensor_list
  338. def get_shape_tensor_inputs(inputs, attrs, shape, op_type):
  339. from paddle.tensor import fill_constant
  340. def _get_attr_shape(list_shape):
  341. attr_shape = []
  342. for idx, dim in enumerate(list_shape):
  343. if isinstance(dim, Variable):
  344. attr_shape.append(-1)
  345. else:
  346. attr_shape.append(dim)
  347. return attr_shape
  348. def _get_shape_tensor(list_shape):
  349. shape_tensor_list = []
  350. for idx, dim in enumerate(list_shape):
  351. if isinstance(dim, Variable):
  352. dim.stop_gradient = True
  353. check_dtype(
  354. dim.dtype,
  355. 'shape[' + str(idx) + ']',
  356. ['int32', 'int64'],
  357. op_type,
  358. '(When type of shape in' + op_type + 'is list or tuple.)',
  359. )
  360. if convert_dtype(dim.dtype) == 'int64':
  361. dim = paddle.cast(x=dim, dtype='int32')
  362. shape_tensor_list.append(dim)
  363. else:
  364. temp_out = fill_constant([], 'int32', dim, force_cpu=True)
  365. shape_tensor_list.append(temp_out)
  366. return shape_tensor_list
  367. if isinstance(shape, Variable):
  368. shape.stop_gradient = True
  369. check_dtype(
  370. shape.dtype,
  371. 'shape',
  372. ['int32', 'int64'],
  373. 'fill_constant',
  374. '(When type of shape in' + op_type + ' is Variable.)',
  375. )
  376. if convert_dtype(shape.dtype) == 'int64':
  377. shape = paddle.cast(shape, 'int32')
  378. inputs["ShapeTensor"] = shape
  379. elif isinstance(shape, (list, tuple)):
  380. attrs["shape"] = _get_attr_shape(shape)
  381. if _contain_var(shape):
  382. inputs['ShapeTensorList'] = _get_shape_tensor(shape)
  383. else:
  384. raise TypeError("Shape only supports Variable, or list, or tuple.")
  385. def _convert_to_tensor_list(old_list, dtype="int32"):
  386. """
  387. Converts all elements of a list to Variable / Value.
  388. """
  389. from paddle.tensor import fill_constant
  390. new_list_tensor = []
  391. for ele in old_list:
  392. if isinstance(ele, (Variable, paddle.pir.Value)):
  393. ele.stop_gradient = True
  394. new_list_tensor.append(ele)
  395. else:
  396. assert isinstance(ele, int)
  397. temp_out = fill_constant([1], dtype, ele, force_cpu=True)
  398. new_list_tensor.append(temp_out)
  399. return new_list_tensor
  400. def convert_shape_to_list(shape):
  401. """
  402. Convert shape(list, tuple, variable) to list in imperative mode
  403. """
  404. if isinstance(shape, (list, tuple)):
  405. shape = [x.item(0) if isinstance(x, Variable) else x for x in shape]
  406. else:
  407. if in_dygraph_mode():
  408. shape = shape.astype(int).tolist()
  409. return shape
  410. def check_shape(shape):
  411. """
  412. Check shape type and shape elements type before passing it to fill_constant
  413. """
  414. if isinstance(shape, (Variable, Value)):
  415. check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'fill_constant')
  416. elif isinstance(shape, (list, tuple)):
  417. for ele in shape:
  418. if not isinstance(ele, (Variable, Value)):
  419. if ele < 0:
  420. raise ValueError(
  421. "All elements in ``shape`` must be positive when it's a list or tuple"
  422. )
  423. if not isinstance(ele, int):
  424. raise TypeError(
  425. "All elements in ``shape`` must be integers when it's a list or tuple"
  426. )
  427. else:
  428. check_dtype(
  429. ele.dtype,
  430. 'element of shape',
  431. ['int32', 'int64'],
  432. 'fill_constant',
  433. )
  434. def try_set_static_shape_tensor(tensor, shape):
  435. """Try to set static shape of tensor from a shape tensor.
  436. For example,
  437. import paddle
  438. paddle.enable_static()
  439. data = paddle.static.data(name="x", shape=[-1, 2], dtype='float32')
  440. shape = paddle.shape(data) # shape should be [-1, 2] instead of [-1, -1]
  441. x = paddle.uniform(shape)
  442. print(x.shape)
  443. # (-1, 2)
  444. """
  445. if not in_dygraph_mode():
  446. # static graph mode, and shape is not all inferred (contains -1)
  447. if -1 in tensor.shape:
  448. if isinstance(shape, Variable):
  449. shape = try_get_constant_shape_from_tensor(shape)
  450. if shape:
  451. tensor.desc.set_shape(shape)
  452. def try_get_constant_shape_from_tensor(shape_tensor):
  453. """Try to get shape from a tensor with constant value.
  454. For example,
  455. import paddle
  456. paddle.enable_static()
  457. data = paddle.static.data(name="x", shape=[-1, 2], dtype='float32')
  458. shape = paddle.shape(data) # shape should be [-1, 2] instead of [-1, -1]
  459. x = paddle.uniform(shape)
  460. print(x.shape)
  461. # (-1, 2)
  462. """
  463. if not in_dygraph_mode():
  464. try:
  465. if shape_tensor.op is not None:
  466. generate_op = shape_tensor.op
  467. if generate_op.type == 'shape':
  468. var = shape_tensor.block.vars[
  469. generate_op.input_arg_names[0]
  470. ]
  471. return var.shape
  472. except:
  473. return None
  474. return None
  475. def get_inputs_outputs_in_block(block):
  476. """
  477. Returns the inputs and outputs variable used in this block but not
  478. created in this block.
  479. """
  480. assert isinstance(
  481. block, Block
  482. ), "input non-Block argument for get_inputs_outputs_in_block."
  483. assert (
  484. block.parent_idx != -1
  485. ), "input block should be a sub-block, not main block."
  486. # Find input/output var names of all ops in block
  487. inner_inputs = set()
  488. inner_outputs = set()
  489. for op in block.ops:
  490. for iname in op.input_names:
  491. for in_var_name in op.input(iname):
  492. if not block.has_var(in_var_name):
  493. # variable not created in this block
  494. inner_inputs.add(in_var_name)
  495. for oname in op.output_names:
  496. for out_var_name in op.output(oname):
  497. if not block.has_var(out_var_name):
  498. # variable not created in this block
  499. inner_outputs.add(out_var_name)
  500. return inner_inputs, inner_outputs