Browse Source

update infer script

Yuwei Guo 2 years ago
parent
commit
069dd9432f

+ 26 - 0
animatediff/utils/convert_lora_safetensor_to_diffusers.py

@@ -23,6 +23,32 @@ from safetensors.torch import load_file
 from diffusers import StableDiffusionPipeline
 import pdb
 
+
+
+def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
+    # directly update weight in diffusers model
+    for key in state_dict:
+        # only process lora down key
+        if "up." in key: continue
+
+        up_key    = key.replace(".down.", ".up.")
+        model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
+        model_key = model_key.replace("to_out.", "to_out.0.")
+        layer_infos = model_key.split(".")[:-1]
+
+        curr_layer = pipeline.unet
+        while len(layer_infos) > 0:
+            temp_name = layer_infos.pop(0)
+            curr_layer = curr_layer.__getattr__(temp_name)
+
+        weight_down = state_dict[key]
+        weight_up   = state_dict[up_key]
+        curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
+
+    return pipeline
+
+
+
 def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
     # load base model
     # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)

+ 68 - 0
animatediff/utils/util.py

@@ -7,8 +7,11 @@ import torch
 import torchvision
 import torch.distributed as dist
 
+from safetensors import safe_open
 from tqdm import tqdm
 from einops import rearrange
+from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
+from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers
 
 
 def zero_rank_print(s):
@@ -87,3 +90,68 @@ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
 def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
     ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
     return ddim_latents
+
+def load_weights(
+    animation_pipeline,
+    # motion module
+    motion_module_path         = "",
+    motion_module_lora_configs = [],
+    # image layers
+    dreambooth_model_path = "",
+    lora_model_path       = "",
+    lora_alpha            = 0.8,
+):
+    # 1.1 motion module
+    unet_state_dict = {}
+    if motion_module_path != "":
+        print(f"load motion module from {motion_module_path}")
+        motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
+        motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
+        unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
+    
+    missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
+    assert len(unexpected) == 0
+    del unet_state_dict
+
+    if dreambooth_model_path != "":
+        print(f"load dreambooth model from {dreambooth_model_path}")
+        if dreambooth_model_path.endswith(".safetensors"):
+            dreambooth_state_dict = {}
+            with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f:
+                for key in f.keys():
+                    dreambooth_state_dict[key] = f.get_tensor(key)
+        elif dreambooth_model_path.endswith(".ckpt"):
+            dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu")
+            
+        # 1. vae
+        converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config)
+        animation_pipeline.vae.load_state_dict(converted_vae_checkpoint)
+        # 2. unet
+        converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config)
+        animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
+        # 3. text_model
+        animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict)
+        del dreambooth_state_dict
+        
+    if lora_model_path != "":
+        print(f"load lora model from {lora_model_path}")
+        assert lora_model_path.endswith(".safetensors")
+        lora_state_dict = {}
+        with safe_open(lora_model_path, framework="pt", device="cpu") as f:
+            for key in f.keys():
+                lora_state_dict[key] = f.get_tensor(key)
+                
+        animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha)
+        del lora_state_dict
+
+
+    for motion_module_lora_config in motion_module_lora_configs:
+        path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
+        print(f"load motion LoRA from {path}")
+
+        motion_lora_state_dict = torch.load(path, map_location="cpu")
+        motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
+
+        animation_pipeline = convert_motion_lora_ckpt_to_diffusers(animation_pipeline, motion_lora_state_dict, alpha)
+
+    return animation_pipeline