Yuwei Guo 2 éve
szülő
commit
ebfd7b74f7
3 módosított fájl, 99 hozzáadás és 87 törlés
  1. 5 3
      configs/prompts/1-ToonYou.yaml
  2. 2 1
      requirements.txt
  3. 92 83
      scripts/animate.py

+ 5 - 3
configs/prompts/1-ToonYou.yaml

@@ -1,7 +1,9 @@
 ToonYou:
-  base:           ""
-  path:           "models/DreamBooth_LoRA/toonyou_beta3.safetensors"
-  motion_module:  "models/Motion_Module/mm_sd_v14.ckpt"
+  base: ""
+  path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors"
+  motion_module:
+    - "models/Motion_Module/mm_sd_v14.ckpt"
+    - "models/Motion_Module/mm_sd_v15.ckpt"
 
   seed:           [10788741199826055526, 6520604954829636163, 6519455744612555650, 16372571278361863751]
   steps:          25

+ 2 - 1
requirements.txt

@@ -3,7 +3,8 @@ torch==1.12.1+cu113
 torchvision==0.13.1+cu113
 diffusers[torch]==0.11.1
 transformers==4.25.1
+imageio==2.27.0
+gdown
 einops
 omegaconf
 safetensors
-imageio==2.27.0

+ 92 - 83
scripts/animate.py

@@ -37,94 +37,103 @@ def main(args):
 
     config  = OmegaConf.load(args.config)
     samples = []
+    
+    sample_idx = 0
     for model_idx, (config_key, model_config) in enumerate(list(config.items())):
-        ### >>> create validation pipeline >>> ###
-        tokenizer    = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
-        text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
-        vae          = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")            
-        unet         = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
-
-        pipeline = AnimationPipeline(
-            vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
-            scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
-        ).to("cuda")
-
-        # 1. unet ckpt
-        # 1.1 motion module
-        motion_module_state_dict = torch.load(model_config.motion_module, map_location="cpu")
-        if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
-        missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
-        assert len(unexpected) == 0
         
-        # 1.2 T2I
-        if model_config.path != "":
-            if model_config.path.endswith(".ckpt"):
-                state_dict = torch.load(model_config.path)
-                pipeline.unet.load_state_dict(state_dict)
-                
-            elif model_config.path.endswith(".safetensors"):
-                state_dict = {}
-                with safe_open(model_config.path, framework="pt", device="cpu") as f:
-                    for key in f.keys():
-                        state_dict[key] = f.get_tensor(key)
-                        
-                is_lora = all("lora" in k for k in state_dict.keys())
-                if not is_lora:
-                    base_state_dict = state_dict
-                else:
-                    base_state_dict = {}
-                    with safe_open(model_config.base, framework="pt", device="cpu") as f:
-                        for key in f.keys():
-                            base_state_dict[key] = f.get_tensor(key)                
-                
-                # vae
-                converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
-                pipeline.vae.load_state_dict(converted_vae_checkpoint)
-                # unet
-                converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
-                pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
-                # text_model
-                pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict)
-                
-                # import pdb
-                # pdb.set_trace()
-                if is_lora:
-                    pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
-
-        pipeline.to("cuda")
-        ### <<< create validation pipeline <<< ###
-
-        prompts      = model_config.prompt
-        n_prompts    = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
+        motion_modules = model_config.motion_module
+        motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules)
+        for motion_module in motion_modules:
         
-        random_seeds = model_config.pop("seed", [-1])
-        random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
-        random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
-        
-        config[config_key].random_seed = []
-        for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
+            ### >>> create validation pipeline >>> ###
+            tokenizer    = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
+            text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
+            vae          = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")            
+            unet         = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
+
+            pipeline = AnimationPipeline(
+                vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
+                scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
+            ).to("cuda")
+
+            # 1. unet ckpt
+            # 1.1 motion module
+            motion_module_state_dict = torch.load(motion_module, map_location="cpu")
+            if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
+            missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
+            assert len(unexpected) == 0
+            
+            # 1.2 T2I
+            if model_config.path != "":
+                if model_config.path.endswith(".ckpt"):
+                    state_dict = torch.load(model_config.path)
+                    pipeline.unet.load_state_dict(state_dict)
+                    
+                elif model_config.path.endswith(".safetensors"):
+                    state_dict = {}
+                    with safe_open(model_config.path, framework="pt", device="cpu") as f:
+                        for key in f.keys():
+                            state_dict[key] = f.get_tensor(key)
+                            
+                    is_lora = all("lora" in k for k in state_dict.keys())
+                    if not is_lora:
+                        base_state_dict = state_dict
+                    else:
+                        base_state_dict = {}
+                        with safe_open(model_config.base, framework="pt", device="cpu") as f:
+                            for key in f.keys():
+                                base_state_dict[key] = f.get_tensor(key)                
+                    
+                    # vae
+                    converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
+                    pipeline.vae.load_state_dict(converted_vae_checkpoint)
+                    # unet
+                    converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
+                    pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
+                    # text_model
+                    pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict)
+                    
+                    # import pdb
+                    # pdb.set_trace()
+                    if is_lora:
+                        pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
+
+            pipeline.to("cuda")
+            ### <<< create validation pipeline <<< ###
+
+            prompts      = model_config.prompt
+            n_prompts    = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
             
-            # manually set random seed for reproduction
-            if random_seed != -1: torch.manual_seed(random_seed)
-            else: torch.seed()
-            config[config_key].random_seed.append(torch.initial_seed())
+            random_seeds = model_config.get("seed", [-1])
+            random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
+            random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
             
-            print(f"current seed: {torch.initial_seed()}")
-            print(f"sampling {prompt} ...")
-            sample = pipeline(
-                prompt,
-                negative_prompt     = n_prompt,
-                num_inference_steps = model_config.steps,
-                guidance_scale      = model_config.guidance_scale,
-                width               = args.W,
-                height              = args.H,
-                video_length        = args.L,
-            ).videos
-            samples.append(sample)
-
-            prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
-            save_videos_grid(sample, f"{savedir}/sample/{model_idx}-{prompt_idx}-{prompt}.gif")
-            print(f"save to {savedir}/sample/{prompt}.gif")
+            config[config_key].random_seed = []
+            for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
+                
+                # manually set random seed for reproduction
+                if random_seed != -1: torch.manual_seed(random_seed)
+                else: torch.seed()
+                config[config_key].random_seed.append(torch.initial_seed())
+                
+                print(f"current seed: {torch.initial_seed()}")
+                print(f"sampling {prompt} ...")
+                sample = pipeline(
+                    prompt,
+                    negative_prompt     = n_prompt,
+                    num_inference_steps = model_config.steps,
+                    guidance_scale      = model_config.guidance_scale,
+                    width               = args.W,
+                    height              = args.H,
+                    video_length        = args.L,
+                ).videos
+                samples.append(sample)
+
+                prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
+                save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif")
+                print(f"save to {savedir}/sample/{prompt}.gif")
+                
+                sample_idx += 1
 
     samples = torch.concat(samples)
     save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)