higgs.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. "HIGGS through FLUTE (Flexible Lookup Table Engine for LUT-quantized LLMs) integration file"
  15. from math import sqrt
  16. from typing import Optional
  17. from ..utils import (
  18. is_flute_available,
  19. is_hadamard_available,
  20. is_torch_available,
  21. )
  22. if is_torch_available():
  23. import torch
  24. from torch import nn
  25. if is_flute_available():
  26. from flute.integrations.higgs import prepare_data_transposed
  27. from flute.tune import TuneMetaData, qgemm_v2
  28. if is_hadamard_available():
  29. from fast_hadamard_transform import hadamard_transform
  30. def pad_to_block(tensor, dims, had_block_size, value=0):
  31. pad_dims = [0 for _ in range(2 * len(tensor.shape))]
  32. for dim in dims:
  33. size = tensor.shape[dim]
  34. next_multiple_of_1024 = ((size - 1) // had_block_size + 1) * had_block_size
  35. delta = next_multiple_of_1024 - size
  36. pad_dims[-2 * dim - 1] = delta
  37. return nn.functional.pad(tensor, pad_dims, "constant", value)
  38. def get_higgs_grid(p: int, n: int) -> "torch.Tensor":
  39. if (p, n) == (2, 256):
  40. return torch.tensor(
  41. [
  42. [-2.501467704772949, 0.17954708635807037],
  43. [-0.6761789321899414, 1.2728623151779175],
  44. [-1.8025816679000854, 0.7613157629966736],
  45. [-0.538287878036499, -2.6028504371643066],
  46. [0.8415029644966125, -0.8600977659225464],
  47. [0.7023013234138489, 3.3138747215270996],
  48. [0.5699077844619751, 2.5782253742218018],
  49. [3.292393207550049, -0.6016128063201904],
  50. [0.5561617016792297, -1.7723814249038696],
  51. [-2.1012380123138428, 0.020958125591278076],
  52. [0.46085724234580994, 0.8428705334663391],
  53. [1.4548040628433228, -0.6156039237976074],
  54. [3.210029363632202, 0.3546904921531677],
  55. [0.8893890976905823, -0.5967988967895508],
  56. [0.8618854284286499, -3.2061192989349365],
  57. [1.1360996961593628, -0.23852407932281494],
  58. [1.6646337509155273, -0.9265465140342712],
  59. [1.4767773151397705, 1.2476022243499756],
  60. [-1.0511897802352905, 1.94503915309906],
  61. [-1.56318998336792, -0.3264186680316925],
  62. [-0.1829211413860321, 0.2922491431236267],
  63. [-0.8950616717338562, -1.3887052536010742],
  64. [-0.08206957578659058, -1.329533576965332],
  65. [-0.487422913312912, 1.4817842245101929],
  66. [-1.6769757270812988, -2.8269758224487305],
  67. [-1.5057679414749146, 1.8905963897705078],
  68. [1.8335362672805786, 1.0515104532241821],
  69. [0.3273945450782776, 1.0491033792495728],
  70. [-3.295924186706543, -0.7021600008010864],
  71. [-1.8428784608840942, -1.2315762042999268],
  72. [-0.8575026392936707, -1.7005949020385742],
  73. [-1.120667815208435, 0.6467998027801514],
  74. [-0.1588846743106842, -1.804071068763733],
  75. [-0.8539647459983826, 0.5645008683204651],
  76. [-1.4192019701004028, -0.6175029873847961],
  77. [1.0799058675765991, 1.7871345281600952],
  78. [1.171311855316162, 0.7511613965034485],
  79. [2.162078380584717, 0.8044339418411255],
  80. [1.3969420194625854, -1.243762493133545],
  81. [-0.23818807303905487, 0.053944624960422516],
  82. [2.304199457168579, -1.2667627334594727],
  83. [1.4225027561187744, 0.568610668182373],
  84. [0.376836895942688, -0.7134661674499512],
  85. [2.0404467582702637, 0.4087389409542084],
  86. [0.7639489769935608, -1.1367933750152588],
  87. [0.3622530400753021, -1.4827953577041626],
  88. [0.4100743532180786, 0.36108437180519104],
  89. [-1.5867475271224976, -1.618212342262268],
  90. [-2.2769672870635986, -1.2132309675216675],
  91. [0.9184022545814514, -0.34428009390830994],
  92. [-0.3902314603328705, 0.21785245835781097],
  93. [3.120687484741211, 1.3077973127365112],
  94. [1.587440848350525, -1.6506884098052979],
  95. [-1.718808889389038, -0.038405973464250565],
  96. [-0.6888407468795776, -0.8402308821678162],
  97. [-0.7981445789337158, -1.1117373704910278],
  98. [-2.4124443531036377, 1.3419722318649292],
  99. [-0.6611530184745789, 0.9939885139465332],
  100. [-0.33103418350219727, -0.16702833771705627],
  101. [-2.4091389179229736, -2.326857566833496],
  102. [1.6610108613967896, -2.159703254699707],
  103. [0.014884627424180508, 0.3887578248977661],
  104. [0.029668325558304787, 1.8786455392837524],
  105. [1.180362582206726, 2.699317216873169],
  106. [1.821286678314209, -0.5960053205490112],
  107. [-0.44835323095321655, 3.327436685562134],
  108. [-0.3714401423931122, -2.1466753482818604],
  109. [-1.1103475093841553, -2.4536871910095215],
  110. [-0.39110705256462097, 0.6670510172843933],
  111. [0.474752813577652, -1.1959707736968994],
  112. [-0.013110585510730743, -2.52519154548645],
  113. [-2.0836575031280518, -1.703289270401001],
  114. [-1.1077687740325928, -0.1252644956111908],
  115. [-0.4138077199459076, 1.1837692260742188],
  116. [-1.977599024772644, 1.688241720199585],
  117. [-1.659559965133667, -2.1387736797332764],
  118. [0.03242531046271324, 0.6526556015014648],
  119. [0.9127950072288513, 0.6099498867988586],
  120. [-0.38478314876556396, 0.433487206697464],
  121. [0.27454206347465515, -0.27719801664352417],
  122. [0.10388526320457458, 2.2812814712524414],
  123. [-0.014394169673323631, -3.177137613296509],
  124. [-1.2871228456497192, -0.8961855173110962],
  125. [0.5720916986465454, -0.921597957611084],
  126. [1.1159656047821045, -0.7609877586364746],
  127. [2.4383342266082764, -2.2983546257019043],
  128. [-0.294057160615921, -0.9770799875259399],
  129. [-0.9342701435089111, 1.107579231262207],
  130. [-1.549338698387146, 3.090520143508911],
  131. [2.6076579093933105, 2.051239013671875],
  132. [-0.9259037375450134, 1.407211184501648],
  133. [-0.1747353971004486, 0.540488600730896],
  134. [-0.8963701725006104, 0.8271111249923706],
  135. [0.6480194926261902, 1.0128909349441528],
  136. [0.980783998966217, -0.06156221032142639],
  137. [-0.16883476078510284, 1.0601658821105957],
  138. [0.5839992761611938, 0.004697148688137531],
  139. [-0.34228450059890747, -1.2423977851867676],
  140. [2.500824451446533, 0.3665279746055603],
  141. [-0.17641609907150269, 1.3529551029205322],
  142. [0.05378641560673714, 2.817232847213745],
  143. [-1.2391047477722168, 2.354328155517578],
  144. [0.630434513092041, -0.668536365032196],
  145. [1.7576488256454468, 0.6738647818565369],
  146. [0.4435231387615204, 0.6000469326972961],
  147. [-0.08794835954904556, -0.11511358618736267],
  148. [1.6540337800979614, 0.33995017409324646],
  149. [-0.04202975332736969, -0.5375117063522339],
  150. [-0.4247745871543884, -0.7897617220878601],
  151. [0.06695003807544708, 1.2000739574432373],
  152. [-3.2508881092071533, 0.28734830021858215],
  153. [-1.613816261291504, 0.4944162368774414],
  154. [1.3598989248275757, 0.26117825508117676],
  155. [2.308382511138916, 1.3462618589401245],
  156. [-1.2137469053268433, -1.9254342317581177],
  157. [-0.4889402985572815, 1.8136259317398071],
  158. [-0.1870335340499878, -0.3480615019798279],
  159. [1.0766386985778809, -1.0627082586288452],
  160. [0.4651014506816864, 2.131748914718628],
  161. [-0.1306295394897461, -0.7811847925186157],
  162. [0.06433182954788208, -1.5397958755493164],
  163. [-0.2894323468208313, -0.5789554715156555],
  164. [-0.6081662178039551, 0.4845278263092041],
  165. [2.697964668273926, -0.18515698611736298],
  166. [0.1277363896369934, -0.7221432328224182],
  167. [0.8700758218765259, 0.35042452812194824],
  168. [0.22088994085788727, 0.495242178440094],
  169. [-2.5843818187713623, -0.8000828623771667],
  170. [0.6732649803161621, -1.4362232685089111],
  171. [-1.5286413431167603, 1.0417330265045166],
  172. [-1.1222513914108276, -0.6269875764846802],
  173. [-0.9752035140991211, -0.8750635385513306],
  174. [-2.6369473934173584, 0.6918523907661438],
  175. [0.14478731155395508, -0.041986867785453796],
  176. [-1.5629483461380005, 1.4369450807571411],
  177. [0.38952457904815674, -2.16428804397583],
  178. [-0.16885095834732056, 0.7976621985435486],
  179. [-3.12416934967041, 1.256506085395813],
  180. [0.6843105554580688, -0.4203019142150879],
  181. [1.9345275163650513, 1.934950351715088],
  182. [0.012184220366179943, -2.1080918312072754],
  183. [-0.6350273489952087, 0.7358828186988831],
  184. [-0.837304949760437, -0.6214472651481628],
  185. [0.08211923390626907, -0.9472538232803345],
  186. [2.9332995414733887, -1.4956780672073364],
  187. [1.3806978464126587, -0.2916182279586792],
  188. [0.06773144006729126, 0.9285762310028076],
  189. [-1.1943119764328003, 1.5963770151138306],
  190. [1.6395620107650757, -0.32285431027412415],
  191. [-1.390851378440857, -0.08273141086101532],
  192. [1.816330909729004, -1.2812227010726929],
  193. [0.7921574711799622, -2.1135804653167725],
  194. [0.5817914605140686, 1.2644577026367188],
  195. [1.929347038269043, -0.2386285960674286],
  196. [0.8877345323562622, 1.190008521080017],
  197. [1.4732073545455933, 0.8935023546218872],
  198. [-2.8518524169921875, -1.5478795766830444],
  199. [0.2439267635345459, 0.7576767802238464],
  200. [0.5246709585189819, -2.606659412384033],
  201. [1.150876760482788, 1.4073830842971802],
  202. [-0.2643202245235443, 2.0634236335754395],
  203. [1.555483341217041, -0.0023102816194295883],
  204. [2.0830578804016113, -1.7225427627563477],
  205. [-0.5424830317497253, -1.070199728012085],
  206. [0.9168899655342102, 0.8955540060997009],
  207. [-0.8120972514152527, 2.696739912033081],
  208. [-0.29908373951911926, -1.5310651063919067],
  209. [1.2320337295532227, -1.556247353553772],
  210. [1.8612544536590576, 0.08704725652933121],
  211. [0.22133447229862213, -1.8091708421707153],
  212. [-0.4403655230998993, -0.38571012020111084],
  213. [-1.88539457321167, 1.192205786705017],
  214. [2.239687919616699, 0.004709010478109121],
  215. [1.139495611190796, 0.45733731985092163],
  216. [-1.507995367050171, 0.19716016948223114],
  217. [0.46986445784568787, 1.5422041416168213],
  218. [-1.2573751211166382, -0.35984551906585693],
  219. [-1.7415345907211304, -0.6020717024803162],
  220. [1.0751984119415283, 0.19006384909152985],
  221. [2.24186635017395, -0.46343153715133667],
  222. [0.3610347509384155, -0.07658443599939346],
  223. [-1.3111497163772583, 0.432013601064682],
  224. [0.6164408326148987, 0.24538464844226837],
  225. [-1.9266542196273804, -0.3256155550479889],
  226. [-0.5870336890220642, -0.1879584938287735],
  227. [-1.0476511716842651, 0.3677721917629242],
  228. [-1.229940414428711, 1.2433830499649048],
  229. [0.18550436198711395, 0.22753673791885376],
  230. [-0.017921989783644676, 0.12625974416732788],
  231. [1.1659504175186157, -0.5020995736122131],
  232. [-0.5983408093452454, -1.40438973903656],
  233. [0.7519024014472961, -0.16282692551612854],
  234. [0.9920787811279297, -1.344896912574768],
  235. [-0.8103678226470947, 0.3064485788345337],
  236. [0.6956969499588013, 1.8208192586898804],
  237. [-2.7830491065979004, -0.2299390584230423],
  238. [-0.34681546688079834, 2.4890666007995605],
  239. [-1.4452646970748901, -1.2216600179672241],
  240. [-2.1872897148132324, 0.8926076292991638],
  241. [1.706072211265564, -2.8440372943878174],
  242. [1.1119003295898438, -2.4923460483551025],
  243. [-2.582794666290283, 2.0973289012908936],
  244. [0.04987720400094986, -0.2964983284473419],
  245. [-2.063807487487793, -0.7847916483879089],
  246. [-0.4068813621997833, 0.9135897755622864],
  247. [-0.9814359545707703, -0.3874954879283905],
  248. [-1.4227229356765747, 0.7337291240692139],
  249. [0.3065044581890106, 1.3125417232513428],
  250. [1.2160996198654175, -1.9643305540084839],
  251. [-1.2163853645324707, 0.14608727395534515],
  252. [-2.3030710220336914, -0.37558120489120483],
  253. [0.9232977628707886, 2.1843791007995605],
  254. [-0.1989777386188507, 1.651851773262024],
  255. [-0.714374840259552, -0.39365994930267334],
  256. [-0.7805715799331665, -2.099881887435913],
  257. [0.9015759229660034, -1.7053706645965576],
  258. [0.1033422127366066, 1.5256654024124146],
  259. [-1.8773194551467896, 2.324174165725708],
  260. [1.9227174520492554, 2.7441604137420654],
  261. [-0.5994020104408264, 0.23984014987945557],
  262. [1.3496100902557373, -0.9126054644584656],
  263. [-0.8765304088592529, -3.1877026557922363],
  264. [-1.2040035724639893, -1.5169521570205688],
  265. [1.4261796474456787, 2.150200128555298],
  266. [1.463774561882019, 1.6656692028045654],
  267. [0.20364105701446533, -0.4988172650337219],
  268. [0.5195154547691345, -0.24067887663841248],
  269. [-1.1116786003112793, -1.1599653959274292],
  270. [-0.8490808606147766, -0.1681060940027237],
  271. [0.3189965784549713, -0.9641751646995544],
  272. [-0.5664751529693604, -0.5951744318008423],
  273. [-1.6347930431365967, -0.9137664437294006],
  274. [0.44048091769218445, -0.47259435057640076],
  275. [-2.147747039794922, 0.47442489862442017],
  276. [1.834734320640564, 1.4462147951126099],
  277. [1.1777573823928833, 1.0659226179122925],
  278. [-0.9568989872932434, 0.09495053440332413],
  279. [-1.838529348373413, 0.2950586676597595],
  280. [-0.4800611734390259, 0.014894310384988785],
  281. [-0.5235516428947449, -1.7687653303146362],
  282. [2.0735011100769043, -0.8825281262397766],
  283. [2.637502431869507, 0.8455678224563599],
  284. [2.606602907180786, -0.7848446369171143],
  285. [-1.1886937618255615, 0.9330510497093201],
  286. [0.38082656264305115, 0.13328030705451965],
  287. [0.6847941875457764, 0.7384101152420044],
  288. [1.2638574838638306, -0.007309418171644211],
  289. [0.18292222917079926, -1.22371244430542],
  290. [0.8143821954727173, 1.4976691007614136],
  291. [0.6571850776672363, 0.48368802666664124],
  292. [-0.6991601586341858, 2.150190830230713],
  293. [0.8101756572723389, 0.10206498205661774],
  294. [-0.08768226951360703, -1.084917664527893],
  295. [-0.7208092212677002, 0.03657956421375275],
  296. [0.3211449086666107, 1.803687334060669],
  297. [-0.7835946083068848, 1.6869111061096191],
  298. ]
  299. )
  300. if (p, n) == (2, 64):
  301. return torch.tensor(
  302. [
  303. [-2.7216711044311523, 0.14431366324424744],
  304. [-0.766914427280426, 1.7193410396575928],
  305. [-2.2575762271881104, 1.2476624250411987],
  306. [1.233758807182312, -2.3560616970062256],
  307. [0.8701965808868408, -0.2649352252483368],
  308. [1.4506438970565796, 2.1776366233825684],
  309. [-0.06305818259716034, 1.9049758911132812],
  310. [2.536226511001587, 0.563927412033081],
  311. [0.4599496126174927, -1.8745561838150024],
  312. [-1.900517225265503, -0.30703988671302795],
  313. [0.09386251866817474, 0.8755807280540466],
  314. [1.946500539779663, -0.6743080615997314],
  315. [2.1338934898376465, 1.4581491947174072],
  316. [0.9429940581321716, -0.8038390278816223],
  317. [2.0697755813598633, -1.614896535873413],
  318. [0.772676408290863, 0.22017823159694672],
  319. [1.0689979791641235, -1.525044322013855],
  320. [0.6813604831695557, 1.1345642805099487],
  321. [0.4706456661224365, 2.606626272201538],
  322. [-1.294018030166626, -0.4372096061706543],
  323. [-0.09134224057197571, 0.4610418677330017],
  324. [-0.7907772064208984, -0.48412787914276123],
  325. [0.060459110885858536, -0.9172890186309814],
  326. [-0.5855047702789307, 2.56172513961792],
  327. [0.11484206467866898, -2.659848213195801],
  328. [-1.5893300771713257, 2.188580274581909],
  329. [1.6750942468643188, 0.7089915871620178],
  330. [-0.445697546005249, 0.7452405095100403],
  331. [-1.8539940118789673, -1.8377939462661743],
  332. [-1.5791912078857422, -1.017285943031311],
  333. [-1.030419945716858, -1.5746369361877441],
  334. [-1.9511750936508179, 0.43696075677871704],
  335. [-0.3446580767631531, -1.8953213691711426],
  336. [-1.4219647645950317, 0.7676230669021606],
  337. [-0.9191089272499084, 0.5021472573280334],
  338. [0.20464491844177246, 1.3684605360031128],
  339. [0.5402919054031372, 0.6699410676956177],
  340. [1.8903915882110596, 0.03638288006186485],
  341. [0.4723062515258789, -0.6216739416122437],
  342. [-0.41345009207725525, -0.22752176225185394],
  343. [2.7119064331054688, -0.5111885070800781],
  344. [1.065286636352539, 0.6950305700302124],
  345. [0.40629103779792786, -0.14339995384216309],
  346. [1.2815024852752686, 0.17108257114887238],
  347. [0.01785222627222538, -0.43778058886528015],
  348. [0.054590027779340744, -1.4225547313690186],
  349. [0.3076786696910858, 0.30697619915008545],
  350. [-0.9498570561408997, -0.9576997756958008],
  351. [-2.4640724658966064, -0.9660449028015137],
  352. [1.3714425563812256, -0.39760473370552063],
  353. [-0.4857747256755829, 0.2386789172887802],
  354. [1.2797833681106567, 1.3097363710403442],
  355. [0.5508887767791748, -1.1777795553207397],
  356. [-1.384316325187683, 0.1465839296579361],
  357. [-0.46556955575942993, -1.2442727088928223],
  358. [-0.3915477693080902, -0.7319604158401489],
  359. [-1.4005504846572876, 1.3890998363494873],
  360. [-0.8647305965423584, 1.0617644786834717],
  361. [-0.8901953101158142, -0.01650036871433258],
  362. [-0.9893633723258972, -2.4662880897521973],
  363. [1.445534110069275, -1.049334168434143],
  364. [-0.041650623083114624, 0.012734669260680676],
  365. [-0.3302375078201294, 1.26217782497406],
  366. [0.6934980154037476, 1.7714335918426514],
  367. ]
  368. )
  369. elif (p, n) == (2, 16):
  370. return torch.tensor(
  371. [
  372. [-0.8996632695198059, -1.6360418796539307],
  373. [-0.961183488368988, 1.5999565124511719],
  374. [-1.882026195526123, 0.678778350353241],
  375. [0.36300793290138245, -1.9667866230010986],
  376. [-0.6814072728157043, -0.576818585395813],
  377. [0.7270012497901917, 0.6186859607696533],
  378. [0.3359416127204895, 1.8371193408966064],
  379. [1.859930396080017, 0.036668598651885986],
  380. [0.17208248376846313, -0.9401724338531494],
  381. [-1.7599700689315796, -0.6244229674339294],
  382. [-0.8993809223175049, 0.32267823815345764],
  383. [0.839488685131073, -0.3017036020755768],
  384. [1.5314953327178955, 1.2942044734954834],
  385. [-0.0011779458727687597, 0.00022069070837460458],
  386. [1.4274526834487915, -1.207889199256897],
  387. [-0.16123905777931213, 0.8787511587142944],
  388. ]
  389. )
  390. elif (p, n) == (1, 16):
  391. return torch.tensor(
  392. [
  393. [-2.7325894832611084],
  394. [-2.069017171859741],
  395. [-1.6180464029312134],
  396. [-1.2562311887741089],
  397. [-0.9423404335975647],
  398. [-0.6567591428756714],
  399. [-0.38804829120635986],
  400. [-0.12839503586292267],
  401. [0.12839503586292267],
  402. [0.38804829120635986],
  403. [0.6567591428756714],
  404. [0.9423404335975647],
  405. [1.2562311887741089],
  406. [1.6180464029312134],
  407. [2.069017171859741],
  408. [2.7325894832611084],
  409. ]
  410. )
  411. elif (p, n) == (1, 8):
  412. return torch.tensor(
  413. [
  414. [-2.1519455909729004],
  415. [-1.3439092636108398],
  416. [-0.7560052871704102],
  417. [-0.2450941801071167],
  418. [0.2450941801071167],
  419. [0.7560052871704102],
  420. [1.3439092636108398],
  421. [2.1519455909729004],
  422. ]
  423. )
  424. elif (p, n) == (1, 4):
  425. return torch.tensor([[-1.5104175806045532], [-0.4527800381183624], [0.4527800381183624], [1.5104175806045532]])
  426. else:
  427. raise NotImplementedError(f"Unsupported p={p}, n={n}")
  428. def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256, hadamard_size: int = 1024):
  429. assert len(weight.shape) == 2, "Only 2D weights are supported for now"
  430. grid = get_higgs_grid(p, 2 ** (p * bits)).to(weight.device)
  431. grid_norm_2 = torch.linalg.norm(grid, axis=-1) ** 2
  432. device = weight.device
  433. dtype = weight.dtype
  434. weight = weight.to(copy=True, dtype=torch.float32)
  435. # Pad to Hadamard transform size
  436. weight = pad_to_block(weight, [1], hadamard_size)
  437. # Scale and Hadamard transform
  438. mult = weight.shape[1] // hadamard_size
  439. weight = weight.reshape(-1, mult, hadamard_size)
  440. scales = torch.linalg.norm(weight, axis=-1)
  441. weight = hadamard_transform(weight, 1) / scales[:, :, None]
  442. # Pad to edenn_d and project
  443. weight = pad_to_block(weight, [2], p).reshape(weight.shape[0], mult, -1, p)
  444. # Quantize
  445. codes = torch.empty(weight.shape[:-1], device=device, dtype=torch.uint8)
  446. for i in range(0, weight.shape[0], 16):
  447. codes[i : i + 16] = torch.argmax(2 * weight[i : i + 16] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8)
  448. del weight
  449. codes = codes.reshape(codes.shape[0], -1)
  450. scales = scales / sqrt(hadamard_size)
  451. weight, scales, tables, tables2, tune_metadata = prepare_data_transposed(
  452. codes,
  453. torch.repeat_interleave(scales.to(dtype), hadamard_size // group_size, dim=1),
  454. grid.to(dtype),
  455. num_bits=bits,
  456. group_size=group_size,
  457. vector_size=p,
  458. dtype=dtype,
  459. device=device,
  460. check_correctness=False,
  461. )
  462. return {
  463. "weight": weight,
  464. "scales": scales,
  465. "tables": tables,
  466. "tables2": tables2.view(dtype=torch.float16),
  467. "tune_metadata": tune_metadata,
  468. }
  469. class HiggsLinear(torch.nn.Module):
  470. def __init__(
  471. self,
  472. in_features: int,
  473. out_features: int,
  474. num_bits: int,
  475. bias=True,
  476. dtype: Optional[torch.dtype] = None,
  477. device: Optional[torch.device] = None,
  478. group_size: int = 256,
  479. hadamard_size: int = 1024,
  480. ):
  481. super().__init__()
  482. self.in_features = in_features
  483. self.out_features = out_features
  484. self.num_bits = num_bits
  485. self.group_size = group_size
  486. self.hadamard_size = hadamard_size
  487. assert in_features % group_size == 0
  488. assert num_bits in [2, 3, 4]
  489. self.weight = nn.Parameter(
  490. torch.empty((out_features * num_bits // 16, in_features), dtype=torch.int16, device=device),
  491. requires_grad=False,
  492. )
  493. self.scales = nn.Parameter(
  494. torch.empty((out_features, in_features // group_size), dtype=dtype, device=device), requires_grad=False
  495. )
  496. self.tables = nn.Parameter(torch.empty((2**num_bits,), dtype=dtype, device=device), requires_grad=False)
  497. self.tables2 = nn.Parameter(
  498. torch.empty((2**num_bits, 2**num_bits, 2), dtype=dtype, device=device), requires_grad=False
  499. )
  500. if bias:
  501. self.bias = nn.Parameter(torch.empty(out_features, device=device, dtype=dtype), requires_grad=False)
  502. else:
  503. self.register_parameter("bias", None)
  504. self.workspace = None # must be set externally to be reused among layers
  505. self.tune_metadata: TuneMetaData = None # must be set externally because architecture dependent
  506. def forward(self, x):
  507. x = pad_to_block(x, [-1], self.hadamard_size)
  508. if self.workspace is None:
  509. raise Exception("Workspace must be set before calling forward")
  510. return qgemm_v2(
  511. x,
  512. self.weight,
  513. self.scales,
  514. self.tables,
  515. self.tables2.view(dtype=torch.float32),
  516. self.workspace,
  517. self.tune_metadata,
  518. hadamard_size=self.hadamard_size,
  519. )
  520. def replace_with_higgs_linear(
  521. model,
  522. quantization_config=None,
  523. current_key_name=None,
  524. has_been_replaced=False,
  525. modules_to_not_convert=None,
  526. ):
  527. """
  528. Public method that recursively replaces the Linear layers of the given model with HIGGS quantized layers.
  529. `accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
  530. conversion has been successful or not.
  531. Args:
  532. model (`torch.nn.Module`):
  533. The model to convert, can be any `torch.nn.Module` instance.
  534. quantization_config (`HiggsConfig`):
  535. The quantization config object that contains the quantization parameters.
  536. current_key_name (`list`, *optional*):
  537. A list that contains the current key name. This is used for recursion and should not be passed by the user.
  538. has_been_replaced (`bool`, *optional*):
  539. A boolean that indicates if the conversion has been successful or not. This is used for recursion and
  540. should not be passed by the user.
  541. """
  542. from accelerate import init_empty_weights
  543. for name, module in model.named_children():
  544. if current_key_name is None:
  545. current_key_name = []
  546. current_key_name.append(name)
  547. if isinstance(module, nn.Linear):
  548. # Check if the current key is not in the `quantization_config.modules_to_not_convert`
  549. current_key_name_str = ".".join(current_key_name)
  550. if not any(current_key_name_str.endswith(key) for key in modules_to_not_convert):
  551. with init_empty_weights():
  552. in_features = module.in_features
  553. out_features = module.out_features
  554. model._modules[name] = HiggsLinear(
  555. in_features,
  556. out_features,
  557. bias=module.bias is not None,
  558. num_bits=quantization_config.bits,
  559. hadamard_size=quantization_config.hadamard_size,
  560. group_size=quantization_config.group_size,
  561. )
  562. has_been_replaced = True
  563. # Store the module class in case we need to transpose the weight later
  564. model._modules[name].source_cls = type(module)
  565. # Force requires grad to False to avoid unexpected errors
  566. model._modules[name].requires_grad_(False)
  567. if len(list(module.children())) > 0:
  568. _, has_been_replaced = replace_with_higgs_linear(
  569. module,
  570. quantization_config=quantization_config,
  571. current_key_name=current_key_name,
  572. has_been_replaced=has_been_replaced,
  573. modules_to_not_convert=modules_to_not_convert,
  574. )
  575. # Remove the last key for recursion
  576. current_key_name.pop(-1)
  577. return model, has_been_replaced
  578. def dequantize_higgs(model, current_key_name=None):
  579. """
  580. Dequantizes the HiggsLinear layers in the given model by replacing them with standard torch.nn.Linear layers.
  581. Args:
  582. model (torch.nn.Module): The model containing HiggsLinear layers to be dequantized.
  583. current_key_name (list, optional): A list to keep track of the current module names during recursion. Defaults to None.
  584. Returns:
  585. torch.nn.Module: The model with HiggsLinear layers replaced by torch.nn.Linear layers.
  586. """
  587. with torch.no_grad():
  588. for name, module in model.named_children():
  589. if current_key_name is None:
  590. current_key_name = []
  591. current_key_name.append(name)
  592. if isinstance(module, HiggsLinear):
  593. in_features = module.in_features
  594. out_features = module.out_features
  595. model._modules[name] = torch.nn.Linear(
  596. in_features,
  597. out_features,
  598. bias=module.bias is not None,
  599. device=module.scales.device,
  600. dtype=module.scales.dtype,
  601. )
  602. model._modules[name].weight.data = module(
  603. torch.eye(in_features, device=module.scales.device, dtype=module.scales.dtype)
  604. ).T.contiguous()
  605. if len(list(module.children())) > 0:
  606. _ = dequantize_higgs(
  607. module,
  608. current_key_name=current_key_name,
  609. )
  610. # Remove the last key for recursion
  611. current_key_name.pop(-1)
  612. return model