adaround.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import sys
  16. import time
  17. import numpy as np
  18. import paddle
  19. from paddle import static
  20. from ..log_helper import get_logger
  21. from .utils import (
  22. _channelwise_quant_axis1_ops,
  23. bias_correction_w,
  24. calculate_quant_cos_error,
  25. dequant_tensor,
  26. load_variable_data,
  27. quant_tensor,
  28. set_variable_data,
  29. stable_sigmoid,
  30. )
  31. _logger = get_logger(
  32. __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
  33. )
  34. GAMMA = -0.1
  35. ZETA = 1.1
  36. def compute_soft_rounding(alpha_v):
  37. return paddle.clip(
  38. paddle.nn.functional.sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA,
  39. min=0,
  40. max=1,
  41. )
  42. def compute_soft_rounding_np(alpha_v):
  43. return np.clip(
  44. stable_sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, a_min=0, a_max=1
  45. )
  46. class AdaRoundLoss:
  47. def __init__(self, reg_param=0.01, default_beta_range=(20, 2)):
  48. self.default_reg_param = reg_param
  49. self.default_beta_range = default_beta_range
  50. def compute_recon_loss(self, ada_quantized_output, orig_output):
  51. square_cost = paddle.nn.functional.square_error_cost(
  52. ada_quantized_output, orig_output
  53. )
  54. recon_loss = paddle.mean(paddle.sum(square_cost, axis=-1))
  55. return recon_loss
  56. def compute_round_loss(self, alpha_v, warm_start, beta):
  57. def round_loss_fn():
  58. # compute rectified sigmoid of parameter 'alpha' which maps it between zero and one
  59. h_v = compute_soft_rounding(alpha_v)
  60. # calculate regularization term - which ensures parameter to converge to exactly zeros and ones
  61. # at the end of optimization
  62. reg_term = paddle.sum(
  63. -paddle.pow(paddle.abs(2 * h_v - 1), beta) + 1
  64. )
  65. # calculate the rounding loss
  66. round_loss = self.default_reg_param * reg_term
  67. return round_loss
  68. round_loss = static.nn.cond(
  69. warm_start,
  70. lambda: paddle.full(shape=[1], dtype='float32', fill_value=0.0),
  71. round_loss_fn,
  72. )
  73. return round_loss
  74. def compute_beta(self, max_iter, cur_iter, warm_start):
  75. # Start and stop beta for annealing of rounding loss (start_beta, end_beta)
  76. start_beta, end_beta = self.default_beta_range
  77. # iteration at end of warm start period, which is 20% of max iterations
  78. warm_start_end_iter = warm_start * max_iter
  79. # compute relative iteration of current iteration
  80. rel_iter = (cur_iter - warm_start_end_iter) / (
  81. max_iter - warm_start_end_iter
  82. )
  83. beta = end_beta + 0.5 * (start_beta - end_beta) * (
  84. 1 + np.cos(rel_iter * np.pi)
  85. )
  86. return beta
  87. class AdaRound:
  88. def __init__(
  89. self,
  90. scale,
  91. weight_tensor,
  92. scope=None,
  93. weight_var_name=None,
  94. weight_op_type=None,
  95. is_train=True,
  96. num_iterations=1000,
  97. ):
  98. self.is_train = is_train
  99. self.num_iterations = num_iterations
  100. self.warm_start = 0.1
  101. self.weight_bits = 8
  102. self.offset = 0.0 # zero-point offset
  103. self.adaround_loss = AdaRoundLoss()
  104. self.ori_weight_tensor = weight_tensor
  105. self.scale = scale
  106. self.scope = scope
  107. self.quant_axis = 0
  108. if weight_op_type in _channelwise_quant_axis1_ops:
  109. self.quant_axis = 1
  110. self.weight_var_name = weight_var_name
  111. self.alpha_name = weight_var_name + ".alpha"
  112. self.initialize_alpha(weight_tensor.copy(), scale, weight_var_name)
  113. def initialize_alpha(self, tensor, scale, var_name):
  114. """
  115. Initializes alpha parameter, same shape as the weight tensor
  116. """
  117. tensor_scale = quant_tensor(tensor, scale, quant_axis=self.quant_axis)
  118. tensor_floor = np.floor(tensor_scale)
  119. tensor = tensor_scale - tensor_floor
  120. alpha = -np.log((ZETA - GAMMA) / (tensor - GAMMA) - 1)
  121. self.alpha_v = paddle.create_parameter(
  122. shape=alpha.shape,
  123. dtype="float32",
  124. name=var_name + ".alpha",
  125. default_initializer=paddle.nn.initializer.Assign(alpha),
  126. )
  127. def _calculate_output_with_adarounded_weights(
  128. self, program, place, exe, data, fp32_fetch_list, weight_tensor_dequant
  129. ):
  130. set_variable_data(
  131. self.scope, place, self.weight_var_name, weight_tensor_dequant
  132. )
  133. adaround_out_tensor = exe.run(
  134. program=program,
  135. feed=data,
  136. fetch_list=[fp32_fetch_list],
  137. return_numpy=True,
  138. scope=self.scope,
  139. )
  140. return adaround_out_tensor
  141. def _calculate_quant_weight(self):
  142. np_alpha = load_variable_data(self.scope, self.alpha_name)
  143. h_alpha = compute_soft_rounding_np(np_alpha)
  144. # Scale the tensor
  145. tensor_scale = quant_tensor(
  146. self.ori_weight_tensor.copy(),
  147. self.scale,
  148. quant_axis=self.quant_axis,
  149. )
  150. weight_tensor = np.floor(tensor_scale)
  151. # Adaround the tensor
  152. weight_tensor_quant = np.add(weight_tensor, h_alpha)
  153. return weight_tensor_quant
  154. def _calculate_adarounded_weights(self):
  155. weight_tensor_quant = self._calculate_quant_weight()
  156. # Dequantize the tensor
  157. weight_tensor_dequant = dequant_tensor(
  158. weight_tensor_quant + self.offset,
  159. self.scale,
  160. quant_axis=self.quant_axis,
  161. )
  162. return weight_tensor_dequant
  163. def update_final_weights(self):
  164. weight_tensor_quant = self._calculate_quant_weight()
  165. return weight_tensor_quant
  166. def get_loss(self, beta, warm_start, adaround_out_tensor, orig_out_tensor):
  167. round_loss = self.adaround_loss.compute_round_loss(
  168. self.alpha_v, warm_start, beta
  169. )
  170. recon_loss = self.adaround_loss.compute_recon_loss(
  171. adaround_out_tensor, orig_out_tensor
  172. )
  173. loss = round_loss + recon_loss
  174. losses = {
  175. 'loss': loss,
  176. 'round_loss': round_loss,
  177. 'recon_loss': recon_loss,
  178. }
  179. return losses
  180. def update_beta_warm(self, cur_iteration):
  181. warm_start = cur_iteration < self.num_iterations * self.warm_start
  182. beta = self.adaround_loss.compute_beta(
  183. self.num_iterations, cur_iteration, self.warm_start
  184. )
  185. return beta, warm_start
  186. def run_adaround(
  187. data_loader,
  188. fp32_program,
  189. fetch_list,
  190. exe,
  191. scope,
  192. place,
  193. quantized_op_pairs,
  194. weight_op_pairs,
  195. scale_dict,
  196. num_iterations=1000,
  197. lr=0.001,
  198. bias_correction=False,
  199. fast_mode=True,
  200. ):
  201. fetch_op_name = fetch_list[0].name
  202. final_weight_tensor_quant_dict = {}
  203. for weight_var_name, quant_op_out_name in quantized_op_pairs.items():
  204. _logger.info(f'Start adaround op: {weight_var_name}')
  205. weight_op_type = weight_op_pairs[weight_var_name]
  206. # get scale and weight tensor
  207. weight_var_tensor = load_variable_data(scope, weight_var_name)
  208. scale = scale_dict[weight_var_name]
  209. fp32_fetch_list = None
  210. for _op in fp32_program.global_block().ops:
  211. if _op.type == "fetch":
  212. _op._rename_input(fetch_op_name, quant_op_out_name)
  213. fp32_fetch_list = fp32_program.global_block().var(
  214. quant_op_out_name
  215. )
  216. fetch_op_name = quant_op_out_name
  217. # build adaround program
  218. startup_program = static.Program()
  219. train_program = static.Program()
  220. with static.program_guard(train_program, startup_program):
  221. with paddle.utils.unique_name.guard():
  222. # initialize adaround
  223. adaround = AdaRound(
  224. scale,
  225. weight_var_tensor,
  226. scope=scope,
  227. weight_var_name=weight_var_name,
  228. weight_op_type=weight_op_type,
  229. num_iterations=num_iterations,
  230. )
  231. orig_out_tensor = static.data(
  232. name='orig_out_tensor',
  233. shape=(-1,) + fp32_fetch_list.shape,
  234. dtype='float32',
  235. )
  236. adaround_out_tensor = static.data(
  237. name='adaround_out_tensor',
  238. shape=(-1,) + fp32_fetch_list.shape,
  239. dtype='float32',
  240. )
  241. beta_tensor = static.data(
  242. name='beta', shape=[-1, 1], dtype='float32'
  243. )
  244. warm_start_tensor = static.data(
  245. name='warm_start', shape=[-1, 1], dtype='bool'
  246. )
  247. train_fetches_loss = adaround.get_loss(
  248. beta_tensor,
  249. warm_start_tensor,
  250. adaround_out_tensor,
  251. orig_out_tensor,
  252. )
  253. optimizer = paddle.optimizer.Adam(learning_rate=lr)
  254. loss = train_fetches_loss['loss']
  255. optimizer.minimize(loss)
  256. exe.run(startup_program)
  257. start_time = time.time()
  258. prev_start_time = start_time
  259. for i, data in enumerate(data_loader()):
  260. prev_start_time = start_time
  261. start_time = time.time()
  262. # run fp32 model
  263. np_orig_out_tensor = exe.run(
  264. program=fp32_program,
  265. feed=data,
  266. fetch_list=[fp32_fetch_list],
  267. return_numpy=True,
  268. scope=scope,
  269. )
  270. adaround_weight_tensor_dequant = (
  271. adaround._calculate_adarounded_weights()
  272. )
  273. np_adaround_out_tensor = (
  274. adaround._calculate_output_with_adarounded_weights(
  275. fp32_program,
  276. place,
  277. exe,
  278. data,
  279. fp32_fetch_list,
  280. adaround_weight_tensor_dequant,
  281. )
  282. )
  283. # If the cosine distance of the two tensor is small, skip training
  284. cos_error = calculate_quant_cos_error(
  285. np_orig_out_tensor[0], np_adaround_out_tensor[0]
  286. )
  287. if fast_mode and cos_error > 0.99:
  288. _logger.info("The cosine error is small, skip training.")
  289. break
  290. beta, warm_start = adaround.update_beta_warm(i)
  291. feed_dict = {
  292. 'orig_out_tensor': np_orig_out_tensor[0],
  293. 'adaround_out_tensor': np_adaround_out_tensor[0],
  294. 'beta': beta,
  295. 'warm_start': warm_start,
  296. }
  297. out = exe.run(
  298. train_program,
  299. feed=feed_dict,
  300. fetch_list=[v.name for v in train_fetches_loss.values()],
  301. return_numpy=True,
  302. )
  303. _logger.info(
  304. f"Iter {i:d}, lr {lr:.5f}, loss {np.mean(out[0]):.5f}, loss_round {np.mean(out[1]):.5f}, loss_recon {np.mean(out[2]):.5f}, time {start_time - prev_start_time:.5f}s"
  305. )
  306. sys.stdout.flush()
  307. if i == num_iterations:
  308. break
  309. final_weight_tensor_quant_dict[
  310. weight_var_name
  311. ] = adaround.update_final_weights()
  312. if bias_correction:
  313. final_weight_tensor_quant_dict[weight_var_name] = bias_correction_w(
  314. weight_var_tensor,
  315. final_weight_tensor_quant_dict[weight_var_name],
  316. scale,
  317. adaround.quant_axis,
  318. weight_bits=adaround.weight_bits,
  319. )
  320. del adaround
  321. # update adarounded calibrated weights
  322. for weight_var_name in quantized_op_pairs.keys():
  323. set_variable_data(
  324. scope,
  325. place,
  326. weight_var_name,
  327. final_weight_tensor_quant_dict[weight_var_name],
  328. )