adv_utils.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
  2. # All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import torch
  16. from torch import nn
  17. from modelscope.utils.logger import get_logger
  18. logger = get_logger()
  19. def _symmetric_kl_div(logits1, logits2, attention_mask=None):
  20. """
  21. Calculate two logits' the KL div value symmetrically.
  22. :param logits1: The first logit.
  23. :param logits2: The second logit.
  24. :param attention_mask: An optional attention_mask which is used to mask some element out.
  25. This is usually useful in token_classification tasks.
  26. If the shape of logits is [N1, N2, ... Nn, D], the shape of attention_mask should be [N1, N2, ... Nn]
  27. :return: The mean loss.
  28. """
  29. labels_num = logits1.shape[-1]
  30. KLDiv = nn.KLDivLoss(reduction='none')
  31. loss = torch.sum(
  32. KLDiv(nn.LogSoftmax(dim=-1)(logits1),
  33. nn.Softmax(dim=-1)(logits2)),
  34. dim=-1) + torch.sum(
  35. KLDiv(nn.LogSoftmax(dim=-1)(logits2),
  36. nn.Softmax(dim=-1)(logits1)),
  37. dim=-1)
  38. if attention_mask is not None:
  39. loss = torch.sum(
  40. loss * attention_mask) / torch.sum(attention_mask) / labels_num
  41. else:
  42. loss = torch.mean(loss) / labels_num
  43. return loss
  44. def compute_adv_loss(embedding,
  45. model,
  46. ori_logits,
  47. ori_loss,
  48. adv_grad_factor,
  49. adv_bound=None,
  50. sigma=5e-6,
  51. **kwargs):
  52. """
  53. Calculate the adv loss of the model.
  54. :param embedding: Original sentence embedding
  55. :param model: The model, or the forward function(including decoder/classifier),
  56. accept kwargs as input, output logits
  57. :param ori_logits: The original logits outputted from the model function
  58. :param ori_loss: The original loss
  59. :param adv_grad_factor: This factor will be multiplied by the KL loss grad and then the result will be added to
  60. the original embedding.
  61. More details please check:https://arxiv.org/abs/1908.04577
  62. The range of this value always be 1e-3~1e-7
  63. :param adv_bound: adv_bound is used to cut the top and the bottom bound of the produced embedding.
  64. If not proveded, 2 * sigma will be used as the adv_bound factor
  65. :param sigma: The std factor used to produce a 0 mean normal distribution.
  66. If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor
  67. :param kwargs: the input param used in model function
  68. :return: The original loss adds the adv loss
  69. """
  70. adv_bound = adv_bound if adv_bound is not None else 2 * sigma
  71. embedding_1 = embedding + embedding.data.new(embedding.size()).normal_(
  72. 0, sigma) # 95% in +- 1e-5
  73. kwargs.pop('input_ids')
  74. if 'inputs_embeds' in kwargs:
  75. kwargs.pop('inputs_embeds')
  76. with_attention_mask = False if 'with_attention_mask' not in kwargs else kwargs[
  77. 'with_attention_mask']
  78. attention_mask = kwargs['attention_mask']
  79. if not with_attention_mask:
  80. attention_mask = None
  81. if 'with_attention_mask' in kwargs:
  82. kwargs.pop('with_attention_mask')
  83. outputs = model(**kwargs, inputs_embeds=embedding_1)
  84. v1_logits = outputs.logits
  85. loss = _symmetric_kl_div(ori_logits, v1_logits, attention_mask)
  86. emb_grad = torch.autograd.grad(loss, embedding_1)[0].data
  87. emb_grad_norm = emb_grad.norm(
  88. dim=2, keepdim=True, p=float('inf')).max(
  89. 1, keepdim=True)[0]
  90. is_nan = torch.any(torch.isnan(emb_grad_norm))
  91. if is_nan:
  92. logger.warning('Nan occurred when calculating adv loss.')
  93. return ori_loss
  94. emb_grad = emb_grad / (emb_grad_norm + 1e-6)
  95. embedding_2 = embedding_1 + adv_grad_factor * emb_grad
  96. embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2)
  97. embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2)
  98. outputs = model(**kwargs, inputs_embeds=embedding_2)
  99. adv_logits = outputs.logits
  100. adv_loss = _symmetric_kl_div(ori_logits, adv_logits, attention_mask)
  101. return ori_loss + adv_loss
  102. def compute_adv_loss_pair(embedding,
  103. model,
  104. start_logits,
  105. end_logits,
  106. ori_loss,
  107. adv_grad_factor,
  108. adv_bound=None,
  109. sigma=5e-6,
  110. **kwargs):
  111. """
  112. Calculate the adv loss of the model. This function is used in the pair logits scenario.
  113. :param embedding: Original sentence embedding
  114. :param model: The model, or the forward function(including decoder/classifier),
  115. accept kwargs as input, output logits
  116. :param start_logits: The original start logits outputted from the model function
  117. :param end_logits: The original end logits outputted from the model function
  118. :param ori_loss: The original loss
  119. :param adv_grad_factor: This factor will be multiplied by the KL loss grad and then the result will be added to
  120. the original embedding.
  121. More details please check:https://arxiv.org/abs/1908.04577
  122. The range of this value always be 1e-3~1e-7
  123. :param adv_bound: adv_bound is used to cut the top and the bottom bound of the produced embedding.
  124. If not proveded, 2 * sigma will be used as the adv_bound factor
  125. :param sigma: The std factor used to produce a 0 mean normal distribution.
  126. If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor
  127. :param kwargs: the input param used in model function
  128. :return: The original loss adds the adv loss
  129. """
  130. adv_bound = adv_bound if adv_bound is not None else 2 * sigma
  131. embedding_1 = embedding + embedding.data.new(embedding.size()).normal_(
  132. 0, sigma) # 95% in +- 1e-5
  133. kwargs.pop('input_ids')
  134. if 'inputs_embeds' in kwargs:
  135. kwargs.pop('inputs_embeds')
  136. outputs = model(**kwargs, inputs_embeds=embedding_1)
  137. v1_logits_start, v1_logits_end = outputs.logits
  138. loss = _symmetric_kl_div(start_logits,
  139. v1_logits_start) + _symmetric_kl_div(
  140. end_logits, v1_logits_end)
  141. loss = loss / 2
  142. emb_grad = torch.autograd.grad(loss, embedding_1)[0].data
  143. emb_grad_norm = emb_grad.norm(
  144. dim=2, keepdim=True, p=float('inf')).max(
  145. 1, keepdim=True)[0]
  146. is_nan = torch.any(torch.isnan(emb_grad_norm))
  147. if is_nan:
  148. logger.warning('Nan occurred when calculating pair adv loss.')
  149. return ori_loss
  150. emb_grad = emb_grad / emb_grad_norm
  151. embedding_2 = embedding_1 + adv_grad_factor * emb_grad
  152. embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2)
  153. embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2)
  154. outputs = model(**kwargs, inputs_embeds=embedding_2)
  155. adv_logits_start, adv_logits_end = outputs.logits
  156. adv_loss = _symmetric_kl_div(start_logits,
  157. adv_logits_start) + _symmetric_kl_div(
  158. end_logits, adv_logits_end)
  159. return ori_loss + adv_loss