123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- import numpy as np
- import torch
- import torchvision
- from torch import nn
- from torchvision import transforms
- class MoCoTransform:
- """
- A stochastic data augmentation module that transforms any given data example randomly
- resulting in two correlated views of the same example,
- denoted x ̃i and x ̃j, which we consider as a positive pair.
- """
- def __init__(self, size=32, gaussian=False):
- self.train_transform = transforms.Compose([
- transforms.ToPILImage(mode='RGB'),
- transforms.RandomResizedCrop(size),
- transforms.RandomHorizontalFlip(p=0.5),
- transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
- transforms.RandomGrayscale(p=0.2),
- transforms.ToTensor()])
- self.test_transform = transforms.Compose([
- torchvision.transforms.Resize(size=size),
- transforms.ToTensor()]
- )
- def __call__(self, x):
- return self.train_transform(x), self.train_transform(x)
- class SimSiamTransform:
- """
- A stochastic data augmentation module that transforms any given data example randomly
- resulting in two correlated views of the same example,
- denoted x ̃i and x ̃j, which we consider as a positive pair.
- """
- def __init__(self, size=32, gaussian=False):
- s = 1
- color_jitter = torchvision.transforms.ColorJitter(
- 0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
- )
- if gaussian:
- self.train_transform = torchvision.transforms.Compose(
- [
- torchvision.transforms.ToPILImage(mode='RGB'),
- torchvision.transforms.RandomResizedCrop(size=size),
- torchvision.transforms.RandomHorizontalFlip(), # with 0.5 probability
- torchvision.transforms.RandomApply([color_jitter], p=0.8),
- torchvision.transforms.RandomGrayscale(p=0.2),
- GaussianBlur(kernel_size=int(0.1 * size)),
- torchvision.transforms.ToTensor(),
- ]
- )
- else:
- self.train_transform = torchvision.transforms.Compose(
- [
- torchvision.transforms.ToPILImage(mode='RGB'),
- torchvision.transforms.RandomResizedCrop(size=size),
- torchvision.transforms.RandomHorizontalFlip(), # with 0.5 probability
- torchvision.transforms.RandomApply([color_jitter], p=0.8),
- torchvision.transforms.RandomGrayscale(p=0.2),
- torchvision.transforms.ToTensor(),
- ]
- )
- self.test_transform = torchvision.transforms.Compose(
- [
- torchvision.transforms.Resize(size=size),
- torchvision.transforms.ToTensor(),
- ]
- )
- def __call__(self, x):
- return self.train_transform(x), self.train_transform(x)
- class SimCLRTransform:
- """
- A stochastic data augmentation module that transforms any given data example randomly
- resulting in two correlated views of the same example,
- denoted x ̃i and x ̃j, which we consider as a positive pair.
- data_format is array or image
- """
- def __init__(self, size=32, gaussian=False, data_format="array"):
- s = 1
- color_jitter = torchvision.transforms.ColorJitter(
- 0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
- )
- if gaussian:
- self.train_transform = torchvision.transforms.Compose(
- [
- torchvision.transforms.ToPILImage(mode='RGB'),
- # torchvision.transforms.Resize(size=size),
- torchvision.transforms.RandomResizedCrop(size=size),
- torchvision.transforms.RandomHorizontalFlip(), # with 0.5 probability
- torchvision.transforms.RandomApply([color_jitter], p=0.8),
- torchvision.transforms.RandomGrayscale(p=0.2),
- GaussianBlur(kernel_size=int(0.1 * size)),
- # RandomApply(torchvision.transforms.GaussianBlur((3, 3), (1.0, 2.0)), p=0.2),
- torchvision.transforms.ToTensor(),
- ]
- )
- else:
- if data_format == "array":
- self.train_transform = torchvision.transforms.Compose(
- [
- torchvision.transforms.ToPILImage(mode='RGB'),
- # torchvision.transforms.Resize(size=size),
- torchvision.transforms.RandomResizedCrop(size=size),
- torchvision.transforms.RandomHorizontalFlip(), # with 0.5 probability
- torchvision.transforms.RandomApply([color_jitter], p=0.8),
- torchvision.transforms.RandomGrayscale(p=0.2),
- torchvision.transforms.ToTensor(),
- ]
- )
- else:
- self.train_transform = torchvision.transforms.Compose(
- [
- torchvision.transforms.RandomResizedCrop(size=size),
- torchvision.transforms.RandomHorizontalFlip(), # with 0.5 probability
- torchvision.transforms.RandomApply([color_jitter], p=0.8),
- torchvision.transforms.RandomGrayscale(p=0.2),
- torchvision.transforms.ToTensor(),
- ]
- )
- self.test_transform = torchvision.transforms.Compose(
- [
- torchvision.transforms.Resize(size=size),
- torchvision.transforms.ToTensor(),
- ]
- )
- self.fine_tune_transform = torchvision.transforms.Compose(
- [
- torchvision.transforms.ToPILImage(mode='RGB'),
- torchvision.transforms.Resize(size=size),
- torchvision.transforms.ToTensor(),
- ]
- )
- def __call__(self, x):
- return self.train_transform(x), self.train_transform(x)
- class GaussianBlur(object):
- """blur a single image on CPU"""
- def __init__(self, kernel_size):
- radias = kernel_size // 2
- kernel_size = radias * 2 + 1
- self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
- stride=1, padding=0, bias=False, groups=3)
- self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
- stride=1, padding=0, bias=False, groups=3)
- self.k = kernel_size
- self.r = radias
- self.blur = nn.Sequential(
- nn.ReflectionPad2d(radias),
- self.blur_h,
- self.blur_v
- )
- self.pil_to_tensor = transforms.ToTensor()
- self.tensor_to_pil = transforms.ToPILImage()
- def __call__(self, img):
- img = self.pil_to_tensor(img).unsqueeze(0)
- sigma = np.random.uniform(0.1, 2.0)
- x = np.arange(-self.r, self.r + 1)
- x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
- x = x / x.sum()
- x = torch.from_numpy(x).view(1, -1).repeat(3, 1)
- self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
- self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))
- with torch.no_grad():
- img = self.blur(img)
- img = img.squeeze()
- img = self.tensor_to_pil(img)
- return img
|