|
@@ -0,0 +1,337 @@
|
|
|
|
|
+import gradio as gr
|
|
|
|
|
+import os
|
|
|
|
|
+from glob import glob
|
|
|
|
|
+import random
|
|
|
|
|
+import pdb
|
|
|
|
|
+from transformers import CLIPTextModel, CLIPTokenizer
|
|
|
|
|
+from animatediff.models.unet import UNet3DConditionModel
|
|
|
|
|
+from animatediff.pipelines.pipeline_animation import AnimationPipeline
|
|
|
|
|
+
|
|
|
|
|
+from diffusers import AutoencoderKL
|
|
|
|
|
+from datetime import datetime
|
|
|
|
|
+import os
|
|
|
|
|
+from omegaconf import OmegaConf
|
|
|
|
|
+import json
|
|
|
|
|
+import torch
|
|
|
|
|
+
|
|
|
|
|
+from diffusers import AutoencoderKL
|
|
|
|
|
+from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
|
|
|
|
|
+
|
|
|
|
|
+from transformers import CLIPTextModel, CLIPTokenizer
|
|
|
|
|
+
|
|
|
|
|
+from animatediff.models.unet import UNet3DConditionModel
|
|
|
|
|
+from animatediff.pipelines.pipeline_animation import AnimationPipeline
|
|
|
|
|
+from animatediff.utils.util import save_videos_grid
|
|
|
|
|
+from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
|
|
|
|
|
+from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
|
|
|
|
|
+from diffusers.utils.import_utils import is_xformers_available
|
|
|
|
|
+
|
|
|
|
|
+from safetensors import safe_open
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+sample_idx = 0
|
|
|
|
|
+
|
|
|
|
|
+scheduler_dict = {
|
|
|
|
|
+ "Euler": EulerDiscreteScheduler,
|
|
|
|
|
+ "PNDM": PNDMScheduler,
|
|
|
|
|
+ "DDIM": DDIMScheduler,
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+css = """
|
|
|
|
|
+.toolbutton {
|
|
|
|
|
+ margin-buttom: 0em 0em 0em 0em;
|
|
|
|
|
+ max-width: 2.5em;
|
|
|
|
|
+ min-width: 2.5em !important;
|
|
|
|
|
+ height: 2.5em;
|
|
|
|
|
+}
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class AnimateController:
|
|
|
|
|
+ def __init__(self):
|
|
|
|
|
+
|
|
|
|
|
+ # config dirs
|
|
|
|
|
+ self.basedir = os.getcwd()
|
|
|
|
|
+ self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion")
|
|
|
|
|
+ self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
|
|
|
|
|
+ self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
|
|
|
|
|
+ self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
|
|
|
|
|
+ self.savedir_sample = os.path.join(self.savedir, "sample")
|
|
|
|
|
+ os.makedirs(self.savedir, exist_ok=True)
|
|
|
|
|
+
|
|
|
|
|
+ self.stable_diffusion_list = []
|
|
|
|
|
+ self.motion_module_list = []
|
|
|
|
|
+ self.personalized_model_list = []
|
|
|
|
|
+
|
|
|
|
|
+ self.refresh_stable_diffusion()
|
|
|
|
|
+ self.refresh_motion_module()
|
|
|
|
|
+ self.refresh_personalized_model()
|
|
|
|
|
+
|
|
|
|
|
+ # config models
|
|
|
|
|
+ self.tokenizer = None
|
|
|
|
|
+ self.text_encoder = None
|
|
|
|
|
+ self.vae = None
|
|
|
|
|
+ self.unet = None
|
|
|
|
|
+ self.pipeline = None
|
|
|
|
|
+ self.lora_model_state_dict = {}
|
|
|
|
|
+
|
|
|
|
|
+ self.inference_config = OmegaConf.load("configs/inference/inference.yaml")
|
|
|
|
|
+
|
|
|
|
|
+ def refresh_stable_diffusion(self):
|
|
|
|
|
+ self.stable_diffusion_list = glob(os.path.join(self.stable_diffusion_dir, "*/"))
|
|
|
|
|
+
|
|
|
|
|
+ def refresh_motion_module(self):
|
|
|
|
|
+ motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
|
|
|
|
|
+ self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
|
|
|
|
|
+
|
|
|
|
|
+ def refresh_personalized_model(self):
|
|
|
|
|
+ personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
|
|
|
|
|
+ self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
|
|
|
|
|
+
|
|
|
|
|
+ def update_stable_diffusion(self, stable_diffusion_dropdown):
|
|
|
|
|
+ self.tokenizer = CLIPTokenizer.from_pretrained(stable_diffusion_dropdown, subfolder="tokenizer")
|
|
|
|
|
+ self.text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder").cuda()
|
|
|
|
|
+ self.vae = AutoencoderKL.from_pretrained(stable_diffusion_dropdown, subfolder="vae").cuda()
|
|
|
|
|
+ self.unet = UNet3DConditionModel.from_pretrained_2d(stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
|
|
|
|
|
+ return gr.Dropdown.update()
|
|
|
|
|
+
|
|
|
|
|
+ def update_motion_module(self, motion_module_dropdown):
|
|
|
|
|
+ if self.unet is None:
|
|
|
|
|
+ gr.Info(f"Please select a pretrained model path.")
|
|
|
|
|
+ return gr.Dropdown.update(value=None)
|
|
|
|
|
+ else:
|
|
|
|
|
+ motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
|
|
|
|
|
+ motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
|
|
|
|
|
+ missing, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False)
|
|
|
|
|
+ assert len(unexpected) == 0
|
|
|
|
|
+ return gr.Dropdown.update()
|
|
|
|
|
+
|
|
|
|
|
+ def update_base_model(self, base_model_dropdown):
|
|
|
|
|
+ if self.unet is None:
|
|
|
|
|
+ gr.Info(f"Please select a pretrained model path.")
|
|
|
|
|
+ return gr.Dropdown.update(value=None)
|
|
|
|
|
+ else:
|
|
|
|
|
+ base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
|
|
|
|
|
+ base_model_state_dict = {}
|
|
|
|
|
+ with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
|
|
|
|
|
+ for key in f.keys():
|
|
|
|
|
+ base_model_state_dict[key] = f.get_tensor(key)
|
|
|
|
|
+
|
|
|
|
|
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_model_state_dict, self.vae.config)
|
|
|
|
|
+ self.vae.load_state_dict(converted_vae_checkpoint)
|
|
|
|
|
+
|
|
|
|
|
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, self.unet.config)
|
|
|
|
|
+ self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
|
|
|
|
+
|
|
|
|
|
+ self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
|
|
|
|
|
+ return gr.Dropdown.update()
|
|
|
|
|
+
|
|
|
|
|
+ def update_lora_model(self, lora_model_dropdown):
|
|
|
|
|
+ lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
|
|
|
|
|
+ self.lora_model_state_dict = {}
|
|
|
|
|
+ if lora_model_dropdown == "none": pass
|
|
|
|
|
+ else:
|
|
|
|
|
+ with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
|
|
|
|
|
+ for key in f.keys():
|
|
|
|
|
+ self.lora_model_state_dict[key] = f.get_tensor(key)
|
|
|
|
|
+ return gr.Dropdown.update()
|
|
|
|
|
+
|
|
|
|
|
+ def animate(
|
|
|
|
|
+ self,
|
|
|
|
|
+ stable_diffusion_dropdown,
|
|
|
|
|
+ motion_module_dropdown,
|
|
|
|
|
+ base_model_dropdown,
|
|
|
|
|
+ lora_alpha_slider,
|
|
|
|
|
+ prompt_textbox,
|
|
|
|
|
+ negative_prompt_textbox,
|
|
|
|
|
+ sampler_dropdown,
|
|
|
|
|
+ sample_step_slider,
|
|
|
|
|
+ width_slider,
|
|
|
|
|
+ length_slider,
|
|
|
|
|
+ height_slider,
|
|
|
|
|
+ cfg_scale_slider,
|
|
|
|
|
+ seed_textbox
|
|
|
|
|
+ ):
|
|
|
|
|
+ if self.unet is None:
|
|
|
|
|
+ raise gr.Error(f"Please select a pretrained model path.")
|
|
|
|
|
+ if motion_module_dropdown == "":
|
|
|
|
|
+ raise gr.Error(f"Please select a motion module.")
|
|
|
|
|
+ if base_model_dropdown == "":
|
|
|
|
|
+ raise gr.Error(f"Please select a base DreamBooth model.")
|
|
|
|
|
+
|
|
|
|
|
+ if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
|
|
|
|
|
+
|
|
|
|
|
+ pipeline = AnimationPipeline(
|
|
|
|
|
+ vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
|
|
|
|
|
+ scheduler=scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
|
|
|
|
|
+ ).to("cuda")
|
|
|
|
|
+
|
|
|
|
|
+ if self.lora_model_state_dict != {}:
|
|
|
|
|
+ pipeline = convert_lora(pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider)
|
|
|
|
|
+
|
|
|
|
|
+ pipeline.to("cuda")
|
|
|
|
|
+
|
|
|
|
|
+ if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
|
|
|
|
|
+ else: torch.seed()
|
|
|
|
|
+ seed = torch.initial_seed()
|
|
|
|
|
+
|
|
|
|
|
+ sample = pipeline(
|
|
|
|
|
+ prompt_textbox,
|
|
|
|
|
+ negative_prompt = negative_prompt_textbox,
|
|
|
|
|
+ num_inference_steps = sample_step_slider,
|
|
|
|
|
+ guidance_scale = cfg_scale_slider,
|
|
|
|
|
+ width = width_slider,
|
|
|
|
|
+ height = height_slider,
|
|
|
|
|
+ video_length = length_slider,
|
|
|
|
|
+ ).videos
|
|
|
|
|
+
|
|
|
|
|
+ save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4")
|
|
|
|
|
+ save_videos_grid(sample, save_sample_path)
|
|
|
|
|
+
|
|
|
|
|
+ sample_config = {
|
|
|
|
|
+ "prompt": prompt_textbox,
|
|
|
|
|
+ "n_prompt": negative_prompt_textbox,
|
|
|
|
|
+ "sampler": sampler_dropdown,
|
|
|
|
|
+ "num_inference_steps": sample_step_slider,
|
|
|
|
|
+ "guidance_scale": cfg_scale_slider,
|
|
|
|
|
+ "width": width_slider,
|
|
|
|
|
+ "height": height_slider,
|
|
|
|
|
+ "video_length": length_slider,
|
|
|
|
|
+ "seed": seed
|
|
|
|
|
+ }
|
|
|
|
|
+ json_str = json.dumps(sample_config, indent=4)
|
|
|
|
|
+ with open(os.path.join(self.savedir, "logs.json"), "a") as f:
|
|
|
|
|
+ f.write(json_str)
|
|
|
|
|
+ f.write("\n\n")
|
|
|
|
|
+
|
|
|
|
|
+ return gr.Video.update(value=save_sample_path)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+controller = AnimateController()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def ui():
|
|
|
|
|
+ with gr.Blocks(css=css) as demo:
|
|
|
|
|
+ gr.Markdown(
|
|
|
|
|
+ """
|
|
|
|
|
+ # [AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725)
|
|
|
|
|
+ Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai (*Corresponding Author)<br>
|
|
|
|
|
+ [Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) | [Github](https://github.com/guoyww/animatediff/)
|
|
|
|
|
+ """
|
|
|
|
|
+ )
|
|
|
|
|
+ with gr.Column(variant="panel"):
|
|
|
|
|
+ gr.Markdown(
|
|
|
|
|
+ """
|
|
|
|
|
+ ### 1. Model checkpoints (select pretrained model path first).
|
|
|
|
|
+ """
|
|
|
|
|
+ )
|
|
|
|
|
+ with gr.Row():
|
|
|
|
|
+ stable_diffusion_dropdown = gr.Dropdown(
|
|
|
|
|
+ label="Pretrained Model Path",
|
|
|
|
|
+ choices=controller.stable_diffusion_list,
|
|
|
|
|
+ interactive=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ stable_diffusion_dropdown.change(fn=controller.update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown])
|
|
|
|
|
+
|
|
|
|
|
+ stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
|
|
|
|
+ def update_stable_diffusion():
|
|
|
|
|
+ controller.refresh_stable_diffusion()
|
|
|
|
|
+ return gr.Dropdown.update(choices=controller.stable_diffusion_list)
|
|
|
|
|
+ stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[], outputs=[stable_diffusion_dropdown])
|
|
|
|
|
+
|
|
|
|
|
+ with gr.Row():
|
|
|
|
|
+ motion_module_dropdown = gr.Dropdown(
|
|
|
|
|
+ label="Select motion module",
|
|
|
|
|
+ choices=controller.motion_module_list,
|
|
|
|
|
+ interactive=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown])
|
|
|
|
|
+
|
|
|
|
|
+ motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
|
|
|
|
+ def update_motion_module():
|
|
|
|
|
+ controller.refresh_motion_module()
|
|
|
|
|
+ return gr.Dropdown.update(choices=controller.motion_module_list)
|
|
|
|
|
+ motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
|
|
|
|
|
+
|
|
|
|
|
+ base_model_dropdown = gr.Dropdown(
|
|
|
|
|
+ label="Select base Dreambooth model (required)",
|
|
|
|
|
+ choices=controller.personalized_model_list,
|
|
|
|
|
+ interactive=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ base_model_dropdown.change(fn=controller.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown])
|
|
|
|
|
+
|
|
|
|
|
+ lora_model_dropdown = gr.Dropdown(
|
|
|
|
|
+ label="Select LoRA model (optional)",
|
|
|
|
|
+ choices=["none"] + controller.personalized_model_list,
|
|
|
|
|
+ value="none",
|
|
|
|
|
+ interactive=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown])
|
|
|
|
|
+
|
|
|
|
|
+ lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True)
|
|
|
|
|
+
|
|
|
|
|
+ personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
|
|
|
|
+ def update_personalized_model():
|
|
|
|
|
+ controller.refresh_personalized_model()
|
|
|
|
|
+ return [
|
|
|
|
|
+ gr.Dropdown.update(choices=controller.personalized_model_list),
|
|
|
|
|
+ gr.Dropdown.update(choices=["none"] + controller.personalized_model_list)
|
|
|
|
|
+ ]
|
|
|
|
|
+ personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
|
|
|
|
|
+
|
|
|
|
|
+ with gr.Column(variant="panel"):
|
|
|
|
|
+ gr.Markdown(
|
|
|
|
|
+ """
|
|
|
|
|
+ ### 2. Configs for AnimateDiff.
|
|
|
|
|
+ """
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ prompt_textbox = gr.Textbox(label="Prompt", lines=2)
|
|
|
|
|
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2)
|
|
|
|
|
+
|
|
|
|
|
+ with gr.Row().style(equal_height=False):
|
|
|
|
|
+ with gr.Column():
|
|
|
|
|
+ with gr.Row():
|
|
|
|
|
+ sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
|
|
|
|
|
+ sample_step_slider = gr.Slider(label="Sampling steps", value=25, minimum=10, maximum=100, step=1)
|
|
|
|
|
+
|
|
|
|
|
+ width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64)
|
|
|
|
|
+ height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64)
|
|
|
|
|
+ length_slider = gr.Slider(label="Animation length", value=16, minimum=8, maximum=24, step=1)
|
|
|
|
|
+ cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20)
|
|
|
|
|
+
|
|
|
|
|
+ with gr.Row():
|
|
|
|
|
+ seed_textbox = gr.Textbox(label="Seed", value=-1)
|
|
|
|
|
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
|
|
|
|
|
+ seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
|
|
|
|
|
+
|
|
|
|
|
+ generate_button = gr.Button(value="Generate", variant='primary')
|
|
|
|
|
+
|
|
|
|
|
+ result_video = gr.Video(label="Generated Animation", interactive=False)
|
|
|
|
|
+
|
|
|
|
|
+ generate_button.click(
|
|
|
|
|
+ fn=controller.animate,
|
|
|
|
|
+ inputs=[
|
|
|
|
|
+ stable_diffusion_dropdown,
|
|
|
|
|
+ motion_module_dropdown,
|
|
|
|
|
+ base_model_dropdown,
|
|
|
|
|
+ lora_alpha_slider,
|
|
|
|
|
+ prompt_textbox,
|
|
|
|
|
+ negative_prompt_textbox,
|
|
|
|
|
+ sampler_dropdown,
|
|
|
|
|
+ sample_step_slider,
|
|
|
|
|
+ width_slider,
|
|
|
|
|
+ length_slider,
|
|
|
|
|
+ height_slider,
|
|
|
|
|
+ cfg_scale_slider,
|
|
|
|
|
+ seed_textbox,
|
|
|
|
|
+ ],
|
|
|
|
|
+ outputs=[result_video]
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return demo
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ demo = ui()
|
|
|
|
|
+ demo.launch(share=True)
|