123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493 |
- import os
- import math
- import wandb
- import random
- import logging
- import inspect
- import argparse
- import datetime
- import subprocess
- from pathlib import Path
- from tqdm.auto import tqdm
- from einops import rearrange
- from omegaconf import OmegaConf
- from safetensors import safe_open
- from typing import Dict, Optional, Tuple
- import torch
- import torchvision
- import torch.nn.functional as F
- import torch.distributed as dist
- from torch.optim.swa_utils import AveragedModel
- from torch.utils.data.distributed import DistributedSampler
- from torch.nn.parallel import DistributedDataParallel as DDP
- import diffusers
- from diffusers import AutoencoderKL, DDIMScheduler
- from diffusers.models import UNet2DConditionModel
- from diffusers.pipelines import StableDiffusionPipeline
- from diffusers.optimization import get_scheduler
- from diffusers.utils import check_min_version
- from diffusers.utils.import_utils import is_xformers_available
- import transformers
- from transformers import CLIPTextModel, CLIPTokenizer
- from animatediff.data.dataset import WebVid10M
- from animatediff.models.unet import UNet3DConditionModel
- from animatediff.pipelines.pipeline_animation import AnimationPipeline
- from animatediff.utils.util import save_videos_grid, zero_rank_print
- def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs):
- """Initializes distributed environment."""
- if launcher == 'pytorch':
- rank = int(os.environ['RANK'])
- num_gpus = torch.cuda.device_count()
- local_rank = rank % num_gpus
- torch.cuda.set_device(local_rank)
- dist.init_process_group(backend=backend, **kwargs)
-
- elif launcher == 'slurm':
- proc_id = int(os.environ['SLURM_PROCID'])
- ntasks = int(os.environ['SLURM_NTASKS'])
- node_list = os.environ['SLURM_NODELIST']
- num_gpus = torch.cuda.device_count()
- local_rank = proc_id % num_gpus
- torch.cuda.set_device(local_rank)
- addr = subprocess.getoutput(
- f'scontrol show hostname {node_list} | head -n1')
- os.environ['MASTER_ADDR'] = addr
- os.environ['WORLD_SIZE'] = str(ntasks)
- os.environ['RANK'] = str(proc_id)
- port = os.environ.get('PORT', port)
- os.environ['MASTER_PORT'] = str(port)
- dist.init_process_group(backend=backend)
- zero_rank_print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}")
-
- else:
- raise NotImplementedError(f'Not implemented launcher type: `{launcher}`!')
-
- return local_rank
- def main(
- image_finetune: bool,
-
- name: str,
- use_wandb: bool,
- launcher: str,
-
- output_dir: str,
- pretrained_model_path: str,
- train_data: Dict,
- validation_data: Dict,
- cfg_random_null_text: bool = True,
- cfg_random_null_text_ratio: float = 0.1,
-
- unet_checkpoint_path: str = "",
- unet_additional_kwargs: Dict = {},
- ema_decay: float = 0.9999,
- noise_scheduler_kwargs = None,
-
- max_train_epoch: int = -1,
- max_train_steps: int = 100,
- validation_steps: int = 100,
- validation_steps_tuple: Tuple = (-1,),
- learning_rate: float = 3e-5,
- scale_lr: bool = False,
- lr_warmup_steps: int = 0,
- lr_scheduler: str = "constant",
- trainable_modules: Tuple[str] = (None, ),
- num_workers: int = 32,
- train_batch_size: int = 1,
- adam_beta1: float = 0.9,
- adam_beta2: float = 0.999,
- adam_weight_decay: float = 1e-2,
- adam_epsilon: float = 1e-08,
- max_grad_norm: float = 1.0,
- gradient_accumulation_steps: int = 1,
- gradient_checkpointing: bool = False,
- checkpointing_epochs: int = 5,
- checkpointing_steps: int = -1,
- mixed_precision_training: bool = True,
- enable_xformers_memory_efficient_attention: bool = True,
- global_seed: int = 42,
- is_debug: bool = False,
- ):
- check_min_version("0.10.0.dev0")
- # Initialize distributed training
- local_rank = init_dist(launcher=launcher)
- global_rank = dist.get_rank()
- num_processes = dist.get_world_size()
- is_main_process = global_rank == 0
- seed = global_seed + global_rank
- torch.manual_seed(seed)
-
- # Logging folder
- folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S")
- output_dir = os.path.join(output_dir, folder_name)
- if is_debug and os.path.exists(output_dir):
- os.system(f"rm -rf {output_dir}")
- *_, config = inspect.getargvalues(inspect.currentframe())
- # Make one log on every process with the configuration for debugging.
- logging.basicConfig(
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
- datefmt="%m/%d/%Y %H:%M:%S",
- level=logging.INFO,
- )
- if is_main_process and (not is_debug) and use_wandb:
- run = wandb.init(project="animatediff", name=folder_name, config=config)
- # Handle the output folder creation
- if is_main_process:
- os.makedirs(output_dir, exist_ok=True)
- os.makedirs(f"{output_dir}/samples", exist_ok=True)
- os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
- os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
- OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
- # Load scheduler, tokenizer and models.
- noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
- vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
- tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
- text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
- if not image_finetune:
- unet = UNet3DConditionModel.from_pretrained_2d(
- pretrained_model_path, subfolder="unet",
- unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs)
- )
- else:
- unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet")
-
- # Load pretrained unet weights
- if unet_checkpoint_path != "":
- zero_rank_print(f"from checkpoint: {unet_checkpoint_path}")
- unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
- if "global_step" in unet_checkpoint_path: zero_rank_print(f"global_step: {unet_checkpoint_path['global_step']}")
- state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path
- m, u = unet.load_state_dict(state_dict, strict=False)
- zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
- assert len(u) == 0
-
- # Freeze vae and text_encoder
- vae.requires_grad_(False)
- text_encoder.requires_grad_(False)
-
- # Set unet trainable parameters
- unet.requires_grad_(False)
- for name, param in unet.named_parameters():
- for trainable_module_name in trainable_modules:
- if trainable_module_name in name:
- param.requires_grad = True
- break
-
- trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters()))
- optimizer = torch.optim.AdamW(
- trainable_params,
- lr=learning_rate,
- betas=(adam_beta1, adam_beta2),
- weight_decay=adam_weight_decay,
- eps=adam_epsilon,
- )
- if is_main_process:
- zero_rank_print(f"trainable params number: {len(trainable_params)}")
- zero_rank_print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
- # Enable xformers
- if enable_xformers_memory_efficient_attention:
- if is_xformers_available():
- unet.enable_xformers_memory_efficient_attention()
- else:
- raise ValueError("xformers is not available. Make sure it is installed correctly")
- # Enable gradient checkpointing
- if gradient_checkpointing:
- unet.enable_gradient_checkpointing()
- # Move models to GPU
- vae.to(local_rank)
- text_encoder.to(local_rank)
- # Get the training dataset
- train_dataset = WebVid10M(**train_data, is_image=image_finetune)
- distributed_sampler = DistributedSampler(
- train_dataset,
- num_replicas=num_processes,
- rank=global_rank,
- shuffle=True,
- seed=global_seed,
- )
- # DataLoaders creation:
- train_dataloader = torch.utils.data.DataLoader(
- train_dataset,
- batch_size=train_batch_size,
- shuffle=False,
- sampler=distributed_sampler,
- num_workers=num_workers,
- pin_memory=True,
- drop_last=True,
- )
- # Get the training iteration
- if max_train_steps == -1:
- assert max_train_epoch != -1
- max_train_steps = max_train_epoch * len(train_dataloader)
-
- if checkpointing_steps == -1:
- assert checkpointing_epochs != -1
- checkpointing_steps = checkpointing_epochs * len(train_dataloader)
- if scale_lr:
- learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes)
- # Scheduler
- lr_scheduler = get_scheduler(
- lr_scheduler,
- optimizer=optimizer,
- num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
- num_training_steps=max_train_steps * gradient_accumulation_steps,
- )
- # Validation pipeline
- if not image_finetune:
- validation_pipeline = AnimationPipeline(
- unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler,
- ).to("cuda")
- else:
- validation_pipeline = StableDiffusionPipeline.from_pretrained(
- pretrained_model_path,
- unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, safety_checker=None,
- )
- validation_pipeline.enable_vae_slicing()
- # DDP warpper
- unet.to(local_rank)
- unet = DDP(unet, device_ids=[local_rank], output_device=local_rank)
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
- # Afterwards we recalculate our number of training epochs
- num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
- # Train!
- total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps
- if is_main_process:
- logging.info("***** Running training *****")
- logging.info(f" Num examples = {len(train_dataset)}")
- logging.info(f" Num Epochs = {num_train_epochs}")
- logging.info(f" Instantaneous batch size per device = {train_batch_size}")
- logging.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
- logging.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
- logging.info(f" Total optimization steps = {max_train_steps}")
- global_step = 0
- first_epoch = 0
- # Only show the progress bar once on each machine.
- progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process)
- progress_bar.set_description("Steps")
- # Support mixed-precision training
- scaler = torch.cuda.amp.GradScaler() if mixed_precision_training else None
- for epoch in range(first_epoch, num_train_epochs):
- train_dataloader.sampler.set_epoch(epoch)
- unet.train()
-
- for step, batch in enumerate(train_dataloader):
- if cfg_random_null_text:
- batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']]
-
- # Data batch sanity check
- if epoch == first_epoch and step == 0:
- pixel_values, texts = batch['pixel_values'].cpu(), batch['text']
- if not image_finetune:
- pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
- for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
- pixel_value = pixel_value[None, ...]
- save_videos_grid(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.gif", rescale=True)
- else:
- for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
- pixel_value = pixel_value / 2. + 0.5
- torchvision.utils.save_image(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.png")
-
- ### >>>> Training >>>> ###
-
- # Convert videos to latent space
- pixel_values = batch["pixel_values"].to(local_rank)
- video_length = pixel_values.shape[1]
- with torch.no_grad():
- if not image_finetune:
- pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
- latents = vae.encode(pixel_values).latent_dist
- latents = latents.sample()
- latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
- else:
- latents = vae.encode(pixel_values).latent_dist
- latents = latents.sample()
- latents = latents * 0.18215
- # Sample noise that we'll add to the latents
- noise = torch.randn_like(latents)
- bsz = latents.shape[0]
-
- # Sample a random timestep for each video
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
- timesteps = timesteps.long()
-
- # Add noise to the latents according to the noise magnitude at each timestep
- # (this is the forward diffusion process)
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
-
- # Get the text embedding for conditioning
- with torch.no_grad():
- prompt_ids = tokenizer(
- batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
- ).input_ids.to(latents.device)
- encoder_hidden_states = text_encoder(prompt_ids)[0]
-
- # Get the target for loss depending on the prediction type
- if noise_scheduler.config.prediction_type == "epsilon":
- target = noise
- elif noise_scheduler.config.prediction_type == "v_prediction":
- raise NotImplementedError
- else:
- raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
- # Predict the noise residual and compute loss
- # Mixed-precision training
- with torch.cuda.amp.autocast(enabled=mixed_precision_training):
- model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
- optimizer.zero_grad()
- # Backpropagate
- if mixed_precision_training:
- scaler.scale(loss).backward()
- """ >>> gradient clipping >>> """
- scaler.unscale_(optimizer)
- torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
- """ <<< gradient clipping <<< """
- scaler.step(optimizer)
- scaler.update()
- else:
- loss.backward()
- """ >>> gradient clipping >>> """
- torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
- """ <<< gradient clipping <<< """
- optimizer.step()
- lr_scheduler.step()
- progress_bar.update(1)
- global_step += 1
-
- ### <<<< Training <<<< ###
-
- # Wandb logging
- if is_main_process and (not is_debug) and use_wandb:
- wandb.log({"train_loss": loss.item()}, step=global_step)
-
- # Save checkpoint
- if is_main_process and (global_step % checkpointing_steps == 0 or step == len(train_dataloader) - 1):
- save_path = os.path.join(output_dir, f"checkpoints")
- state_dict = {
- "epoch": epoch,
- "global_step": global_step,
- "state_dict": unet.state_dict(),
- }
- if step == len(train_dataloader) - 1:
- torch.save(state_dict, os.path.join(save_path, f"checkpoint-epoch-{epoch+1}.ckpt"))
- else:
- torch.save(state_dict, os.path.join(save_path, f"checkpoint.ckpt"))
- logging.info(f"Saved state to {save_path} (global_step: {global_step})")
-
- # Periodically validation
- if is_main_process and (global_step % validation_steps == 0 or global_step in validation_steps_tuple):
- samples = []
-
- generator = torch.Generator(device=latents.device)
- generator.manual_seed(global_seed)
-
- height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size
- width = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size
- prompts = validation_data.prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.prompts
- for idx, prompt in enumerate(prompts):
- if not image_finetune:
- sample = validation_pipeline(
- prompt,
- generator = generator,
- video_length = train_data.sample_n_frames,
- height = height,
- width = width,
- **validation_data,
- ).videos
- save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif")
- samples.append(sample)
-
- else:
- sample = validation_pipeline(
- prompt,
- generator = generator,
- height = height,
- width = width,
- num_inference_steps = validation_data.get("num_inference_steps", 25),
- guidance_scale = validation_data.get("guidance_scale", 8.),
- ).images[0]
- sample = torchvision.transforms.functional.to_tensor(sample)
- samples.append(sample)
-
- if not image_finetune:
- samples = torch.concat(samples)
- save_path = f"{output_dir}/samples/sample-{global_step}.gif"
- save_videos_grid(samples, save_path)
-
- else:
- samples = torch.stack(samples)
- save_path = f"{output_dir}/samples/sample-{global_step}.png"
- torchvision.utils.save_image(samples, save_path, nrow=4)
- logging.info(f"Saved samples to {save_path}")
-
- logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
- progress_bar.set_postfix(**logs)
-
- if global_step >= max_train_steps:
- break
-
- dist.destroy_process_group()
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--config", type=str, required=True)
- parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch")
- parser.add_argument("--wandb", action="store_true")
- args = parser.parse_args()
- name = Path(args.config).stem
- config = OmegaConf.load(args.config)
- main(name=name, launcher=args.launcher, use_wandb=args.wandb, **config)
|