| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586 |
- # coding=utf-8
- # Copyright 2019 The Google AI Language Team Authors and The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch ELECTRA model."""
- import math
- import os
- from dataclasses import dataclass
- from typing import Callable, Optional, Union
- import torch
- from torch import nn
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
- from ...activations import ACT2FN, get_activation
- from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
- from ...generation import GenerationMixin
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutputWithCrossAttentions,
- BaseModelOutputWithPastAndCrossAttentions,
- CausalLMOutputWithCrossAttentions,
- MaskedLMOutput,
- MultipleChoiceModelOutput,
- QuestionAnsweringModelOutput,
- SequenceClassifierOutput,
- TokenClassifierOutput,
- )
- from ...modeling_utils import PreTrainedModel
- from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
- from ...utils import ModelOutput, auto_docstring, logging
- from ...utils.deprecation import deprecate_kwarg
- from .configuration_electra import ElectraConfig
- logger = logging.get_logger(__name__)
- def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_or_generator="discriminator"):
- """Load tf checkpoints in a pytorch model."""
- try:
- import re
- import numpy as np
- import tensorflow as tf
- except ImportError:
- logger.error(
- "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
- "https://www.tensorflow.org/install/ for installation instructions."
- )
- raise
- tf_path = os.path.abspath(tf_checkpoint_path)
- logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
- # Load weights from TF model
- init_vars = tf.train.list_variables(tf_path)
- names = []
- arrays = []
- for name, shape in init_vars:
- logger.info(f"Loading TF weight {name} with shape {shape}")
- array = tf.train.load_variable(tf_path, name)
- names.append(name)
- arrays.append(array)
- for name, array in zip(names, arrays):
- original_name: str = name
- try:
- if isinstance(model, ElectraForMaskedLM):
- name = name.replace("electra/embeddings/", "generator/embeddings/")
- if discriminator_or_generator == "generator":
- name = name.replace("electra/", "discriminator/")
- name = name.replace("generator/", "electra/")
- name = name.replace("dense_1", "dense_prediction")
- name = name.replace("generator_predictions/output_bias", "generator_lm_head/bias")
- name = name.split("/")
- # print(original_name, name)
- # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
- # which are not required for using pretrained model
- if any(n in ["global_step", "temperature"] for n in name):
- logger.info(f"Skipping {original_name}")
- continue
- pointer = model
- for m_name in name:
- if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
- scope_names = re.split(r"_(\d+)", m_name)
- else:
- scope_names = [m_name]
- if scope_names[0] == "kernel" or scope_names[0] == "gamma":
- pointer = getattr(pointer, "weight")
- elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
- pointer = getattr(pointer, "bias")
- elif scope_names[0] == "output_weights":
- pointer = getattr(pointer, "weight")
- elif scope_names[0] == "squad":
- pointer = getattr(pointer, "classifier")
- else:
- pointer = getattr(pointer, scope_names[0])
- if len(scope_names) >= 2:
- num = int(scope_names[1])
- pointer = pointer[num]
- if m_name.endswith("_embeddings"):
- pointer = getattr(pointer, "weight")
- elif m_name == "kernel":
- array = np.transpose(array)
- try:
- if pointer.shape != array.shape:
- raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
- except ValueError as e:
- e.args += (pointer.shape, array.shape)
- raise
- print(f"Initialize PyTorch weight {name}", original_name)
- pointer.data = torch.from_numpy(array)
- except AttributeError as e:
- print(f"Skipping {original_name}", name, e)
- continue
- return model
- class ElectraEmbeddings(nn.Module):
- """Construct the embeddings from word, position and token_type embeddings."""
- def __init__(self, config):
- super().__init__()
- self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
- self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
- self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
- # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
- # any TensorFlow checkpoint file
- self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
- self.register_buffer(
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
- )
- self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
- self.register_buffer(
- "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
- )
- # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- token_type_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- past_key_values_length: int = 0,
- ) -> torch.Tensor:
- if input_ids is not None:
- input_shape = input_ids.size()
- else:
- input_shape = inputs_embeds.size()[:-1]
- seq_length = input_shape[1]
- if position_ids is None:
- position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
- # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
- # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
- # issue #5664
- if token_type_ids is None:
- if hasattr(self, "token_type_ids"):
- buffered_token_type_ids = self.token_type_ids[:, :seq_length]
- buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
- token_type_ids = buffered_token_type_ids_expanded
- else:
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
- if inputs_embeds is None:
- inputs_embeds = self.word_embeddings(input_ids)
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
- embeddings = inputs_embeds + token_type_embeddings
- if self.position_embedding_type == "absolute":
- position_embeddings = self.position_embeddings(position_ids)
- embeddings += position_embeddings
- embeddings = self.LayerNorm(embeddings)
- embeddings = self.dropout(embeddings)
- return embeddings
- # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Electra
- class ElectraSelfAttention(nn.Module):
- def __init__(self, config, position_embedding_type=None, layer_idx=None):
- super().__init__()
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
- raise ValueError(
- f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
- f"heads ({config.num_attention_heads})"
- )
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- self.position_embedding_type = position_embedding_type or getattr(
- config, "position_embedding_type", "absolute"
- )
- if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
- self.max_position_embeddings = config.max_position_embeddings
- self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
- self.is_decoder = config.is_decoder
- self.layer_idx = layer_idx
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.FloatTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- past_key_values: Optional[Cache] = None,
- output_attentions: Optional[bool] = False,
- cache_position: Optional[torch.Tensor] = None,
- ) -> tuple[torch.Tensor]:
- batch_size, seq_length, _ = hidden_states.shape
- query_layer = self.query(hidden_states)
- query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
- 1, 2
- )
- is_updated = False
- is_cross_attention = encoder_hidden_states is not None
- if past_key_values is not None:
- if isinstance(past_key_values, EncoderDecoderCache):
- is_updated = past_key_values.is_updated.get(self.layer_idx)
- if is_cross_attention:
- # after the first generated id, we can subsequently re-use all key/value_layer from cache
- curr_past_key_value = past_key_values.cross_attention_cache
- else:
- curr_past_key_value = past_key_values.self_attention_cache
- else:
- curr_past_key_value = past_key_values
- current_states = encoder_hidden_states if is_cross_attention else hidden_states
- if is_cross_attention and past_key_values is not None and is_updated:
- # reuse k,v, cross_attentions
- key_layer = curr_past_key_value.layers[self.layer_idx].keys
- value_layer = curr_past_key_value.layers[self.layer_idx].values
- else:
- key_layer = self.key(current_states)
- key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
- 1, 2
- )
- value_layer = self.value(current_states)
- value_layer = value_layer.view(
- batch_size, -1, self.num_attention_heads, self.attention_head_size
- ).transpose(1, 2)
- if past_key_values is not None:
- # save all key/value_layer to cache to be re-used for fast auto-regressive generation
- cache_position = cache_position if not is_cross_attention else None
- key_layer, value_layer = curr_past_key_value.update(
- key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
- )
- # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
- if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
- past_key_values.is_updated[self.layer_idx] = True
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
- if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
- query_length, key_length = query_layer.shape[2], key_layer.shape[2]
- if past_key_values is not None:
- position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
- -1, 1
- )
- else:
- position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
- position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
- distance = position_ids_l - position_ids_r
- positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
- positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
- if self.position_embedding_type == "relative_key":
- relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
- attention_scores = attention_scores + relative_position_scores
- elif self.position_embedding_type == "relative_key_query":
- relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
- relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
- attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
- if attention_mask is not None:
- # Apply the attention mask is (precomputed for all layers in ElectraModel forward() function)
- attention_scores = attention_scores + attention_mask
- # Normalize the attention scores to probabilities.
- attention_probs = nn.functional.softmax(attention_scores, dim=-1)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs = self.dropout(attention_probs)
- # Mask heads if we want to
- if head_mask is not None:
- attention_probs = attention_probs * head_mask
- context_layer = torch.matmul(attention_probs, value_layer)
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(new_context_layer_shape)
- return context_layer, attention_probs
- # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
- class ElectraSelfOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- ELECTRA_SELF_ATTENTION_CLASSES = {
- "eager": ElectraSelfAttention,
- }
- # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra,BERT->ELECTRA
- class ElectraAttention(nn.Module):
- def __init__(self, config, position_embedding_type=None, layer_idx=None):
- super().__init__()
- self.self = ELECTRA_SELF_ATTENTION_CLASSES[config._attn_implementation](
- config,
- position_embedding_type=position_embedding_type,
- layer_idx=layer_idx,
- )
- self.output = ElectraSelfOutput(config)
- self.pruned_heads = set()
- def prune_heads(self, heads):
- if len(heads) == 0:
- return
- heads, index = find_pruneable_heads_and_indices(
- heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
- )
- # Prune linear layers
- self.self.query = prune_linear_layer(self.self.query, index)
- self.self.key = prune_linear_layer(self.self.key, index)
- self.self.value = prune_linear_layer(self.self.value, index)
- self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
- # Update hyper params and store pruned heads
- self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
- self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
- self.pruned_heads = self.pruned_heads.union(heads)
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.FloatTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- past_key_values: Optional[Cache] = None,
- output_attentions: Optional[bool] = False,
- cache_position: Optional[torch.Tensor] = None,
- ) -> tuple[torch.Tensor]:
- self_outputs = self.self(
- hidden_states,
- attention_mask=attention_mask,
- head_mask=head_mask,
- encoder_hidden_states=encoder_hidden_states,
- past_key_values=past_key_values,
- output_attentions=output_attentions,
- cache_position=cache_position,
- )
- attention_output = self.output(self_outputs[0], hidden_states)
- outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
- return outputs
- # Copied from transformers.models.bert.modeling_bert.BertIntermediate
- class ElectraIntermediate(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
- # Copied from transformers.models.bert.modeling_bert.BertOutput
- class ElectraOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Electra
- class ElectraLayer(GradientCheckpointingLayer):
- def __init__(self, config, layer_idx=None):
- super().__init__()
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
- self.seq_len_dim = 1
- self.attention = ElectraAttention(config, layer_idx=layer_idx)
- self.is_decoder = config.is_decoder
- self.add_cross_attention = config.add_cross_attention
- if self.add_cross_attention:
- if not self.is_decoder:
- raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
- self.crossattention = ElectraAttention(config, position_embedding_type="absolute", layer_idx=layer_idx)
- self.intermediate = ElectraIntermediate(config)
- self.output = ElectraOutput(config)
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.FloatTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- past_key_values: Optional[Cache] = None,
- output_attentions: Optional[bool] = False,
- cache_position: Optional[torch.Tensor] = None,
- ) -> tuple[torch.Tensor]:
- self_attention_outputs = self.attention(
- hidden_states,
- attention_mask=attention_mask,
- head_mask=head_mask,
- output_attentions=output_attentions,
- past_key_values=past_key_values,
- cache_position=cache_position,
- )
- attention_output = self_attention_outputs[0]
- outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
- if self.is_decoder and encoder_hidden_states is not None:
- if not hasattr(self, "crossattention"):
- raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
- " by setting `config.add_cross_attention=True`"
- )
- cross_attention_outputs = self.crossattention(
- attention_output,
- attention_mask=encoder_attention_mask,
- head_mask=head_mask,
- encoder_hidden_states=encoder_hidden_states,
- past_key_values=past_key_values,
- output_attentions=output_attentions,
- cache_position=cache_position,
- )
- attention_output = cross_attention_outputs[0]
- outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
- layer_output = apply_chunking_to_forward(
- self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
- )
- outputs = (layer_output,) + outputs
- return outputs
- def feed_forward_chunk(self, attention_output):
- intermediate_output = self.intermediate(attention_output)
- layer_output = self.output(intermediate_output, attention_output)
- return layer_output
- # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Electra
- class ElectraEncoder(nn.Module):
- def __init__(self, config, layer_idx=None):
- super().__init__()
- self.config = config
- self.layer = nn.ModuleList([ElectraLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
- self.gradient_checkpointing = False
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.FloatTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- past_key_values: Optional[Cache] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = False,
- output_hidden_states: Optional[bool] = False,
- return_dict: Optional[bool] = True,
- cache_position: Optional[torch.Tensor] = None,
- ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
- all_hidden_states = () if output_hidden_states else None
- all_self_attentions = () if output_attentions else None
- all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
- if use_cache and self.config.is_decoder and past_key_values is None:
- past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
- if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple):
- logger.warning_once(
- "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
- "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
- "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
- )
- past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
- for i, layer_module in enumerate(self.layer):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- layer_head_mask = head_mask[i] if head_mask is not None else None
- layer_outputs = layer_module(
- hidden_states,
- attention_mask,
- layer_head_mask,
- encoder_hidden_states, # as a positional argument for gradient checkpointing
- encoder_attention_mask=encoder_attention_mask,
- past_key_values=past_key_values,
- output_attentions=output_attentions,
- cache_position=cache_position,
- )
- hidden_states = layer_outputs[0]
- if output_attentions:
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
- if self.config.add_cross_attention:
- all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(
- v
- for v in [
- hidden_states,
- past_key_values,
- all_hidden_states,
- all_self_attentions,
- all_cross_attentions,
- ]
- if v is not None
- )
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- cross_attentions=all_cross_attentions,
- )
- class ElectraDiscriminatorPredictions(nn.Module):
- """Prediction module for the discriminator, made up of two dense layers."""
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.activation = get_activation(config.hidden_act)
- self.dense_prediction = nn.Linear(config.hidden_size, 1)
- self.config = config
- def forward(self, discriminator_hidden_states):
- hidden_states = self.dense(discriminator_hidden_states)
- hidden_states = self.activation(hidden_states)
- logits = self.dense_prediction(hidden_states).squeeze(-1)
- return logits
- class ElectraGeneratorPredictions(nn.Module):
- """Prediction module for the generator, made up of two dense layers."""
- def __init__(self, config):
- super().__init__()
- self.activation = get_activation("gelu")
- self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
- self.dense = nn.Linear(config.hidden_size, config.embedding_size)
- def forward(self, generator_hidden_states):
- hidden_states = self.dense(generator_hidden_states)
- hidden_states = self.activation(hidden_states)
- hidden_states = self.LayerNorm(hidden_states)
- return hidden_states
- @auto_docstring
- class ElectraPreTrainedModel(PreTrainedModel):
- config: ElectraConfig
- load_tf_weights = load_tf_weights_in_electra
- base_model_prefix = "electra"
- supports_gradient_checkpointing = True
- def _init_weights(self, module):
- """Initialize the weights"""
- if isinstance(module, nn.Linear):
- # Slightly different from the TF version which uses truncated_normal for initialization
- # cf https://github.com/pytorch/pytorch/pull/5617
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
- elif isinstance(module, nn.LayerNorm):
- module.bias.data.zero_()
- module.weight.data.fill_(1.0)
- @dataclass
- @auto_docstring(
- custom_intro="""
- Output type of [`ElectraForPreTraining`].
- """
- )
- class ElectraForPreTrainingOutput(ModelOutput):
- r"""
- loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
- Total loss of the ELECTRA objective.
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
- Prediction scores of the head (scores for each token before SoftMax).
- """
- loss: Optional[torch.FloatTensor] = None
- logits: Optional[torch.FloatTensor] = None
- hidden_states: Optional[tuple[torch.FloatTensor]] = None
- attentions: Optional[tuple[torch.FloatTensor]] = None
- @auto_docstring
- class ElectraModel(ElectraPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.embeddings = ElectraEmbeddings(config)
- if config.embedding_size != config.hidden_size:
- self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size)
- self.encoder = ElectraEncoder(config)
- self.config = config
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.word_embeddings
- def set_input_embeddings(self, value):
- self.embeddings.word_embeddings = value
- def _prune_heads(self, heads_to_prune):
- """
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
- class PreTrainedModel
- """
- for layer, heads in heads_to_prune.items():
- self.encoder.layer[layer].attention.prune_heads(heads)
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- token_type_ids: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- encoder_attention_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[Cache] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[tuple[torch.Tensor], BaseModelOutputWithCrossAttentions]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is not None:
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
- input_shape = input_ids.size()
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
- batch_size, seq_length = input_shape
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- past_key_values_length = 0
- if past_key_values is not None:
- past_key_values_length = (
- past_key_values[0][0].shape[-2]
- if not isinstance(past_key_values, Cache)
- else past_key_values.get_seq_length()
- )
- if attention_mask is None:
- attention_mask = torch.ones(input_shape, device=device)
- if token_type_ids is None:
- if hasattr(self.embeddings, "token_type_ids"):
- buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
- buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
- token_type_ids = buffered_token_type_ids_expanded
- else:
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
- # If a 2D or 3D attention mask is provided for the cross-attention
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
- if self.config.is_decoder and encoder_hidden_states is not None:
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
- if encoder_attention_mask is None:
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
- else:
- encoder_extended_attention_mask = None
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
- hidden_states = self.embeddings(
- input_ids=input_ids,
- position_ids=position_ids,
- token_type_ids=token_type_ids,
- inputs_embeds=inputs_embeds,
- past_key_values_length=past_key_values_length,
- )
- if hasattr(self, "embeddings_project"):
- hidden_states = self.embeddings_project(hidden_states)
- hidden_states = self.encoder(
- hidden_states,
- attention_mask=extended_attention_mask,
- head_mask=head_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_extended_attention_mask,
- past_key_values=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- return hidden_states
- class ElectraClassificationHead(nn.Module):
- """Head for sentence-level classification tasks."""
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- classifier_dropout = (
- config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
- )
- self.activation = get_activation("gelu")
- self.dropout = nn.Dropout(classifier_dropout)
- self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
- def forward(self, features, **kwargs):
- x = features[:, 0, :] # take <s> token (equiv. to [CLS])
- x = self.dropout(x)
- x = self.dense(x)
- x = self.activation(x) # although BERT uses tanh here, it seems Electra authors used gelu here
- x = self.dropout(x)
- x = self.out_proj(x)
- return x
- # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Electra
- class ElectraSequenceSummary(nn.Module):
- r"""
- Compute a single vector summary of a sequence hidden states.
- Args:
- config ([`ElectraConfig`]):
- The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
- config class of your model for the default values it uses):
- - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
- - `"last"` -- Take the last token hidden state (like XLNet)
- - `"first"` -- Take the first token hidden state (like Bert)
- - `"mean"` -- Take the mean of all tokens hidden states
- - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
- - `"attn"` -- Not implemented now, use multi-head attention
- - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
- - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
- (otherwise to `config.hidden_size`).
- - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
- another string or `None` will add no activation.
- - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
- - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
- """
- def __init__(self, config: ElectraConfig):
- super().__init__()
- self.summary_type = getattr(config, "summary_type", "last")
- if self.summary_type == "attn":
- # We should use a standard multi-head attention module with absolute positional embedding for that.
- # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
- # We can probably just use the multi-head attention module of PyTorch >=1.1.0
- raise NotImplementedError
- self.summary = nn.Identity()
- if hasattr(config, "summary_use_proj") and config.summary_use_proj:
- if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
- num_classes = config.num_labels
- else:
- num_classes = config.hidden_size
- self.summary = nn.Linear(config.hidden_size, num_classes)
- activation_string = getattr(config, "summary_activation", None)
- self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
- self.first_dropout = nn.Identity()
- if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
- self.first_dropout = nn.Dropout(config.summary_first_dropout)
- self.last_dropout = nn.Identity()
- if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
- self.last_dropout = nn.Dropout(config.summary_last_dropout)
- def forward(
- self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
- ) -> torch.FloatTensor:
- """
- Compute a single vector summary of a sequence hidden states.
- Args:
- hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
- The hidden states of the last layer.
- cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
- Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
- Returns:
- `torch.FloatTensor`: The summary of the sequence hidden states.
- """
- if self.summary_type == "last":
- output = hidden_states[:, -1]
- elif self.summary_type == "first":
- output = hidden_states[:, 0]
- elif self.summary_type == "mean":
- output = hidden_states.mean(dim=1)
- elif self.summary_type == "cls_index":
- if cls_index is None:
- cls_index = torch.full_like(
- hidden_states[..., :1, :],
- hidden_states.shape[-2] - 1,
- dtype=torch.long,
- )
- else:
- cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
- cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
- # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
- output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
- elif self.summary_type == "attn":
- raise NotImplementedError
- output = self.first_dropout(output)
- output = self.summary(output)
- output = self.activation(output)
- output = self.last_dropout(output)
- return output
- @auto_docstring(
- custom_intro="""
- ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the
- pooled output) e.g. for GLUE tasks.
- """
- )
- class ElectraForSequenceClassification(ElectraPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.config = config
- self.electra = ElectraModel(config)
- self.classifier = ElectraClassificationHead(config)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- token_type_ids: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- discriminator_hidden_states = self.electra(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = discriminator_hidden_states[0]
- logits = self.classifier(sequence_output)
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
- if not return_dict:
- output = (logits,) + discriminator_hidden_states[1:]
- return ((loss,) + output) if loss is not None else output
- return SequenceClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=discriminator_hidden_states.hidden_states,
- attentions=discriminator_hidden_states.attentions,
- )
- @auto_docstring(
- custom_intro="""
- Electra model with a binary classification head on top as used during pretraining for identifying generated tokens.
- It is recommended to load the discriminator checkpoint into that model.
- """
- )
- class ElectraForPreTraining(ElectraPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.electra = ElectraModel(config)
- self.discriminator_predictions = ElectraDiscriminatorPredictions(config)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- token_type_ids: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[tuple[torch.Tensor], ElectraForPreTrainingOutput]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see `input_ids` docstring)
- Indices should be in `[0, 1]`:
- - 0 indicates the token is an original token,
- - 1 indicates the token was replaced.
- Examples:
- ```python
- >>> from transformers import ElectraForPreTraining, AutoTokenizer
- >>> import torch
- >>> discriminator = ElectraForPreTraining.from_pretrained("google/electra-base-discriminator")
- >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-base-discriminator")
- >>> sentence = "The quick brown fox jumps over the lazy dog"
- >>> fake_sentence = "The quick brown fox fake over the lazy dog"
- >>> fake_tokens = tokenizer.tokenize(fake_sentence, add_special_tokens=True)
- >>> fake_inputs = tokenizer.encode(fake_sentence, return_tensors="pt")
- >>> discriminator_outputs = discriminator(fake_inputs)
- >>> predictions = torch.round((torch.sign(discriminator_outputs[0]) + 1) / 2)
- >>> fake_tokens
- ['[CLS]', 'the', 'quick', 'brown', 'fox', 'fake', 'over', 'the', 'lazy', 'dog', '[SEP]']
- >>> predictions.squeeze().tolist()
- [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- discriminator_hidden_states = self.electra(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- discriminator_sequence_output = discriminator_hidden_states[0]
- logits = self.discriminator_predictions(discriminator_sequence_output)
- loss = None
- if labels is not None:
- loss_fct = nn.BCEWithLogitsLoss()
- if attention_mask is not None:
- active_loss = attention_mask.view(-1, discriminator_sequence_output.shape[1]) == 1
- active_logits = logits.view(-1, discriminator_sequence_output.shape[1])[active_loss]
- active_labels = labels[active_loss]
- loss = loss_fct(active_logits, active_labels.float())
- else:
- loss = loss_fct(logits.view(-1, discriminator_sequence_output.shape[1]), labels.float())
- if not return_dict:
- output = (logits,) + discriminator_hidden_states[1:]
- return ((loss,) + output) if loss is not None else output
- return ElectraForPreTrainingOutput(
- loss=loss,
- logits=logits,
- hidden_states=discriminator_hidden_states.hidden_states,
- attentions=discriminator_hidden_states.attentions,
- )
- @auto_docstring(
- custom_intro="""
- Electra model with a language modeling head on top.
- Even though both the discriminator and generator may be loaded into this model, the generator is the only model of
- the two to have been trained for the masked language modeling task.
- """
- )
- class ElectraForMaskedLM(ElectraPreTrainedModel):
- _tied_weights_keys = ["generator_lm_head.weight"]
- def __init__(self, config):
- super().__init__(config)
- self.electra = ElectraModel(config)
- self.generator_predictions = ElectraGeneratorPredictions(config)
- self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.generator_lm_head
- def set_output_embeddings(self, word_embeddings):
- self.generator_lm_head = word_embeddings
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- token_type_ids: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
- config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
- loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- generator_hidden_states = self.electra(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- generator_sequence_output = generator_hidden_states[0]
- prediction_scores = self.generator_predictions(generator_sequence_output)
- prediction_scores = self.generator_lm_head(prediction_scores)
- loss = None
- # Masked language modeling softmax layer
- if labels is not None:
- loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
- loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
- if not return_dict:
- output = (prediction_scores,) + generator_hidden_states[1:]
- return ((loss,) + output) if loss is not None else output
- return MaskedLMOutput(
- loss=loss,
- logits=prediction_scores,
- hidden_states=generator_hidden_states.hidden_states,
- attentions=generator_hidden_states.attentions,
- )
- @auto_docstring(
- custom_intro="""
- Electra model with a token classification head on top.
- Both the discriminator and generator may be loaded into this model.
- """
- )
- class ElectraForTokenClassification(ElectraPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.electra = ElectraModel(config)
- classifier_dropout = (
- config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
- )
- self.dropout = nn.Dropout(classifier_dropout)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- token_type_ids: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- discriminator_hidden_states = self.electra(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- discriminator_sequence_output = discriminator_hidden_states[0]
- discriminator_sequence_output = self.dropout(discriminator_sequence_output)
- logits = self.classifier(discriminator_sequence_output)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- if not return_dict:
- output = (logits,) + discriminator_hidden_states[1:]
- return ((loss,) + output) if loss is not None else output
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=discriminator_hidden_states.hidden_states,
- attentions=discriminator_hidden_states.attentions,
- )
- @auto_docstring
- class ElectraForQuestionAnswering(ElectraPreTrainedModel):
- config: ElectraConfig
- base_model_prefix = "electra"
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.electra = ElectraModel(config)
- self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- token_type_ids: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- start_positions: Optional[torch.Tensor] = None,
- end_positions: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- discriminator_hidden_states = self.electra(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- )
- sequence_output = discriminator_hidden_states[0]
- logits = self.qa_outputs(sequence_output)
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1).contiguous()
- end_logits = end_logits.squeeze(-1).contiguous()
- total_loss = None
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1)
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
- ignored_index = start_logits.size(1)
- start_positions = start_positions.clamp(0, ignored_index)
- end_positions = end_positions.clamp(0, ignored_index)
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
- if not return_dict:
- output = (
- start_logits,
- end_logits,
- ) + discriminator_hidden_states[1:]
- return ((total_loss,) + output) if total_loss is not None else output
- return QuestionAnsweringModelOutput(
- loss=total_loss,
- start_logits=start_logits,
- end_logits=end_logits,
- hidden_states=discriminator_hidden_states.hidden_states,
- attentions=discriminator_hidden_states.attentions,
- )
- @auto_docstring
- class ElectraForMultipleChoice(ElectraPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.electra = ElectraModel(config)
- self.sequence_summary = ElectraSequenceSummary(config)
- self.classifier = nn.Linear(config.hidden_size, 1)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- token_type_ids: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
- 1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- [What are token type IDs?](../glossary#token-type-ids)
- position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.max_position_embeddings - 1]`.
- [What are position IDs?](../glossary#position-ids)
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
- num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
- `input_ids` above)
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
- input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
- attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
- token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
- position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
- inputs_embeds = (
- inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
- if inputs_embeds is not None
- else None
- )
- discriminator_hidden_states = self.electra(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = discriminator_hidden_states[0]
- pooled_output = self.sequence_summary(sequence_output)
- logits = self.classifier(pooled_output)
- reshaped_logits = logits.view(-1, num_choices)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(reshaped_logits, labels)
- if not return_dict:
- output = (reshaped_logits,) + discriminator_hidden_states[1:]
- return ((loss,) + output) if loss is not None else output
- return MultipleChoiceModelOutput(
- loss=loss,
- logits=reshaped_logits,
- hidden_states=discriminator_hidden_states.hidden_states,
- attentions=discriminator_hidden_states.attentions,
- )
- @auto_docstring(
- custom_intro="""
- ELECTRA Model with a `language modeling` head on top for CLM fine-tuning.
- """
- )
- class ElectraForCausalLM(ElectraPreTrainedModel, GenerationMixin):
- _tied_weights_keys = ["generator_lm_head.weight"]
- def __init__(self, config):
- super().__init__(config)
- if not config.is_decoder:
- logger.warning("If you want to use `ElectraForCausalLM` as a standalone, add `is_decoder=True.`")
- self.electra = ElectraModel(config)
- self.generator_predictions = ElectraGeneratorPredictions(config)
- self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size)
- self.init_weights()
- def get_output_embeddings(self):
- return self.generator_lm_head
- def set_output_embeddings(self, new_embeddings):
- self.generator_lm_head = new_embeddings
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- token_type_ids: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- encoder_attention_mask: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- past_key_values: Optional[Cache] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- **kwargs,
- ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
- `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
- ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
- Example:
- ```python
- >>> from transformers import AutoTokenizer, ElectraForCausalLM, ElectraConfig
- >>> import torch
- >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-base-generator")
- >>> config = ElectraConfig.from_pretrained("google/electra-base-generator")
- >>> config.is_decoder = True
- >>> model = ElectraForCausalLM.from_pretrained("google/electra-base-generator", config=config)
- >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> prediction_logits = outputs.logits
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if labels is not None:
- use_cache = False
- outputs = self.electra(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- past_key_values=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- prediction_scores = self.generator_lm_head(self.generator_predictions(sequence_output))
- lm_loss = None
- if labels is not None:
- lm_loss = self.loss_function(
- prediction_scores,
- labels,
- vocab_size=self.config.vocab_size,
- **kwargs,
- )
- if not return_dict:
- output = (prediction_scores,) + outputs[1:]
- return ((lm_loss,) + output) if lm_loss is not None else output
- return CausalLMOutputWithCrossAttentions(
- loss=lm_loss,
- logits=prediction_scores,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- cross_attentions=outputs.cross_attentions,
- )
- __all__ = [
- "ElectraForCausalLM",
- "ElectraForMaskedLM",
- "ElectraForMultipleChoice",
- "ElectraForPreTraining",
- "ElectraForQuestionAnswering",
- "ElectraForSequenceClassification",
- "ElectraForTokenClassification",
- "ElectraModel",
- "ElectraPreTrainedModel",
- "load_tf_weights_in_electra",
- ]
|