PixelShuffle.h 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. #include <ATen/core/Tensor.h>
  2. #include <c10/util/Exception.h>
  3. namespace at::native {
  4. inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_factor) {
  5. TORCH_CHECK(self.dim() >= 3,
  6. "pixel_shuffle expects input to have at least 3 dimensions, but got input with ",
  7. self.dim(), " dimension(s)");
  8. TORCH_CHECK(upscale_factor > 0,
  9. "pixel_shuffle expects a positive upscale_factor, but got ",
  10. upscale_factor);
  11. int64_t c = self.size(-3);
  12. int64_t upscale_factor_squared = upscale_factor * upscale_factor;
  13. TORCH_CHECK(c % upscale_factor_squared == 0,
  14. "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
  15. "upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared);
  16. }
  17. inline void check_pixel_unshuffle_shapes(const Tensor& self, int64_t downscale_factor) {
  18. TORCH_CHECK(
  19. self.dim() >= 3,
  20. "pixel_unshuffle expects input to have at least 3 dimensions, but got input with ",
  21. self.dim(),
  22. " dimension(s)");
  23. TORCH_CHECK(
  24. downscale_factor > 0,
  25. "pixel_unshuffle expects a positive downscale_factor, but got ",
  26. downscale_factor);
  27. int64_t h = self.size(-2);
  28. int64_t w = self.size(-1);
  29. TORCH_CHECK(
  30. h % downscale_factor == 0,
  31. "pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=",
  32. h,
  33. " is not divisible by ",
  34. downscale_factor);
  35. TORCH_CHECK(
  36. w % downscale_factor == 0,
  37. "pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=",
  38. w,
  39. " is not divisible by ",
  40. downscale_factor);
  41. }
  42. } // namespace at::native