| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- # Copyright (c) 2022 Zhipu.AI
- import random
- import torch
- class PromptSpell(torch.nn.Module):
- def __init__(self, spell_length, hidden_size, spell_func):
- super(PromptSpell, self).__init__()
- self.spell_length = spell_length
- self.hidden_size = hidden_size
- self.spell_embeddings = torch.nn.Embedding(self.spell_length,
- self.hidden_size)
- self.spell_func = spell_func
- if self.spell_func == 'lstm':
- self.lstm_head = torch.nn.LSTM(
- input_size=self.hidden_size,
- hidden_size=self.hidden_size,
- num_layers=2,
- # dropout=self.lstm_dropout,
- bidirectional=True,
- batch_first=True) # .to(torch.device("cuda"))
- self.mlp_head = torch.nn.Sequential(
- torch.nn.Linear(2 * self.hidden_size, self.hidden_size),
- torch.nn.ReLU(),
- torch.nn.Linear(self.hidden_size, self.hidden_size))
- elif self.spell_func == 'mlp':
- self.mlp_head = torch.nn.Sequential(
- torch.nn.Linear(self.hidden_size, self.hidden_size),
- torch.nn.ReLU(),
- torch.nn.Linear(self.hidden_size, self.hidden_size))
- elif self.spell_func != 'none':
- raise NotImplementedError('Prompt function ' + self.spell_func)
- def init_embedding(self, word_embeddings=None, task_tokens=None):
- num_words = 5000
- with torch.no_grad():
- for i in range(self.spell_length):
- rand_token = random.randrange(num_words)
- if task_tokens is None:
- target_embedding = word_embeddings[rand_token]
- else:
- word_embedding = word_embeddings[rand_token]
- task_token = random.choice(task_tokens)
- task_embedding = word_embeddings[task_token]
- ratio = random.random()
- target_embedding = word_embedding * ratio + task_embedding * (
- 1 - ratio)
- self.spell_embeddings.weight.data[i] = target_embedding
- def forward(self):
- prompt_embeds = self.spell_embeddings.weight.unsqueeze(0)
- if self.spell_func == 'lstm':
- prompt_embeds = self.lstm_head(prompt_embeds)[0]
- if self.spell_func == 'lstm' or self.spell_func == 'mlp':
- prompt_embeds = self.mlp_head(prompt_embeds)
- return prompt_embeds
|