lsq.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import paddle
  16. from paddle.autograd import PyLayer
  17. from paddle.framework import ParamAttr
  18. from paddle.nn.initializer import Constant
  19. from paddle.utils import unique_name
  20. from ..layer.layers import Layer
  21. def round(x):
  22. sign = paddle.sign(x)
  23. x = sign * paddle.floor(paddle.abs(x) + 0.5)
  24. return x
  25. class LsqFunc(PyLayer):
  26. @staticmethod
  27. def forward(ctx, weight, alpha, g, Qn, Qp, per_channel=False, quant_axis=0):
  28. ctx.save_for_backward(weight, alpha)
  29. ctx.other = g, Qn, Qp, per_channel, quant_axis
  30. if per_channel:
  31. sizes = weight.shape
  32. weight = weight.reshape((weight.shape[quant_axis], -1))
  33. weight = weight.transpose((1, 0))
  34. alpha = paddle.broadcast_to(alpha, weight.shape)
  35. quant_w = round(paddle.divide(weight, alpha)).clip(Qn, Qp)
  36. quant_w = quant_w * alpha
  37. quant_w = quant_w.transpose((1, 0))
  38. quant_w = quant_w.reshape(sizes)
  39. else:
  40. quant_w = round(paddle.divide(weight, alpha)).clip(Qn, Qp)
  41. quant_w = quant_w * alpha
  42. return quant_w
  43. @staticmethod
  44. def backward(ctx, grad_weight):
  45. weight, alpha = ctx.saved_tensor()
  46. g, Qn, Qp, per_channel, quant_axis = ctx.other
  47. if per_channel:
  48. sizes = weight.shape
  49. weight = weight.reshape((weight.shape[quant_axis], -1))
  50. weight = weight.transpose((1, 0))
  51. alpha = paddle.broadcast_to(alpha, weight.shape)
  52. q_w = paddle.divide(weight, alpha)
  53. q_w = q_w.transpose((1, 0))
  54. q_w = q_w.reshape(sizes)
  55. else:
  56. q_w = paddle.divide(weight, alpha)
  57. lower_flag = paddle.cast((q_w < Qn), 'float32')
  58. upper_flag = paddle.cast((q_w > Qp), 'float32')
  59. middle_flag = 1.0 - lower_flag - upper_flag
  60. if per_channel:
  61. grad_alpha = (
  62. (
  63. lower_flag * Qn
  64. + upper_flag * Qp
  65. + middle_flag * round(q_w)
  66. - middle_flag * q_w
  67. )
  68. * grad_weight
  69. * g
  70. )
  71. grad_alpha = grad_alpha.reshape(
  72. (grad_alpha.shape[quant_axis], -1)
  73. ).sum(axis=1)
  74. else:
  75. grad_alpha = (
  76. (
  77. (
  78. lower_flag * Qn
  79. + upper_flag * Qp
  80. + middle_flag * round(q_w)
  81. - middle_flag * q_w
  82. )
  83. * grad_weight
  84. * g
  85. )
  86. .sum()
  87. .unsqueeze(axis=0)[0]
  88. )
  89. grad_weight = middle_flag * grad_weight
  90. return grad_weight, grad_alpha
  91. class LsqPlusActFunc(PyLayer):
  92. @staticmethod
  93. def forward(ctx, x, alpha, beta, g, Qn, Qp):
  94. ctx.save_for_backward(x, alpha, beta)
  95. ctx.other = g, Qn, Qp
  96. quant_x = round(paddle.divide((x - beta), alpha)).clip(Qn, Qp)
  97. return quant_x * alpha + beta
  98. @staticmethod
  99. def backward(ctx, grad_x):
  100. x, alpha, beta = ctx.saved_tensor()
  101. g, Qn, Qp = ctx.other
  102. q_x = (x - beta) / alpha
  103. lower_flag = paddle.cast((q_x < Qn), 'float32')
  104. upper_flag = paddle.cast((q_x > Qp), 'float32')
  105. middle_flag = 1.0 - lower_flag - upper_flag
  106. grad_alpha = (
  107. (
  108. (
  109. lower_flag * Qn
  110. + upper_flag * Qp
  111. + middle_flag * round(q_x)
  112. - middle_flag * q_x
  113. )
  114. * grad_x
  115. * g
  116. )
  117. .sum()
  118. .unsqueeze(axis=0)[0]
  119. )
  120. grad_beta = (
  121. ((lower_flag + upper_flag) * grad_x * g).sum().unsqueeze(axis=0)[0]
  122. )
  123. grad_x = middle_flag * grad_x
  124. return grad_x, grad_alpha, grad_beta
  125. class FakeQuantActLSQPlus(Layer):
  126. def __init__(
  127. self,
  128. quant_bits,
  129. all_positive=False,
  130. symmetric=False,
  131. batch_init=20,
  132. dtype='float32',
  133. name=None,
  134. reduce_type=None,
  135. ):
  136. super().__init__()
  137. '''
  138. Args:
  139. quant_bits(int): quantization bit number for weights.
  140. all_positive(bool): whether unsigned or signed quantization, where True for unsigned quantization and False for signed quantization.
  141. symmetric(bool): whether symmetric or asymmetric quantization.
  142. batch_init(int): number of batches that collect Gaussian approximation for the weight distribution in each layer.
  143. dtype(str): data type.
  144. name(str): the name of the weight.
  145. reduce_type(str): the reduce type which is needed when parallel training.
  146. '''
  147. self.bits = quant_bits
  148. self.all_positive = all_positive
  149. self.symmetric = symmetric
  150. self.batch_init = batch_init
  151. self.name = name
  152. self.reduce_type = reduce_type
  153. if self.all_positive:
  154. # unsigned activation
  155. self.Qn = 0
  156. self.Qp = 2**self.bits - 1
  157. else:
  158. # signed activation
  159. self.Qn = -(2 ** (self.bits - 1))
  160. self.Qp = 2 ** (self.bits - 1) - 1
  161. scale_prefix = f"{name}.scale" if name else 'quant_dequant.scale'
  162. self._scale_name = unique_name.generate(scale_prefix)
  163. s_attr = ParamAttr(
  164. name=self._scale_name, initializer=Constant(1.0), trainable=True
  165. )
  166. self.s = self.create_parameter(shape=[], attr=s_attr, dtype='float32')
  167. self.s.stop_gradient = False
  168. if not self.symmetric:
  169. beta_prefix = f"{name}.beta" if name else 'quant_dequant.beta'
  170. self._beta_name = unique_name.generate(beta_prefix)
  171. beta_attr = ParamAttr(
  172. name=self._beta_name, initializer=Constant(0.0), trainable=True
  173. )
  174. self.beta = self.create_parameter(
  175. shape=[], attr=beta_attr, dtype='float32'
  176. )
  177. self.beta.stop_gradient = False
  178. self.init_state = 0
  179. def forward(self, activation):
  180. if self.reduce_type == "max":
  181. paddle.distributed.all_reduce(
  182. self.s, op=paddle.distributed.ReduceOp.MAX
  183. )
  184. if not self.symmetric and self.reduce_type == "max":
  185. paddle.distributed.all_reduce(
  186. self.beta, op=paddle.distributed.ReduceOp.MAX
  187. )
  188. if self.init_state == 0:
  189. self.g = paddle.to_tensor(
  190. 1.0 / math.sqrt(activation.numel() * self.Qp)
  191. )
  192. min_a = paddle.min(activation.detach())
  193. max_a = paddle.max(activation.detach())
  194. self.s.set_value((max_a - min_a) / (self.Qp - self.Qn))
  195. if not self.symmetric:
  196. self.beta.set_value(min_a - self.s * self.Qn)
  197. self.init_state += 1
  198. elif self.init_state < self.batch_init:
  199. min_a = paddle.min(activation.detach())
  200. max_a = paddle.max(activation.detach())
  201. self.s.set_value(
  202. self.s * 0.9 + 0.1 * (max_a - min_a) / (self.Qp - self.Qn)
  203. )
  204. if not self.symmetric:
  205. self.beta.set_value(
  206. self.s * 0.9 + 0.1 * (min_a - self.s * self.Qn)
  207. )
  208. self.init_state += 1
  209. else:
  210. self.init_state += 1
  211. activation.stop_gradient = False
  212. if not self.symmetric:
  213. q_a = LsqPlusActFunc.apply(
  214. activation, self.s, self.beta, self.g, self.Qn, self.Qp
  215. )
  216. else:
  217. q_a = LsqFunc.apply(
  218. activation, self.s, self.g, self.Qn, self.Qp, per_channel=False
  219. )
  220. return q_a
  221. class FakeQuantWeightLSQPlus(Layer):
  222. def __init__(
  223. self,
  224. quant_bits,
  225. all_positive=False,
  226. per_channel=False,
  227. batch_init=20,
  228. channel_num=None,
  229. quant_linear=False,
  230. dtype='float32',
  231. name=None,
  232. reduce_type=None,
  233. ):
  234. super().__init__()
  235. '''
  236. Args:
  237. quant_bits(int): quantization bit number for weights.
  238. all_positive(bool): whether unsigned or signed quantization, where True for unsigned quantization and False for signed quantization.
  239. per_channel(bool): whether layer-wise or channel-wise quantization, where True for layer-wise quantization and False for channel-wise quantization.
  240. batch_init(int): number of batches that collect Gaussian approximation for the weight distribution in each layer.
  241. channel_num(int): the channel number of the weight which is needed when per_channel is True.
  242. quant_linear(bool): whether the weight is from Linear.
  243. dtype(str): data type.
  244. name(str): the name of the weight.
  245. reduce_type(str): the reduce type which is needed when parallel training.
  246. '''
  247. self.bits = quant_bits
  248. self.all_positive = all_positive
  249. self.per_channel = per_channel
  250. self.quant_linear = quant_linear
  251. self.batch_init = batch_init
  252. self.name = name
  253. self.quant_axis = 1 if quant_linear else 0
  254. self.collect_axis = 0 if quant_linear else 1
  255. self.reduce_type = reduce_type
  256. if self.all_positive:
  257. # unsigned weight
  258. self.Qn = 0
  259. self.Qp = 2**self.bits - 1
  260. else:
  261. # signed weight
  262. self.Qn = -(2 ** (self.bits - 1))
  263. self.Qp = 2 ** (self.bits - 1) - 1
  264. self.init_state = 0
  265. scale_prefix = f"{name}.scale" if name else 'quant_dequant.scale'
  266. self._scale_name = unique_name.generate(scale_prefix)
  267. s_attr = ParamAttr(
  268. name=self._scale_name, initializer=Constant(1.0), trainable=True
  269. )
  270. self.s = self.create_parameter(
  271. shape=[channel_num], attr=s_attr, dtype=dtype
  272. )
  273. self.s.stop_gradient = False
  274. def forward(self, weight):
  275. if self.reduce_type == "max":
  276. paddle.distributed.all_reduce(
  277. self.s, op=paddle.distributed.ReduceOp.MAX
  278. )
  279. if self.init_state == 0:
  280. self.g = paddle.to_tensor(1.0 / math.sqrt(weight.numel() * self.Qp))
  281. self.div = 2**self.bits - 1
  282. if self.per_channel:
  283. weight_tmp = weight.detach().reshape((weight.shape[0], -1))
  284. mean = paddle.mean(weight_tmp, axis=self.collect_axis)
  285. std = paddle.std(weight_tmp, axis=self.collect_axis)
  286. s = paddle.max(
  287. paddle.stack(
  288. [paddle.abs(mean - 3 * std), paddle.abs(mean + 3 * std)]
  289. ),
  290. axis=0,
  291. )
  292. self.s.set_value(s / self.div)
  293. else:
  294. mean = paddle.mean(weight.detach())
  295. std = paddle.std(weight.detach())
  296. self.s.set_value(
  297. max(
  298. [paddle.abs(mean - 3 * std), paddle.abs(mean + 3 * std)]
  299. )
  300. / self.div
  301. )
  302. self.init_state += 1
  303. elif self.init_state < self.batch_init:
  304. self.div = 2**self.bits - 1
  305. if self.per_channel:
  306. weight_tmp = weight.detach().reshape((weight.shape[0], -1))
  307. mean = paddle.mean(weight_tmp, axis=self.collect_axis)
  308. std = paddle.std(weight_tmp, axis=self.collect_axis)
  309. s = paddle.max(
  310. paddle.stack(
  311. [paddle.abs(mean - 3 * std), paddle.abs(mean + 3 * std)]
  312. ),
  313. axis=0,
  314. )
  315. self.s.set_value(s * 0.9 + 0.1 * s / self.div)
  316. else:
  317. mean = paddle.mean(weight.detach())
  318. std = paddle.std(weight.detach())
  319. self.s.set_value(
  320. self.s * 0.9
  321. + 0.1
  322. * max(
  323. [paddle.abs(mean - 3 * std), paddle.abs(mean + 3 * std)]
  324. )
  325. / self.div
  326. )
  327. self.init_state += 1
  328. elif self.init_state == self.batch_init:
  329. self.init_state += 1
  330. weight.stop_gradient = False
  331. w_q = LsqFunc.apply(
  332. weight,
  333. self.s,
  334. self.g,
  335. self.Qn,
  336. self.Qp,
  337. self.per_channel,
  338. self.quant_axis,
  339. )
  340. return w_q