model.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. from __future__ import absolute_import
  2. from torch import nn
  3. from torch.nn import functional as F
  4. from torch.nn import init
  5. from easyfl.models.model import BaseModel
  6. from .resnet import *
  7. __all__ = ["BUCModel"]
  8. class AvgPooling(nn.Module):
  9. def __init__(self, input_feature_size, embedding_feature_size=2048, dropout=0.5):
  10. super(self.__class__, self).__init__()
  11. # embedding
  12. self.embedding_feature_size = embedding_feature_size
  13. self.embedding = nn.Linear(input_feature_size, embedding_feature_size)
  14. self.embedding_bn = nn.BatchNorm1d(embedding_feature_size)
  15. init.kaiming_normal_(self.embedding.weight, mode='fan_out')
  16. init.constant_(self.embedding.bias, 0)
  17. init.constant_(self.embedding_bn.weight, 1)
  18. init.constant_(self.embedding_bn.bias, 0)
  19. self.drop = nn.Dropout(dropout)
  20. def forward(self, inputs):
  21. net = inputs.mean(dim=1)
  22. eval_features = F.normalize(net, p=2, dim=1)
  23. net = self.embedding(net)
  24. net = self.embedding_bn(net)
  25. net = F.normalize(net, p=2, dim=1)
  26. net = self.drop(net)
  27. return net, eval_features
  28. class BUCModel(BaseModel):
  29. def __init__(self, dropout=0.5, embedding_feature_size=2048):
  30. super(self.__class__, self).__init__()
  31. self.CNN = resnet50(dropout=dropout)
  32. self.avg_pooling = AvgPooling(input_feature_size=2048,
  33. embedding_feature_size=embedding_feature_size,
  34. dropout=dropout)
  35. def forward(self, x):
  36. # resnet encoding
  37. resnet_feature = self.CNN(x)
  38. shape = resnet_feature.shape
  39. # reshape back into (batch, samples, ...)
  40. # samples of video frames, we only use images, so always 1.
  41. resnet_feature = resnet_feature.view(shape[0], 1, -1)
  42. # avg pooling
  43. output = self.avg_pooling(resnet_feature)
  44. return output