cross_entropy.py 913 B

12345678910111213141516171819202122232425262728293031323334
  1. import torch as t
  2. from federatedml.util import consts
  3. from torch.nn.functional import one_hot
  4. def cross_entropy(p2, p1, reduction='mean'):
  5. p2 = p2 + consts.FLOAT_ZERO # to avoid nan
  6. assert p2.shape == p1.shape
  7. if reduction == 'sum':
  8. return -t.sum(p1 * t.log(p2))
  9. elif reduction == 'mean':
  10. return -t.mean(t.sum(p1 * t.log(p2), dim=1))
  11. elif reduction == 'none':
  12. return -t.sum(p1 * t.log(p2), dim=1)
  13. else:
  14. raise ValueError('unknown reduction')
  15. class CrossEntropyLoss(t.nn.Module):
  16. """
  17. A CrossEntropy Loss that will not compute Softmax
  18. """
  19. def __init__(self, reduction='mean'):
  20. super(CrossEntropyLoss, self).__init__()
  21. self.reduction = reduction
  22. def forward(self, pred, label):
  23. one_hot_label = one_hot(label.flatten())
  24. loss_ = cross_entropy(pred, one_hot_label, self.reduction)
  25. return loss_