transform.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import numpy as np
  2. import torch
  3. import torchvision
  4. from torch import nn
  5. from torchvision import transforms
  6. class MoCoTransform:
  7. """
  8. A stochastic data augmentation module that transforms any given data example randomly
  9. resulting in two correlated views of the same example,
  10. denoted x ̃i and x ̃j, which we consider as a positive pair.
  11. """
  12. def __init__(self, size=32, gaussian=False):
  13. self.train_transform = transforms.Compose([
  14. transforms.ToPILImage(mode='RGB'),
  15. transforms.RandomResizedCrop(size),
  16. transforms.RandomHorizontalFlip(p=0.5),
  17. transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
  18. transforms.RandomGrayscale(p=0.2),
  19. transforms.ToTensor()])
  20. self.test_transform = transforms.Compose([
  21. torchvision.transforms.Resize(size=size),
  22. transforms.ToTensor()]
  23. )
  24. def __call__(self, x):
  25. return self.train_transform(x), self.train_transform(x)
  26. class SimSiamTransform:
  27. """
  28. A stochastic data augmentation module that transforms any given data example randomly
  29. resulting in two correlated views of the same example,
  30. denoted x ̃i and x ̃j, which we consider as a positive pair.
  31. """
  32. def __init__(self, size=32, gaussian=False):
  33. s = 1
  34. color_jitter = torchvision.transforms.ColorJitter(
  35. 0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
  36. )
  37. if gaussian:
  38. self.train_transform = torchvision.transforms.Compose(
  39. [
  40. torchvision.transforms.ToPILImage(mode='RGB'),
  41. torchvision.transforms.RandomResizedCrop(size=size),
  42. torchvision.transforms.RandomHorizontalFlip(), # with 0.5 probability
  43. torchvision.transforms.RandomApply([color_jitter], p=0.8),
  44. torchvision.transforms.RandomGrayscale(p=0.2),
  45. GaussianBlur(kernel_size=int(0.1 * size)),
  46. torchvision.transforms.ToTensor(),
  47. ]
  48. )
  49. else:
  50. self.train_transform = torchvision.transforms.Compose(
  51. [
  52. torchvision.transforms.ToPILImage(mode='RGB'),
  53. torchvision.transforms.RandomResizedCrop(size=size),
  54. torchvision.transforms.RandomHorizontalFlip(), # with 0.5 probability
  55. torchvision.transforms.RandomApply([color_jitter], p=0.8),
  56. torchvision.transforms.RandomGrayscale(p=0.2),
  57. torchvision.transforms.ToTensor(),
  58. ]
  59. )
  60. self.test_transform = torchvision.transforms.Compose(
  61. [
  62. torchvision.transforms.Resize(size=size),
  63. torchvision.transforms.ToTensor(),
  64. ]
  65. )
  66. def __call__(self, x):
  67. return self.train_transform(x), self.train_transform(x)
  68. class SimCLRTransform:
  69. """
  70. A stochastic data augmentation module that transforms any given data example randomly
  71. resulting in two correlated views of the same example,
  72. denoted x ̃i and x ̃j, which we consider as a positive pair.
  73. data_format is array or image
  74. """
  75. def __init__(self, size=32, gaussian=False, data_format="array"):
  76. s = 1
  77. color_jitter = torchvision.transforms.ColorJitter(
  78. 0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
  79. )
  80. if gaussian:
  81. self.train_transform = torchvision.transforms.Compose(
  82. [
  83. torchvision.transforms.ToPILImage(mode='RGB'),
  84. # torchvision.transforms.Resize(size=size),
  85. torchvision.transforms.RandomResizedCrop(size=size),
  86. torchvision.transforms.RandomHorizontalFlip(), # with 0.5 probability
  87. torchvision.transforms.RandomApply([color_jitter], p=0.8),
  88. torchvision.transforms.RandomGrayscale(p=0.2),
  89. GaussianBlur(kernel_size=int(0.1 * size)),
  90. # RandomApply(torchvision.transforms.GaussianBlur((3, 3), (1.0, 2.0)), p=0.2),
  91. torchvision.transforms.ToTensor(),
  92. ]
  93. )
  94. else:
  95. if data_format == "array":
  96. self.train_transform = torchvision.transforms.Compose(
  97. [
  98. torchvision.transforms.ToPILImage(mode='RGB'),
  99. # torchvision.transforms.Resize(size=size),
  100. torchvision.transforms.RandomResizedCrop(size=size),
  101. torchvision.transforms.RandomHorizontalFlip(), # with 0.5 probability
  102. torchvision.transforms.RandomApply([color_jitter], p=0.8),
  103. torchvision.transforms.RandomGrayscale(p=0.2),
  104. torchvision.transforms.ToTensor(),
  105. ]
  106. )
  107. else:
  108. self.train_transform = torchvision.transforms.Compose(
  109. [
  110. torchvision.transforms.RandomResizedCrop(size=size),
  111. torchvision.transforms.RandomHorizontalFlip(), # with 0.5 probability
  112. torchvision.transforms.RandomApply([color_jitter], p=0.8),
  113. torchvision.transforms.RandomGrayscale(p=0.2),
  114. torchvision.transforms.ToTensor(),
  115. ]
  116. )
  117. self.test_transform = torchvision.transforms.Compose(
  118. [
  119. torchvision.transforms.Resize(size=size),
  120. torchvision.transforms.ToTensor(),
  121. ]
  122. )
  123. self.fine_tune_transform = torchvision.transforms.Compose(
  124. [
  125. torchvision.transforms.ToPILImage(mode='RGB'),
  126. torchvision.transforms.Resize(size=size),
  127. torchvision.transforms.ToTensor(),
  128. ]
  129. )
  130. def __call__(self, x):
  131. return self.train_transform(x), self.train_transform(x)
  132. class GaussianBlur(object):
  133. """blur a single image on CPU"""
  134. def __init__(self, kernel_size):
  135. radias = kernel_size // 2
  136. kernel_size = radias * 2 + 1
  137. self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
  138. stride=1, padding=0, bias=False, groups=3)
  139. self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
  140. stride=1, padding=0, bias=False, groups=3)
  141. self.k = kernel_size
  142. self.r = radias
  143. self.blur = nn.Sequential(
  144. nn.ReflectionPad2d(radias),
  145. self.blur_h,
  146. self.blur_v
  147. )
  148. self.pil_to_tensor = transforms.ToTensor()
  149. self.tensor_to_pil = transforms.ToPILImage()
  150. def __call__(self, img):
  151. img = self.pil_to_tensor(img).unsqueeze(0)
  152. sigma = np.random.uniform(0.1, 2.0)
  153. x = np.arange(-self.r, self.r + 1)
  154. x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
  155. x = x / x.sum()
  156. x = torch.from_numpy(x).view(1, -1).repeat(3, 1)
  157. self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
  158. self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))
  159. with torch.no_grad():
  160. img = self.blur(img)
  161. img = img.squeeze()
  162. img = self.tensor_to_pil(img)
  163. return img