animate.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import argparse
  2. import datetime
  3. import inspect
  4. import os
  5. from omegaconf import OmegaConf
  6. import torch
  7. import diffusers
  8. from diffusers import AutoencoderKL, DDIMScheduler
  9. from tqdm.auto import tqdm
  10. from transformers import CLIPTextModel, CLIPTokenizer
  11. from animatediff.models.unet import UNet3DConditionModel
  12. from animatediff.pipelines.pipeline_animation import AnimationPipeline
  13. from animatediff.utils.util import save_videos_grid
  14. from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
  15. from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
  16. from einops import rearrange, repeat
  17. import csv, pdb, glob
  18. from safetensors import safe_open
  19. import math
  20. from pathlib import Path
  21. def main(args):
  22. *_, func_args = inspect.getargvalues(inspect.currentframe())
  23. func_args = dict(func_args)
  24. time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
  25. savedir = f"samples/{Path(args.config).stem}-{time_str}"
  26. os.makedirs(savedir)
  27. inference_config = OmegaConf.load(args.inference_config)
  28. config = OmegaConf.load(args.config)
  29. samples = []
  30. for model_idx, (config_key, model_config) in enumerate(list(config.items())):
  31. ### >>> create validation pipeline >>> ###
  32. tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
  33. text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
  34. vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
  35. unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
  36. pipeline = AnimationPipeline(
  37. vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
  38. scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
  39. ).to("cuda")
  40. # 1. unet ckpt
  41. # 1.1 motion module
  42. motion_module_state_dict = torch.load(model_config.motion_module, map_location="cpu")
  43. if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
  44. missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
  45. assert len(unexpected) == 0
  46. # 1.2 T2I
  47. if model_config.path != "":
  48. if model_config.path.endswith(".ckpt"):
  49. state_dict = torch.load(model_config.path)
  50. pipeline.unet.load_state_dict(state_dict)
  51. elif model_config.path.endswith(".safetensors"):
  52. state_dict = {}
  53. with safe_open(model_config.path, framework="pt", device="cpu") as f:
  54. for key in f.keys():
  55. state_dict[key] = f.get_tensor(key)
  56. is_lora = all("lora" in k for k in state_dict.keys())
  57. if not is_lora:
  58. base_state_dict = state_dict
  59. else:
  60. base_state_dict = {}
  61. with safe_open(model_config.base, framework="pt", device="cpu") as f:
  62. for key in f.keys():
  63. base_state_dict[key] = f.get_tensor(key)
  64. # vae
  65. converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
  66. pipeline.vae.load_state_dict(converted_vae_checkpoint)
  67. # unet
  68. converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
  69. pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
  70. # text_model
  71. pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict)
  72. # import pdb
  73. # pdb.set_trace()
  74. if is_lora:
  75. pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
  76. pipeline.to("cuda")
  77. ### <<< create validation pipeline <<< ###
  78. prompts = model_config.prompt
  79. n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
  80. random_seeds = model_config.pop("seed", [-1])
  81. random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
  82. random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
  83. config[config_key].random_seed = []
  84. for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
  85. # manually set random seed for reproduction
  86. if random_seed != -1: torch.manual_seed(random_seed)
  87. else: torch.seed()
  88. config[config_key].random_seed.append(torch.initial_seed())
  89. print(f"current seed: {torch.initial_seed()}")
  90. print(f"sampling {prompt} ...")
  91. sample = pipeline(
  92. prompt,
  93. negative_prompt = n_prompt,
  94. num_inference_steps = model_config.steps,
  95. guidance_scale = model_config.guidance_scale,
  96. width = args.W,
  97. height = args.H,
  98. video_length = args.L,
  99. ).videos
  100. samples.append(sample)
  101. prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
  102. save_videos_grid(sample, f"{savedir}/sample/{model_idx}-{prompt_idx}-{prompt}.gif")
  103. print(f"save to {savedir}/sample/{prompt}.gif")
  104. samples = torch.concat(samples)
  105. save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)
  106. OmegaConf.save(config, f"{savedir}/config.yaml")
  107. if __name__ == "__main__":
  108. parser = argparse.ArgumentParser()
  109. parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",)
  110. parser.add_argument("--inference_config", type=str, default="configs/inference/inference.yaml")
  111. parser.add_argument("--config", type=str, required=True)
  112. parser.add_argument("--L", type=int, default=16 )
  113. parser.add_argument("--W", type=int, default=512)
  114. parser.add_argument("--H", type=int, default=512)
  115. args = parser.parse_args()
  116. main(args)