| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375 |
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import logging
- import sys
- import time
- import numpy as np
- import paddle
- from paddle import static
- from ..log_helper import get_logger
- from .utils import (
- _channelwise_quant_axis1_ops,
- bias_correction_w,
- calculate_quant_cos_error,
- dequant_tensor,
- load_variable_data,
- quant_tensor,
- set_variable_data,
- stable_sigmoid,
- )
- _logger = get_logger(
- __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
- )
- GAMMA = -0.1
- ZETA = 1.1
- def compute_soft_rounding(alpha_v):
- return paddle.clip(
- paddle.nn.functional.sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA,
- min=0,
- max=1,
- )
- def compute_soft_rounding_np(alpha_v):
- return np.clip(
- stable_sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, a_min=0, a_max=1
- )
- class AdaRoundLoss:
- def __init__(self, reg_param=0.01, default_beta_range=(20, 2)):
- self.default_reg_param = reg_param
- self.default_beta_range = default_beta_range
- def compute_recon_loss(self, ada_quantized_output, orig_output):
- square_cost = paddle.nn.functional.square_error_cost(
- ada_quantized_output, orig_output
- )
- recon_loss = paddle.mean(paddle.sum(square_cost, axis=-1))
- return recon_loss
- def compute_round_loss(self, alpha_v, warm_start, beta):
- def round_loss_fn():
- # compute rectified sigmoid of parameter 'alpha' which maps it between zero and one
- h_v = compute_soft_rounding(alpha_v)
- # calculate regularization term - which ensures parameter to converge to exactly zeros and ones
- # at the end of optimization
- reg_term = paddle.sum(
- -paddle.pow(paddle.abs(2 * h_v - 1), beta) + 1
- )
- # calculate the rounding loss
- round_loss = self.default_reg_param * reg_term
- return round_loss
- round_loss = static.nn.cond(
- warm_start,
- lambda: paddle.full(shape=[1], dtype='float32', fill_value=0.0),
- round_loss_fn,
- )
- return round_loss
- def compute_beta(self, max_iter, cur_iter, warm_start):
- # Start and stop beta for annealing of rounding loss (start_beta, end_beta)
- start_beta, end_beta = self.default_beta_range
- # iteration at end of warm start period, which is 20% of max iterations
- warm_start_end_iter = warm_start * max_iter
- # compute relative iteration of current iteration
- rel_iter = (cur_iter - warm_start_end_iter) / (
- max_iter - warm_start_end_iter
- )
- beta = end_beta + 0.5 * (start_beta - end_beta) * (
- 1 + np.cos(rel_iter * np.pi)
- )
- return beta
- class AdaRound:
- def __init__(
- self,
- scale,
- weight_tensor,
- scope=None,
- weight_var_name=None,
- weight_op_type=None,
- is_train=True,
- num_iterations=1000,
- ):
- self.is_train = is_train
- self.num_iterations = num_iterations
- self.warm_start = 0.1
- self.weight_bits = 8
- self.offset = 0.0 # zero-point offset
- self.adaround_loss = AdaRoundLoss()
- self.ori_weight_tensor = weight_tensor
- self.scale = scale
- self.scope = scope
- self.quant_axis = 0
- if weight_op_type in _channelwise_quant_axis1_ops:
- self.quant_axis = 1
- self.weight_var_name = weight_var_name
- self.alpha_name = weight_var_name + ".alpha"
- self.initialize_alpha(weight_tensor.copy(), scale, weight_var_name)
- def initialize_alpha(self, tensor, scale, var_name):
- """
- Initializes alpha parameter, same shape as the weight tensor
- """
- tensor_scale = quant_tensor(tensor, scale, quant_axis=self.quant_axis)
- tensor_floor = np.floor(tensor_scale)
- tensor = tensor_scale - tensor_floor
- alpha = -np.log((ZETA - GAMMA) / (tensor - GAMMA) - 1)
- self.alpha_v = paddle.create_parameter(
- shape=alpha.shape,
- dtype="float32",
- name=var_name + ".alpha",
- default_initializer=paddle.nn.initializer.Assign(alpha),
- )
- def _calculate_output_with_adarounded_weights(
- self, program, place, exe, data, fp32_fetch_list, weight_tensor_dequant
- ):
- set_variable_data(
- self.scope, place, self.weight_var_name, weight_tensor_dequant
- )
- adaround_out_tensor = exe.run(
- program=program,
- feed=data,
- fetch_list=[fp32_fetch_list],
- return_numpy=True,
- scope=self.scope,
- )
- return adaround_out_tensor
- def _calculate_quant_weight(self):
- np_alpha = load_variable_data(self.scope, self.alpha_name)
- h_alpha = compute_soft_rounding_np(np_alpha)
- # Scale the tensor
- tensor_scale = quant_tensor(
- self.ori_weight_tensor.copy(),
- self.scale,
- quant_axis=self.quant_axis,
- )
- weight_tensor = np.floor(tensor_scale)
- # Adaround the tensor
- weight_tensor_quant = np.add(weight_tensor, h_alpha)
- return weight_tensor_quant
- def _calculate_adarounded_weights(self):
- weight_tensor_quant = self._calculate_quant_weight()
- # Dequantize the tensor
- weight_tensor_dequant = dequant_tensor(
- weight_tensor_quant + self.offset,
- self.scale,
- quant_axis=self.quant_axis,
- )
- return weight_tensor_dequant
- def update_final_weights(self):
- weight_tensor_quant = self._calculate_quant_weight()
- return weight_tensor_quant
- def get_loss(self, beta, warm_start, adaround_out_tensor, orig_out_tensor):
- round_loss = self.adaround_loss.compute_round_loss(
- self.alpha_v, warm_start, beta
- )
- recon_loss = self.adaround_loss.compute_recon_loss(
- adaround_out_tensor, orig_out_tensor
- )
- loss = round_loss + recon_loss
- losses = {
- 'loss': loss,
- 'round_loss': round_loss,
- 'recon_loss': recon_loss,
- }
- return losses
- def update_beta_warm(self, cur_iteration):
- warm_start = cur_iteration < self.num_iterations * self.warm_start
- beta = self.adaround_loss.compute_beta(
- self.num_iterations, cur_iteration, self.warm_start
- )
- return beta, warm_start
- def run_adaround(
- data_loader,
- fp32_program,
- fetch_list,
- exe,
- scope,
- place,
- quantized_op_pairs,
- weight_op_pairs,
- scale_dict,
- num_iterations=1000,
- lr=0.001,
- bias_correction=False,
- fast_mode=True,
- ):
- fetch_op_name = fetch_list[0].name
- final_weight_tensor_quant_dict = {}
- for weight_var_name, quant_op_out_name in quantized_op_pairs.items():
- _logger.info(f'Start adaround op: {weight_var_name}')
- weight_op_type = weight_op_pairs[weight_var_name]
- # get scale and weight tensor
- weight_var_tensor = load_variable_data(scope, weight_var_name)
- scale = scale_dict[weight_var_name]
- fp32_fetch_list = None
- for _op in fp32_program.global_block().ops:
- if _op.type == "fetch":
- _op._rename_input(fetch_op_name, quant_op_out_name)
- fp32_fetch_list = fp32_program.global_block().var(
- quant_op_out_name
- )
- fetch_op_name = quant_op_out_name
- # build adaround program
- startup_program = static.Program()
- train_program = static.Program()
- with static.program_guard(train_program, startup_program):
- with paddle.utils.unique_name.guard():
- # initialize adaround
- adaround = AdaRound(
- scale,
- weight_var_tensor,
- scope=scope,
- weight_var_name=weight_var_name,
- weight_op_type=weight_op_type,
- num_iterations=num_iterations,
- )
- orig_out_tensor = static.data(
- name='orig_out_tensor',
- shape=(-1,) + fp32_fetch_list.shape,
- dtype='float32',
- )
- adaround_out_tensor = static.data(
- name='adaround_out_tensor',
- shape=(-1,) + fp32_fetch_list.shape,
- dtype='float32',
- )
- beta_tensor = static.data(
- name='beta', shape=[-1, 1], dtype='float32'
- )
- warm_start_tensor = static.data(
- name='warm_start', shape=[-1, 1], dtype='bool'
- )
- train_fetches_loss = adaround.get_loss(
- beta_tensor,
- warm_start_tensor,
- adaround_out_tensor,
- orig_out_tensor,
- )
- optimizer = paddle.optimizer.Adam(learning_rate=lr)
- loss = train_fetches_loss['loss']
- optimizer.minimize(loss)
- exe.run(startup_program)
- start_time = time.time()
- prev_start_time = start_time
- for i, data in enumerate(data_loader()):
- prev_start_time = start_time
- start_time = time.time()
- # run fp32 model
- np_orig_out_tensor = exe.run(
- program=fp32_program,
- feed=data,
- fetch_list=[fp32_fetch_list],
- return_numpy=True,
- scope=scope,
- )
- adaround_weight_tensor_dequant = (
- adaround._calculate_adarounded_weights()
- )
- np_adaround_out_tensor = (
- adaround._calculate_output_with_adarounded_weights(
- fp32_program,
- place,
- exe,
- data,
- fp32_fetch_list,
- adaround_weight_tensor_dequant,
- )
- )
- # If the cosine distance of the two tensor is small, skip training
- cos_error = calculate_quant_cos_error(
- np_orig_out_tensor[0], np_adaround_out_tensor[0]
- )
- if fast_mode and cos_error > 0.99:
- _logger.info("The cosine error is small, skip training.")
- break
- beta, warm_start = adaround.update_beta_warm(i)
- feed_dict = {
- 'orig_out_tensor': np_orig_out_tensor[0],
- 'adaround_out_tensor': np_adaround_out_tensor[0],
- 'beta': beta,
- 'warm_start': warm_start,
- }
- out = exe.run(
- train_program,
- feed=feed_dict,
- fetch_list=[v.name for v in train_fetches_loss.values()],
- return_numpy=True,
- )
- _logger.info(
- f"Iter {i:d}, lr {lr:.5f}, loss {np.mean(out[0]):.5f}, loss_round {np.mean(out[1]):.5f}, loss_recon {np.mean(out[2]):.5f}, time {start_time - prev_start_time:.5f}s"
- )
- sys.stdout.flush()
- if i == num_iterations:
- break
- final_weight_tensor_quant_dict[
- weight_var_name
- ] = adaround.update_final_weights()
- if bias_correction:
- final_weight_tensor_quant_dict[weight_var_name] = bias_correction_w(
- weight_var_tensor,
- final_weight_tensor_quant_dict[weight_var_name],
- scale,
- adaround.quant_axis,
- weight_bits=adaround.weight_bits,
- )
- del adaround
- # update adarounded calibrated weights
- for weight_var_name in quantized_op_pairs.keys():
- set_variable_data(
- scope,
- place,
- weight_var_name,
- final_weight_tensor_quant_dict[weight_var_name],
- )
|