losses.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. import collections
  2. import torch
  3. sl = 0
  4. nl = 0
  5. nl2 = 0
  6. nl3 = 0
  7. dl = 0
  8. el = 0
  9. rl = 0
  10. kl = 0
  11. tl = 0
  12. al = 0
  13. cl = 0
  14. popular_offsets = collections.defaultdict(int)
  15. batch_number = 0
  16. TASKS = {
  17. 's': 'segment_semantic',
  18. 'd': 'depth_zbuffer',
  19. 'n': 'normal',
  20. 'N': 'normal2',
  21. 'k': 'keypoints2d',
  22. 'e': 'edge_occlusion',
  23. 'r': 'reshading',
  24. 't': 'edge_texture',
  25. 'a': 'rgb',
  26. 'c': 'principal_curvature'
  27. }
  28. LOSSES = {
  29. "ss_l": 's',
  30. "edge2d_l": 't',
  31. "depth_l": 'd',
  32. "norm_l": 'n',
  33. "key_l": 'k',
  34. "edge_l": 'e',
  35. "shade_l": 'r',
  36. "rgb_l": 'a',
  37. "pc_l": 'c',
  38. }
  39. def parse_tasks(task_str):
  40. tasks = []
  41. for char in task_str:
  42. tasks.append(TASKS[char])
  43. return tasks
  44. def parse_loss_names(loss_names):
  45. tasks = []
  46. for l in loss_names:
  47. tasks.append(LOSSES[l])
  48. return tasks
  49. def segment_semantic_loss(output, target, mask):
  50. global sl
  51. sl = torch.nn.functional.cross_entropy(output.float(), target.long().squeeze(dim=1), ignore_index=0,
  52. reduction='mean')
  53. return sl
  54. def normal_loss(output, target, mask):
  55. global nl
  56. nl = rotate_loss(output, target, mask, normal_loss_base)
  57. return nl
  58. def normal_loss_simple(output, target, mask):
  59. global nl
  60. out = torch.nn.functional.l1_loss(output, target, reduction='none')
  61. out *= mask.float()
  62. nl = out.mean()
  63. return nl
  64. def rotate_loss(output, target, mask, loss_name):
  65. global popular_offsets
  66. target = target[:, :, 1:-1, 1:-1].float()
  67. mask = mask[:, :, 1:-1, 1:-1].float()
  68. output = output.float()
  69. val1 = loss = loss_name(output[:, :, 1:-1, 1:-1], target, mask)
  70. val2 = loss_name(output[:, :, 0:-2, 1:-1], target, mask)
  71. loss = torch.min(loss, val2)
  72. val3 = loss_name(output[:, :, 1:-1, 0:-2], target, mask)
  73. loss = torch.min(loss, val3)
  74. val4 = loss_name(output[:, :, 2:, 1:-1], target, mask)
  75. loss = torch.min(loss, val4)
  76. val5 = loss_name(output[:, :, 1:-1, 2:], target, mask)
  77. loss = torch.min(loss, val5)
  78. val6 = loss_name(output[:, :, 0:-2, 0:-2], target, mask)
  79. loss = torch.min(loss, val6)
  80. val7 = loss_name(output[:, :, 2:, 2:], target, mask)
  81. loss = torch.min(loss, val7)
  82. val8 = loss_name(output[:, :, 0:-2, 2:], target, mask)
  83. loss = torch.min(loss, val8)
  84. val9 = loss_name(output[:, :, 2:, 0:-2], target, mask)
  85. loss = torch.min(loss, val9)
  86. # lst = [val1,val2,val3,val4,val5,val6,val7,val8,val9]
  87. # print(loss.size())
  88. loss = loss.mean()
  89. # print(loss)
  90. return loss
  91. def normal_loss_base(output, target, mask):
  92. out = torch.nn.functional.l1_loss(output, target, reduction='none')
  93. out *= mask
  94. out = out.mean(dim=(1, 2, 3))
  95. return out
  96. def normal2_loss(output, target, mask):
  97. global nl3
  98. diff = output.float() - target.float()
  99. out = torch.abs(diff)
  100. out = out * mask.float()
  101. nl3 = out.mean()
  102. return nl3
  103. def depth_loss_simple(output, target, mask):
  104. global dl
  105. out = torch.nn.functional.l1_loss(output, target, reduction='none')
  106. out *= mask.float()
  107. dl = out.mean()
  108. return dl
  109. def depth_loss(output, target, mask):
  110. global dl
  111. dl = rotate_loss(output, target, mask, depth_loss_base)
  112. return dl
  113. def depth_loss_base(output, target, mask):
  114. out = torch.nn.functional.l1_loss(output, target, reduction='none')
  115. out *= mask.float()
  116. out = out.mean(dim=(1, 2, 3))
  117. return out
  118. def edge_loss_simple(output, target, mask):
  119. global el
  120. out = torch.nn.functional.l1_loss(output, target, reduction='none')
  121. out *= mask
  122. el = out.mean()
  123. return el
  124. def reshade_loss(output, target, mask):
  125. global rl
  126. out = torch.nn.functional.l1_loss(output, target, reduction='none')
  127. out *= mask
  128. rl = out.mean()
  129. return rl
  130. def keypoints2d_loss(output, target, mask):
  131. global kl
  132. kl = torch.nn.functional.l1_loss(output, target)
  133. return kl
  134. def edge2d_loss(output, target, mask):
  135. global tl
  136. tl = torch.nn.functional.l1_loss(output, target)
  137. return tl
  138. def auto_loss(output, target, mask):
  139. global al
  140. al = torch.nn.functional.l1_loss(output, target)
  141. return al
  142. def pc_loss(output, target, mask):
  143. global cl
  144. out = torch.nn.functional.l1_loss(output, target, reduction='none')
  145. out *= mask
  146. cl = out.mean()
  147. return cl
  148. def edge_loss(output, target, mask):
  149. global el
  150. out = torch.nn.functional.l1_loss(output, target, reduction='none')
  151. out *= mask
  152. el = out.mean()
  153. return el
  154. def get_taskonomy_loss(losses):
  155. def taskonomy_loss(output, target):
  156. if 'mask' in target:
  157. mask = target['mask']
  158. else:
  159. mask = None
  160. sum_loss = None
  161. num = 0
  162. for n, t in target.items():
  163. if n in losses:
  164. o = output[n].float()
  165. this_loss = losses[n](o, t, mask)
  166. num += 1
  167. if sum_loss:
  168. sum_loss = sum_loss + this_loss
  169. else:
  170. sum_loss = this_loss
  171. return sum_loss # /num # should not take average when using xception_taskonomy_new
  172. return taskonomy_loss
  173. def get_losses(task_str, is_rotate_loss, task_weights=None):
  174. losses = {}
  175. criteria = {}
  176. if 's' in task_str:
  177. losses['segment_semantic'] = segment_semantic_loss
  178. criteria['ss_l'] = lambda x, y: sl
  179. if 'd' in task_str:
  180. if not is_rotate_loss:
  181. losses['depth_zbuffer'] = depth_loss_simple
  182. else:
  183. losses['depth_zbuffer'] = depth_loss
  184. criteria['depth_l'] = lambda x, y: dl
  185. if 'n' in task_str:
  186. if not is_rotate_loss:
  187. losses['normal'] = normal_loss_simple
  188. else:
  189. losses['normal'] = normal_loss
  190. criteria['norm_l'] = lambda x, y: nl
  191. # criteria['norm_l2']=lambda x,y : nl2
  192. if 'N' in task_str:
  193. losses['normal2'] = normal2_loss
  194. criteria['norm2'] = lambda x, y: nl3
  195. if 'k' in task_str:
  196. losses['keypoints2d'] = keypoints2d_loss
  197. criteria['key_l'] = lambda x, y: kl
  198. if 'e' in task_str:
  199. if not is_rotate_loss:
  200. losses['edge_occlusion'] = edge_loss_simple
  201. else:
  202. losses['edge_occlusion'] = edge_loss
  203. # losses['edge_occlusion']=edge_loss
  204. criteria['edge_l'] = lambda x, y: el
  205. if 'r' in task_str:
  206. losses['reshading'] = reshade_loss
  207. criteria['shade_l'] = lambda x, y: rl
  208. if 't' in task_str:
  209. losses['edge_texture'] = edge2d_loss
  210. criteria['edge2d_l'] = lambda x, y: tl
  211. if 'a' in task_str:
  212. losses['rgb'] = auto_loss
  213. criteria['rgb_l'] = lambda x, y: al
  214. if 'c' in task_str:
  215. losses['principal_curvature'] = pc_loss
  216. criteria['pc_l'] = lambda x, y: cl
  217. if task_weights:
  218. weights = [float(x) for x in task_weights.split(',')]
  219. losses2 = {}
  220. criteria2 = {}
  221. for l, w, c in zip(losses.items(), weights, criteria.items()):
  222. losses[l[0]] = lambda x, y, z, l=l[1], w=w: l(x, y, z) * w
  223. criteria[c[0]] = lambda x, y, c=c[1], w=w: c(x, y) * w
  224. taskonomy_loss = get_taskonomy_loss(losses)
  225. criteria2 = {'Loss': taskonomy_loss}
  226. for key, value in criteria.items():
  227. criteria2[key] = value
  228. criteria = criteria2
  229. return criteria