modeling_cpmant.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807
  1. # coding=utf-8
  2. # Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch CPMAnt"""
  16. import math
  17. from typing import Optional, Union
  18. import torch
  19. import torch.nn.functional as F
  20. from torch import nn
  21. from torch.nn import CrossEntropyLoss
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache
  24. from ...generation import GenerationMixin
  25. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  26. from ...modeling_utils import PreTrainedModel
  27. from ...utils import auto_docstring, logging
  28. from .configuration_cpmant import CpmAntConfig
  29. logger = logging.get_logger(__name__)
  30. class CpmAntLayerNorm(nn.Module):
  31. """
  32. We use Root Mean Square (RMS) Layer Normalization, please see https://huggingface.co/papers/1910.07467 for details."
  33. """
  34. def __init__(self, config: CpmAntConfig):
  35. super().__init__()
  36. self.eps = config.eps
  37. self.dim_norm = config.hidden_size
  38. self.weight = nn.Parameter(torch.empty(config.hidden_size))
  39. def forward(self, hidden_states: torch.Tensor):
  40. """
  41. Args:
  42. hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
  43. """
  44. if hidden_states.size(-1) != self.dim_norm:
  45. raise AssertionError("hidden_states.size(-1) != self.dim_norm")
  46. old_dtype = hidden_states.dtype
  47. variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
  48. hidden_states = (hidden_states * torch.rsqrt(variance + self.eps)).to(old_dtype) * self.weight
  49. return hidden_states
  50. class CpmAntAttention(nn.Module):
  51. def __init__(self, config: CpmAntConfig, layer_idx=None):
  52. super().__init__()
  53. self.dim_model = config.hidden_size
  54. self.num_heads = config.num_attention_heads
  55. self.dim_head = config.dim_head
  56. self.layer_idx = layer_idx
  57. self.project_q = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
  58. self.project_k = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
  59. self.project_v = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
  60. self.attention_out = nn.Linear(self.num_heads * self.dim_head, self.dim_model, bias=False)
  61. self.softmax = torch.nn.Softmax(dim=-1)
  62. if config.dropout_p is not None:
  63. self.dropout = torch.nn.Dropout(p=config.dropout_p)
  64. else:
  65. self.dropout = None
  66. def forward(
  67. self,
  68. hidden_q: torch.Tensor,
  69. hidden_kv: torch.Tensor,
  70. attention_mask: torch.BoolTensor,
  71. position_bias: torch.Tensor,
  72. output_attentions: Optional[bool] = False,
  73. past_key_values: Optional[Cache] = None,
  74. use_cache: Optional[bool] = None,
  75. cache_position: Optional[torch.Tensor] = None,
  76. ):
  77. """
  78. Args:
  79. hidden_q (`torch.Tensor`):
  80. Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
  81. hidden_kv (`torch.Tensor` of shape `(batch, len_k, dim_model)`)):
  82. Tensor *key_value* and *query* of shape `(batch, len_k, dim_model)`
  83. attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
  84. Avoid invalid areas to participate in the calculation of self-attention.
  85. position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
  86. Provide positional information to self-attention block.
  87. output_attentions (`bool`, *optional*):
  88. Whether or not to return the attentions tensors of all attention layers.
  89. past_key_values (`Cache`, *optional*):
  90. Cached past key and value projection states.
  91. use_cache (`bool`, *optional*):
  92. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  93. (see `past_key_values`).
  94. """
  95. batch_size = hidden_q.size(0)
  96. len_q = hidden_q.size(1)
  97. len_k = hidden_kv.size(1)
  98. query = self.project_q(hidden_q)
  99. key = self.project_k(hidden_kv)
  100. value = self.project_v(hidden_kv)
  101. query = query.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
  102. key = key.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
  103. value = value.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
  104. if past_key_values is not None:
  105. key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position})
  106. len_k = key.size(-2)
  107. # (batch_size, num_heads, len_q, dim_head) @ (batch_size, num_heads, dim_head, len_k) -> (batch_size, num_heads, len_q, len_k)
  108. score = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.dim_head)
  109. score = score + position_bias
  110. score = torch.masked_fill(
  111. score,
  112. attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
  113. torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
  114. )
  115. score = self.softmax(score)
  116. score = torch.masked_fill(
  117. score,
  118. attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
  119. torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
  120. )
  121. if output_attentions:
  122. attn_weights = score
  123. else:
  124. attn_weights = None
  125. if self.dropout is not None:
  126. score = self.dropout(score)
  127. # (batch_size, num_heads, len_q, len_k) @ (batch_size, num_heads, len_k, dim_head) -> (batch_size, num_heads, len_q, dim_head)
  128. score = torch.matmul(score, value)
  129. score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3)
  130. score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head)
  131. score = self.attention_out(score)
  132. return score, attn_weights
  133. class CpmAntSelfAttentionBlock(nn.Module):
  134. def __init__(self, config: CpmAntConfig, layer_idx=None):
  135. super().__init__()
  136. self.layernorm_before_attention = CpmAntLayerNorm(config)
  137. self.self_attention = CpmAntAttention(config, layer_idx=layer_idx)
  138. if config.dropout_p:
  139. self.dropout = torch.nn.Dropout(config.dropout_p)
  140. else:
  141. self.dropout = None
  142. def forward(
  143. self,
  144. hidden_states: torch.Tensor,
  145. attention_mask: torch.Tensor,
  146. position_bias: Optional[torch.Tensor] = None,
  147. output_attentions: Optional[bool] = False,
  148. past_key_values: Optional[Cache] = None,
  149. use_cache: Optional[bool] = None,
  150. cache_position: Optional[torch.Tensor] = None,
  151. ):
  152. """
  153. Args:
  154. hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
  155. Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
  156. attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
  157. Avoid invalid areas to participate in the calculation of self-attention.
  158. position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
  159. Provide positional information to self-attention block.
  160. output_attentions (`bool`, *optional*):
  161. Whether or not to return the attentions tensors of all attention layers.
  162. past_key_values (`Cache`, *optional*):
  163. Cached past key and value projection states.
  164. use_cache (`bool`, *optional*):
  165. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  166. (see `past_key_values`).
  167. """
  168. outputs = self.layernorm_before_attention(hidden_states)
  169. outputs, attn_weights = self.self_attention(
  170. outputs,
  171. outputs,
  172. attention_mask,
  173. position_bias,
  174. output_attentions,
  175. past_key_values,
  176. use_cache,
  177. cache_position,
  178. )
  179. if self.dropout is not None:
  180. outputs = self.dropout(outputs)
  181. hidden_states = hidden_states + outputs
  182. return hidden_states, attn_weights
  183. class CpmAntDenseGatedACT(nn.Module):
  184. def __init__(self, config: CpmAntConfig):
  185. super().__init__()
  186. self.w_0 = nn.Linear(config.hidden_size, config.dim_ff, bias=False)
  187. self.w_1 = nn.Linear(config.hidden_size, config.dim_ff, bias=False)
  188. self.act = torch.nn.GELU()
  189. def forward(self, hidden_states: torch.Tensor):
  190. """Transform an input tensor from one feature space to another via a nonlinear operation
  191. Args:
  192. hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
  193. """
  194. gate_score = self.act(self.w_0(hidden_states))
  195. hidden_states = self.w_1(hidden_states)
  196. hidden_states = gate_score * hidden_states
  197. return hidden_states
  198. class CpmAntFeedForward(nn.Module):
  199. def __init__(self, config: CpmAntConfig):
  200. super().__init__()
  201. self.w_in = CpmAntDenseGatedACT(config)
  202. if config.dropout_p is not None:
  203. self.dropout = torch.nn.Dropout(config.dropout_p)
  204. else:
  205. self.dropout = None
  206. self.w_out = nn.Linear(config.dim_ff, config.hidden_size, bias=False)
  207. def forward(self, hidden_states: torch.Tensor):
  208. """
  209. Args:
  210. hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
  211. """
  212. hidden_states = self.w_in(hidden_states)
  213. if self.dropout is not None:
  214. hidden_states = self.dropout(hidden_states)
  215. hidden_states = self.w_out(hidden_states)
  216. return hidden_states
  217. class CpmAntFFNBlock(nn.Module):
  218. def __init__(self, config: CpmAntConfig):
  219. super().__init__()
  220. self.layernorm_before_ffn = CpmAntLayerNorm(config)
  221. self.ffn = CpmAntFeedForward(config)
  222. if config.dropout_p:
  223. self.dropout = torch.nn.Dropout(config.dropout_p)
  224. else:
  225. self.dropout = None
  226. def forward(
  227. self,
  228. hidden_states: torch.Tensor,
  229. ):
  230. """
  231. Args:
  232. hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
  233. Hidden states before feed forward layer.
  234. """
  235. ln_outputs = self.layernorm_before_ffn(hidden_states)
  236. outputs = self.ffn(ln_outputs)
  237. if self.dropout is not None:
  238. outputs = self.dropout(outputs)
  239. hidden_states = hidden_states + outputs
  240. return hidden_states
  241. class CpmAntTransformerBlock(nn.Module):
  242. def __init__(self, config: CpmAntConfig, layer_idx=None):
  243. super().__init__()
  244. self.self_att = CpmAntSelfAttentionBlock(config, layer_idx=layer_idx)
  245. self.ffn = CpmAntFFNBlock(config)
  246. def forward(
  247. self,
  248. hidden_states: torch.Tensor,
  249. attention_mask: torch.Tensor,
  250. position_bias: Optional[torch.Tensor] = None,
  251. output_attentions: Optional[bool] = False,
  252. past_key_values: Optional[Cache] = None,
  253. use_cache: Optional[bool] = None,
  254. cache_position: Optional[torch.Tensor] = None,
  255. ):
  256. """
  257. Args:
  258. hidden_states (`torch.Tensor`):
  259. Input to the layer of shape `(batch, seq_len, dim_model)`
  260. attention_mask (`torch.Tensor`):
  261. Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
  262. position_bias (`torch.Tensor`):
  263. Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
  264. output_attentions (`bool`, *optional*):
  265. Whether or not to return the attentions tensors of all attention layers.
  266. past_key_values (`Cache`, *optional*):
  267. Cached past key and value projection states
  268. use_cache (`bool`, *optional*):
  269. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  270. (see `past_key_values`).
  271. """
  272. hidden_states, attn_weights = self.self_att(
  273. hidden_states,
  274. attention_mask=attention_mask,
  275. position_bias=position_bias,
  276. output_attentions=output_attentions,
  277. past_key_values=past_key_values,
  278. use_cache=use_cache,
  279. cache_position=cache_position,
  280. )
  281. hidden_states = self.ffn(hidden_states)
  282. return hidden_states, attn_weights
  283. class CpmAntEncoder(nn.Module):
  284. def __init__(self, config: CpmAntConfig):
  285. super().__init__()
  286. self.num_layers = config.num_hidden_layers
  287. self.layers = nn.ModuleList([CpmAntTransformerBlock(config, layer_idx=i) for i in range(self.num_layers)])
  288. self.output_layernorm = CpmAntLayerNorm(config)
  289. def forward(
  290. self,
  291. hidden_states: torch.Tensor,
  292. attention_mask: torch.Tensor,
  293. position_bias: torch.Tensor,
  294. output_attentions: Optional[bool] = None,
  295. output_hidden_states: Optional[bool] = None,
  296. past_key_values: Optional[Cache] = None,
  297. use_cache: Optional[bool] = None,
  298. cache_position: Optional[torch.Tensor] = None,
  299. ):
  300. """
  301. Args:
  302. hidden_states (`torch.Tensor`):
  303. Input to the layer of shape `(batch, seq_len, dim_model)`
  304. attention_mask (`torch.Tensor`):
  305. Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
  306. position_bias (`torch.Tensor`):
  307. Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
  308. output_attentions (`bool`, *optional*):
  309. Whether or not to return the attentions tensors of all attention layers.
  310. output_hidden_states (`bool`, *optional*):
  311. Whether or not to return the hidden states of all layers.
  312. past_key_values (`Cache`, *optional*):
  313. Cached past key and value projection states
  314. use_cache (`bool`, *optional*):
  315. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  316. (see `past_key_values`).
  317. """
  318. all_hidden_states = () if output_hidden_states else None
  319. all_self_attns = () if output_attentions else None
  320. for i, layer in enumerate(self.layers):
  321. if output_hidden_states:
  322. all_hidden_states += (hidden_states,)
  323. layer_outputs = layer(
  324. hidden_states,
  325. attention_mask,
  326. position_bias,
  327. output_attentions=output_attentions,
  328. past_key_values=past_key_values,
  329. use_cache=use_cache,
  330. )
  331. hidden_states, attn_weights = layer_outputs
  332. if output_attentions:
  333. all_self_attns += (attn_weights,)
  334. hidden_states = self.output_layernorm(hidden_states)
  335. if output_hidden_states:
  336. all_hidden_states += (hidden_states,)
  337. return hidden_states, all_hidden_states, all_self_attns
  338. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->CPMAnt
  339. class CpmAntIntermediate(nn.Module):
  340. def __init__(self, config):
  341. super().__init__()
  342. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  343. if isinstance(config.hidden_act, str):
  344. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  345. else:
  346. self.intermediate_act_fn = config.hidden_act
  347. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  348. hidden_states = self.dense(hidden_states)
  349. hidden_states = self.intermediate_act_fn(hidden_states)
  350. return hidden_states
  351. class CpmAntSegmentPositionEmbedding(nn.Module):
  352. def __init__(self, config: CpmAntConfig):
  353. super().__init__()
  354. self.num_heads = config.num_attention_heads
  355. self.num_buckets = config.position_bias_num_buckets
  356. self.max_distance = config.position_bias_max_distance
  357. self.num_segments = config.segment_types
  358. self.relative_attention_bias = nn.Parameter(
  359. torch.empty(
  360. config.segment_types * config.segment_types + config.position_bias_num_buckets,
  361. config.num_attention_heads,
  362. )
  363. )
  364. def forward(
  365. self,
  366. key_pos: torch.Tensor,
  367. query_pos: torch.Tensor,
  368. key_segment: torch.Tensor,
  369. query_segment: torch.Tensor,
  370. ):
  371. with torch.no_grad():
  372. batch = key_pos.size(0)
  373. keylen = key_pos.size(1)
  374. querylen = query_pos.size(1)
  375. if key_pos.size(0) != query_pos.size(0):
  376. raise AssertionError(
  377. f"key_pos.size(0) should be equal to query_pos.size(0), but got {key_pos.size(0)} and {query_pos.size(0)}!"
  378. )
  379. if keylen != key_segment.size(1) or querylen != query_segment.size(1):
  380. raise AssertionError(
  381. f"keylen should be equal to key_segment.size(1), but got {keylen} and {key_segment.size(1)}!"
  382. )
  383. if querylen != query_segment.size(1):
  384. raise AssertionError(
  385. f"querylen should be equal to query_segment.size(1), but got {querylen} and {query_segment.size(1)}!"
  386. )
  387. key_pos = key_pos.view(batch, -1, keylen)
  388. query_pos = query_pos.view(batch, querylen, -1)
  389. key_segment = key_segment.view(batch, -1, keylen)
  390. query_segment = query_segment.view(batch, querylen, -1)
  391. relative_position_bucket = self._segment_relative_position_bucket(query_segment, key_segment)
  392. relative_position_bucket = relative_position_bucket + self.num_buckets
  393. # (batch, len_q, len_k)
  394. absolute_position_bucket = self._position_bucket(
  395. torch.arange(keylen, dtype=torch.int32, device=relative_position_bucket.device)[None, :]
  396. - torch.arange(querylen, dtype=torch.int32, device=relative_position_bucket.device)[:, None],
  397. num_buckets=self.num_buckets,
  398. max_distance=self.max_distance,
  399. )
  400. relative_position_bucket = torch.where(
  401. (key_segment == query_segment),
  402. absolute_position_bucket[None, :, :],
  403. relative_position_bucket,
  404. )
  405. # (batch, len_q, len_k, num_heads)
  406. embeds = F.embedding(relative_position_bucket, self.relative_attention_bias)
  407. # (batch, num_heads, len_q, len_k)
  408. embeds = embeds.permute(0, 3, 1, 2).contiguous()
  409. return embeds
  410. def _segment_relative_position_bucket(self, query_segment, key_segment):
  411. return query_segment * self.num_segments + key_segment
  412. def _position_bucket(self, relative_position, num_buckets=32, max_distance=128):
  413. relative_buckets = 0
  414. # always bidirectional in CPMAnt
  415. num_buckets //= 2
  416. relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
  417. relative_position = torch.abs(relative_position)
  418. max_exact = num_buckets // 2
  419. is_small = relative_position < max_exact
  420. relative_position_if_large = max_exact + (
  421. torch.log(relative_position.float() / max_exact)
  422. / math.log(max_distance / max_exact)
  423. * (num_buckets - max_exact)
  424. ).to(torch.int32)
  425. relative_position_if_large = torch.min(
  426. relative_position_if_large,
  427. torch.full_like(relative_position_if_large, num_buckets - 1),
  428. )
  429. relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_position_if_large)
  430. return relative_buckets
  431. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->CPMAnt
  432. class CpmAntOutput(nn.Module):
  433. def __init__(self, config):
  434. super().__init__()
  435. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  436. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  437. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  438. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  439. hidden_states = self.dense(hidden_states)
  440. hidden_states = self.dropout(hidden_states)
  441. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  442. return hidden_states
  443. @auto_docstring
  444. class CpmAntPreTrainedModel(PreTrainedModel):
  445. config: CpmAntConfig
  446. base_model_prefix = "cpmant"
  447. def _init_weights(self, module):
  448. """Initialize the weights"""
  449. if isinstance(module, nn.Linear):
  450. module.weight.data.normal_(mean=0.0, std=self.config.init_std)
  451. if module.bias is not None:
  452. module.bias.data.zero_()
  453. elif isinstance(module, nn.Embedding):
  454. module.weight.data.normal_(mean=0.0, std=self.config.init_std)
  455. if module.padding_idx is not None:
  456. module.weight.data[module.padding_idx].zero_()
  457. elif isinstance(module, nn.LayerNorm):
  458. module.bias.data.zero_()
  459. module.weight.data.fill_(1.0)
  460. elif isinstance(module, CpmAntLayerNorm):
  461. module.weight.data.fill_(1.0)
  462. elif isinstance(module, CpmAntSegmentPositionEmbedding):
  463. module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std)
  464. @auto_docstring
  465. class CpmAntModel(CpmAntPreTrainedModel):
  466. def __init__(self, config: CpmAntConfig):
  467. super().__init__(config)
  468. self.encoder = CpmAntEncoder(config)
  469. self.segment_embedding = nn.Embedding(config.segment_types, config.hidden_size)
  470. self.input_embedding = nn.Embedding(
  471. config.vocab_size + config.prompt_types * config.prompt_length, config.hidden_size
  472. )
  473. self.position_bias = CpmAntSegmentPositionEmbedding(config)
  474. self.prompt_length = config.prompt_length
  475. self.vocab_size = config.vocab_size
  476. self.post_init()
  477. def get_input_embeddings(self):
  478. return self.input_embedding
  479. def set_input_embeddings(self, embeddings, **kwargs):
  480. self.input_embedding = embeddings
  481. def _prepare_attention_mask(self, input_ids, span, context, length):
  482. batch = input_ids.size(0)
  483. seqlen = input_ids.size(1)
  484. device = input_ids.device
  485. directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1)
  486. attention_mask = context[:, None, :] | (
  487. context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
  488. )
  489. attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
  490. # mask for left padding
  491. mask_1d = (
  492. torch.tensor(list(range(seqlen - self.prompt_length))[::-1], device=device)[None, :].repeat(batch, 1)
  493. < length[:, None]
  494. )
  495. mask_1d = torch.cat((torch.ones(batch, self.prompt_length, device=device).bool(), mask_1d), dim=1)
  496. attention_mask = mask_1d.view(batch, seqlen, 1) & mask_1d.view(batch, 1, seqlen) & attention_mask
  497. return attention_mask
  498. @auto_docstring
  499. def forward(
  500. self,
  501. input_ids: Optional[torch.Tensor] = None,
  502. output_attentions: Optional[bool] = None,
  503. output_hidden_states: Optional[bool] = None,
  504. past_key_values: Optional[Cache] = None,
  505. use_cache: Optional[bool] = None,
  506. return_dict: Optional[bool] = None,
  507. cache_position: Optional[torch.Tensor] = None,
  508. **kwargs,
  509. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPast]:
  510. r"""
  511. input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
  512. Indices of input sequence tokens in the vocabulary.
  513. Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  514. [`PreTrainedTokenizer.__call__`] for details.
  515. [What are input IDs?](../glossary#input-ids)
  516. """
  517. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  518. output_hidden_states = (
  519. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  520. )
  521. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  522. use_cache = use_cache if use_cache is not None else self.config.use_cache
  523. # add prompts ahead
  524. if input_ids.dtype != torch.int32:
  525. input_ids = input_ids.to(torch.int32)
  526. dtype, device = input_ids.dtype, input_ids.device
  527. segment = torch.where(input_ids != 0, 2, 0).to(dtype=dtype, device=device)
  528. length = (segment != 0).sum(-1).to(dtype=dtype, device=device)
  529. input_ids = torch.cat(
  530. (
  531. torch.arange(
  532. self.prompt_length * 2 + self.vocab_size,
  533. self.prompt_length * 3 + self.vocab_size,
  534. dtype=dtype,
  535. device=device,
  536. ).repeat(input_ids.size(0), 1),
  537. input_ids,
  538. ),
  539. dim=1,
  540. )
  541. batch, seq_length = input_ids.size()
  542. segment = torch.cat((torch.zeros(batch, self.prompt_length, dtype=dtype, device=device), segment), dim=1)
  543. context = torch.full((batch, seq_length), 1, dtype=dtype, device=device)
  544. position = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1)
  545. span = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
  546. if use_cache and past_key_values is None:
  547. past_key_values = DynamicCache(config=self.config)
  548. if use_cache and isinstance(past_key_values, tuple):
  549. logger.warning_once(
  550. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  551. "You should pass an instance of `DynamicCache` instead, e.g. "
  552. "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
  553. )
  554. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  555. past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  556. input_ids = input_ids.contiguous()
  557. hidden_states = self.input_embedding(input_ids)
  558. segment_states = self.segment_embedding(segment)
  559. if past_length != 0:
  560. segment_states = segment_states[:, -1:, :]
  561. hidden_states = hidden_states + segment_states
  562. attention_mask = self._prepare_attention_mask(input_ids, span, context, length)
  563. position_bias = self.position_bias(position, position, segment, segment)
  564. attention_mask = attention_mask[:, past_length:, :]
  565. position_bias = position_bias[:, :, past_length:, :]
  566. hidden_states = hidden_states[:, past_length:, :]
  567. hidden_states, all_hidden_states, all_attentions = self.encoder(
  568. hidden_states,
  569. attention_mask,
  570. position_bias,
  571. output_attentions,
  572. output_hidden_states,
  573. past_key_values,
  574. use_cache,
  575. cache_position,
  576. )
  577. if past_length == 0:
  578. hidden_states = hidden_states[:, self.prompt_length :, :]
  579. # drop the prompt
  580. if all_attentions is not None:
  581. new_attentions = ()
  582. for attention in all_attentions:
  583. new_attentions += (attention[:, :, self.prompt_length :, self.prompt_length :],)
  584. all_attentions = new_attentions
  585. if all_hidden_states is not None:
  586. new_hidden_states = ()
  587. for hidden_state in all_hidden_states:
  588. new_hidden_states += (hidden_state[:, self.prompt_length :, :],)
  589. all_hidden_states = new_hidden_states
  590. if not return_dict:
  591. return tuple(
  592. v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None
  593. )
  594. return BaseModelOutputWithPast(
  595. last_hidden_state=hidden_states,
  596. past_key_values=past_key_values,
  597. hidden_states=all_hidden_states,
  598. attentions=all_attentions,
  599. )
  600. @auto_docstring(
  601. custom_intro="""
  602. The CPMAnt Model with a language modeling head on top (linear layer with weights tied to the input embeddings).
  603. """
  604. )
  605. class CpmAntForCausalLM(CpmAntPreTrainedModel, GenerationMixin):
  606. _tied_weights_keys = ["lm_head.weight"]
  607. def __init__(self, config: CpmAntConfig):
  608. super().__init__(config)
  609. self.cpmant = CpmAntModel(config)
  610. # lm_head.weight is tied to cpmant.input_embedding.weight
  611. self.lm_head = nn.Linear(
  612. config.hidden_size, config.vocab_size + config.prompt_types * config.prompt_length, bias=False
  613. )
  614. self.post_init()
  615. @auto_docstring
  616. def forward(
  617. self,
  618. input_ids: Optional[torch.Tensor] = None,
  619. past_key_values: Optional[Cache] = None,
  620. use_cache: Optional[bool] = None,
  621. output_attentions: Optional[bool] = None,
  622. output_hidden_states: Optional[bool] = None,
  623. labels: Optional[torch.Tensor] = None,
  624. return_dict: Optional[bool] = None,
  625. attention_mask: Optional[torch.Tensor] = None, # dummy parameter for text-generation pipeline
  626. cache_position: Optional[torch.Tensor] = None,
  627. **kwargs,
  628. ) -> Union[tuple, CausalLMOutputWithPast]:
  629. r"""
  630. input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
  631. Indices of input sequence tokens in the vocabulary.
  632. Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  633. [`PreTrainedTokenizer.__call__`] for details.
  634. [What are input IDs?](../glossary#input-ids)
  635. labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  636. Labels for computing the masked language modeling loss.
  637. Example:
  638. Text Generation with CpmAntForCausalLM.
  639. ```python
  640. >>> from transformers import CPMAntTokenizer, CpmAntForCausalLM
  641. >>> texts = "今天天气不错,"
  642. >>> model = CpmAntForCausalLM.from_pretrained("openbmb/cpm-ant-10b")
  643. >>> tokenizer = CPMAntTokenizer.from_pretrained("openbmb/cpm-ant-10b")
  644. >>> input_ids = tokenizer(texts, return_tensors="pt")
  645. >>> outputs = model.generate(**input_ids)
  646. >>> output_texts = tokenizer.batch_decode(outputs)
  647. >>> print(output_texts)
  648. ['今天天气不错,阳光明媚,我和妈妈一起去超市买东西。\n在超市里,我看到了一个很好玩的玩具,它的名字叫“机器人”。它有一个圆圆的脑袋,两只圆圆的眼睛,还有一个圆圆的']
  649. ```
  650. """
  651. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  652. model_output = self.cpmant(
  653. input_ids,
  654. output_attentions,
  655. output_hidden_states,
  656. past_key_values,
  657. use_cache,
  658. return_dict,
  659. cache_position,
  660. )
  661. hidden_states = model_output.last_hidden_state if return_dict else model_output[0]
  662. logits = self.lm_head(hidden_states)
  663. loss = None
  664. if labels is not None:
  665. loss_func = CrossEntropyLoss()
  666. loss = loss_func(logits.view(-1, logits.size(-1)), labels.view(-1))
  667. if not return_dict:
  668. output = (logits,) + model_output[1:]
  669. return ((loss,) + output) if loss is not None else output
  670. return CausalLMOutputWithPast(
  671. loss=loss,
  672. logits=logits,
  673. past_key_values=model_output.past_key_values,
  674. hidden_states=model_output.hidden_states,
  675. attentions=model_output.attentions,
  676. )
  677. def get_input_embeddings(self):
  678. return self.cpmant.input_embedding
  679. def set_input_embeddings(self, embeddings):
  680. self.cpmant.input_embedding = embeddings
  681. def _reorder_cache(self, past_key_values, beam_idx):
  682. past_key_values = [list(each) if each is not None else each for each in past_key_values]
  683. for key_value_layer in past_key_values:
  684. key_value_layer[0] = key_value_layer[0][beam_idx]
  685. key_value_layer[1] = key_value_layer[1][beam_idx]
  686. return past_key_values
  687. __all__ = ["CpmAntForCausalLM", "CpmAntModel", "CpmAntPreTrainedModel"]