weighted_loss.py 425 B

12345678910111213141516
  1. import torch as t
  2. from torch.nn import BCELoss
  3. class WeightedBCE(t.nn.Module):
  4. def __init__(self) -> None:
  5. super().__init__()
  6. self.loss_fn = BCELoss(reduce=False)
  7. def forward(self, pred, label_and_weight):
  8. label, weights = label_and_weight
  9. losses = self.loss_fn(pred, label)
  10. losses = losses * weights
  11. loss_val = losses.sum() / weights.sum()
  12. return loss_val