UnfoldBackward.h 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/TensorIterator.h>
  4. #include <ATen/native/DispatchStub.h>
  5. #include <ATen/native/NonEmptyUtils.h>
  6. #ifndef AT_PER_OPERATOR_HEADERS
  7. #include <ATen/Functions.h>
  8. #else
  9. #include <ATen/ops/arange.h>
  10. #endif
  11. namespace at::native {
  12. using unfold_backward_fn = void (*)(
  13. Tensor& grad_in,
  14. const Tensor& grad,
  15. int64_t dim,
  16. int64_t size,
  17. int64_t step
  18. );
  19. DECLARE_DISPATCH(unfold_backward_fn, unfold_backward_stub)
  20. namespace {
  21. // Note on naming: it is unconventional.
  22. // grad_in does not mean that it is a gradient wrt to input,
  23. // grad_in/grad_out is just an input/output of unfold_backward kernel.
  24. [[maybe_unused]] static TensorIterator _make_unfold_backward_iter_over_grad_out(
  25. Tensor& grad_out,
  26. const Tensor& grad_in,
  27. int64_t dim,
  28. int64_t size,
  29. int64_t step) {
  30. dim = maybe_wrap_dim(dim, grad_out.dim());
  31. // last dim stores the folds
  32. auto grad_out_dim_size = ensure_nonempty_size(grad_out, dim);
  33. auto grad_in_dim_size = ensure_nonempty_size(grad_in, dim);
  34. // dictates the number of elements to iterate over
  35. // in dimension `dim`
  36. auto iter_dim_size = std::min(
  37. grad_out_dim_size,
  38. (grad_in_dim_size - 1) * step + size
  39. );
  40. /* prepare grad_out for TensorIterator { */
  41. auto grad_out_strides = ensure_nonempty_vec(grad_out.strides().vec());
  42. auto grad_out_sizes = ensure_nonempty_vec(grad_out.sizes().vec());
  43. grad_out_sizes[dim] = iter_dim_size;
  44. auto grad_out_restrided = grad_out.as_strided(
  45. grad_out_sizes, grad_out_strides
  46. );
  47. /* } */
  48. /* prepare grad_in for TensorIterator { */
  49. auto grad_in_strides = ensure_nonempty_vec(grad_in.strides().vec());
  50. auto grad_in_sizes = ensure_nonempty_vec(grad_in.sizes().vec());
  51. // set strides for dim to 0
  52. // and size to 1 because
  53. // this dimension is indexed inside the kernel
  54. grad_in_strides[dim] = 0;
  55. grad_in_sizes[dim] = 1;
  56. grad_in_strides.pop_back();
  57. grad_in_sizes.pop_back();
  58. auto grad_in_restrided = grad_in.squeeze(-1).as_strided(
  59. grad_in_sizes, grad_in_strides
  60. );
  61. /* } */
  62. // During the TensorIterator iteration we have to know
  63. // i_dim in grad_out[i_1,...,i_dim,...i_n],
  64. // idx_dim stores this information
  65. /* prepare idx_dim for TensorIterator { */
  66. auto idx_dim = at::arange(
  67. 0, iter_dim_size, grad_in.options().dtype(at::kLong)
  68. );
  69. auto grad_out_dim = ensure_nonempty_dim(grad_out.dim());
  70. auto idx_dim_strides = std::vector<int64_t>(grad_out_dim, 0);
  71. auto idx_dim_sizes = std::vector<int64_t>(grad_out_dim, 1);
  72. idx_dim_strides[dim] = 1;
  73. idx_dim_sizes[dim] = iter_dim_size;
  74. // idx_dim size will broadcast over determined by grad_out sizes in TensorIterator
  75. auto idx_dim_restrided = idx_dim.as_strided(idx_dim_sizes, idx_dim_strides);
  76. /* } */
  77. auto iter = TensorIteratorConfig()
  78. .set_check_mem_overlap(false)
  79. .check_all_same_dtype(false)
  80. .resize_outputs(false)
  81. .add_owned_output(grad_out_restrided)
  82. .add_owned_const_input(grad_in_restrided)
  83. .add_owned_const_input(idx_dim_restrided)
  84. .build();
  85. return iter;
  86. }
  87. }
  88. } // namespace at::native