Yuwei Guo 2 tahun lalu
induk
melakukan
4f0efbe11f
4 mengubah file dengan 346 tambahan dan 3 penghapusan
  1. 9 0
      README.md
  2. TEMPAT SAMPAH
      __assets__/figs/gradio.jpg
  3. 337 0
      app.py
  4. 0 3
      configs/inference/inference.yaml

+ 9 - 0
README.md

@@ -124,6 +124,15 @@ Then run the following commands:
 python -m scripts.animate --config [path to the config file]
 ```
 
+## Gradio Demo
+We develop a Gradio demo to support a easier usage. To launch it, run the following commands:
+```
+conda activate animatediff
+python app.py
+```
+By default, the demo will be run at `localhost:7860`.
+<br><img src="__assets__/figs/gradio.jpg" style="width: 50em; margin-top: 1em">
+
 ## Gallery
 Here we demonstrate several best results we found in our experiments.
 

TEMPAT SAMPAH
__assets__/figs/gradio.jpg


+ 337 - 0
app.py

@@ -0,0 +1,337 @@
+import gradio as gr
+import os
+from glob import glob
+import random
+import pdb
+from transformers import CLIPTextModel, CLIPTokenizer
+from animatediff.models.unet import UNet3DConditionModel
+from animatediff.pipelines.pipeline_animation import AnimationPipeline
+
+from diffusers import AutoencoderKL
+from datetime import datetime
+import os
+from omegaconf import OmegaConf
+import json
+import torch
+
+from diffusers import AutoencoderKL
+from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
+
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from animatediff.models.unet import UNet3DConditionModel
+from animatediff.pipelines.pipeline_animation import AnimationPipeline
+from animatediff.utils.util import save_videos_grid
+from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
+from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
+from diffusers.utils.import_utils import is_xformers_available
+
+from safetensors import safe_open
+
+
+sample_idx     = 0
+
+scheduler_dict = {
+    "Euler": EulerDiscreteScheduler,
+    "PNDM": PNDMScheduler,
+    "DDIM": DDIMScheduler,
+}
+
+css = """
+.toolbutton {
+    margin-buttom: 0em 0em 0em 0em;
+    max-width: 2.5em;
+    min-width: 2.5em !important;
+    height: 2.5em;
+}
+"""
+
+
+class AnimateController:
+    def __init__(self):
+        
+        # config dirs
+        self.basedir                = os.getcwd()
+        self.stable_diffusion_dir   = os.path.join(self.basedir, "models", "StableDiffusion")
+        self.motion_module_dir      = os.path.join(self.basedir, "models", "Motion_Module")
+        self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
+        self.savedir                = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
+        self.savedir_sample         = os.path.join(self.savedir, "sample")
+        os.makedirs(self.savedir, exist_ok=True)
+
+        self.stable_diffusion_list   = []
+        self.motion_module_list      = []
+        self.personalized_model_list = []
+        
+        self.refresh_stable_diffusion()
+        self.refresh_motion_module()
+        self.refresh_personalized_model()
+        
+        # config models
+        self.tokenizer             = None
+        self.text_encoder          = None
+        self.vae                   = None
+        self.unet                  = None
+        self.pipeline              = None
+        self.lora_model_state_dict = {}
+        
+        self.inference_config      = OmegaConf.load("configs/inference/inference.yaml")
+
+    def refresh_stable_diffusion(self):
+        self.stable_diffusion_list = glob(os.path.join(self.stable_diffusion_dir, "*/"))
+
+    def refresh_motion_module(self):
+        motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
+        self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
+
+    def refresh_personalized_model(self):
+        personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
+        self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
+
+    def update_stable_diffusion(self, stable_diffusion_dropdown):
+        self.tokenizer = CLIPTokenizer.from_pretrained(stable_diffusion_dropdown, subfolder="tokenizer")
+        self.text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder").cuda()
+        self.vae = AutoencoderKL.from_pretrained(stable_diffusion_dropdown, subfolder="vae").cuda()
+        self.unet = UNet3DConditionModel.from_pretrained_2d(stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
+        return gr.Dropdown.update()
+
+    def update_motion_module(self, motion_module_dropdown):
+        if self.unet is None:
+            gr.Info(f"Please select a pretrained model path.")
+            return gr.Dropdown.update(value=None)
+        else:
+            motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
+            motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
+            missing, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False)
+            assert len(unexpected) == 0
+            return gr.Dropdown.update()
+
+    def update_base_model(self, base_model_dropdown):
+        if self.unet is None:
+            gr.Info(f"Please select a pretrained model path.")
+            return gr.Dropdown.update(value=None)
+        else:
+            base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
+            base_model_state_dict = {}
+            with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
+                for key in f.keys():
+                    base_model_state_dict[key] = f.get_tensor(key)
+                    
+            converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_model_state_dict, self.vae.config)
+            self.vae.load_state_dict(converted_vae_checkpoint)
+
+            converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, self.unet.config)
+            self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
+
+            self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
+            return gr.Dropdown.update()
+
+    def update_lora_model(self, lora_model_dropdown):
+        lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
+        self.lora_model_state_dict = {}
+        if lora_model_dropdown == "none": pass
+        else:
+            with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
+                for key in f.keys():
+                    self.lora_model_state_dict[key] = f.get_tensor(key)
+        return gr.Dropdown.update()
+
+    def animate(
+        self,
+        stable_diffusion_dropdown,
+        motion_module_dropdown,
+        base_model_dropdown,
+        lora_alpha_slider,
+        prompt_textbox, 
+        negative_prompt_textbox, 
+        sampler_dropdown, 
+        sample_step_slider, 
+        width_slider, 
+        length_slider, 
+        height_slider, 
+        cfg_scale_slider, 
+        seed_textbox
+    ):    
+        if self.unet is None:
+            raise gr.Error(f"Please select a pretrained model path.")
+        if motion_module_dropdown == "": 
+            raise gr.Error(f"Please select a motion module.")
+        if base_model_dropdown == "":
+            raise gr.Error(f"Please select a base DreamBooth model.")
+
+        if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
+
+        pipeline = AnimationPipeline(
+            vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
+            scheduler=scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
+        ).to("cuda")
+        
+        if self.lora_model_state_dict != {}:
+            pipeline = convert_lora(pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider)
+
+        pipeline.to("cuda")
+
+        if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
+        else: torch.seed()
+        seed = torch.initial_seed()
+        
+        sample = pipeline(
+            prompt_textbox,
+            negative_prompt     = negative_prompt_textbox,
+            num_inference_steps = sample_step_slider,
+            guidance_scale      = cfg_scale_slider,
+            width               = width_slider,
+            height              = height_slider,
+            video_length        = length_slider,
+        ).videos
+
+        save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4")
+        save_videos_grid(sample, save_sample_path)
+    
+        sample_config = {
+            "prompt": prompt_textbox,
+            "n_prompt": negative_prompt_textbox,
+            "sampler": sampler_dropdown,
+            "num_inference_steps": sample_step_slider,
+            "guidance_scale": cfg_scale_slider,
+            "width": width_slider,
+            "height": height_slider,
+            "video_length": length_slider,
+            "seed": seed
+        }
+        json_str = json.dumps(sample_config, indent=4)
+        with open(os.path.join(self.savedir, "logs.json"), "a") as f:
+            f.write(json_str)
+            f.write("\n\n")
+            
+        return gr.Video.update(value=save_sample_path)
+        
+
+controller = AnimateController()
+
+
+def ui():
+    with gr.Blocks(css=css) as demo:
+        gr.Markdown(
+            """
+            # [AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725)
+            Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai (*Corresponding Author)<br>
+            [Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) | [Github](https://github.com/guoyww/animatediff/)
+            """
+        )
+        with gr.Column(variant="panel"):
+            gr.Markdown(
+                """
+                ### 1. Model checkpoints (select pretrained model path first).
+                """
+            )
+            with gr.Row():
+                stable_diffusion_dropdown = gr.Dropdown(
+                    label="Pretrained Model Path",
+                    choices=controller.stable_diffusion_list,
+                    interactive=True,
+                )
+                stable_diffusion_dropdown.change(fn=controller.update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown])
+                
+                stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
+                def update_stable_diffusion():
+                    controller.refresh_stable_diffusion()
+                    return gr.Dropdown.update(choices=controller.stable_diffusion_list)
+                stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[], outputs=[stable_diffusion_dropdown])
+
+            with gr.Row():
+                motion_module_dropdown = gr.Dropdown(
+                    label="Select motion module",
+                    choices=controller.motion_module_list,
+                    interactive=True,
+                )
+                motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown])
+                
+                motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
+                def update_motion_module():
+                    controller.refresh_motion_module()
+                    return gr.Dropdown.update(choices=controller.motion_module_list)
+                motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
+                
+                base_model_dropdown = gr.Dropdown(
+                    label="Select base Dreambooth model (required)",
+                    choices=controller.personalized_model_list,
+                    interactive=True,
+                )
+                base_model_dropdown.change(fn=controller.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown])
+                
+                lora_model_dropdown = gr.Dropdown(
+                    label="Select LoRA model (optional)",
+                    choices=["none"] + controller.personalized_model_list,
+                    value="none",
+                    interactive=True,
+                )
+                lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown])
+                
+                lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True)
+                
+                personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
+                def update_personalized_model():
+                    controller.refresh_personalized_model()
+                    return [
+                        gr.Dropdown.update(choices=controller.personalized_model_list),
+                        gr.Dropdown.update(choices=["none"] + controller.personalized_model_list)
+                    ]
+                personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
+
+        with gr.Column(variant="panel"):
+            gr.Markdown(
+                """
+                ### 2. Configs for AnimateDiff.
+                """
+            )
+            
+            prompt_textbox = gr.Textbox(label="Prompt", lines=2)
+            negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2)
+                
+            with gr.Row().style(equal_height=False):
+                with gr.Column():
+                    with gr.Row():
+                        sampler_dropdown   = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
+                        sample_step_slider = gr.Slider(label="Sampling steps", value=25, minimum=10, maximum=100, step=1)
+                        
+                    width_slider     = gr.Slider(label="Width",            value=512, minimum=256, maximum=1024, step=64)
+                    height_slider    = gr.Slider(label="Height",           value=512, minimum=256, maximum=1024, step=64)
+                    length_slider    = gr.Slider(label="Animation length", value=16,  minimum=8,   maximum=24,   step=1)
+                    cfg_scale_slider = gr.Slider(label="CFG Scale",        value=7.5, minimum=0,   maximum=20)
+                    
+                    with gr.Row():
+                        seed_textbox = gr.Textbox(label="Seed", value=-1)
+                        seed_button  = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
+                        seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
+            
+                    generate_button = gr.Button(value="Generate", variant='primary')
+                    
+                result_video = gr.Video(label="Generated Animation", interactive=False)
+
+            generate_button.click(
+                fn=controller.animate,
+                inputs=[
+                    stable_diffusion_dropdown,
+                    motion_module_dropdown,
+                    base_model_dropdown,
+                    lora_alpha_slider,
+                    prompt_textbox, 
+                    negative_prompt_textbox, 
+                    sampler_dropdown, 
+                    sample_step_slider, 
+                    width_slider, 
+                    length_slider, 
+                    height_slider, 
+                    cfg_scale_slider, 
+                    seed_textbox,
+                ],
+                outputs=[result_video]
+            )
+            
+    return demo
+
+
+if __name__ == "__main__":
+    demo = ui()
+    demo.launch(share=True)

+ 0 - 3
configs/inference/inference.yaml

@@ -21,9 +21,6 @@ unet_additional_kwargs:
     temporal_attention_dim_div: 1
 
 noise_scheduler_kwargs:
-  num_train_timesteps: 1000
   beta_start: 0.00085
   beta_end: 0.012
   beta_schedule: "linear"
-  steps_offset: 1
-  clip_sample: false