animate.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import argparse
  2. import datetime
  3. import inspect
  4. import os
  5. from omegaconf import OmegaConf
  6. import torch
  7. import diffusers
  8. from diffusers import AutoencoderKL, DDIMScheduler
  9. from tqdm.auto import tqdm
  10. from transformers import CLIPTextModel, CLIPTokenizer
  11. from animatediff.models.unet import UNet3DConditionModel
  12. from animatediff.pipelines.pipeline_animation import AnimationPipeline
  13. from animatediff.utils.util import save_videos_grid
  14. from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
  15. from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
  16. from diffusers.utils.import_utils import is_xformers_available
  17. from einops import rearrange, repeat
  18. import csv, pdb, glob
  19. from safetensors import safe_open
  20. import math
  21. from pathlib import Path
  22. def main(args):
  23. *_, func_args = inspect.getargvalues(inspect.currentframe())
  24. func_args = dict(func_args)
  25. time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
  26. savedir = f"samples/{Path(args.config).stem}-{time_str}"
  27. os.makedirs(savedir)
  28. config = OmegaConf.load(args.config)
  29. samples = []
  30. sample_idx = 0
  31. for model_idx, (config_key, model_config) in enumerate(list(config.items())):
  32. motion_modules = model_config.motion_module
  33. motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules)
  34. for motion_module in motion_modules:
  35. inference_config = OmegaConf.load(model_config.get("inference_config", args.inference_config))
  36. ### >>> create validation pipeline >>> ###
  37. tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
  38. text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
  39. vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
  40. unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
  41. if is_xformers_available(): unet.enable_xformers_memory_efficient_attention()
  42. else: assert False
  43. pipeline = AnimationPipeline(
  44. vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
  45. scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
  46. ).to("cuda")
  47. # 1. unet ckpt
  48. # 1.1 motion module
  49. motion_module_state_dict = torch.load(motion_module, map_location="cpu")
  50. if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
  51. missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
  52. assert len(unexpected) == 0
  53. # 1.2 T2I
  54. if model_config.path != "":
  55. if model_config.path.endswith(".ckpt"):
  56. state_dict = torch.load(model_config.path)
  57. pipeline.unet.load_state_dict(state_dict)
  58. elif model_config.path.endswith(".safetensors"):
  59. state_dict = {}
  60. with safe_open(model_config.path, framework="pt", device="cpu") as f:
  61. for key in f.keys():
  62. state_dict[key] = f.get_tensor(key)
  63. is_lora = all("lora" in k for k in state_dict.keys())
  64. if not is_lora:
  65. base_state_dict = state_dict
  66. else:
  67. base_state_dict = {}
  68. with safe_open(model_config.base, framework="pt", device="cpu") as f:
  69. for key in f.keys():
  70. base_state_dict[key] = f.get_tensor(key)
  71. # vae
  72. converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
  73. pipeline.vae.load_state_dict(converted_vae_checkpoint)
  74. # unet
  75. converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
  76. pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
  77. # text_model
  78. pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict)
  79. # import pdb
  80. # pdb.set_trace()
  81. if is_lora:
  82. pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
  83. pipeline.to("cuda")
  84. ### <<< create validation pipeline <<< ###
  85. prompts = model_config.prompt
  86. n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
  87. random_seeds = model_config.get("seed", [-1])
  88. random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
  89. random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
  90. config[config_key].random_seed = []
  91. for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
  92. # manually set random seed for reproduction
  93. if random_seed != -1: torch.manual_seed(random_seed)
  94. else: torch.seed()
  95. config[config_key].random_seed.append(torch.initial_seed())
  96. print(f"current seed: {torch.initial_seed()}")
  97. print(f"sampling {prompt} ...")
  98. sample = pipeline(
  99. prompt,
  100. negative_prompt = n_prompt,
  101. num_inference_steps = model_config.steps,
  102. guidance_scale = model_config.guidance_scale,
  103. width = args.W,
  104. height = args.H,
  105. video_length = args.L,
  106. ).videos
  107. samples.append(sample)
  108. prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
  109. save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif")
  110. print(f"save to {savedir}/sample/{prompt}.gif")
  111. sample_idx += 1
  112. samples = torch.concat(samples)
  113. save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)
  114. OmegaConf.save(config, f"{savedir}/config.yaml")
  115. if __name__ == "__main__":
  116. parser = argparse.ArgumentParser()
  117. parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",)
  118. parser.add_argument("--inference_config", type=str, default="configs/inference/inference-v1.yaml")
  119. parser.add_argument("--config", type=str, required=True)
  120. parser.add_argument("--L", type=int, default=16 )
  121. parser.add_argument("--W", type=int, default=512)
  122. parser.add_argument("--H", type=int, default=512)
  123. args = parser.parse_args()
  124. main(args)