transforms.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from __future__ import absolute_import
  2. import math
  3. import random
  4. from PIL import Image
  5. from torchvision import transforms
  6. class RectScale(object):
  7. def __init__(self, height, width, interpolation=Image.BILINEAR):
  8. self.height = height
  9. self.width = width
  10. self.interpolation = interpolation
  11. def __call__(self, img):
  12. w, h = img.size
  13. if h == self.height and w == self.width:
  14. return img
  15. return img.resize((self.width, self.height), self.interpolation)
  16. class RandomSizedRectCrop(object):
  17. def __init__(self, height, width, interpolation=Image.BILINEAR):
  18. self.height = height
  19. self.width = width
  20. self.interpolation = interpolation
  21. def __call__(self, img):
  22. for attempt in range(10):
  23. area = img.size[0] * img.size[1]
  24. target_area = random.uniform(0.64, 1.0) * area
  25. aspect_ratio = random.uniform(2, 3)
  26. h = int(round(math.sqrt(target_area * aspect_ratio)))
  27. w = int(round(math.sqrt(target_area / aspect_ratio)))
  28. if w <= img.size[0] and h <= img.size[1]:
  29. x1 = random.randint(0, img.size[0] - w)
  30. y1 = random.randint(0, img.size[1] - h)
  31. img = img.crop((x1, y1, x1 + w, y1 + h))
  32. assert (img.size == (w, h))
  33. return img.resize((self.width, self.height), self.interpolation)
  34. # Fallback
  35. scale = RectScale(self.height, self.width,
  36. interpolation=self.interpolation)
  37. return scale(img)
  38. normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  39. TRANSFORM_TRAIN_LIST = transforms.Compose([
  40. RandomSizedRectCrop(256, 128),
  41. transforms.RandomHorizontalFlip(),
  42. transforms.ToTensor(),
  43. normalizer,
  44. ])
  45. TRANSFORM_VAL_LIST = transformer = transforms.Compose([
  46. RectScale(256, 128),
  47. transforms.ToTensor(),
  48. normalizer,
  49. ])