Browse Source

training script

Yuwei Guo 1 year ago
parent
commit
e816747d66
8 changed files with 744 additions and 1 deletions
  1. 2 0
      .gitignore
  2. 30 1
      README.md
  3. 98 0
      animatediff/data/dataset.py
  4. 5 0
      animatediff/utils/util.py
  5. 48 0
      configs/training/image_finetune.yaml
  6. 66 0
      configs/training/training.yaml
  7. 2 0
      environment.yaml
  8. 493 0
      train.py

+ 2 - 0
.gitignore

@@ -1,4 +1,6 @@
 samples/
+wandb/
+outputs/
 __pycache__/
 models/StableDiffusion/stable-diffusion-v1-5
 scripts/animate_inter.py

+ 30 - 1
README.md

@@ -63,7 +63,7 @@ Contributions are always welcome!! The <code>dev</code> branch is for community
 </details>
 
 
-## Setup for Inference
+## Setups for Inference
 
 ### Prepare Environment
 
@@ -139,6 +139,35 @@ Then run the following commands:
 python -m scripts.animate --config [path to the config file]
 ```
 
+
+## Steps for Training
+
+### Dataset
+Before training, download the videos files and the `.csv` annotations of [WebVid10M](https://maxbain.com/webvid-dataset/) to the local mechine.
+Note that our examplar training script requires all the videos to be saved in a single folder. You may change this by modifying `animatediff/data/dataset.py`.
+
+### Configuration
+After dataset preparations, update the below data paths in the config `.yaml` files in `configs/training/` folder:
+```
+train_data:
+  csv_path:     [Replace with .csv Annotation File Path]
+  video_folder: [Replace with Video Folder Path]
+  sample_size:  256
+```
+Other training parameters (lr, epochs, validation settings, etc.) are also included in the config files.
+
+### Training
+To train motion modules
+```
+torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/training.yaml
+```
+
+To finetune the unet's image layers
+```
+torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/image_finetune.yaml
+```
+
+
 ## Gradio Demo
 We have created a Gradio demo to make AnimateDiff easier to use. To launch the demo, please run the following commands:
 ```

+ 98 - 0
animatediff/data/dataset.py

@@ -0,0 +1,98 @@
+import os, io, csv, math, random
+import numpy as np
+from einops import rearrange
+from decord import VideoReader
+
+import torch
+import torchvision.transforms as transforms
+from torch.utils.data.dataset import Dataset
+from animatediff.utils.util import zero_rank_print
+
+
+
+class WebVid10M(Dataset):
+    def __init__(
+            self,
+            csv_path, video_folder,
+            sample_size=256, sample_stride=4, sample_n_frames=16,
+            is_image=False,
+        ):
+        zero_rank_print(f"loading annotations from {csv_path} ...")
+        with open(csv_path, 'r') as csvfile:
+            self.dataset = list(csv.DictReader(csvfile))
+        self.length = len(self.dataset)
+        zero_rank_print(f"data scale: {self.length}")
+
+        self.video_folder    = video_folder
+        self.sample_stride   = sample_stride
+        self.sample_n_frames = sample_n_frames
+        self.is_image        = is_image
+        
+        sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
+        self.pixel_transforms = transforms.Compose([
+            transforms.RandomHorizontalFlip(),
+            transforms.Resize(sample_size[0]),
+            transforms.CenterCrop(sample_size),
+            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+        ])
+    
+    def get_batch(self, idx):
+        video_dict = self.dataset[idx]
+        videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
+        
+        video_dir    = os.path.join(self.video_folder, f"{videoid}.mp4")
+        video_reader = VideoReader(video_dir)
+        video_length = len(video_reader)
+        
+        if not self.is_image:
+            clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
+            start_idx   = random.randint(0, video_length - clip_length)
+            batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
+        else:
+            batch_index = [random.randint(0, video_length - 1)]
+
+        pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
+        pixel_values = pixel_values / 255.
+        del video_reader
+
+        if self.is_image:
+            pixel_values = pixel_values[0]
+        
+        return pixel_values, name
+
+    def __len__(self):
+        return self.length
+
+    def __getitem__(self, idx):
+        while True:
+            try:
+                pixel_values, name = self.get_batch(idx)
+                break
+
+            except Exception as e:
+                idx = random.randint(0, self.length-1)
+
+        pixel_values = self.pixel_transforms(pixel_values)
+        sample = dict(pixel_values=pixel_values, text=name)
+        return sample
+
+
+
+if __name__ == "__main__":
+    from animatediff.utils.util import save_videos_grid
+
+    dataset = WebVid10M(
+        csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
+        video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
+        sample_size=256,
+        sample_stride=4, sample_n_frames=16,
+        is_image=True,
+    )
+    import pdb
+    pdb.set_trace()
+    
+    dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,)
+    for idx, batch in enumerate(dataloader):
+        print(batch["pixel_values"].shape, len(batch["text"]))
+        # for i in range(batch["pixel_values"].shape[0]):
+        #     save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True)

+ 5 - 0
animatediff/utils/util.py

@@ -5,11 +5,16 @@ from typing import Union
 
 import torch
 import torchvision
+import torch.distributed as dist
 
 from tqdm import tqdm
 from einops import rearrange
 
 
+def zero_rank_print(s):
+    if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
+
+
 def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
     videos = rearrange(videos, "b c t h w -> t b c h w")
     outputs = []

+ 48 - 0
configs/training/image_finetune.yaml

@@ -0,0 +1,48 @@
+image_finetune: true
+
+output_dir: "outputs"
+pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5"
+
+noise_scheduler_kwargs:
+  num_train_timesteps: 1000
+  beta_start:          0.00085
+  beta_end:            0.012
+  beta_schedule:       "scaled_linear"
+  steps_offset:        1
+  clip_sample:         false
+
+train_data:
+  csv_path:     "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv"
+  video_folder: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val"
+  sample_size:  256
+
+validation_data:
+  prompts:
+    - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons."
+    - "A drone view of celebration with Christma tree and fireworks, starry sky - background."
+    - "Robot dancing in times square."
+    - "Pacific coast, carmel by the sea ocean and waves."
+  num_inference_steps: 25
+  guidance_scale: 8.
+
+trainable_modules:
+  - "."
+
+unet_checkpoint_path: ""
+
+learning_rate:    1.e-5
+train_batch_size: 50
+
+max_train_epoch:      -1
+max_train_steps:      100
+checkpointing_epochs: -1
+checkpointing_steps:  60
+
+validation_steps:       5000
+validation_steps_tuple: [2, 50]
+
+global_seed: 42
+mixed_precision_training: true
+enable_xformers_memory_efficient_attention: True
+
+is_debug: False

+ 66 - 0
configs/training/training.yaml

@@ -0,0 +1,66 @@
+image_finetune: false
+
+output_dir: "outputs"
+pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5"
+
+unet_additional_kwargs:
+  use_motion_module              : true
+  motion_module_resolutions      : [ 1,2,4,8 ]
+  unet_use_cross_frame_attention : false
+  unet_use_temporal_attention    : false
+
+  motion_module_type: Vanilla
+  motion_module_kwargs:
+    num_attention_heads                : 8
+    num_transformer_block              : 1
+    attention_block_types              : [ "Temporal_Self", "Temporal_Self" ]
+    temporal_position_encoding         : true
+    temporal_position_encoding_max_len : 24
+    temporal_attention_dim_div         : 1
+    zero_initialize                    : true
+
+noise_scheduler_kwargs:
+  num_train_timesteps: 1000
+  beta_start:          0.00085
+  beta_end:            0.012
+  beta_schedule:       "linear"
+  steps_offset:        1
+  clip_sample:         false
+
+train_data:
+  csv_path:        "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv"
+  video_folder:    "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val"
+  sample_size:     256
+  sample_stride:   4
+  sample_n_frames: 16
+
+validation_data:
+  prompts:
+    - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons."
+    - "A drone view of celebration with Christma tree and fireworks, starry sky - background."
+    - "Robot dancing in times square."
+    - "Pacific coast, carmel by the sea ocean and waves."
+  num_inference_steps: 25
+  guidance_scale: 8.
+
+trainable_modules:
+  - "motion_modules."
+
+unet_checkpoint_path: ""
+
+learning_rate:    1.e-4
+train_batch_size: 4
+
+max_train_epoch:      -1
+max_train_steps:      100
+checkpointing_epochs: -1
+checkpointing_steps:  60
+
+validation_steps:       5000
+validation_steps_tuple: [2, 50]
+
+global_seed: 42
+mixed_precision_training: true
+enable_xformers_memory_efficient_attention: True
+
+is_debug: False

+ 2 - 0
environment.yaml

@@ -14,8 +14,10 @@ dependencies:
     - transformers==4.25.1
     - xformers==0.0.16
     - imageio==2.27.0
+    - decord==0.6.0
     - gdown
     - einops
     - omegaconf
     - safetensors
     - gradio
+    - wandb

+ 493 - 0
train.py

@@ -0,0 +1,493 @@
+import os
+import math
+import wandb
+import random
+import logging
+import inspect
+import argparse
+import datetime
+import subprocess
+
+from pathlib import Path
+from tqdm.auto import tqdm
+from einops import rearrange
+from omegaconf import OmegaConf
+from safetensors import safe_open
+from typing import Dict, Optional, Tuple
+
+import torch
+import torchvision
+import torch.nn.functional as F
+import torch.distributed as dist
+from torch.optim.swa_utils import AveragedModel
+from torch.utils.data.distributed import DistributedSampler
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+import diffusers
+from diffusers import AutoencoderKL, DDIMScheduler
+from diffusers.models import UNet2DConditionModel
+from diffusers.pipelines import StableDiffusionPipeline
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version
+from diffusers.utils.import_utils import is_xformers_available
+
+import transformers
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from animatediff.data.dataset import WebVid10M
+from animatediff.models.unet import UNet3DConditionModel
+from animatediff.pipelines.pipeline_animation import AnimationPipeline
+from animatediff.utils.util import save_videos_grid, zero_rank_print
+
+
+
+def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs):
+    """Initializes distributed environment."""
+    if launcher == 'pytorch':
+        rank = int(os.environ['RANK'])
+        num_gpus = torch.cuda.device_count()
+        local_rank = rank % num_gpus
+        torch.cuda.set_device(local_rank)
+        dist.init_process_group(backend=backend, **kwargs)
+        
+    elif launcher == 'slurm':
+        proc_id = int(os.environ['SLURM_PROCID'])
+        ntasks = int(os.environ['SLURM_NTASKS'])
+        node_list = os.environ['SLURM_NODELIST']
+        num_gpus = torch.cuda.device_count()
+        local_rank = proc_id % num_gpus
+        torch.cuda.set_device(local_rank)
+        addr = subprocess.getoutput(
+            f'scontrol show hostname {node_list} | head -n1')
+        os.environ['MASTER_ADDR'] = addr
+        os.environ['WORLD_SIZE'] = str(ntasks)
+        os.environ['RANK'] = str(proc_id)
+        port = os.environ.get('PORT', port)
+        os.environ['MASTER_PORT'] = str(port)
+        dist.init_process_group(backend=backend)
+        zero_rank_print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}")
+        
+    else:
+        raise NotImplementedError(f'Not implemented launcher type: `{launcher}`!')
+    
+    return local_rank
+
+
+
+def main(
+    image_finetune: bool,
+    
+    name: str,
+    use_wandb: bool,
+    launcher: str,
+    
+    output_dir: str,
+    pretrained_model_path: str,
+
+    train_data: Dict,
+    validation_data: Dict,
+    cfg_random_null_text: bool = True,
+    cfg_random_null_text_ratio: float = 0.1,
+    
+    unet_checkpoint_path: str = "",
+    unet_additional_kwargs: Dict = {},
+    ema_decay: float = 0.9999,
+    noise_scheduler_kwargs = None,
+    
+    max_train_epoch: int = -1,
+    max_train_steps: int = 100,
+    validation_steps: int = 100,
+    validation_steps_tuple: Tuple = (-1,),
+
+    learning_rate: float = 3e-5,
+    scale_lr: bool = False,
+    lr_warmup_steps: int = 0,
+    lr_scheduler: str = "constant",
+
+    trainable_modules: Tuple[str] = (None, ),
+    num_workers: int = 32,
+    train_batch_size: int = 1,
+    adam_beta1: float = 0.9,
+    adam_beta2: float = 0.999,
+    adam_weight_decay: float = 1e-2,
+    adam_epsilon: float = 1e-08,
+    max_grad_norm: float = 1.0,
+    gradient_accumulation_steps: int = 1,
+    gradient_checkpointing: bool = False,
+    checkpointing_epochs: int = 5,
+    checkpointing_steps: int = -1,
+
+    mixed_precision_training: bool = True,
+    enable_xformers_memory_efficient_attention: bool = True,
+
+    global_seed: int = 42,
+    is_debug: bool = False,
+):
+    check_min_version("0.10.0.dev0")
+
+    # Initialize distributed training
+    local_rank      = init_dist(launcher=launcher)
+    global_rank     = dist.get_rank()
+    num_processes   = dist.get_world_size()
+    is_main_process = global_rank == 0
+
+    seed = global_seed + global_rank
+    torch.manual_seed(seed)
+    
+    # Logging folder
+    folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S")
+    output_dir = os.path.join(output_dir, folder_name)
+    if is_debug and os.path.exists(output_dir):
+        os.system(f"rm -rf {output_dir}")
+
+    *_, config = inspect.getargvalues(inspect.currentframe())
+
+    # Make one log on every process with the configuration for debugging.
+    logging.basicConfig(
+        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+        datefmt="%m/%d/%Y %H:%M:%S",
+        level=logging.INFO,
+    )
+
+    if is_main_process and (not is_debug) and use_wandb:
+        run = wandb.init(project="animatediff", name=folder_name, config=config)
+
+    # Handle the output folder creation
+    if is_main_process:
+        os.makedirs(output_dir, exist_ok=True)
+        os.makedirs(f"{output_dir}/samples", exist_ok=True)
+        os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
+        os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
+        OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
+
+    # Load scheduler, tokenizer and models.
+    noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
+
+    vae          = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
+    tokenizer    = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
+    text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
+    if not image_finetune:
+        unet = UNet3DConditionModel.from_pretrained_2d(
+            pretrained_model_path, subfolder="unet", 
+            unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs)
+        )
+    else:
+        unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet")
+        
+    # Load pretrained unet weights
+    if unet_checkpoint_path != "":
+        zero_rank_print(f"from checkpoint: {unet_checkpoint_path}")
+        unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
+        if "global_step" in unet_checkpoint_path: zero_rank_print(f"global_step: {unet_checkpoint_path['global_step']}")
+        state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path
+
+        m, u = unet.load_state_dict(state_dict, strict=False)
+        zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+        assert len(u) == 0
+        
+    # Freeze vae and text_encoder
+    vae.requires_grad_(False)
+    text_encoder.requires_grad_(False)
+    
+    # Set unet trainable parameters
+    unet.requires_grad_(False)
+    for name, param in unet.named_parameters():
+        for trainable_module_name in trainable_modules:
+            if trainable_module_name in name:
+                param.requires_grad = True
+                break
+            
+    trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters()))
+    optimizer = torch.optim.AdamW(
+        trainable_params,
+        lr=learning_rate,
+        betas=(adam_beta1, adam_beta2),
+        weight_decay=adam_weight_decay,
+        eps=adam_epsilon,
+    )
+
+    if is_main_process:
+        zero_rank_print(f"trainable params number: {len(trainable_params)}")
+        zero_rank_print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
+
+    # Enable xformers
+    if enable_xformers_memory_efficient_attention:
+        if is_xformers_available():
+            unet.enable_xformers_memory_efficient_attention()
+        else:
+            raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+    # Enable gradient checkpointing
+    if gradient_checkpointing:
+        unet.enable_gradient_checkpointing()
+
+    # Move models to GPU
+    vae.to(local_rank)
+    text_encoder.to(local_rank)
+
+    # Get the training dataset
+    train_dataset = WebVid10M(**train_data, is_image=image_finetune)
+    distributed_sampler = DistributedSampler(
+        train_dataset,
+        num_replicas=num_processes,
+        rank=global_rank,
+        shuffle=True,
+        seed=global_seed,
+    )
+
+    # DataLoaders creation:
+    train_dataloader = torch.utils.data.DataLoader(
+        train_dataset,
+        batch_size=train_batch_size,
+        shuffle=False,
+        sampler=distributed_sampler,
+        num_workers=num_workers,
+        pin_memory=True,
+        drop_last=True,
+    )
+
+    # Get the training iteration
+    if max_train_steps == -1:
+        assert max_train_epoch != -1
+        max_train_steps = max_train_epoch * len(train_dataloader)
+        
+    if checkpointing_steps == -1:
+        assert checkpointing_epochs != -1
+        checkpointing_steps = checkpointing_epochs * len(train_dataloader)
+
+    if scale_lr:
+        learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes)
+
+    # Scheduler
+    lr_scheduler = get_scheduler(
+        lr_scheduler,
+        optimizer=optimizer,
+        num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
+        num_training_steps=max_train_steps * gradient_accumulation_steps,
+    )
+
+    # Validation pipeline
+    if not image_finetune:
+        validation_pipeline = AnimationPipeline(
+            unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler,
+        ).to("cuda")
+    else:
+        validation_pipeline = StableDiffusionPipeline.from_pretrained(
+            pretrained_model_path,
+            unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, safety_checker=None,
+        )
+    validation_pipeline.enable_vae_slicing()
+
+    # DDP warpper
+    unet.to(local_rank)
+    unet = DDP(unet, device_ids=[local_rank], output_device=local_rank)
+
+    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
+    # Afterwards we recalculate our number of training epochs
+    num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
+
+    # Train!
+    total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps
+
+    if is_main_process:
+        logging.info("***** Running training *****")
+        logging.info(f"  Num examples = {len(train_dataset)}")
+        logging.info(f"  Num Epochs = {num_train_epochs}")
+        logging.info(f"  Instantaneous batch size per device = {train_batch_size}")
+        logging.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+        logging.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
+        logging.info(f"  Total optimization steps = {max_train_steps}")
+    global_step = 0
+    first_epoch = 0
+
+    # Only show the progress bar once on each machine.
+    progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process)
+    progress_bar.set_description("Steps")
+
+    # Support mixed-precision training
+    scaler = torch.cuda.amp.GradScaler() if mixed_precision_training else None
+
+    for epoch in range(first_epoch, num_train_epochs):
+        train_dataloader.sampler.set_epoch(epoch)
+        unet.train()
+        
+        for step, batch in enumerate(train_dataloader):
+            if cfg_random_null_text:
+                batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']]
+                
+            # Data batch sanity check
+            if epoch == first_epoch and step == 0:
+                pixel_values, texts = batch['pixel_values'].cpu(), batch['text']
+                if not image_finetune:
+                    pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
+                    for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
+                        pixel_value = pixel_value[None, ...]
+                        save_videos_grid(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.gif", rescale=True)
+                else:
+                    for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
+                        pixel_value = pixel_value / 2. + 0.5
+                        torchvision.utils.save_image(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.png")
+                    
+            ### >>>> Training >>>> ###
+            
+            # Convert videos to latent space            
+            pixel_values = batch["pixel_values"].to(local_rank)
+            video_length = pixel_values.shape[1]
+            with torch.no_grad():
+                if not image_finetune:
+                    pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
+                    latents = vae.encode(pixel_values).latent_dist
+                    latents = latents.sample()
+                    latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
+                else:
+                    latents = vae.encode(pixel_values).latent_dist
+                    latents = latents.sample()
+
+                latents = latents * 0.18215
+
+            # Sample noise that we'll add to the latents
+            noise = torch.randn_like(latents)
+            bsz = latents.shape[0]
+            
+            # Sample a random timestep for each video
+            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+            timesteps = timesteps.long()
+            
+            # Add noise to the latents according to the noise magnitude at each timestep
+            # (this is the forward diffusion process)
+            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+            
+            # Get the text embedding for conditioning
+            with torch.no_grad():
+                prompt_ids = tokenizer(
+                    batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+                ).input_ids.to(latents.device)
+                encoder_hidden_states = text_encoder(prompt_ids)[0]
+                
+            # Get the target for loss depending on the prediction type
+            if noise_scheduler.config.prediction_type == "epsilon":
+                target = noise
+            elif noise_scheduler.config.prediction_type == "v_prediction":
+                raise NotImplementedError
+            else:
+                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+            # Predict the noise residual and compute loss
+            # Mixed-precision training
+            with torch.cuda.amp.autocast(enabled=mixed_precision_training):
+                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+            optimizer.zero_grad()
+
+            # Backpropagate
+            if mixed_precision_training:
+                scaler.scale(loss).backward()
+                """ >>> gradient clipping >>> """
+                scaler.unscale_(optimizer)
+                torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
+                """ <<< gradient clipping <<< """
+                scaler.step(optimizer)
+                scaler.update()
+            else:
+                loss.backward()
+                """ >>> gradient clipping >>> """
+                torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
+                """ <<< gradient clipping <<< """
+                optimizer.step()
+
+            lr_scheduler.step()
+            progress_bar.update(1)
+            global_step += 1
+            
+            ### <<<< Training <<<< ###
+            
+            # Wandb logging
+            if is_main_process and (not is_debug) and use_wandb:
+                wandb.log({"train_loss": loss.item()}, step=global_step)
+                
+            # Save checkpoint
+            if is_main_process and (global_step % checkpointing_steps == 0 or step == len(train_dataloader) - 1):
+                save_path = os.path.join(output_dir, f"checkpoints")
+                state_dict = {
+                    "epoch": epoch,
+                    "global_step": global_step,
+                    "state_dict": unet.state_dict(),
+                }
+                if step == len(train_dataloader) - 1:
+                    torch.save(state_dict, os.path.join(save_path, f"checkpoint-epoch-{epoch+1}.ckpt"))
+                else:
+                    torch.save(state_dict, os.path.join(save_path, f"checkpoint.ckpt"))
+                logging.info(f"Saved state to {save_path} (global_step: {global_step})")
+                
+            # Periodically validation
+            if is_main_process and (global_step % validation_steps == 0 or global_step in validation_steps_tuple):
+                samples = []
+                
+                generator = torch.Generator(device=latents.device)
+                generator.manual_seed(global_seed)
+                
+                height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size
+                width  = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size
+
+                prompts = validation_data.prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.prompts
+
+                for idx, prompt in enumerate(prompts):
+                    if not image_finetune:
+                        sample = validation_pipeline(
+                            prompt,
+                            generator    = generator,
+                            video_length = train_data.sample_n_frames,
+                            height       = height,
+                            width        = width,
+                            **validation_data,
+                        ).videos
+                        save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif")
+                        samples.append(sample)
+                        
+                    else:
+                        sample = validation_pipeline(
+                            prompt,
+                            generator           = generator,
+                            height              = height,
+                            width               = width,
+                            num_inference_steps = validation_data.get("num_inference_steps", 25),
+                            guidance_scale      = validation_data.get("guidance_scale", 8.),
+                        ).images[0]
+                        sample = torchvision.transforms.functional.to_tensor(sample)
+                        samples.append(sample)
+                
+                if not image_finetune:
+                    samples = torch.concat(samples)
+                    save_path = f"{output_dir}/samples/sample-{global_step}.gif"
+                    save_videos_grid(samples, save_path)
+                    
+                else:
+                    samples = torch.stack(samples)
+                    save_path = f"{output_dir}/samples/sample-{global_step}.png"
+                    torchvision.utils.save_image(samples, save_path, nrow=4)
+
+                logging.info(f"Saved samples to {save_path}")
+                
+            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+            progress_bar.set_postfix(**logs)
+            
+            if global_step >= max_train_steps:
+                break
+            
+    dist.destroy_process_group()
+
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--config",   type=str, required=True)
+    parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch")
+    parser.add_argument("--wandb",    action="store_true")
+    args = parser.parse_args()
+
+    name   = Path(args.config).stem
+    config = OmegaConf.load(args.config)
+
+    main(name=name, launcher=args.launcher, use_wandb=args.wandb, **config)