total_variation.cpp 777 B

12345678910111213141516171819202122
  1. #include <torch/extension.h>
  2. #include <vector>
  3. void total_variation_add_grad_cuda(torch::Tensor param, torch::Tensor grad, float wx, float wy, float wz, bool dense_mode);
  4. #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
  5. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  6. #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
  7. void total_variation_add_grad(torch::Tensor param, torch::Tensor grad, float wx, float wy, float wz, bool dense_mode) {
  8. CHECK_INPUT(param);
  9. CHECK_INPUT(grad);
  10. total_variation_add_grad_cuda(param, grad, wx, wy, wz, dense_mode);
  11. }
  12. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  13. m.def("total_variation_add_grad", &total_variation_add_grad, "Add total variation grad");
  14. }