LossMulti.h 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/AccumulateType.h>
  4. #include <ATen/Dispatch.h>
  5. #include <ATen/TensorUtils.h>
  6. namespace at::native {
  7. inline void multilabel_margin_loss_shape_check(
  8. int64_t& nframe,
  9. int64_t& dim,
  10. const int64_t& ndims,
  11. const Tensor& input,
  12. const Tensor& target) {
  13. TORCH_CHECK(
  14. (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
  15. "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
  16. input.sizes());
  17. if (ndims <= 1) {
  18. nframe = 1;
  19. dim = ndims == 0 ? 1 : input.size(0);
  20. TORCH_CHECK(
  21. target.dim() <= 1 && target.numel() == dim,
  22. "inconsistent target size: ", target.sizes(), " for input of size: ",
  23. input.sizes());
  24. } else {
  25. nframe = input.size(0);
  26. dim = input.size(1);
  27. TORCH_CHECK(
  28. target.dim() == 2 && target.size(0) == nframe &&
  29. target.size(1) == dim,
  30. "inconsistent target size: ", target.sizes(), " for input of size: ",
  31. input.sizes());
  32. }
  33. }
  34. inline void multi_margin_loss_shape_check(
  35. int64_t& nframe,
  36. int64_t& dim,
  37. const int64_t& ndims,
  38. const Tensor& input,
  39. const Tensor& target,
  40. const std::optional<Tensor>& weight) {
  41. TORCH_CHECK(
  42. (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
  43. "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
  44. input.sizes());
  45. if (ndims <= 1) {
  46. nframe = 1;
  47. dim = ndims == 0 ? 1 : input.size(0);
  48. } else {
  49. nframe = input.size(0);
  50. dim = input.size(1);
  51. }
  52. TORCH_CHECK(
  53. target.dim() <= 1 && target.numel() == nframe,
  54. "inconsistent target size, expected ", nframe, " but got ",
  55. target.sizes());
  56. if (weight && weight->defined()) {
  57. TORCH_CHECK(
  58. weight->dim() <= 1 && weight->numel() == dim,
  59. "inconsistent weight size, expected ", dim, " but got ",
  60. weight->sizes());
  61. }
  62. }
  63. } // namespace at::native