t5_encoder.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # -------------------------------------------------------------------------
  5. import logging
  6. import random
  7. import torch
  8. from transformers import MT5Config, T5Config
  9. logger = logging.getLogger(__name__)
  10. class T5Encoder(torch.nn.Module):
  11. """T5 encoder outputs only the last hidden state"""
  12. def __init__(self, encoder, config: T5Config | MT5Config):
  13. super().__init__()
  14. self.encoder = encoder
  15. self.config = config
  16. def forward(self, input_ids, attention_mask):
  17. return self.encoder(input_ids, attention_mask)[0]
  18. class T5EncoderInputs:
  19. def __init__(self, input_ids, attention_mask):
  20. self.input_ids: torch.LongTensor = input_ids
  21. self.attention_mask: torch.LongTensor = attention_mask
  22. @staticmethod
  23. def create_dummy(
  24. batch_size: int,
  25. sequence_length: int,
  26. vocab_size: int,
  27. device: torch.device,
  28. use_int32_inputs: bool = False,
  29. ): # -> T5EncoderInputs
  30. """Create dummy inputs for T5 encoder.
  31. Args:
  32. batch_size (int): batch size
  33. sequence_length (int): sequence length
  34. vocab_size (int): vocabulary size
  35. device (torch.device): device of output tensors
  36. Returns:
  37. T5EncoderInputs: dummy inputs for encoder
  38. """
  39. dtype = torch.int32 if use_int32_inputs else torch.int64
  40. input_ids = torch.randint(
  41. low=0,
  42. high=vocab_size - 1,
  43. size=(batch_size, sequence_length),
  44. dtype=dtype,
  45. device=device,
  46. )
  47. attention_mask = torch.ones([batch_size, sequence_length], dtype=dtype, device=device)
  48. if sequence_length >= 2:
  49. for i in range(batch_size):
  50. padding_position = random.randint(0, sequence_length - 1)
  51. attention_mask[i, :padding_position] = 0
  52. return T5EncoderInputs(input_ids, attention_mask)
  53. def to_list(self) -> list:
  54. input_list = [v for v in [self.input_ids, self.attention_mask] if v is not None]
  55. return input_list