variable_index.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. from . import core, unique_name
  17. MAX_INTEGER = 2**31 - 1
  18. def replace_ellipsis(var, item):
  19. from .framework import Variable
  20. # Use slice(None) to replace Ellipsis.
  21. # For var, var.shape = [3,4,5,6]
  22. #
  23. # var[..., 1:2] -> var[:, :, :, 1:2]
  24. # var[0, ...] -> var[0]
  25. # var[0, ..., 1:2] -> var[0, :, :, 1:2]
  26. item = list(item)
  27. # Remove Variable to skip bug when counting Ellipsis
  28. item_remove_var = [
  29. ele
  30. for ele in item
  31. if not isinstance(ele, (Variable, paddle.pir.Value, np.ndarray))
  32. and ele is not None
  33. ]
  34. ell_count = item_remove_var.count(Ellipsis)
  35. if ell_count == 0:
  36. return item
  37. elif ell_count > 1:
  38. raise IndexError("An index can only have a single ellipsis ('...')")
  39. ell_idx = item.index(Ellipsis)
  40. if ell_idx == len(item) - 1:
  41. return item[:-1]
  42. else:
  43. item[ell_idx : ell_idx + 1] = [slice(None)] * (
  44. len(var.shape) - len(item) + item.count(None) + 1
  45. )
  46. return item
  47. def replace_ndarray_and_range(item):
  48. new_item = []
  49. for slice_item in item:
  50. if isinstance(slice_item, np.ndarray):
  51. new_item.append(paddle.assign(slice_item))
  52. elif isinstance(slice_item, range):
  53. new_item.append(list(slice_item))
  54. else:
  55. new_item.append(slice_item)
  56. return new_item
  57. def replace_none(item):
  58. new_item = []
  59. none_axes = []
  60. for i, slice_item in enumerate(item):
  61. if slice_item is None:
  62. none_axes.append(i)
  63. else:
  64. new_item.append(slice_item)
  65. return new_item, none_axes
  66. def is_scalar_tensor(ele):
  67. from .framework import Variable
  68. if isinstance(ele, Variable):
  69. if len(ele.shape) == 0 and ele.dtype != paddle.bool:
  70. return True
  71. elif isinstance(ele, paddle.pir.Value):
  72. if len(ele.shape) == 0 and ele.dtype != paddle.base.libpaddle.BOOL:
  73. return True
  74. return False
  75. def deal_attrs(attrs, attr, attr_name, tensor_attr_name, inputs, infer_flags):
  76. from .framework import Variable
  77. if paddle.utils._contain_var(attr):
  78. inputs[tensor_attr_name] = paddle.utils._convert_to_tensor_list(
  79. attr, dtype="int64"
  80. )
  81. for i, dim in enumerate(attr):
  82. if isinstance(dim, (Variable, paddle.pir.Value)):
  83. attrs[attr_name].append(-1)
  84. infer_flags[i] = -1
  85. else:
  86. attrs[attr_name].append(dim)
  87. else:
  88. attrs[attr_name] = attr
  89. def get_value_for_bool_tensor(var, item):
  90. if len(item.shape) > len(var.shape):
  91. raise IndexError(
  92. "The dims of bool index doesn't match indexed array, "
  93. "the dims of bool index except to be equal or less "
  94. f"than {len(var.shape)}, but received {len(item.shape)}."
  95. )
  96. i = 0
  97. item_shape = item.shape
  98. while i < len(item.shape):
  99. dim_len = item_shape[i]
  100. if dim_len != -1 and var.shape[i] != -1 and dim_len != var.shape[i]:
  101. raise IndexError(
  102. "The dimension of bool index doesn't match indexed array along "
  103. f"dimension {i}, the target dimension is {var.shape[i]}, but received {dim_len}."
  104. )
  105. i += 1
  106. if len(item.shape) == len(var.shape):
  107. return paddle.masked_select(var, item)
  108. bool_2_idx = paddle.nonzero(item)
  109. return paddle.gather_nd(var, bool_2_idx)
  110. def _setitem_for_tensor_array(var, item, value):
  111. """branches for tensor array setitem operation.
  112. A item can be a:
  113. (1) int/Variable, which is a simple number/variable such as [1], [-2]
  114. (2) Slice, which is represented by bounds such as [2:-1]
  115. (3) Tuple, which includes the above two cases such as [2:-1, 1]
  116. If item is case (1), we perform paddle.tensor.array_write,
  117. in other cases, we raise a NotImplementedError.
  118. """
  119. from .framework import Variable
  120. assert (
  121. not paddle.in_dynamic_mode()
  122. ), "setitem for tensor_array must be called in static graph mode."
  123. if isinstance(item, (Variable, paddle.pir.Value, int)):
  124. from paddle.jit.dy2static.convert_operators import to_static_variable
  125. from paddle.tensor import array_write
  126. item = paddle.cast(to_static_variable(item), dtype='int64')
  127. value = to_static_variable(value)
  128. return array_write(x=value, i=item, array=var)
  129. else:
  130. raise NotImplementedError(
  131. f"Only support __setitem__ by Int/Variable in tensor_array, but gets {type(item)}"
  132. )
  133. def deal_advanced_index(
  134. ori_tensor, indices, is_for_setitem, values, out_is_view=True
  135. ):
  136. """
  137. Transpose origin Tensor and advanced indices to the front.
  138. Returns:
  139. transed_tensor (Tensor): transposed tensor, corresponding with advanced indices
  140. transed_index (List): advanced indices transposed to the front
  141. trans_back_dim (List): order of axes to transpose back to original order. Only used in __setitem__.
  142. pos_of_new_dim (int): axis of new dim in the result. Only used in __getitem__.
  143. rank_of_new_dim (int): rank of new dim in the result. Only used in __getitem__.
  144. transed_value_tensor (Tensor): value tensor transposed to the front. Only used in __setitem__.
  145. """
  146. transed_dim = []
  147. transed_index = []
  148. # These flags indicates whether the result get by gather_nd requires a second transpose.
  149. # Only used in __getitem__.
  150. pos_of_new_dim = MAX_INTEGER
  151. rank_of_new_dim = 1
  152. for i, indice in enumerate(indices):
  153. if indice is not None:
  154. if i == 0:
  155. # case 1: advanced indices at axis 0, the new dim will be at first.
  156. pos_of_new_dim = 0
  157. if i > 0 and len(transed_dim) > 0 and transed_dim[-1] != i - 1:
  158. # case 2: there are not adjacent advanced indices, the new dim will be at first.
  159. pos_of_new_dim = 0
  160. else:
  161. pos_of_new_dim = min(pos_of_new_dim, i)
  162. rank_of_new_dim = max(rank_of_new_dim, indice[1].ndim)
  163. transed_dim.append(i)
  164. transed_index.append(indice[1])
  165. for i in range(ori_tensor.ndim):
  166. if indices[i] is None:
  167. transed_dim.append(i)
  168. trans_back_dim = np.argsort(transed_dim).tolist() if is_for_setitem else []
  169. transed_value_tensor = None
  170. if transed_dim == list(range(ori_tensor.ndim)):
  171. transed_tensor = ori_tensor
  172. if is_for_setitem:
  173. transed_value_tensor = values
  174. else:
  175. out_is_view = True
  176. transed_tensor = ori_tensor.transpose(transed_dim)
  177. if is_for_setitem:
  178. if values.ndim > 1 and pos_of_new_dim != 0:
  179. # If the value tensor is not a scalar / 1-D Tensor, and the src tensor was
  180. # transposed at 1st dim, the value tensor should be transposed too.
  181. transed_value_tensor = values.transpose(transed_dim)
  182. else:
  183. transed_value_tensor = values
  184. return (
  185. transed_tensor,
  186. transed_index,
  187. trans_back_dim,
  188. pos_of_new_dim,
  189. rank_of_new_dim,
  190. transed_value_tensor,
  191. out_is_view,
  192. )
  193. def slice_is_same_to_original(start, end, step):
  194. if start is None and end is None and step is None:
  195. return True
  196. # If there is Variable, we cannot determine whether it is the same to original.
  197. if isinstance(start, (paddle.base.Variable, paddle.pir.Value)):
  198. return False
  199. if isinstance(end, (paddle.base.Variable, paddle.pir.Value)):
  200. return False
  201. if isinstance(step, (paddle.base.Variable, paddle.pir.Value)):
  202. return False
  203. return start == 0 and end == MAX_INTEGER and step == 1
  204. def is_tensor_array_type(value):
  205. from .framework import in_pir_mode
  206. if in_pir_mode():
  207. return value.is_dense_tensor_array_type()
  208. else:
  209. return (
  210. hasattr(value, "desc")
  211. and value.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY
  212. )
  213. def parse_index(x, indices):
  214. is_tensor_array = is_tensor_array_type(x)
  215. advanced_index = (
  216. [] if is_tensor_array else [None] * 2 * len(x.shape)
  217. ) # content is (dim, index)
  218. # for set_value / slice / strided_slice OP
  219. decrease_axes = []
  220. axes = []
  221. starts = []
  222. ends = []
  223. steps = []
  224. use_strided_slice = False
  225. has_advanced_index = False
  226. if not isinstance(indices, tuple):
  227. indices = (indices,)
  228. indices = replace_ndarray_and_range(indices)
  229. indices = replace_ellipsis(x, indices)
  230. indices, none_axes = replace_none(indices)
  231. estimated_dim = 0
  232. dim = 0
  233. for i, slice_item in enumerate(indices):
  234. start, end, step = None, None, None
  235. if type(slice_item) is int:
  236. if (
  237. not is_tensor_array
  238. and x.shape[dim] is not None
  239. and x.shape[dim] >= 0
  240. and slice_item >= x.shape[dim]
  241. ):
  242. # For python, if users write a, b = var, the __getitem__
  243. # method will iterate through 0, 1, 2 ... until __getitem__
  244. # throws an IndexError, then stop. The var[0], var[1] will
  245. # be given to a, b respectively. If more values are given,
  246. # the unpack size would cause error.
  247. # We raises IndexError here to support grammar like `a, b = var`
  248. raise IndexError(
  249. "slice_item %d at dim %d should be >= 0 and < x.shape[%d]: %d"
  250. % (slice_item, dim, dim, x.shape[dim])
  251. )
  252. # not calculate result to reduce call times for slice OP.
  253. decrease_axes.append(dim)
  254. start = slice_item
  255. step = 1
  256. end = slice_item + 1 if slice_item != -1 else MAX_INTEGER
  257. dim += 1
  258. elif is_scalar_tensor(slice_item):
  259. # not calculate result to reduce call times for slice OP.
  260. decrease_axes.append(dim)
  261. start = slice_item
  262. step = 1
  263. end = slice_item + 1
  264. dim += 1
  265. elif isinstance(slice_item, bool):
  266. # single bool is advanced-indexing
  267. none_axes.append(dim)
  268. advanced_index[estimated_dim] = (
  269. estimated_dim,
  270. paddle.to_tensor([slice_item]),
  271. )
  272. has_advanced_index = True
  273. estimated_dim += 1
  274. elif isinstance(slice_item, slice):
  275. start = slice_item.start
  276. end = slice_item.stop
  277. step = slice_item.step
  278. if start is None and end is None and step is None:
  279. estimated_dim += 1
  280. dim += 1
  281. continue
  282. step = 1 if step is None else step
  283. if start is None:
  284. start = 0 if step > 0 else MAX_INTEGER
  285. if end is None:
  286. end = MAX_INTEGER if step > 0 else -1
  287. if not (
  288. is_tensor_array
  289. or isinstance(end, (paddle.base.Variable, paddle.pir.Value))
  290. or isinstance(step, (paddle.base.Variable, paddle.pir.Value))
  291. ):
  292. if x.shape[dim] != -1 and end >= x.shape[dim]:
  293. end = MAX_INTEGER if step > 0 else -1
  294. estimated_dim += 1
  295. dim += 1
  296. elif isinstance(slice_item, (list, tuple)):
  297. advanced_index[estimated_dim] = (
  298. estimated_dim,
  299. paddle.to_tensor(slice_item),
  300. )
  301. if (
  302. advanced_index[estimated_dim][1].dtype == paddle.bool
  303. and len(slice_item) != x.shape[dim]
  304. ):
  305. raise IndexError(
  306. f"The shape of boolean index {len(slice_item)} did not match indexed tensor {x.shape[dim]} along axis {dim}"
  307. )
  308. has_advanced_index = True
  309. estimated_dim += 1
  310. dim += 1
  311. elif isinstance(slice_item, paddle.base.Variable):
  312. # In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing.
  313. if (
  314. slice_item.dtype == paddle.bool
  315. or slice_item.dtype == paddle.base.libpaddle.BOOL
  316. ):
  317. if slice_item.ndim == 0:
  318. # 0-D bool Tensor, same as single PY-bool.
  319. none_axes.append(dim)
  320. elif slice_item.shape[0] != x.shape[dim]:
  321. raise IndexError(
  322. f"The shape of boolean index {slice_item.shape[0]} did not match indexed tensor {x.shape[dim]} along axis {dim}"
  323. )
  324. advanced_index[estimated_dim] = (estimated_dim, slice_item)
  325. has_advanced_index = True
  326. estimated_dim += 1
  327. dim += 1
  328. elif isinstance(slice_item, paddle.pir.Value):
  329. # In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing.
  330. if slice_item.dtype == paddle.pir.core.DataType.BOOL:
  331. if slice_item.ndim == 0:
  332. # 0-D bool Tensor, same as single PY-bool.
  333. none_axes.append(dim)
  334. elif slice_item.shape[0] != x.shape[dim]:
  335. raise IndexError(
  336. f"The shape of boolean index {slice_item.shape[0]} did not match indexed tensor {x.shape[dim]} along axis {dim}"
  337. )
  338. advanced_index[estimated_dim] = (estimated_dim, slice_item)
  339. has_advanced_index = True
  340. estimated_dim += 1
  341. dim += 1
  342. else:
  343. raise IndexError(
  344. f"Valid index accept int / bool / slice / ellipsis / list / Tuple / Ndarray / Tensor, but received {slice_item}."
  345. )
  346. if not slice_is_same_to_original(start, end, step):
  347. starts.append(start)
  348. ends.append(end)
  349. steps.append(step)
  350. axes.append(dim - 1)
  351. use_strided_slice = (
  352. True
  353. if (
  354. isinstance(step, (paddle.base.Variable, paddle.pir.Value))
  355. or step != 1
  356. )
  357. else use_strided_slice
  358. )
  359. return (
  360. starts,
  361. ends,
  362. steps,
  363. axes,
  364. none_axes,
  365. decrease_axes,
  366. advanced_index,
  367. has_advanced_index,
  368. use_strided_slice,
  369. )
  370. def _setitem_static(x, indices, values):
  371. """
  372. In dynamic mode, this function will modify the value at input tensor, returning same Tensor as input.
  373. But it will return a new Tensor with assigned value in static mode.
  374. Args:
  375. x(Tensor): Tensor to be set value.
  376. indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched.
  377. values(Tensor|Number|Ndarray): values to be assigned to the x.
  378. """
  379. from . import in_dynamic_or_pir_mode
  380. from .framework import Variable, default_main_program, in_pir_mode
  381. is_tensor_array = is_tensor_array_type(x)
  382. if is_tensor_array:
  383. return _setitem_for_tensor_array(x, indices, values)
  384. # step1: parsing the index and recording them
  385. (
  386. starts,
  387. ends,
  388. steps,
  389. axes,
  390. none_axes,
  391. decrease_axes,
  392. advanced_index,
  393. has_advanced_index,
  394. use_strided_slice,
  395. ) = parse_index(x, indices)
  396. inputs = {'Input': x}
  397. attrs = {
  398. 'axes': axes,
  399. 'starts': starts,
  400. 'ends': ends,
  401. 'steps': steps,
  402. 'decrease_axes': decrease_axes,
  403. 'none_axes': none_axes,
  404. }
  405. value_tensor = None
  406. StartsTensorList = None
  407. EndsTensorList = None
  408. StepsTensorList = None
  409. shape = None
  410. if paddle.utils._contain_var(starts):
  411. StartsTensorList = paddle.utils._convert_to_tensor_list(starts)
  412. inputs['StartsTensorList'] = StartsTensorList
  413. del attrs['starts']
  414. if paddle.utils._contain_var(ends):
  415. EndsTensorList = paddle.utils._convert_to_tensor_list(ends)
  416. inputs['EndsTensorList'] = EndsTensorList
  417. del attrs['ends']
  418. if paddle.utils._contain_var(steps):
  419. StepsTensorList = paddle.utils._convert_to_tensor_list(steps)
  420. inputs['StepsTensorList'] = StepsTensorList
  421. del attrs['steps']
  422. if not has_advanced_index:
  423. # step2. Parse values
  424. dtype = x.dtype
  425. attrs['dtype'] = dtype
  426. from .data_feeder import convert_dtype
  427. if isinstance(values, (bool, int, float, complex)):
  428. values = np.array([values]).astype(convert_dtype(dtype))
  429. if isinstance(values, np.ndarray):
  430. shape = list(values.shape)
  431. values = values.ravel().tolist()
  432. attrs["values"] = values
  433. attrs["shape"] = shape
  434. elif isinstance(values, (Variable, paddle.pir.Value)):
  435. values = values.astype(dtype)
  436. inputs["ValueTensor"] = values
  437. value_tensor = values
  438. else:
  439. raise TypeError(
  440. "Only support to assign an integer, float, numpy.ndarray or "
  441. f"paddle.Tensor to a paddle.Tensor, but received {type(values)}"
  442. )
  443. # step3.1: Only basic indexing, use OP set_value to set value.
  444. if in_dynamic_or_pir_mode():
  445. if in_pir_mode():
  446. if isinstance(starts, (list, tuple)):
  447. if paddle.utils._contain_var(starts):
  448. starts = paddle.utils.get_int_tensor_list(starts)
  449. if isinstance(ends, (list, tuple)):
  450. if paddle.utils._contain_var(ends):
  451. ends = paddle.utils.get_int_tensor_list(ends)
  452. if isinstance(steps, (list, tuple)):
  453. if paddle.utils._contain_var(steps):
  454. steps = paddle.utils.get_int_tensor_list(steps)
  455. if value_tensor is None:
  456. output = paddle._C_ops.set_value_(
  457. x,
  458. starts,
  459. ends,
  460. steps,
  461. axes,
  462. decrease_axes,
  463. none_axes,
  464. shape,
  465. values,
  466. )
  467. else:
  468. output = paddle._C_ops.set_value_with_tensor_(
  469. x,
  470. value_tensor,
  471. starts,
  472. ends,
  473. steps,
  474. axes,
  475. decrease_axes,
  476. none_axes,
  477. )
  478. if in_pir_mode():
  479. # map var to the new output, for dy2static
  480. from paddle.jit.pir_dy2static.parameter_recorder import (
  481. _global_inplace_map,
  482. )
  483. _global_inplace_map.add(default_main_program(), x, output)
  484. return output
  485. else:
  486. helper = paddle.base.layer_helper.LayerHelper(
  487. 'set_value', **locals()
  488. )
  489. if helper.main_program.current_block_idx != 0:
  490. # not in global block, we should create a global variable.
  491. output = helper._create_global_variable_for_type_inference(
  492. dtype=x.dtype
  493. )
  494. else:
  495. output = helper.create_variable_for_type_inference(
  496. dtype=x.dtype
  497. )
  498. cur_block = default_main_program().current_block()
  499. cur_block.append_op(
  500. type="set_value",
  501. inputs=inputs,
  502. outputs={'Out': output},
  503. attrs=attrs,
  504. inplace_map={"Input": "Out"},
  505. )
  506. # map var to the new output
  507. paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
  508. cur_block.program, x.desc.id(), output
  509. )
  510. return output
  511. else:
  512. # step3.2: Case for there are advanced indexing.
  513. # 1. get __getitem__ result of basic indexing;
  514. # 2. transpose original tensor so that the axis with advanced indexing will come to the first;
  515. # 3. assign values to the sliced result by index_put OP;
  516. # 4. transpose back and assign the result to original tensor by set_value OP.
  517. if not isinstance(values, (Variable, paddle.pir.Value)):
  518. values = paddle.assign(values).astype(x.dtype)
  519. sub_tensor, is_view = get_tensor_with_basic_indexing(
  520. x,
  521. axes,
  522. starts,
  523. ends,
  524. steps,
  525. decrease_axes,
  526. none_axes,
  527. use_strided_slice,
  528. )
  529. (
  530. transed_sub_tensor,
  531. adjusted_advanced_index,
  532. transback_dim,
  533. _,
  534. _,
  535. values,
  536. is_view,
  537. ) = deal_advanced_index(
  538. sub_tensor, advanced_index, True, values, is_view
  539. )
  540. if values.dtype != transed_sub_tensor.dtype:
  541. values = values.astype(transed_sub_tensor.dtype)
  542. if paddle.in_dynamic_mode():
  543. if (
  544. len(adjusted_advanced_index) == 1
  545. and adjusted_advanced_index[0].dtype
  546. in (paddle.bool, paddle.base.libpaddle.BOOL)
  547. and len(
  548. adjusted_advanced_index[0].shape
  549. == len(transed_sub_tensor.shape)
  550. )
  551. ):
  552. if values.shape != transed_sub_tensor.shape:
  553. values = values.expand(transed_sub_tensor.shape)
  554. transed_sub_tensor = paddle._C_ops.where_(
  555. paddle.logical_not(adjusted_advanced_index[0]),
  556. transed_sub_tensor,
  557. values,
  558. )
  559. if not is_view:
  560. return x
  561. else:
  562. # NOTE(zoooo0820): directly return result instead of another set_value, after backward bug fixed.
  563. transed_sub_tensor = transed_sub_tensor.index_put_(
  564. adjusted_advanced_index, values
  565. )
  566. if not is_view:
  567. return x
  568. else:
  569. transed_sub_tensor = transed_sub_tensor.index_put(
  570. adjusted_advanced_index, values
  571. )
  572. transback_sub_tensor = transed_sub_tensor.transpose(transback_dim)
  573. inputs["ValueTensor"] = transback_sub_tensor
  574. if in_dynamic_or_pir_mode():
  575. if in_pir_mode():
  576. if isinstance(starts, (list, tuple)):
  577. if paddle.utils._contain_var(starts):
  578. starts = paddle.utils.get_int_tensor_list(starts)
  579. if isinstance(ends, (list, tuple)):
  580. if paddle.utils._contain_var(ends):
  581. ends = paddle.utils.get_int_tensor_list(ends)
  582. if isinstance(steps, (list, tuple)):
  583. if paddle.utils._contain_var(steps):
  584. ends = paddle.utils.get_int_tensor_list(steps)
  585. output = paddle._C_ops.set_value_with_tensor_(
  586. x,
  587. transback_sub_tensor,
  588. starts,
  589. ends,
  590. steps,
  591. axes,
  592. decrease_axes,
  593. none_axes,
  594. )
  595. from paddle.jit.pir_dy2static.parameter_recorder import (
  596. _global_inplace_map,
  597. )
  598. _global_inplace_map.add(default_main_program(), x, output)
  599. else:
  600. helper = paddle.base.layer_helper.LayerHelper(
  601. 'set_value', **locals()
  602. )
  603. if helper.main_program.current_block_idx != 0:
  604. # not in global block, we should create a global variable.
  605. output = helper._create_global_variable_for_type_inference(
  606. dtype=x.dtype
  607. )
  608. else:
  609. output = helper.create_variable_for_type_inference(
  610. dtype=x.dtype
  611. )
  612. cur_block = default_main_program().current_block()
  613. cur_block.append_op(
  614. type="set_value",
  615. inputs=inputs,
  616. outputs={'Out': output},
  617. attrs=attrs,
  618. inplace_map={"Input": "Out"},
  619. )
  620. # map var to the new output
  621. paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
  622. cur_block.program, x.desc.id(), output
  623. )
  624. return output
  625. def get_tensor_with_basic_indexing(
  626. x, axes, starts, ends, steps, decrease_axes, none_axes, use_strided_slice
  627. ):
  628. from .dygraph.base import in_to_static_mode
  629. out_is_view = False
  630. if in_to_static_mode() and hasattr(x, "is_view_var"):
  631. x.is_view_var = True
  632. if len(axes) == 0:
  633. out = x
  634. else:
  635. out_is_view = True
  636. op_type = "strided_slice" if use_strided_slice else "slice"
  637. inputs = {'Input': [x]}
  638. attrs = {
  639. 'axes': axes,
  640. 'starts': [],
  641. 'ends': [],
  642. 'decrease_axis': decrease_axes,
  643. }
  644. if use_strided_slice:
  645. attrs['strides'] = []
  646. infer_flags = [1] * len(axes)
  647. deal_attrs(
  648. attrs, starts, "starts", "StartsTensorList", inputs, infer_flags
  649. )
  650. deal_attrs(attrs, ends, "ends", "EndsTensorList", inputs, infer_flags)
  651. deal_attrs(
  652. attrs, steps, "strides", "StridesTensorList", inputs, infer_flags
  653. )
  654. attrs['infer_flags'] = infer_flags
  655. from . import in_dynamic_or_pir_mode, in_pir_mode
  656. if in_dynamic_or_pir_mode():
  657. if "StartsTensorList" in inputs.keys():
  658. st = inputs['StartsTensorList']
  659. else:
  660. st = attrs['starts']
  661. if "EndsTensorList" in inputs.keys():
  662. end = inputs['EndsTensorList']
  663. else:
  664. end = attrs['ends']
  665. if "StridesTensorList" in inputs.keys():
  666. stride = inputs['StridesTensorList']
  667. else:
  668. stride = attrs['strides']
  669. if use_strided_slice:
  670. # TODO(zoooo0820): support strided_slice_array until PIR API is ready
  671. out = paddle._C_ops.strided_slice(x, axes, st, end, stride)
  672. if len(decrease_axes) > 0:
  673. out = paddle._C_ops.squeeze(out, decrease_axes)
  674. else:
  675. if in_pir_mode():
  676. if isinstance(st, (list, tuple)):
  677. if paddle.utils._contain_var(st):
  678. st = paddle.utils.get_int_tensor_list(st)
  679. if isinstance(end, (list, tuple)):
  680. if paddle.utils._contain_var(end):
  681. end = paddle.utils.get_int_tensor_list(end)
  682. if x.is_dense_tensor_array_type():
  683. if len(decrease_axes) > 0:
  684. return (
  685. paddle._pir_ops.slice_array_dense(x, st),
  686. False,
  687. )
  688. else:
  689. return (
  690. paddle._pir_ops.slice_array(x, st, end),
  691. False,
  692. )
  693. out = paddle._C_ops.slice(
  694. x,
  695. axes,
  696. st,
  697. end,
  698. attrs['infer_flags'],
  699. attrs['decrease_axis'],
  700. )
  701. else:
  702. from .framework import default_main_program
  703. target_block = default_main_program().current_block()
  704. slice_out_var = target_block.create_var(
  705. name=unique_name.generate_with_ignorable_key(
  706. x.name + "_" + op_type
  707. ),
  708. dtype=x.dtype,
  709. )
  710. target_block.append_op(
  711. type=op_type,
  712. inputs=inputs,
  713. outputs={'Out': [slice_out_var]},
  714. attrs=attrs,
  715. )
  716. out = slice_out_var
  717. if len(none_axes) > 0:
  718. out_is_view = True
  719. # Deal with cases that decrease_axes is not empty
  720. # For example:
  721. # # x.shape: (2,3,4)
  722. # out = x[0, 0:2, None] # out.shape : (2, 1, 4)
  723. for idx, axis in enumerate(none_axes):
  724. l = len([i for i in decrease_axes if i < axis])
  725. new_axis = axis - l
  726. none_axes[idx] = new_axis
  727. out = paddle.unsqueeze(out, axis=none_axes)
  728. if in_to_static_mode() and hasattr(out, "is_view_var"):
  729. out.is_view_var = True
  730. return out, out_is_view
  731. def _getitem_static(x, indices):
  732. """
  733. Args:
  734. x(Tensor): Tensor to be indexing.
  735. indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched.
  736. """
  737. # step1: parsing the index and recording them
  738. (
  739. starts,
  740. ends,
  741. steps,
  742. axes,
  743. none_axes,
  744. decrease_axes,
  745. advanced_index,
  746. has_advanced_index,
  747. use_strided_slice,
  748. ) = parse_index(x, indices)
  749. # step2: Dealing with basic indexing
  750. out, _ = get_tensor_with_basic_indexing(
  751. x,
  752. axes,
  753. starts,
  754. ends,
  755. steps,
  756. decrease_axes,
  757. none_axes,
  758. use_strided_slice,
  759. )
  760. # step3: Dealing with advanced indexing
  761. if has_advanced_index:
  762. (
  763. transed_tensor,
  764. adjusted_advanced_index,
  765. _,
  766. pos_of_new_dim,
  767. rank_of_new_dim,
  768. _,
  769. _,
  770. ) = deal_advanced_index(out, advanced_index, False, None)
  771. # TODO(zooooo0820): Replacing gather_nd to another advanced OP for handling of mixed indexes more efficiently
  772. if len(adjusted_advanced_index) == 1 and adjusted_advanced_index[
  773. 0
  774. ].dtype in (paddle.bool, paddle.base.libpaddle.BOOL):
  775. # Note: now slice not support 0-size Tensor, so only one bool tensor can return empty 0-size.
  776. out = get_value_for_bool_tensor(
  777. transed_tensor, adjusted_advanced_index[0]
  778. )
  779. else:
  780. adjusted_advanced_index = parse_bool_and_broadcast_indices(
  781. adjusted_advanced_index
  782. )
  783. if len(adjusted_advanced_index) > 1:
  784. advanced_index_tensor = paddle.stack(
  785. adjusted_advanced_index, axis=-1
  786. )
  787. else:
  788. # fast path for single bool tensor, since stack is much slower than unsuqeeze
  789. advanced_index_tensor = adjusted_advanced_index[0].unsqueeze(-1)
  790. out = paddle.gather_nd(transed_tensor, advanced_index_tensor)
  791. if pos_of_new_dim != 0:
  792. perm = (
  793. list(range(rank_of_new_dim, pos_of_new_dim + rank_of_new_dim))
  794. + list(range(0, rank_of_new_dim))
  795. + list(range(pos_of_new_dim + rank_of_new_dim, out.ndim))
  796. )
  797. out = out.transpose(perm)
  798. return out
  799. def parse_bool_and_broadcast_indices(indices):
  800. # deal with multiple Tensors and translating bool tensor to int tensor.
  801. # In static mode, bool-tensor cannot be broadcasted since its corresponding int tensor's shape cannot be infered.
  802. for i, indice in enumerate(indices):
  803. if (
  804. indice.dtype == paddle.bool
  805. or indice.dtype == paddle.base.libpaddle.BOOL
  806. ):
  807. indices[i] = paddle.nonzero(indice)[:, 0]
  808. if len(indices) > 1:
  809. indices = paddle.broadcast_tensors(indices)
  810. return indices