animate.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import argparse
  2. import datetime
  3. import inspect
  4. import os
  5. from omegaconf import OmegaConf
  6. import torch
  7. import diffusers
  8. from diffusers import AutoencoderKL, DDIMScheduler
  9. from tqdm.auto import tqdm
  10. from transformers import CLIPTextModel, CLIPTokenizer
  11. from animatediff.models.unet import UNet3DConditionModel
  12. from animatediff.pipelines.pipeline_animation import AnimationPipeline
  13. from animatediff.utils.util import save_videos_grid
  14. from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
  15. from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
  16. from einops import rearrange, repeat
  17. import csv, pdb, glob
  18. from safetensors import safe_open
  19. import math
  20. from pathlib import Path
  21. def main(args):
  22. *_, func_args = inspect.getargvalues(inspect.currentframe())
  23. func_args = dict(func_args)
  24. time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
  25. savedir = f"samples/{Path(args.config).stem}-{time_str}"
  26. os.makedirs(savedir)
  27. inference_config = OmegaConf.load(args.inference_config)
  28. config = OmegaConf.load(args.config)
  29. samples = []
  30. sample_idx = 0
  31. for model_idx, (config_key, model_config) in enumerate(list(config.items())):
  32. motion_modules = model_config.motion_module
  33. motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules)
  34. for motion_module in motion_modules:
  35. ### >>> create validation pipeline >>> ###
  36. tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
  37. text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
  38. vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
  39. unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
  40. pipeline = AnimationPipeline(
  41. vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
  42. scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
  43. ).to("cuda")
  44. # 1. unet ckpt
  45. # 1.1 motion module
  46. motion_module_state_dict = torch.load(motion_module, map_location="cpu")
  47. if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
  48. missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
  49. assert len(unexpected) == 0
  50. # 1.2 T2I
  51. if model_config.path != "":
  52. if model_config.path.endswith(".ckpt"):
  53. state_dict = torch.load(model_config.path)
  54. pipeline.unet.load_state_dict(state_dict)
  55. elif model_config.path.endswith(".safetensors"):
  56. state_dict = {}
  57. with safe_open(model_config.path, framework="pt", device="cpu") as f:
  58. for key in f.keys():
  59. state_dict[key] = f.get_tensor(key)
  60. is_lora = all("lora" in k for k in state_dict.keys())
  61. if not is_lora:
  62. base_state_dict = state_dict
  63. else:
  64. base_state_dict = {}
  65. with safe_open(model_config.base, framework="pt", device="cpu") as f:
  66. for key in f.keys():
  67. base_state_dict[key] = f.get_tensor(key)
  68. # vae
  69. converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
  70. pipeline.vae.load_state_dict(converted_vae_checkpoint)
  71. # unet
  72. converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
  73. pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
  74. # text_model
  75. pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict)
  76. # import pdb
  77. # pdb.set_trace()
  78. if is_lora:
  79. pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
  80. pipeline.to("cuda")
  81. ### <<< create validation pipeline <<< ###
  82. prompts = model_config.prompt
  83. n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
  84. random_seeds = model_config.get("seed", [-1])
  85. random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
  86. random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
  87. config[config_key].random_seed = []
  88. for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
  89. # manually set random seed for reproduction
  90. if random_seed != -1: torch.manual_seed(random_seed)
  91. else: torch.seed()
  92. config[config_key].random_seed.append(torch.initial_seed())
  93. print(f"current seed: {torch.initial_seed()}")
  94. print(f"sampling {prompt} ...")
  95. sample = pipeline(
  96. prompt,
  97. negative_prompt = n_prompt,
  98. num_inference_steps = model_config.steps,
  99. guidance_scale = model_config.guidance_scale,
  100. width = args.W,
  101. height = args.H,
  102. video_length = args.L,
  103. ).videos
  104. samples.append(sample)
  105. prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
  106. save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif")
  107. print(f"save to {savedir}/sample/{prompt}.gif")
  108. sample_idx += 1
  109. samples = torch.concat(samples)
  110. save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)
  111. OmegaConf.save(config, f"{savedir}/config.yaml")
  112. if __name__ == "__main__":
  113. parser = argparse.ArgumentParser()
  114. parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",)
  115. parser.add_argument("--inference_config", type=str, default="configs/inference/inference.yaml")
  116. parser.add_argument("--config", type=str, required=True)
  117. parser.add_argument("--L", type=int, default=16 )
  118. parser.add_argument("--W", type=int, default=512)
  119. parser.add_argument("--H", type=int, default=512)
  120. args = parser.parse_args()
  121. main(args)