codegeex.py 35 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030
  1. # Copyright (c) 2022 Zhipu.AI
  2. import math
  3. import torch
  4. import torch.nn.functional as F
  5. def fast_gelu(x):
  6. """Mindspore's fast gelu implementation."""
  7. return x / (1 + torch.exp(-1.702 * torch.abs(x))) * torch.exp(
  8. 0.851 * (x - torch.abs(x)))
  9. class MLP(torch.nn.Module):
  10. """MLP.
  11. MLP will take the input with h hidden state, project it to 4*h
  12. hidden dimension, perform nonlinear transformation, and project the
  13. state back into h hidden dimension. At the end, dropout is also
  14. applied.
  15. """
  16. def __init__(
  17. self,
  18. hidden_size,
  19. ):
  20. super(MLP, self).__init__()
  21. self.hidden_size = hidden_size
  22. # Project to 4h.
  23. self.dense_h_to_4h = torch.nn.Linear(
  24. self.hidden_size,
  25. 4 * self.hidden_size,
  26. )
  27. self.activation_func = fast_gelu
  28. # Project back to h.
  29. self.dense_4h_to_h = torch.nn.Linear(
  30. 4 * self.hidden_size,
  31. self.hidden_size,
  32. )
  33. def forward(self, hidden_states):
  34. # [s, b, 4hp]
  35. intermediate_parallel = self.dense_h_to_4h(hidden_states)
  36. intermediate_parallel = self.activation_func(intermediate_parallel)
  37. # [s, b, h]
  38. output = self.dense_4h_to_h(intermediate_parallel)
  39. return output
  40. class SelfAttention(torch.nn.Module):
  41. """self-attention layer abstract class.
  42. Self-attention layer takes input with size [b, s, h]
  43. and returns output of the same size.
  44. """
  45. def __init__(
  46. self,
  47. hidden_size,
  48. num_attention_heads,
  49. layer_number,
  50. fp16=True,
  51. attention_softmax_in_fp32=True,
  52. ):
  53. super(SelfAttention, self).__init__()
  54. self.hidden_size = hidden_size
  55. self.num_attention_heads = num_attention_heads
  56. self.fp16 = fp16
  57. self.attention_softmax_in_fp32 = attention_softmax_in_fp32
  58. self.layer_number = max(1, layer_number)
  59. assert self.hidden_size % self.num_attention_heads == 0
  60. self.hidden_size_per_attention_head = int(self.hidden_size
  61. // self.num_attention_heads)
  62. self.query = torch.nn.Linear(self.hidden_size, self.hidden_size)
  63. self.key = torch.nn.Linear(self.hidden_size, self.hidden_size)
  64. self.value = torch.nn.Linear(self.hidden_size, self.hidden_size)
  65. self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
  66. self.softmax = torch.nn.Softmax(dim=-1)
  67. self.dense = torch.nn.Linear(self.hidden_size, self.hidden_size)
  68. def forward(
  69. self,
  70. hidden_states,
  71. attention_mask,
  72. layer_past=None,
  73. get_key_value=False,
  74. prompt_length=None,
  75. context_length=None,
  76. ):
  77. # hidden_states: [sq, b, h]
  78. # =====================
  79. # Query, Key, and Value
  80. # =====================
  81. query_layer = self.query(hidden_states)
  82. key_layer = self.key(hidden_states)
  83. value_layer = self.value(hidden_states)
  84. new_query_layer_shape = query_layer.size()[:-1] + (
  85. self.num_attention_heads, self.hidden_size_per_attention_head
  86. ) # noqa
  87. query_layer = query_layer.view(*new_query_layer_shape)
  88. new_query_layer_shape = key_layer.size()[:-1] + (
  89. self.num_attention_heads, self.hidden_size_per_attention_head)
  90. key_layer = key_layer.view(*new_query_layer_shape)
  91. new_query_layer_shape = value_layer.size()[:-1] + (
  92. self.num_attention_heads, self.hidden_size_per_attention_head
  93. ) # noqa
  94. value_layer = value_layer.view(*new_query_layer_shape)
  95. # ==================================
  96. # Adjust key and value for inference
  97. # ==================================
  98. if layer_past is not None:
  99. past_key, past_value = layer_past
  100. key_layer = torch.cat((past_key.type_as(key_layer), key_layer),
  101. dim=0)
  102. value_layer = torch.cat(
  103. (past_value.type_as(value_layer), value_layer), dim=0)
  104. if get_key_value:
  105. present = (key_layer, value_layer)
  106. # ===================================
  107. # Raw attention scores. [b, np, sq, sk]
  108. # ===================================
  109. # [b, np, sq, sk]
  110. output_size = (query_layer.size(1), query_layer.size(2),
  111. query_layer.size(0), key_layer.size(0))
  112. # [sq, b, np, hn] -> [sq, b * np, hn]
  113. query_layer = query_layer.contiguous().view(
  114. output_size[2], output_size[0] * output_size[1], -1)
  115. key_layer = key_layer.contiguous().view(
  116. output_size[3], output_size[0] * output_size[1], -1)
  117. # Raw attention scores. [b * np, sq, sk]
  118. matmul_result = torch.matmul(
  119. query_layer.transpose(0, 1),
  120. key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor
  121. # change view to [b, np, sq, sk]
  122. attention_scores = matmul_result.view(*output_size)
  123. # ==================================================
  124. # Update attention mask for inference. [b, np, sq, sk]
  125. # ==================================================
  126. if get_key_value:
  127. with torch.no_grad():
  128. if layer_past is not None:
  129. attention_mask = attention_mask[
  130. ...,
  131. attention_scores.size(3)
  132. - 1, :attention_scores.size(3)].unsqueeze(2)
  133. else:
  134. attention_mask = attention_mask[
  135. ..., :attention_scores.size(3), :attention_scores.
  136. size(3)]
  137. if context_length is not None:
  138. attention_mask = torch.clone(attention_mask)
  139. attention_mask[:, :, context_length:, :] = True
  140. # attention scores and attention mask [b, np, sq, sk]
  141. # attention_scores = attention_mask_func(attention_scores, attention_mask)
  142. attention_scores = attention_scores - attention_mask * 10000.0
  143. if self.attention_softmax_in_fp32:
  144. attention_probs = self.softmax(attention_scores.float()).half()
  145. else:
  146. attention_probs = self.softmax(attention_scores)
  147. # =========================
  148. # Context layer. [sq, b, hp]
  149. # =========================
  150. # value_layer -> context layer.
  151. # [sq, b, np, hn] --> [b, np, sq, hn]
  152. # context layer shape: [b, np, sq, hn]
  153. output_size = (value_layer.size(1), value_layer.size(2),
  154. query_layer.size(0), value_layer.size(3))
  155. # change view [sq, b * np, hn]
  156. value_layer = value_layer.view(
  157. value_layer.size(0), output_size[0] * output_size[1], -1)
  158. # change view [b * np, sq, sk]
  159. attention_probs = attention_probs.view(output_size[0] * output_size[1],
  160. output_size[2], -1)
  161. context_layer = torch.bmm(
  162. attention_probs,
  163. value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))
  164. # change view [b, np, sq, hn]
  165. context_layer = context_layer.view(*output_size)
  166. # # [b, np, sq, hn] --> [sq, b, np, hn]
  167. context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
  168. # # [sq, b, np, hn] --> [sq, b, hp]
  169. new_context_layer_shape = context_layer.size()[:-2] + (
  170. self.hidden_size, )
  171. context_layer = context_layer.view(*new_context_layer_shape)
  172. # =================
  173. # Output. [sq, b, h]
  174. # =================
  175. output = self.dense(context_layer)
  176. if get_key_value:
  177. output = [output, present]
  178. return output
  179. class TopQuerySelfAttention(torch.nn.Module):
  180. """Top query self-attention layer abstract class.
  181. Self-attention layer takes input with size [b, s, h]
  182. and returns output of the same size.
  183. """
  184. def __init__(
  185. self,
  186. hidden_size,
  187. num_attention_heads,
  188. layer_number,
  189. fp16=True,
  190. attention_softmax_in_fp32=True,
  191. ):
  192. super(TopQuerySelfAttention, self).__init__()
  193. self.hidden_size = hidden_size
  194. self.num_attention_heads = num_attention_heads
  195. self.fp16 = fp16
  196. self.attention_softmax_in_fp32 = attention_softmax_in_fp32
  197. self.layer_number = max(1, layer_number)
  198. assert self.hidden_size % self.num_attention_heads == 0
  199. self.hidden_size_per_attention_head = int(self.hidden_size
  200. // self.num_attention_heads)
  201. self.query = torch.nn.Linear(self.hidden_size, self.hidden_size)
  202. self.key = torch.nn.Linear(self.hidden_size, self.hidden_size)
  203. self.value = torch.nn.Linear(self.hidden_size, self.hidden_size)
  204. self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
  205. self.softmax = torch.nn.Softmax(dim=-1)
  206. self.dense = torch.nn.Linear(self.hidden_size, self.hidden_size)
  207. def forward(
  208. self,
  209. hidden_states,
  210. query_hidden_state,
  211. attention_mask,
  212. layer_past=None,
  213. get_key_value=False,
  214. prompt_length=None,
  215. context_length=None,
  216. ):
  217. # hidden_states: [sq, b, h]
  218. query_layer = self.query(query_hidden_state)
  219. key_layer = self.key(hidden_states)
  220. value_layer = self.value(hidden_states)
  221. new_query_layer_shape = query_layer.size()[:-1] + (
  222. self.num_attention_heads, self.hidden_size_per_attention_head
  223. ) # noqa
  224. query_layer = query_layer.view(*new_query_layer_shape)
  225. new_query_layer_shape = key_layer.size()[:-1] + (
  226. self.num_attention_heads, self.hidden_size_per_attention_head)
  227. key_layer = key_layer.view(*new_query_layer_shape)
  228. new_query_layer_shape = value_layer.size()[:-1] + (
  229. self.num_attention_heads, self.hidden_size_per_attention_head
  230. ) # noqa
  231. value_layer = value_layer.view(*new_query_layer_shape)
  232. # ==================================
  233. # Adjust key and value for inference
  234. # ==================================
  235. if layer_past is not None:
  236. past_key, past_value = layer_past
  237. key_layer = torch.cat((past_key.type_as(key_layer), key_layer),
  238. dim=0)
  239. value_layer = torch.cat(
  240. (past_value.type_as(value_layer), value_layer), dim=0)
  241. if get_key_value:
  242. present = (key_layer, value_layer)
  243. # ===================================
  244. # Raw attention scores. [b, np, sq, sk]
  245. # ===================================
  246. # [b, np, sq, sk]
  247. output_size = (query_layer.size(1), query_layer.size(2),
  248. query_layer.size(0), key_layer.size(0))
  249. # [s, b, np, hn] -> [s, b * np, hn]
  250. query_layer = query_layer.contiguous().view(
  251. output_size[2], output_size[0] * output_size[1], -1)
  252. key_layer = key_layer.contiguous().view(
  253. output_size[3], output_size[0] * output_size[1], -1)
  254. # Raw attention scores. [b * np, sq, sk]
  255. matmul_result = torch.matmul(
  256. query_layer.transpose(0, 1),
  257. key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor
  258. # change view to [b, np, s, s]
  259. attention_scores = matmul_result.view(*output_size)
  260. # ==================================================
  261. # Update attention mask for inference. [b, np, sq, sk]
  262. # ==================================================
  263. if get_key_value:
  264. with torch.no_grad():
  265. if layer_past is not None:
  266. attention_mask = attention_mask[
  267. ...,
  268. attention_scores.size(3)
  269. - 1, :attention_scores.size(3)].unsqueeze(2)
  270. else:
  271. attention_mask = attention_mask[
  272. ..., :attention_scores.size(3), :attention_scores.
  273. size(3)]
  274. if context_length is not None:
  275. attention_mask = torch.clone(attention_mask)
  276. attention_mask[:, :, context_length:, :] = True
  277. # attention scores and attention mask [b, np, sq, sk]
  278. # attention_scores = attention_mask_func(attention_scores, attention_mask)
  279. attention_scores = attention_scores - attention_mask * 10000.0
  280. if self.attention_softmax_in_fp32:
  281. attention_probs = self.softmax(attention_scores.float()).half()
  282. else:
  283. attention_probs = self.softmax(attention_scores)
  284. # =========================
  285. # Context layer. [sq, b, hp]
  286. # =========================
  287. # value_layer -> context layer.
  288. # [sq, b, np, hn] --> [b, np, sq, hn]
  289. # context layer shape: [b, np, sq, hn]
  290. output_size = (value_layer.size(1), value_layer.size(2),
  291. query_layer.size(0), value_layer.size(3))
  292. # change view [sq, b * np, hn]
  293. value_layer = value_layer.view(
  294. value_layer.size(0), output_size[0] * output_size[1], -1)
  295. # change view [b * np, sq, sk]
  296. attention_probs = attention_probs.view(output_size[0] * output_size[1],
  297. output_size[2], -1)
  298. # matmul: [b * np, sq, hn]
  299. context_layer = torch.bmm(
  300. attention_probs,
  301. value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))
  302. # change view [b, np, sq, hn]
  303. context_layer = context_layer.view(*output_size)
  304. # [b, np, sq, hn] --> [sq, b, np, hn]
  305. context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
  306. # [sq, b, np, hn] --> [sq, b, hp]
  307. new_context_layer_shape = context_layer.size()[:-2] + \
  308. (self.hidden_size,) # noqa
  309. context_layer = context_layer.view(*new_context_layer_shape)
  310. # =================
  311. # Output. [sq, b, h]
  312. # =================
  313. output = self.dense(context_layer)
  314. if get_key_value:
  315. output = [output, present]
  316. return output
  317. class TransformerLayer(torch.nn.Module):
  318. """A single transformer layer.
  319. Transformore layer takes input with size [b, s, h] and returns an
  320. output of the same size.
  321. """
  322. def __init__(
  323. self,
  324. hidden_size,
  325. num_attention_heads,
  326. layer_number,
  327. layernorm_epsilon=1e-5,
  328. fp16=True,
  329. attention_softmax_in_fp32=True,
  330. ):
  331. super(TransformerLayer, self).__init__()
  332. self.hidden_size = hidden_size
  333. self.layernorm_epsilon = layernorm_epsilon
  334. self.layer_number = layer_number
  335. # Layernorm on the input data.
  336. self.input_layernorm = torch.nn.LayerNorm(
  337. hidden_size, eps=self.layernorm_epsilon)
  338. # Self attention.
  339. self.attention = SelfAttention(hidden_size, num_attention_heads,
  340. layer_number, fp16,
  341. attention_softmax_in_fp32)
  342. # Layernorm on the input data.
  343. self.post_attention_layernorm = torch.nn.LayerNorm(
  344. self.hidden_size, eps=self.layernorm_epsilon)
  345. self.mlp = MLP(self.hidden_size)
  346. def forward(
  347. self,
  348. hidden_states,
  349. attention_mask,
  350. layer_past=None,
  351. get_key_value=False,
  352. prompt_length=None,
  353. context_length=None,
  354. ):
  355. # hidden_states: [b, s, h]
  356. # Use FP32 for Layernorm
  357. # layernorm_output = self.input_layernorm(hidden_states.float()).half()
  358. layernorm_output = self.input_layernorm(hidden_states)
  359. # Self attention.
  360. attention_output = self.attention(
  361. layernorm_output,
  362. attention_mask,
  363. layer_past=layer_past,
  364. get_key_value=get_key_value,
  365. prompt_length=prompt_length,
  366. context_length=context_length)
  367. if get_key_value:
  368. attention_output, presents = attention_output
  369. # Residual connection.
  370. residual = hidden_states
  371. layernorm_input = attention_output + residual
  372. # Use FP32 for Layernorm
  373. # layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half()
  374. layernorm_output = self.post_attention_layernorm(layernorm_input)
  375. mlp_output = self.mlp(layernorm_output)
  376. output = mlp_output + layernorm_input
  377. if get_key_value:
  378. output = [output, presents]
  379. return output
  380. class TopQueryLayer(torch.nn.Module):
  381. """A single top query layer.
  382. Top query layer takes input with size [b, s, h] and returns an
  383. output of the same size.
  384. """
  385. def __init__(
  386. self,
  387. hidden_size,
  388. num_attention_heads,
  389. layer_number,
  390. layernorm_epsilon=1e-5,
  391. ):
  392. super(TopQueryLayer, self).__init__()
  393. self.hidden_size = hidden_size
  394. self.num_attention_heads = num_attention_heads
  395. self.layernorm_epsilon = layernorm_epsilon
  396. self.layer_number = layer_number
  397. # Use FP32 for Layernorm
  398. self.input_layernorm = torch.nn.LayerNorm(
  399. self.hidden_size, eps=self.layernorm_epsilon)
  400. # Self attention.
  401. self.attention = TopQuerySelfAttention(self.hidden_size,
  402. self.num_attention_heads,
  403. self.layer_number)
  404. # Layernorm on the input data.
  405. self.post_attention_layernorm = torch.nn.LayerNorm(
  406. self.hidden_size, eps=self.layernorm_epsilon)
  407. # MLP
  408. self.mlp = MLP(self.hidden_size)
  409. def forward(
  410. self,
  411. hidden_states,
  412. query_hidden_state,
  413. attention_mask,
  414. layer_past=None,
  415. get_key_value=False,
  416. prompt_length=None,
  417. context_length=None,
  418. ):
  419. # hidden_states: [b, s, h]
  420. assert query_hidden_state != None # noqa
  421. # Use FP32 for Layernorm
  422. # layernorm_output = self.input_layernorm(hidden_states.float()).half()
  423. layernorm_output = self.input_layernorm(hidden_states)
  424. # Self attention.
  425. attention_output = self.attention(
  426. layernorm_output,
  427. query_hidden_state,
  428. attention_mask,
  429. layer_past=layer_past,
  430. get_key_value=get_key_value,
  431. prompt_length=prompt_length,
  432. context_length=context_length)
  433. if get_key_value:
  434. attention_output, presents = attention_output
  435. # Residual connection.
  436. residual = hidden_states
  437. layernorm_input = attention_output + residual
  438. # Use FP32 for Layernorm
  439. # layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half()
  440. layernorm_output = self.post_attention_layernorm(layernorm_input)
  441. # MLP.
  442. mlp_output = self.mlp(layernorm_output)
  443. # Second residual connection.
  444. residual = layernorm_input
  445. output = mlp_output + residual
  446. if get_key_value:
  447. output = [output, presents]
  448. return output
  449. class Transformer(torch.nn.Module):
  450. """Transformer class."""
  451. def __init__(
  452. self,
  453. hidden_size,
  454. num_attention_heads,
  455. num_layers,
  456. layernorm_epsilon=1e-5,
  457. ):
  458. super(Transformer, self).__init__()
  459. self.hidden_size = hidden_size
  460. self.num_attention_heads = num_attention_heads
  461. self.layernorm_epsilon = layernorm_epsilon
  462. # Number of layers:
  463. self.num_layers = num_layers
  464. self.num_unique_layers = None
  465. #################
  466. assert self.num_unique_layers is None
  467. #################
  468. if self.num_unique_layers is None:
  469. self.num_unique_layers = self.num_layers
  470. assert self.num_layers % self.num_unique_layers == 0, \
  471. 'number of layers should be divisible by number of unique layers'
  472. # Transformer layers.
  473. def build_layer(layer_number):
  474. return TransformerLayer(self.hidden_size, self.num_attention_heads,
  475. layer_number)
  476. self.layers = torch.nn.ModuleList(
  477. [build_layer(i + 1) for i in range(self.num_unique_layers)])
  478. self.topQueryLayer = TopQueryLayer(self.hidden_size,
  479. self.num_attention_heads,
  480. self.num_unique_layers)
  481. self.final_layernorm = torch.nn.LayerNorm(
  482. self.hidden_size, eps=self.layernorm_epsilon)
  483. def _get_layer_index(self, layer_number):
  484. return layer_number % self.num_unique_layers
  485. def _get_layer(self, layer_number):
  486. return self.layers[self._get_layer_index(layer_number)]
  487. def forward(
  488. self,
  489. hidden_states,
  490. query_hidden_state,
  491. attention_mask,
  492. layer_past=None,
  493. get_key_value=False,
  494. prompt_length=None,
  495. context_length=None,
  496. ):
  497. # data format change to avoid explicit transposes : [b s h] --> [s b h]
  498. hidden_states = hidden_states.transpose(0, 1).contiguous()
  499. query_hidden_state = query_hidden_state.transpose(0, 1).contiguous()
  500. if get_key_value:
  501. presents = []
  502. for index in range(self.num_layers):
  503. layer = self._get_layer(index)
  504. past = None
  505. if layer_past is not None:
  506. past = layer_past[index]
  507. hidden_states = layer(
  508. hidden_states,
  509. attention_mask,
  510. layer_past=past,
  511. get_key_value=get_key_value,
  512. prompt_length=prompt_length,
  513. context_length=context_length)
  514. if get_key_value:
  515. hidden_states, present = hidden_states
  516. presents.append(present)
  517. # Use FP32 for Layernorm
  518. # hidden_states_ = self.final_layernorm(hidden_states.float()).half()
  519. hidden_states_ = self.final_layernorm(hidden_states)
  520. #################################
  521. # top query layer
  522. #################################
  523. past = None
  524. if layer_past is not None:
  525. past = layer_past[self.num_layers]
  526. hidden_states = self.topQueryLayer(
  527. hidden_states_,
  528. query_hidden_state,
  529. attention_mask,
  530. layer_past=past,
  531. get_key_value=get_key_value,
  532. prompt_length=prompt_length,
  533. context_length=context_length)
  534. if get_key_value:
  535. hidden_states, present = hidden_states
  536. presents.append(present)
  537. # reverting data format change [s b h] --> [b s h]
  538. output = hidden_states.transpose(0, 1).contiguous()
  539. if get_key_value:
  540. output = [output, presents]
  541. return output
  542. def state_dict_for_save_checkpoint(self,
  543. destination=None,
  544. prefix='',
  545. keep_vars=False):
  546. return self.state_dict(destination, prefix, keep_vars)
  547. class Embedding(torch.nn.Module):
  548. """Language model embeddings.
  549. Arguments:
  550. hidden_size: hidden size
  551. vocab_size: vocabulary size
  552. max_sequence_length: maximum size of sequence. This
  553. is used for positional embedding
  554. """
  555. def __init__(
  556. self,
  557. hidden_size,
  558. vocab_size,
  559. max_sequence_length,
  560. ):
  561. super(Embedding, self).__init__()
  562. self.hidden_size = hidden_size
  563. self.vocab_size = vocab_size
  564. self.max_sequence_length = max_sequence_length
  565. # Word embeddings.
  566. self.word_embeddings = torch.nn.Embedding(self.vocab_size,
  567. self.hidden_size)
  568. self._word_embeddings_key = 'word_embeddings'
  569. # Position embedding.
  570. self.position_embeddings = torch.nn.Embedding(self.max_sequence_length,
  571. self.hidden_size)
  572. self.position_embeddings = self.position_embeddings.half()
  573. self._position_embeddings_key = 'position_embeddings'
  574. def forward(self, input_ids, position_ids):
  575. # Embeddings.
  576. words_embeddings = self.word_embeddings(input_ids)
  577. position_embeddings = self.position_embeddings(position_ids)
  578. embeddings = words_embeddings + position_embeddings
  579. return embeddings
  580. def state_dict_for_save_checkpoint(self,
  581. destination=None,
  582. prefix='',
  583. keep_vars=False):
  584. """For easy load."""
  585. state_dict_ = {}
  586. state_dict_[self._word_embeddings_key] \
  587. = self.word_embeddings.state_dict(destination, prefix, keep_vars)
  588. state_dict_[self._position_embeddings_key] \
  589. = self.position_embeddings.state_dict(
  590. destination, prefix, keep_vars)
  591. return state_dict_
  592. def load_state_dict(self, state_dict, strict=True):
  593. """Customized load."""
  594. # Word embedding.
  595. if self._word_embeddings_key in state_dict:
  596. state_dict_ = state_dict[self._word_embeddings_key]
  597. else:
  598. # for backward compatibility.
  599. state_dict_ = {}
  600. for key in state_dict.keys():
  601. if 'word_embeddings' in key:
  602. state_dict_[key.split('word_embeddings.')[1]] \
  603. = state_dict[key]
  604. state_dict_['weight'] = state_dict_['weight'][:self.vocab_size]
  605. self.word_embeddings.load_state_dict(state_dict_, strict=strict)
  606. # Position embedding.
  607. if self._position_embeddings_key in state_dict:
  608. state_dict_ = state_dict[self._position_embeddings_key]
  609. else:
  610. # for backward compatibility.
  611. state_dict_ = {}
  612. for key in state_dict.keys():
  613. if 'position_embeddings' in key:
  614. state_dict_[key.split('position_embeddings.')[1]] \
  615. = state_dict[key]
  616. self.position_embeddings.load_state_dict(state_dict_, strict=strict)
  617. class QueryEmbedding(torch.nn.Module):
  618. """Language model embeddings.
  619. Arguments:
  620. hidden_size: hidden size
  621. vocab_size: vocabulary size
  622. max_sequence_length: maximum size of sequence. This
  623. is used for positional embedding
  624. """
  625. def __init__(
  626. self,
  627. hidden_size,
  628. vocab_size,
  629. max_sequence_length,
  630. ):
  631. super(QueryEmbedding, self).__init__()
  632. self.hidden_size = hidden_size
  633. self.vocab_size = vocab_size
  634. self.max_sequence_length = max_sequence_length
  635. # Top query position embedding (serial).
  636. self.top_query_embeddings = torch.nn.Embedding(
  637. self.max_sequence_length, self.hidden_size)
  638. self.top_query_embeddings = self.top_query_embeddings.half()
  639. self._top_query_embeddings_key = 'top_query_embeddings'
  640. def forward(self, position_ids):
  641. # Embeddings.
  642. embeddings = self.top_query_embeddings(position_ids)
  643. return embeddings
  644. def state_dict_for_save_checkpoint(self,
  645. destination=None,
  646. prefix='',
  647. keep_vars=False):
  648. """For easy load."""
  649. state_dict_ = {}
  650. state_dict_[self._top_query_embeddings_key] \
  651. = self.top_query_embeddings.state_dict(
  652. destination, prefix, keep_vars)
  653. return state_dict_
  654. def load_state_dict(self, state_dict, strict=True):
  655. """Customized load."""
  656. # Position embedding.
  657. if self._top_query_embeddings_key in state_dict:
  658. state_dict_ = state_dict[self._top_query_embeddings_key]
  659. else:
  660. # for backward compatibility.
  661. state_dict_ = {}
  662. for key in state_dict.keys():
  663. if 'top_query_embeddings' in key:
  664. state_dict_[key.split('top_query_embeddings.')[1]] \
  665. = state_dict[key]
  666. self.top_query_embeddings.load_state_dict(state_dict_, strict=strict)
  667. class TransformerLanguageModel(torch.nn.Module):
  668. """Transformer language model.
  669. Arguments:
  670. transformer_hparams: transformer hyperparameters
  671. attention_mask_func: a function that takes `unmaksed-attention-scores`
  672. with size [b, np, s, s] and an `attention-mask` and will apply
  673. the masking. The function should return a masked score of the
  674. same size [b, np, s, s].
  675. masked-attention-scores = attention_mask_func(
  676. unmaksed-attention-scores, attention-mask)
  677. vocab_size: vocabulary size
  678. max_sequence_length: maximum size of sequence. This
  679. is used for positional embedding
  680. """
  681. def __init__(
  682. self,
  683. hidden_size,
  684. num_layers,
  685. num_attention_heads,
  686. padded_vocab_size,
  687. max_position_embeddings,
  688. ):
  689. super(TransformerLanguageModel, self).__init__()
  690. self.hidden_size = hidden_size
  691. self.num_layers = num_layers
  692. self.num_attention_heads = num_attention_heads
  693. self.padded_vocab_size = padded_vocab_size
  694. self.max_position_embeddings = max_position_embeddings
  695. # Embeddings
  696. self.embedding = Embedding(self.hidden_size, self.padded_vocab_size,
  697. self.max_position_embeddings)
  698. self._embedding_key = 'embedding'
  699. # Query embeddings
  700. self.topQueryEmbedding = QueryEmbedding(self.hidden_size,
  701. self.padded_vocab_size,
  702. self.max_position_embeddings)
  703. self._topQueryEmbedding_key = 'topQueryEmbedding'
  704. # Transformer
  705. self.transformer = Transformer(self.hidden_size,
  706. self.num_attention_heads,
  707. self.num_layers)
  708. self._transformer_key = 'transformer'
  709. def forward(
  710. self,
  711. input_ids,
  712. position_ids,
  713. attention_mask,
  714. layer_past=None,
  715. get_key_value=False,
  716. prompt_length=None,
  717. context_length=None,
  718. ):
  719. # Embeddings.
  720. embedding_output = self.embedding(input_ids, position_ids)
  721. query_position_ids = position_ids
  722. queryEmbedding_out = self.topQueryEmbedding(query_position_ids)
  723. # Transformer.
  724. transformer_output = self.transformer(
  725. embedding_output,
  726. queryEmbedding_out,
  727. attention_mask,
  728. layer_past=layer_past,
  729. get_key_value=get_key_value,
  730. prompt_length=prompt_length,
  731. context_length=context_length)
  732. return transformer_output
  733. def state_dict_for_save_checkpoint(self,
  734. destination=None,
  735. prefix='',
  736. keep_vars=False):
  737. """For easy load."""
  738. state_dict_ = {}
  739. state_dict_[self._embedding_key] \
  740. = self.embedding.state_dict_for_save_checkpoint(
  741. destination, prefix, keep_vars)
  742. state_dict_[self._topQueryEmbedding_key] \
  743. = self.topQueryEmbedding.state_dict_for_save_checkpoint(
  744. destination, prefix, keep_vars)
  745. state_dict_[self._transformer_key] \
  746. = self.transformer.state_dict_for_save_checkpoint(
  747. destination, prefix, keep_vars)
  748. return state_dict_
  749. def load_state_dict(self, state_dict, strict=True):
  750. """Customized load."""
  751. # Embedding.
  752. if self._embedding_key in state_dict:
  753. state_dict_ = state_dict[self._embedding_key]
  754. else:
  755. # for backward compatibility.
  756. state_dict_ = {}
  757. for key in state_dict.keys():
  758. if '_embeddings' in key:
  759. state_dict_[key] = state_dict[key]
  760. self.embedding.load_state_dict(state_dict_, strict=strict)
  761. if self._topQueryEmbedding_key in state_dict:
  762. state_dict_ = state_dict[self._topQueryEmbedding_key]
  763. else:
  764. # for backward compatibility.
  765. state_dict_ = {}
  766. for key in state_dict.keys():
  767. if '_embeddings' in key:
  768. state_dict_[key] = state_dict[key]
  769. self.topQueryEmbedding.load_state_dict(state_dict_, strict=strict)
  770. # Transformer.
  771. if self._transformer_key in state_dict:
  772. state_dict_ = state_dict[self._transformer_key]
  773. else:
  774. # for backward compatibility.
  775. state_dict_ = {}
  776. for key in state_dict.keys():
  777. if 'transformer.' in key:
  778. state_dict_[key.split('transformer.')[1]] = state_dict[key]
  779. self.transformer.load_state_dict(state_dict_, strict=strict)
  780. class CodeGeeXModel(torch.nn.Module):
  781. """CodeGeeX: A Multilingual Code Generation Model."""
  782. def __init__(
  783. self,
  784. hidden_size,
  785. num_layers,
  786. num_attention_heads,
  787. padded_vocab_size,
  788. max_position_embeddings,
  789. ):
  790. super(CodeGeeXModel, self).__init__()
  791. self.language_model = TransformerLanguageModel(
  792. hidden_size, num_layers, num_attention_heads, padded_vocab_size,
  793. max_position_embeddings)
  794. self._language_model_key = 'language_model'
  795. def forward(
  796. self,
  797. input_ids,
  798. position_ids,
  799. attention_mask,
  800. layer_past=None,
  801. get_key_value=False,
  802. prompt_length=None,
  803. context_length=None,
  804. ):
  805. # Language model.
  806. lm_output = self.language_model(
  807. input_ids,
  808. position_ids,
  809. attention_mask,
  810. layer_past=layer_past,
  811. get_key_value=get_key_value,
  812. prompt_length=prompt_length,
  813. context_length=context_length)
  814. if get_key_value:
  815. lm_output, presents = lm_output
  816. output = F.linear(
  817. lm_output,
  818. self.language_model.embedding.word_embeddings.weight.half())
  819. if get_key_value:
  820. output = [output, presents]
  821. return output
  822. def state_dict_for_save_checkpoint(self,
  823. destination=None,
  824. prefix='',
  825. keep_vars=False):
  826. state_dict_ = {}
  827. state_dict_[self._language_model_key] \
  828. = self.language_model.state_dict_for_save_checkpoint(
  829. destination, prefix, keep_vars)
  830. return state_dict_
  831. def load_state_dict(self, state_dict, strict=True):
  832. """Customized load."""
  833. if self._language_model_key in state_dict:
  834. state_dict = state_dict[self._language_model_key]
  835. self.language_model.load_state_dict(state_dict, strict=strict)