wkv_cuda.cu 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. #include <stdio.h>
  2. #include <assert.h>
  3. #define MIN_VALUE (-1e38)
  4. template <typename F>
  5. __global__ void kernel_forward(
  6. const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
  7. const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y
  8. ) {
  9. const int idx = blockIdx.x * blockDim.x + threadIdx.x;
  10. const int _b = idx / C;
  11. const int _c = idx % C;
  12. const int _offset = _b * T * C + _c;
  13. F u = _u[_c];
  14. F w = _w[_c];
  15. const F *__restrict__ const k = _k + _offset;
  16. const F *__restrict__ const v = _v + _offset;
  17. F *__restrict__ const y = _y + _offset;
  18. // aa and bb are running sums divided by exp(pp) (to avoid overflow)
  19. F aa = 0, bb = 0, pp = MIN_VALUE;
  20. for (int i = 0; i < T; i++) {
  21. const int ii = i * C;
  22. const F kk = k[ii];
  23. const F vv = v[ii];
  24. F ww = u + kk;
  25. F p = max(pp, ww);
  26. F e1 = exp(pp - p);
  27. F e2 = exp(ww - p);
  28. y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
  29. ww = w + pp;
  30. p = max(ww, kk);
  31. e1 = exp(ww - p);
  32. e2 = exp(kk - p);
  33. aa = e1 * aa + e2 * vv;
  34. bb = e1 * bb + e2;
  35. pp = p;
  36. }
  37. }
  38. template <typename F>
  39. __global__ void kernel_forward_with_state(
  40. const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
  41. const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y, F *__restrict__ const _s
  42. ) {
  43. const int idx = blockIdx.x * blockDim.x + threadIdx.x;
  44. const int _b = idx / C;
  45. const int _c = idx % C;
  46. const int _offset_s = _b * C * 3 + _c * 3;
  47. const int _offset = _b * T * C + _c;
  48. F u = _u[_c];
  49. F w = _w[_c];
  50. const F *__restrict__ const k = _k + _offset;
  51. const F *__restrict__ const v = _v + _offset;
  52. F *__restrict__ const y = _y + _offset;
  53. F *__restrict__ const s = _s + _offset_s;
  54. // aa and bb are running sums divided by exp(pp) (to avoid overflow)
  55. F aa = s[0], bb = s[1], pp = s[2];
  56. for (int i = 0; i < T; i++) {
  57. const int ii = i * C;
  58. const F kk = k[ii];
  59. const F vv = v[ii];
  60. F ww = u + kk;
  61. F p = max(pp, ww);
  62. F e1 = exp(pp - p);
  63. F e2 = exp(ww - p);
  64. y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
  65. ww = w + pp;
  66. p = max(ww, kk);
  67. e1 = exp(ww - p);
  68. e2 = exp(kk - p);
  69. aa = e1 * aa + e2 * vv;
  70. bb = e1 * bb + e2;
  71. pp = p;
  72. }
  73. s[0] = aa;
  74. s[1] = bb;
  75. s[2] = pp;
  76. }
  77. template <typename F>
  78. __global__ void kernel_backward(
  79. const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
  80. const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _y,
  81. const F *__restrict__ const _gy, F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk,
  82. F *__restrict__ const _gv
  83. ) {
  84. const int idx = blockIdx.x * blockDim.x + threadIdx.x;
  85. const int _b = idx / C;
  86. const int _c = idx % C;
  87. const int _offset = _b * T * C + _c;
  88. F u = _u[_c];
  89. F w = _w[_c];
  90. const F *__restrict__ const k = _k + _offset;
  91. const F *__restrict__ const v = _v + _offset;
  92. const F *__restrict__ const y = _y + _offset;
  93. const F *__restrict__ const gy = _gy + _offset;
  94. F *__restrict__ const gk = _gk + _offset;
  95. F *__restrict__ const gv = _gv + _offset;
  96. F q[Tmax], r[Tmax];
  97. F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
  98. for (int i = 0; i < T; i++) {
  99. const int ii = i * C;
  100. const F kk = k[ii];
  101. const F vv = v[ii];
  102. const F yy = y[ii];
  103. F ww = u + kk;
  104. F p = max(pp, ww);
  105. F e1 = exp(pp - p);
  106. F e2 = exp(ww - p);
  107. const F qq = gy[ii] / (e1 * bb + e2);
  108. gw += (ga - gb * yy) * e1 * qq;
  109. gu += (vv - yy) * e2 * qq;
  110. q[i] = qq;
  111. r[i] = ww - p;
  112. ww = w + pp;
  113. p = max(ww, kk);
  114. e1 = exp(ww - p);
  115. e2 = exp(kk - p);
  116. ga = e1 * (aa + ga);
  117. gb = e1 * (bb + gb);
  118. aa = e1 * aa + e2 * vv;
  119. bb = e1 * bb + e2;
  120. pp = p;
  121. }
  122. const int _offsetBC = _b * C + _c;
  123. _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
  124. _gu[_offsetBC] = gu;
  125. aa = 0, bb = 0, pp = MIN_VALUE;
  126. for (int i = T - 1; i >= 0; i--) {
  127. const int ii = i * C;
  128. const F kk = k[ii];
  129. const F vv = v[ii];
  130. const F yy = y[ii];
  131. const F qq = q[i];
  132. const F rr = r[i];
  133. F e1 = qq * exp(rr);
  134. F e2 = exp(kk + pp);
  135. gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
  136. gv[ii] = e1 + e2 * aa;
  137. const F ww = w + pp;
  138. const F www = rr - u - kk;
  139. const F p = max(ww, www);
  140. e1 = exp(ww - p);
  141. e2 = qq * exp(www - p);
  142. aa = e1 * aa + e2;
  143. bb = e1 * bb - e2 * yy;
  144. pp = p;
  145. }
  146. }
  147. void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
  148. dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
  149. assert(B * C % threadsPerBlock.x == 0);
  150. dim3 numBlocks(B * C / threadsPerBlock.x);
  151. kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
  152. }
  153. void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s) {
  154. dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
  155. assert(B * C % threadsPerBlock.x == 0);
  156. dim3 numBlocks(B * C / threadsPerBlock.x);
  157. kernel_forward_with_state<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, s);
  158. }
  159. void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
  160. dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
  161. assert(B * C % threadsPerBlock.x == 0);
  162. dim3 numBlocks(B * C / threadsPerBlock.x);
  163. kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
  164. }