bottomup.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import copy
  2. import logging
  3. import sys
  4. import numpy as np
  5. import torch
  6. from .evaluators import Evaluator, extract_features
  7. from .exclusive_loss import ExLoss
  8. from .trainers import Trainer
  9. from .utils.transform.transforms import TRANSFORM_VAL_LIST
  10. logger = logging.getLogger(__name__)
  11. class BottomUp:
  12. def __init__(self,
  13. cid,
  14. model,
  15. batch_size,
  16. eval_batch_size,
  17. num_classes,
  18. train_data,
  19. test_data,
  20. device,
  21. embedding_feature_size=2048,
  22. initial_epochs=20,
  23. local_epochs=2,
  24. step_size=16,
  25. seed=0):
  26. self.cid = cid
  27. self.model = model
  28. self.num_classes = num_classes
  29. self.batch_size = batch_size
  30. self.eval_batch_size = eval_batch_size
  31. self.device = device
  32. self.seed = seed
  33. self.gallery_cam = None
  34. self.gallery_label = None
  35. self.query_cam = None
  36. self.query_label = None
  37. self.test_gallery_loader = None
  38. self.test_query_loader = None
  39. self.train_data = train_data
  40. self.test_data = test_data
  41. self.initial_epochs = initial_epochs
  42. self.local_epochs = local_epochs
  43. self.step_size = step_size
  44. self.embedding_feature_size = embedding_feature_size
  45. self.fixed_layer = False
  46. self.old_features = None
  47. self.feature_distance = 0
  48. self.criterion = ExLoss(self.embedding_feature_size, self.num_classes, t=10).to(device)
  49. def set_model(self, model, current_step):
  50. if current_step == 0:
  51. self.model = model.to(self.device)
  52. else:
  53. self.model.load_state_dict(model.state_dict())
  54. self.model = self.model.to(self.device)
  55. def train(self, step, dynamic_epoch=False):
  56. self.model = self.model.train()
  57. # adjust training epochs and learning rate
  58. epochs = self.initial_epochs if step == 0 else self.local_epochs
  59. init_lr = 0.1 if step == 0 else 0.01
  60. step_size = self.step_size if step == 0 else sys.maxsize
  61. logger.info("create train transform loader with batch size {}".format(self.batch_size))
  62. loader = self.train_data.loader(self.batch_size, self.cid, seed=self.seed, num_workers=6)
  63. # the base parameters for the backbone (e.g. ResNet50)
  64. base_param_ids = set(map(id, self.model.CNN.base.parameters()))
  65. # we fixed the first three blocks to save GPU memory
  66. base_params_need_for_grad = filter(lambda p: p.requires_grad, self.model.CNN.base.parameters())
  67. # params of the new layers
  68. new_params = [p for p in self.model.parameters() if id(p) not in base_param_ids]
  69. # set the learning rate for backbone to be 0.1 times
  70. param_groups = [
  71. {'params': base_params_need_for_grad, 'lr_mult': 0.1},
  72. {'params': new_params, 'lr_mult': 1.0}]
  73. optimizer = torch.optim.SGD(param_groups, lr=init_lr, momentum=0.9, weight_decay=5e-4, nesterov=True)
  74. # change the learning rate by step
  75. def adjust_lr(epoch, step_size):
  76. lr = init_lr / (10 ** (epoch // step_size))
  77. for g in optimizer.param_groups:
  78. g['lr'] = lr * g.get('lr_mult', 1)
  79. logger.info("number of epochs, {}: {}".format(self.cid, epochs))
  80. """ main training process """
  81. trainer = Trainer(self.model, self.criterion, self.device, fixed_layer=self.fixed_layer)
  82. for epoch in range(epochs):
  83. adjust_lr(epoch, step_size)
  84. stop_local_training = trainer.train(epoch, loader, optimizer, print_freq=max(5, len(loader) // 30 * 10))
  85. # Dynamically decide number of local epochs, based on conditions inside trainer.
  86. if step > 0 and dynamic_epoch and stop_local_training:
  87. logger.info("Dynamic epoch: in step {}, stop training {} after epoch {}".format(step, self.cid, epoch))
  88. break
  89. return self.model
  90. def evaluate(self, cid, model=None):
  91. # getting cid from argument is because of merged training
  92. if model is None:
  93. model = self.model
  94. model = model.eval()
  95. model = model.to(self.device)
  96. gallery_id = '{}_{}'.format(cid, 'gallery')
  97. query_id = '{}_{}'.format(cid, 'query')
  98. logger.info("create test transform loader with batch size {}".format(self.eval_batch_size))
  99. gallery_loader = self.test_data.loader(batch_size=self.eval_batch_size,
  100. client_id=gallery_id,
  101. shuffle=False,
  102. num_workers=6)
  103. query_loader = self.test_data.loader(batch_size=self.eval_batch_size,
  104. client_id=query_id,
  105. shuffle=False,
  106. num_workers=6)
  107. evaluator = Evaluator(model, self.test_data, query_id, gallery_id, self.device)
  108. rank1, rank5, rank10, mAP = evaluator.evaluate(query_loader, gallery_loader)
  109. return rank1, rank5, rank10, mAP
  110. # New get_new_train_data
  111. def relabel_train_data(self, device, unlabeled_ys, labeled_ys, nums_to_merge, size_penalty):
  112. # extract feature/classifier
  113. self.model = self.model.to(device)
  114. loader = self.train_data.loader(self.batch_size,
  115. self.cid,
  116. shuffle=False,
  117. num_workers=6,
  118. transform=TRANSFORM_VAL_LIST)
  119. features = extract_features(self.model, loader, device)
  120. # calculate cosine distance of features
  121. if self.old_features:
  122. similarities = []
  123. for old_feature, new_feature in zip(self.old_features, features):
  124. m = torch.cosine_similarity(old_feature, new_feature, dim=0)
  125. similarities.append(m)
  126. self.feature_distance = 1 - sum(similarities) / len(similarities)
  127. logger.info("Cosine distance between features, {}: {}".format(self.cid, self.feature_distance))
  128. self.old_features = copy.deepcopy(features)
  129. features = np.array([logit.numpy() for logit in features])
  130. # images of the same cluster
  131. label_to_images = {}
  132. for idx, l in enumerate(unlabeled_ys):
  133. label_to_images[l] = label_to_images.get(l, []) + [idx]
  134. dists = self.calculate_distance(features)
  135. idx1, idx2 = self.select_merge_data(features, unlabeled_ys, label_to_images, size_penalty, dists)
  136. unlabeled_ys = self.relabel_new_train_data(idx1, idx2, labeled_ys, unlabeled_ys, nums_to_merge)
  137. num_classes = len(np.unique(np.array(unlabeled_ys)))
  138. # change the criterion classifier
  139. self.criterion = ExLoss(self.embedding_feature_size, num_classes, t=10).to(device)
  140. return unlabeled_ys
  141. def relabel_new_train_data(self, idx1, idx2, labeled_ys, label, num_to_merge):
  142. correct = 0
  143. num_before_merge = len(np.unique(np.array(label)))
  144. # merge clusters with minimum dissimilarity
  145. for i in range(len(idx1)):
  146. label1 = label[idx1[i]]
  147. label2 = label[idx2[i]]
  148. if label1 < label2:
  149. label = [label1 if x == label2 else x for x in label]
  150. else:
  151. label = [label2 if x == label1 else x for x in label]
  152. if labeled_ys[idx1[i]] == labeled_ys[idx2[i]]:
  153. correct += 1
  154. num_merged = num_before_merge - len(np.sort(np.unique(np.array(label))))
  155. if num_merged == num_to_merge:
  156. break
  157. # set new label to the new training transform
  158. unique_label = np.sort(np.unique(np.array(label)))
  159. for i in range(len(unique_label)):
  160. label_now = unique_label[i]
  161. label = [i if x == label_now else x for x in label]
  162. self.train_data.data[self.cid]['y'] = label
  163. num_after_merge = len(np.unique(np.array(label)))
  164. logger.info("num of label before merge: {}, after merge: {}, sub: {}".format(
  165. num_before_merge, num_after_merge, num_before_merge - num_after_merge))
  166. return label
  167. def calculate_distance(self, u_feas):
  168. # calculate distance between features
  169. x = torch.from_numpy(u_feas)
  170. y = x
  171. m = len(u_feas)
  172. dists = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, m) + \
  173. torch.pow(y, 2).sum(dim=1, keepdim=True).expand(m, m).t()
  174. dists.addmm_(1, -2, x, y.t())
  175. return dists
  176. def select_merge_data(self, u_feas, label, label_to_images, ratio_n, dists):
  177. dists.add_(torch.tril(100000 * torch.ones(len(u_feas), len(u_feas))))
  178. cnt = torch.FloatTensor([len(label_to_images[label[idx]]) for idx in range(len(u_feas))])
  179. dists += ratio_n * (cnt.view(1, len(cnt)) + cnt.view(len(cnt), 1))
  180. for idx in range(len(u_feas)):
  181. for j in range(idx + 1, len(u_feas)):
  182. if label[idx] == label[j]:
  183. dists[idx, j] = 100000
  184. dists = dists.numpy()
  185. ind = np.unravel_index(np.argsort(dists, axis=None), dists.shape)
  186. idx1 = ind[0]
  187. idx2 = ind[1]
  188. return idx1, idx2