import torch
[docs]def soft_cross_entropy(input, targets):
student_likelihood = torch.nn.functional.log_softmax(input, dim=-1)
targets_prob = torch.nn.functional.softmax(targets, dim=-1)
return (- targets_prob * student_likelihood).sum(dim=-1).mean()
[docs]def soft_cross_entropy_tinybert(input, targets):
student_likelihood = torch.nn.functional.log_softmax(input, dim=-1)
targets_prob = torch.nn.functional.softmax(targets, dim=-1)
return (- targets_prob * student_likelihood).mean()
[docs]def soft_kl_div_loss(input, targets, reduction="batchmean", **kwargs):
student_likelihood = torch.nn.functional.log_softmax(input, dim=-1)
targets_prob = torch.nn.functional.softmax(targets, dim=-1)
return torch.nn.functional.kl_div(student_likelihood, targets_prob, reduction=reduction, **kwargs)
[docs]def mse_loss(inputs, targets, **kwargs):
return torch.nn.functional.mse_loss(inputs, targets, **kwargs)
[docs]def cosine_embedding_loss(input1, input2, target, **kwargs):
return torch.nn.functional.cosine_embedding_loss(input1, input2, target, reduction="mean", **kwargs)