exclusive_loss.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. from __future__ import absolute_import
  2. import torch
  3. import torch.nn.functional as F
  4. from torch import nn, autograd
  5. class Exclusive(autograd.Function):
  6. # def __init__(ctx, V):
  7. # super(Exclusive, ctx).__init__()
  8. # ctx.V = V
  9. @staticmethod
  10. def forward(ctx, inputs, targets, V):
  11. ctx.V = V
  12. ctx.save_for_backward(inputs, targets)
  13. outputs = inputs.mm(ctx.V.t())
  14. return outputs
  15. @staticmethod
  16. def backward(ctx, grad_outputs):
  17. inputs, targets = ctx.saved_tensors
  18. grad_inputs = grad_outputs.mm(ctx.V) if ctx.needs_input_grad[0] else None
  19. for x, y in zip(inputs, targets):
  20. ctx.V[y] = F.normalize( (ctx.V[y] + x) / 2, p=2, dim=0)
  21. return grad_inputs, None, None
  22. class ExLoss(nn.Module):
  23. def __init__(self, num_features, num_classes, t=1.0, weight=None):
  24. super(ExLoss, self).__init__()
  25. self.num_features = num_features
  26. self.t = t
  27. self.weight = weight
  28. self.register_buffer('V', torch.zeros(num_classes, num_features))
  29. def forward(self, inputs, targets):
  30. outputs = Exclusive.apply(inputs, targets, self.V) * self.t
  31. # outputs = Exclusive(self.V)(inputs, targets) * self.t
  32. loss = F.cross_entropy(outputs, targets, weight=self.weight)
  33. return loss, outputs