modeling_ctrl.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787
  1. # coding=utf-8
  2. # Copyright 2018 Salesforce and HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """PyTorch CTRL model."""
  17. from typing import Optional, Union
  18. import numpy as np
  19. import torch
  20. from torch import nn
  21. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  22. from ...cache_utils import Cache, DynamicCache
  23. from ...generation import GenerationMixin
  24. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
  25. from ...modeling_utils import PreTrainedModel
  26. from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_linear_layer
  27. from ...utils import (
  28. auto_docstring,
  29. logging,
  30. )
  31. from .configuration_ctrl import CTRLConfig
  32. logger = logging.get_logger(__name__)
  33. def angle_defn(pos, i, d_model_size):
  34. angle_rates = 1 / torch.pow(10000, (2 * (i // 2)) / d_model_size)
  35. return pos * angle_rates
  36. def positional_encoding(position, d_model_size, dtype):
  37. # create the sinusoidal pattern for the positional encoding
  38. angle_rads = angle_defn(
  39. torch.arange(position, dtype=torch.int64).to(dtype).unsqueeze(1),
  40. torch.arange(d_model_size, dtype=torch.int64).to(dtype).unsqueeze(0),
  41. d_model_size,
  42. )
  43. sines = torch.sin(angle_rads[:, 0::2])
  44. cosines = torch.cos(angle_rads[:, 1::2])
  45. pos_encoding = torch.cat([sines, cosines], dim=-1)
  46. return pos_encoding
  47. def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None):
  48. # calculate attention
  49. matmul_qk = torch.matmul(q, k.permute(0, 1, 3, 2))
  50. dk = k.shape[-1]
  51. scaled_attention_logits = matmul_qk / np.sqrt(dk)
  52. if mask is not None:
  53. nd, ns = scaled_attention_logits.size(-2), scaled_attention_logits.size(-1)
  54. scaled_attention_logits += mask[ns - nd : ns, :ns] * -1e4
  55. if attention_mask is not None:
  56. # Apply the attention mask
  57. scaled_attention_logits = scaled_attention_logits + attention_mask
  58. attention_weights = torch.softmax(scaled_attention_logits, dim=-1)
  59. # Mask heads if we want to
  60. if head_mask is not None:
  61. attention_weights = attention_weights * head_mask
  62. output = torch.matmul(attention_weights, v)
  63. return output, attention_weights
  64. class MultiHeadAttention(nn.Module):
  65. def __init__(self, d_model_size, num_heads, layer_idx=None):
  66. super().__init__()
  67. self.num_heads = num_heads
  68. self.d_model_size = d_model_size
  69. self.layer_idx = layer_idx
  70. self.depth = int(d_model_size / self.num_heads)
  71. self.Wq = nn.Linear(d_model_size, d_model_size)
  72. self.Wk = nn.Linear(d_model_size, d_model_size)
  73. self.Wv = nn.Linear(d_model_size, d_model_size)
  74. self.dense = nn.Linear(d_model_size, d_model_size)
  75. self.pruned_heads = set()
  76. def prune_heads(self, heads):
  77. attention_head_size = self.d_model_size // self.num_heads
  78. if len(heads) == 0:
  79. return
  80. heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, attention_head_size, self.pruned_heads)
  81. # Prune linear layers
  82. self.Wq = prune_linear_layer(self.Wq, index)
  83. self.Wk = prune_linear_layer(self.Wk, index)
  84. self.Wv = prune_linear_layer(self.Wv, index)
  85. self.dense = prune_linear_layer(self.dense, index, dim=1)
  86. # Update hyper params
  87. self.num_heads = self.num_heads - len(heads)
  88. self.d_model_size = attention_head_size * self.num_heads
  89. self.pruned_heads = self.pruned_heads.union(heads)
  90. def split_into_heads(self, x, batch_size):
  91. x = x.reshape(batch_size, -1, self.num_heads, self.depth)
  92. return x.permute([0, 2, 1, 3])
  93. def forward(
  94. self,
  95. v,
  96. k,
  97. q,
  98. mask,
  99. layer_past=None,
  100. attention_mask=None,
  101. head_mask=None,
  102. use_cache=False,
  103. output_attentions=False,
  104. cache_position=None,
  105. ):
  106. batch_size = q.shape[0]
  107. q = self.Wq(q)
  108. k = self.Wk(k)
  109. v = self.Wv(v)
  110. q = self.split_into_heads(q, batch_size)
  111. k = self.split_into_heads(k, batch_size)
  112. v = self.split_into_heads(v, batch_size)
  113. if layer_past is not None:
  114. k, v = layer_past.update(k, v, self.layer_idx, {"cache_position": cache_position})
  115. output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
  116. scaled_attention = output[0].permute([0, 2, 1, 3])
  117. attn = output[1]
  118. original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size)
  119. output = self.dense(original_size_attention)
  120. return output, attn
  121. def point_wise_feed_forward_network(d_model_size, dff):
  122. return nn.Sequential(nn.Linear(d_model_size, dff), nn.ReLU(), nn.Linear(dff, d_model_size))
  123. class EncoderLayer(nn.Module):
  124. def __init__(self, d_model_size, num_heads, dff, rate=0.1, layer_idx=None):
  125. super().__init__()
  126. self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads, layer_idx=layer_idx)
  127. self.ffn = point_wise_feed_forward_network(d_model_size, dff)
  128. self.layernorm1 = nn.LayerNorm(d_model_size, eps=1e-6)
  129. self.layernorm2 = nn.LayerNorm(d_model_size, eps=1e-6)
  130. self.dropout1 = nn.Dropout(rate)
  131. self.dropout2 = nn.Dropout(rate)
  132. def forward(
  133. self,
  134. x,
  135. mask,
  136. layer_past=None,
  137. attention_mask=None,
  138. head_mask=None,
  139. use_cache=False,
  140. output_attentions=False,
  141. cache_position=None,
  142. ):
  143. normed = self.layernorm1(x)
  144. attn_outputs = self.multi_head_attention(
  145. normed,
  146. normed,
  147. normed,
  148. mask,
  149. layer_past=layer_past,
  150. attention_mask=attention_mask,
  151. head_mask=head_mask,
  152. use_cache=use_cache,
  153. output_attentions=output_attentions,
  154. cache_position=cache_position,
  155. )
  156. attn_output = attn_outputs[0]
  157. attn_output = self.dropout1(attn_output)
  158. out1 = x + attn_output
  159. out2 = self.layernorm2(out1)
  160. ffn_output = self.ffn(out2)
  161. ffn_output = self.dropout2(ffn_output)
  162. out2 = out1 + ffn_output
  163. outputs = (out2,) + attn_outputs[1:]
  164. return outputs
  165. @auto_docstring
  166. class CTRLPreTrainedModel(PreTrainedModel):
  167. config: CTRLConfig
  168. base_model_prefix = "transformer"
  169. def _init_weights(self, module):
  170. """Initialize the weights."""
  171. if isinstance(module, (nn.Linear, Conv1D)):
  172. # Slightly different from the TF version which uses truncated_normal for initialization
  173. # cf https://github.com/pytorch/pytorch/pull/5617
  174. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  175. if module.bias is not None:
  176. module.bias.data.zero_()
  177. elif isinstance(module, nn.Embedding):
  178. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  179. if module.padding_idx is not None:
  180. module.weight.data[module.padding_idx].zero_()
  181. elif isinstance(module, nn.LayerNorm):
  182. module.bias.data.zero_()
  183. module.weight.data.fill_(1.0)
  184. @auto_docstring
  185. class CTRLModel(CTRLPreTrainedModel):
  186. def __init__(self, config):
  187. super().__init__(config)
  188. self.d_model_size = config.n_embd
  189. self.num_layers = config.n_layer
  190. self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size, torch.float)
  191. self.w = nn.Embedding(config.vocab_size, config.n_embd)
  192. self.dropout = nn.Dropout(config.embd_pdrop)
  193. self.h = nn.ModuleList(
  194. [
  195. EncoderLayer(config.n_embd, config.n_head, config.dff, config.resid_pdrop, layer_idx=i)
  196. for i in range(config.n_layer)
  197. ]
  198. )
  199. self.layernorm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
  200. # Initialize weights and apply final processing
  201. self.post_init()
  202. def get_input_embeddings(self):
  203. return self.w
  204. def set_input_embeddings(self, new_embeddings):
  205. self.w = new_embeddings
  206. def _prune_heads(self, heads_to_prune):
  207. """
  208. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
  209. """
  210. for layer, heads in heads_to_prune.items():
  211. self.h[layer].multi_head_attention.prune_heads(heads)
  212. @auto_docstring
  213. def forward(
  214. self,
  215. input_ids: Optional[torch.LongTensor] = None,
  216. past_key_values: Optional[Cache] = None,
  217. attention_mask: Optional[torch.FloatTensor] = None,
  218. token_type_ids: Optional[torch.LongTensor] = None,
  219. position_ids: Optional[torch.LongTensor] = None,
  220. head_mask: Optional[torch.FloatTensor] = None,
  221. inputs_embeds: Optional[torch.FloatTensor] = None,
  222. use_cache: Optional[bool] = None,
  223. output_attentions: Optional[bool] = None,
  224. output_hidden_states: Optional[bool] = None,
  225. return_dict: Optional[bool] = None,
  226. cache_position: Optional[torch.Tensor] = None,
  227. **kwargs, # NOOP kwargs, for now
  228. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPast]:
  229. r"""
  230. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  231. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0].shape[-2]`
  232. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  233. If `past_key_values` is used, only input IDs that do not have their past calculated should be passed as
  234. `input_ids`.
  235. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
  236. [`PreTrainedTokenizer.encode`] for details.
  237. [What are input IDs?](../glossary#input-ids)
  238. Example:
  239. ```python
  240. >>> from transformers import AutoTokenizer, CTRLModel
  241. >>> import torch
  242. >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
  243. >>> model = CTRLModel.from_pretrained("Salesforce/ctrl")
  244. >>> # CTRL was trained with control codes as the first token
  245. >>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
  246. >>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
  247. >>> outputs = model(**inputs)
  248. >>> last_hidden_states = outputs.last_hidden_state
  249. >>> list(last_hidden_states.shape)
  250. [1, 5, 1280]
  251. ```"""
  252. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  253. use_cache = use_cache if use_cache is not None else self.config.use_cache
  254. output_hidden_states = (
  255. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  256. )
  257. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  258. if input_ids is not None and inputs_embeds is not None:
  259. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  260. elif input_ids is not None:
  261. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  262. input_shape = input_ids.size()
  263. input_ids = input_ids.view(-1, input_shape[-1])
  264. batch_size = input_ids.shape[0]
  265. elif inputs_embeds is not None:
  266. input_shape = inputs_embeds.size()[:-1]
  267. batch_size = inputs_embeds.shape[0]
  268. else:
  269. raise ValueError("You have to specify either input_ids or inputs_embeds")
  270. device = input_ids.device if input_ids is not None else inputs_embeds.device
  271. if use_cache and past_key_values is None:
  272. past_key_values = DynamicCache(config=self.config)
  273. if use_cache and isinstance(past_key_values, tuple):
  274. logger.warning_once(
  275. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  276. "You should pass an instance of `DynamicCache` instead, e.g. "
  277. "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
  278. )
  279. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  280. past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  281. if position_ids is None:
  282. position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
  283. position_ids = position_ids.unsqueeze(0)
  284. # Attention mask.
  285. if attention_mask is not None:
  286. if batch_size <= 0:
  287. raise ValueError("batch_size has to be defined and > 0")
  288. attention_mask = attention_mask.view(batch_size, -1)
  289. # We create a 3D attention mask from a 2D tensor mask.
  290. # Sizes are [batch_size, 1, 1, to_seq_length]
  291. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  292. # this attention mask is more simple than the triangular masking of causal attention
  293. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  294. attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  295. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  296. # masked positions, this operation will create a tensor which is 0.0 for
  297. # positions we want to attend and the dtype's smallest value for masked positions.
  298. # Since we are adding it to the raw scores before the softmax, this is
  299. # effectively the same as removing these entirely.
  300. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
  301. attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
  302. # Prepare head mask if needed
  303. head_mask = self.get_head_mask(head_mask, self.config.n_layer)
  304. if token_type_ids is not None:
  305. token_type_ids = token_type_ids.view(-1, input_shape[-1])
  306. token_type_embeds = self.w(token_type_ids)
  307. token_type_embeds *= np.sqrt(self.d_model_size)
  308. else:
  309. token_type_embeds = 0
  310. if inputs_embeds is None:
  311. inputs_embeds = self.w(input_ids)
  312. # inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
  313. seq_len = input_shape[-1]
  314. mask = torch.triu(torch.ones(seq_len + past_length, seq_len + past_length), 1).to(device)
  315. inputs_embeds *= np.sqrt(self.d_model_size)
  316. # `self.pos_encoding` won't be sent to the correct device along the model, so we do it manually.
  317. self.pos_encoding = self.pos_encoding.to(device)
  318. pos_embeds = self.pos_encoding[position_ids, :]
  319. hidden_states = inputs_embeds + pos_embeds + token_type_embeds
  320. hidden_states = self.dropout(hidden_states)
  321. all_hidden_states = () if output_hidden_states else None
  322. all_attentions = () if output_attentions else None
  323. for i, h in enumerate(self.h):
  324. if output_hidden_states:
  325. all_hidden_states = all_hidden_states + (hidden_states,)
  326. outputs = h(
  327. hidden_states,
  328. mask,
  329. layer_past=past_key_values,
  330. attention_mask=attention_mask,
  331. head_mask=head_mask[i],
  332. use_cache=use_cache,
  333. output_attentions=output_attentions,
  334. cache_position=cache_position,
  335. )
  336. hidden_states = outputs[0]
  337. if output_attentions:
  338. all_attentions += (outputs[1],)
  339. hidden_states = self.layernorm(hidden_states)
  340. if output_hidden_states:
  341. all_hidden_states = all_hidden_states + (hidden_states,)
  342. if not return_dict:
  343. return tuple(
  344. v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None
  345. )
  346. return BaseModelOutputWithPast(
  347. last_hidden_state=hidden_states,
  348. past_key_values=past_key_values,
  349. hidden_states=all_hidden_states,
  350. attentions=all_attentions,
  351. )
  352. @auto_docstring(
  353. custom_intro="""
  354. The CTRL Model transformer with a language modeling head on top (linear layer with weights tied to the input
  355. embeddings).
  356. """
  357. )
  358. class CTRLLMHeadModel(CTRLPreTrainedModel, GenerationMixin):
  359. _tied_weights_keys = ["lm_head.weight"]
  360. def __init__(self, config):
  361. super().__init__(config)
  362. self.transformer = CTRLModel(config)
  363. self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=True)
  364. # Initialize weights and apply final processing
  365. self.post_init()
  366. @auto_docstring
  367. def forward(
  368. self,
  369. input_ids: Optional[torch.LongTensor] = None,
  370. past_key_values: Optional[Cache] = None,
  371. attention_mask: Optional[torch.FloatTensor] = None,
  372. token_type_ids: Optional[torch.LongTensor] = None,
  373. position_ids: Optional[torch.LongTensor] = None,
  374. head_mask: Optional[torch.FloatTensor] = None,
  375. inputs_embeds: Optional[torch.FloatTensor] = None,
  376. labels: Optional[torch.LongTensor] = None,
  377. use_cache: Optional[bool] = None,
  378. output_attentions: Optional[bool] = None,
  379. output_hidden_states: Optional[bool] = None,
  380. return_dict: Optional[bool] = None,
  381. cache_position: Optional[torch.Tensor] = None,
  382. **kwargs,
  383. ) -> Union[tuple[torch.Tensor], CausalLMOutputWithPast]:
  384. r"""
  385. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  386. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0].shape[-2]`
  387. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  388. If `past_key_values` is used, only input IDs that do not have their past calculated should be passed as
  389. `input_ids`.
  390. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
  391. [`PreTrainedTokenizer.encode`] for details.
  392. [What are input IDs?](../glossary#input-ids)
  393. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  394. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  395. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  396. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  397. Example:
  398. ```python
  399. >>> import torch
  400. >>> from transformers import AutoTokenizer, CTRLLMHeadModel
  401. >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
  402. >>> model = CTRLLMHeadModel.from_pretrained("Salesforce/ctrl")
  403. >>> # CTRL was trained with control codes as the first token
  404. >>> inputs = tokenizer("Wikipedia The llama is", return_tensors="pt")
  405. >>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
  406. >>> sequence_ids = model.generate(inputs["input_ids"])
  407. >>> sequences = tokenizer.batch_decode(sequence_ids)
  408. >>> sequences
  409. ['Wikipedia The llama is a member of the family Bovidae. It is native to the Andes of Peru,']
  410. >>> outputs = model(**inputs, labels=inputs["input_ids"])
  411. >>> round(outputs.loss.item(), 2)
  412. 9.21
  413. >>> list(outputs.logits.shape)
  414. [1, 5, 246534]
  415. ```"""
  416. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  417. transformer_outputs = self.transformer(
  418. input_ids,
  419. past_key_values=past_key_values,
  420. attention_mask=attention_mask,
  421. token_type_ids=token_type_ids,
  422. position_ids=position_ids,
  423. head_mask=head_mask,
  424. inputs_embeds=inputs_embeds,
  425. use_cache=use_cache,
  426. output_attentions=output_attentions,
  427. output_hidden_states=output_hidden_states,
  428. return_dict=return_dict,
  429. cache_position=cache_position,
  430. )
  431. hidden_states = transformer_outputs[0]
  432. lm_logits = self.lm_head(hidden_states)
  433. loss = None
  434. if labels is not None:
  435. loss = self.loss_function(
  436. lm_logits,
  437. labels,
  438. vocab_size=self.config.vocab_size,
  439. **kwargs,
  440. )
  441. if not return_dict:
  442. output = (lm_logits,) + transformer_outputs[1:]
  443. return ((loss,) + output) if loss is not None else output
  444. return CausalLMOutputWithPast(
  445. loss=loss,
  446. logits=lm_logits,
  447. past_key_values=transformer_outputs.past_key_values,
  448. hidden_states=transformer_outputs.hidden_states,
  449. attentions=transformer_outputs.attentions,
  450. )
  451. def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs):
  452. # Overwritten -- inputs_embeds not working properly
  453. # only last tokens for inputs_ids if past is defined in kwargs
  454. if past_key_values is not None:
  455. past_length = past_key_values.get_seq_length()
  456. # Some generation methods already pass only the last input ID
  457. if input_ids.shape[1] > past_length:
  458. remove_prefix_length = past_length
  459. else:
  460. # Default to old behavior: keep only final ID
  461. remove_prefix_length = input_ids.shape[1] - 1
  462. input_ids = input_ids[:, remove_prefix_length:]
  463. model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache}
  464. # token_type_ids are computed on CTRLModel.forward()
  465. kwargs.pop("token_type_ids", None)
  466. # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
  467. for key, value in kwargs.items():
  468. if key not in model_inputs:
  469. print(f"Warning: {key} is not a recognized input.")
  470. model_inputs[key] = value
  471. return model_inputs
  472. @auto_docstring(
  473. custom_intro="""
  474. The CTRL Model transformer with a sequence classification head on top (linear layer).
  475. [`CTRLForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  476. (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the last
  477. token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in
  478. each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
  479. guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last
  480. value in each row of the batch).
  481. """
  482. )
  483. class CTRLForSequenceClassification(CTRLPreTrainedModel):
  484. def __init__(self, config):
  485. super().__init__(config)
  486. self.num_labels = config.num_labels
  487. self.transformer = CTRLModel(config)
  488. self.classifier = nn.Linear(config.n_embd, self.num_labels, bias=False)
  489. # Initialize weights and apply final processing
  490. self.post_init()
  491. @auto_docstring
  492. def forward(
  493. self,
  494. input_ids: Optional[torch.LongTensor] = None,
  495. past_key_values: Optional[Cache] = None,
  496. attention_mask: Optional[torch.FloatTensor] = None,
  497. token_type_ids: Optional[torch.LongTensor] = None,
  498. position_ids: Optional[torch.LongTensor] = None,
  499. head_mask: Optional[torch.FloatTensor] = None,
  500. inputs_embeds: Optional[torch.FloatTensor] = None,
  501. labels: Optional[torch.LongTensor] = None,
  502. use_cache: Optional[bool] = None,
  503. output_attentions: Optional[bool] = None,
  504. output_hidden_states: Optional[bool] = None,
  505. return_dict: Optional[bool] = None,
  506. ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
  507. r"""
  508. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  509. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0].shape[-2]`
  510. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  511. If `past_key_values` is used, only input IDs that do not have their past calculated should be passed as
  512. `input_ids`.
  513. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
  514. [`PreTrainedTokenizer.encode`] for details.
  515. [What are input IDs?](../glossary#input-ids)
  516. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  517. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  518. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  519. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  520. Example of single-label classification:
  521. ```python
  522. >>> import torch
  523. >>> from transformers import AutoTokenizer, CTRLForSequenceClassification
  524. >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
  525. >>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl")
  526. >>> # CTRL was trained with control codes as the first token
  527. >>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
  528. >>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
  529. >>> with torch.no_grad():
  530. ... logits = model(**inputs).logits
  531. >>> predicted_class_id = logits.argmax().item()
  532. >>> model.config.id2label[predicted_class_id]
  533. 'LABEL_0'
  534. ```
  535. ```python
  536. >>> import torch
  537. >>> torch.manual_seed(42) # doctest: +IGNORE_RESULT
  538. >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
  539. >>> num_labels = len(model.config.id2label)
  540. >>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl", num_labels=num_labels)
  541. >>> labels = torch.tensor(1)
  542. >>> loss = model(**inputs, labels=labels).loss
  543. >>> round(loss.item(), 2)
  544. 0.93
  545. ```
  546. Example of multi-label classification:
  547. ```python
  548. >>> import torch
  549. >>> from transformers import AutoTokenizer, CTRLForSequenceClassification
  550. >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
  551. >>> model = CTRLForSequenceClassification.from_pretrained(
  552. ... "Salesforce/ctrl", problem_type="multi_label_classification"
  553. ... )
  554. >>> # CTRL was trained with control codes as the first token
  555. >>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
  556. >>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
  557. >>> with torch.no_grad():
  558. ... logits = model(**inputs).logits
  559. >>> predicted_class_id = logits.argmax().item()
  560. >>> model.config.id2label[predicted_class_id]
  561. 'LABEL_0'
  562. ```
  563. ```python
  564. >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
  565. >>> num_labels = len(model.config.id2label)
  566. >>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl", num_labels=num_labels)
  567. >>> num_labels = len(model.config.id2label)
  568. >>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to(
  569. ... torch.float
  570. ... )
  571. >>> loss = model(**inputs, labels=labels).loss
  572. >>> loss.backward() # doctest: +IGNORE_RESULT
  573. ```"""
  574. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  575. transformer_outputs = self.transformer(
  576. input_ids,
  577. past_key_values=past_key_values,
  578. attention_mask=attention_mask,
  579. token_type_ids=token_type_ids,
  580. position_ids=position_ids,
  581. head_mask=head_mask,
  582. inputs_embeds=inputs_embeds,
  583. use_cache=use_cache,
  584. output_attentions=output_attentions,
  585. output_hidden_states=output_hidden_states,
  586. return_dict=return_dict,
  587. )
  588. hidden_states = transformer_outputs[0]
  589. logits = self.classifier(hidden_states)
  590. if input_ids is not None:
  591. batch_size, sequence_length = input_ids.shape[:2]
  592. else:
  593. batch_size, sequence_length = inputs_embeds.shape[:2]
  594. if self.config.pad_token_id is None and batch_size != 1:
  595. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  596. if self.config.pad_token_id is None:
  597. last_non_pad_token = -1
  598. elif input_ids is not None:
  599. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  600. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  601. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  602. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  603. else:
  604. last_non_pad_token = -1
  605. logger.warning_once(
  606. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  607. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  608. )
  609. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  610. loss = None
  611. if labels is not None:
  612. if self.config.problem_type is None:
  613. if self.num_labels == 1:
  614. self.config.problem_type = "regression"
  615. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  616. self.config.problem_type = "single_label_classification"
  617. else:
  618. self.config.problem_type = "multi_label_classification"
  619. if self.config.problem_type == "regression":
  620. loss_fct = MSELoss()
  621. if self.num_labels == 1:
  622. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  623. else:
  624. loss = loss_fct(pooled_logits, labels)
  625. elif self.config.problem_type == "single_label_classification":
  626. loss_fct = CrossEntropyLoss()
  627. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  628. elif self.config.problem_type == "multi_label_classification":
  629. loss_fct = BCEWithLogitsLoss()
  630. loss = loss_fct(pooled_logits, labels)
  631. if not return_dict:
  632. output = (pooled_logits,) + transformer_outputs[2:]
  633. return ((loss,) + output) if loss is not None else output
  634. return SequenceClassifierOutput(
  635. loss=loss,
  636. logits=pooled_logits,
  637. hidden_states=transformer_outputs.hidden_states,
  638. attentions=transformer_outputs.attentions,
  639. )
  640. __all__ = ["CTRLForSequenceClassification", "CTRLLMHeadModel", "CTRLModel", "CTRLPreTrainedModel"]