| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import torch
- import torch.nn.functional as F
- from torch.nn.modules.loss import _Loss
- def compute_kl_loss(p, q, filter_scores=None):
- p_loss = F.kl_div(
- F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')
- q_loss = F.kl_div(
- F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')
- # You can choose whether to use function "sum" and "mean" depending on your task
- p_loss = p_loss.sum(dim=-1)
- q_loss = q_loss.sum(dim=-1)
- # mask is for filter mechanism
- if filter_scores is not None:
- p_loss = filter_scores * p_loss
- q_loss = filter_scores * q_loss
- p_loss = p_loss.mean()
- q_loss = q_loss.mean()
- loss = (p_loss + q_loss) / 2
- return loss
- class CatKLLoss(_Loss):
- """
- CatKLLoss
- """
- def __init__(self, reduction='mean'):
- super(CatKLLoss, self).__init__()
- assert reduction in ['none', 'sum', 'mean']
- self.reduction = reduction
- def forward(self, log_qy, log_py):
- """
- KL(qy|py) = Eq[qy * log(q(y) / p(y))]
- log_qy: (batch_size, latent_size)
- log_py: (batch_size, latent_size)
- """
- qy = torch.exp(log_qy)
- kl = torch.sum(qy * (log_qy - log_py), dim=1)
- if self.reduction == 'mean':
- kl = kl.mean()
- elif self.reduction == 'sum':
- kl = kl.sum()
- return kl
|