modeling_flax_outputs.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700
  1. # Copyright 2021 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. import flax
  16. import jax.numpy as jnp
  17. from .utils import ModelOutput
  18. @flax.struct.dataclass
  19. class FlaxBaseModelOutput(ModelOutput):
  20. """
  21. Base class for model's outputs, with potential hidden states and attentions.
  22. Args:
  23. last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
  24. Sequence of hidden-states at the output of the last layer of the model.
  25. hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  26. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  27. `(batch_size, sequence_length, hidden_size)`.
  28. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  29. attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  30. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  31. sequence_length)`.
  32. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  33. heads.
  34. """
  35. last_hidden_state: Optional[jnp.ndarray] = None
  36. hidden_states: Optional[tuple[jnp.ndarray]] = None
  37. attentions: Optional[tuple[jnp.ndarray]] = None
  38. @flax.struct.dataclass
  39. class FlaxBaseModelOutputWithNoAttention(ModelOutput):
  40. """
  41. Base class for model's outputs, with potential hidden states.
  42. Args:
  43. last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
  44. Sequence of hidden-states at the output of the last layer of the model.
  45. hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  46. Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
  47. for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
  48. model at the output of each layer plus the optional initial embedding outputs.
  49. """
  50. last_hidden_state: Optional[jnp.ndarray] = None
  51. hidden_states: Optional[tuple[jnp.ndarray]] = None
  52. @flax.struct.dataclass
  53. class FlaxBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
  54. """
  55. Base class for model's outputs that also contains a pooling of the last hidden states.
  56. Args:
  57. last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
  58. Sequence of hidden-states at the output of the last layer of the model.
  59. pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
  60. Last layer hidden-state after a pooling operation on the spatial dimensions.
  61. hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  62. Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
  63. for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
  64. model at the output of each layer plus the optional initial embedding outputs.
  65. """
  66. last_hidden_state: Optional[jnp.ndarray] = None
  67. pooler_output: Optional[jnp.ndarray] = None
  68. hidden_states: Optional[tuple[jnp.ndarray]] = None
  69. @flax.struct.dataclass
  70. class FlaxImageClassifierOutputWithNoAttention(ModelOutput):
  71. """
  72. Base class for outputs of image classification models.
  73. Args:
  74. logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
  75. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  76. hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when
  77. `config.output_hidden_states=True`):
  78. Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
  79. for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
  80. called feature maps) of the model at the output of each stage.
  81. """
  82. logits: Optional[jnp.ndarray] = None
  83. hidden_states: Optional[tuple[jnp.ndarray]] = None
  84. @flax.struct.dataclass
  85. class FlaxBaseModelOutputWithPast(ModelOutput):
  86. """
  87. Base class for model's outputs, with potential hidden states and attentions.
  88. Args:
  89. last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
  90. Sequence of hidden-states at the output of the last layer of the model.
  91. past_key_values (`dict[str, jnp.ndarray]`):
  92. Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
  93. auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
  94. hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  95. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  96. `(batch_size, sequence_length, hidden_size)`.
  97. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  98. attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  99. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  100. sequence_length)`.
  101. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  102. heads.
  103. """
  104. last_hidden_state: Optional[jnp.ndarray] = None
  105. past_key_values: Optional[dict[str, jnp.ndarray]] = None
  106. hidden_states: Optional[tuple[jnp.ndarray]] = None
  107. attentions: Optional[tuple[jnp.ndarray]] = None
  108. @flax.struct.dataclass
  109. class FlaxBaseModelOutputWithPooling(ModelOutput):
  110. """
  111. Base class for model's outputs that also contains a pooling of the last hidden states.
  112. Args:
  113. last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
  114. Sequence of hidden-states at the output of the last layer of the model.
  115. pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
  116. Last layer hidden-state of the first token of the sequence (classification token) further processed by a
  117. Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
  118. prediction (classification) objective during pretraining.
  119. hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  120. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  121. `(batch_size, sequence_length, hidden_size)`.
  122. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  123. attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  124. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  125. sequence_length)`.
  126. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  127. heads.
  128. """
  129. last_hidden_state: Optional[jnp.ndarray] = None
  130. pooler_output: Optional[jnp.ndarray] = None
  131. hidden_states: Optional[tuple[jnp.ndarray]] = None
  132. attentions: Optional[tuple[jnp.ndarray]] = None
  133. @flax.struct.dataclass
  134. class FlaxBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
  135. """
  136. Base class for model's outputs that also contains a pooling of the last hidden states.
  137. Args:
  138. last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
  139. Sequence of hidden-states at the output of the last layer of the model.
  140. pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
  141. Last layer hidden-state of the first token of the sequence (classification token) after further processing
  142. through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
  143. the classification token after processing through a linear layer and a tanh activation function. The linear
  144. layer weights are trained from the next sentence prediction (classification) objective during pretraining.
  145. hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  146. Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
  147. for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  148. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  149. attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  150. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  151. sequence_length)`.
  152. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  153. heads.
  154. cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
  155. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  156. sequence_length)`.
  157. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
  158. weighted average in the cross-attention heads.
  159. past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  160. Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  161. `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
  162. `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
  163. encoder_sequence_length, embed_size_per_head)`.
  164. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  165. `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  166. input) to speed up sequential decoding.
  167. """
  168. last_hidden_state: Optional[jnp.ndarray] = None
  169. pooler_output: Optional[jnp.ndarray] = None
  170. hidden_states: Optional[tuple[jnp.ndarray]] = None
  171. past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
  172. attentions: Optional[tuple[jnp.ndarray]] = None
  173. cross_attentions: Optional[tuple[jnp.ndarray]] = None
  174. @flax.struct.dataclass
  175. class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
  176. """
  177. Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
  178. Args:
  179. last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
  180. Sequence of hidden-states at the output of the last layer of the model.
  181. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  182. hidden_size)` is output.
  183. past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  184. Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  185. `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
  186. `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
  187. encoder_sequence_length, embed_size_per_head)`.
  188. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  189. `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  190. input) to speed up sequential decoding.
  191. hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  192. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  193. `(batch_size, sequence_length, hidden_size)`.
  194. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  195. attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  196. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  197. sequence_length)`.
  198. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  199. heads.
  200. cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
  201. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  202. sequence_length)`.
  203. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
  204. weighted average in the cross-attention heads.
  205. """
  206. last_hidden_state: Optional[jnp.ndarray] = None
  207. past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
  208. hidden_states: Optional[tuple[jnp.ndarray]] = None
  209. attentions: Optional[tuple[jnp.ndarray]] = None
  210. cross_attentions: Optional[tuple[jnp.ndarray]] = None
  211. @flax.struct.dataclass
  212. class FlaxSeq2SeqModelOutput(ModelOutput):
  213. """
  214. Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
  215. decoding.
  216. Args:
  217. last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
  218. Sequence of hidden-states at the output of the last layer of the decoder of the model.
  219. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  220. hidden_size)` is output.
  221. past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  222. Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  223. `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
  224. `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
  225. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  226. blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  227. decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  228. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  229. `(batch_size, sequence_length, hidden_size)`.
  230. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
  231. decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  232. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  233. sequence_length)`.
  234. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
  235. self-attention heads.
  236. cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  237. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  238. sequence_length)`.
  239. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
  240. weighted average in the cross-attention heads.
  241. encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  242. Sequence of hidden-states at the output of the last layer of the encoder of the model.
  243. encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  244. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  245. `(batch_size, sequence_length, hidden_size)`.
  246. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
  247. encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  248. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  249. sequence_length)`.
  250. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
  251. self-attention heads.
  252. """
  253. last_hidden_state: Optional[jnp.ndarray] = None
  254. past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
  255. decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
  256. decoder_attentions: Optional[tuple[jnp.ndarray]] = None
  257. cross_attentions: Optional[tuple[jnp.ndarray]] = None
  258. encoder_last_hidden_state: Optional[jnp.ndarray] = None
  259. encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
  260. encoder_attentions: Optional[tuple[jnp.ndarray]] = None
  261. @flax.struct.dataclass
  262. class FlaxCausalLMOutputWithCrossAttentions(ModelOutput):
  263. """
  264. Base class for causal language model (or autoregressive) outputs.
  265. Args:
  266. logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
  267. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  268. hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  269. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  270. `(batch_size, sequence_length, hidden_size)`.
  271. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  272. attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  273. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  274. sequence_length)`.
  275. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  276. heads.
  277. cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  278. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  279. sequence_length)`.
  280. Cross attentions weights after the attention softmax, used to compute the weighted average in the
  281. cross-attention heads.
  282. past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  283. Tuple of `jnp.ndarray` tuples of length `config.n_layers`, with each tuple containing the cached key, value
  284. states of the self-attention and the cross-attention layers if model is used in encoder-decoder setting.
  285. Only relevant if `config.is_decoder = True`.
  286. Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
  287. `past_key_values` input) to speed up sequential decoding.
  288. """
  289. logits: Optional[jnp.ndarray] = None
  290. past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
  291. hidden_states: Optional[tuple[jnp.ndarray]] = None
  292. attentions: Optional[tuple[jnp.ndarray]] = None
  293. cross_attentions: Optional[tuple[jnp.ndarray]] = None
  294. @flax.struct.dataclass
  295. class FlaxMaskedLMOutput(ModelOutput):
  296. """
  297. Base class for masked language models outputs.
  298. Args:
  299. logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
  300. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  301. hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  302. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  303. `(batch_size, sequence_length, hidden_size)`.
  304. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  305. attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  306. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  307. sequence_length)`.
  308. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  309. heads.
  310. """
  311. logits: Optional[jnp.ndarray] = None
  312. hidden_states: Optional[tuple[jnp.ndarray]] = None
  313. attentions: Optional[tuple[jnp.ndarray]] = None
  314. FlaxCausalLMOutput = FlaxMaskedLMOutput
  315. @flax.struct.dataclass
  316. class FlaxSeq2SeqLMOutput(ModelOutput):
  317. """
  318. Base class for sequence-to-sequence language models outputs.
  319. Args:
  320. logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
  321. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  322. past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  323. Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  324. `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
  325. `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
  326. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  327. blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  328. decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  329. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  330. `(batch_size, sequence_length, hidden_size)`.
  331. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
  332. decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  333. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  334. sequence_length)`.
  335. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
  336. self-attention heads.
  337. cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  338. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  339. sequence_length)`.
  340. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
  341. weighted average in the cross-attention heads.
  342. encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  343. Sequence of hidden-states at the output of the last layer of the encoder of the model.
  344. encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  345. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  346. `(batch_size, sequence_length, hidden_size)`.
  347. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
  348. encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  349. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  350. sequence_length)`.
  351. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
  352. self-attention heads.
  353. """
  354. logits: Optional[jnp.ndarray] = None
  355. past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
  356. decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
  357. decoder_attentions: Optional[tuple[jnp.ndarray]] = None
  358. cross_attentions: Optional[tuple[jnp.ndarray]] = None
  359. encoder_last_hidden_state: Optional[jnp.ndarray] = None
  360. encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
  361. encoder_attentions: Optional[tuple[jnp.ndarray]] = None
  362. @flax.struct.dataclass
  363. class FlaxNextSentencePredictorOutput(ModelOutput):
  364. """
  365. Base class for outputs of models predicting if two sentences are consecutive or not.
  366. Args:
  367. logits (`jnp.ndarray` of shape `(batch_size, 2)`):
  368. Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
  369. before SoftMax).
  370. hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  371. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  372. `(batch_size, sequence_length, hidden_size)`.
  373. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  374. attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  375. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  376. sequence_length)`.
  377. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  378. heads.
  379. """
  380. logits: Optional[jnp.ndarray] = None
  381. hidden_states: Optional[tuple[jnp.ndarray]] = None
  382. attentions: Optional[tuple[jnp.ndarray]] = None
  383. @flax.struct.dataclass
  384. class FlaxSequenceClassifierOutput(ModelOutput):
  385. """
  386. Base class for outputs of sentence classification models.
  387. Args:
  388. logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
  389. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  390. hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  391. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  392. `(batch_size, sequence_length, hidden_size)`.
  393. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  394. attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  395. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  396. sequence_length)`.
  397. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  398. heads.
  399. """
  400. logits: Optional[jnp.ndarray] = None
  401. hidden_states: Optional[tuple[jnp.ndarray]] = None
  402. attentions: Optional[tuple[jnp.ndarray]] = None
  403. @flax.struct.dataclass
  404. class FlaxSeq2SeqSequenceClassifierOutput(ModelOutput):
  405. """
  406. Base class for outputs of sequence-to-sequence sentence classification models.
  407. Args:
  408. logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
  409. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  410. past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  411. Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  412. `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
  413. `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
  414. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  415. blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  416. decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  417. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  418. `(batch_size, sequence_length, hidden_size)`.
  419. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
  420. decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  421. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  422. sequence_length)`.
  423. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
  424. self-attention heads.
  425. cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  426. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  427. sequence_length)`.
  428. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
  429. weighted average in the cross-attention heads.
  430. encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  431. Sequence of hidden-states at the output of the last layer of the encoder of the model.
  432. encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  433. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  434. `(batch_size, sequence_length, hidden_size)`.
  435. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
  436. encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  437. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  438. sequence_length)`.
  439. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
  440. self-attention heads.
  441. """
  442. logits: Optional[jnp.ndarray] = None
  443. past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
  444. decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
  445. decoder_attentions: Optional[tuple[jnp.ndarray]] = None
  446. cross_attentions: Optional[tuple[jnp.ndarray]] = None
  447. encoder_last_hidden_state: Optional[jnp.ndarray] = None
  448. encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
  449. encoder_attentions: Optional[tuple[jnp.ndarray]] = None
  450. @flax.struct.dataclass
  451. class FlaxMultipleChoiceModelOutput(ModelOutput):
  452. """
  453. Base class for outputs of multiple choice models.
  454. Args:
  455. logits (`jnp.ndarray` of shape `(batch_size, num_choices)`):
  456. *num_choices* is the second dimension of the input tensors. (see *input_ids* above).
  457. Classification scores (before SoftMax).
  458. hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  459. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  460. `(batch_size, sequence_length, hidden_size)`.
  461. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  462. attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  463. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  464. sequence_length)`.
  465. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  466. heads.
  467. """
  468. logits: Optional[jnp.ndarray] = None
  469. hidden_states: Optional[tuple[jnp.ndarray]] = None
  470. attentions: Optional[tuple[jnp.ndarray]] = None
  471. @flax.struct.dataclass
  472. class FlaxTokenClassifierOutput(ModelOutput):
  473. """
  474. Base class for outputs of token classification models.
  475. Args:
  476. logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.num_labels)`):
  477. Classification scores (before SoftMax).
  478. hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  479. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  480. `(batch_size, sequence_length, hidden_size)`.
  481. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  482. attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  483. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  484. sequence_length)`.
  485. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  486. heads.
  487. """
  488. logits: Optional[jnp.ndarray] = None
  489. hidden_states: Optional[tuple[jnp.ndarray]] = None
  490. attentions: Optional[tuple[jnp.ndarray]] = None
  491. @flax.struct.dataclass
  492. class FlaxQuestionAnsweringModelOutput(ModelOutput):
  493. """
  494. Base class for outputs of question answering models.
  495. Args:
  496. start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
  497. Span-start scores (before SoftMax).
  498. end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
  499. Span-end scores (before SoftMax).
  500. hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  501. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  502. `(batch_size, sequence_length, hidden_size)`.
  503. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  504. attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  505. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  506. sequence_length)`.
  507. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  508. heads.
  509. """
  510. start_logits: Optional[jnp.ndarray] = None
  511. end_logits: Optional[jnp.ndarray] = None
  512. hidden_states: Optional[tuple[jnp.ndarray]] = None
  513. attentions: Optional[tuple[jnp.ndarray]] = None
  514. @flax.struct.dataclass
  515. class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
  516. """
  517. Base class for outputs of sequence-to-sequence question answering models.
  518. Args:
  519. start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
  520. Span-start scores (before SoftMax).
  521. end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
  522. Span-end scores (before SoftMax).
  523. past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  524. Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  525. `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
  526. `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
  527. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  528. blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  529. decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  530. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  531. `(batch_size, sequence_length, hidden_size)`.
  532. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
  533. decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  534. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  535. sequence_length)`.
  536. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
  537. self-attention heads.
  538. cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  539. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  540. sequence_length)`.
  541. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
  542. weighted average in the cross-attention heads.
  543. encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  544. Sequence of hidden-states at the output of the last layer of the encoder of the model.
  545. encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  546. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  547. `(batch_size, sequence_length, hidden_size)`.
  548. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
  549. encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  550. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  551. sequence_length)`.
  552. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
  553. self-attention heads.
  554. """
  555. start_logits: Optional[jnp.ndarray] = None
  556. end_logits: Optional[jnp.ndarray] = None
  557. past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
  558. decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
  559. decoder_attentions: Optional[tuple[jnp.ndarray]] = None
  560. cross_attentions: Optional[tuple[jnp.ndarray]] = None
  561. encoder_last_hidden_state: Optional[jnp.ndarray] = None
  562. encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
  563. encoder_attentions: Optional[tuple[jnp.ndarray]] = None