# Copyright (c) Alibaba, Inc. and its affiliates. import torch import torch.nn.functional as F from torch.nn.modules.loss import _Loss def compute_kl_loss(p, q, filter_scores=None): p_loss = F.kl_div( F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none') q_loss = F.kl_div( F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none') # You can choose whether to use function "sum" and "mean" depending on your task p_loss = p_loss.sum(dim=-1) q_loss = q_loss.sum(dim=-1) # mask is for filter mechanism if filter_scores is not None: p_loss = filter_scores * p_loss q_loss = filter_scores * q_loss p_loss = p_loss.mean() q_loss = q_loss.mean() loss = (p_loss + q_loss) / 2 return loss class CatKLLoss(_Loss): """ CatKLLoss """ def __init__(self, reduction='mean'): super(CatKLLoss, self).__init__() assert reduction in ['none', 'sum', 'mean'] self.reduction = reduction def forward(self, log_qy, log_py): """ KL(qy|py) = Eq[qy * log(q(y) / p(y))] log_qy: (batch_size, latent_size) log_py: (batch_size, latent_size) """ qy = torch.exp(log_qy) kl = torch.sum(qy * (log_qy - log_py), dim=1) if self.reduction == 'mean': kl = kl.mean() elif self.reduction == 'sum': kl = kl.sum() return kl