Browse Source

optimize memory cost

Yuwei Guo 2 years ago
parent
commit
05fdf470ad
2 changed files with 10 additions and 1 deletions
  1. 6 1
      animatediff/pipelines/pipeline_animation.py
  2. 4 0
      scripts/animate.py

+ 6 - 1
animatediff/pipelines/pipeline_animation.py

@@ -6,6 +6,7 @@ from dataclasses import dataclass
 
 import numpy as np
 import torch
+from tqdm import tqdm
 
 from diffusers.utils import is_accelerate_available
 from packaging import version
@@ -239,7 +240,11 @@ class AnimationPipeline(DiffusionPipeline):
         video_length = latents.shape[2]
         latents = 1 / 0.18215 * latents
         latents = rearrange(latents, "b c f h w -> (b f) c h w")
-        video = self.vae.decode(latents).sample
+        # video = self.vae.decode(latents).sample
+        video = []
+        for frame_idx in tqdm(range(latents.shape[0])):
+            video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
+        video = torch.cat(video)
         video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
         video = (video / 2 + 0.5).clamp(0, 1)
         # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16

+ 4 - 0
scripts/animate.py

@@ -17,6 +17,7 @@ from animatediff.pipelines.pipeline_animation import AnimationPipeline
 from animatediff.utils.util import save_videos_grid
 from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
 from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
+from diffusers.utils.import_utils import is_xformers_available
 
 from einops import rearrange, repeat
 
@@ -51,6 +52,9 @@ def main(args):
             vae          = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")            
             unet         = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
 
+            if is_xformers_available(): unet.enable_xformers_memory_efficient_attention()
+            else: assert False
+
             pipeline = AnimationPipeline(
                 vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
                 scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),