cifar10.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import numpy as np
  2. import torch
  3. import torchvision
  4. from torchvision import transforms
  5. class Cutout(object):
  6. """Cutout data augmentation is adopted from https://github.com/uoguelph-mlrg/Cutout"""
  7. def __init__(self, length=16):
  8. self.length = length
  9. def __call__(self, img):
  10. """
  11. Args:
  12. img (Tensor): Tensor image of size (C, H, W).
  13. Returns:
  14. Tensor: Image with n_holes of dimension length x length cut out of it.
  15. """
  16. h = img.size(1)
  17. w = img.size(2)
  18. mask = np.ones((h, w), np.float32)
  19. y = np.random.randint(h)
  20. x = np.random.randint(w)
  21. y1 = np.clip(y - self.length // 2, 0, h)
  22. y2 = np.clip(y + self.length // 2, 0, h)
  23. x1 = np.clip(x - self.length // 2, 0, w)
  24. x2 = np.clip(x + self.length // 2, 0, w)
  25. mask[y1: y2, x1: x2] = 0.
  26. mask = torch.from_numpy(mask)
  27. mask = mask.expand_as(img)
  28. img *= mask
  29. return img
  30. transform_train = transforms.Compose([
  31. torchvision.transforms.ToPILImage(mode='RGB'),
  32. transforms.RandomCrop(32, padding=4),
  33. transforms.RandomHorizontalFlip(),
  34. transforms.ToTensor(),
  35. transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
  36. ])
  37. transform_train.transforms.append(Cutout())
  38. transform_test = transforms.Compose([
  39. transforms.ToTensor(),
  40. transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
  41. ])