animate.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import argparse
  2. import datetime
  3. import inspect
  4. import os
  5. from omegaconf import OmegaConf
  6. import torch
  7. import diffusers
  8. from diffusers import AutoencoderKL, DDIMScheduler
  9. from tqdm.auto import tqdm
  10. from transformers import CLIPTextModel, CLIPTokenizer
  11. from animatediff.models.unet import UNet3DConditionModel
  12. from animatediff.pipelines.pipeline_animation import AnimationPipeline
  13. from animatediff.utils.util import save_videos_grid
  14. from animatediff.utils.util import load_weights
  15. from diffusers.utils.import_utils import is_xformers_available
  16. from einops import rearrange, repeat
  17. import csv, pdb, glob
  18. import math
  19. from pathlib import Path
  20. def main(args):
  21. *_, func_args = inspect.getargvalues(inspect.currentframe())
  22. func_args = dict(func_args)
  23. time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
  24. savedir = f"samples/{Path(args.config).stem}-{time_str}"
  25. os.makedirs(savedir)
  26. config = OmegaConf.load(args.config)
  27. samples = []
  28. sample_idx = 0
  29. for model_idx, (config_key, model_config) in enumerate(list(config.items())):
  30. motion_modules = model_config.motion_module
  31. motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules)
  32. for motion_module in motion_modules:
  33. inference_config = OmegaConf.load(model_config.get("inference_config", args.inference_config))
  34. ### >>> create validation pipeline >>> ###
  35. tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
  36. text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
  37. vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
  38. unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
  39. if is_xformers_available(): unet.enable_xformers_memory_efficient_attention()
  40. else: assert False
  41. pipeline = AnimationPipeline(
  42. vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
  43. scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
  44. ).to("cuda")
  45. pipeline = load_weights(
  46. pipeline,
  47. # motion module
  48. motion_module_path = motion_module,
  49. motion_module_lora_configs = model_config.get("motion_module_lora_configs", []),
  50. # image layers
  51. dreambooth_model_path = model_config.get("dreambooth_path", ""),
  52. lora_model_path = model_config.get("lora_model_path", ""),
  53. lora_alpha = model_config.get("lora_alpha", 0.8),
  54. ).to("cuda")
  55. prompts = model_config.prompt
  56. n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
  57. random_seeds = model_config.get("seed", [-1])
  58. random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
  59. random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
  60. config[config_key].random_seed = []
  61. for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
  62. # manually set random seed for reproduction
  63. if random_seed != -1: torch.manual_seed(random_seed)
  64. else: torch.seed()
  65. config[config_key].random_seed.append(torch.initial_seed())
  66. print(f"current seed: {torch.initial_seed()}")
  67. print(f"sampling {prompt} ...")
  68. sample = pipeline(
  69. prompt,
  70. negative_prompt = n_prompt,
  71. num_inference_steps = model_config.steps,
  72. guidance_scale = model_config.guidance_scale,
  73. width = args.W,
  74. height = args.H,
  75. video_length = args.L,
  76. ).videos
  77. samples.append(sample)
  78. prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
  79. save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif")
  80. print(f"save to {savedir}/sample/{prompt}.gif")
  81. sample_idx += 1
  82. samples = torch.concat(samples)
  83. save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)
  84. OmegaConf.save(config, f"{savedir}/config.yaml")
  85. if __name__ == "__main__":
  86. parser = argparse.ArgumentParser()
  87. parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",)
  88. parser.add_argument("--inference_config", type=str, default="configs/inference/inference-v2.yaml")
  89. parser.add_argument("--config", type=str, required=True)
  90. parser.add_argument("--L", type=int, default=16 )
  91. parser.add_argument("--W", type=int, default=512)
  92. parser.add_argument("--H", type=int, default=512)
  93. args = parser.parse_args()
  94. main(args)