complex.h 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. #pragma once
  2. #include <complex>
  3. #include <c10/macros/Macros.h>
  4. #include <c10/util/Half.h>
  5. #include <torch/headeronly/util/complex.h>
  6. // std functions
  7. //
  8. // The implementation of these functions also follow the design of C++20
  9. namespace std {
  10. template <typename T>
  11. constexpr T real(const c10::complex<T>& z) {
  12. return z.real();
  13. }
  14. template <typename T>
  15. constexpr T imag(const c10::complex<T>& z) {
  16. return z.imag();
  17. }
  18. template <typename T>
  19. C10_HOST_DEVICE T abs(const c10::complex<T>& z) {
  20. #if defined(__CUDACC__) || defined(__HIPCC__)
  21. return thrust::abs(static_cast<thrust::complex<T>>(z));
  22. #else
  23. return std::abs(static_cast<std::complex<T>>(z));
  24. #endif
  25. }
  26. #if defined(USE_ROCM)
  27. #define ROCm_Bug(x)
  28. #else
  29. #define ROCm_Bug(x) x
  30. #endif
  31. template <typename T>
  32. C10_HOST_DEVICE T arg(const c10::complex<T>& z) {
  33. return ROCm_Bug(std)::atan2(std::imag(z), std::real(z));
  34. }
  35. #undef ROCm_Bug
  36. template <typename T>
  37. constexpr T norm(const c10::complex<T>& z) {
  38. return z.real() * z.real() + z.imag() * z.imag();
  39. }
  40. // For std::conj, there are other versions of it:
  41. // constexpr std::complex<float> conj( float z );
  42. // template< class DoubleOrInteger >
  43. // constexpr std::complex<double> conj( DoubleOrInteger z );
  44. // constexpr std::complex<long double> conj( long double z );
  45. // These are not implemented
  46. // TODO(@zasdfgbnm): implement them as c10::conj
  47. template <typename T>
  48. constexpr c10::complex<T> conj(const c10::complex<T>& z) {
  49. return c10::complex<T>(z.real(), -z.imag());
  50. }
  51. // Thrust does not have complex --> complex version of thrust::proj,
  52. // so this function is not implemented at c10 right now.
  53. // TODO(@zasdfgbnm): implement it by ourselves
  54. // There is no c10 version of std::polar, because std::polar always
  55. // returns std::complex. Use c10::polar instead;
  56. } // namespace std
  57. #define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
  58. // math functions are included in a separate file
  59. #include <c10/util/complex_math.h> // IWYU pragma: keep
  60. // utilities for complex types
  61. #include <c10/util/complex_utils.h> // IWYU pragma: keep
  62. #undef C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H