random.h 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. // Philox Counter based RNG implementation for Metal
  2. // Borrowed from aten/src/ATen/core/PhiloxRNGEngine.h
  3. // Which in turn borrowed from
  4. // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
  5. #pragma once
  6. #include <metal_stdlib>
  7. namespace c10 {
  8. namespace metal {
  9. namespace detail {
  10. constexpr float uint32_to_uniform_float(uint32_t value) {
  11. // maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
  12. constexpr float scale = 4.6566127342e-10;
  13. return static_cast<float>(value & 0x7FFFFFFF) * scale;
  14. }
  15. inline uint2 splitlong(ulong v) {
  16. return uint2(v >> 32, v & 0xffffffff);
  17. }
  18. } // namespace detail
  19. namespace philox4 {
  20. uint2 mulhilo(uint a, uint b) {
  21. auto rc = static_cast<ulong>(a) * b;
  22. return detail::splitlong(rc);
  23. }
  24. uint4 single_round(uint4 ctr, uint2 key) {
  25. constexpr uint kPhiloxSA = 0xD2511F53;
  26. constexpr uint kPhiloxSB = 0xCD9E8D57;
  27. auto rc0 = mulhilo(kPhiloxSA, ctr.x);
  28. auto rc1 = mulhilo(kPhiloxSB, ctr.z);
  29. return uint4(rc1.y ^ ctr.y ^ key.x, rc1.x, rc0.y ^ ctr.w ^ key.y, rc0.x);
  30. }
  31. uint4 multiple_rounds(uint4 ctr, uint2 key, uint rounds) {
  32. constexpr uint2 kPhilox10 = {0x9E3779B9, 0xBB67AE85};
  33. for (uint round = 0; round < rounds - 1; ++round) {
  34. ctr = single_round(ctr, key);
  35. key += kPhilox10;
  36. }
  37. return ctr;
  38. }
  39. uint4 rand(long seed, long index) {
  40. uint4 ctr = 0;
  41. ctr.zw = detail::splitlong(index);
  42. return multiple_rounds(ctr, detail::splitlong(seed), 10);
  43. }
  44. } // namespace philox4
  45. float randn(long seed, long index) {
  46. auto value = philox4::rand(seed, index);
  47. float u1 = 1.0 - detail::uint32_to_uniform_float(value.x);
  48. float u2 = 1.0 - detail::uint32_to_uniform_float(value.y);
  49. return ::metal::sqrt(-2.0 * ::metal::log(u1)) *
  50. ::metal::cos(2.0 * M_PI_F * u2);
  51. }
  52. float rand(long seed, long index) {
  53. auto value = philox4::rand(seed, index);
  54. return detail::uint32_to_uniform_float(value.x);
  55. }
  56. long randint64(long seed, long index, long low, long high) {
  57. auto range = high - low;
  58. auto value = philox4::rand(seed, index);
  59. // TODO: Implement better algorithm for large ranges
  60. return low +
  61. static_cast<long>(detail::uint32_to_uniform_float(value.x) * range);
  62. }
  63. } // namespace metal
  64. } // namespace c10