model.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. import copy
  2. import torch
  3. import torch.nn.functional as F
  4. import torchvision.models as models
  5. from torch import nn
  6. from easyfl.models.model import BaseModel
  7. from easyfl.models.resnet import ResNet18, ResNet50
  8. SimSiam = "simsiam"
  9. SimSiamNoSG = "simsiam_no_sg"
  10. SimCLR = "simclr"
  11. MoCo = "moco"
  12. MoCoV2 = "moco_v2"
  13. BYOL = "byol"
  14. BYOLNoSG = "byol_no_sg"
  15. BYOLNoEMA = "byol_no_ema"
  16. BYOLNoEMA_NoSG = "byol_no_ema_no_sg"
  17. BYOLNoPredictor = "byol_no_p"
  18. Symmetric = "symmetric"
  19. SymmetricNoSG = "symmetric_no_sg"
  20. OneLayer = "1_layer"
  21. TwoLayer = "2_layer"
  22. RESNET18 = "resnet18"
  23. RESNET50 = "resnet50"
  24. def get_encoder(arch=RESNET18):
  25. return models.__dict__[arch]
  26. def get_model(model, encoder_network, predictor_network=TwoLayer):
  27. mlp = False
  28. T = 0.07
  29. stop_gradient = True
  30. has_predictor = True
  31. if model == SymmetricNoSG:
  32. stop_gradient = False
  33. model = Symmetric
  34. elif model == SimSiamNoSG:
  35. stop_gradient = False
  36. model = SimSiam
  37. elif model == BYOLNoSG:
  38. stop_gradient = False
  39. model = BYOL
  40. elif model == BYOLNoPredictor:
  41. has_predictor = False
  42. model = BYOL
  43. elif model == MoCoV2:
  44. model = MoCo
  45. mlp = True
  46. T = 0.2
  47. if model == Symmetric:
  48. if encoder_network == RESNET50:
  49. return SymmetricModel(net=ResNet50(), stop_gradient=stop_gradient)
  50. else:
  51. return SymmetricModel(stop_gradient=stop_gradient)
  52. elif model == SimSiam:
  53. net = ResNet18()
  54. if encoder_network == RESNET50:
  55. net = ResNet50()
  56. return SimSiamModel(net=net, stop_gradient=stop_gradient)
  57. elif model == MoCo:
  58. net = ResNet18
  59. if encoder_network == RESNET50:
  60. net = ResNet50
  61. return MoCoModel(net=net, mlp=mlp, T=T)
  62. elif model == BYOL:
  63. net = ResNet18()
  64. if encoder_network == RESNET50:
  65. net = ResNet50()
  66. return BYOLModel(net=net, stop_gradient=stop_gradient, has_predictor=has_predictor,
  67. predictor_network=predictor_network)
  68. elif model == SimCLR:
  69. net = ResNet18()
  70. if encoder_network == RESNET50:
  71. net = ResNet50()
  72. return SimCLRModel(net=net)
  73. else:
  74. raise NotImplementedError
  75. def get_encoder_network(model, encoder_network, num_classes=10, projection_size=2048, projection_hidden_size=4096):
  76. if model in [MoCo, MoCoV2]:
  77. num_classes = 128
  78. if encoder_network == RESNET18:
  79. resnet = ResNet18(num_classes=num_classes)
  80. elif encoder_network == RESNET50:
  81. resnet = ResNet50(num_classes=num_classes)
  82. else:
  83. raise NotImplementedError
  84. if model in [Symmetric, SimSiam, BYOL, SymmetricNoSG, SimSiamNoSG, BYOLNoSG, SimCLR]:
  85. resnet.fc = MLP(resnet.feature_dim, projection_size, projection_hidden_size)
  86. if model == MoCoV2:
  87. resnet.fc = MLP(resnet.feature_dim, num_classes, resnet.feature_dim)
  88. return resnet
  89. # ------------- SymmetricModel Model -----------------
  90. class SymmetricModel(BaseModel):
  91. def __init__(
  92. self,
  93. net=ResNet18(),
  94. image_size=32,
  95. projection_size=2048,
  96. projection_hidden_size=4096,
  97. stop_gradient=True
  98. ):
  99. super().__init__()
  100. self.online_encoder = net
  101. self.online_encoder.fc = MLP(net.feature_dim, projection_size, projection_hidden_size) # projector
  102. self.stop_gradient = stop_gradient
  103. # send a mock image tensor to instantiate singleton parameters
  104. self.forward(torch.randn(2, 3, image_size, image_size), torch.randn(2, 3, image_size, image_size))
  105. def forward(self, image_one, image_two):
  106. f = self.online_encoder
  107. z1, z2 = f(image_one), f(image_two)
  108. if self.stop_gradient:
  109. loss = D(z1, z2)
  110. else:
  111. loss = D_NO_SG(z1, z2)
  112. return loss
  113. # ------------- SimSiam Model -----------------
  114. class SimSiamModel(BaseModel):
  115. def __init__(
  116. self,
  117. net=ResNet18(),
  118. image_size=32,
  119. projection_size=2048,
  120. projection_hidden_size=4096,
  121. stop_gradient=True,
  122. ):
  123. super().__init__()
  124. self.online_encoder = net
  125. self.online_encoder.fc = MLP(net.feature_dim, projection_size, projection_hidden_size) # projector
  126. self.online_predictor = MLP(
  127. projection_size, projection_size, projection_hidden_size
  128. )
  129. self.stop_gradient = stop_gradient
  130. # send a mock image tensor to instantiate singleton parameters
  131. self.forward(torch.randn(2, 3, image_size, image_size), torch.randn(2, 3, image_size, image_size))
  132. def forward(self, image_one, image_two):
  133. f, h = self.online_encoder, self.online_predictor
  134. z1, z2 = f(image_one), f(image_two)
  135. p1, p2 = h(z1), h(z2)
  136. if self.stop_gradient:
  137. loss = D(p1, z2) / 2 + D(p2, z1) / 2
  138. else:
  139. loss = D_NO_SG(p1, z2) / 2 + D_NO_SG(p2, z1) / 2
  140. return loss
  141. class MLP(nn.Module):
  142. def __init__(self, dim, projection_size, hidden_size=4096, num_layer=TwoLayer):
  143. super().__init__()
  144. self.in_features = dim
  145. if num_layer == OneLayer:
  146. self.net = nn.Sequential(
  147. nn.Linear(dim, projection_size),
  148. )
  149. elif num_layer == TwoLayer:
  150. self.net = nn.Sequential(
  151. nn.Linear(dim, hidden_size),
  152. nn.BatchNorm1d(hidden_size),
  153. nn.ReLU(inplace=True),
  154. nn.Linear(hidden_size, projection_size),
  155. )
  156. else:
  157. raise NotImplementedError(f"Not defined MLP: {num_layer}")
  158. def forward(self, x):
  159. return self.net(x)
  160. def D(p, z, version='simplified'): # negative cosine similarity
  161. if version == 'original':
  162. z = z.detach() # stop gradient
  163. p = F.normalize(p, dim=1) # l2-normalize
  164. z = F.normalize(z, dim=1) # l2-normalize
  165. return -(p * z).sum(dim=1).mean()
  166. elif version == 'simplified': # same thing, much faster. Scroll down, speed test in __main__
  167. return - F.cosine_similarity(p, z.detach(), dim=-1).mean()
  168. else:
  169. raise Exception
  170. def D_NO_SG(p, z, version='simplified'): # negative cosine similarity without stop gradient
  171. if version == 'original':
  172. p = F.normalize(p, dim=1) # l2-normalize
  173. z = F.normalize(z, dim=1) # l2-normalize
  174. return -(p * z).sum(dim=1).mean()
  175. elif version == 'simplified': # same thing, much faster. Scroll down, speed test in __main__
  176. return - F.cosine_similarity(p, z, dim=-1).mean()
  177. else:
  178. raise Exception
  179. # ------------- BYOL Model -----------------
  180. class BYOLModel(BaseModel):
  181. def __init__(
  182. self,
  183. net=ResNet18(),
  184. image_size=32,
  185. projection_size=2048,
  186. projection_hidden_size=4096,
  187. moving_average_decay=0.99,
  188. stop_gradient=True,
  189. has_predictor=True,
  190. predictor_network=TwoLayer,
  191. ):
  192. super().__init__()
  193. self.online_encoder = net
  194. if not hasattr(net, 'feature_dim'):
  195. feature_dim = list(net.children())[-1].in_features
  196. else:
  197. feature_dim = net.feature_dim
  198. self.online_encoder.fc = MLP(feature_dim, projection_size, projection_hidden_size) # projector
  199. self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size, predictor_network)
  200. self.target_encoder = None
  201. self.target_ema_updater = EMA(moving_average_decay)
  202. self.stop_gradient = stop_gradient
  203. self.has_predictor = has_predictor
  204. # debug purpose
  205. # self.forward(torch.randn(2, 3, image_size, image_size), torch.randn(2, 3, image_size, image_size))
  206. # self.reset_moving_average()
  207. def _get_target_encoder(self):
  208. target_encoder = copy.deepcopy(self.online_encoder)
  209. return target_encoder
  210. def reset_moving_average(self):
  211. del self.target_encoder
  212. self.target_encoder = None
  213. def update_moving_average(self):
  214. assert (
  215. self.target_encoder is not None
  216. ), "target encoder has not been created yet"
  217. update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
  218. def forward(self, image_one, image_two):
  219. online_pred_one = self.online_encoder(image_one)
  220. online_pred_two = self.online_encoder(image_two)
  221. if self.has_predictor:
  222. online_pred_one = self.online_predictor(online_pred_one)
  223. online_pred_two = self.online_predictor(online_pred_two)
  224. if self.stop_gradient:
  225. with torch.no_grad():
  226. if self.target_encoder is None:
  227. self.target_encoder = self._get_target_encoder()
  228. target_proj_one = self.target_encoder(image_one)
  229. target_proj_two = self.target_encoder(image_two)
  230. target_proj_one = target_proj_one.detach()
  231. target_proj_two = target_proj_two.detach()
  232. else:
  233. if self.target_encoder is None:
  234. self.target_encoder = self._get_target_encoder()
  235. target_proj_one = self.target_encoder(image_one)
  236. target_proj_two = self.target_encoder(image_two)
  237. loss_one = byol_loss_fn(online_pred_one, target_proj_two)
  238. loss_two = byol_loss_fn(online_pred_two, target_proj_one)
  239. loss = loss_one + loss_two
  240. return loss.mean()
  241. class EMA:
  242. def __init__(self, beta):
  243. super().__init__()
  244. self.beta = beta
  245. def update_average(self, old, new):
  246. if old is None:
  247. return new
  248. return old * self.beta + (1 - self.beta) * new
  249. def update_moving_average(ema_updater, ma_model, current_model):
  250. for current_params, ma_params in zip(
  251. current_model.parameters(), ma_model.parameters()
  252. ):
  253. old_weight, up_weight = ma_params.data, current_params.data
  254. ma_params.data = ema_updater.update_average(old_weight, up_weight)
  255. def byol_loss_fn(x, y):
  256. x = F.normalize(x, dim=-1, p=2)
  257. y = F.normalize(y, dim=-1, p=2)
  258. return 2 - 2 * (x * y).sum(dim=-1)
  259. # ------------- MoCo Model -----------------
  260. class MoCoModel(BaseModel):
  261. def __init__(self, net=ResNet18, dim=128, K=4096, m=0.99, T=0.1, bn_splits=8, symmetric=True, mlp=False):
  262. super().__init__()
  263. self.K = K
  264. self.m = m
  265. self.T = T
  266. self.symmetric = symmetric
  267. # create the encoders
  268. self.encoder_q = net(num_classes=dim)
  269. self.encoder_k = net(num_classes=dim)
  270. if mlp:
  271. feature_dim = self.encoder_q.feature_dim
  272. self.encoder_q.fc = MLP(feature_dim, dim, feature_dim)
  273. self.encoder_k.fc = MLP(feature_dim, dim, feature_dim)
  274. for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
  275. param_k.data.copy_(param_q.data) # initialize
  276. param_k.requires_grad = False # not update by gradient
  277. # create the queue
  278. self.register_buffer("queue", torch.randn(dim, K))
  279. self.queue = nn.functional.normalize(self.queue, dim=0)
  280. self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
  281. @torch.no_grad()
  282. def reset_key_encoder(self):
  283. """
  284. Momentum update of the key encoder
  285. """
  286. for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
  287. param_k.data.copy_(param_q.data) # initialize
  288. param_k.requires_grad = False # not update by gradient
  289. @torch.no_grad()
  290. def _momentum_update_key_encoder(self):
  291. """
  292. Momentum update of the key encoder
  293. """
  294. for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
  295. param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
  296. @torch.no_grad()
  297. def _dequeue_and_enqueue(self, keys):
  298. batch_size = keys.shape[0]
  299. ptr = int(self.queue_ptr)
  300. assert self.K % batch_size == 0 # for simplicity
  301. # replace the keys at ptr (dequeue and enqueue)
  302. self.queue[:, ptr:ptr + batch_size] = keys.t() # transpose
  303. ptr = (ptr + batch_size) % self.K # move pointer
  304. self.queue_ptr[0] = ptr
  305. @torch.no_grad()
  306. def _batch_shuffle_single_gpu(self, x, device):
  307. """
  308. Batch shuffle, for making use of BatchNorm.
  309. """
  310. # random shuffle index
  311. idx_shuffle = torch.randperm(x.shape[0]).to(device)
  312. # index for restoring
  313. idx_unshuffle = torch.argsort(idx_shuffle)
  314. return x[idx_shuffle], idx_unshuffle
  315. @torch.no_grad()
  316. def _batch_unshuffle_single_gpu(self, x, idx_unshuffle):
  317. """
  318. Undo batch shuffle.
  319. """
  320. return x[idx_unshuffle]
  321. def contrastive_loss(self, im_q, im_k, device):
  322. # compute query features
  323. q = self.encoder_q(im_q) # queries: NxC
  324. q = nn.functional.normalize(q, dim=1) # already normalized
  325. # compute key features
  326. with torch.no_grad(): # no gradient to keys
  327. # shuffle for making use of BN
  328. im_k_, idx_unshuffle = self._batch_shuffle_single_gpu(im_k, device)
  329. k = self.encoder_k(im_k_) # keys: NxC
  330. k = nn.functional.normalize(k, dim=1) # already normalized
  331. # undo shuffle
  332. k = self._batch_unshuffle_single_gpu(k, idx_unshuffle)
  333. # compute logits
  334. # Einstein sum is more intuitive
  335. # positive logits: Nx1
  336. l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
  337. # negative logits: NxK
  338. l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
  339. # logits: Nx(1+K)
  340. logits = torch.cat([l_pos, l_neg], dim=1)
  341. # apply temperature
  342. logits /= self.T
  343. # labels: positive key indicators
  344. labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)
  345. loss = nn.CrossEntropyLoss().to(device)(logits, labels)
  346. return loss, q, k
  347. def forward(self, im1, im2, device):
  348. """
  349. Input:
  350. im_q: a batch of query images
  351. im_k: a batch of key images
  352. Output:
  353. loss
  354. """
  355. # update the key encoder
  356. with torch.no_grad(): # no gradient to keys
  357. self._momentum_update_key_encoder()
  358. # compute loss
  359. if self.symmetric: # asymmetric loss
  360. loss_12, q1, k2 = self.contrastive_loss(im1, im2, device)
  361. loss_21, q2, k1 = self.contrastive_loss(im2, im1, device)
  362. loss = loss_12 + loss_21
  363. k = torch.cat([k1, k2], dim=0)
  364. else: # asymmetric loss
  365. loss, q, k = self.contrastive_loss(im1, im2, device)
  366. self._dequeue_and_enqueue(k)
  367. return loss
  368. # ------------- SimCLR Model -----------------
  369. class SimCLRModel(BaseModel):
  370. def __init__(self, net=ResNet18(), image_size=32, projection_size=2048, projection_hidden_size=4096):
  371. super().__init__()
  372. self.online_encoder = net
  373. self.online_encoder.fc = MLP(net.feature_dim, projection_size, projection_hidden_size) # projector
  374. def forward(self, image):
  375. return self.online_encoder(image)