xception.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821
  1. """
  2. Creates an Xception Model as defined in:
  3. Xception: Deep Learning with Depthwise Separable Convolutions, https://arxiv.org/pdf/1610.02357.pdf
  4. """
  5. import math
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from easyfl.models.model import BaseModel
  10. from .ozan_rep_fun import ozan_rep_function, OzanRepFunction, gradnorm_rep_function, GradNormRepFunction, \
  11. trevor_rep_function, TrevorRepFunction
  12. __all__ = ['xception',
  13. 'xception_gradnorm',
  14. 'xception_half_gradnorm',
  15. 'xception_ozan',
  16. 'xception_half',
  17. 'xception_quad',
  18. 'xception_double',
  19. 'xception_double_ozan',
  20. 'xception_half_ozan',
  21. 'xception_quad_ozan']
  22. # model_urls = {
  23. # 'xception_taskonomy':'file:///home/tstand/Dropbox/taskonomy/xception_taskonomy-a4b32ef7.pth.tar'
  24. # }
  25. class SeparableConv2d(nn.Module):
  26. def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False,
  27. groupsize=1):
  28. super(SeparableConv2d, self).__init__()
  29. self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation,
  30. groups=max(1, in_channels // groupsize), bias=bias)
  31. self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
  32. def forward(self, x):
  33. x = self.conv1(x)
  34. x = self.pointwise(x)
  35. return x
  36. class Block(nn.Module):
  37. def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
  38. super(Block, self).__init__()
  39. if out_filters != in_filters or strides != 1:
  40. self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False)
  41. self.skipbn = nn.BatchNorm2d(out_filters)
  42. else:
  43. self.skip = None
  44. self.relu = nn.ReLU(inplace=True)
  45. rep = []
  46. filters = in_filters
  47. if grow_first:
  48. rep.append(self.relu)
  49. rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False))
  50. rep.append(nn.BatchNorm2d(out_filters))
  51. filters = out_filters
  52. for i in range(reps - 1):
  53. rep.append(self.relu)
  54. rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False))
  55. rep.append(nn.BatchNorm2d(filters))
  56. if not grow_first:
  57. rep.append(self.relu)
  58. rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False))
  59. rep.append(nn.BatchNorm2d(out_filters))
  60. filters = out_filters
  61. if not start_with_relu:
  62. rep = rep[1:]
  63. else:
  64. rep[0] = nn.ReLU(inplace=False)
  65. if strides != 1:
  66. # rep.append(nn.AvgPool2d(3,strides,1))
  67. rep.append(nn.Conv2d(filters, filters, 2, 2))
  68. self.rep = nn.Sequential(*rep)
  69. def forward(self, inp):
  70. x = self.rep(inp)
  71. if self.skip is not None:
  72. skip = self.skip(inp)
  73. skip = self.skipbn(skip)
  74. else:
  75. skip = inp
  76. x += skip
  77. return x
  78. class Encoder(nn.Module):
  79. def __init__(self):
  80. super(Encoder, self).__init__()
  81. self.conv1 = nn.Conv2d(3, 24, 3, 2, 1, bias=False)
  82. self.bn1 = nn.BatchNorm2d(24)
  83. self.relu = nn.ReLU(inplace=True)
  84. self.relu2 = nn.ReLU(inplace=False)
  85. self.conv2 = nn.Conv2d(24, 48, 3, 1, 1, bias=False)
  86. self.bn2 = nn.BatchNorm2d(48)
  87. # do relu here
  88. self.block1 = Block(48, 96, 2, 2, start_with_relu=False, grow_first=True)
  89. self.block2 = Block(96, 192, 2, 2, start_with_relu=True, grow_first=True)
  90. self.block3 = Block(192, 512, 2, 2, start_with_relu=True, grow_first=True)
  91. # self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True)
  92. # self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True)
  93. # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True)
  94. # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True)
  95. self.block8 = Block(512, 512, 2, 1, start_with_relu=True, grow_first=True)
  96. self.block9 = Block(512, 512, 2, 1, start_with_relu=True, grow_first=True)
  97. self.block10 = Block(512, 512, 2, 1, start_with_relu=True, grow_first=True)
  98. self.block11 = Block(512, 512, 2, 1, start_with_relu=True, grow_first=True)
  99. # self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
  100. self.conv3 = SeparableConv2d(512, 256, 3, 1, 1)
  101. self.bn3 = nn.BatchNorm2d(256)
  102. # self.conv3 = SeparableConv2d(1024,1536,3,1,1)
  103. # self.bn3 = nn.BatchNorm2d(1536)
  104. # do relu here
  105. # self.conv4 = SeparableConv2d(1536,2048,3,1,1)
  106. # self.bn4 = nn.BatchNorm2d(2048)
  107. def forward(self, input):
  108. x = self.conv1(input)
  109. x = self.bn1(x)
  110. x = self.relu(x)
  111. x = self.conv2(x)
  112. x = self.bn2(x)
  113. x = self.relu(x)
  114. x = self.block1(x)
  115. x = self.block2(x)
  116. x = self.block3(x)
  117. # x = self.block4(x)
  118. # x = self.block5(x)
  119. # x = self.block6(x)
  120. # x = self.block7(x)
  121. x = self.block8(x)
  122. x = self.block9(x)
  123. x = self.block10(x)
  124. x = self.block11(x)
  125. # x = self.block12(x)
  126. x = self.conv3(x)
  127. x = self.bn3(x)
  128. # x = self.relu(x)
  129. # x = self.conv4(x)
  130. # x = self.bn4(x)
  131. representation = self.relu2(x)
  132. return representation
  133. class EncoderHalf(nn.Module):
  134. def __init__(self):
  135. super(EncoderHalf, self).__init__()
  136. self.conv1 = nn.Conv2d(3, 24, 3, 2, 1, bias=False)
  137. self.bn1 = nn.BatchNorm2d(24)
  138. self.relu = nn.ReLU(inplace=True)
  139. self.relu2 = nn.ReLU(inplace=False)
  140. self.conv2 = nn.Conv2d(24, 48, 3, 1, 1, bias=False)
  141. self.bn2 = nn.BatchNorm2d(48)
  142. # do relu here
  143. self.block1 = Block(48, 64, 2, 2, start_with_relu=False, grow_first=True)
  144. self.block2 = Block(64, 128, 2, 2, start_with_relu=True, grow_first=True)
  145. self.block3 = Block(128, 360, 2, 2, start_with_relu=True, grow_first=True)
  146. # self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True)
  147. # self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True)
  148. # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True)
  149. # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True)
  150. self.block8 = Block(360, 360, 2, 1, start_with_relu=True, grow_first=True)
  151. self.block9 = Block(360, 360, 2, 1, start_with_relu=True, grow_first=True)
  152. self.block10 = Block(360, 360, 2, 1, start_with_relu=True, grow_first=True)
  153. self.block11 = Block(360, 360, 2, 1, start_with_relu=True, grow_first=True)
  154. # self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
  155. self.conv3 = SeparableConv2d(360, 256, 3, 1, 1)
  156. self.bn3 = nn.BatchNorm2d(256)
  157. # self.conv3 = SeparableConv2d(1024,1536,3,1,1)
  158. # self.bn3 = nn.BatchNorm2d(1536)
  159. # do relu here
  160. # self.conv4 = SeparableConv2d(1536,2048,3,1,1)
  161. # self.bn4 = nn.BatchNorm2d(2048)
  162. def forward(self, input):
  163. x = self.conv1(input)
  164. x = self.bn1(x)
  165. x = self.relu(x)
  166. x = self.conv2(x)
  167. x = self.bn2(x)
  168. x = self.relu(x)
  169. x = self.block1(x)
  170. x = self.block2(x)
  171. x = self.block3(x)
  172. # x = self.block4(x)
  173. # x = self.block5(x)
  174. # x = self.block6(x)
  175. # x = self.block7(x)
  176. x = self.block8(x)
  177. x = self.block9(x)
  178. x = self.block10(x)
  179. x = self.block11(x)
  180. # x = self.block12(x)
  181. x = self.conv3(x)
  182. x = self.bn3(x)
  183. # x = self.relu(x)
  184. # x = self.conv4(x)
  185. # x = self.bn4(x)
  186. representation = self.relu2(x)
  187. return representation
  188. class EncoderQuad(nn.Module):
  189. def __init__(self):
  190. super(EncoderQuad, self).__init__()
  191. print('entering quad constructor')
  192. self.conv1 = nn.Conv2d(3, 48, 3, 2, 1, bias=False)
  193. self.bn1 = nn.BatchNorm2d(48)
  194. self.relu = nn.ReLU(inplace=True)
  195. self.relu2 = nn.ReLU(inplace=False)
  196. self.conv2 = nn.Conv2d(48, 96, 3, 1, 1, bias=False)
  197. self.bn2 = nn.BatchNorm2d(96)
  198. # do relu here
  199. self.block1 = Block(96, 192, 2, 2, start_with_relu=False, grow_first=True)
  200. self.block2 = Block(192, 384, 2, 2, start_with_relu=True, grow_first=True)
  201. self.block3 = Block(384, 1024, 2, 2, start_with_relu=True, grow_first=True)
  202. # self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True)
  203. # self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True)
  204. # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True)
  205. # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True)
  206. self.block8 = Block(1024, 1024, 2, 1, start_with_relu=True, grow_first=True)
  207. self.block9 = Block(1024, 1024, 2, 1, start_with_relu=True, grow_first=True)
  208. self.block10 = Block(1024, 1024, 2, 1, start_with_relu=True, grow_first=True)
  209. self.block11 = Block(1024, 1024, 2, 1, start_with_relu=True, grow_first=True)
  210. # self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
  211. self.conv3 = SeparableConv2d(1024, 256, 3, 1, 1)
  212. self.bn3 = nn.BatchNorm2d(256)
  213. # self.conv3 = SeparableConv2d(1024,1536,3,1,1)
  214. # self.bn3 = nn.BatchNorm2d(1536)
  215. # do relu here
  216. # self.conv4 = SeparableConv2d(1536,2048,3,1,1)
  217. # self.bn4 = nn.BatchNorm2d(2048)
  218. def forward(self, input):
  219. x = self.conv1(input)
  220. x = self.bn1(x)
  221. x = self.relu(x)
  222. x = self.conv2(x)
  223. x = self.bn2(x)
  224. x = self.relu(x)
  225. x = self.block1(x)
  226. x = self.block2(x)
  227. x = self.block3(x)
  228. # x = self.block4(x)
  229. # x = self.block5(x)
  230. # x = self.block6(x)
  231. # x = self.block7(x)
  232. x = self.block8(x)
  233. x = self.block9(x)
  234. x = self.block10(x)
  235. x = self.block11(x)
  236. # x = self.block12(x)
  237. x = self.conv3(x)
  238. x = self.bn3(x)
  239. # x = self.relu(x)
  240. # x = self.conv4(x)
  241. # x = self.bn4(x)
  242. representation = self.relu2(x)
  243. return representation
  244. class EncoderDouble(nn.Module):
  245. def __init__(self):
  246. super(EncoderDouble, self).__init__()
  247. print('entering double constructor')
  248. self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False)
  249. self.bn1 = nn.BatchNorm2d(32)
  250. self.relu = nn.ReLU(inplace=True)
  251. self.relu2 = nn.ReLU(inplace=False)
  252. self.conv2 = nn.Conv2d(32, 64, 3, 1, 1, bias=False)
  253. self.bn2 = nn.BatchNorm2d(64)
  254. # do relu here
  255. self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True)
  256. self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True)
  257. self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True)
  258. # self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True)
  259. # self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True)
  260. # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True)
  261. # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True)
  262. self.block8 = Block(728, 728, 2, 1, start_with_relu=True, grow_first=True)
  263. self.block9 = Block(728, 728, 2, 1, start_with_relu=True, grow_first=True)
  264. self.block10 = Block(728, 728, 2, 1, start_with_relu=True, grow_first=True)
  265. self.block11 = Block(728, 728, 2, 1, start_with_relu=True, grow_first=True)
  266. # self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
  267. self.conv3 = SeparableConv2d(728, 256, 3, 1, 1)
  268. self.bn3 = nn.BatchNorm2d(256)
  269. # self.conv3 = SeparableConv2d(1024,1536,3,1,1)
  270. # self.bn3 = nn.BatchNorm2d(1536)
  271. # do relu here
  272. # self.conv4 = SeparableConv2d(1536,2048,3,1,1)
  273. # self.bn4 = nn.BatchNorm2d(2048)
  274. def forward(self, input):
  275. x = self.conv1(input)
  276. x = self.bn1(x)
  277. x = self.relu(x)
  278. x = self.conv2(x)
  279. x = self.bn2(x)
  280. x = self.relu(x)
  281. x = self.block1(x)
  282. x = self.block2(x)
  283. x = self.block3(x)
  284. # x = self.block4(x)
  285. # x = self.block5(x)
  286. # x = self.block6(x)
  287. # x = self.block7(x)
  288. x = self.block8(x)
  289. x = self.block9(x)
  290. x = self.block10(x)
  291. x = self.block11(x)
  292. # x = self.block12(x)
  293. x = self.conv3(x)
  294. x = self.bn3(x)
  295. # x = self.relu(x)
  296. # x = self.conv4(x)
  297. # x = self.bn4(x)
  298. representation = self.relu2(x)
  299. return representation
  300. def interpolate(inp, size):
  301. t = inp.type()
  302. inp = inp.float()
  303. out = nn.functional.interpolate(inp, size=size, mode='bilinear', align_corners=False)
  304. if out.type() != t:
  305. out = out.half()
  306. return out
  307. class Decoder(nn.Module):
  308. def __init__(self, output_channels=32, num_classes=None):
  309. super(Decoder, self).__init__()
  310. self.output_channels = output_channels
  311. self.num_classes = num_classes
  312. if num_classes is not None:
  313. self.fc = nn.Linear(256, num_classes)
  314. # else:
  315. # self.fc = nn.Linear(256, 1000)
  316. else:
  317. self.relu = nn.ReLU(inplace=True)
  318. self.conv_decode_res = SeparableConv2d(256, 16, 3, padding=1)
  319. self.conv_decode_res2 = SeparableConv2d(256, 96, 3, padding=1)
  320. self.bn_conv_decode_res = nn.BatchNorm2d(16)
  321. self.bn_conv_decode_res2 = nn.BatchNorm2d(96)
  322. self.upconv1 = nn.ConvTranspose2d(96, 96, 2, 2)
  323. self.bn_upconv1 = nn.BatchNorm2d(96)
  324. self.conv_decode1 = SeparableConv2d(96, 64, 3, padding=1)
  325. self.bn_decode1 = nn.BatchNorm2d(64)
  326. self.upconv2 = nn.ConvTranspose2d(64, 64, 2, 2)
  327. self.bn_upconv2 = nn.BatchNorm2d(64)
  328. self.conv_decode2 = SeparableConv2d(64, 64, 5, padding=2)
  329. self.bn_decode2 = nn.BatchNorm2d(64)
  330. self.upconv3 = nn.ConvTranspose2d(64, 32, 2, 2)
  331. self.bn_upconv3 = nn.BatchNorm2d(32)
  332. self.conv_decode3 = SeparableConv2d(32, 32, 5, padding=2)
  333. self.bn_decode3 = nn.BatchNorm2d(32)
  334. self.upconv4 = nn.ConvTranspose2d(32, 32, 2, 2)
  335. self.bn_upconv4 = nn.BatchNorm2d(32)
  336. self.conv_decode4 = SeparableConv2d(48, output_channels, 5, padding=2)
  337. def forward(self, representation):
  338. # batch_size=representation.shape[0]
  339. if self.num_classes is None:
  340. x2 = self.conv_decode_res(representation)
  341. x2 = self.bn_conv_decode_res(x2)
  342. x2 = interpolate(x2, size=(256, 256))
  343. x = self.conv_decode_res2(representation)
  344. x = self.bn_conv_decode_res2(x)
  345. x = self.upconv1(x)
  346. x = self.bn_upconv1(x)
  347. x = self.relu(x)
  348. x = self.conv_decode1(x)
  349. x = self.bn_decode1(x)
  350. x = self.relu(x)
  351. x = self.upconv2(x)
  352. x = self.bn_upconv2(x)
  353. x = self.relu(x)
  354. x = self.conv_decode2(x)
  355. x = self.bn_decode2(x)
  356. x = self.relu(x)
  357. x = self.upconv3(x)
  358. x = self.bn_upconv3(x)
  359. x = self.relu(x)
  360. x = self.conv_decode3(x)
  361. x = self.bn_decode3(x)
  362. x = self.relu(x)
  363. x = self.upconv4(x)
  364. x = self.bn_upconv4(x)
  365. x = torch.cat([x, x2], 1)
  366. # print(x.shape,self.static.shape)
  367. # x = torch.cat([x,x2,input,self.static.expand(batch_size,-1,-1,-1)],1)
  368. x = self.relu(x)
  369. x = self.conv_decode4(x)
  370. # z = x[:,19:22,:,:].clone()
  371. # y = (z).norm(2,1,True).clamp(min=1e-12)
  372. # print(y.shape,x[:,21:24,:,:].shape)
  373. # x[:,19:22,:,:]=z/y
  374. else:
  375. # print(representation.shape)
  376. x = F.adaptive_avg_pool2d(representation, (1, 1))
  377. x = x.view(x.size(0), -1)
  378. # print(x.shape)
  379. x = self.fc(x)
  380. # print(x.shape)
  381. return x
  382. class Xception(BaseModel):
  383. """
  384. Xception optimized for the ImageNet dataset, as specified in
  385. https://arxiv.org/pdf/1610.02357.pdf
  386. """
  387. def __init__(self, tasks=None, num_classes=None, ozan=False, half=False):
  388. """ Constructor
  389. Args:
  390. num_classes: number of classes
  391. """
  392. super(Xception, self).__init__()
  393. print('half is', half)
  394. if half == 'Quad':
  395. print('running quad code')
  396. self.encoder = EncoderQuad()
  397. elif half == 'Double':
  398. self.encoder = EncoderDouble()
  399. elif half:
  400. self.encoder = EncoderHalf()
  401. else:
  402. self.encoder = Encoder()
  403. self.tasks = tasks
  404. self.ozan = ozan
  405. self.task_to_decoder = {}
  406. self.task_to_output_channels = {
  407. 'segment_semantic': 18,
  408. 'depth_zbuffer': 1,
  409. 'normal': 3,
  410. 'normal2': 3,
  411. 'edge_occlusion': 1,
  412. 'reshading': 1,
  413. 'keypoints2d': 1,
  414. 'edge_texture': 1,
  415. 'principal_curvature': 2,
  416. 'rgb': 3,
  417. }
  418. if tasks is not None:
  419. for task in tasks:
  420. output_channels = self.task_to_output_channels[task]
  421. decoder = Decoder(output_channels, num_classes)
  422. self.task_to_decoder[task] = decoder
  423. else:
  424. self.task_to_decoder['classification'] = Decoder(output_channels=0, num_classes=1000)
  425. self.decoders = nn.ModuleList(self.task_to_decoder.values())
  426. # ------- init weights --------
  427. for m in self.modules():
  428. if isinstance(m, nn.Conv2d):
  429. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  430. m.weight.data.normal_(0, math.sqrt(2. / n))
  431. elif isinstance(m, nn.BatchNorm2d):
  432. m.weight.data.fill_(1)
  433. m.bias.data.zero_()
  434. # -----------------------------
  435. count = 0
  436. def input_per_task_losses(self, losses):
  437. # if GradNormRepFunction.inital_task_losses is None:
  438. # GradNormRepFunction.inital_task_losses=losses
  439. # GradNormRepFunction.current_weights=[1 for i in losses]
  440. Xception.count += 1
  441. if Xception.count < 200:
  442. GradNormRepFunction.inital_task_losses = losses
  443. GradNormRepFunction.current_weights = [1 for i in losses]
  444. elif Xception.count % 20 == 0:
  445. with open("gradnorm_weights.txt", "a") as myfile:
  446. myfile.write(str(Xception.count) + ': ' + str(GradNormRepFunction.current_weights) + '\n')
  447. GradNormRepFunction.current_task_losses = losses
  448. def forward(self, input):
  449. rep = self.encoder(input)
  450. if self.tasks is None:
  451. return self.decoders[0](rep)
  452. outputs = {'rep': rep}
  453. if self.ozan == 'gradnorm':
  454. GradNormRepFunction.n = len(self.decoders)
  455. rep = gradnorm_rep_function(rep)
  456. for i, (task, decoder) in enumerate(zip(self.task_to_decoder.keys(), self.decoders)):
  457. outputs[task] = decoder(rep[i])
  458. elif self.ozan:
  459. OzanRepFunction.n = len(self.decoders)
  460. rep = ozan_rep_function(rep)
  461. for i, (task, decoder) in enumerate(zip(self.task_to_decoder.keys(), self.decoders)):
  462. outputs[task] = decoder(rep[i])
  463. else:
  464. TrevorRepFunction.n = len(self.decoders)
  465. rep = trevor_rep_function(rep)
  466. for i, (task, decoder) in enumerate(zip(self.task_to_decoder.keys(), self.decoders)):
  467. outputs[task] = decoder(rep)
  468. # Original loss
  469. # for i, (task, decoder) in enumerate(zip(self.task_to_decoder.keys(), self.decoders)):
  470. # outputs[task] = decoder(rep)
  471. return outputs
  472. def xception(pretrained=False, **kwargs):
  473. """
  474. Construct Xception.
  475. """
  476. model = Xception(**kwargs)
  477. if pretrained:
  478. # state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
  479. # for name,weight in state_dict.items():
  480. # if 'pointwise' in name:
  481. # state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
  482. # if 'conv1' in name and len(weight.shape)!=4:
  483. # state_dict[name]=weight.unsqueeze(1)
  484. # model.load_state_dict(state_dict)
  485. # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
  486. model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
  487. # if num_classes !=1000:
  488. # model.fc = nn.Linear(2048, num_classes)
  489. # import torch
  490. # print("writing new state dict")
  491. # torch.save(model.state_dict(),"xception.pth.tar")
  492. # print("done")
  493. # import sys
  494. # sys.exit(1)
  495. return model
  496. def xception_ozan(pretrained=False, **kwargs):
  497. """
  498. Construct Xception.
  499. """
  500. model = Xception(ozan=True, **kwargs)
  501. if pretrained:
  502. # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
  503. model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
  504. return model
  505. def xception_gradnorm(pretrained=False, **kwargs):
  506. """
  507. Construct Xception.
  508. """
  509. model = Xception(ozan='gradnorm', **kwargs)
  510. if pretrained:
  511. # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
  512. model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
  513. return model
  514. def xception_half_gradnorm(pretrained=False, **kwargs):
  515. """
  516. Construct Xception.
  517. """
  518. model = Xception(half=True, ozan='gradnorm', **kwargs)
  519. return model
  520. def xception_half(pretrained=False, **kwargs):
  521. """
  522. Construct Xception.
  523. """
  524. # try:
  525. # num_classes = kwargs['num_classes']
  526. # except:
  527. # num_classes=1000
  528. # if pretrained:
  529. # kwargs['num_classes']=1000
  530. model = Xception(half=True, **kwargs)
  531. if pretrained:
  532. # state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
  533. # for name,weight in state_dict.items():
  534. # if 'pointwise' in name:
  535. # state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
  536. # if 'conv1' in name and len(weight.shape)!=4:
  537. # state_dict[name]=weight.unsqueeze(1)
  538. # model.load_state_dict(state_dict)
  539. # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
  540. model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
  541. # if num_classes !=1000:
  542. # model.fc = nn.Linear(2048, num_classes)
  543. # import torch
  544. # print("writing new state dict")
  545. # torch.save(model.state_dict(),"xception.pth.tar")
  546. # print("done")
  547. # import sys
  548. # sys.exit(1)
  549. return model
  550. def xception_quad(pretrained=False, **kwargs):
  551. """
  552. Construct Xception.
  553. """
  554. # try:
  555. # num_classes = kwargs['num_classes']
  556. # except:
  557. # num_classes=1000
  558. # if pretrained:
  559. # kwargs['num_classes']=1000
  560. print('got quad')
  561. model = Xception(half='Quad', **kwargs)
  562. if pretrained:
  563. # state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
  564. # for name,weight in state_dict.items():
  565. # if 'pointwise' in name:
  566. # state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
  567. # if 'conv1' in name and len(weight.shape)!=4:
  568. # state_dict[name]=weight.unsqueeze(1)
  569. # model.load_state_dict(state_dict)
  570. # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
  571. model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
  572. # if num_classes !=1000:
  573. # model.fc = nn.Linear(2048, num_classes)
  574. # import torch
  575. # print("writing new state dict")
  576. # torch.save(model.state_dict(),"xception.pth.tar")
  577. # print("done")
  578. # import sys
  579. # sys.exit(1)
  580. return model
  581. def xception_double(pretrained=False, **kwargs):
  582. """
  583. Construct Xception.
  584. """
  585. # try:
  586. # num_classes = kwargs['num_classes']
  587. # except:
  588. # num_classes=1000
  589. # if pretrained:
  590. # kwargs['num_classes']=1000
  591. print('got double')
  592. model = Xception(half='Double', **kwargs)
  593. if pretrained:
  594. # state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
  595. # for name,weight in state_dict.items():
  596. # if 'pointwise' in name:
  597. # state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
  598. # if 'conv1' in name and len(weight.shape)!=4:
  599. # state_dict[name]=weight.unsqueeze(1)
  600. # model.load_state_dict(state_dict)
  601. # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
  602. model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
  603. # if num_classes !=1000:
  604. # model.fc = nn.Linear(2048, num_classes)
  605. # import torch
  606. # print("writing new state dict")
  607. # torch.save(model.state_dict(),"xception.pth.tar")
  608. # print("done")
  609. # import sys
  610. # sys.exit(1)
  611. return model
  612. def xception_quad_ozan(pretrained=False, **kwargs):
  613. """
  614. Construct Xception.
  615. """
  616. # try:
  617. # num_classes = kwargs['num_classes']
  618. # except:
  619. # num_classes=1000
  620. # if pretrained:
  621. # kwargs['num_classes']=1000
  622. print('got quad ozan')
  623. model = Xception(ozan=True, half='Quad', **kwargs)
  624. if pretrained:
  625. # state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
  626. # for name,weight in state_dict.items():
  627. # if 'pointwise' in name:
  628. # state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
  629. # if 'conv1' in name and len(weight.shape)!=4:
  630. # state_dict[name]=weight.unsqueeze(1)
  631. # model.load_state_dict(state_dict)
  632. # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
  633. model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
  634. # if num_classes !=1000:
  635. # model.fc = nn.Linear(2048, num_classes)
  636. # import torch
  637. # print("writing new state dict")
  638. # torch.save(model.state_dict(),"xception.pth.tar")
  639. # print("done")
  640. # import sys
  641. # sys.exit(1)
  642. return model
  643. def xception_double_ozan(pretrained=False, **kwargs):
  644. """
  645. Construct Xception.
  646. """
  647. # try:
  648. # num_classes = kwargs['num_classes']
  649. # except:
  650. # num_classes=1000
  651. # if pretrained:
  652. # kwargs['num_classes']=1000
  653. print('got double')
  654. model = Xception(ozan=True, half='Double', **kwargs)
  655. if pretrained:
  656. # state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
  657. # for name,weight in state_dict.items():
  658. # if 'pointwise' in name:
  659. # state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
  660. # if 'conv1' in name and len(weight.shape)!=4:
  661. # state_dict[name]=weight.unsqueeze(1)
  662. # model.load_state_dict(state_dict)
  663. # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
  664. model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
  665. # if num_classes !=1000:
  666. # model.fc = nn.Linear(2048, num_classes)
  667. # import torch
  668. # print("writing new state dict")
  669. # torch.save(model.state_dict(),"xception.pth.tar")
  670. # print("done")
  671. # import sys
  672. # sys.exit(1)
  673. return model
  674. def xception_half_ozan(pretrained=False, **kwargs):
  675. """
  676. Construct Xception.
  677. """
  678. model = Xception(ozan=True, half=True, **kwargs)
  679. if pretrained:
  680. # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
  681. model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
  682. return model