123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821 |
- """
- Creates an Xception Model as defined in:
- Xception: Deep Learning with Depthwise Separable Convolutions, https://arxiv.org/pdf/1610.02357.pdf
- """
- import math
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from easyfl.models.model import BaseModel
- from .ozan_rep_fun import ozan_rep_function, OzanRepFunction, gradnorm_rep_function, GradNormRepFunction, \
- trevor_rep_function, TrevorRepFunction
- __all__ = ['xception',
- 'xception_gradnorm',
- 'xception_half_gradnorm',
- 'xception_ozan',
- 'xception_half',
- 'xception_quad',
- 'xception_double',
- 'xception_double_ozan',
- 'xception_half_ozan',
- 'xception_quad_ozan']
- # model_urls = {
- # 'xception_taskonomy':'file:///home/tstand/Dropbox/taskonomy/xception_taskonomy-a4b32ef7.pth.tar'
- # }
- class SeparableConv2d(nn.Module):
- def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False,
- groupsize=1):
- super(SeparableConv2d, self).__init__()
- self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation,
- groups=max(1, in_channels // groupsize), bias=bias)
- self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
- def forward(self, x):
- x = self.conv1(x)
- x = self.pointwise(x)
- return x
- class Block(nn.Module):
- def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
- super(Block, self).__init__()
- if out_filters != in_filters or strides != 1:
- self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False)
- self.skipbn = nn.BatchNorm2d(out_filters)
- else:
- self.skip = None
- self.relu = nn.ReLU(inplace=True)
- rep = []
- filters = in_filters
- if grow_first:
- rep.append(self.relu)
- rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False))
- rep.append(nn.BatchNorm2d(out_filters))
- filters = out_filters
- for i in range(reps - 1):
- rep.append(self.relu)
- rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False))
- rep.append(nn.BatchNorm2d(filters))
- if not grow_first:
- rep.append(self.relu)
- rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False))
- rep.append(nn.BatchNorm2d(out_filters))
- filters = out_filters
- if not start_with_relu:
- rep = rep[1:]
- else:
- rep[0] = nn.ReLU(inplace=False)
- if strides != 1:
- # rep.append(nn.AvgPool2d(3,strides,1))
- rep.append(nn.Conv2d(filters, filters, 2, 2))
- self.rep = nn.Sequential(*rep)
- def forward(self, inp):
- x = self.rep(inp)
- if self.skip is not None:
- skip = self.skip(inp)
- skip = self.skipbn(skip)
- else:
- skip = inp
- x += skip
- return x
- class Encoder(nn.Module):
- def __init__(self):
- super(Encoder, self).__init__()
- self.conv1 = nn.Conv2d(3, 24, 3, 2, 1, bias=False)
- self.bn1 = nn.BatchNorm2d(24)
- self.relu = nn.ReLU(inplace=True)
- self.relu2 = nn.ReLU(inplace=False)
- self.conv2 = nn.Conv2d(24, 48, 3, 1, 1, bias=False)
- self.bn2 = nn.BatchNorm2d(48)
- # do relu here
- self.block1 = Block(48, 96, 2, 2, start_with_relu=False, grow_first=True)
- self.block2 = Block(96, 192, 2, 2, start_with_relu=True, grow_first=True)
- self.block3 = Block(192, 512, 2, 2, start_with_relu=True, grow_first=True)
- # self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True)
- # self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True)
- # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True)
- # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True)
- self.block8 = Block(512, 512, 2, 1, start_with_relu=True, grow_first=True)
- self.block9 = Block(512, 512, 2, 1, start_with_relu=True, grow_first=True)
- self.block10 = Block(512, 512, 2, 1, start_with_relu=True, grow_first=True)
- self.block11 = Block(512, 512, 2, 1, start_with_relu=True, grow_first=True)
- # self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
- self.conv3 = SeparableConv2d(512, 256, 3, 1, 1)
- self.bn3 = nn.BatchNorm2d(256)
- # self.conv3 = SeparableConv2d(1024,1536,3,1,1)
- # self.bn3 = nn.BatchNorm2d(1536)
- # do relu here
- # self.conv4 = SeparableConv2d(1536,2048,3,1,1)
- # self.bn4 = nn.BatchNorm2d(2048)
- def forward(self, input):
- x = self.conv1(input)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.conv2(x)
- x = self.bn2(x)
- x = self.relu(x)
- x = self.block1(x)
- x = self.block2(x)
- x = self.block3(x)
- # x = self.block4(x)
- # x = self.block5(x)
- # x = self.block6(x)
- # x = self.block7(x)
- x = self.block8(x)
- x = self.block9(x)
- x = self.block10(x)
- x = self.block11(x)
- # x = self.block12(x)
- x = self.conv3(x)
- x = self.bn3(x)
- # x = self.relu(x)
- # x = self.conv4(x)
- # x = self.bn4(x)
- representation = self.relu2(x)
- return representation
- class EncoderHalf(nn.Module):
- def __init__(self):
- super(EncoderHalf, self).__init__()
- self.conv1 = nn.Conv2d(3, 24, 3, 2, 1, bias=False)
- self.bn1 = nn.BatchNorm2d(24)
- self.relu = nn.ReLU(inplace=True)
- self.relu2 = nn.ReLU(inplace=False)
- self.conv2 = nn.Conv2d(24, 48, 3, 1, 1, bias=False)
- self.bn2 = nn.BatchNorm2d(48)
- # do relu here
- self.block1 = Block(48, 64, 2, 2, start_with_relu=False, grow_first=True)
- self.block2 = Block(64, 128, 2, 2, start_with_relu=True, grow_first=True)
- self.block3 = Block(128, 360, 2, 2, start_with_relu=True, grow_first=True)
- # self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True)
- # self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True)
- # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True)
- # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True)
- self.block8 = Block(360, 360, 2, 1, start_with_relu=True, grow_first=True)
- self.block9 = Block(360, 360, 2, 1, start_with_relu=True, grow_first=True)
- self.block10 = Block(360, 360, 2, 1, start_with_relu=True, grow_first=True)
- self.block11 = Block(360, 360, 2, 1, start_with_relu=True, grow_first=True)
- # self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
- self.conv3 = SeparableConv2d(360, 256, 3, 1, 1)
- self.bn3 = nn.BatchNorm2d(256)
- # self.conv3 = SeparableConv2d(1024,1536,3,1,1)
- # self.bn3 = nn.BatchNorm2d(1536)
- # do relu here
- # self.conv4 = SeparableConv2d(1536,2048,3,1,1)
- # self.bn4 = nn.BatchNorm2d(2048)
- def forward(self, input):
- x = self.conv1(input)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.conv2(x)
- x = self.bn2(x)
- x = self.relu(x)
- x = self.block1(x)
- x = self.block2(x)
- x = self.block3(x)
- # x = self.block4(x)
- # x = self.block5(x)
- # x = self.block6(x)
- # x = self.block7(x)
- x = self.block8(x)
- x = self.block9(x)
- x = self.block10(x)
- x = self.block11(x)
- # x = self.block12(x)
- x = self.conv3(x)
- x = self.bn3(x)
- # x = self.relu(x)
- # x = self.conv4(x)
- # x = self.bn4(x)
- representation = self.relu2(x)
- return representation
- class EncoderQuad(nn.Module):
- def __init__(self):
- super(EncoderQuad, self).__init__()
- print('entering quad constructor')
- self.conv1 = nn.Conv2d(3, 48, 3, 2, 1, bias=False)
- self.bn1 = nn.BatchNorm2d(48)
- self.relu = nn.ReLU(inplace=True)
- self.relu2 = nn.ReLU(inplace=False)
- self.conv2 = nn.Conv2d(48, 96, 3, 1, 1, bias=False)
- self.bn2 = nn.BatchNorm2d(96)
- # do relu here
- self.block1 = Block(96, 192, 2, 2, start_with_relu=False, grow_first=True)
- self.block2 = Block(192, 384, 2, 2, start_with_relu=True, grow_first=True)
- self.block3 = Block(384, 1024, 2, 2, start_with_relu=True, grow_first=True)
- # self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True)
- # self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True)
- # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True)
- # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True)
- self.block8 = Block(1024, 1024, 2, 1, start_with_relu=True, grow_first=True)
- self.block9 = Block(1024, 1024, 2, 1, start_with_relu=True, grow_first=True)
- self.block10 = Block(1024, 1024, 2, 1, start_with_relu=True, grow_first=True)
- self.block11 = Block(1024, 1024, 2, 1, start_with_relu=True, grow_first=True)
- # self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
- self.conv3 = SeparableConv2d(1024, 256, 3, 1, 1)
- self.bn3 = nn.BatchNorm2d(256)
- # self.conv3 = SeparableConv2d(1024,1536,3,1,1)
- # self.bn3 = nn.BatchNorm2d(1536)
- # do relu here
- # self.conv4 = SeparableConv2d(1536,2048,3,1,1)
- # self.bn4 = nn.BatchNorm2d(2048)
- def forward(self, input):
- x = self.conv1(input)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.conv2(x)
- x = self.bn2(x)
- x = self.relu(x)
- x = self.block1(x)
- x = self.block2(x)
- x = self.block3(x)
- # x = self.block4(x)
- # x = self.block5(x)
- # x = self.block6(x)
- # x = self.block7(x)
- x = self.block8(x)
- x = self.block9(x)
- x = self.block10(x)
- x = self.block11(x)
- # x = self.block12(x)
- x = self.conv3(x)
- x = self.bn3(x)
- # x = self.relu(x)
- # x = self.conv4(x)
- # x = self.bn4(x)
- representation = self.relu2(x)
- return representation
- class EncoderDouble(nn.Module):
- def __init__(self):
- super(EncoderDouble, self).__init__()
- print('entering double constructor')
- self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False)
- self.bn1 = nn.BatchNorm2d(32)
- self.relu = nn.ReLU(inplace=True)
- self.relu2 = nn.ReLU(inplace=False)
- self.conv2 = nn.Conv2d(32, 64, 3, 1, 1, bias=False)
- self.bn2 = nn.BatchNorm2d(64)
- # do relu here
- self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True)
- self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True)
- self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True)
- # self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True)
- # self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True)
- # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True)
- # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True)
- self.block8 = Block(728, 728, 2, 1, start_with_relu=True, grow_first=True)
- self.block9 = Block(728, 728, 2, 1, start_with_relu=True, grow_first=True)
- self.block10 = Block(728, 728, 2, 1, start_with_relu=True, grow_first=True)
- self.block11 = Block(728, 728, 2, 1, start_with_relu=True, grow_first=True)
- # self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
- self.conv3 = SeparableConv2d(728, 256, 3, 1, 1)
- self.bn3 = nn.BatchNorm2d(256)
- # self.conv3 = SeparableConv2d(1024,1536,3,1,1)
- # self.bn3 = nn.BatchNorm2d(1536)
- # do relu here
- # self.conv4 = SeparableConv2d(1536,2048,3,1,1)
- # self.bn4 = nn.BatchNorm2d(2048)
- def forward(self, input):
- x = self.conv1(input)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.conv2(x)
- x = self.bn2(x)
- x = self.relu(x)
- x = self.block1(x)
- x = self.block2(x)
- x = self.block3(x)
- # x = self.block4(x)
- # x = self.block5(x)
- # x = self.block6(x)
- # x = self.block7(x)
- x = self.block8(x)
- x = self.block9(x)
- x = self.block10(x)
- x = self.block11(x)
- # x = self.block12(x)
- x = self.conv3(x)
- x = self.bn3(x)
- # x = self.relu(x)
- # x = self.conv4(x)
- # x = self.bn4(x)
- representation = self.relu2(x)
- return representation
- def interpolate(inp, size):
- t = inp.type()
- inp = inp.float()
- out = nn.functional.interpolate(inp, size=size, mode='bilinear', align_corners=False)
- if out.type() != t:
- out = out.half()
- return out
- class Decoder(nn.Module):
- def __init__(self, output_channels=32, num_classes=None):
- super(Decoder, self).__init__()
- self.output_channels = output_channels
- self.num_classes = num_classes
- if num_classes is not None:
- self.fc = nn.Linear(256, num_classes)
- # else:
- # self.fc = nn.Linear(256, 1000)
- else:
- self.relu = nn.ReLU(inplace=True)
- self.conv_decode_res = SeparableConv2d(256, 16, 3, padding=1)
- self.conv_decode_res2 = SeparableConv2d(256, 96, 3, padding=1)
- self.bn_conv_decode_res = nn.BatchNorm2d(16)
- self.bn_conv_decode_res2 = nn.BatchNorm2d(96)
- self.upconv1 = nn.ConvTranspose2d(96, 96, 2, 2)
- self.bn_upconv1 = nn.BatchNorm2d(96)
- self.conv_decode1 = SeparableConv2d(96, 64, 3, padding=1)
- self.bn_decode1 = nn.BatchNorm2d(64)
- self.upconv2 = nn.ConvTranspose2d(64, 64, 2, 2)
- self.bn_upconv2 = nn.BatchNorm2d(64)
- self.conv_decode2 = SeparableConv2d(64, 64, 5, padding=2)
- self.bn_decode2 = nn.BatchNorm2d(64)
- self.upconv3 = nn.ConvTranspose2d(64, 32, 2, 2)
- self.bn_upconv3 = nn.BatchNorm2d(32)
- self.conv_decode3 = SeparableConv2d(32, 32, 5, padding=2)
- self.bn_decode3 = nn.BatchNorm2d(32)
- self.upconv4 = nn.ConvTranspose2d(32, 32, 2, 2)
- self.bn_upconv4 = nn.BatchNorm2d(32)
- self.conv_decode4 = SeparableConv2d(48, output_channels, 5, padding=2)
- def forward(self, representation):
- # batch_size=representation.shape[0]
- if self.num_classes is None:
- x2 = self.conv_decode_res(representation)
- x2 = self.bn_conv_decode_res(x2)
- x2 = interpolate(x2, size=(256, 256))
- x = self.conv_decode_res2(representation)
- x = self.bn_conv_decode_res2(x)
- x = self.upconv1(x)
- x = self.bn_upconv1(x)
- x = self.relu(x)
- x = self.conv_decode1(x)
- x = self.bn_decode1(x)
- x = self.relu(x)
- x = self.upconv2(x)
- x = self.bn_upconv2(x)
- x = self.relu(x)
- x = self.conv_decode2(x)
- x = self.bn_decode2(x)
- x = self.relu(x)
- x = self.upconv3(x)
- x = self.bn_upconv3(x)
- x = self.relu(x)
- x = self.conv_decode3(x)
- x = self.bn_decode3(x)
- x = self.relu(x)
- x = self.upconv4(x)
- x = self.bn_upconv4(x)
- x = torch.cat([x, x2], 1)
- # print(x.shape,self.static.shape)
- # x = torch.cat([x,x2,input,self.static.expand(batch_size,-1,-1,-1)],1)
- x = self.relu(x)
- x = self.conv_decode4(x)
- # z = x[:,19:22,:,:].clone()
- # y = (z).norm(2,1,True).clamp(min=1e-12)
- # print(y.shape,x[:,21:24,:,:].shape)
- # x[:,19:22,:,:]=z/y
- else:
- # print(representation.shape)
- x = F.adaptive_avg_pool2d(representation, (1, 1))
- x = x.view(x.size(0), -1)
- # print(x.shape)
- x = self.fc(x)
- # print(x.shape)
- return x
- class Xception(BaseModel):
- """
- Xception optimized for the ImageNet dataset, as specified in
- https://arxiv.org/pdf/1610.02357.pdf
- """
- def __init__(self, tasks=None, num_classes=None, ozan=False, half=False):
- """ Constructor
- Args:
- num_classes: number of classes
- """
- super(Xception, self).__init__()
- print('half is', half)
- if half == 'Quad':
- print('running quad code')
- self.encoder = EncoderQuad()
- elif half == 'Double':
- self.encoder = EncoderDouble()
- elif half:
- self.encoder = EncoderHalf()
- else:
- self.encoder = Encoder()
- self.tasks = tasks
- self.ozan = ozan
- self.task_to_decoder = {}
- self.task_to_output_channels = {
- 'segment_semantic': 18,
- 'depth_zbuffer': 1,
- 'normal': 3,
- 'normal2': 3,
- 'edge_occlusion': 1,
- 'reshading': 1,
- 'keypoints2d': 1,
- 'edge_texture': 1,
- 'principal_curvature': 2,
- 'rgb': 3,
- }
- if tasks is not None:
- for task in tasks:
- output_channels = self.task_to_output_channels[task]
- decoder = Decoder(output_channels, num_classes)
- self.task_to_decoder[task] = decoder
- else:
- self.task_to_decoder['classification'] = Decoder(output_channels=0, num_classes=1000)
- self.decoders = nn.ModuleList(self.task_to_decoder.values())
- # ------- init weights --------
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- m.weight.data.normal_(0, math.sqrt(2. / n))
- elif isinstance(m, nn.BatchNorm2d):
- m.weight.data.fill_(1)
- m.bias.data.zero_()
- # -----------------------------
- count = 0
- def input_per_task_losses(self, losses):
- # if GradNormRepFunction.inital_task_losses is None:
- # GradNormRepFunction.inital_task_losses=losses
- # GradNormRepFunction.current_weights=[1 for i in losses]
- Xception.count += 1
- if Xception.count < 200:
- GradNormRepFunction.inital_task_losses = losses
- GradNormRepFunction.current_weights = [1 for i in losses]
- elif Xception.count % 20 == 0:
- with open("gradnorm_weights.txt", "a") as myfile:
- myfile.write(str(Xception.count) + ': ' + str(GradNormRepFunction.current_weights) + '\n')
- GradNormRepFunction.current_task_losses = losses
- def forward(self, input):
- rep = self.encoder(input)
- if self.tasks is None:
- return self.decoders[0](rep)
- outputs = {'rep': rep}
- if self.ozan == 'gradnorm':
- GradNormRepFunction.n = len(self.decoders)
- rep = gradnorm_rep_function(rep)
- for i, (task, decoder) in enumerate(zip(self.task_to_decoder.keys(), self.decoders)):
- outputs[task] = decoder(rep[i])
- elif self.ozan:
- OzanRepFunction.n = len(self.decoders)
- rep = ozan_rep_function(rep)
- for i, (task, decoder) in enumerate(zip(self.task_to_decoder.keys(), self.decoders)):
- outputs[task] = decoder(rep[i])
- else:
- TrevorRepFunction.n = len(self.decoders)
- rep = trevor_rep_function(rep)
- for i, (task, decoder) in enumerate(zip(self.task_to_decoder.keys(), self.decoders)):
- outputs[task] = decoder(rep)
- # Original loss
- # for i, (task, decoder) in enumerate(zip(self.task_to_decoder.keys(), self.decoders)):
- # outputs[task] = decoder(rep)
- return outputs
- def xception(pretrained=False, **kwargs):
- """
- Construct Xception.
- """
- model = Xception(**kwargs)
- if pretrained:
- # state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
- # for name,weight in state_dict.items():
- # if 'pointwise' in name:
- # state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
- # if 'conv1' in name and len(weight.shape)!=4:
- # state_dict[name]=weight.unsqueeze(1)
- # model.load_state_dict(state_dict)
- # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
- model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
- # if num_classes !=1000:
- # model.fc = nn.Linear(2048, num_classes)
- # import torch
- # print("writing new state dict")
- # torch.save(model.state_dict(),"xception.pth.tar")
- # print("done")
- # import sys
- # sys.exit(1)
- return model
- def xception_ozan(pretrained=False, **kwargs):
- """
- Construct Xception.
- """
- model = Xception(ozan=True, **kwargs)
- if pretrained:
- # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
- model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
- return model
- def xception_gradnorm(pretrained=False, **kwargs):
- """
- Construct Xception.
- """
- model = Xception(ozan='gradnorm', **kwargs)
- if pretrained:
- # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
- model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
- return model
- def xception_half_gradnorm(pretrained=False, **kwargs):
- """
- Construct Xception.
- """
- model = Xception(half=True, ozan='gradnorm', **kwargs)
- return model
- def xception_half(pretrained=False, **kwargs):
- """
- Construct Xception.
- """
- # try:
- # num_classes = kwargs['num_classes']
- # except:
- # num_classes=1000
- # if pretrained:
- # kwargs['num_classes']=1000
- model = Xception(half=True, **kwargs)
- if pretrained:
- # state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
- # for name,weight in state_dict.items():
- # if 'pointwise' in name:
- # state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
- # if 'conv1' in name and len(weight.shape)!=4:
- # state_dict[name]=weight.unsqueeze(1)
- # model.load_state_dict(state_dict)
- # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
- model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
- # if num_classes !=1000:
- # model.fc = nn.Linear(2048, num_classes)
- # import torch
- # print("writing new state dict")
- # torch.save(model.state_dict(),"xception.pth.tar")
- # print("done")
- # import sys
- # sys.exit(1)
- return model
- def xception_quad(pretrained=False, **kwargs):
- """
- Construct Xception.
- """
- # try:
- # num_classes = kwargs['num_classes']
- # except:
- # num_classes=1000
- # if pretrained:
- # kwargs['num_classes']=1000
- print('got quad')
- model = Xception(half='Quad', **kwargs)
- if pretrained:
- # state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
- # for name,weight in state_dict.items():
- # if 'pointwise' in name:
- # state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
- # if 'conv1' in name and len(weight.shape)!=4:
- # state_dict[name]=weight.unsqueeze(1)
- # model.load_state_dict(state_dict)
- # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
- model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
- # if num_classes !=1000:
- # model.fc = nn.Linear(2048, num_classes)
- # import torch
- # print("writing new state dict")
- # torch.save(model.state_dict(),"xception.pth.tar")
- # print("done")
- # import sys
- # sys.exit(1)
- return model
- def xception_double(pretrained=False, **kwargs):
- """
- Construct Xception.
- """
- # try:
- # num_classes = kwargs['num_classes']
- # except:
- # num_classes=1000
- # if pretrained:
- # kwargs['num_classes']=1000
- print('got double')
- model = Xception(half='Double', **kwargs)
- if pretrained:
- # state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
- # for name,weight in state_dict.items():
- # if 'pointwise' in name:
- # state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
- # if 'conv1' in name and len(weight.shape)!=4:
- # state_dict[name]=weight.unsqueeze(1)
- # model.load_state_dict(state_dict)
- # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
- model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
- # if num_classes !=1000:
- # model.fc = nn.Linear(2048, num_classes)
- # import torch
- # print("writing new state dict")
- # torch.save(model.state_dict(),"xception.pth.tar")
- # print("done")
- # import sys
- # sys.exit(1)
- return model
- def xception_quad_ozan(pretrained=False, **kwargs):
- """
- Construct Xception.
- """
- # try:
- # num_classes = kwargs['num_classes']
- # except:
- # num_classes=1000
- # if pretrained:
- # kwargs['num_classes']=1000
- print('got quad ozan')
- model = Xception(ozan=True, half='Quad', **kwargs)
- if pretrained:
- # state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
- # for name,weight in state_dict.items():
- # if 'pointwise' in name:
- # state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
- # if 'conv1' in name and len(weight.shape)!=4:
- # state_dict[name]=weight.unsqueeze(1)
- # model.load_state_dict(state_dict)
- # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
- model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
- # if num_classes !=1000:
- # model.fc = nn.Linear(2048, num_classes)
- # import torch
- # print("writing new state dict")
- # torch.save(model.state_dict(),"xception.pth.tar")
- # print("done")
- # import sys
- # sys.exit(1)
- return model
- def xception_double_ozan(pretrained=False, **kwargs):
- """
- Construct Xception.
- """
- # try:
- # num_classes = kwargs['num_classes']
- # except:
- # num_classes=1000
- # if pretrained:
- # kwargs['num_classes']=1000
- print('got double')
- model = Xception(ozan=True, half='Double', **kwargs)
- if pretrained:
- # state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
- # for name,weight in state_dict.items():
- # if 'pointwise' in name:
- # state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
- # if 'conv1' in name and len(weight.shape)!=4:
- # state_dict[name]=weight.unsqueeze(1)
- # model.load_state_dict(state_dict)
- # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
- model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
- # if num_classes !=1000:
- # model.fc = nn.Linear(2048, num_classes)
- # import torch
- # print("writing new state dict")
- # torch.save(model.state_dict(),"xception.pth.tar")
- # print("done")
- # import sys
- # sys.exit(1)
- return model
- def xception_half_ozan(pretrained=False, **kwargs):
- """
- Construct Xception.
- """
- model = Xception(ozan=True, half=True, **kwargs)
- if pretrained:
- # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
- model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
- return model
|