backbone.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. """PyTorch LSTM model. """
  3. import torch.nn as nn
  4. from modelscope.metainfo import Models
  5. from modelscope.models import TorchModel
  6. from modelscope.models.builder import MODELS
  7. from modelscope.outputs import BackboneModelOutput
  8. from modelscope.utils.constant import Tasks
  9. @MODELS.register_module(group_key=Tasks.backbone, module_name=Models.lstm)
  10. class LSTMModel(TorchModel):
  11. def __init__(self, vocab_size, embed_width, hidden_size=100, **kwargs):
  12. super().__init__()
  13. hidden_size = kwargs.get('lstm_hidden_size', hidden_size)
  14. self.embedding = Embedding(vocab_size, embed_width)
  15. self.lstm = nn.LSTM(
  16. embed_width,
  17. hidden_size,
  18. num_layers=1,
  19. bidirectional=True,
  20. batch_first=True)
  21. def forward(self, input_ids, **kwargs) -> BackboneModelOutput:
  22. embedding = self.embedding(input_ids)
  23. lstm_output, _ = self.lstm(embedding)
  24. return BackboneModelOutput(last_hidden_state=lstm_output)
  25. class Embedding(nn.Module):
  26. def __init__(self, vocab_size, embed_width):
  27. super(Embedding, self).__init__()
  28. self.embedding = nn.Embedding(vocab_size, embed_width)
  29. def forward(self, input_ids):
  30. return self.embedding(input_ids)