train-checkpoint.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. import os
  2. import math
  3. import wandb
  4. import random
  5. import logging
  6. import inspect
  7. import argparse
  8. import datetime
  9. import subprocess
  10. from pathlib import Path
  11. from tqdm.auto import tqdm
  12. from einops import rearrange
  13. from omegaconf import OmegaConf
  14. from safetensors import safe_open
  15. from typing import Dict, Optional, Tuple
  16. import torch
  17. import torchvision
  18. import torch.nn.functional as F
  19. import torch.distributed as dist
  20. from torch.optim.swa_utils import AveragedModel
  21. from torch.utils.data.distributed import DistributedSampler
  22. from torch.nn.parallel import DistributedDataParallel as DDP
  23. import diffusers
  24. from diffusers import AutoencoderKL, DDIMScheduler
  25. from diffusers.models import UNet2DConditionModel
  26. from diffusers.pipelines import StableDiffusionPipeline
  27. from diffusers.optimization import get_scheduler
  28. from diffusers.utils import check_min_version
  29. from diffusers.utils.import_utils import is_xformers_available
  30. import transformers
  31. from transformers import CLIPTextModel, CLIPTokenizer
  32. from animatediff.data.dataset import WebVid10M
  33. from animatediff.models.unet import UNet3DConditionModel
  34. from animatediff.pipelines.pipeline_animation import AnimationPipeline
  35. from animatediff.utils.util import save_videos_grid, zero_rank_print
  36. def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs):
  37. """Initializes distributed environment."""
  38. if launcher == 'pytorch':
  39. rank = int(os.environ['RANK'])
  40. num_gpus = torch.cuda.device_count()
  41. local_rank = rank % num_gpus
  42. torch.cuda.set_device(local_rank)
  43. dist.init_process_group(backend=backend, **kwargs)
  44. elif launcher == 'slurm':
  45. proc_id = int(os.environ['SLURM_PROCID'])
  46. ntasks = int(os.environ['SLURM_NTASKS'])
  47. node_list = os.environ['SLURM_NODELIST']
  48. num_gpus = torch.cuda.device_count()
  49. local_rank = proc_id % num_gpus
  50. torch.cuda.set_device(local_rank)
  51. addr = subprocess.getoutput(
  52. f'scontrol show hostname {node_list} | head -n1')
  53. os.environ['MASTER_ADDR'] = addr
  54. os.environ['WORLD_SIZE'] = str(ntasks)
  55. os.environ['RANK'] = str(proc_id)
  56. port = os.environ.get('PORT', port)
  57. os.environ['MASTER_PORT'] = str(port)
  58. dist.init_process_group(backend=backend)
  59. zero_rank_print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}")
  60. else:
  61. raise NotImplementedError(f'Not implemented launcher type: `{launcher}`!')
  62. return local_rank
  63. def main(
  64. image_finetune: bool,
  65. name: str,
  66. use_wandb: bool,
  67. launcher: str,
  68. output_dir: str,
  69. pretrained_model_path: str,
  70. train_data: Dict,
  71. validation_data: Dict,
  72. cfg_random_null_text: bool = True,
  73. cfg_random_null_text_ratio: float = 0.1,
  74. unet_checkpoint_path: str = "",
  75. unet_additional_kwargs: Dict = {},
  76. ema_decay: float = 0.9999,
  77. noise_scheduler_kwargs = None,
  78. max_train_epoch: int = -1,
  79. max_train_steps: int = 100,
  80. validation_steps: int = 100,
  81. validation_steps_tuple: Tuple = (-1,),
  82. learning_rate: float = 3e-5,
  83. scale_lr: bool = False,
  84. lr_warmup_steps: int = 0,
  85. lr_scheduler: str = "constant",
  86. trainable_modules: Tuple[str] = (None, ),
  87. num_workers: int = 32,
  88. train_batch_size: int = 1,
  89. adam_beta1: float = 0.9,
  90. adam_beta2: float = 0.999,
  91. adam_weight_decay: float = 1e-2,
  92. adam_epsilon: float = 1e-08,
  93. max_grad_norm: float = 1.0,
  94. gradient_accumulation_steps: int = 1,
  95. gradient_checkpointing: bool = False,
  96. checkpointing_epochs: int = 5,
  97. checkpointing_steps: int = -1,
  98. mixed_precision_training: bool = True,
  99. enable_xformers_memory_efficient_attention: bool = True,
  100. global_seed: int = 42,
  101. is_debug: bool = False,
  102. ):
  103. check_min_version("0.10.0.dev0")
  104. # Initialize distributed training
  105. local_rank = init_dist(launcher=launcher)
  106. global_rank = dist.get_rank()
  107. num_processes = dist.get_world_size()
  108. is_main_process = global_rank == 0
  109. seed = global_seed + global_rank
  110. torch.manual_seed(seed)
  111. # Logging folder
  112. folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S")
  113. output_dir = os.path.join(output_dir, folder_name)
  114. if is_debug and os.path.exists(output_dir):
  115. os.system(f"rm -rf {output_dir}")
  116. *_, config = inspect.getargvalues(inspect.currentframe())
  117. # Make one log on every process with the configuration for debugging.
  118. logging.basicConfig(
  119. format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
  120. datefmt="%m/%d/%Y %H:%M:%S",
  121. level=logging.INFO,
  122. )
  123. if is_main_process and (not is_debug) and use_wandb:
  124. run = wandb.init(project="animatediff", name=folder_name, config=config)
  125. # Handle the output folder creation
  126. if is_main_process:
  127. os.makedirs(output_dir, exist_ok=True)
  128. os.makedirs(f"{output_dir}/samples", exist_ok=True)
  129. os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
  130. os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
  131. OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
  132. # Load scheduler, tokenizer and models.
  133. noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
  134. vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
  135. tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
  136. text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
  137. if not image_finetune:
  138. unet = UNet3DConditionModel.from_pretrained_2d(
  139. pretrained_model_path, subfolder="unet",
  140. unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs)
  141. )
  142. else:
  143. unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet")
  144. # Load pretrained unet weights
  145. if unet_checkpoint_path != "":
  146. zero_rank_print(f"from checkpoint: {unet_checkpoint_path}")
  147. unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
  148. if "global_step" in unet_checkpoint_path: zero_rank_print(f"global_step: {unet_checkpoint_path['global_step']}")
  149. state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path
  150. m, u = unet.load_state_dict(state_dict, strict=False)
  151. zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
  152. assert len(u) == 0
  153. # Freeze vae and text_encoder
  154. vae.requires_grad_(False)
  155. text_encoder.requires_grad_(False)
  156. # Set unet trainable parameters
  157. unet.requires_grad_(False)
  158. for name, param in unet.named_parameters():
  159. for trainable_module_name in trainable_modules:
  160. if trainable_module_name in name:
  161. param.requires_grad = True
  162. break
  163. trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters()))
  164. optimizer = torch.optim.AdamW(
  165. trainable_params,
  166. lr=learning_rate,
  167. betas=(adam_beta1, adam_beta2),
  168. weight_decay=adam_weight_decay,
  169. eps=adam_epsilon,
  170. )
  171. if is_main_process:
  172. zero_rank_print(f"trainable params number: {len(trainable_params)}")
  173. zero_rank_print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
  174. # Enable xformers
  175. if enable_xformers_memory_efficient_attention:
  176. if is_xformers_available():
  177. unet.enable_xformers_memory_efficient_attention()
  178. else:
  179. raise ValueError("xformers is not available. Make sure it is installed correctly")
  180. # Enable gradient checkpointing
  181. if gradient_checkpointing:
  182. unet.enable_gradient_checkpointing()
  183. # Move models to GPU
  184. vae.to(local_rank)
  185. text_encoder.to(local_rank)
  186. # Get the training dataset
  187. train_dataset = WebVid10M(**train_data, is_image=image_finetune)
  188. distributed_sampler = DistributedSampler(
  189. train_dataset,
  190. num_replicas=num_processes,
  191. rank=global_rank,
  192. shuffle=True,
  193. seed=global_seed,
  194. )
  195. # DataLoaders creation:
  196. train_dataloader = torch.utils.data.DataLoader(
  197. train_dataset,
  198. batch_size=train_batch_size,
  199. shuffle=False,
  200. sampler=distributed_sampler,
  201. num_workers=num_workers,
  202. pin_memory=True,
  203. drop_last=True,
  204. )
  205. # Get the training iteration
  206. if max_train_steps == -1:
  207. assert max_train_epoch != -1
  208. max_train_steps = max_train_epoch * len(train_dataloader)
  209. if checkpointing_steps == -1:
  210. assert checkpointing_epochs != -1
  211. checkpointing_steps = checkpointing_epochs * len(train_dataloader)
  212. if scale_lr:
  213. learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes)
  214. # Scheduler
  215. lr_scheduler = get_scheduler(
  216. lr_scheduler,
  217. optimizer=optimizer,
  218. num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
  219. num_training_steps=max_train_steps * gradient_accumulation_steps,
  220. )
  221. # Validation pipeline
  222. if not image_finetune:
  223. validation_pipeline = AnimationPipeline(
  224. unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler,
  225. ).to("cuda")
  226. else:
  227. validation_pipeline = StableDiffusionPipeline.from_pretrained(
  228. pretrained_model_path,
  229. unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, safety_checker=None,
  230. )
  231. validation_pipeline.enable_vae_slicing()
  232. # DDP warpper
  233. unet.to(local_rank)
  234. unet = DDP(unet, device_ids=[local_rank], output_device=local_rank)
  235. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
  236. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
  237. # Afterwards we recalculate our number of training epochs
  238. num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
  239. # Train!
  240. total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps
  241. if is_main_process:
  242. logging.info("***** Running training *****")
  243. logging.info(f" Num examples = {len(train_dataset)}")
  244. logging.info(f" Num Epochs = {num_train_epochs}")
  245. logging.info(f" Instantaneous batch size per device = {train_batch_size}")
  246. logging.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
  247. logging.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
  248. logging.info(f" Total optimization steps = {max_train_steps}")
  249. global_step = 0
  250. first_epoch = 0
  251. # Only show the progress bar once on each machine.
  252. progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process)
  253. progress_bar.set_description("Steps")
  254. # Support mixed-precision training
  255. scaler = torch.cuda.amp.GradScaler() if mixed_precision_training else None
  256. for epoch in range(first_epoch, num_train_epochs):
  257. train_dataloader.sampler.set_epoch(epoch)
  258. unet.train()
  259. for step, batch in enumerate(train_dataloader):
  260. if cfg_random_null_text:
  261. batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']]
  262. # Data batch sanity check
  263. if epoch == first_epoch and step == 0:
  264. pixel_values, texts = batch['pixel_values'].cpu(), batch['text']
  265. if not image_finetune:
  266. pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
  267. for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
  268. pixel_value = pixel_value[None, ...]
  269. save_videos_grid(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.gif", rescale=True)
  270. else:
  271. for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
  272. pixel_value = pixel_value / 2. + 0.5
  273. torchvision.utils.save_image(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.png")
  274. ### >>>> Training >>>> ###
  275. # Convert videos to latent space
  276. pixel_values = batch["pixel_values"].to(local_rank)
  277. video_length = pixel_values.shape[1]
  278. with torch.no_grad():
  279. if not image_finetune:
  280. pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
  281. latents = vae.encode(pixel_values).latent_dist
  282. latents = latents.sample()
  283. latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
  284. else:
  285. latents = vae.encode(pixel_values).latent_dist
  286. latents = latents.sample()
  287. latents = latents * 0.18215
  288. # Sample noise that we'll add to the latents
  289. noise = torch.randn_like(latents)
  290. bsz = latents.shape[0]
  291. # Sample a random timestep for each video
  292. timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
  293. timesteps = timesteps.long()
  294. # Add noise to the latents according to the noise magnitude at each timestep
  295. # (this is the forward diffusion process)
  296. noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
  297. # Get the text embedding for conditioning
  298. with torch.no_grad():
  299. prompt_ids = tokenizer(
  300. batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
  301. ).input_ids.to(latents.device)
  302. encoder_hidden_states = text_encoder(prompt_ids)[0]
  303. # Get the target for loss depending on the prediction type
  304. if noise_scheduler.config.prediction_type == "epsilon":
  305. target = noise
  306. elif noise_scheduler.config.prediction_type == "v_prediction":
  307. raise NotImplementedError
  308. else:
  309. raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
  310. # Predict the noise residual and compute loss
  311. # Mixed-precision training
  312. with torch.cuda.amp.autocast(enabled=mixed_precision_training):
  313. model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
  314. loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
  315. optimizer.zero_grad()
  316. # Backpropagate
  317. if mixed_precision_training:
  318. scaler.scale(loss).backward()
  319. """ >>> gradient clipping >>> """
  320. scaler.unscale_(optimizer)
  321. torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
  322. """ <<< gradient clipping <<< """
  323. scaler.step(optimizer)
  324. scaler.update()
  325. else:
  326. loss.backward()
  327. """ >>> gradient clipping >>> """
  328. torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
  329. """ <<< gradient clipping <<< """
  330. optimizer.step()
  331. lr_scheduler.step()
  332. progress_bar.update(1)
  333. global_step += 1
  334. ### <<<< Training <<<< ###
  335. # Wandb logging
  336. if is_main_process and (not is_debug) and use_wandb:
  337. wandb.log({"train_loss": loss.item()}, step=global_step)
  338. # Save checkpoint
  339. if is_main_process and (global_step % checkpointing_steps == 0 or step == len(train_dataloader) - 1):
  340. save_path = os.path.join(output_dir, f"checkpoints")
  341. state_dict = {
  342. "epoch": epoch,
  343. "global_step": global_step,
  344. "state_dict": unet.state_dict(),
  345. }
  346. if step == len(train_dataloader) - 1:
  347. torch.save(state_dict, os.path.join(save_path, f"checkpoint-epoch-{epoch+1}.ckpt"))
  348. else:
  349. torch.save(state_dict, os.path.join(save_path, f"checkpoint.ckpt"))
  350. logging.info(f"Saved state to {save_path} (global_step: {global_step})")
  351. # Periodically validation
  352. if is_main_process and (global_step % validation_steps == 0 or global_step in validation_steps_tuple):
  353. samples = []
  354. generator = torch.Generator(device=latents.device)
  355. generator.manual_seed(global_seed)
  356. height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size
  357. width = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size
  358. prompts = validation_data.prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.prompts
  359. for idx, prompt in enumerate(prompts):
  360. if not image_finetune:
  361. sample = validation_pipeline(
  362. prompt,
  363. generator = generator,
  364. video_length = train_data.sample_n_frames,
  365. height = height,
  366. width = width,
  367. **validation_data,
  368. ).videos
  369. save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif")
  370. samples.append(sample)
  371. else:
  372. sample = validation_pipeline(
  373. prompt,
  374. generator = generator,
  375. height = height,
  376. width = width,
  377. num_inference_steps = validation_data.get("num_inference_steps", 25),
  378. guidance_scale = validation_data.get("guidance_scale", 8.),
  379. ).images[0]
  380. sample = torchvision.transforms.functional.to_tensor(sample)
  381. samples.append(sample)
  382. if not image_finetune:
  383. samples = torch.concat(samples)
  384. save_path = f"{output_dir}/samples/sample-{global_step}.gif"
  385. save_videos_grid(samples, save_path)
  386. else:
  387. samples = torch.stack(samples)
  388. save_path = f"{output_dir}/samples/sample-{global_step}.png"
  389. torchvision.utils.save_image(samples, save_path, nrow=4)
  390. logging.info(f"Saved samples to {save_path}")
  391. logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
  392. progress_bar.set_postfix(**logs)
  393. if global_step >= max_train_steps:
  394. break
  395. dist.destroy_process_group()
  396. if __name__ == "__main__":
  397. parser = argparse.ArgumentParser()
  398. parser.add_argument("--config", type=str, required=True)
  399. parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch")
  400. parser.add_argument("--wandb", action="store_true")
  401. args = parser.parse_args()
  402. name = Path(args.config).stem
  403. config = OmegaConf.load(args.config)
  404. main(name=name, launcher=args.launcher, use_wandb=args.wandb, **config)