im2col.h 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/TensorUtils.h>
  4. #include <ATen/Utils.h>
  5. #include <ATen/Parallel.h>
  6. #include <ATen/native/cpu/utils.h>
  7. #include <c10/util/irange.h>
  8. #include <algorithm>
  9. namespace at::native {
  10. template <typename T>
  11. static void im2col(
  12. const T* data_im,
  13. const int64_t channels,
  14. const int64_t height,
  15. const int64_t width,
  16. const int64_t output_height,
  17. const int64_t output_width,
  18. const int64_t kernel_h,
  19. const int64_t kernel_w,
  20. const int64_t pad_h,
  21. const int64_t pad_w,
  22. const int64_t stride_h,
  23. const int64_t stride_w,
  24. const int64_t dilation_h,
  25. const int64_t dilation_w,
  26. T* data_col,
  27. bool is_channels_last = false) {
  28. const int64_t height_col = output_height;
  29. const int64_t width_col = output_width;
  30. const int64_t channels_col = channels * kernel_h * kernel_w;
  31. if (is_channels_last) {
  32. at::parallel_for(0, height_col * width_col, 0, [&](int64_t begin, int64_t end) {
  33. int64_t h_col{0}, w_col{0};
  34. data_index_init(begin, h_col, height_col, w_col, width_col);
  35. for (const auto i_col : c10::irange(begin, end)) {
  36. for (const auto h_offset : c10::irange(kernel_h)) {
  37. int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
  38. for (const auto w_offset : c10::irange(kernel_w)) {
  39. int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
  40. const T* slice_im = data_im + (h_im * width + w_im) * channels;
  41. T* slice_col = data_col + (i_col * kernel_h * kernel_w + h_offset * kernel_w + w_offset) * channels;
  42. if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
  43. std::copy_n(slice_im, channels, slice_col);
  44. } else {
  45. std::fill_n(slice_col, channels, T(0));
  46. }
  47. }
  48. }
  49. // move the next index
  50. data_index_step(h_col, height_col, w_col, width_col);
  51. }
  52. });
  53. } else {
  54. at::parallel_for(0, channels_col, 0, [&](int64_t begin, int64_t end) {
  55. int64_t c_im{0}, h_offset{0}, w_offset{0};
  56. data_index_init(begin, c_im, channels, h_offset, kernel_h, w_offset, kernel_w);
  57. for (const auto c_col : c10::irange(begin, end)) {
  58. for (const auto h_col : c10::irange(height_col)) {
  59. int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
  60. for (const auto w_col : c10::irange(width_col)) {
  61. int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
  62. data_col[(c_col * height_col + h_col) * width_col + w_col] =
  63. (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width)
  64. ? c10::load(&(data_im[(c_im * height + h_im) * width + w_im]))
  65. : static_cast<T>(0);
  66. }
  67. }
  68. // move to the next index
  69. data_index_step(c_im, channels, h_offset, kernel_h, w_offset, kernel_w);
  70. }
  71. });
  72. }
  73. }
  74. template <typename T>
  75. static void col2im(
  76. const T* data_col,
  77. const int64_t channels,
  78. const int64_t height,
  79. const int64_t width,
  80. const int64_t output_height,
  81. const int64_t output_width,
  82. const int64_t kernel_h,
  83. const int64_t kernel_w,
  84. const int64_t pad_h,
  85. const int64_t pad_w,
  86. const int64_t stride_h,
  87. const int64_t stride_w,
  88. const int64_t dilation_h,
  89. const int64_t dilation_w,
  90. T* data_im,
  91. bool is_channels_last = false) {
  92. std::fill_n(data_im, height * width * channels, T(0));
  93. const int64_t height_col = output_height;
  94. const int64_t width_col = output_width;
  95. const int64_t channels_col = channels * kernel_h * kernel_w;
  96. if (is_channels_last) {
  97. for (const auto h_col : c10::irange(height_col)) {
  98. for (const auto w_col : c10::irange(width_col)) {
  99. for (const auto h_offset : c10::irange(kernel_h)) {
  100. int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
  101. for (const auto w_offset : c10::irange(kernel_w)) {
  102. int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
  103. T* slice_im = data_im + (h_im * width + w_im) * channels;
  104. const T* slice_col = data_col + ((h_col * width_col + w_col) * kernel_h * kernel_w
  105. + h_offset * kernel_w + w_offset) * channels;
  106. if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width) {
  107. std::transform(slice_col, slice_col + channels, slice_im, slice_im, std::plus<T>());
  108. }
  109. }
  110. }
  111. }
  112. }
  113. } else {
  114. for (const auto c_col : c10::irange(channels_col)) {
  115. int64_t w_offset = c_col % kernel_w;
  116. int64_t h_offset = (c_col / kernel_w) % kernel_h;
  117. int64_t c_im = c_col / kernel_h / kernel_w;
  118. for (const auto h_col : c10::irange(height_col)) {
  119. int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
  120. for (const auto w_col : c10::irange(width_col)) {
  121. int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
  122. if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width)
  123. data_im[(c_im * height + h_im) * width + w_im] +=
  124. data_col[(c_col * height_col + h_col) * width_col + w_col];
  125. }
  126. }
  127. }
  128. }
  129. }
  130. } // namespace at::native