Gelu.h 843 B

123456789101112131415161718192021222324252627282930313233
  1. #pragma once
  2. #include <c10/util/Exception.h>
  3. #include <string_view>
  4. namespace at::native {
  5. // These constants control the approximation behavior of gelu function.
  6. enum class GeluType {
  7. None, // Baseline Gelu
  8. Tanh, // Tanh Gelu Approximation
  9. END
  10. };
  11. inline GeluType get_gelutype_enum(const std::string_view approximate) {
  12. if (approximate == "none") {
  13. return GeluType::None;
  14. } else if (approximate == "tanh") {
  15. return GeluType::Tanh;
  16. } else {
  17. TORCH_CHECK(false, "approximate argument must be either none or tanh.");
  18. }
  19. }
  20. inline std::string gelutype_to_string(const GeluType type) {
  21. switch(type) {
  22. case GeluType::None: return "none";
  23. case GeluType::Tanh: return "tanh";
  24. default: TORCH_CHECK(false, "unknown GELU type: ", static_cast<int>(type));
  25. }
  26. }
  27. } // namespace at::native