123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295 |
- import collections
- import torch
- sl = 0
- nl = 0
- nl2 = 0
- nl3 = 0
- dl = 0
- el = 0
- rl = 0
- kl = 0
- tl = 0
- al = 0
- cl = 0
- popular_offsets = collections.defaultdict(int)
- batch_number = 0
- TASKS = {
- 's': 'segment_semantic',
- 'd': 'depth_zbuffer',
- 'n': 'normal',
- 'N': 'normal2',
- 'k': 'keypoints2d',
- 'e': 'edge_occlusion',
- 'r': 'reshading',
- 't': 'edge_texture',
- 'a': 'rgb',
- 'c': 'principal_curvature'
- }
- LOSSES = {
- "ss_l": 's',
- "edge2d_l": 't',
- "depth_l": 'd',
- "norm_l": 'n',
- "key_l": 'k',
- "edge_l": 'e',
- "shade_l": 'r',
- "rgb_l": 'a',
- "pc_l": 'c',
- }
- def parse_tasks(task_str):
- tasks = []
- for char in task_str:
- tasks.append(TASKS[char])
- return tasks
- def parse_loss_names(loss_names):
- tasks = []
- for l in loss_names:
- tasks.append(LOSSES[l])
- return tasks
- def segment_semantic_loss(output, target, mask):
- global sl
- sl = torch.nn.functional.cross_entropy(output.float(), target.long().squeeze(dim=1), ignore_index=0,
- reduction='mean')
- return sl
- def normal_loss(output, target, mask):
- global nl
- nl = rotate_loss(output, target, mask, normal_loss_base)
- return nl
- def normal_loss_simple(output, target, mask):
- global nl
- out = torch.nn.functional.l1_loss(output, target, reduction='none')
- out *= mask.float()
- nl = out.mean()
- return nl
- def rotate_loss(output, target, mask, loss_name):
- global popular_offsets
- target = target[:, :, 1:-1, 1:-1].float()
- mask = mask[:, :, 1:-1, 1:-1].float()
- output = output.float()
- val1 = loss = loss_name(output[:, :, 1:-1, 1:-1], target, mask)
- val2 = loss_name(output[:, :, 0:-2, 1:-1], target, mask)
- loss = torch.min(loss, val2)
- val3 = loss_name(output[:, :, 1:-1, 0:-2], target, mask)
- loss = torch.min(loss, val3)
- val4 = loss_name(output[:, :, 2:, 1:-1], target, mask)
- loss = torch.min(loss, val4)
- val5 = loss_name(output[:, :, 1:-1, 2:], target, mask)
- loss = torch.min(loss, val5)
- val6 = loss_name(output[:, :, 0:-2, 0:-2], target, mask)
- loss = torch.min(loss, val6)
- val7 = loss_name(output[:, :, 2:, 2:], target, mask)
- loss = torch.min(loss, val7)
- val8 = loss_name(output[:, :, 0:-2, 2:], target, mask)
- loss = torch.min(loss, val8)
- val9 = loss_name(output[:, :, 2:, 0:-2], target, mask)
- loss = torch.min(loss, val9)
-
-
- loss = loss.mean()
-
- return loss
- def normal_loss_base(output, target, mask):
- out = torch.nn.functional.l1_loss(output, target, reduction='none')
- out *= mask
- out = out.mean(dim=(1, 2, 3))
- return out
- def normal2_loss(output, target, mask):
- global nl3
- diff = output.float() - target.float()
- out = torch.abs(diff)
- out = out * mask.float()
- nl3 = out.mean()
- return nl3
- def depth_loss_simple(output, target, mask):
- global dl
- out = torch.nn.functional.l1_loss(output, target, reduction='none')
- out *= mask.float()
- dl = out.mean()
- return dl
- def depth_loss(output, target, mask):
- global dl
- dl = rotate_loss(output, target, mask, depth_loss_base)
- return dl
- def depth_loss_base(output, target, mask):
- out = torch.nn.functional.l1_loss(output, target, reduction='none')
- out *= mask.float()
- out = out.mean(dim=(1, 2, 3))
- return out
- def edge_loss_simple(output, target, mask):
- global el
- out = torch.nn.functional.l1_loss(output, target, reduction='none')
- out *= mask
- el = out.mean()
- return el
- def reshade_loss(output, target, mask):
- global rl
- out = torch.nn.functional.l1_loss(output, target, reduction='none')
- out *= mask
- rl = out.mean()
- return rl
- def keypoints2d_loss(output, target, mask):
- global kl
- kl = torch.nn.functional.l1_loss(output, target)
- return kl
- def edge2d_loss(output, target, mask):
- global tl
- tl = torch.nn.functional.l1_loss(output, target)
- return tl
- def auto_loss(output, target, mask):
- global al
- al = torch.nn.functional.l1_loss(output, target)
- return al
- def pc_loss(output, target, mask):
- global cl
- out = torch.nn.functional.l1_loss(output, target, reduction='none')
- out *= mask
- cl = out.mean()
- return cl
- def edge_loss(output, target, mask):
- global el
- out = torch.nn.functional.l1_loss(output, target, reduction='none')
- out *= mask
- el = out.mean()
- return el
- def get_taskonomy_loss(losses):
- def taskonomy_loss(output, target):
- if 'mask' in target:
- mask = target['mask']
- else:
- mask = None
- sum_loss = None
- num = 0
- for n, t in target.items():
- if n in losses:
- o = output[n].float()
- this_loss = losses[n](o, t, mask)
- num += 1
- if sum_loss:
- sum_loss = sum_loss + this_loss
- else:
- sum_loss = this_loss
- return sum_loss
- return taskonomy_loss
- def get_losses(task_str, is_rotate_loss, task_weights=None):
- losses = {}
- criteria = {}
- if 's' in task_str:
- losses['segment_semantic'] = segment_semantic_loss
- criteria['ss_l'] = lambda x, y: sl
- if 'd' in task_str:
- if not is_rotate_loss:
- losses['depth_zbuffer'] = depth_loss_simple
- else:
- losses['depth_zbuffer'] = depth_loss
- criteria['depth_l'] = lambda x, y: dl
- if 'n' in task_str:
- if not is_rotate_loss:
- losses['normal'] = normal_loss_simple
- else:
- losses['normal'] = normal_loss
- criteria['norm_l'] = lambda x, y: nl
-
- if 'N' in task_str:
- losses['normal2'] = normal2_loss
- criteria['norm2'] = lambda x, y: nl3
- if 'k' in task_str:
- losses['keypoints2d'] = keypoints2d_loss
- criteria['key_l'] = lambda x, y: kl
- if 'e' in task_str:
- if not is_rotate_loss:
- losses['edge_occlusion'] = edge_loss_simple
- else:
- losses['edge_occlusion'] = edge_loss
-
- criteria['edge_l'] = lambda x, y: el
- if 'r' in task_str:
- losses['reshading'] = reshade_loss
- criteria['shade_l'] = lambda x, y: rl
- if 't' in task_str:
- losses['edge_texture'] = edge2d_loss
- criteria['edge2d_l'] = lambda x, y: tl
- if 'a' in task_str:
- losses['rgb'] = auto_loss
- criteria['rgb_l'] = lambda x, y: al
- if 'c' in task_str:
- losses['principal_curvature'] = pc_loss
- criteria['pc_l'] = lambda x, y: cl
- if task_weights:
- weights = [float(x) for x in task_weights.split(',')]
- losses2 = {}
- criteria2 = {}
- for l, w, c in zip(losses.items(), weights, criteria.items()):
- losses[l[0]] = lambda x, y, z, l=l[1], w=w: l(x, y, z) * w
- criteria[c[0]] = lambda x, y, c=c[1], w=w: c(x, y) * w
- taskonomy_loss = get_taskonomy_loss(losses)
- criteria2 = {'Loss': taskonomy_loss}
- for key, value in criteria.items():
- criteria2[key] = value
- criteria = criteria2
- return criteria
|