1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- import os, io, csv, math, random
- import numpy as np
- from einops import rearrange
- from decord import VideoReader
- import torch
- import torchvision.transforms as transforms
- from torch.utils.data.dataset import Dataset
- from animatediff.utils.util import zero_rank_print
- class WebVid10M(Dataset):
- def __init__(
- self,
- csv_path, video_folder,
- sample_size=256, sample_stride=4, sample_n_frames=16,
- is_image=False,
- ):
- zero_rank_print(f"loading annotations from {csv_path} ...")
- with open(csv_path, 'r') as csvfile:
- self.dataset = list(csv.DictReader(csvfile))
- self.length = len(self.dataset)
- zero_rank_print(f"data scale: {self.length}")
- self.video_folder = video_folder
- self.sample_stride = sample_stride
- self.sample_n_frames = sample_n_frames
- self.is_image = is_image
-
- sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
- self.pixel_transforms = transforms.Compose([
- transforms.RandomHorizontalFlip(),
- transforms.Resize(sample_size[0]),
- transforms.CenterCrop(sample_size),
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
- ])
-
- def get_batch(self, idx):
- video_dict = self.dataset[idx]
- videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
-
- video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
- video_reader = VideoReader(video_dir)
- video_length = len(video_reader)
-
- if not self.is_image:
- clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
- start_idx = random.randint(0, video_length - clip_length)
- batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
- else:
- batch_index = [random.randint(0, video_length - 1)]
- pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
- pixel_values = pixel_values / 255.
- del video_reader
- if self.is_image:
- pixel_values = pixel_values[0]
-
- return pixel_values, name
- def __len__(self):
- return self.length
- def __getitem__(self, idx):
- while True:
- try:
- pixel_values, name = self.get_batch(idx)
- break
- except Exception as e:
- idx = random.randint(0, self.length-1)
- pixel_values = self.pixel_transforms(pixel_values)
- sample = dict(pixel_values=pixel_values, text=name)
- return sample
- if __name__ == "__main__":
- from animatediff.utils.util import save_videos_grid
- dataset = WebVid10M(
- csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
- video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
- sample_size=256,
- sample_stride=4, sample_n_frames=16,
- is_image=True,
- )
- import pdb
- pdb.set_trace()
-
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,)
- for idx, batch in enumerate(dataloader):
- print(batch["pixel_values"].shape, len(batch["text"]))
- # for i in range(batch["pixel_values"].shape[0]):
- # save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True)
|