12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- import numpy as np
- import torch
- import torchvision
- from torchvision import transforms
- class Cutout(object):
- """Cutout data augmentation is adopted from https://github.com/uoguelph-mlrg/Cutout"""
- def __init__(self, length=16):
- self.length = length
- def __call__(self, img):
- """
- Args:
- img (Tensor): Tensor image of size (C, H, W).
- Returns:
- Tensor: Image with n_holes of dimension length x length cut out of it.
- """
- h = img.size(1)
- w = img.size(2)
- mask = np.ones((h, w), np.float32)
- y = np.random.randint(h)
- x = np.random.randint(w)
- y1 = np.clip(y - self.length // 2, 0, h)
- y2 = np.clip(y + self.length // 2, 0, h)
- x1 = np.clip(x - self.length // 2, 0, w)
- x2 = np.clip(x + self.length // 2, 0, w)
- mask[y1: y2, x1: x2] = 0.
- mask = torch.from_numpy(mask)
- mask = mask.expand_as(img)
- img *= mask
- return img
- transform_train = transforms.Compose([
- torchvision.transforms.ToPILImage(mode='RGB'),
- transforms.RandomCrop(32, padding=4),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
- ])
- transform_train.transforms.append(Cutout())
- transform_test = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
- ])
|