123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483 |
- import copy
- import torch
- import torch.nn.functional as F
- import torchvision.models as models
- from torch import nn
- from easyfl.models.model import BaseModel
- from easyfl.models.resnet import ResNet18, ResNet50
- SimSiam = "simsiam"
- SimSiamNoSG = "simsiam_no_sg"
- SimCLR = "simclr"
- MoCo = "moco"
- MoCoV2 = "moco_v2"
- BYOL = "byol"
- BYOLNoSG = "byol_no_sg"
- BYOLNoEMA = "byol_no_ema"
- BYOLNoEMA_NoSG = "byol_no_ema_no_sg"
- BYOLNoPredictor = "byol_no_p"
- Symmetric = "symmetric"
- SymmetricNoSG = "symmetric_no_sg"
- OneLayer = "1_layer"
- TwoLayer = "2_layer"
- RESNET18 = "resnet18"
- RESNET50 = "resnet50"
- def get_encoder(arch=RESNET18):
- return models.__dict__[arch]
- def get_model(model, encoder_network, predictor_network=TwoLayer):
- mlp = False
- T = 0.07
- stop_gradient = True
- has_predictor = True
- if model == SymmetricNoSG:
- stop_gradient = False
- model = Symmetric
- elif model == SimSiamNoSG:
- stop_gradient = False
- model = SimSiam
- elif model == BYOLNoSG:
- stop_gradient = False
- model = BYOL
- elif model == BYOLNoPredictor:
- has_predictor = False
- model = BYOL
- elif model == MoCoV2:
- model = MoCo
- mlp = True
- T = 0.2
- if model == Symmetric:
- if encoder_network == RESNET50:
- return SymmetricModel(net=ResNet50(), stop_gradient=stop_gradient)
- else:
- return SymmetricModel(stop_gradient=stop_gradient)
- elif model == SimSiam:
- net = ResNet18()
- if encoder_network == RESNET50:
- net = ResNet50()
- return SimSiamModel(net=net, stop_gradient=stop_gradient)
- elif model == MoCo:
- net = ResNet18
- if encoder_network == RESNET50:
- net = ResNet50
- return MoCoModel(net=net, mlp=mlp, T=T)
- elif model == BYOL:
- net = ResNet18()
- if encoder_network == RESNET50:
- net = ResNet50()
- return BYOLModel(net=net, stop_gradient=stop_gradient, has_predictor=has_predictor,
- predictor_network=predictor_network)
- elif model == SimCLR:
- net = ResNet18()
- if encoder_network == RESNET50:
- net = ResNet50()
- return SimCLRModel(net=net)
- else:
- raise NotImplementedError
- def get_encoder_network(model, encoder_network, num_classes=10, projection_size=2048, projection_hidden_size=4096):
- if model in [MoCo, MoCoV2]:
- num_classes = 128
- if encoder_network == RESNET18:
- resnet = ResNet18(num_classes=num_classes)
- elif encoder_network == RESNET50:
- resnet = ResNet50(num_classes=num_classes)
- else:
- raise NotImplementedError
- if model in [Symmetric, SimSiam, BYOL, SymmetricNoSG, SimSiamNoSG, BYOLNoSG, SimCLR]:
- resnet.fc = MLP(resnet.feature_dim, projection_size, projection_hidden_size)
- if model == MoCoV2:
- resnet.fc = MLP(resnet.feature_dim, num_classes, resnet.feature_dim)
- return resnet
- class SymmetricModel(BaseModel):
- def __init__(
- self,
- net=ResNet18(),
- image_size=32,
- projection_size=2048,
- projection_hidden_size=4096,
- stop_gradient=True
- ):
- super().__init__()
- self.online_encoder = net
- self.online_encoder.fc = MLP(net.feature_dim, projection_size, projection_hidden_size)
- self.stop_gradient = stop_gradient
-
- self.forward(torch.randn(2, 3, image_size, image_size), torch.randn(2, 3, image_size, image_size))
- def forward(self, image_one, image_two):
- f = self.online_encoder
- z1, z2 = f(image_one), f(image_two)
- if self.stop_gradient:
- loss = D(z1, z2)
- else:
- loss = D_NO_SG(z1, z2)
- return loss
- class SimSiamModel(BaseModel):
- def __init__(
- self,
- net=ResNet18(),
- image_size=32,
- projection_size=2048,
- projection_hidden_size=4096,
- stop_gradient=True,
- ):
- super().__init__()
- self.online_encoder = net
- self.online_encoder.fc = MLP(net.feature_dim, projection_size, projection_hidden_size)
- self.online_predictor = MLP(
- projection_size, projection_size, projection_hidden_size
- )
- self.stop_gradient = stop_gradient
-
- self.forward(torch.randn(2, 3, image_size, image_size), torch.randn(2, 3, image_size, image_size))
- def forward(self, image_one, image_two):
- f, h = self.online_encoder, self.online_predictor
- z1, z2 = f(image_one), f(image_two)
- p1, p2 = h(z1), h(z2)
- if self.stop_gradient:
- loss = D(p1, z2) / 2 + D(p2, z1) / 2
- else:
- loss = D_NO_SG(p1, z2) / 2 + D_NO_SG(p2, z1) / 2
- return loss
- class MLP(nn.Module):
- def __init__(self, dim, projection_size, hidden_size=4096, num_layer=TwoLayer):
- super().__init__()
- self.in_features = dim
- if num_layer == OneLayer:
- self.net = nn.Sequential(
- nn.Linear(dim, projection_size),
- )
- elif num_layer == TwoLayer:
- self.net = nn.Sequential(
- nn.Linear(dim, hidden_size),
- nn.BatchNorm1d(hidden_size),
- nn.ReLU(inplace=True),
- nn.Linear(hidden_size, projection_size),
- )
- else:
- raise NotImplementedError(f"Not defined MLP: {num_layer}")
- def forward(self, x):
- return self.net(x)
- def D(p, z, version='simplified'):
- if version == 'original':
- z = z.detach()
- p = F.normalize(p, dim=1)
- z = F.normalize(z, dim=1)
- return -(p * z).sum(dim=1).mean()
- elif version == 'simplified':
- return - F.cosine_similarity(p, z.detach(), dim=-1).mean()
- else:
- raise Exception
- def D_NO_SG(p, z, version='simplified'):
- if version == 'original':
- p = F.normalize(p, dim=1)
- z = F.normalize(z, dim=1)
- return -(p * z).sum(dim=1).mean()
- elif version == 'simplified':
- return - F.cosine_similarity(p, z, dim=-1).mean()
- else:
- raise Exception
- class BYOLModel(BaseModel):
- def __init__(
- self,
- net=ResNet18(),
- image_size=32,
- projection_size=2048,
- projection_hidden_size=4096,
- moving_average_decay=0.99,
- stop_gradient=True,
- has_predictor=True,
- predictor_network=TwoLayer,
- ):
- super().__init__()
- self.online_encoder = net
- if not hasattr(net, 'feature_dim'):
- feature_dim = list(net.children())[-1].in_features
- else:
- feature_dim = net.feature_dim
- self.online_encoder.fc = MLP(feature_dim, projection_size, projection_hidden_size)
- self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size, predictor_network)
- self.target_encoder = None
- self.target_ema_updater = EMA(moving_average_decay)
- self.stop_gradient = stop_gradient
- self.has_predictor = has_predictor
-
-
-
- def _get_target_encoder(self):
- target_encoder = copy.deepcopy(self.online_encoder)
- return target_encoder
- def reset_moving_average(self):
- del self.target_encoder
- self.target_encoder = None
- def update_moving_average(self):
- assert (
- self.target_encoder is not None
- ), "target encoder has not been created yet"
- update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
- def forward(self, image_one, image_two):
- online_pred_one = self.online_encoder(image_one)
- online_pred_two = self.online_encoder(image_two)
- if self.has_predictor:
- online_pred_one = self.online_predictor(online_pred_one)
- online_pred_two = self.online_predictor(online_pred_two)
- if self.stop_gradient:
- with torch.no_grad():
- if self.target_encoder is None:
- self.target_encoder = self._get_target_encoder()
- target_proj_one = self.target_encoder(image_one)
- target_proj_two = self.target_encoder(image_two)
- target_proj_one = target_proj_one.detach()
- target_proj_two = target_proj_two.detach()
- else:
- if self.target_encoder is None:
- self.target_encoder = self._get_target_encoder()
- target_proj_one = self.target_encoder(image_one)
- target_proj_two = self.target_encoder(image_two)
- loss_one = byol_loss_fn(online_pred_one, target_proj_two)
- loss_two = byol_loss_fn(online_pred_two, target_proj_one)
- loss = loss_one + loss_two
- return loss.mean()
- class EMA:
- def __init__(self, beta):
- super().__init__()
- self.beta = beta
- def update_average(self, old, new):
- if old is None:
- return new
- return old * self.beta + (1 - self.beta) * new
- def update_moving_average(ema_updater, ma_model, current_model):
- for current_params, ma_params in zip(
- current_model.parameters(), ma_model.parameters()
- ):
- old_weight, up_weight = ma_params.data, current_params.data
- ma_params.data = ema_updater.update_average(old_weight, up_weight)
- def byol_loss_fn(x, y):
- x = F.normalize(x, dim=-1, p=2)
- y = F.normalize(y, dim=-1, p=2)
- return 2 - 2 * (x * y).sum(dim=-1)
- class MoCoModel(BaseModel):
- def __init__(self, net=ResNet18, dim=128, K=4096, m=0.99, T=0.1, bn_splits=8, symmetric=True, mlp=False):
- super().__init__()
- self.K = K
- self.m = m
- self.T = T
- self.symmetric = symmetric
-
- self.encoder_q = net(num_classes=dim)
- self.encoder_k = net(num_classes=dim)
- if mlp:
- feature_dim = self.encoder_q.feature_dim
- self.encoder_q.fc = MLP(feature_dim, dim, feature_dim)
- self.encoder_k.fc = MLP(feature_dim, dim, feature_dim)
- for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
- param_k.data.copy_(param_q.data)
- param_k.requires_grad = False
-
- self.register_buffer("queue", torch.randn(dim, K))
- self.queue = nn.functional.normalize(self.queue, dim=0)
- self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
- @torch.no_grad()
- def reset_key_encoder(self):
- """
- Momentum update of the key encoder
- """
- for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
- param_k.data.copy_(param_q.data)
- param_k.requires_grad = False
- @torch.no_grad()
- def _momentum_update_key_encoder(self):
- """
- Momentum update of the key encoder
- """
- for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
- param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
- @torch.no_grad()
- def _dequeue_and_enqueue(self, keys):
- batch_size = keys.shape[0]
- ptr = int(self.queue_ptr)
- assert self.K % batch_size == 0
-
- self.queue[:, ptr:ptr + batch_size] = keys.t()
- ptr = (ptr + batch_size) % self.K
- self.queue_ptr[0] = ptr
- @torch.no_grad()
- def _batch_shuffle_single_gpu(self, x, device):
- """
- Batch shuffle, for making use of BatchNorm.
- """
-
- idx_shuffle = torch.randperm(x.shape[0]).to(device)
-
- idx_unshuffle = torch.argsort(idx_shuffle)
- return x[idx_shuffle], idx_unshuffle
- @torch.no_grad()
- def _batch_unshuffle_single_gpu(self, x, idx_unshuffle):
- """
- Undo batch shuffle.
- """
- return x[idx_unshuffle]
- def contrastive_loss(self, im_q, im_k, device):
-
- q = self.encoder_q(im_q)
- q = nn.functional.normalize(q, dim=1)
-
- with torch.no_grad():
-
- im_k_, idx_unshuffle = self._batch_shuffle_single_gpu(im_k, device)
- k = self.encoder_k(im_k_)
- k = nn.functional.normalize(k, dim=1)
-
- k = self._batch_unshuffle_single_gpu(k, idx_unshuffle)
-
-
-
- l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
-
- l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
-
- logits = torch.cat([l_pos, l_neg], dim=1)
-
- logits /= self.T
-
- labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)
- loss = nn.CrossEntropyLoss().to(device)(logits, labels)
- return loss, q, k
- def forward(self, im1, im2, device):
- """
- Input:
- im_q: a batch of query images
- im_k: a batch of key images
- Output:
- loss
- """
-
- with torch.no_grad():
- self._momentum_update_key_encoder()
-
- if self.symmetric:
- loss_12, q1, k2 = self.contrastive_loss(im1, im2, device)
- loss_21, q2, k1 = self.contrastive_loss(im2, im1, device)
- loss = loss_12 + loss_21
- k = torch.cat([k1, k2], dim=0)
- else:
- loss, q, k = self.contrastive_loss(im1, im2, device)
- self._dequeue_and_enqueue(k)
- return loss
- class SimCLRModel(BaseModel):
- def __init__(self, net=ResNet18(), image_size=32, projection_size=2048, projection_hidden_size=4096):
- super().__init__()
- self.online_encoder = net
- self.online_encoder.fc = MLP(net.feature_dim, projection_size, projection_hidden_size)
- def forward(self, image):
- return self.online_encoder(image)
|