dlpack.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. from ..base.core import LoDTensor
  16. from ..base.data_feeder import check_type
  17. from ..base.framework import in_dygraph_mode
  18. __all__ = [
  19. 'to_dlpack',
  20. 'from_dlpack',
  21. ]
  22. def to_dlpack(x):
  23. """
  24. Encodes a tensor to DLPack.
  25. Args:
  26. x (Tensor): The input tensor, and the data type can be `bool`, `float16`, `float32`,
  27. `float64`, `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`,
  28. `complex128`.
  29. Returns:
  30. dltensor, and the data type is PyCapsule.
  31. Examples:
  32. .. code-block:: python
  33. >>> import paddle
  34. >>> # x is a tensor with shape [2, 4]
  35. >>> x = paddle.to_tensor([[0.2, 0.3, 0.5, 0.9],
  36. ... [0.1, 0.2, 0.6, 0.7]])
  37. >>> dlpack = paddle.utils.dlpack.to_dlpack(x)
  38. >>> print(dlpack)
  39. >>> # doctest: +SKIP('the address will change in every run')
  40. <capsule object "dltensor" at 0x7f6103c681b0>
  41. """
  42. if in_dygraph_mode():
  43. if not isinstance(x, (paddle.Tensor, paddle.base.core.eager.Tensor)):
  44. raise TypeError(
  45. "The type of 'x' in to_dlpack must be paddle.Tensor,"
  46. f" but received {type(x)}."
  47. )
  48. return x.value().get_tensor()._to_dlpack()
  49. check_type(x, 'x', (LoDTensor), 'to_dlpack')
  50. return x._to_dlpack()
  51. def from_dlpack(dlpack):
  52. """
  53. Decodes a DLPack to a tensor.
  54. Args:
  55. dlpack (PyCapsule): a PyCapsule object with the dltensor.
  56. Returns:
  57. out (Tensor), a tensor decoded from DLPack. One thing to be noted, if we get
  58. an input dltensor with data type as `bool`, we return the decoded
  59. tensor as `uint8`.
  60. Examples:
  61. .. code-block:: python
  62. >>> import paddle
  63. >>> # x is a tensor with shape [2, 4]
  64. >>> x = paddle.to_tensor([[0.2, 0.3, 0.5, 0.9],
  65. ... [0.1, 0.2, 0.6, 0.7]])
  66. >>> dlpack = paddle.utils.dlpack.to_dlpack(x)
  67. >>> x = paddle.utils.dlpack.from_dlpack(dlpack)
  68. >>> print(x)
  69. Tensor(shape=[2, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
  70. [[0.20000000, 0.30000001, 0.50000000, 0.89999998],
  71. [0.10000000, 0.20000000, 0.60000002, 0.69999999]])
  72. """
  73. t = type(dlpack)
  74. dlpack_flag = t.__module__ == 'builtins' and t.__name__ == 'PyCapsule'
  75. if not dlpack_flag:
  76. raise TypeError(
  77. "The type of 'dlpack' in from_dlpack must be PyCapsule object,"
  78. f" but received {type(dlpack)}."
  79. )
  80. if in_dygraph_mode():
  81. out = paddle.base.core.from_dlpack(dlpack)
  82. out = paddle.to_tensor(out)
  83. return out
  84. out = paddle.base.core.from_dlpack(dlpack)
  85. return out