test_train.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. import os
  2. import math
  3. import wandb
  4. import random
  5. import logging
  6. import inspect
  7. import argparse
  8. import datetime
  9. import subprocess
  10. from pathlib import Path
  11. from tqdm.auto import tqdm
  12. from einops import rearrange
  13. from omegaconf import OmegaConf
  14. from safetensors import safe_open
  15. from typing import Dict, Optional, Tuple
  16. import torch
  17. import torchvision
  18. import torch.nn.functional as F
  19. import torch.distributed as dist
  20. from torch.optim.swa_utils import AveragedModel
  21. from torch.utils.data.distributed import DistributedSampler
  22. from torch.nn.parallel import DistributedDataParallel as DDP
  23. import diffusers
  24. from diffusers import AutoencoderKL, DDIMScheduler
  25. from diffusers.models import UNet2DConditionModel
  26. from diffusers.pipelines import StableDiffusionPipeline
  27. from diffusers.optimization import get_scheduler
  28. from diffusers.utils import check_min_version
  29. from diffusers.utils.import_utils import is_xformers_available
  30. import transformers
  31. from transformers import CLIPTextModel, CLIPTokenizer
  32. from animatediff.data.dataset import WebVid10M
  33. from animatediff.models.unet import UNet3DConditionModel
  34. from animatediff.pipelines.pipeline_animation import AnimationPipeline
  35. from animatediff.utils.util import save_videos_grid, zero_rank_print
  36. from animatediff.models.lora import LoRAModule
  37. def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs):
  38. """Initializes distributed environment."""
  39. if launcher == 'pytorch':
  40. rank = int(os.environ['RANK'])
  41. num_gpus = torch.cuda.device_count()
  42. local_rank = rank % num_gpus
  43. torch.cuda.set_device(local_rank)
  44. dist.init_process_group(backend=backend, **kwargs)
  45. elif launcher == 'slurm':
  46. proc_id = int(os.environ['SLURM_PROCID'])
  47. ntasks = int(os.environ['SLURM_NTASKS'])
  48. node_list = os.environ['SLURM_NODELIST']
  49. num_gpus = torch.cuda.device_count()
  50. local_rank = proc_id % num_gpus
  51. torch.cuda.set_device(local_rank)
  52. addr = subprocess.getoutput(
  53. f'scontrol show hostname {node_list} | head -n1')
  54. os.environ['MASTER_ADDR'] = addr
  55. os.environ['WORLD_SIZE'] = str(ntasks)
  56. os.environ['RANK'] = str(proc_id)
  57. port = os.environ.get('PORT', port)
  58. os.environ['MASTER_PORT'] = str(port)
  59. dist.init_process_group(backend=backend)
  60. zero_rank_print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}")
  61. else:
  62. raise NotImplementedError(f'Not implemented launcher type: `{launcher}`!')
  63. return local_rank
  64. def main(
  65. image_finetune: bool,
  66. name: str,
  67. use_wandb: bool,
  68. launcher: str,
  69. output_dir: str,
  70. pretrained_model_path: str,
  71. train_data: Dict,
  72. validation_data: Dict,
  73. cfg_random_null_text: bool = True,
  74. cfg_random_null_text_ratio: float = 0.1,
  75. unet_checkpoint_path: str = "",
  76. unet_additional_kwargs: Dict = {},
  77. ema_decay: float = 0.9999,
  78. noise_scheduler_kwargs = None,
  79. max_train_epoch: int = -1,
  80. max_train_steps: int = 100,
  81. validation_steps: int = 100,
  82. validation_steps_tuple: Tuple = (-1,),
  83. learning_rate: float = 3e-5,
  84. scale_lr: bool = False,
  85. lr_warmup_steps: int = 0,
  86. lr_scheduler: str = "constant",
  87. trainable_modules: Tuple[str] = (None, ),
  88. num_workers: int = 32,
  89. train_batch_size: int = 1,
  90. adam_beta1: float = 0.9,
  91. adam_beta2: float = 0.999,
  92. adam_weight_decay: float = 1e-2,
  93. adam_epsilon: float = 1e-08,
  94. max_grad_norm: float = 1.0,
  95. gradient_accumulation_steps: int = 1,
  96. gradient_checkpointing: bool = False,
  97. checkpointing_epochs: int = 5,
  98. checkpointing_steps: int = -1,
  99. mixed_precision_training: bool = True,
  100. enable_xformers_memory_efficient_attention: bool = True,
  101. global_seed: int = 42,
  102. is_debug: bool = False,
  103. ):
  104. check_min_version("0.10.0.dev0")
  105. # Initialize distributed training
  106. local_rank = init_dist(launcher=launcher)
  107. global_rank = dist.get_rank()
  108. num_processes = dist.get_world_size()
  109. is_main_process = global_rank == 0
  110. seed = global_seed + global_rank
  111. torch.manual_seed(seed)
  112. # Logging folder
  113. folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S")
  114. output_dir = os.path.join(output_dir, folder_name)
  115. if is_debug and os.path.exists(output_dir):
  116. os.system(f"rm -rf {output_dir}")
  117. *_, config = inspect.getargvalues(inspect.currentframe())
  118. # Make one log on every process with the configuration for debugging.
  119. logging.basicConfig(
  120. format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
  121. datefmt="%m/%d/%Y %H:%M:%S",
  122. level=logging.INFO,
  123. )
  124. if is_main_process and (not is_debug) and use_wandb:
  125. run = wandb.init(project="animatediff", name=folder_name, config=config)
  126. # Handle the output folder creation
  127. if is_main_process:
  128. os.makedirs(output_dir, exist_ok=True)
  129. os.makedirs(f"{output_dir}/samples", exist_ok=True)
  130. os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
  131. os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
  132. OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
  133. # Load scheduler, tokenizer and models.
  134. noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
  135. vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
  136. tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
  137. text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
  138. if not image_finetune:
  139. unet = UNet3DConditionModel.from_pretrained_2d(
  140. pretrained_model_path, subfolder="unet",
  141. unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs)
  142. )
  143. else:
  144. unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet")
  145. # Load pretrained unet weights
  146. if unet_checkpoint_path != "":
  147. zero_rank_print(f"from checkpoint: {unet_checkpoint_path}")
  148. unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
  149. if "global_step" in unet_checkpoint_path: zero_rank_print(f"global_step: {unet_checkpoint_path['global_step']}")
  150. state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path
  151. m, u = unet.load_state_dict(state_dict, strict=False)
  152. zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
  153. assert len(u) == 0
  154. loras = []
  155. for name, module in unet.named_modules():
  156. if "to_q" in name or "to_k" in name or "to_v" in name or "to_out.0" in name:
  157. lora_name = name + "_lora"
  158. print(lora_name)
  159. lora = LoRAModule(lora_name,module)
  160. loras.append(lora)
  161. # print(module.__class__.__name__)
  162. print("enable LoRA for U-Net")
  163. for lora in loras:
  164. lora.apply_to()
  165. # Freeze vae and text_encoder
  166. vae.requires_grad_(False)
  167. text_encoder.requires_grad_(False)
  168. # Set unet trainable parameters
  169. unet.requires_grad_(False)
  170. for lora in loras:
  171. lora.requires_grad = True
  172. input()
  173. # for name, param in unet.named_parameters():
  174. # for trainable_module_name in trainable_modules:
  175. # if trainable_module_name in name:
  176. # param.requires_grad = True
  177. # break
  178. trainable_params = trainable_params = [param for lora in loras for param in lora.parameters() if param.requires_grad]
  179. optimizer = torch.optim.AdamW(
  180. trainable_params,
  181. lr=learning_rate,
  182. betas=(adam_beta1, adam_beta2),
  183. weight_decay=adam_weight_decay,
  184. eps=adam_epsilon,
  185. )
  186. if is_main_process:
  187. zero_rank_print(f"trainable params number: {len(trainable_params)}")
  188. zero_rank_print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
  189. # Enable xformers
  190. # if enable_xformers_memory_efficient_attention:
  191. # if is_xformers_available():
  192. # unet.enable_xformers_memory_efficient_attention()
  193. # else:
  194. # raise ValueError("xformers is not available. Make sure it is installed correctly")
  195. # Enable gradient checkpointing
  196. if gradient_checkpointing:
  197. unet.enable_gradient_checkpointing()
  198. # Move models to GPU
  199. vae.to(local_rank)
  200. text_encoder.to(local_rank)
  201. # Get the training dataset
  202. train_dataset = WebVid10M(**train_data, is_image=image_finetune)
  203. distributed_sampler = DistributedSampler(
  204. train_dataset,
  205. num_replicas=num_processes,
  206. rank=global_rank,
  207. shuffle=True,
  208. seed=global_seed,
  209. )
  210. # DataLoaders creation:
  211. train_dataloader = torch.utils.data.DataLoader(
  212. train_dataset,
  213. batch_size=train_batch_size,
  214. shuffle=False,
  215. sampler=distributed_sampler,
  216. num_workers=num_workers,
  217. pin_memory=True,
  218. drop_last=True,
  219. )
  220. # Get the training iteration
  221. if max_train_steps == -1:
  222. assert max_train_epoch != -1
  223. max_train_steps = max_train_epoch * len(train_dataloader)
  224. if checkpointing_steps == -1:
  225. assert checkpointing_epochs != -1
  226. checkpointing_steps = checkpointing_epochs * len(train_dataloader)
  227. if scale_lr:
  228. learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes)
  229. # Scheduler
  230. lr_scheduler = get_scheduler(
  231. lr_scheduler,
  232. optimizer=optimizer,
  233. num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
  234. num_training_steps=max_train_steps * gradient_accumulation_steps,
  235. )
  236. # Validation pipeline
  237. if not image_finetune:
  238. validation_pipeline = AnimationPipeline(
  239. unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler,
  240. ).to("cuda")
  241. else:
  242. validation_pipeline = StableDiffusionPipeline.from_pretrained(
  243. pretrained_model_path,
  244. unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, safety_checker=None,
  245. )
  246. validation_pipeline.enable_vae_slicing()
  247. class LoRAModuleWrapper(torch.nn.Module):
  248. def __init__(self, loras):
  249. super().__init__()
  250. for i, lora in enumerate(loras):
  251. module_name = lora.lora_name.replace('.', '-')
  252. self.add_module(module_name, lora)
  253. loras_wrapper = LoRAModuleWrapper(loras).to(local_rank)
  254. # DDP warpper
  255. unet.to(local_rank)
  256. loras_ddp = DDP(loras_wrapper, device_ids=[local_rank], output_device=local_rank)
  257. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
  258. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
  259. # Afterwards we recalculate our number of training epochs
  260. num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
  261. # Train!
  262. total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps
  263. if is_main_process:
  264. logging.info("***** Running training *****")
  265. logging.info(f" Num examples = {len(train_dataset)}")
  266. logging.info(f" Num Epochs = {num_train_epochs}")
  267. logging.info(f" Instantaneous batch size per device = {train_batch_size}")
  268. logging.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
  269. logging.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
  270. logging.info(f" Total optimization steps = {max_train_steps}")
  271. global_step = 0
  272. first_epoch = 0
  273. # Only show the progress bar once on each machine.
  274. progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process)
  275. progress_bar.set_description("Steps")
  276. # Support mixed-precision training
  277. scaler = torch.cuda.amp.GradScaler() if mixed_precision_training else None
  278. for epoch in range(first_epoch, num_train_epochs):
  279. train_dataloader.sampler.set_epoch(epoch)
  280. unet.train()
  281. for step, batch in enumerate(train_dataloader):
  282. if cfg_random_null_text:
  283. batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']]
  284. # Data batch sanity check
  285. if epoch == first_epoch and step == 0:
  286. pixel_values, texts = batch['pixel_values'].cpu(), batch['text']
  287. if not image_finetune:
  288. pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
  289. for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
  290. pixel_value = pixel_value[None, ...]
  291. save_videos_grid(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.gif", rescale=True)
  292. else:
  293. for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
  294. pixel_value = pixel_value / 2. + 0.5
  295. torchvision.utils.save_image(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.png")
  296. ### >>>> Training >>>> ###
  297. # Convert videos to latent space
  298. pixel_values = batch["pixel_values"].to(local_rank)
  299. video_length = pixel_values.shape[1]
  300. with torch.no_grad():
  301. if not image_finetune:
  302. pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
  303. latents = vae.encode(pixel_values).latent_dist
  304. latents = latents.sample()
  305. latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
  306. else:
  307. latents = vae.encode(pixel_values).latent_dist
  308. latents = latents.sample()
  309. latents = latents * 0.18215
  310. # Sample noise that we'll add to the latents
  311. noise = torch.randn_like(latents)
  312. bsz = latents.shape[0]
  313. # Sample a random timestep for each video
  314. timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
  315. timesteps = timesteps.long()
  316. # Add noise to the latents according to the noise magnitude at each timestep
  317. # (this is the forward diffusion process)
  318. noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
  319. # Get the text embedding for conditioning
  320. with torch.no_grad():
  321. prompt_ids = tokenizer(
  322. batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
  323. ).input_ids.to(latents.device)
  324. encoder_hidden_states = text_encoder(prompt_ids)[0]
  325. # Get the target for loss depending on the prediction type
  326. if noise_scheduler.config.prediction_type == "epsilon":
  327. target = noise
  328. elif noise_scheduler.config.prediction_type == "v_prediction":
  329. raise NotImplementedError
  330. else:
  331. raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
  332. # Predict the noise residual and compute loss
  333. # Mixed-precision training
  334. with torch.cuda.amp.autocast(enabled=mixed_precision_training):
  335. model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
  336. loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
  337. optimizer.zero_grad()
  338. # Backpropagate
  339. if mixed_precision_training:
  340. scaler.scale(loss).backward()
  341. """ >>> gradient clipping >>> """
  342. scaler.unscale_(optimizer)
  343. torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
  344. """ <<< gradient clipping <<< """
  345. scaler.step(optimizer)
  346. scaler.update()
  347. else:
  348. loss.backward()
  349. """ >>> gradient clipping >>> """
  350. torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
  351. """ <<< gradient clipping <<< """
  352. optimizer.step()
  353. lr_scheduler.step()
  354. progress_bar.update(1)
  355. global_step += 1
  356. # for lora in loras:
  357. # print(f"lora_down grad: {lora.lora_down.weight.grad}")
  358. # print(f"lora_up grad: {lora.lora_up.weight.grad}")
  359. # print(f"lora_down requires_grad: {lora.lora_down.weight.requires_grad}")
  360. # print(f"lora_up requires_grad: {lora.lora_up.weight.requires_grad}")
  361. # print(f"lora_down data: {lora.lora_down.weight.data}")
  362. # print(f"lora_up data: {lora.lora_up.weight.data}")
  363. # input()
  364. ### <<<< Training <<<< ###
  365. # Wandb logging
  366. if is_main_process and (not is_debug) and use_wandb:
  367. wandb.log({"train_loss": loss.item()}, step=global_step)
  368. # Save checkpoint
  369. if is_main_process and (global_step % checkpointing_steps == 0 or step == len(train_dataloader) - 1):
  370. lora_save_path = os.path.join(output_dir, f"checkpoints", "lora")
  371. os.makedirs(lora_save_path, exist_ok=True)
  372. loras_state_dict = {
  373. "epoch": epoch,
  374. "global_step": global_step,
  375. "state_dict": loras_wrapper.state_dict(),
  376. }
  377. if step == len(train_dataloader) - 1:
  378. torch.save(loras_state_dict, os.path.join(lora_save_path, f"checkpoint-epoch-{epoch+1}.ckpt"))
  379. else:
  380. torch.save(loras_state_dict, os.path.join(lora_save_path, f"checkpoint.ckpt"))
  381. logging.info(f"Saved LoRA state to {lora_save_path} (global_step: {global_step})")
  382. # Periodically validation
  383. if is_main_process and (global_step % validation_steps == 0 or global_step in validation_steps_tuple):
  384. samples = []
  385. generator = torch.Generator(device=latents.device)
  386. generator.manual_seed(global_seed)
  387. height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size
  388. width = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size
  389. prompts = validation_data.prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.prompts
  390. for idx, prompt in enumerate(prompts):
  391. if not image_finetune:
  392. sample = validation_pipeline(
  393. prompt,
  394. generator = generator,
  395. video_length = train_data.sample_n_frames,
  396. height = height,
  397. width = width,
  398. **validation_data,
  399. ).videos
  400. save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif")
  401. samples.append(sample)
  402. else:
  403. sample = validation_pipeline(
  404. prompt,
  405. generator = generator,
  406. height = height,
  407. width = width,
  408. num_inference_steps = validation_data.get("num_inference_steps", 25),
  409. guidance_scale = validation_data.get("guidance_scale", 8.),
  410. ).images[0]
  411. sample = torchvision.transforms.functional.to_tensor(sample)
  412. samples.append(sample)
  413. if not image_finetune:
  414. samples = torch.concat(samples)
  415. save_path = f"{output_dir}/samples/sample-{global_step}.gif"
  416. save_videos_grid(samples, save_path)
  417. else:
  418. samples = torch.stack(samples)
  419. save_path = f"{output_dir}/samples/sample-{global_step}.png"
  420. torchvision.utils.save_image(samples, save_path, nrow=4)
  421. logging.info(f"Saved samples to {save_path}")
  422. logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
  423. progress_bar.set_postfix(**logs)
  424. if global_step >= max_train_steps:
  425. break
  426. dist.destroy_process_group()
  427. if __name__ == "__main__":
  428. parser = argparse.ArgumentParser()
  429. parser.add_argument("--config", type=str, required=True)
  430. parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch")
  431. parser.add_argument("--wandb", action="store_true")
  432. args = parser.parse_args()
  433. name = Path(args.config).stem
  434. config = OmegaConf.load(args.config)
  435. main(name=name, launcher=args.launcher, use_wandb=args.wandb, **config)