prompt.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # Copyright (c) 2022 Zhipu.AI
  2. import random
  3. import torch
  4. class PromptSpell(torch.nn.Module):
  5. def __init__(self, spell_length, hidden_size, spell_func):
  6. super(PromptSpell, self).__init__()
  7. self.spell_length = spell_length
  8. self.hidden_size = hidden_size
  9. self.spell_embeddings = torch.nn.Embedding(self.spell_length,
  10. self.hidden_size)
  11. self.spell_func = spell_func
  12. if self.spell_func == 'lstm':
  13. self.lstm_head = torch.nn.LSTM(
  14. input_size=self.hidden_size,
  15. hidden_size=self.hidden_size,
  16. num_layers=2,
  17. # dropout=self.lstm_dropout,
  18. bidirectional=True,
  19. batch_first=True) # .to(torch.device("cuda"))
  20. self.mlp_head = torch.nn.Sequential(
  21. torch.nn.Linear(2 * self.hidden_size, self.hidden_size),
  22. torch.nn.ReLU(),
  23. torch.nn.Linear(self.hidden_size, self.hidden_size))
  24. elif self.spell_func == 'mlp':
  25. self.mlp_head = torch.nn.Sequential(
  26. torch.nn.Linear(self.hidden_size, self.hidden_size),
  27. torch.nn.ReLU(),
  28. torch.nn.Linear(self.hidden_size, self.hidden_size))
  29. elif self.spell_func != 'none':
  30. raise NotImplementedError('Prompt function ' + self.spell_func)
  31. def init_embedding(self, word_embeddings=None, task_tokens=None):
  32. num_words = 5000
  33. with torch.no_grad():
  34. for i in range(self.spell_length):
  35. rand_token = random.randrange(num_words)
  36. if task_tokens is None:
  37. target_embedding = word_embeddings[rand_token]
  38. else:
  39. word_embedding = word_embeddings[rand_token]
  40. task_token = random.choice(task_tokens)
  41. task_embedding = word_embeddings[task_token]
  42. ratio = random.random()
  43. target_embedding = word_embedding * ratio + task_embedding * (
  44. 1 - ratio)
  45. self.spell_embeddings.weight.data[i] = target_embedding
  46. def forward(self):
  47. prompt_embeds = self.spell_embeddings.weight.unsqueeze(0)
  48. if self.spell_func == 'lstm':
  49. prompt_embeds = self.lstm_head(prompt_embeds)[0]
  50. if self.spell_func == 'lstm' or self.spell_func == 'mlp':
  51. prompt_embeds = self.mlp_head(prompt_embeds)
  52. return prompt_embeds