| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
- # -------------------------------------------------------------------------
- import logging
- import random
- import torch
- from transformers import MT5Config, T5Config
- logger = logging.getLogger(__name__)
- class T5Encoder(torch.nn.Module):
- """T5 encoder outputs only the last hidden state"""
- def __init__(self, encoder, config: T5Config | MT5Config):
- super().__init__()
- self.encoder = encoder
- self.config = config
- def forward(self, input_ids, attention_mask):
- return self.encoder(input_ids, attention_mask)[0]
- class T5EncoderInputs:
- def __init__(self, input_ids, attention_mask):
- self.input_ids: torch.LongTensor = input_ids
- self.attention_mask: torch.LongTensor = attention_mask
- @staticmethod
- def create_dummy(
- batch_size: int,
- sequence_length: int,
- vocab_size: int,
- device: torch.device,
- use_int32_inputs: bool = False,
- ): # -> T5EncoderInputs
- """Create dummy inputs for T5 encoder.
- Args:
- batch_size (int): batch size
- sequence_length (int): sequence length
- vocab_size (int): vocabulary size
- device (torch.device): device of output tensors
- Returns:
- T5EncoderInputs: dummy inputs for encoder
- """
- dtype = torch.int32 if use_int32_inputs else torch.int64
- input_ids = torch.randint(
- low=0,
- high=vocab_size - 1,
- size=(batch_size, sequence_length),
- dtype=dtype,
- device=device,
- )
- attention_mask = torch.ones([batch_size, sequence_length], dtype=dtype, device=device)
- if sequence_length >= 2:
- for i in range(batch_size):
- padding_position = random.randint(0, sequence_length - 1)
- attention_mask[i, :padding_position] = 0
- return T5EncoderInputs(input_ids, attention_mask)
- def to_list(self) -> list:
- input_list = [v for v in [self.input_ids, self.attention_mask] if v is not None]
- return input_list
|