dataset.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import os, io, csv, math, random
  2. import numpy as np
  3. from einops import rearrange
  4. from decord import VideoReader
  5. import torch
  6. import torchvision.transforms as transforms
  7. from torch.utils.data.dataset import Dataset
  8. from animatediff.utils.util import zero_rank_print
  9. class WebVid10M(Dataset):
  10. def __init__(
  11. self,
  12. csv_path, video_folder,
  13. sample_size=256, sample_stride=4, sample_n_frames=16,
  14. is_image=False,
  15. ):
  16. zero_rank_print(f"loading annotations from {csv_path} ...")
  17. with open(csv_path, 'r') as csvfile:
  18. self.dataset = list(csv.DictReader(csvfile))
  19. self.length = len(self.dataset)
  20. zero_rank_print(f"data scale: {self.length}")
  21. self.video_folder = video_folder
  22. self.sample_stride = sample_stride
  23. self.sample_n_frames = sample_n_frames
  24. self.is_image = is_image
  25. sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
  26. self.pixel_transforms = transforms.Compose([
  27. transforms.RandomHorizontalFlip(),
  28. transforms.Resize(sample_size[0]),
  29. transforms.CenterCrop(sample_size),
  30. transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
  31. ])
  32. def get_batch(self, idx):
  33. video_dict = self.dataset[idx]
  34. videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
  35. video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
  36. video_reader = VideoReader(video_dir)
  37. video_length = len(video_reader)
  38. if not self.is_image:
  39. clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
  40. start_idx = random.randint(0, video_length - clip_length)
  41. batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
  42. else:
  43. batch_index = [random.randint(0, video_length - 1)]
  44. pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
  45. pixel_values = pixel_values / 255.
  46. del video_reader
  47. if self.is_image:
  48. pixel_values = pixel_values[0]
  49. return pixel_values, name
  50. def __len__(self):
  51. return self.length
  52. def __getitem__(self, idx):
  53. while True:
  54. try:
  55. pixel_values, name = self.get_batch(idx)
  56. break
  57. except Exception as e:
  58. idx = random.randint(0, self.length-1)
  59. pixel_values = self.pixel_transforms(pixel_values)
  60. sample = dict(pixel_values=pixel_values, text=name)
  61. return sample
  62. if __name__ == "__main__":
  63. from animatediff.utils.util import save_videos_grid
  64. dataset = WebVid10M(
  65. csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
  66. video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
  67. sample_size=256,
  68. sample_stride=4, sample_n_frames=16,
  69. is_image=True,
  70. )
  71. import pdb
  72. pdb.set_trace()
  73. dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,)
  74. for idx, batch in enumerate(dataloader):
  75. print(batch["pixel_values"].shape, len(batch["text"]))
  76. # for i in range(batch["pixel_values"].shape[0]):
  77. # save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True)