""" Creates an Xception Model as defined in: Xception: Deep Learning with Depthwise Separable Convolutions, https://arxiv.org/pdf/1610.02357.pdf """ import math import torch import torch.nn as nn import torch.nn.functional as F from easyfl.models.model import BaseModel from .ozan_rep_fun import ozan_rep_function, OzanRepFunction, gradnorm_rep_function, GradNormRepFunction, \ trevor_rep_function, TrevorRepFunction __all__ = ['xception', 'xception_gradnorm', 'xception_half_gradnorm', 'xception_ozan', 'xception_half', 'xception_quad', 'xception_double', 'xception_double_ozan', 'xception_half_ozan', 'xception_quad_ozan'] # model_urls = { # 'xception_taskonomy':'file:///home/tstand/Dropbox/taskonomy/xception_taskonomy-a4b32ef7.pth.tar' # } class SeparableConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False, groupsize=1): super(SeparableConv2d, self).__init__() self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=max(1, in_channels // groupsize), bias=bias) self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) def forward(self, x): x = self.conv1(x) x = self.pointwise(x) return x class Block(nn.Module): def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True): super(Block, self).__init__() if out_filters != in_filters or strides != 1: self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False) self.skipbn = nn.BatchNorm2d(out_filters) else: self.skip = None self.relu = nn.ReLU(inplace=True) rep = [] filters = in_filters if grow_first: rep.append(self.relu) rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False)) rep.append(nn.BatchNorm2d(out_filters)) filters = out_filters for i in range(reps - 1): rep.append(self.relu) rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False)) rep.append(nn.BatchNorm2d(filters)) if not grow_first: rep.append(self.relu) rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False)) rep.append(nn.BatchNorm2d(out_filters)) filters = out_filters if not start_with_relu: rep = rep[1:] else: rep[0] = nn.ReLU(inplace=False) if strides != 1: # rep.append(nn.AvgPool2d(3,strides,1)) rep.append(nn.Conv2d(filters, filters, 2, 2)) self.rep = nn.Sequential(*rep) def forward(self, inp): x = self.rep(inp) if self.skip is not None: skip = self.skip(inp) skip = self.skipbn(skip) else: skip = inp x += skip return x class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() self.conv1 = nn.Conv2d(3, 24, 3, 2, 1, bias=False) self.bn1 = nn.BatchNorm2d(24) self.relu = nn.ReLU(inplace=True) self.relu2 = nn.ReLU(inplace=False) self.conv2 = nn.Conv2d(24, 48, 3, 1, 1, bias=False) self.bn2 = nn.BatchNorm2d(48) # do relu here self.block1 = Block(48, 96, 2, 2, start_with_relu=False, grow_first=True) self.block2 = Block(96, 192, 2, 2, start_with_relu=True, grow_first=True) self.block3 = Block(192, 512, 2, 2, start_with_relu=True, grow_first=True) # self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True) # self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True) # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True) # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True) self.block8 = Block(512, 512, 2, 1, start_with_relu=True, grow_first=True) self.block9 = Block(512, 512, 2, 1, start_with_relu=True, grow_first=True) self.block10 = Block(512, 512, 2, 1, start_with_relu=True, grow_first=True) self.block11 = Block(512, 512, 2, 1, start_with_relu=True, grow_first=True) # self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) self.conv3 = SeparableConv2d(512, 256, 3, 1, 1) self.bn3 = nn.BatchNorm2d(256) # self.conv3 = SeparableConv2d(1024,1536,3,1,1) # self.bn3 = nn.BatchNorm2d(1536) # do relu here # self.conv4 = SeparableConv2d(1536,2048,3,1,1) # self.bn4 = nn.BatchNorm2d(2048) def forward(self, input): x = self.conv1(input) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.block1(x) x = self.block2(x) x = self.block3(x) # x = self.block4(x) # x = self.block5(x) # x = self.block6(x) # x = self.block7(x) x = self.block8(x) x = self.block9(x) x = self.block10(x) x = self.block11(x) # x = self.block12(x) x = self.conv3(x) x = self.bn3(x) # x = self.relu(x) # x = self.conv4(x) # x = self.bn4(x) representation = self.relu2(x) return representation class EncoderHalf(nn.Module): def __init__(self): super(EncoderHalf, self).__init__() self.conv1 = nn.Conv2d(3, 24, 3, 2, 1, bias=False) self.bn1 = nn.BatchNorm2d(24) self.relu = nn.ReLU(inplace=True) self.relu2 = nn.ReLU(inplace=False) self.conv2 = nn.Conv2d(24, 48, 3, 1, 1, bias=False) self.bn2 = nn.BatchNorm2d(48) # do relu here self.block1 = Block(48, 64, 2, 2, start_with_relu=False, grow_first=True) self.block2 = Block(64, 128, 2, 2, start_with_relu=True, grow_first=True) self.block3 = Block(128, 360, 2, 2, start_with_relu=True, grow_first=True) # self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True) # self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True) # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True) # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True) self.block8 = Block(360, 360, 2, 1, start_with_relu=True, grow_first=True) self.block9 = Block(360, 360, 2, 1, start_with_relu=True, grow_first=True) self.block10 = Block(360, 360, 2, 1, start_with_relu=True, grow_first=True) self.block11 = Block(360, 360, 2, 1, start_with_relu=True, grow_first=True) # self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) self.conv3 = SeparableConv2d(360, 256, 3, 1, 1) self.bn3 = nn.BatchNorm2d(256) # self.conv3 = SeparableConv2d(1024,1536,3,1,1) # self.bn3 = nn.BatchNorm2d(1536) # do relu here # self.conv4 = SeparableConv2d(1536,2048,3,1,1) # self.bn4 = nn.BatchNorm2d(2048) def forward(self, input): x = self.conv1(input) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.block1(x) x = self.block2(x) x = self.block3(x) # x = self.block4(x) # x = self.block5(x) # x = self.block6(x) # x = self.block7(x) x = self.block8(x) x = self.block9(x) x = self.block10(x) x = self.block11(x) # x = self.block12(x) x = self.conv3(x) x = self.bn3(x) # x = self.relu(x) # x = self.conv4(x) # x = self.bn4(x) representation = self.relu2(x) return representation class EncoderQuad(nn.Module): def __init__(self): super(EncoderQuad, self).__init__() print('entering quad constructor') self.conv1 = nn.Conv2d(3, 48, 3, 2, 1, bias=False) self.bn1 = nn.BatchNorm2d(48) self.relu = nn.ReLU(inplace=True) self.relu2 = nn.ReLU(inplace=False) self.conv2 = nn.Conv2d(48, 96, 3, 1, 1, bias=False) self.bn2 = nn.BatchNorm2d(96) # do relu here self.block1 = Block(96, 192, 2, 2, start_with_relu=False, grow_first=True) self.block2 = Block(192, 384, 2, 2, start_with_relu=True, grow_first=True) self.block3 = Block(384, 1024, 2, 2, start_with_relu=True, grow_first=True) # self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True) # self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True) # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True) # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True) self.block8 = Block(1024, 1024, 2, 1, start_with_relu=True, grow_first=True) self.block9 = Block(1024, 1024, 2, 1, start_with_relu=True, grow_first=True) self.block10 = Block(1024, 1024, 2, 1, start_with_relu=True, grow_first=True) self.block11 = Block(1024, 1024, 2, 1, start_with_relu=True, grow_first=True) # self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) self.conv3 = SeparableConv2d(1024, 256, 3, 1, 1) self.bn3 = nn.BatchNorm2d(256) # self.conv3 = SeparableConv2d(1024,1536,3,1,1) # self.bn3 = nn.BatchNorm2d(1536) # do relu here # self.conv4 = SeparableConv2d(1536,2048,3,1,1) # self.bn4 = nn.BatchNorm2d(2048) def forward(self, input): x = self.conv1(input) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.block1(x) x = self.block2(x) x = self.block3(x) # x = self.block4(x) # x = self.block5(x) # x = self.block6(x) # x = self.block7(x) x = self.block8(x) x = self.block9(x) x = self.block10(x) x = self.block11(x) # x = self.block12(x) x = self.conv3(x) x = self.bn3(x) # x = self.relu(x) # x = self.conv4(x) # x = self.bn4(x) representation = self.relu2(x) return representation class EncoderDouble(nn.Module): def __init__(self): super(EncoderDouble, self).__init__() print('entering double constructor') self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False) self.bn1 = nn.BatchNorm2d(32) self.relu = nn.ReLU(inplace=True) self.relu2 = nn.ReLU(inplace=False) self.conv2 = nn.Conv2d(32, 64, 3, 1, 1, bias=False) self.bn2 = nn.BatchNorm2d(64) # do relu here self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True) self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True) self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True) # self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True) # self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True) # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True) # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True) self.block8 = Block(728, 728, 2, 1, start_with_relu=True, grow_first=True) self.block9 = Block(728, 728, 2, 1, start_with_relu=True, grow_first=True) self.block10 = Block(728, 728, 2, 1, start_with_relu=True, grow_first=True) self.block11 = Block(728, 728, 2, 1, start_with_relu=True, grow_first=True) # self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) self.conv3 = SeparableConv2d(728, 256, 3, 1, 1) self.bn3 = nn.BatchNorm2d(256) # self.conv3 = SeparableConv2d(1024,1536,3,1,1) # self.bn3 = nn.BatchNorm2d(1536) # do relu here # self.conv4 = SeparableConv2d(1536,2048,3,1,1) # self.bn4 = nn.BatchNorm2d(2048) def forward(self, input): x = self.conv1(input) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.block1(x) x = self.block2(x) x = self.block3(x) # x = self.block4(x) # x = self.block5(x) # x = self.block6(x) # x = self.block7(x) x = self.block8(x) x = self.block9(x) x = self.block10(x) x = self.block11(x) # x = self.block12(x) x = self.conv3(x) x = self.bn3(x) # x = self.relu(x) # x = self.conv4(x) # x = self.bn4(x) representation = self.relu2(x) return representation def interpolate(inp, size): t = inp.type() inp = inp.float() out = nn.functional.interpolate(inp, size=size, mode='bilinear', align_corners=False) if out.type() != t: out = out.half() return out class Decoder(nn.Module): def __init__(self, output_channels=32, num_classes=None): super(Decoder, self).__init__() self.output_channels = output_channels self.num_classes = num_classes if num_classes is not None: self.fc = nn.Linear(256, num_classes) # else: # self.fc = nn.Linear(256, 1000) else: self.relu = nn.ReLU(inplace=True) self.conv_decode_res = SeparableConv2d(256, 16, 3, padding=1) self.conv_decode_res2 = SeparableConv2d(256, 96, 3, padding=1) self.bn_conv_decode_res = nn.BatchNorm2d(16) self.bn_conv_decode_res2 = nn.BatchNorm2d(96) self.upconv1 = nn.ConvTranspose2d(96, 96, 2, 2) self.bn_upconv1 = nn.BatchNorm2d(96) self.conv_decode1 = SeparableConv2d(96, 64, 3, padding=1) self.bn_decode1 = nn.BatchNorm2d(64) self.upconv2 = nn.ConvTranspose2d(64, 64, 2, 2) self.bn_upconv2 = nn.BatchNorm2d(64) self.conv_decode2 = SeparableConv2d(64, 64, 5, padding=2) self.bn_decode2 = nn.BatchNorm2d(64) self.upconv3 = nn.ConvTranspose2d(64, 32, 2, 2) self.bn_upconv3 = nn.BatchNorm2d(32) self.conv_decode3 = SeparableConv2d(32, 32, 5, padding=2) self.bn_decode3 = nn.BatchNorm2d(32) self.upconv4 = nn.ConvTranspose2d(32, 32, 2, 2) self.bn_upconv4 = nn.BatchNorm2d(32) self.conv_decode4 = SeparableConv2d(48, output_channels, 5, padding=2) def forward(self, representation): # batch_size=representation.shape[0] if self.num_classes is None: x2 = self.conv_decode_res(representation) x2 = self.bn_conv_decode_res(x2) x2 = interpolate(x2, size=(256, 256)) x = self.conv_decode_res2(representation) x = self.bn_conv_decode_res2(x) x = self.upconv1(x) x = self.bn_upconv1(x) x = self.relu(x) x = self.conv_decode1(x) x = self.bn_decode1(x) x = self.relu(x) x = self.upconv2(x) x = self.bn_upconv2(x) x = self.relu(x) x = self.conv_decode2(x) x = self.bn_decode2(x) x = self.relu(x) x = self.upconv3(x) x = self.bn_upconv3(x) x = self.relu(x) x = self.conv_decode3(x) x = self.bn_decode3(x) x = self.relu(x) x = self.upconv4(x) x = self.bn_upconv4(x) x = torch.cat([x, x2], 1) # print(x.shape,self.static.shape) # x = torch.cat([x,x2,input,self.static.expand(batch_size,-1,-1,-1)],1) x = self.relu(x) x = self.conv_decode4(x) # z = x[:,19:22,:,:].clone() # y = (z).norm(2,1,True).clamp(min=1e-12) # print(y.shape,x[:,21:24,:,:].shape) # x[:,19:22,:,:]=z/y else: # print(representation.shape) x = F.adaptive_avg_pool2d(representation, (1, 1)) x = x.view(x.size(0), -1) # print(x.shape) x = self.fc(x) # print(x.shape) return x class Xception(BaseModel): """ Xception optimized for the ImageNet dataset, as specified in https://arxiv.org/pdf/1610.02357.pdf """ def __init__(self, tasks=None, num_classes=None, ozan=False, half=False): """ Constructor Args: num_classes: number of classes """ super(Xception, self).__init__() print('half is', half) if half == 'Quad': print('running quad code') self.encoder = EncoderQuad() elif half == 'Double': self.encoder = EncoderDouble() elif half: self.encoder = EncoderHalf() else: self.encoder = Encoder() self.tasks = tasks self.ozan = ozan self.task_to_decoder = {} self.task_to_output_channels = { 'segment_semantic': 18, 'depth_zbuffer': 1, 'normal': 3, 'normal2': 3, 'edge_occlusion': 1, 'reshading': 1, 'keypoints2d': 1, 'edge_texture': 1, 'principal_curvature': 2, 'rgb': 3, } if tasks is not None: for task in tasks: output_channels = self.task_to_output_channels[task] decoder = Decoder(output_channels, num_classes) self.task_to_decoder[task] = decoder else: self.task_to_decoder['classification'] = Decoder(output_channels=0, num_classes=1000) self.decoders = nn.ModuleList(self.task_to_decoder.values()) # ------- init weights -------- for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() # ----------------------------- count = 0 def input_per_task_losses(self, losses): # if GradNormRepFunction.inital_task_losses is None: # GradNormRepFunction.inital_task_losses=losses # GradNormRepFunction.current_weights=[1 for i in losses] Xception.count += 1 if Xception.count < 200: GradNormRepFunction.inital_task_losses = losses GradNormRepFunction.current_weights = [1 for i in losses] elif Xception.count % 20 == 0: with open("gradnorm_weights.txt", "a") as myfile: myfile.write(str(Xception.count) + ': ' + str(GradNormRepFunction.current_weights) + '\n') GradNormRepFunction.current_task_losses = losses def forward(self, input): rep = self.encoder(input) if self.tasks is None: return self.decoders[0](rep) outputs = {'rep': rep} if self.ozan == 'gradnorm': GradNormRepFunction.n = len(self.decoders) rep = gradnorm_rep_function(rep) for i, (task, decoder) in enumerate(zip(self.task_to_decoder.keys(), self.decoders)): outputs[task] = decoder(rep[i]) elif self.ozan: OzanRepFunction.n = len(self.decoders) rep = ozan_rep_function(rep) for i, (task, decoder) in enumerate(zip(self.task_to_decoder.keys(), self.decoders)): outputs[task] = decoder(rep[i]) else: TrevorRepFunction.n = len(self.decoders) rep = trevor_rep_function(rep) for i, (task, decoder) in enumerate(zip(self.task_to_decoder.keys(), self.decoders)): outputs[task] = decoder(rep) # Original loss # for i, (task, decoder) in enumerate(zip(self.task_to_decoder.keys(), self.decoders)): # outputs[task] = decoder(rep) return outputs def xception(pretrained=False, **kwargs): """ Construct Xception. """ model = Xception(**kwargs) if pretrained: # state_dict = model_zoo.load_url(model_urls['xception_taskonomy']) # for name,weight in state_dict.items(): # if 'pointwise' in name: # state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1) # if 'conv1' in name and len(weight.shape)!=4: # state_dict[name]=weight.unsqueeze(1) # model.load_state_dict(state_dict) # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar')) model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar')) # if num_classes !=1000: # model.fc = nn.Linear(2048, num_classes) # import torch # print("writing new state dict") # torch.save(model.state_dict(),"xception.pth.tar") # print("done") # import sys # sys.exit(1) return model def xception_ozan(pretrained=False, **kwargs): """ Construct Xception. """ model = Xception(ozan=True, **kwargs) if pretrained: # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar')) model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar')) return model def xception_gradnorm(pretrained=False, **kwargs): """ Construct Xception. """ model = Xception(ozan='gradnorm', **kwargs) if pretrained: # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar')) model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar')) return model def xception_half_gradnorm(pretrained=False, **kwargs): """ Construct Xception. """ model = Xception(half=True, ozan='gradnorm', **kwargs) return model def xception_half(pretrained=False, **kwargs): """ Construct Xception. """ # try: # num_classes = kwargs['num_classes'] # except: # num_classes=1000 # if pretrained: # kwargs['num_classes']=1000 model = Xception(half=True, **kwargs) if pretrained: # state_dict = model_zoo.load_url(model_urls['xception_taskonomy']) # for name,weight in state_dict.items(): # if 'pointwise' in name: # state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1) # if 'conv1' in name and len(weight.shape)!=4: # state_dict[name]=weight.unsqueeze(1) # model.load_state_dict(state_dict) # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar')) model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar')) # if num_classes !=1000: # model.fc = nn.Linear(2048, num_classes) # import torch # print("writing new state dict") # torch.save(model.state_dict(),"xception.pth.tar") # print("done") # import sys # sys.exit(1) return model def xception_quad(pretrained=False, **kwargs): """ Construct Xception. """ # try: # num_classes = kwargs['num_classes'] # except: # num_classes=1000 # if pretrained: # kwargs['num_classes']=1000 print('got quad') model = Xception(half='Quad', **kwargs) if pretrained: # state_dict = model_zoo.load_url(model_urls['xception_taskonomy']) # for name,weight in state_dict.items(): # if 'pointwise' in name: # state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1) # if 'conv1' in name and len(weight.shape)!=4: # state_dict[name]=weight.unsqueeze(1) # model.load_state_dict(state_dict) # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar')) model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar')) # if num_classes !=1000: # model.fc = nn.Linear(2048, num_classes) # import torch # print("writing new state dict") # torch.save(model.state_dict(),"xception.pth.tar") # print("done") # import sys # sys.exit(1) return model def xception_double(pretrained=False, **kwargs): """ Construct Xception. """ # try: # num_classes = kwargs['num_classes'] # except: # num_classes=1000 # if pretrained: # kwargs['num_classes']=1000 print('got double') model = Xception(half='Double', **kwargs) if pretrained: # state_dict = model_zoo.load_url(model_urls['xception_taskonomy']) # for name,weight in state_dict.items(): # if 'pointwise' in name: # state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1) # if 'conv1' in name and len(weight.shape)!=4: # state_dict[name]=weight.unsqueeze(1) # model.load_state_dict(state_dict) # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar')) model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar')) # if num_classes !=1000: # model.fc = nn.Linear(2048, num_classes) # import torch # print("writing new state dict") # torch.save(model.state_dict(),"xception.pth.tar") # print("done") # import sys # sys.exit(1) return model def xception_quad_ozan(pretrained=False, **kwargs): """ Construct Xception. """ # try: # num_classes = kwargs['num_classes'] # except: # num_classes=1000 # if pretrained: # kwargs['num_classes']=1000 print('got quad ozan') model = Xception(ozan=True, half='Quad', **kwargs) if pretrained: # state_dict = model_zoo.load_url(model_urls['xception_taskonomy']) # for name,weight in state_dict.items(): # if 'pointwise' in name: # state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1) # if 'conv1' in name and len(weight.shape)!=4: # state_dict[name]=weight.unsqueeze(1) # model.load_state_dict(state_dict) # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar')) model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar')) # if num_classes !=1000: # model.fc = nn.Linear(2048, num_classes) # import torch # print("writing new state dict") # torch.save(model.state_dict(),"xception.pth.tar") # print("done") # import sys # sys.exit(1) return model def xception_double_ozan(pretrained=False, **kwargs): """ Construct Xception. """ # try: # num_classes = kwargs['num_classes'] # except: # num_classes=1000 # if pretrained: # kwargs['num_classes']=1000 print('got double') model = Xception(ozan=True, half='Double', **kwargs) if pretrained: # state_dict = model_zoo.load_url(model_urls['xception_taskonomy']) # for name,weight in state_dict.items(): # if 'pointwise' in name: # state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1) # if 'conv1' in name and len(weight.shape)!=4: # state_dict[name]=weight.unsqueeze(1) # model.load_state_dict(state_dict) # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar')) model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar')) # if num_classes !=1000: # model.fc = nn.Linear(2048, num_classes) # import torch # print("writing new state dict") # torch.save(model.state_dict(),"xception.pth.tar") # print("done") # import sys # sys.exit(1) return model def xception_half_ozan(pretrained=False, **kwargs): """ Construct Xception. """ model = Xception(ozan=True, half=True, **kwargs) if pretrained: # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar')) model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar')) return model