util-checkpoint.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import os
  2. import imageio
  3. import numpy as np
  4. from typing import Union
  5. import torch
  6. import torchvision
  7. import torch.distributed as dist
  8. from safetensors import safe_open
  9. from tqdm import tqdm
  10. from einops import rearrange
  11. from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
  12. from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers, convert_motion_lora_ckpt_to_diffusers_test
  13. def zero_rank_print(s):
  14. if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
  15. def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
  16. videos = rearrange(videos, "b c t h w -> t b c h w")
  17. outputs = []
  18. for x in videos:
  19. x = torchvision.utils.make_grid(x, nrow=n_rows)
  20. x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
  21. if rescale:
  22. x = (x + 1.0) / 2.0 # -1,1 -> 0,1
  23. x = (x * 255).numpy().astype(np.uint8)
  24. outputs.append(x)
  25. os.makedirs(os.path.dirname(path), exist_ok=True)
  26. imageio.mimsave(path, outputs, fps=fps)
  27. # DDIM Inversion
  28. @torch.no_grad()
  29. def init_prompt(prompt, pipeline):
  30. uncond_input = pipeline.tokenizer(
  31. [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
  32. return_tensors="pt"
  33. )
  34. uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
  35. text_input = pipeline.tokenizer(
  36. [prompt],
  37. padding="max_length",
  38. max_length=pipeline.tokenizer.model_max_length,
  39. truncation=True,
  40. return_tensors="pt",
  41. )
  42. text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
  43. context = torch.cat([uncond_embeddings, text_embeddings])
  44. return context
  45. def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
  46. sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
  47. timestep, next_timestep = min(
  48. timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
  49. alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
  50. alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
  51. beta_prod_t = 1 - alpha_prod_t
  52. next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
  53. next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
  54. next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
  55. return next_sample
  56. def get_noise_pred_single(latents, t, context, unet):
  57. noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
  58. return noise_pred
  59. @torch.no_grad()
  60. def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
  61. context = init_prompt(prompt, pipeline)
  62. uncond_embeddings, cond_embeddings = context.chunk(2)
  63. all_latent = [latent]
  64. latent = latent.clone().detach()
  65. for i in tqdm(range(num_inv_steps)):
  66. t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
  67. noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
  68. latent = next_step(noise_pred, t, latent, ddim_scheduler)
  69. all_latent.append(latent)
  70. return all_latent
  71. @torch.no_grad()
  72. def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
  73. ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
  74. return ddim_latents
  75. def load_weights(
  76. animation_pipeline,
  77. # motion module
  78. motion_module_path = "",
  79. motion_module_lora_configs = [],
  80. # image layers
  81. dreambooth_model_path = "",
  82. lora_model_path = "",
  83. lora_alpha = 0.8,
  84. ):
  85. # 1.1 motion module
  86. unet_state_dict = {}
  87. if motion_module_path != "":
  88. print(f"load motion module from {motion_module_path}")
  89. motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
  90. motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
  91. unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
  92. missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
  93. assert len(unexpected) == 0
  94. del unet_state_dict
  95. if dreambooth_model_path != "":
  96. print(f"load dreambooth model from {dreambooth_model_path}")
  97. if dreambooth_model_path.endswith(".safetensors"):
  98. dreambooth_state_dict = {}
  99. with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f:
  100. for key in f.keys():
  101. dreambooth_state_dict[key] = f.get_tensor(key)
  102. elif dreambooth_model_path.endswith(".ckpt"):
  103. dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu")
  104. # 1. vae
  105. converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config)
  106. animation_pipeline.vae.load_state_dict(converted_vae_checkpoint)
  107. # 2. unet
  108. converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config)
  109. animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
  110. # 3. text_model
  111. animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict)
  112. del dreambooth_state_dict
  113. if lora_model_path != "":
  114. print(f"load lora model from {lora_model_path}")
  115. assert lora_model_path.endswith(".safetensors")
  116. lora_state_dict = {}
  117. with safe_open(lora_model_path, framework="pt", device="cpu") as f:
  118. for key in f.keys():
  119. lora_state_dict[key] = f.get_tensor(key)
  120. animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha)
  121. del lora_state_dict
  122. for motion_module_lora_config in motion_module_lora_configs:
  123. path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
  124. print(f"load motion LoRA from {path}")
  125. motion_lora_state_dict = torch.load(path, map_location="cpu")
  126. # print(motion_lora_state_dict)
  127. # input()
  128. motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
  129. animation_pipeline = convert_motion_lora_ckpt_to_diffusers_test(animation_pipeline, motion_lora_state_dict, alpha)
  130. return animation_pipeline