123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 |
- import os
- import json
- import torch
- import random
- import gradio as gr
- from glob import glob
- from omegaconf import OmegaConf
- from datetime import datetime
- from safetensors import safe_open
- from diffusers import AutoencoderKL
- from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
- from diffusers.utils.import_utils import is_xformers_available
- from transformers import CLIPTextModel, CLIPTokenizer
- from animatediff.models.unet import UNet3DConditionModel
- from animatediff.pipelines.pipeline_animation import AnimationPipeline
- from animatediff.utils.util import save_videos_grid
- from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
- from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
- sample_idx = 0
- scheduler_dict = {
- "Euler": EulerDiscreteScheduler,
- "PNDM": PNDMScheduler,
- "DDIM": DDIMScheduler,
- }
- css = """
- .toolbutton {
- margin-buttom: 0em 0em 0em 0em;
- max-width: 2.5em;
- min-width: 2.5em !important;
- height: 2.5em;
- }
- """
- class AnimateController:
- def __init__(self):
-
- # config dirs
- self.basedir = os.getcwd()
- self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion")
- self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
- self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
- self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
- self.savedir_sample = os.path.join(self.savedir, "sample")
- os.makedirs(self.savedir, exist_ok=True)
- self.stable_diffusion_list = []
- self.motion_module_list = []
- self.personalized_model_list = []
-
- self.refresh_stable_diffusion()
- self.refresh_motion_module()
- self.refresh_personalized_model()
-
- # config models
- self.tokenizer = None
- self.text_encoder = None
- self.vae = None
- self.unet = None
- self.pipeline = None
- self.lora_model_state_dict = {}
-
- self.inference_config = OmegaConf.load("configs/inference/inference.yaml")
- def refresh_stable_diffusion(self):
- self.stable_diffusion_list = glob(os.path.join(self.stable_diffusion_dir, "*/"))
- def refresh_motion_module(self):
- motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
- self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
- def refresh_personalized_model(self):
- personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
- self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
- def update_stable_diffusion(self, stable_diffusion_dropdown):
- self.tokenizer = CLIPTokenizer.from_pretrained(stable_diffusion_dropdown, subfolder="tokenizer")
- self.text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder").cuda()
- self.vae = AutoencoderKL.from_pretrained(stable_diffusion_dropdown, subfolder="vae").cuda()
- self.unet = UNet3DConditionModel.from_pretrained_2d(stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
- return gr.Dropdown.update()
- def update_motion_module(self, motion_module_dropdown):
- if self.unet is None:
- gr.Info(f"Please select a pretrained model path.")
- return gr.Dropdown.update(value=None)
- else:
- motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
- motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
- missing, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False)
- assert len(unexpected) == 0
- return gr.Dropdown.update()
- def update_base_model(self, base_model_dropdown):
- if self.unet is None:
- gr.Info(f"Please select a pretrained model path.")
- return gr.Dropdown.update(value=None)
- else:
- base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
- base_model_state_dict = {}
- with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
- for key in f.keys():
- base_model_state_dict[key] = f.get_tensor(key)
-
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_model_state_dict, self.vae.config)
- self.vae.load_state_dict(converted_vae_checkpoint)
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, self.unet.config)
- self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
- self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
- return gr.Dropdown.update()
- def update_lora_model(self, lora_model_dropdown):
- lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
- self.lora_model_state_dict = {}
- if lora_model_dropdown == "none": pass
- else:
- with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
- for key in f.keys():
- self.lora_model_state_dict[key] = f.get_tensor(key)
- return gr.Dropdown.update()
- def animate(
- self,
- stable_diffusion_dropdown,
- motion_module_dropdown,
- base_model_dropdown,
- lora_alpha_slider,
- prompt_textbox,
- negative_prompt_textbox,
- sampler_dropdown,
- sample_step_slider,
- width_slider,
- length_slider,
- height_slider,
- cfg_scale_slider,
- seed_textbox
- ):
- if self.unet is None:
- raise gr.Error(f"Please select a pretrained model path.")
- if motion_module_dropdown == "":
- raise gr.Error(f"Please select a motion module.")
- if base_model_dropdown == "":
- raise gr.Error(f"Please select a base DreamBooth model.")
- if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
- pipeline = AnimationPipeline(
- vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
- scheduler=scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
- ).to("cuda")
-
- if self.lora_model_state_dict != {}:
- pipeline = convert_lora(pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider)
- pipeline.to("cuda")
- if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
- else: torch.seed()
- seed = torch.initial_seed()
-
- sample = pipeline(
- prompt_textbox,
- negative_prompt = negative_prompt_textbox,
- num_inference_steps = sample_step_slider,
- guidance_scale = cfg_scale_slider,
- width = width_slider,
- height = height_slider,
- video_length = length_slider,
- ).videos
- save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4")
- save_videos_grid(sample, save_sample_path)
-
- sample_config = {
- "prompt": prompt_textbox,
- "n_prompt": negative_prompt_textbox,
- "sampler": sampler_dropdown,
- "num_inference_steps": sample_step_slider,
- "guidance_scale": cfg_scale_slider,
- "width": width_slider,
- "height": height_slider,
- "video_length": length_slider,
- "seed": seed
- }
- json_str = json.dumps(sample_config, indent=4)
- with open(os.path.join(self.savedir, "logs.json"), "a") as f:
- f.write(json_str)
- f.write("\n\n")
-
- return gr.Video.update(value=save_sample_path)
-
- controller = AnimateController()
- def ui():
- with gr.Blocks(css=css) as demo:
- gr.Markdown(
- """
- # [AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725)
- Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai (*Corresponding Author)<br>
- [Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) | [Github](https://github.com/guoyww/animatediff/)
- """
- )
- with gr.Column(variant="panel"):
- gr.Markdown(
- """
- ### 1. Model checkpoints (select pretrained model path first).
- """
- )
- with gr.Row():
- stable_diffusion_dropdown = gr.Dropdown(
- label="Pretrained Model Path",
- choices=controller.stable_diffusion_list,
- interactive=True,
- )
- stable_diffusion_dropdown.change(fn=controller.update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown])
-
- stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
- def update_stable_diffusion():
- controller.refresh_stable_diffusion()
- return gr.Dropdown.update(choices=controller.stable_diffusion_list)
- stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[], outputs=[stable_diffusion_dropdown])
- with gr.Row():
- motion_module_dropdown = gr.Dropdown(
- label="Select motion module",
- choices=controller.motion_module_list,
- interactive=True,
- )
- motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown])
-
- motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
- def update_motion_module():
- controller.refresh_motion_module()
- return gr.Dropdown.update(choices=controller.motion_module_list)
- motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
-
- base_model_dropdown = gr.Dropdown(
- label="Select base Dreambooth model (required)",
- choices=controller.personalized_model_list,
- interactive=True,
- )
- base_model_dropdown.change(fn=controller.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown])
-
- lora_model_dropdown = gr.Dropdown(
- label="Select LoRA model (optional)",
- choices=["none"] + controller.personalized_model_list,
- value="none",
- interactive=True,
- )
- lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown])
-
- lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True)
-
- personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
- def update_personalized_model():
- controller.refresh_personalized_model()
- return [
- gr.Dropdown.update(choices=controller.personalized_model_list),
- gr.Dropdown.update(choices=["none"] + controller.personalized_model_list)
- ]
- personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
- with gr.Column(variant="panel"):
- gr.Markdown(
- """
- ### 2. Configs for AnimateDiff.
- """
- )
-
- prompt_textbox = gr.Textbox(label="Prompt", lines=2)
- negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2)
-
- with gr.Row().style(equal_height=False):
- with gr.Column():
- with gr.Row():
- sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
- sample_step_slider = gr.Slider(label="Sampling steps", value=25, minimum=10, maximum=100, step=1)
-
- width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64)
- height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64)
- length_slider = gr.Slider(label="Animation length", value=16, minimum=8, maximum=24, step=1)
- cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20)
-
- with gr.Row():
- seed_textbox = gr.Textbox(label="Seed", value=-1)
- seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
- seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
-
- generate_button = gr.Button(value="Generate", variant='primary')
-
- result_video = gr.Video(label="Generated Animation", interactive=False)
- generate_button.click(
- fn=controller.animate,
- inputs=[
- stable_diffusion_dropdown,
- motion_module_dropdown,
- base_model_dropdown,
- lora_alpha_slider,
- prompt_textbox,
- negative_prompt_textbox,
- sampler_dropdown,
- sample_step_slider,
- width_slider,
- length_slider,
- height_slider,
- cfg_scale_slider,
- seed_textbox,
- ],
- outputs=[result_video]
- )
-
- return demo
- if __name__ == "__main__":
- demo = ui()
- demo.launch(share=True)
|