Yuwei Guo 2 роки тому
батько
коміт
d09e5cfa53
1 змінених файлів з 11 додано та 47 видалено
  1. 11 47
      scripts/animate.py

+ 11 - 47
scripts/animate.py

@@ -15,14 +15,12 @@ from transformers import CLIPTextModel, CLIPTokenizer
 from animatediff.models.unet import UNet3DConditionModel
 from animatediff.pipelines.pipeline_animation import AnimationPipeline
 from animatediff.utils.util import save_videos_grid
-from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
-from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
+from animatediff.utils.util import load_weights
 from diffusers.utils.import_utils import is_xformers_available
 
 from einops import rearrange, repeat
 
 import csv, pdb, glob
-from safetensors import safe_open
 import math
 from pathlib import Path
 
@@ -60,50 +58,16 @@ def main(args):
                 scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
             ).to("cuda")
 
-            # 1. unet ckpt
-            # 1.1 motion module
-            motion_module_state_dict = torch.load(motion_module, map_location="cpu")
-            if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
-            missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
-            assert len(unexpected) == 0
-            
-            # 1.2 T2I
-            if model_config.path != "":
-                if model_config.path.endswith(".ckpt"):
-                    state_dict = torch.load(model_config.path)
-                    pipeline.unet.load_state_dict(state_dict)
-                    
-                elif model_config.path.endswith(".safetensors"):
-                    state_dict = {}
-                    with safe_open(model_config.path, framework="pt", device="cpu") as f:
-                        for key in f.keys():
-                            state_dict[key] = f.get_tensor(key)
-                            
-                    is_lora = all("lora" in k for k in state_dict.keys())
-                    if not is_lora:
-                        base_state_dict = state_dict
-                    else:
-                        base_state_dict = {}
-                        with safe_open(model_config.base, framework="pt", device="cpu") as f:
-                            for key in f.keys():
-                                base_state_dict[key] = f.get_tensor(key)                
-                    
-                    # vae
-                    converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
-                    pipeline.vae.load_state_dict(converted_vae_checkpoint)
-                    # unet
-                    converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
-                    pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
-                    # text_model
-                    pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict)
-                    
-                    # import pdb
-                    # pdb.set_trace()
-                    if is_lora:
-                        pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
-
-            pipeline.to("cuda")
-            ### <<< create validation pipeline <<< ###
+            pipeline = load_weights(
+                pipeline,
+                # motion module
+                motion_module_path         = motion_module,
+                motion_module_lora_configs = model_config.get("motion_module_lora_configs", []),
+                # image layers
+                dreambooth_model_path      = model_config.get("dreambooth_path", ""),
+                lora_model_path            = model_config.get("lora_model_path", ""),
+                lora_alpha                 = model_config.get("lora_alpha", 0.8),
+            ).to("cuda")
 
             prompts      = model_config.prompt
             n_prompts    = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt