|
|
@@ -37,94 +37,103 @@ def main(args):
|
|
|
|
|
|
config = OmegaConf.load(args.config)
|
|
|
samples = []
|
|
|
+
|
|
|
+ sample_idx = 0
|
|
|
for model_idx, (config_key, model_config) in enumerate(list(config.items())):
|
|
|
- ### >>> create validation pipeline >>> ###
|
|
|
- tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
|
|
|
- text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
|
|
|
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
|
|
|
- unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
|
|
|
-
|
|
|
- pipeline = AnimationPipeline(
|
|
|
- vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
|
|
|
- scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
|
|
|
- ).to("cuda")
|
|
|
-
|
|
|
- # 1. unet ckpt
|
|
|
- # 1.1 motion module
|
|
|
- motion_module_state_dict = torch.load(model_config.motion_module, map_location="cpu")
|
|
|
- if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
|
|
|
- missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
|
|
|
- assert len(unexpected) == 0
|
|
|
|
|
|
- # 1.2 T2I
|
|
|
- if model_config.path != "":
|
|
|
- if model_config.path.endswith(".ckpt"):
|
|
|
- state_dict = torch.load(model_config.path)
|
|
|
- pipeline.unet.load_state_dict(state_dict)
|
|
|
-
|
|
|
- elif model_config.path.endswith(".safetensors"):
|
|
|
- state_dict = {}
|
|
|
- with safe_open(model_config.path, framework="pt", device="cpu") as f:
|
|
|
- for key in f.keys():
|
|
|
- state_dict[key] = f.get_tensor(key)
|
|
|
-
|
|
|
- is_lora = all("lora" in k for k in state_dict.keys())
|
|
|
- if not is_lora:
|
|
|
- base_state_dict = state_dict
|
|
|
- else:
|
|
|
- base_state_dict = {}
|
|
|
- with safe_open(model_config.base, framework="pt", device="cpu") as f:
|
|
|
- for key in f.keys():
|
|
|
- base_state_dict[key] = f.get_tensor(key)
|
|
|
-
|
|
|
- # vae
|
|
|
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
|
|
|
- pipeline.vae.load_state_dict(converted_vae_checkpoint)
|
|
|
- # unet
|
|
|
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
|
|
|
- pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
|
|
- # text_model
|
|
|
- pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict)
|
|
|
-
|
|
|
- # import pdb
|
|
|
- # pdb.set_trace()
|
|
|
- if is_lora:
|
|
|
- pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
|
|
|
-
|
|
|
- pipeline.to("cuda")
|
|
|
- ### <<< create validation pipeline <<< ###
|
|
|
-
|
|
|
- prompts = model_config.prompt
|
|
|
- n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
|
|
|
+ motion_modules = model_config.motion_module
|
|
|
+ motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules)
|
|
|
+ for motion_module in motion_modules:
|
|
|
|
|
|
- random_seeds = model_config.pop("seed", [-1])
|
|
|
- random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
|
|
|
- random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
|
|
|
-
|
|
|
- config[config_key].random_seed = []
|
|
|
- for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
|
|
|
+ ### >>> create validation pipeline >>> ###
|
|
|
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
|
|
|
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
|
|
|
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
|
|
|
+ unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
|
|
|
+
|
|
|
+ pipeline = AnimationPipeline(
|
|
|
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
|
|
|
+ scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
|
|
|
+ ).to("cuda")
|
|
|
+
|
|
|
+ # 1. unet ckpt
|
|
|
+ # 1.1 motion module
|
|
|
+ motion_module_state_dict = torch.load(motion_module, map_location="cpu")
|
|
|
+ if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
|
|
|
+ missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
|
|
|
+ assert len(unexpected) == 0
|
|
|
+
|
|
|
+ # 1.2 T2I
|
|
|
+ if model_config.path != "":
|
|
|
+ if model_config.path.endswith(".ckpt"):
|
|
|
+ state_dict = torch.load(model_config.path)
|
|
|
+ pipeline.unet.load_state_dict(state_dict)
|
|
|
+
|
|
|
+ elif model_config.path.endswith(".safetensors"):
|
|
|
+ state_dict = {}
|
|
|
+ with safe_open(model_config.path, framework="pt", device="cpu") as f:
|
|
|
+ for key in f.keys():
|
|
|
+ state_dict[key] = f.get_tensor(key)
|
|
|
+
|
|
|
+ is_lora = all("lora" in k for k in state_dict.keys())
|
|
|
+ if not is_lora:
|
|
|
+ base_state_dict = state_dict
|
|
|
+ else:
|
|
|
+ base_state_dict = {}
|
|
|
+ with safe_open(model_config.base, framework="pt", device="cpu") as f:
|
|
|
+ for key in f.keys():
|
|
|
+ base_state_dict[key] = f.get_tensor(key)
|
|
|
+
|
|
|
+ # vae
|
|
|
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
|
|
|
+ pipeline.vae.load_state_dict(converted_vae_checkpoint)
|
|
|
+ # unet
|
|
|
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
|
|
|
+ pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
|
|
+ # text_model
|
|
|
+ pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict)
|
|
|
+
|
|
|
+ # import pdb
|
|
|
+ # pdb.set_trace()
|
|
|
+ if is_lora:
|
|
|
+ pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
|
|
|
+
|
|
|
+ pipeline.to("cuda")
|
|
|
+ ### <<< create validation pipeline <<< ###
|
|
|
+
|
|
|
+ prompts = model_config.prompt
|
|
|
+ n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
|
|
|
|
|
|
- # manually set random seed for reproduction
|
|
|
- if random_seed != -1: torch.manual_seed(random_seed)
|
|
|
- else: torch.seed()
|
|
|
- config[config_key].random_seed.append(torch.initial_seed())
|
|
|
+ random_seeds = model_config.get("seed", [-1])
|
|
|
+ random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
|
|
|
+ random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
|
|
|
|
|
|
- print(f"current seed: {torch.initial_seed()}")
|
|
|
- print(f"sampling {prompt} ...")
|
|
|
- sample = pipeline(
|
|
|
- prompt,
|
|
|
- negative_prompt = n_prompt,
|
|
|
- num_inference_steps = model_config.steps,
|
|
|
- guidance_scale = model_config.guidance_scale,
|
|
|
- width = args.W,
|
|
|
- height = args.H,
|
|
|
- video_length = args.L,
|
|
|
- ).videos
|
|
|
- samples.append(sample)
|
|
|
-
|
|
|
- prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
|
|
|
- save_videos_grid(sample, f"{savedir}/sample/{model_idx}-{prompt_idx}-{prompt}.gif")
|
|
|
- print(f"save to {savedir}/sample/{prompt}.gif")
|
|
|
+ config[config_key].random_seed = []
|
|
|
+ for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
|
|
|
+
|
|
|
+ # manually set random seed for reproduction
|
|
|
+ if random_seed != -1: torch.manual_seed(random_seed)
|
|
|
+ else: torch.seed()
|
|
|
+ config[config_key].random_seed.append(torch.initial_seed())
|
|
|
+
|
|
|
+ print(f"current seed: {torch.initial_seed()}")
|
|
|
+ print(f"sampling {prompt} ...")
|
|
|
+ sample = pipeline(
|
|
|
+ prompt,
|
|
|
+ negative_prompt = n_prompt,
|
|
|
+ num_inference_steps = model_config.steps,
|
|
|
+ guidance_scale = model_config.guidance_scale,
|
|
|
+ width = args.W,
|
|
|
+ height = args.H,
|
|
|
+ video_length = args.L,
|
|
|
+ ).videos
|
|
|
+ samples.append(sample)
|
|
|
+
|
|
|
+ prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
|
|
|
+ save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif")
|
|
|
+ print(f"save to {savedir}/sample/{prompt}.gif")
|
|
|
+
|
|
|
+ sample_idx += 1
|
|
|
|
|
|
samples = torch.concat(samples)
|
|
|
save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)
|