quantization.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. import base64
  2. import bz2
  3. import ctypes
  4. from typing import List
  5. import torch
  6. from torch.nn.parameter import Parameter
  7. from modelscope.utils import logger as logging
  8. logger = logging.get_logger()
  9. try:
  10. from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
  11. class Kernel:
  12. def __init__(self, code: bytes, function_names: List[str]):
  13. self.code = code
  14. self._function_names = function_names
  15. self._cmodule = LazyKernelCModule(self.code)
  16. for name in self._function_names:
  17. setattr(self, name, KernelFunction(self._cmodule, name))
  18. quantization_code = '$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ' # noqa
  19. kernels = Kernel(
  20. bz2.decompress(base64.b64decode(quantization_code)),
  21. [
  22. 'int4WeightCompression',
  23. 'int4WeightExtractionFloat',
  24. 'int4WeightExtractionHalf',
  25. 'int8WeightExtractionFloat',
  26. 'int8WeightExtractionHalf',
  27. ],
  28. )
  29. except Exception as exception:
  30. kernels = None
  31. logger.warning('Failed to load cpm_kernels:' + str(exception))
  32. class W8A16Linear(torch.autograd.Function):
  33. @staticmethod
  34. def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor,
  35. scale_w: torch.Tensor, weight_bit_width):
  36. ctx.inp_shape = inp.size()
  37. ctx.weight_bit_width = weight_bit_width
  38. out_features = quant_w.size(0)
  39. inp = inp.contiguous().view(-1, inp.size(-1))
  40. weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
  41. ctx.weight_shape = weight.size()
  42. output = inp.mm(weight.t())
  43. ctx.save_for_backward(inp, quant_w, scale_w)
  44. return output.view(*(ctx.inp_shape[:-1] + (out_features, )))
  45. @staticmethod
  46. def backward(ctx, grad_output: torch.Tensor):
  47. inp, quant_w, scale_w = ctx.saved_tensors
  48. weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
  49. grad_output = grad_output.contiguous().view(-1, weight.size(0))
  50. grad_input = grad_output.mm(weight)
  51. grad_weight = grad_output.t().mm(inp)
  52. return grad_input.view(ctx.inp_shape), grad_weight.view(
  53. ctx.weight_shape), None, None
  54. def compress_int4_weight(weight: torch.Tensor): # (n, m)
  55. with torch.cuda.device(weight.device):
  56. n, m = weight.size(0), weight.size(1)
  57. assert m % 2 == 0
  58. m = m // 2
  59. out = torch.empty(n, m, dtype=torch.int8, device='cuda')
  60. stream = torch.cuda.current_stream()
  61. gridDim = (n, 1, 1)
  62. blockDim = (min(round_up(m, 32), 1024), 1, 1)
  63. kernels.int4WeightCompression(
  64. gridDim,
  65. blockDim,
  66. 0,
  67. stream,
  68. [
  69. ctypes.c_void_p(weight.data_ptr()),
  70. ctypes.c_void_p(out.data_ptr()),
  71. ctypes.c_int32(n),
  72. ctypes.c_int32(m)
  73. ],
  74. )
  75. return out
  76. def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor,
  77. source_bit_width: int):
  78. assert scale_list.dtype in [torch.half, torch.bfloat16]
  79. assert weight.dtype in [torch.int8]
  80. if source_bit_width == 8:
  81. return weight.to(scale_list.dtype) * scale_list[:, None]
  82. elif source_bit_width == 4:
  83. func = (
  84. kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half
  85. else kernels.int4WeightExtractionBFloat16)
  86. else:
  87. assert False, 'Unsupported bit-width'
  88. with torch.cuda.device(weight.device):
  89. n, m = weight.size(0), weight.size(1)
  90. out = torch.empty(
  91. n,
  92. m * (8 // source_bit_width),
  93. dtype=scale_list.dtype,
  94. device='cuda')
  95. stream = torch.cuda.current_stream()
  96. gridDim = (n, 1, 1)
  97. blockDim = (min(round_up(m, 32), 1024), 1, 1)
  98. func(
  99. gridDim,
  100. blockDim,
  101. 0,
  102. stream,
  103. [
  104. ctypes.c_void_p(weight.data_ptr()),
  105. ctypes.c_void_p(scale_list.data_ptr()),
  106. ctypes.c_void_p(out.data_ptr()),
  107. ctypes.c_int32(n),
  108. ctypes.c_int32(m),
  109. ],
  110. )
  111. return out
  112. class QuantizedLinear(torch.nn.Module):
  113. def __init__(self,
  114. weight_bit_width: int,
  115. weight,
  116. bias=None,
  117. device='cpu',
  118. dtype=None,
  119. empty_init=False,
  120. *args,
  121. **kwargs):
  122. super().__init__()
  123. self.weight_bit_width = weight_bit_width
  124. shape = weight.shape
  125. if weight is None or empty_init:
  126. self.weight = torch.empty(
  127. shape[0],
  128. shape[1] * weight_bit_width // 8,
  129. dtype=torch.int8,
  130. device=device)
  131. self.weight_scale = torch.empty(
  132. shape[0], dtype=dtype, device=device)
  133. else:
  134. self.weight_scale = weight.abs().max(dim=-1).values / (
  135. (2**(weight_bit_width - 1)) - 1)
  136. self.weight = torch.round(weight / self.weight_scale[:, None]).to(
  137. torch.int8)
  138. if weight_bit_width == 4:
  139. self.weight = compress_int4_weight(self.weight)
  140. self.weight = Parameter(self.weight.to(device), requires_grad=False)
  141. self.weight_scale = Parameter(
  142. self.weight_scale.to(device), requires_grad=False)
  143. self.bias = Parameter(
  144. bias.to(device), requires_grad=False) if bias is not None else None
  145. def forward(self, input):
  146. output = W8A16Linear.apply(input, self.weight, self.weight_scale,
  147. self.weight_bit_width)
  148. if self.bias is not None:
  149. output = output + self.bias
  150. return output
  151. def quantize(model, weight_bit_width, empty_init=False, device=None):
  152. """Replace fp16 linear with quantized linear"""
  153. for layer in model.layers:
  154. layer.self_attention.query_key_value = QuantizedLinear(
  155. weight_bit_width=weight_bit_width,
  156. weight=layer.self_attention.query_key_value.weight.to(
  157. torch.cuda.current_device()),
  158. bias=layer.self_attention.query_key_value.bias,
  159. dtype=layer.self_attention.query_key_value.weight.dtype,
  160. device=layer.self_attention.query_key_value.weight.device
  161. if device is None else device,
  162. empty_init=empty_init)
  163. layer.self_attention.dense = QuantizedLinear(
  164. weight_bit_width=weight_bit_width,
  165. weight=layer.self_attention.dense.weight.to(
  166. torch.cuda.current_device()),
  167. bias=layer.self_attention.dense.bias,
  168. dtype=layer.self_attention.dense.weight.dtype,
  169. device=layer.self_attention.dense.weight.device
  170. if device is None else device,
  171. empty_init=empty_init)
  172. layer.mlp.dense_h_to_4h = QuantizedLinear(
  173. weight_bit_width=weight_bit_width,
  174. weight=layer.mlp.dense_h_to_4h.weight.to(
  175. torch.cuda.current_device()),
  176. bias=layer.mlp.dense_h_to_4h.bias,
  177. dtype=layer.mlp.dense_h_to_4h.weight.dtype,
  178. device=layer.mlp.dense_h_to_4h.weight.device
  179. if device is None else device,
  180. empty_init=empty_init)
  181. layer.mlp.dense_4h_to_h = QuantizedLinear(
  182. weight_bit_width=weight_bit_width,
  183. weight=layer.mlp.dense_4h_to_h.weight.to(
  184. torch.cuda.current_device()),
  185. bias=layer.mlp.dense_4h_to_h.bias,
  186. dtype=layer.mlp.dense_4h_to_h.weight.dtype,
  187. device=layer.mlp.dense_4h_to_h.weight.device
  188. if device is None else device,
  189. empty_init=empty_init)
  190. return model