|
|
@@ -15,14 +15,12 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
|
|
from animatediff.models.unet import UNet3DConditionModel
|
|
|
from animatediff.pipelines.pipeline_animation import AnimationPipeline
|
|
|
from animatediff.utils.util import save_videos_grid
|
|
|
-from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
|
|
|
-from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
|
|
|
+from animatediff.utils.util import load_weights
|
|
|
from diffusers.utils.import_utils import is_xformers_available
|
|
|
|
|
|
from einops import rearrange, repeat
|
|
|
|
|
|
import csv, pdb, glob
|
|
|
-from safetensors import safe_open
|
|
|
import math
|
|
|
from pathlib import Path
|
|
|
|
|
|
@@ -60,50 +58,16 @@ def main(args):
|
|
|
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
|
|
|
).to("cuda")
|
|
|
|
|
|
- # 1. unet ckpt
|
|
|
- # 1.1 motion module
|
|
|
- motion_module_state_dict = torch.load(motion_module, map_location="cpu")
|
|
|
- if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
|
|
|
- missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
|
|
|
- assert len(unexpected) == 0
|
|
|
-
|
|
|
- # 1.2 T2I
|
|
|
- if model_config.path != "":
|
|
|
- if model_config.path.endswith(".ckpt"):
|
|
|
- state_dict = torch.load(model_config.path)
|
|
|
- pipeline.unet.load_state_dict(state_dict)
|
|
|
-
|
|
|
- elif model_config.path.endswith(".safetensors"):
|
|
|
- state_dict = {}
|
|
|
- with safe_open(model_config.path, framework="pt", device="cpu") as f:
|
|
|
- for key in f.keys():
|
|
|
- state_dict[key] = f.get_tensor(key)
|
|
|
-
|
|
|
- is_lora = all("lora" in k for k in state_dict.keys())
|
|
|
- if not is_lora:
|
|
|
- base_state_dict = state_dict
|
|
|
- else:
|
|
|
- base_state_dict = {}
|
|
|
- with safe_open(model_config.base, framework="pt", device="cpu") as f:
|
|
|
- for key in f.keys():
|
|
|
- base_state_dict[key] = f.get_tensor(key)
|
|
|
-
|
|
|
- # vae
|
|
|
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
|
|
|
- pipeline.vae.load_state_dict(converted_vae_checkpoint)
|
|
|
- # unet
|
|
|
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
|
|
|
- pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
|
|
- # text_model
|
|
|
- pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict)
|
|
|
-
|
|
|
- # import pdb
|
|
|
- # pdb.set_trace()
|
|
|
- if is_lora:
|
|
|
- pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
|
|
|
-
|
|
|
- pipeline.to("cuda")
|
|
|
- ### <<< create validation pipeline <<< ###
|
|
|
+ pipeline = load_weights(
|
|
|
+ pipeline,
|
|
|
+ # motion module
|
|
|
+ motion_module_path = motion_module,
|
|
|
+ motion_module_lora_configs = model_config.get("motion_module_lora_configs", []),
|
|
|
+ # image layers
|
|
|
+ dreambooth_model_path = model_config.get("dreambooth_path", ""),
|
|
|
+ lora_model_path = model_config.get("lora_model_path", ""),
|
|
|
+ lora_alpha = model_config.get("lora_alpha", 0.8),
|
|
|
+ ).to("cuda")
|
|
|
|
|
|
prompts = model_config.prompt
|
|
|
n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
|