| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- import argparse
- import datetime
- import inspect
- import os
- from omegaconf import OmegaConf
- import torch
- import diffusers
- from diffusers import AutoencoderKL, DDIMScheduler
- from tqdm.auto import tqdm
- from transformers import CLIPTextModel, CLIPTokenizer
- from animatediff.models.unet import UNet3DConditionModel
- from animatediff.pipelines.pipeline_animation import AnimationPipeline
- from animatediff.utils.util import save_videos_grid
- from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
- from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
- from diffusers.utils.import_utils import is_xformers_available
- from einops import rearrange, repeat
- import csv, pdb, glob
- from safetensors import safe_open
- import math
- from pathlib import Path
- def main(args):
- *_, func_args = inspect.getargvalues(inspect.currentframe())
- func_args = dict(func_args)
-
- time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
- savedir = f"samples/{Path(args.config).stem}-{time_str}"
- os.makedirs(savedir)
- inference_config = OmegaConf.load(args.inference_config)
- config = OmegaConf.load(args.config)
- samples = []
-
- sample_idx = 0
- for model_idx, (config_key, model_config) in enumerate(list(config.items())):
-
- motion_modules = model_config.motion_module
- motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules)
- for motion_module in motion_modules:
-
- ### >>> create validation pipeline >>> ###
- tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
- text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
- unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
- if is_xformers_available(): unet.enable_xformers_memory_efficient_attention()
- else: assert False
- pipeline = AnimationPipeline(
- vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
- scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
- ).to("cuda")
- # 1. unet ckpt
- # 1.1 motion module
- motion_module_state_dict = torch.load(motion_module, map_location="cpu")
- if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
- missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
- assert len(unexpected) == 0
-
- # 1.2 T2I
- if model_config.path != "":
- if model_config.path.endswith(".ckpt"):
- state_dict = torch.load(model_config.path)
- pipeline.unet.load_state_dict(state_dict)
-
- elif model_config.path.endswith(".safetensors"):
- state_dict = {}
- with safe_open(model_config.path, framework="pt", device="cpu") as f:
- for key in f.keys():
- state_dict[key] = f.get_tensor(key)
-
- is_lora = all("lora" in k for k in state_dict.keys())
- if not is_lora:
- base_state_dict = state_dict
- else:
- base_state_dict = {}
- with safe_open(model_config.base, framework="pt", device="cpu") as f:
- for key in f.keys():
- base_state_dict[key] = f.get_tensor(key)
-
- # vae
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
- pipeline.vae.load_state_dict(converted_vae_checkpoint)
- # unet
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
- pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
- # text_model
- pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict)
-
- # import pdb
- # pdb.set_trace()
- if is_lora:
- pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
- pipeline.to("cuda")
- ### <<< create validation pipeline <<< ###
- prompts = model_config.prompt
- n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
-
- random_seeds = model_config.get("seed", [-1])
- random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
- random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
-
- config[config_key].random_seed = []
- for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
-
- # manually set random seed for reproduction
- if random_seed != -1: torch.manual_seed(random_seed)
- else: torch.seed()
- config[config_key].random_seed.append(torch.initial_seed())
-
- print(f"current seed: {torch.initial_seed()}")
- print(f"sampling {prompt} ...")
- sample = pipeline(
- prompt,
- negative_prompt = n_prompt,
- num_inference_steps = model_config.steps,
- guidance_scale = model_config.guidance_scale,
- width = args.W,
- height = args.H,
- video_length = args.L,
- ).videos
- samples.append(sample)
- prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
- save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif")
- print(f"save to {savedir}/sample/{prompt}.gif")
-
- sample_idx += 1
- samples = torch.concat(samples)
- save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)
- OmegaConf.save(config, f"{savedir}/config.yaml")
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",)
- parser.add_argument("--inference_config", type=str, default="configs/inference/inference.yaml")
- parser.add_argument("--config", type=str, required=True)
-
- parser.add_argument("--L", type=int, default=16 )
- parser.add_argument("--W", type=int, default=512)
- parser.add_argument("--H", type=int, default=512)
- args = parser.parse_args()
- main(args)
|