EmbeddingBag.h 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. #include <ATen/core/Tensor.h>
  2. #include <ATen/Config.h>
  3. #include <cstdint>
  4. #ifdef USE_FBGEMM
  5. #include <fbgemm/FbgemmEmbedding.h>
  6. #endif
  7. namespace at::native {
  8. enum class EmbeddingBagMode {
  9. SUM = 0,
  10. MEAN = 1,
  11. MAX = 2,
  12. };
  13. [[maybe_unused]] static bool operator==(int64_t op1, EmbeddingBagMode op2) {
  14. return op1 == static_cast<int64_t>(op2);
  15. }
  16. [[maybe_unused]] static bool operator!=(int64_t op1, EmbeddingBagMode op2) {
  17. return !(op1 == op2);
  18. }
  19. void check_arguments(
  20. const Tensor& weight,
  21. const Tensor& indices,
  22. const Tensor& offsets,
  23. const int64_t mode,
  24. const std::optional<Tensor>& per_sample_weights,
  25. bool include_last_offset);
  26. void make_bag_size_out(
  27. Tensor& bag_size_out,
  28. const Tensor& offsets,
  29. const Tensor& indices,
  30. const int64_t mode,
  31. const bool include_last_offset,
  32. const bool requires_grad);
  33. void make_max_indices_out(
  34. Tensor& max_indices_out,
  35. const Tensor& weight,
  36. const Tensor& indices,
  37. const Tensor& offsets,
  38. const Tensor& bag_size,
  39. const int64_t mode,
  40. bool include_last_offset);
  41. void make_offset2bag_out(
  42. Tensor& offset2bag,
  43. Tensor& output,
  44. const Tensor& weight,
  45. const Tensor& indices,
  46. const Tensor& offsets,
  47. const int64_t mode,
  48. const std::optional<Tensor>& per_sample_weights,
  49. const int64_t padding_idx = -1);
  50. #ifdef USE_FBGEMM
  51. template<bool has_weight, typename TIndex, typename TData>
  52. struct _CallbackAndBlockSize {
  53. using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature<TData, TIndex, TIndex, TData>::Type;
  54. int64_t blockSize = -1;
  55. TCallback callback = nullptr;
  56. static TCallback generateCallback(int64_t block_size) {
  57. return fbgemm::GenerateEmbeddingSpMDM<TData, TIndex, TIndex, TData>(
  58. block_size,
  59. has_weight,
  60. /* normalize_by_lengths */false,
  61. /* prefetch */16,
  62. /* is_weight_positional */false,
  63. /* use_offsets */true);
  64. }
  65. _CallbackAndBlockSize() = default;
  66. explicit _CallbackAndBlockSize(std::optional<int64_t> maybe_block_size)
  67. : blockSize(maybe_block_size.value_or(-1))
  68. , callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr)
  69. {}
  70. };
  71. template<typename... StorageMixins>
  72. struct _EmbeddingBagKernelCacheImpl : private StorageMixins... {
  73. _EmbeddingBagKernelCacheImpl() = default;
  74. // use each of the mixins to store corresponding kernel and block size
  75. explicit _EmbeddingBagKernelCacheImpl(std::optional<int64_t> maybe_block_size)
  76. : StorageMixins(maybe_block_size)...
  77. {}
  78. // this method is thread safe (call sites may call from different threads)
  79. template<bool has_weight, typename TIndex, typename TData>
  80. typename _CallbackAndBlockSize<has_weight, TIndex, TData>::TCallback
  81. getCallback(int64_t block_size) const {
  82. // if the cache doesn't store the kernel for the incoming block size
  83. // (so it is different from the one stored in corresponding mixin)
  84. // regenerate the kernel (not writing it into the cache so we avoid locks)
  85. if (block_size != _CallbackAndBlockSize<has_weight, TIndex, TData>::blockSize) {
  86. return _CallbackAndBlockSize<has_weight, TIndex, TData>::generateCallback(block_size);
  87. }
  88. // else retrieve the cached kernel from the corresponding mixin
  89. return _CallbackAndBlockSize<has_weight, TIndex, TData>::callback;
  90. }
  91. };
  92. // instantiate the cache with the list of storage mixins
  93. // for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file
  94. using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl<
  95. _CallbackAndBlockSize<true, int32_t, float>,
  96. _CallbackAndBlockSize<false, int32_t, float>,
  97. _CallbackAndBlockSize<true, int64_t, float>,
  98. _CallbackAndBlockSize<false, int64_t, float>,
  99. _CallbackAndBlockSize<true, int32_t, unsigned short>,
  100. _CallbackAndBlockSize<false, int32_t, unsigned short>,
  101. _CallbackAndBlockSize<true, int64_t, unsigned short>,
  102. _CallbackAndBlockSize<false, int64_t, unsigned short>>;
  103. #else
  104. struct _EmbeddingBagKernelCache {
  105. explicit _EmbeddingBagKernelCache(std::optional<int64_t> /* maybe_block_size */) {}
  106. };
  107. #endif
  108. void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
  109. Tensor& bag_size, Tensor* max_indices,
  110. const Tensor &weight, const Tensor &indices,
  111. const Tensor &offsets, const int64_t mode = 0,
  112. const std::optional<Tensor>& per_sample_weights = std::nullopt,
  113. bool include_last_offset = false,
  114. int64_t padding_idx = -1,
  115. _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
  116. void _embedding_bag_cpu_out(
  117. at::Tensor& output,
  118. at::Tensor& offset2bag,
  119. at::Tensor& bag_size,
  120. at::Tensor* p_max_indices,
  121. const at::Tensor& weight,
  122. const at::Tensor& indices,
  123. const at::Tensor& offsets,
  124. const bool scale_grad_by_freq,
  125. const int64_t mode,
  126. const bool sparse,
  127. const std::optional<at::Tensor>& per_sample_weights,
  128. const bool include_last_offset,
  129. const std::optional<int64_t>& padding_idx,
  130. _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
  131. } // namespace at::native