criterions.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import torch
  3. import torch.nn.functional as F
  4. from torch.nn.modules.loss import _Loss
  5. def compute_kl_loss(p, q, filter_scores=None):
  6. p_loss = F.kl_div(
  7. F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')
  8. q_loss = F.kl_div(
  9. F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')
  10. # You can choose whether to use function "sum" and "mean" depending on your task
  11. p_loss = p_loss.sum(dim=-1)
  12. q_loss = q_loss.sum(dim=-1)
  13. # mask is for filter mechanism
  14. if filter_scores is not None:
  15. p_loss = filter_scores * p_loss
  16. q_loss = filter_scores * q_loss
  17. p_loss = p_loss.mean()
  18. q_loss = q_loss.mean()
  19. loss = (p_loss + q_loss) / 2
  20. return loss
  21. class CatKLLoss(_Loss):
  22. """
  23. CatKLLoss
  24. """
  25. def __init__(self, reduction='mean'):
  26. super(CatKLLoss, self).__init__()
  27. assert reduction in ['none', 'sum', 'mean']
  28. self.reduction = reduction
  29. def forward(self, log_qy, log_py):
  30. """
  31. KL(qy|py) = Eq[qy * log(q(y) / p(y))]
  32. log_qy: (batch_size, latent_size)
  33. log_py: (batch_size, latent_size)
  34. """
  35. qy = torch.exp(log_qy)
  36. kl = torch.sum(qy * (log_qy - log_py), dim=1)
  37. if self.reduction == 'mean':
  38. kl = kl.mean()
  39. elif self.reduction == 'sum':
  40. kl = kl.sum()
  41. return kl