modeling_layers.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from functools import partial
  15. from typing import Optional
  16. import torch
  17. import torch.nn as nn
  18. from .cache_utils import Cache
  19. from .modeling_outputs import (
  20. BaseModelOutputWithPast,
  21. QuestionAnsweringModelOutput,
  22. SequenceClassifierOutputWithPast,
  23. TokenClassifierOutput,
  24. )
  25. from .models.auto import AutoModel
  26. from .processing_utils import Unpack
  27. from .utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  28. logger = logging.get_logger(__name__)
  29. class GradientCheckpointingLayer(nn.Module):
  30. """Base class for layers with gradient checkpointing.
  31. This class enables gradient checkpointing functionality for a layer. By default, gradient checkpointing is disabled
  32. (`gradient_checkpointing = False`). When `model.set_gradient_checkpointing()` is called, gradient checkpointing is
  33. enabled by setting `gradient_checkpointing = True` and assigning a checkpointing function to `_gradient_checkpointing_func`.
  34. Important:
  35. When using gradient checkpointing with `use_reentrant=True`, inputs that require gradients (e.g. hidden states)
  36. must be passed as positional arguments (`*args`) rather than keyword arguments to properly propagate gradients.
  37. Example:
  38. ```python
  39. >>> # Correct - hidden_states passed as positional arg
  40. >>> out = self.layer(hidden_states, attention_mask=attention_mask)
  41. >>> # Incorrect - hidden_states passed as keyword arg
  42. >>> out = self.layer(hidden_states=hidden_states, attention_mask=attention_mask)
  43. ```
  44. """
  45. gradient_checkpointing = False
  46. def __call__(self, *args, **kwargs):
  47. if self.gradient_checkpointing and self.training:
  48. do_warn = False
  49. layer_name = self.__class__.__name__
  50. message = f"Caching is incompatible with gradient checkpointing in {layer_name}. Setting"
  51. if "use_cache" in kwargs and kwargs["use_cache"]:
  52. kwargs["use_cache"] = False
  53. message += " `use_cache=False`,"
  54. do_warn = True
  55. # different names for the same thing in different layers
  56. # TODO cyril: this one without `S` can be removed after deprection cycle
  57. if "past_key_value" in kwargs and kwargs["past_key_value"] is not None:
  58. kwargs["past_key_value"] = None
  59. message += " `past_key_value=None`,"
  60. do_warn = True
  61. if "past_key_values" in kwargs and kwargs["past_key_values"] is not None:
  62. kwargs["past_key_values"] = None
  63. message += " `past_key_values=None`,"
  64. do_warn = True
  65. if "layer_past" in kwargs and kwargs["layer_past"] is not None:
  66. kwargs["layer_past"] = None
  67. message += " `layer_past=None`,"
  68. do_warn = True
  69. # warn if anything was changed
  70. if do_warn:
  71. message = message.rstrip(",") + "."
  72. logger.warning_once(message)
  73. return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
  74. return super().__call__(*args, **kwargs)
  75. @auto_docstring
  76. class GenericForSequenceClassification:
  77. base_model_prefix = "model"
  78. def __init__(self, config):
  79. super().__init__(config)
  80. self.num_labels = config.num_labels
  81. # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
  82. setattr(self, self.base_model_prefix, AutoModel.from_config(config))
  83. self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
  84. # Initialize weights and apply final processing
  85. self.post_init()
  86. @can_return_tuple
  87. @auto_docstring
  88. def forward(
  89. self,
  90. input_ids: Optional[torch.LongTensor] = None,
  91. attention_mask: Optional[torch.Tensor] = None,
  92. position_ids: Optional[torch.LongTensor] = None,
  93. past_key_values: Optional[Cache] = None,
  94. inputs_embeds: Optional[torch.FloatTensor] = None,
  95. labels: Optional[torch.LongTensor] = None,
  96. use_cache: Optional[bool] = None,
  97. **kwargs: Unpack[TransformersKwargs],
  98. ) -> SequenceClassifierOutputWithPast:
  99. transformer_outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)(
  100. input_ids,
  101. attention_mask=attention_mask,
  102. position_ids=position_ids,
  103. past_key_values=past_key_values,
  104. inputs_embeds=inputs_embeds,
  105. use_cache=use_cache,
  106. **kwargs,
  107. )
  108. hidden_states = transformer_outputs.last_hidden_state
  109. logits = self.score(hidden_states)
  110. if input_ids is not None:
  111. batch_size = input_ids.shape[0]
  112. else:
  113. batch_size = inputs_embeds.shape[0]
  114. if self.config.pad_token_id is None and batch_size != 1:
  115. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  116. if self.config.pad_token_id is None:
  117. last_non_pad_token = -1
  118. elif input_ids is not None:
  119. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  120. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  121. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  122. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  123. else:
  124. last_non_pad_token = -1
  125. logger.warning_once(
  126. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  127. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  128. )
  129. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  130. loss = None
  131. if labels is not None:
  132. loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
  133. return SequenceClassifierOutputWithPast(
  134. loss=loss,
  135. logits=pooled_logits,
  136. past_key_values=transformer_outputs.past_key_values,
  137. hidden_states=transformer_outputs.hidden_states,
  138. attentions=transformer_outputs.attentions,
  139. )
  140. @auto_docstring
  141. class GenericForQuestionAnswering:
  142. base_model_prefix = "model"
  143. def __init__(self, config):
  144. super().__init__(config)
  145. # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
  146. setattr(self, self.base_model_prefix, AutoModel.from_config(config))
  147. self.qa_outputs = nn.Linear(config.hidden_size, 2)
  148. # Initialize weights and apply final processing
  149. self.post_init()
  150. def get_input_embeddings(self):
  151. return getattr(self, self.base_model_prefix).embed_tokens
  152. def set_input_embeddings(self, value):
  153. getattr(self, self.base_model_prefix).embed_tokens = value
  154. @can_return_tuple
  155. @auto_docstring
  156. def forward(
  157. self,
  158. input_ids: Optional[torch.LongTensor] = None,
  159. attention_mask: Optional[torch.Tensor] = None,
  160. position_ids: Optional[torch.LongTensor] = None,
  161. past_key_values: Optional[Cache] = None,
  162. inputs_embeds: Optional[torch.FloatTensor] = None,
  163. start_positions: Optional[torch.LongTensor] = None,
  164. end_positions: Optional[torch.LongTensor] = None,
  165. **kwargs: Unpack[TransformersKwargs],
  166. ) -> QuestionAnsweringModelOutput:
  167. outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)(
  168. input_ids,
  169. attention_mask=attention_mask,
  170. position_ids=position_ids,
  171. past_key_values=past_key_values,
  172. inputs_embeds=inputs_embeds,
  173. **kwargs,
  174. )
  175. sequence_output = outputs.last_hidden_state
  176. logits = self.qa_outputs(sequence_output)
  177. start_logits, end_logits = logits.split(1, dim=-1)
  178. start_logits = start_logits.squeeze(-1).contiguous()
  179. end_logits = end_logits.squeeze(-1).contiguous()
  180. loss = None
  181. if start_positions is not None and end_positions is not None:
  182. loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
  183. return QuestionAnsweringModelOutput(
  184. loss=loss,
  185. start_logits=start_logits,
  186. end_logits=end_logits,
  187. hidden_states=outputs.hidden_states,
  188. attentions=outputs.attentions,
  189. )
  190. @auto_docstring
  191. class GenericForTokenClassification:
  192. base_model_prefix = "model"
  193. def __init__(self, config):
  194. super().__init__(config)
  195. self.num_labels = config.num_labels
  196. # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
  197. setattr(self, self.base_model_prefix, AutoModel.from_config(config))
  198. if getattr(config, "classifier_dropout", None) is not None:
  199. classifier_dropout = config.classifier_dropout
  200. elif getattr(config, "hidden_dropout", None) is not None:
  201. classifier_dropout = config.hidden_dropout
  202. else:
  203. classifier_dropout = 0.1
  204. self.dropout = nn.Dropout(classifier_dropout)
  205. self.score = nn.Linear(config.hidden_size, config.num_labels)
  206. # Initialize weights and apply final processing
  207. self.post_init()
  208. @can_return_tuple
  209. @auto_docstring
  210. def forward(
  211. self,
  212. input_ids: Optional[torch.LongTensor] = None,
  213. attention_mask: Optional[torch.Tensor] = None,
  214. position_ids: Optional[torch.LongTensor] = None,
  215. past_key_values: Optional[Cache] = None,
  216. inputs_embeds: Optional[torch.FloatTensor] = None,
  217. labels: Optional[torch.LongTensor] = None,
  218. use_cache: Optional[bool] = None,
  219. **kwargs: Unpack[TransformersKwargs],
  220. ) -> TokenClassifierOutput:
  221. outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)(
  222. input_ids,
  223. attention_mask=attention_mask,
  224. position_ids=position_ids,
  225. past_key_values=past_key_values,
  226. inputs_embeds=inputs_embeds,
  227. use_cache=use_cache,
  228. **kwargs,
  229. )
  230. sequence_output = outputs.last_hidden_state
  231. sequence_output = self.dropout(sequence_output)
  232. logits = self.score(sequence_output)
  233. loss = None
  234. if labels is not None:
  235. loss = self.loss_function(logits, labels, self.config)
  236. return TokenClassifierOutput(
  237. loss=loss,
  238. logits=logits,
  239. hidden_states=outputs.hidden_states,
  240. attentions=outputs.attentions,
  241. )