spqr.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. "SpQR (Sparse-Quantized Representation) integration file"
  15. from ..utils import is_accelerate_available, is_spqr_available, is_torch_available
  16. if is_torch_available():
  17. import torch.nn as nn
  18. def replace_with_spqr_linear(
  19. model,
  20. quantization_config=None,
  21. modules_to_not_convert=None,
  22. current_key_name=None,
  23. has_been_replaced=False,
  24. ):
  25. """
  26. Public method that recursively replaces the Linear layers of the given model with SpQR quantized layers.
  27. `accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
  28. conversion has been successful or not.
  29. Args:
  30. model (`torch.nn.Module`):
  31. The model to convert, can be any `torch.nn.Module` instance.
  32. quantization_config (`SpQRConfig`):
  33. The quantization config object that contains the quantization parameters.
  34. modules_to_not_convert (`list[str]`, *optional*):
  35. A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
  36. converted.
  37. current_key_name (`list`, *optional*):
  38. A list that contains the current key name. This is used for recursion and should not be passed by the user.
  39. has_been_replaced (`bool`, *optional*):
  40. A boolean that indicates if the conversion has been successful or not. This is used for recursion and
  41. should not be passed by the user.
  42. """
  43. if modules_to_not_convert is None:
  44. modules_to_not_convert = []
  45. if is_accelerate_available():
  46. from accelerate import init_empty_weights
  47. if is_spqr_available():
  48. from spqr_quant import QuantizedLinear
  49. for name, module in model.named_children():
  50. if current_key_name is None:
  51. current_key_name = []
  52. current_key_name.append(name)
  53. if isinstance(module, nn.Linear):
  54. # Check if the current key is not in the `modules_to_not_convert`
  55. if ".".join(current_key_name) + ".weight" not in modules_to_not_convert:
  56. with init_empty_weights():
  57. tensor_name = ".".join(current_key_name)
  58. shapes = quantization_config.shapes
  59. shapes_keys = shapes.keys()
  60. shapes_valid = (
  61. f"{tensor_name}.dense_weights.shape" in shapes_keys
  62. and f"{tensor_name}.row_offsets.shape" in shapes_keys
  63. and f"{tensor_name}.col_vals.shape" in shapes_keys
  64. and f"{tensor_name}.in_perm.shape" in shapes_keys
  65. )
  66. if not shapes_valid:
  67. raise ValueError(
  68. f"The SpQR quantization config does not contain the shape "
  69. f"configuration for {tensor_name}. This indicates that the "
  70. f"configuration is either invalid or corrupted."
  71. )
  72. dense_weights_shape = shapes[f"{tensor_name}.dense_weights.shape"]
  73. row_offsets_shape = shapes[f"{tensor_name}.row_offsets.shape"]
  74. col_vals_shape = shapes[f"{tensor_name}.col_vals.shape"]
  75. in_perm_shape = shapes[f"{tensor_name}.in_perm.shape"]
  76. in_features = module.in_features
  77. out_features = module.out_features
  78. model._modules[name] = QuantizedLinear.create_placehodler(
  79. rows=out_features,
  80. cols=in_features,
  81. bits=quantization_config.bits,
  82. beta1=quantization_config.beta1,
  83. beta2=quantization_config.beta2,
  84. dense_weights_shape=dense_weights_shape,
  85. row_offsets_shape=row_offsets_shape,
  86. col_vals_shape=col_vals_shape,
  87. in_perm_shape=in_perm_shape,
  88. )
  89. has_been_replaced = True
  90. # Store the module class in case we need to transpose the weight later
  91. model._modules[name].source_cls = type(module)
  92. # Force requires grad to False to avoid unexpected errors
  93. model._modules[name].requires_grad_(False)
  94. else:
  95. pass
  96. if len(list(module.children())) > 0:
  97. _, has_been_replaced = replace_with_spqr_linear(
  98. module,
  99. quantization_config=quantization_config,
  100. modules_to_not_convert=modules_to_not_convert,
  101. current_key_name=current_key_name,
  102. has_been_replaced=has_been_replaced,
  103. )
  104. # Remove the last key for recursion
  105. current_key_name.pop(-1)
  106. return model, has_been_replaced