downstream.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Copyright (c) 2022 Zhipu.AI
  2. """Multiple choice model."""
  3. import torch
  4. import torch.nn
  5. from .modeling_glm import GLMModel
  6. class GLMForMultiTokenCloze(torch.nn.Module):
  7. def __init__(self,
  8. language_model: GLMModel,
  9. take_softmax=True,
  10. length_penalty=0.0):
  11. super(GLMForMultiTokenCloze, self).__init__()
  12. self.model = language_model
  13. self.take_softmax = take_softmax
  14. self.length_penalty = length_penalty
  15. def state_dict(self, destination=None, prefix='', keep_vars=False):
  16. # [h.remove() for h in self.hook_handles]
  17. sd = self.model.state_dict(destination, prefix, keep_vars)
  18. return sd
  19. def load_state_dict(self, state_dict, strict=True):
  20. return self.model.load_state_dict(state_dict, strict=strict)
  21. def named_parameters(self, prefix: str = '', recurse: bool = True):
  22. return self.model.named_parameters(prefix=prefix, recurse=recurse)
  23. def forward(self,
  24. input_ids,
  25. position_ids,
  26. attention_mask,
  27. target_ids=None,
  28. logit_mask=None,
  29. prompt_pos=None):
  30. if target_ids is None:
  31. return self.model(input_ids, position_ids, attention_mask)
  32. num_choices = None
  33. if len(input_ids.shape) == 3:
  34. batch_size, num_choices = input_ids.shape[:2]
  35. input_ids = input_ids.reshape(-1, input_ids.size(-1))
  36. attention_mask = attention_mask.reshape(-1,
  37. *attention_mask.size()[2:])
  38. position_ids = position_ids.reshape(-1, *position_ids.size()[2:])
  39. target_ids = target_ids.reshape(-1, target_ids.size(-1))
  40. logit_mask = logit_mask.reshape(-1, logit_mask.size(-1))
  41. if prompt_pos is not None:
  42. prompt_pos = prompt_pos.reshape(-1, prompt_pos.size(-1))
  43. outputs, *mems = self.model(
  44. input_ids, position_ids, attention_mask, prompt_pos=prompt_pos)
  45. if self.take_softmax:
  46. outputs = torch.nn.functional.log_softmax(outputs, dim=-1)
  47. # select the target logits
  48. batch_ids = torch.arange(
  49. target_ids.size(0), dtype=torch.long, device=target_ids.device)
  50. batch_ids = batch_ids.unsqueeze(1).expand_as(target_ids)
  51. seq_ids = torch.arange(
  52. target_ids.size(-1), dtype=torch.long, device=target_ids.device)
  53. seq_ids = seq_ids.unsqueeze(0).expand_as(target_ids)
  54. logits = outputs[batch_ids, seq_ids, target_ids]
  55. logits = (logits * logit_mask).sum(dim=1)
  56. if self.length_penalty > 0.0:
  57. logits = logits / logit_mask.sum(dim=1)**self.length_penalty
  58. if num_choices is not None:
  59. logits = logits.view(-1, num_choices)
  60. return (logits, *mems)
  61. class GLMForMultiTokenClozeFast(torch.nn.Module):
  62. def __init__(self, language_model, take_softmax=True, length_penalty=0.0):
  63. super(GLMForMultiTokenClozeFast, self).__init__()
  64. self.model = language_model
  65. self.take_softmax = take_softmax
  66. self.length_penalty = length_penalty
  67. def forward(self, input_ids, position_ids, attention_mask, dec_input_ids,
  68. dec_position_ids, dec_attention_mask, dec_target_ids,
  69. dec_logit_mask):
  70. # encoder
  71. outputs, *mems = self.model(
  72. input_ids,
  73. position_ids,
  74. attention_mask,
  75. return_memory=True,
  76. detach_memory=False)
  77. batch_size, num_choices, max_dec_len = dec_input_ids.size()
  78. max_enc_len = input_ids.size(-1)
  79. enc_mems = []
  80. for hidden in mems:
  81. hidden = hidden.unsqueeze(1).expand(-1, num_choices, -1,
  82. -1).reshape(
  83. batch_size * num_choices,
  84. *hidden.size()[1:])
  85. enc_mems.append(hidden)
  86. def build_dec_mask_matrix(seq_length, sep, memory_length=0):
  87. m = enc_mems[0].new_ones((1, seq_length, seq_length))
  88. m = torch.tril(m)
  89. # sep = dec_attention_mask
  90. ids = torch.arange(
  91. memory_length, device=sep.device, dtype=sep.dtype).view(1, -1)
  92. mask = ids < sep.view(-1, 1) # batch * mem
  93. mask = mask.unsqueeze(1).float().expand(-1, seq_length, -1)
  94. m = m.expand(batch_size * num_choices, -1, -1)
  95. m = torch.cat((mask, m), dim=2)
  96. m = m.unsqueeze(1)
  97. return m
  98. dec_input_ids = dec_input_ids.reshape(-1, max_dec_len)
  99. dec_position_ids = dec_position_ids.reshape(
  100. -1,
  101. *dec_position_ids.size()[2:])
  102. # dec_attention_mask = dec_attention_mask.reshape(-1, *dec_attention_mask.size()[2:]).unsqueeze(1)
  103. dec_attention_mask = build_dec_mask_matrix(
  104. max_dec_len, dec_attention_mask.reshape(-1), max_enc_len)
  105. dec_target_ids = dec_target_ids.reshape(-1, dec_target_ids.size(-1))
  106. dec_logit_mask = dec_logit_mask.reshape(-1, dec_logit_mask.size(-1))
  107. outputs, *mems = self.model(dec_input_ids, dec_position_ids,
  108. dec_attention_mask, *enc_mems)
  109. if self.take_softmax:
  110. outputs = torch.nn.functional.log_softmax(outputs, dim=-1)
  111. batch_ids = torch.arange(
  112. dec_target_ids.size(0),
  113. dtype=torch.long,
  114. device=dec_target_ids.device)
  115. batch_ids = batch_ids.unsqueeze(1).expand_as(dec_target_ids)
  116. seq_ids = torch.arange(
  117. dec_target_ids.size(-1),
  118. dtype=torch.long,
  119. device=dec_target_ids.device)
  120. seq_ids = seq_ids.unsqueeze(0).expand_as(dec_target_ids)
  121. logits = outputs[batch_ids, seq_ids, dec_target_ids]
  122. logits = (logits * dec_logit_mask).sum(dim=1)
  123. if self.length_penalty > 0.0:
  124. logits = logits / dec_logit_mask.sum(dim=1)**self.length_penalty
  125. if num_choices is not None:
  126. logits = logits.view(-1, num_choices)
  127. return (logits, *mems)
  128. class GLMForSingleTokenCloze(torch.nn.Module):
  129. def __init__(self, language_model, take_softmax=False):
  130. super().__init__()
  131. self.model = language_model
  132. self.take_softmax = take_softmax
  133. def state_dict(self, destination=None, prefix='', keep_vars=False):
  134. # [h.remove() for h in self.hook_handles]
  135. sd = self.model.state_dict(destination, prefix, keep_vars)
  136. return sd
  137. def load_state_dict(self, state_dict, strict=True):
  138. return self.model.load_state_dict(state_dict, strict=strict)
  139. def named_parameters(self, prefix: str = '', recurse: bool = True):
  140. return self.model.named_parameters(prefix=prefix, recurse=recurse)
  141. def forward(self,
  142. input_ids,
  143. position_ids,
  144. attention_mask,
  145. target_ids=None,
  146. logit_mask=None,
  147. prompt_pos=None):
  148. if target_ids is None:
  149. return self.model(input_ids, position_ids, attention_mask)
  150. assert len(input_ids.shape) == 2
  151. outputs, *mems = self.model(
  152. input_ids, position_ids, attention_mask, prompt_pos=prompt_pos)
  153. batch_ids = torch.arange(
  154. outputs.size(0),
  155. dtype=attention_mask.dtype,
  156. device=attention_mask.device)
  157. target_logits = outputs[batch_ids, attention_mask]
  158. if self.take_softmax:
  159. target_prob = torch.nn.functional.log_softmax(
  160. target_logits, dim=-1)
  161. else:
  162. target_prob = target_logits
  163. batch_ids = batch_ids.unsqueeze(1).expand_as(target_ids)
  164. output = target_prob[batch_ids, target_ids]
  165. return (output, target_logits, *mems)
  166. class GLMForSequenceClassification(torch.nn.Module):
  167. def __init__(self,
  168. language_model,
  169. hidden_size,
  170. hidden_dropout,
  171. pool_token,
  172. num_class=1):
  173. super().__init__()
  174. self.pool_token = pool_token
  175. self.model = language_model
  176. self.num_class = num_class
  177. # Multi-choice head.
  178. self.pool_layer = torch.nn.Linear(hidden_size, hidden_size)
  179. self.multichoice_dropout = torch.nn.Dropout(hidden_dropout)
  180. self.multichoice_head = torch.nn.Linear(hidden_size, num_class)
  181. def forward(self, input_ids, position_ids, attention_mask):
  182. num_choices = None
  183. if len(input_ids.shape) == 3:
  184. assert self.num_class == 1
  185. batch_size, num_choices = input_ids.shape[:2]
  186. input_ids = input_ids.reshape(-1, input_ids.size(-1))
  187. attention_mask = attention_mask.reshape(-1,
  188. *attention_mask.size()[2:])
  189. position_ids = position_ids.reshape(-1, *position_ids.size()[2:])
  190. outputs, *mems = self.model(input_ids, position_ids, attention_mask)
  191. if self.pool_token == 'start':
  192. output = outputs[torch.arange(
  193. outputs.size(0),
  194. dtype=attention_mask.dtype,
  195. device=attention_mask.device), attention_mask]
  196. elif self.pool_token == 'pad':
  197. output = outputs[torch.arange(
  198. outputs.size(0),
  199. dtype=attention_mask.dtype,
  200. device=attention_mask.device), attention_mask - 1]
  201. elif self.pool_token == 'cls':
  202. output = outputs[:, 0]
  203. else:
  204. raise NotImplementedError
  205. output = torch.tanh(self.pool_layer(output))
  206. multichoice_output = self.multichoice_dropout(output)
  207. logits = self.multichoice_head(multichoice_output)
  208. if num_choices is not None:
  209. logits = logits.view(-1, num_choices)
  210. return (logits, *mems)