123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- import torch.nn as nn
- from torch.nn import init
- from torchvision import models
- from easyfl.models import BaseModel
- def weights_init_kaiming(m):
- classname = m.__class__.__name__
- if classname.find('Conv') != -1:
- init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # For old pytorch, you may use kaiming_normal.
- elif classname.find('Linear') != -1:
- init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')
- init.constant_(m.bias.data, 0.0)
- elif classname.find('BatchNorm1d') != -1:
- init.normal_(m.weight.data, 1.0, 0.02)
- init.constant_(m.bias.data, 0.0)
- def weights_init_classifier(m):
- classname = m.__class__.__name__
- if classname.find('Linear') != -1:
- init.normal_(m.weight.data, std=0.001)
- init.constant_(m.bias.data, 0.0)
- class ClassBlock(nn.Module):
- def __init__(self, input_dim, class_num, droprate, relu=False, bnorm=True, num_bottleneck=512, linear=True,
- return_f=False):
- super(ClassBlock, self).__init__()
- self.return_f = return_f
- add_block = []
- if linear:
- add_block += [nn.Linear(input_dim, num_bottleneck)]
- else:
- num_bottleneck = input_dim
- if bnorm:
- add_block += [nn.BatchNorm1d(num_bottleneck)]
- if relu:
- add_block += [nn.LeakyReLU(0.1)]
- if droprate > 0:
- add_block += [nn.Dropout(p=droprate)]
- add_block = nn.Sequential(*add_block)
- add_block.apply(weights_init_kaiming)
- classifier = []
- classifier += [nn.Linear(num_bottleneck, class_num)]
- classifier = nn.Sequential(*classifier)
- classifier.apply(weights_init_classifier)
- self.add_block = add_block
- self.classifier = classifier
- def forward(self, x):
- x = self.add_block(x)
- if self.return_f:
- f = x
- x = self.classifier(x)
- return x, f
- else:
- x = self.classifier(x)
- return x
- # Define the ResNet50-based Model
- class Model(BaseModel):
- def __init__(self, class_num=0, droprate=0.5, stride=2):
- super(Model, self).__init__()
- model_ft = models.resnet50(pretrained=True)
- self.class_num = class_num
- if stride == 1:
- model_ft.layer4[0].downsample[0].stride = (1, 1)
- model_ft.layer4[0].conv2.stride = (1, 1)
- model_ft.avgpool = nn.AdaptiveAvgPool2d((1, 1))
- self.model = model_ft
- if self.class_num != 0:
- self.classifier = ClassBlock(2048, class_num, droprate)
- else:
- self.classifier = ClassBlock(2048, 10, droprate) # 10 is not effective because classifier is replaced below
- self.classifier.classifier = nn.Sequential()
- def forward(self, x):
- x = self.model.conv1(x)
- x = self.model.bn1(x)
- x = self.model.relu(x)
- x = self.model.maxpool(x)
- x = self.model.layer1(x)
- x = self.model.layer2(x)
- x = self.model.layer3(x)
- x = self.model.layer4(x)
- x = self.model.avgpool(x)
- x = x.view(x.size(0), x.size(1))
- x = self.classifier(x)
- return x
- def get_classifier(class_num, num_bottleneck=512):
- classifier = []
- classifier += [nn.Linear(num_bottleneck, class_num)]
- classifier = nn.Sequential(*classifier)
- classifier.apply(weights_init_classifier)
- return classifier
|