Browse Source

feat: add lora-script

tingsong lu 1 year ago
parent
commit
36e739b7c1

+ 532 - 0
.ipynb_checkpoints/test_train-checkpoint.py

@@ -0,0 +1,532 @@
+import os
+import math
+import wandb
+import random
+import logging
+import inspect
+import argparse
+import datetime
+import subprocess
+
+from pathlib import Path
+from tqdm.auto import tqdm
+from einops import rearrange
+from omegaconf import OmegaConf
+from safetensors import safe_open
+from typing import Dict, Optional, Tuple
+
+import torch
+import torchvision
+import torch.nn.functional as F
+import torch.distributed as dist
+from torch.optim.swa_utils import AveragedModel
+from torch.utils.data.distributed import DistributedSampler
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+import diffusers
+from diffusers import AutoencoderKL, DDIMScheduler
+from diffusers.models import UNet2DConditionModel
+from diffusers.pipelines import StableDiffusionPipeline
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version
+from diffusers.utils.import_utils import is_xformers_available
+
+import transformers
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from animatediff.data.dataset import WebVid10M
+from animatediff.models.unet import UNet3DConditionModel
+from animatediff.pipelines.pipeline_animation import AnimationPipeline
+from animatediff.utils.util import save_videos_grid, zero_rank_print
+from animatediff.models.lora import LoRAModule
+
+
+
+def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs):
+    """Initializes distributed environment."""
+    if launcher == 'pytorch':
+        rank = int(os.environ['RANK'])
+        num_gpus = torch.cuda.device_count()
+        local_rank = rank % num_gpus
+        torch.cuda.set_device(local_rank)
+        dist.init_process_group(backend=backend, **kwargs)
+        
+    elif launcher == 'slurm':
+        proc_id = int(os.environ['SLURM_PROCID'])
+        ntasks = int(os.environ['SLURM_NTASKS'])
+        node_list = os.environ['SLURM_NODELIST']
+        num_gpus = torch.cuda.device_count()
+        local_rank = proc_id % num_gpus
+        torch.cuda.set_device(local_rank)
+        addr = subprocess.getoutput(
+            f'scontrol show hostname {node_list} | head -n1')
+        os.environ['MASTER_ADDR'] = addr
+        os.environ['WORLD_SIZE'] = str(ntasks)
+        os.environ['RANK'] = str(proc_id)
+        port = os.environ.get('PORT', port)
+        os.environ['MASTER_PORT'] = str(port)
+        dist.init_process_group(backend=backend)
+        zero_rank_print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}")
+        
+    else:
+        raise NotImplementedError(f'Not implemented launcher type: `{launcher}`!')
+    
+    return local_rank
+
+
+
+def main(
+    image_finetune: bool,
+    
+    name: str,
+    use_wandb: bool,
+    launcher: str,
+    
+    output_dir: str,
+    pretrained_model_path: str,
+
+    train_data: Dict,
+    validation_data: Dict,
+    cfg_random_null_text: bool = True,
+    cfg_random_null_text_ratio: float = 0.1,
+    
+    unet_checkpoint_path: str = "",
+    unet_additional_kwargs: Dict = {},
+    ema_decay: float = 0.9999,
+    noise_scheduler_kwargs = None,
+    
+    max_train_epoch: int = -1,
+    max_train_steps: int = 100,
+    validation_steps: int = 100,
+    validation_steps_tuple: Tuple = (-1,),
+
+    learning_rate: float = 3e-5,
+    scale_lr: bool = False,
+    lr_warmup_steps: int = 0,
+    lr_scheduler: str = "constant",
+
+    trainable_modules: Tuple[str] = (None, ),
+    num_workers: int = 32,
+    train_batch_size: int = 1,
+    adam_beta1: float = 0.9,
+    adam_beta2: float = 0.999,
+    adam_weight_decay: float = 1e-2,
+    adam_epsilon: float = 1e-08,
+    max_grad_norm: float = 1.0,
+    gradient_accumulation_steps: int = 1,
+    gradient_checkpointing: bool = False,
+    checkpointing_epochs: int = 5,
+    checkpointing_steps: int = -1,
+
+    mixed_precision_training: bool = True,
+    enable_xformers_memory_efficient_attention: bool = True,
+
+    global_seed: int = 42,
+    is_debug: bool = False,
+):
+    check_min_version("0.10.0.dev0")
+
+    # Initialize distributed training
+    local_rank      = init_dist(launcher=launcher)
+    global_rank     = dist.get_rank()
+    num_processes   = dist.get_world_size()
+    is_main_process = global_rank == 0
+
+    seed = global_seed + global_rank
+    torch.manual_seed(seed)
+    
+    # Logging folder
+    folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S")
+    output_dir = os.path.join(output_dir, folder_name)
+    if is_debug and os.path.exists(output_dir):
+        os.system(f"rm -rf {output_dir}")
+
+    *_, config = inspect.getargvalues(inspect.currentframe())
+
+    # Make one log on every process with the configuration for debugging.
+    logging.basicConfig(
+        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+        datefmt="%m/%d/%Y %H:%M:%S",
+        level=logging.INFO,
+    )
+
+    if is_main_process and (not is_debug) and use_wandb:
+        run = wandb.init(project="animatediff", name=folder_name, config=config)
+
+    # Handle the output folder creation
+    if is_main_process:
+        os.makedirs(output_dir, exist_ok=True)
+        os.makedirs(f"{output_dir}/samples", exist_ok=True)
+        os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
+        os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
+        OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
+
+    # Load scheduler, tokenizer and models.
+    noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
+
+    vae          = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
+    tokenizer    = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
+    text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
+    if not image_finetune:
+        unet = UNet3DConditionModel.from_pretrained_2d(
+            pretrained_model_path, subfolder="unet", 
+            unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs)
+        )
+    else:
+        unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet")
+        
+    # Load pretrained unet weights
+    if unet_checkpoint_path != "":
+        zero_rank_print(f"from checkpoint: {unet_checkpoint_path}")
+        unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
+        if "global_step" in unet_checkpoint_path: zero_rank_print(f"global_step: {unet_checkpoint_path['global_step']}")
+        state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path
+
+        m, u = unet.load_state_dict(state_dict, strict=False)
+        zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+        assert len(u) == 0
+        
+    loras = []
+    for name, module in unet.named_modules():
+        if "to_q" in name or "to_k" in name or "to_v" in name or "to_out.0" in name:
+            lora_name = name + "_lora" 
+            print(lora_name)
+            lora = LoRAModule(lora_name,module)
+            loras.append(lora)
+            # print(module.__class__.__name__)
+    print("enable LoRA for U-Net")
+    for lora in loras:
+        lora.apply_to()
+    
+    # Freeze vae and text_encoder
+    vae.requires_grad_(False)
+    text_encoder.requires_grad_(False)
+    
+    # Set unet trainable parameters
+    unet.requires_grad_(False)
+    
+    for lora in loras:
+        lora.requires_grad = True
+    input()
+    
+    # for name, param in unet.named_parameters():
+    #     for trainable_module_name in trainable_modules:
+    #         if trainable_module_name in name:
+    #             param.requires_grad = True
+    #             break
+            
+    trainable_params = trainable_params = [param for lora in loras for param in lora.parameters() if param.requires_grad]
+    optimizer = torch.optim.AdamW(
+        trainable_params,
+        lr=learning_rate,
+        betas=(adam_beta1, adam_beta2),
+        weight_decay=adam_weight_decay,
+        eps=adam_epsilon,
+    )
+
+    if is_main_process:
+        zero_rank_print(f"trainable params number: {len(trainable_params)}")
+        zero_rank_print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
+
+    # Enable xformers
+    # if enable_xformers_memory_efficient_attention:
+    #     if is_xformers_available():
+    #         unet.enable_xformers_memory_efficient_attention()
+    #     else:
+    #         raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+    # Enable gradient checkpointing
+    if gradient_checkpointing:
+        unet.enable_gradient_checkpointing()
+
+    # Move models to GPU
+    vae.to(local_rank)
+    text_encoder.to(local_rank)
+
+    # Get the training dataset
+    train_dataset = WebVid10M(**train_data, is_image=image_finetune)
+    distributed_sampler = DistributedSampler(
+        train_dataset,
+        num_replicas=num_processes,
+        rank=global_rank,
+        shuffle=True,
+        seed=global_seed,
+    )
+
+    # DataLoaders creation:
+    train_dataloader = torch.utils.data.DataLoader(
+        train_dataset,
+        batch_size=train_batch_size,
+        shuffle=False,
+        sampler=distributed_sampler,
+        num_workers=num_workers,
+        pin_memory=True,
+        drop_last=True,
+    )
+
+    # Get the training iteration
+    if max_train_steps == -1:
+        assert max_train_epoch != -1
+        max_train_steps = max_train_epoch * len(train_dataloader)
+        
+    if checkpointing_steps == -1:
+        assert checkpointing_epochs != -1
+        checkpointing_steps = checkpointing_epochs * len(train_dataloader)
+
+    if scale_lr:
+        learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes)
+
+    # Scheduler
+    lr_scheduler = get_scheduler(
+        lr_scheduler,
+        optimizer=optimizer,
+        num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
+        num_training_steps=max_train_steps * gradient_accumulation_steps,
+    )
+
+    # Validation pipeline
+    if not image_finetune:
+        validation_pipeline = AnimationPipeline(
+            unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler,
+        ).to("cuda")
+    else:
+        validation_pipeline = StableDiffusionPipeline.from_pretrained(
+            pretrained_model_path,
+            unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, safety_checker=None,
+        )
+    validation_pipeline.enable_vae_slicing()
+
+    class LoRAModuleWrapper(torch.nn.Module):
+        def __init__(self, loras):
+            super().__init__()
+            for i, lora in enumerate(loras):
+                module_name = lora.lora_name.replace('.', '-')
+                self.add_module(module_name, lora)
+    loras_wrapper = LoRAModuleWrapper(loras).to(local_rank)
+    
+    # DDP warpper
+    unet.to(local_rank)
+    loras_ddp = DDP(loras_wrapper, device_ids=[local_rank], output_device=local_rank)
+
+    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
+    # Afterwards we recalculate our number of training epochs
+    num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
+
+    # Train!
+    total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps
+
+    if is_main_process:
+        logging.info("***** Running training *****")
+        logging.info(f"  Num examples = {len(train_dataset)}")
+        logging.info(f"  Num Epochs = {num_train_epochs}")
+        logging.info(f"  Instantaneous batch size per device = {train_batch_size}")
+        logging.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+        logging.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
+        logging.info(f"  Total optimization steps = {max_train_steps}")
+    global_step = 0
+    first_epoch = 0
+
+    # Only show the progress bar once on each machine.
+    progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process)
+    progress_bar.set_description("Steps")
+
+    # Support mixed-precision training
+    scaler = torch.cuda.amp.GradScaler() if mixed_precision_training else None
+
+    for epoch in range(first_epoch, num_train_epochs):
+        train_dataloader.sampler.set_epoch(epoch)
+        unet.train()
+        
+        for step, batch in enumerate(train_dataloader):
+            if cfg_random_null_text:
+                batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']]
+                
+            # Data batch sanity check
+            if epoch == first_epoch and step == 0:
+                pixel_values, texts = batch['pixel_values'].cpu(), batch['text']
+                if not image_finetune:
+                    pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
+                    for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
+                        pixel_value = pixel_value[None, ...]
+                        save_videos_grid(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.gif", rescale=True)
+                else:
+                    for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
+                        pixel_value = pixel_value / 2. + 0.5
+                        torchvision.utils.save_image(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.png")
+                    
+            ### >>>> Training >>>> ###
+            
+            # Convert videos to latent space            
+            pixel_values = batch["pixel_values"].to(local_rank)
+            video_length = pixel_values.shape[1]
+            with torch.no_grad():
+                if not image_finetune:
+                    pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
+                    latents = vae.encode(pixel_values).latent_dist
+                    latents = latents.sample()
+                    latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
+                else:
+                    latents = vae.encode(pixel_values).latent_dist
+                    latents = latents.sample()
+
+                latents = latents * 0.18215
+
+            # Sample noise that we'll add to the latents
+            noise = torch.randn_like(latents)
+            bsz = latents.shape[0]
+            
+            # Sample a random timestep for each video
+            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+            timesteps = timesteps.long()
+            
+            # Add noise to the latents according to the noise magnitude at each timestep
+            # (this is the forward diffusion process)
+            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+            
+            # Get the text embedding for conditioning
+            with torch.no_grad():
+                prompt_ids = tokenizer(
+                    batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+                ).input_ids.to(latents.device)
+                encoder_hidden_states = text_encoder(prompt_ids)[0]
+                
+            # Get the target for loss depending on the prediction type
+            if noise_scheduler.config.prediction_type == "epsilon":
+                target = noise
+            elif noise_scheduler.config.prediction_type == "v_prediction":
+                raise NotImplementedError
+            else:
+                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+            # Predict the noise residual and compute loss
+            # Mixed-precision training
+            with torch.cuda.amp.autocast(enabled=mixed_precision_training):
+                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+            optimizer.zero_grad()
+
+            # Backpropagate
+            if mixed_precision_training:
+                scaler.scale(loss).backward()
+                """ >>> gradient clipping >>> """
+                scaler.unscale_(optimizer)
+                torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
+                """ <<< gradient clipping <<< """
+                scaler.step(optimizer)
+                scaler.update()
+            else:
+                loss.backward()
+                """ >>> gradient clipping >>> """
+                torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
+                """ <<< gradient clipping <<< """
+                optimizer.step()
+
+            lr_scheduler.step()
+            progress_bar.update(1)
+            global_step += 1
+            
+#             for lora in loras:
+#                 print(f"lora_down grad: {lora.lora_down.weight.grad}")
+#                 print(f"lora_up grad: {lora.lora_up.weight.grad}")
+
+#                 print(f"lora_down requires_grad: {lora.lora_down.weight.requires_grad}")
+#                 print(f"lora_up requires_grad: {lora.lora_up.weight.requires_grad}")
+
+#                 print(f"lora_down data: {lora.lora_down.weight.data}")
+#                 print(f"lora_up data: {lora.lora_up.weight.data}")
+#                 input()
+
+            
+            ### <<<< Training <<<< ###
+            
+            # Wandb logging
+            if is_main_process and (not is_debug) and use_wandb:
+                wandb.log({"train_loss": loss.item()}, step=global_step)
+                
+            # Save checkpoint
+            if is_main_process and (global_step % checkpointing_steps == 0 or step == len(train_dataloader) - 1):
+                lora_save_path = os.path.join(output_dir, f"checkpoints", "lora")
+                os.makedirs(lora_save_path, exist_ok=True)
+                loras_state_dict = {
+                    "epoch": epoch,
+                    "global_step": global_step,
+                    "state_dict": loras_wrapper.state_dict(),
+                }
+                if step == len(train_dataloader) - 1:
+                    torch.save(loras_state_dict, os.path.join(lora_save_path, f"checkpoint-epoch-{epoch+1}.ckpt"))
+                else:
+                    torch.save(loras_state_dict, os.path.join(lora_save_path, f"checkpoint.ckpt"))
+                logging.info(f"Saved LoRA state to {lora_save_path} (global_step: {global_step})")
+                
+            # Periodically validation
+            if is_main_process and (global_step % validation_steps == 0 or global_step in validation_steps_tuple):
+                samples = []
+                
+                generator = torch.Generator(device=latents.device)
+                generator.manual_seed(global_seed)
+                
+                height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size
+                width  = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size
+
+                prompts = validation_data.prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.prompts
+
+                for idx, prompt in enumerate(prompts):
+                    if not image_finetune:
+                        sample = validation_pipeline(
+                            prompt,
+                            generator    = generator,
+                            video_length = train_data.sample_n_frames,
+                            height       = height,
+                            width        = width,
+                            **validation_data,
+                        ).videos
+                        save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif")
+                        samples.append(sample)
+                        
+                    else:
+                        sample = validation_pipeline(
+                            prompt,
+                            generator           = generator,
+                            height              = height,
+                            width               = width,
+                            num_inference_steps = validation_data.get("num_inference_steps", 25),
+                            guidance_scale      = validation_data.get("guidance_scale", 8.),
+                        ).images[0]
+                        sample = torchvision.transforms.functional.to_tensor(sample)
+                        samples.append(sample)
+                
+                if not image_finetune:
+                    samples = torch.concat(samples)
+                    save_path = f"{output_dir}/samples/sample-{global_step}.gif"
+                    save_videos_grid(samples, save_path)
+                    
+                else:
+                    samples = torch.stack(samples)
+                    save_path = f"{output_dir}/samples/sample-{global_step}.png"
+                    torchvision.utils.save_image(samples, save_path, nrow=4)
+
+                logging.info(f"Saved samples to {save_path}")
+                
+            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+            progress_bar.set_postfix(**logs)
+            
+            if global_step >= max_train_steps:
+                break
+            
+    dist.destroy_process_group()
+
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--config",   type=str, required=True)
+    parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch")
+    parser.add_argument("--wandb",    action="store_true")
+    args = parser.parse_args()
+
+    name   = Path(args.config).stem
+    config = OmegaConf.load(args.config)
+
+    main(name=name, launcher=args.launcher, use_wandb=args.wandb, **config)

+ 493 - 0
.ipynb_checkpoints/train-checkpoint.py

@@ -0,0 +1,493 @@
+import os
+import math
+import wandb
+import random
+import logging
+import inspect
+import argparse
+import datetime
+import subprocess
+
+from pathlib import Path
+from tqdm.auto import tqdm
+from einops import rearrange
+from omegaconf import OmegaConf
+from safetensors import safe_open
+from typing import Dict, Optional, Tuple
+
+import torch
+import torchvision
+import torch.nn.functional as F
+import torch.distributed as dist
+from torch.optim.swa_utils import AveragedModel
+from torch.utils.data.distributed import DistributedSampler
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+import diffusers
+from diffusers import AutoencoderKL, DDIMScheduler
+from diffusers.models import UNet2DConditionModel
+from diffusers.pipelines import StableDiffusionPipeline
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version
+from diffusers.utils.import_utils import is_xformers_available
+
+import transformers
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from animatediff.data.dataset import WebVid10M
+from animatediff.models.unet import UNet3DConditionModel
+from animatediff.pipelines.pipeline_animation import AnimationPipeline
+from animatediff.utils.util import save_videos_grid, zero_rank_print
+
+
+
+def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs):
+    """Initializes distributed environment."""
+    if launcher == 'pytorch':
+        rank = int(os.environ['RANK'])
+        num_gpus = torch.cuda.device_count()
+        local_rank = rank % num_gpus
+        torch.cuda.set_device(local_rank)
+        dist.init_process_group(backend=backend, **kwargs)
+        
+    elif launcher == 'slurm':
+        proc_id = int(os.environ['SLURM_PROCID'])
+        ntasks = int(os.environ['SLURM_NTASKS'])
+        node_list = os.environ['SLURM_NODELIST']
+        num_gpus = torch.cuda.device_count()
+        local_rank = proc_id % num_gpus
+        torch.cuda.set_device(local_rank)
+        addr = subprocess.getoutput(
+            f'scontrol show hostname {node_list} | head -n1')
+        os.environ['MASTER_ADDR'] = addr
+        os.environ['WORLD_SIZE'] = str(ntasks)
+        os.environ['RANK'] = str(proc_id)
+        port = os.environ.get('PORT', port)
+        os.environ['MASTER_PORT'] = str(port)
+        dist.init_process_group(backend=backend)
+        zero_rank_print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}")
+        
+    else:
+        raise NotImplementedError(f'Not implemented launcher type: `{launcher}`!')
+    
+    return local_rank
+
+
+
+def main(
+    image_finetune: bool,
+    
+    name: str,
+    use_wandb: bool,
+    launcher: str,
+    
+    output_dir: str,
+    pretrained_model_path: str,
+
+    train_data: Dict,
+    validation_data: Dict,
+    cfg_random_null_text: bool = True,
+    cfg_random_null_text_ratio: float = 0.1,
+    
+    unet_checkpoint_path: str = "",
+    unet_additional_kwargs: Dict = {},
+    ema_decay: float = 0.9999,
+    noise_scheduler_kwargs = None,
+    
+    max_train_epoch: int = -1,
+    max_train_steps: int = 100,
+    validation_steps: int = 100,
+    validation_steps_tuple: Tuple = (-1,),
+
+    learning_rate: float = 3e-5,
+    scale_lr: bool = False,
+    lr_warmup_steps: int = 0,
+    lr_scheduler: str = "constant",
+
+    trainable_modules: Tuple[str] = (None, ),
+    num_workers: int = 32,
+    train_batch_size: int = 1,
+    adam_beta1: float = 0.9,
+    adam_beta2: float = 0.999,
+    adam_weight_decay: float = 1e-2,
+    adam_epsilon: float = 1e-08,
+    max_grad_norm: float = 1.0,
+    gradient_accumulation_steps: int = 1,
+    gradient_checkpointing: bool = False,
+    checkpointing_epochs: int = 5,
+    checkpointing_steps: int = -1,
+
+    mixed_precision_training: bool = True,
+    enable_xformers_memory_efficient_attention: bool = True,
+
+    global_seed: int = 42,
+    is_debug: bool = False,
+):
+    check_min_version("0.10.0.dev0")
+
+    # Initialize distributed training
+    local_rank      = init_dist(launcher=launcher)
+    global_rank     = dist.get_rank()
+    num_processes   = dist.get_world_size()
+    is_main_process = global_rank == 0
+
+    seed = global_seed + global_rank
+    torch.manual_seed(seed)
+    
+    # Logging folder
+    folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S")
+    output_dir = os.path.join(output_dir, folder_name)
+    if is_debug and os.path.exists(output_dir):
+        os.system(f"rm -rf {output_dir}")
+
+    *_, config = inspect.getargvalues(inspect.currentframe())
+
+    # Make one log on every process with the configuration for debugging.
+    logging.basicConfig(
+        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+        datefmt="%m/%d/%Y %H:%M:%S",
+        level=logging.INFO,
+    )
+
+    if is_main_process and (not is_debug) and use_wandb:
+        run = wandb.init(project="animatediff", name=folder_name, config=config)
+
+    # Handle the output folder creation
+    if is_main_process:
+        os.makedirs(output_dir, exist_ok=True)
+        os.makedirs(f"{output_dir}/samples", exist_ok=True)
+        os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
+        os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
+        OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
+
+    # Load scheduler, tokenizer and models.
+    noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
+
+    vae          = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
+    tokenizer    = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
+    text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
+    if not image_finetune:
+        unet = UNet3DConditionModel.from_pretrained_2d(
+            pretrained_model_path, subfolder="unet", 
+            unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs)
+        )
+    else:
+        unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet")
+        
+    # Load pretrained unet weights
+    if unet_checkpoint_path != "":
+        zero_rank_print(f"from checkpoint: {unet_checkpoint_path}")
+        unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
+        if "global_step" in unet_checkpoint_path: zero_rank_print(f"global_step: {unet_checkpoint_path['global_step']}")
+        state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path
+
+        m, u = unet.load_state_dict(state_dict, strict=False)
+        zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+        assert len(u) == 0
+        
+    # Freeze vae and text_encoder
+    vae.requires_grad_(False)
+    text_encoder.requires_grad_(False)
+    
+    # Set unet trainable parameters
+    unet.requires_grad_(False)
+    for name, param in unet.named_parameters():
+        for trainable_module_name in trainable_modules:
+            if trainable_module_name in name:
+                param.requires_grad = True
+                break
+            
+    trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters()))
+    optimizer = torch.optim.AdamW(
+        trainable_params,
+        lr=learning_rate,
+        betas=(adam_beta1, adam_beta2),
+        weight_decay=adam_weight_decay,
+        eps=adam_epsilon,
+    )
+
+    if is_main_process:
+        zero_rank_print(f"trainable params number: {len(trainable_params)}")
+        zero_rank_print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
+
+    # Enable xformers
+    if enable_xformers_memory_efficient_attention:
+        if is_xformers_available():
+            unet.enable_xformers_memory_efficient_attention()
+        else:
+            raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+    # Enable gradient checkpointing
+    if gradient_checkpointing:
+        unet.enable_gradient_checkpointing()
+
+    # Move models to GPU
+    vae.to(local_rank)
+    text_encoder.to(local_rank)
+
+    # Get the training dataset
+    train_dataset = WebVid10M(**train_data, is_image=image_finetune)
+    distributed_sampler = DistributedSampler(
+        train_dataset,
+        num_replicas=num_processes,
+        rank=global_rank,
+        shuffle=True,
+        seed=global_seed,
+    )
+
+    # DataLoaders creation:
+    train_dataloader = torch.utils.data.DataLoader(
+        train_dataset,
+        batch_size=train_batch_size,
+        shuffle=False,
+        sampler=distributed_sampler,
+        num_workers=num_workers,
+        pin_memory=True,
+        drop_last=True,
+    )
+
+    # Get the training iteration
+    if max_train_steps == -1:
+        assert max_train_epoch != -1
+        max_train_steps = max_train_epoch * len(train_dataloader)
+        
+    if checkpointing_steps == -1:
+        assert checkpointing_epochs != -1
+        checkpointing_steps = checkpointing_epochs * len(train_dataloader)
+
+    if scale_lr:
+        learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes)
+
+    # Scheduler
+    lr_scheduler = get_scheduler(
+        lr_scheduler,
+        optimizer=optimizer,
+        num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
+        num_training_steps=max_train_steps * gradient_accumulation_steps,
+    )
+
+    # Validation pipeline
+    if not image_finetune:
+        validation_pipeline = AnimationPipeline(
+            unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler,
+        ).to("cuda")
+    else:
+        validation_pipeline = StableDiffusionPipeline.from_pretrained(
+            pretrained_model_path,
+            unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, safety_checker=None,
+        )
+    validation_pipeline.enable_vae_slicing()
+
+    # DDP warpper
+    unet.to(local_rank)
+    unet = DDP(unet, device_ids=[local_rank], output_device=local_rank)
+
+    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
+    # Afterwards we recalculate our number of training epochs
+    num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
+
+    # Train!
+    total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps
+
+    if is_main_process:
+        logging.info("***** Running training *****")
+        logging.info(f"  Num examples = {len(train_dataset)}")
+        logging.info(f"  Num Epochs = {num_train_epochs}")
+        logging.info(f"  Instantaneous batch size per device = {train_batch_size}")
+        logging.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+        logging.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
+        logging.info(f"  Total optimization steps = {max_train_steps}")
+    global_step = 0
+    first_epoch = 0
+
+    # Only show the progress bar once on each machine.
+    progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process)
+    progress_bar.set_description("Steps")
+
+    # Support mixed-precision training
+    scaler = torch.cuda.amp.GradScaler() if mixed_precision_training else None
+
+    for epoch in range(first_epoch, num_train_epochs):
+        train_dataloader.sampler.set_epoch(epoch)
+        unet.train()
+        
+        for step, batch in enumerate(train_dataloader):
+            if cfg_random_null_text:
+                batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']]
+                
+            # Data batch sanity check
+            if epoch == first_epoch and step == 0:
+                pixel_values, texts = batch['pixel_values'].cpu(), batch['text']
+                if not image_finetune:
+                    pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
+                    for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
+                        pixel_value = pixel_value[None, ...]
+                        save_videos_grid(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.gif", rescale=True)
+                else:
+                    for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
+                        pixel_value = pixel_value / 2. + 0.5
+                        torchvision.utils.save_image(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.png")
+                    
+            ### >>>> Training >>>> ###
+            
+            # Convert videos to latent space            
+            pixel_values = batch["pixel_values"].to(local_rank)
+            video_length = pixel_values.shape[1]
+            with torch.no_grad():
+                if not image_finetune:
+                    pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
+                    latents = vae.encode(pixel_values).latent_dist
+                    latents = latents.sample()
+                    latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
+                else:
+                    latents = vae.encode(pixel_values).latent_dist
+                    latents = latents.sample()
+
+                latents = latents * 0.18215
+
+            # Sample noise that we'll add to the latents
+            noise = torch.randn_like(latents)
+            bsz = latents.shape[0]
+            
+            # Sample a random timestep for each video
+            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+            timesteps = timesteps.long()
+            
+            # Add noise to the latents according to the noise magnitude at each timestep
+            # (this is the forward diffusion process)
+            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+            
+            # Get the text embedding for conditioning
+            with torch.no_grad():
+                prompt_ids = tokenizer(
+                    batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+                ).input_ids.to(latents.device)
+                encoder_hidden_states = text_encoder(prompt_ids)[0]
+                
+            # Get the target for loss depending on the prediction type
+            if noise_scheduler.config.prediction_type == "epsilon":
+                target = noise
+            elif noise_scheduler.config.prediction_type == "v_prediction":
+                raise NotImplementedError
+            else:
+                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+            # Predict the noise residual and compute loss
+            # Mixed-precision training
+            with torch.cuda.amp.autocast(enabled=mixed_precision_training):
+                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+            optimizer.zero_grad()
+
+            # Backpropagate
+            if mixed_precision_training:
+                scaler.scale(loss).backward()
+                """ >>> gradient clipping >>> """
+                scaler.unscale_(optimizer)
+                torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
+                """ <<< gradient clipping <<< """
+                scaler.step(optimizer)
+                scaler.update()
+            else:
+                loss.backward()
+                """ >>> gradient clipping >>> """
+                torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
+                """ <<< gradient clipping <<< """
+                optimizer.step()
+
+            lr_scheduler.step()
+            progress_bar.update(1)
+            global_step += 1
+            
+            ### <<<< Training <<<< ###
+            
+            # Wandb logging
+            if is_main_process and (not is_debug) and use_wandb:
+                wandb.log({"train_loss": loss.item()}, step=global_step)
+                
+            # Save checkpoint
+            if is_main_process and (global_step % checkpointing_steps == 0 or step == len(train_dataloader) - 1):
+                save_path = os.path.join(output_dir, f"checkpoints")
+                state_dict = {
+                    "epoch": epoch,
+                    "global_step": global_step,
+                    "state_dict": unet.state_dict(),
+                }
+                if step == len(train_dataloader) - 1:
+                    torch.save(state_dict, os.path.join(save_path, f"checkpoint-epoch-{epoch+1}.ckpt"))
+                else:
+                    torch.save(state_dict, os.path.join(save_path, f"checkpoint.ckpt"))
+                logging.info(f"Saved state to {save_path} (global_step: {global_step})")
+                
+            # Periodically validation
+            if is_main_process and (global_step % validation_steps == 0 or global_step in validation_steps_tuple):
+                samples = []
+                
+                generator = torch.Generator(device=latents.device)
+                generator.manual_seed(global_seed)
+                
+                height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size
+                width  = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size
+
+                prompts = validation_data.prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.prompts
+
+                for idx, prompt in enumerate(prompts):
+                    if not image_finetune:
+                        sample = validation_pipeline(
+                            prompt,
+                            generator    = generator,
+                            video_length = train_data.sample_n_frames,
+                            height       = height,
+                            width        = width,
+                            **validation_data,
+                        ).videos
+                        save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif")
+                        samples.append(sample)
+                        
+                    else:
+                        sample = validation_pipeline(
+                            prompt,
+                            generator           = generator,
+                            height              = height,
+                            width               = width,
+                            num_inference_steps = validation_data.get("num_inference_steps", 25),
+                            guidance_scale      = validation_data.get("guidance_scale", 8.),
+                        ).images[0]
+                        sample = torchvision.transforms.functional.to_tensor(sample)
+                        samples.append(sample)
+                
+                if not image_finetune:
+                    samples = torch.concat(samples)
+                    save_path = f"{output_dir}/samples/sample-{global_step}.gif"
+                    save_videos_grid(samples, save_path)
+                    
+                else:
+                    samples = torch.stack(samples)
+                    save_path = f"{output_dir}/samples/sample-{global_step}.png"
+                    torchvision.utils.save_image(samples, save_path, nrow=4)
+
+                logging.info(f"Saved samples to {save_path}")
+                
+            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+            progress_bar.set_postfix(**logs)
+            
+            if global_step >= max_train_steps:
+                break
+            
+    dist.destroy_process_group()
+
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--config",   type=str, required=True)
+    parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch")
+    parser.add_argument("--wandb",    action="store_true")
+    args = parser.parse_args()
+
+    name   = Path(args.config).stem
+    config = OmegaConf.load(args.config)
+
+    main(name=name, launcher=args.launcher, use_wandb=args.wandb, **config)

+ 84 - 0
animatediff/models/.ipynb_checkpoints/lora-checkpoint.py

@@ -0,0 +1,84 @@
+import torch
+import math
+
+class LoRAModule(torch.nn.Module):
+    """
+    replaces forward method of the original Linear, instead of replacing the original Linear module.
+    """
+
+    def __init__(
+        self,
+        lora_name,
+        org_module: torch.nn.Module,
+        multiplier=1.0,
+        lora_dim=64,
+        alpha=32,
+        dropout=None,
+        rank_dropout=None,
+        module_dropout=None,
+    ):
+        """if alpha == 0 or None, alpha is rank (no scaling)."""
+        super().__init__()
+        self.lora_name = lora_name
+
+        in_dim = org_module.in_features
+        out_dim = org_module.out_features
+
+        self.lora_dim = lora_dim
+
+        self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
+        self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
+
+        if type(alpha) == torch.Tensor:
+            alpha = alpha.detach().float().numpy()  # without casting, bf16 causes error
+        alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
+        self.scale = alpha / self.lora_dim
+        self.register_buffer("alpha", torch.tensor(alpha))  # 定数として扱える
+
+        # same as microsoft's
+        torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
+        torch.nn.init.zeros_(self.lora_up.weight)
+
+        self.multiplier = multiplier
+        self.org_module = org_module  # remove in applying
+        self.dropout = dropout
+        self.rank_dropout = rank_dropout
+        self.module_dropout = module_dropout
+
+    def apply_to(self):
+        self.org_forward = self.org_module.forward
+        self.org_module.forward = self.forward
+        del self.org_module
+
+    def forward(self, x):
+        org_forwarded = self.org_forward(x)
+
+        # module dropout
+        if self.module_dropout is not None and self.training:
+            if torch.rand(1) < self.module_dropout:
+                return org_forwarded
+
+        lx = self.lora_down(x)
+
+        # normal dropout
+        if self.dropout is not None and self.training:
+            lx = torch.nn.functional.dropout(lx, p=self.dropout)
+
+        # rank dropout
+        if self.rank_dropout is not None and self.training:
+            mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
+            if len(lx.size()) == 3:
+                mask = mask.unsqueeze(1)  # for Text Encoder
+            elif len(lx.size()) == 4:
+                mask = mask.unsqueeze(-1).unsqueeze(-1)  # for Conv2d
+            lx = lx * mask
+
+            # scaling for rank dropout: treat as if the rank is changed
+            scale = self.scale * (1.0 / (1.0 - self.rank_dropout))  # redundant for readability
+        else:
+            scale = self.scale
+
+        lx = self.lora_up(lx)
+
+        return org_forwarded + lx * self.multiplier * scale
+    

+ 84 - 0
animatediff/models/lora.py

@@ -0,0 +1,84 @@
+import torch
+import math
+
+class LoRAModule(torch.nn.Module):
+    """
+    replaces forward method of the original Linear, instead of replacing the original Linear module.
+    """
+
+    def __init__(
+        self,
+        lora_name,
+        org_module: torch.nn.Module,
+        multiplier=1.0,
+        lora_dim=64,
+        alpha=32,
+        dropout=None,
+        rank_dropout=None,
+        module_dropout=None,
+    ):
+        """if alpha == 0 or None, alpha is rank (no scaling)."""
+        super().__init__()
+        self.lora_name = lora_name
+
+        in_dim = org_module.in_features
+        out_dim = org_module.out_features
+
+        self.lora_dim = lora_dim
+
+        self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
+        self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
+
+        if type(alpha) == torch.Tensor:
+            alpha = alpha.detach().float().numpy()  # without casting, bf16 causes error
+        alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
+        self.scale = alpha / self.lora_dim
+        self.register_buffer("alpha", torch.tensor(alpha))  # 定数として扱える
+
+        # same as microsoft's
+        torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
+        torch.nn.init.zeros_(self.lora_up.weight)
+
+        self.multiplier = multiplier
+        self.org_module = org_module  # remove in applying
+        self.dropout = dropout
+        self.rank_dropout = rank_dropout
+        self.module_dropout = module_dropout
+
+    def apply_to(self):
+        self.org_forward = self.org_module.forward
+        self.org_module.forward = self.forward
+        del self.org_module
+
+    def forward(self, x):
+        org_forwarded = self.org_forward(x)
+
+        # module dropout
+        if self.module_dropout is not None and self.training:
+            if torch.rand(1) < self.module_dropout:
+                return org_forwarded
+
+        lx = self.lora_down(x)
+
+        # normal dropout
+        if self.dropout is not None and self.training:
+            lx = torch.nn.functional.dropout(lx, p=self.dropout)
+
+        # rank dropout
+        if self.rank_dropout is not None and self.training:
+            mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
+            if len(lx.size()) == 3:
+                mask = mask.unsqueeze(1)  # for Text Encoder
+            elif len(lx.size()) == 4:
+                mask = mask.unsqueeze(-1).unsqueeze(-1)  # for Conv2d
+            lx = lx * mask
+
+            # scaling for rank dropout: treat as if the rank is changed
+            scale = self.scale * (1.0 / (1.0 - self.rank_dropout))  # redundant for readability
+        else:
+            scale = self.scale
+
+        lx = self.lora_up(lx)
+
+        return org_forwarded + lx * self.multiplier * scale
+    

+ 182 - 0
animatediff/utils/.ipynb_checkpoints/convert_lora_safetensor_to_diffusers-checkpoint.py

@@ -0,0 +1,182 @@
+# coding=utf-8
+# Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""" Conversion script for the LoRA's safetensors checkpoints. """
+
+import argparse
+
+import torch
+from safetensors.torch import load_file
+
+from diffusers import StableDiffusionPipeline
+import pdb
+
+
+
+def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
+    # directly update weight in diffusers model
+    for key in state_dict:
+        # only process lora down key
+        if "up." in key: continue
+
+        up_key    = key.replace(".down.", ".up.")
+        model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
+        model_key = model_key.replace("to_out.", "to_out.0.")
+        layer_infos = model_key.split(".")[:-1]
+
+        curr_layer = pipeline.unet
+        while len(layer_infos) > 0:
+            temp_name = layer_infos.pop(0)
+            curr_layer = curr_layer.__getattr__(temp_name)
+
+        weight_down = state_dict[key]
+        weight_up   = state_dict[up_key]
+        curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
+
+    return pipeline
+
+def convert_motion_lora_ckpt_to_diffusers_test(pipeline, state_dict, alpha=1.0):
+    # directly update weight in diffusers model
+    for key in state_dict:
+        if "weight" in key:
+            # only process lora down key
+            if "up." in key: continue
+            
+            up_key    = key.replace("_down.", "_up.")
+            model_key = key.replace('-', '.').replace("_lora", "").replace("lora_down.", "").replace("lora_up.", "")
+            print(up_key)
+            print(key)
+            print(model_key)
+            layer_infos = model_key.split(".")[:-1]
+            curr_layer = pipeline.unet
+            while len(layer_infos) > 0:
+                temp_name = layer_infos.pop(0)
+                curr_layer = curr_layer.__getattr__(temp_name)
+
+            weight_down = state_dict[key]
+            weight_up   = state_dict[up_key]
+            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
+            print(weight_down)
+            print(weight_up)
+            print("------")
+            print(curr_layer)
+
+    return pipeline
+
+
+
+def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
+    # load base model
+    # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
+
+    # load LoRA weight from .safetensors
+    # state_dict = load_file(checkpoint_path)
+
+    visited = []
+
+    # directly update weight in diffusers model
+    for key in state_dict:
+        # it is suggested to print out the key, it usually will be something like below
+        # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
+
+        # as we have set the alpha beforehand, so just skip
+        if ".alpha" in key or key in visited:
+            continue
+
+        if "text" in key:
+            layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
+            curr_layer = pipeline.text_encoder
+        else:
+            layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
+            curr_layer = pipeline.unet
+
+        # find the target layer
+        temp_name = layer_infos.pop(0)
+        while len(layer_infos) > -1:
+            try:
+                curr_layer = curr_layer.__getattr__(temp_name)
+                if len(layer_infos) > 0:
+                    temp_name = layer_infos.pop(0)
+                elif len(layer_infos) == 0:
+                    break
+            except Exception:
+                if len(temp_name) > 0:
+                    temp_name += "_" + layer_infos.pop(0)
+                else:
+                    temp_name = layer_infos.pop(0)
+
+        pair_keys = []
+        if "lora_down" in key:
+            pair_keys.append(key.replace("lora_down", "lora_up"))
+            pair_keys.append(key)
+        else:
+            pair_keys.append(key)
+            pair_keys.append(key.replace("lora_up", "lora_down"))
+
+        # update weight
+        if len(state_dict[pair_keys[0]].shape) == 4:
+            weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
+            weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
+            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
+        else:
+            weight_up = state_dict[pair_keys[0]].to(torch.float32)
+            weight_down = state_dict[pair_keys[1]].to(torch.float32)
+            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
+
+        # update visited list
+        for item in pair_keys:
+            visited.append(item)
+
+    return pipeline
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument(
+        "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
+    )
+    parser.add_argument(
+        "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
+    )
+    parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
+    parser.add_argument(
+        "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
+    )
+    parser.add_argument(
+        "--lora_prefix_text_encoder",
+        default="lora_te",
+        type=str,
+        help="The prefix of text encoder weight in safetensors",
+    )
+    parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
+    parser.add_argument(
+        "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
+    )
+    parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
+
+    args = parser.parse_args()
+
+    base_model_path = args.base_model_path
+    checkpoint_path = args.checkpoint_path
+    dump_path = args.dump_path
+    lora_prefix_unet = args.lora_prefix_unet
+    lora_prefix_text_encoder = args.lora_prefix_text_encoder
+    alpha = args.alpha
+
+    pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
+
+    pipe = pipe.to(args.device)
+    pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)

+ 159 - 0
animatediff/utils/.ipynb_checkpoints/util-checkpoint.py

@@ -0,0 +1,159 @@
+import os
+import imageio
+import numpy as np
+from typing import Union
+
+import torch
+import torchvision
+import torch.distributed as dist
+
+from safetensors import safe_open
+from tqdm import tqdm
+from einops import rearrange
+from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
+from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers, convert_motion_lora_ckpt_to_diffusers_test
+
+
+def zero_rank_print(s):
+    if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
+
+
+def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
+    videos = rearrange(videos, "b c t h w -> t b c h w")
+    outputs = []
+    for x in videos:
+        x = torchvision.utils.make_grid(x, nrow=n_rows)
+        x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+        if rescale:
+            x = (x + 1.0) / 2.0  # -1,1 -> 0,1
+        x = (x * 255).numpy().astype(np.uint8)
+        outputs.append(x)
+
+    os.makedirs(os.path.dirname(path), exist_ok=True)
+    imageio.mimsave(path, outputs, fps=fps)
+
+
+# DDIM Inversion
+@torch.no_grad()
+def init_prompt(prompt, pipeline):
+    uncond_input = pipeline.tokenizer(
+        [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
+        return_tensors="pt"
+    )
+    uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
+    text_input = pipeline.tokenizer(
+        [prompt],
+        padding="max_length",
+        max_length=pipeline.tokenizer.model_max_length,
+        truncation=True,
+        return_tensors="pt",
+    )
+    text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
+    context = torch.cat([uncond_embeddings, text_embeddings])
+
+    return context
+
+
+def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
+              sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
+    timestep, next_timestep = min(
+        timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
+    alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
+    alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
+    beta_prod_t = 1 - alpha_prod_t
+    next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
+    next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
+    next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
+    return next_sample
+
+
+def get_noise_pred_single(latents, t, context, unet):
+    noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
+    return noise_pred
+
+
+@torch.no_grad()
+def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
+    context = init_prompt(prompt, pipeline)
+    uncond_embeddings, cond_embeddings = context.chunk(2)
+    all_latent = [latent]
+    latent = latent.clone().detach()
+    for i in tqdm(range(num_inv_steps)):
+        t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
+        noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
+        latent = next_step(noise_pred, t, latent, ddim_scheduler)
+        all_latent.append(latent)
+    return all_latent
+
+
+@torch.no_grad()
+def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
+    ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
+    return ddim_latents
+
+def load_weights(
+    animation_pipeline,
+    # motion module
+    motion_module_path         = "",
+    motion_module_lora_configs = [],
+    # image layers
+    dreambooth_model_path = "",
+    lora_model_path       = "",
+    lora_alpha            = 0.8,
+):
+    # 1.1 motion module
+    unet_state_dict = {}
+    if motion_module_path != "":
+        print(f"load motion module from {motion_module_path}")
+        motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
+        motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
+        unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
+    
+    missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
+    assert len(unexpected) == 0
+    del unet_state_dict
+
+    if dreambooth_model_path != "":
+        print(f"load dreambooth model from {dreambooth_model_path}")
+        if dreambooth_model_path.endswith(".safetensors"):
+            dreambooth_state_dict = {}
+            with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f:
+                for key in f.keys():
+                    dreambooth_state_dict[key] = f.get_tensor(key)
+        elif dreambooth_model_path.endswith(".ckpt"):
+            dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu")
+            
+        # 1. vae
+        converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config)
+        animation_pipeline.vae.load_state_dict(converted_vae_checkpoint)
+        # 2. unet
+        converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config)
+        animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
+        # 3. text_model
+        animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict)
+        del dreambooth_state_dict
+        
+    if lora_model_path != "":
+        print(f"load lora model from {lora_model_path}")
+        assert lora_model_path.endswith(".safetensors")
+        lora_state_dict = {}
+        with safe_open(lora_model_path, framework="pt", device="cpu") as f:
+            for key in f.keys():
+                lora_state_dict[key] = f.get_tensor(key)
+                
+        animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha)
+        del lora_state_dict
+
+
+    for motion_module_lora_config in motion_module_lora_configs:
+        path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
+        print(f"load motion LoRA from {path}")
+
+        motion_lora_state_dict = torch.load(path, map_location="cpu")
+        # print(motion_lora_state_dict)
+        # input()
+        motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
+
+        animation_pipeline = convert_motion_lora_ckpt_to_diffusers_test(animation_pipeline, motion_lora_state_dict, alpha)
+
+    return animation_pipeline

+ 28 - 0
animatediff/utils/convert_lora_safetensor_to_diffusers.py

@@ -47,6 +47,34 @@ def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
 
     return pipeline
 
+def convert_motion_lora_ckpt_to_diffusers_test(pipeline, state_dict, alpha=1.0):
+    # directly update weight in diffusers model
+    for key in state_dict:
+        if "weight" in key:
+            # only process lora down key
+            if "up." in key: continue
+            
+            up_key    = key.replace("_down.", "_up.")
+            model_key = key.replace('-', '.').replace("_lora", "").replace("lora_down.", "").replace("lora_up.", "")
+            print(up_key)
+            print(key)
+            print(model_key)
+            layer_infos = model_key.split(".")[:-1]
+            curr_layer = pipeline.unet
+            while len(layer_infos) > 0:
+                temp_name = layer_infos.pop(0)
+                curr_layer = curr_layer.__getattr__(temp_name)
+
+            weight_down = state_dict[key]
+            weight_up   = state_dict[up_key]
+            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
+            print(weight_down)
+            print(weight_up)
+            print("------")
+            print(curr_layer)
+
+    return pipeline
+
 
 
 def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):

+ 4 - 2
animatediff/utils/util.py

@@ -11,7 +11,7 @@ from safetensors import safe_open
 from tqdm import tqdm
 from einops import rearrange
 from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
-from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers
+from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers, convert_motion_lora_ckpt_to_diffusers_test
 
 
 def zero_rank_print(s):
@@ -150,8 +150,10 @@ def load_weights(
         print(f"load motion LoRA from {path}")
 
         motion_lora_state_dict = torch.load(path, map_location="cpu")
+        # print(motion_lora_state_dict)
+        # input()
         motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
 
-        animation_pipeline = convert_motion_lora_ckpt_to_diffusers(animation_pipeline, motion_lora_state_dict, alpha)
+        animation_pipeline = convert_motion_lora_ckpt_to_diffusers_test(animation_pipeline, motion_lora_state_dict, alpha)
 
     return animation_pipeline

+ 26 - 0
configs/inference/.ipynb_checkpoints/inference-v1-checkpoint.yaml

@@ -0,0 +1,26 @@
+unet_additional_kwargs:
+  unet_use_cross_frame_attention: false
+  unet_use_temporal_attention: false
+  use_motion_module: true
+  motion_module_resolutions:
+  - 1
+  - 2
+  - 4
+  - 8
+  motion_module_mid_block: false
+  motion_module_decoder_only: false
+  motion_module_type: Vanilla
+  motion_module_kwargs:
+    num_attention_heads: 8
+    num_transformer_block: 1
+    attention_block_types:
+    - Temporal_Self
+    - Temporal_Self
+    temporal_position_encoding: true
+    temporal_position_encoding_max_len: 24
+    temporal_attention_dim_div: 1
+
+noise_scheduler_kwargs:
+  beta_start: 0.00085
+  beta_end: 0.012
+  beta_schedule: "linear"

+ 27 - 0
configs/inference/.ipynb_checkpoints/inference-v2-checkpoint.yaml

@@ -0,0 +1,27 @@
+unet_additional_kwargs:
+  use_inflated_groupnorm: true
+  unet_use_cross_frame_attention: false
+  unet_use_temporal_attention: false
+  use_motion_module: true
+  motion_module_resolutions:
+  - 1
+  - 2
+  - 4
+  - 8
+  motion_module_mid_block: true
+  motion_module_decoder_only: false
+  motion_module_type: Vanilla
+  motion_module_kwargs:
+    num_attention_heads: 8
+    num_transformer_block: 1
+    attention_block_types:
+    - Temporal_Self
+    - Temporal_Self
+    temporal_position_encoding: true
+    temporal_position_encoding_max_len: 32
+    temporal_attention_dim_div: 1
+
+noise_scheduler_kwargs:
+  beta_start: 0.00085
+  beta_end: 0.012
+  beta_schedule: "linear"

+ 26 - 0
configs/prompts/.ipynb_checkpoints/1-ToonYou-checkpoint.yaml

@@ -0,0 +1,26 @@
+ToonYou:
+  motion_module:
+    - "models/Motion_Module/mm_sd_v15_v2.ckpt"
+    
+  motion_module_lora_configs:
+    - path:  "/root/autodl-tmp/AnimateDiff/models/MotionLoRA/zoom_in_24/checkpoint-epoch-200.ckpt"
+      alpha: 1.0
+
+  dreambooth_path: "models/DreamBooth_LoRA/toonyou_beta6.safetensors"
+  lora_model_path: ""
+
+  seed:           [10788741199826055526, 6520604954829636163, 6519455744612555650, 16372571278361863751]
+  steps:          25
+  guidance_scale: 7.5
+
+  prompt:
+    - "best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress"
+    - "masterpiece, best quality, 1girl, solo, cherry blossoms, hanami, pink flower, white flower, spring season, wisteria, petals, flower, plum blossoms, outdoors, falling petals, white hair, black eyes,"
+    - "best quality, masterpiece, 1boy, formal, abstract, looking at viewer, masculine, marble pattern"
+    - "best quality, masterpiece, 1girl, cloudy sky, dandelion, contrapposto, alternate hairstyle,"
+
+  n_prompt:
+    - "(worst quality, low quality:2),NSFW,monochrome,zombie,overexposure,watermark,text,bad anatomy,bad hand,extra hands,extra fingers,too many fingers,fused fingers,bad arm,distorted arm,extra arms,fused arms,extra legs,missing leg,disembodied leg,extra nipples,detached arm,liquid hand,inverted hand,disembodied limb,small breasts,loli,oversized head,extra body,completely nude,extra navel,EasyNegative,(hair between eyes),sketch,duplicate,ugly,huge eyes,text,logo,worst face,(bad and mutated hands:1.3),(blurry:2),horror,geometry,bad_prompt,(bad hands),(missing fingers),multiple limbs,bad anatomy,(interlocked fingers:1.2),Ugly Fingers,(extra digit and hands and fingers and legs and arms:1.4),((2girl)),(deformed fingers:1.2),(long fingers:1.2),(bad-artist-anime),bad-artist,bad hand,extra legs,"
+    - "(worst quality, low quality:2),NSFW,monochrome,zombie,overexposure,watermark,text,bad anatomy,bad hand,extra hands,extra fingers,too many fingers,fused fingers,bad arm,distorted arm,extra arms,fused arms,extra legs,missing leg,disembodied leg,extra nipples,detached arm,liquid hand,inverted hand,disembodied limb,small breasts,loli,oversized head,extra body,completely nude,extra navel,EasyNegative,(hair between eyes),sketch,duplicate,ugly,huge eyes,text,logo,worst face,(bad and mutated hands:1.3),(blurry:2),horror,geometry,bad_prompt,(bad hands),(missing fingers),multiple limbs,bad anatomy,(interlocked fingers:1.2),Ugly Fingers,(extra digit and hands and fingers and legs and arms:1.4),((2girl)),(deformed fingers:1.2),(long fingers:1.2),(bad-artist-anime),bad-artist,bad hand,extra legs,"
+    - "(worst quality, low quality:2),NSFW,monochrome,zombie,overexposure,watermark,text,bad anatomy,bad hand,extra hands,extra fingers,too many fingers,fused fingers,bad arm,distorted arm,extra arms,fused arms,extra legs,missing leg,disembodied leg,extra nipples,detached arm,liquid hand,inverted hand,disembodied limb,small breasts,loli,oversized head,extra body,completely nude,extra navel,EasyNegative,(hair between eyes),sketch,duplicate,ugly,huge eyes,text,logo,worst face,(bad and mutated hands:1.3),(blurry:2),horror,geometry,bad_prompt,(bad hands),(missing fingers),multiple limbs,bad anatomy,(interlocked fingers:1.2),Ugly Fingers,(extra digit and hands and fingers and legs and arms:1.4),((2girl)),(deformed fingers:1.2),(long fingers:1.2),(bad-artist-anime),bad-artist,bad hand,extra legs,"
+    - "(worst quality, low quality:2),NSFW,monochrome,zombie,overexposure,watermark,text,bad anatomy,bad hand,extra hands,extra fingers,too many fingers,fused fingers,bad arm,distorted arm,extra arms,fused arms,extra legs,missing leg,disembodied leg,extra nipples,detached arm,liquid hand,inverted hand,disembodied limb,small breasts,loli,oversized head,extra body,completely nude,extra navel,EasyNegative,(hair between eyes),sketch,duplicate,ugly,huge eyes,text,logo,worst face,(bad and mutated hands:1.3),(blurry:2),horror,geometry,bad_prompt,(bad hands),(missing fingers),multiple limbs,bad anatomy,(interlocked fingers:1.2),Ugly Fingers,(extra digit and hands and fingers and legs and arms:1.4),((2girl)),(deformed fingers:1.2),(long fingers:1.2),(bad-artist-anime),bad-artist,bad hand,extra legs,"

+ 23 - 0
configs/prompts/.ipynb_checkpoints/2-Lyriel-checkpoint.yaml

@@ -0,0 +1,23 @@
+Lyriel:
+  motion_module:
+    - "models/Motion_Module/mm_sd_v14.ckpt"
+    - "models/Motion_Module/mm_sd_v15.ckpt"
+
+  dreambooth_path: "models/DreamBooth_LoRA/lyriel_v16.safetensors"
+  lora_model_path: ""
+
+  seed:           [10917152860782582783, 6399018107401806238, 15875751942533906793, 6653196880059936551]
+  steps:          25
+  guidance_scale: 7.5
+
+  prompt:
+    - "dark shot, epic realistic, portrait of halo, sunglasses, blue eyes, tartan scarf, white hair by atey ghailan, by greg rutkowski, by greg tocchini, by james gilleard, by joe fenton, by kaethe butcher, gradient yellow, black, brown and magenta color scheme, grunge aesthetic!!! graffiti tag wall background, art by greg rutkowski and artgerm, soft cinematic light, adobe lightroom, photolab, hdr, intricate, highly detailed, depth of field, faded, neutral colors, hdr, muted colors, hyperdetailed, artstation, cinematic, warm lights, dramatic light, intricate details, complex background, rutkowski, teal and orange"
+    - "A forbidden castle high up in the mountains, pixel art, intricate details2, hdr, intricate details, hyperdetailed5, natural skin texture, hyperrealism, soft light, sharp, game art, key visual, surreal"
+    - "dark theme, medieval portrait of a man sharp features, grim, cold stare, dark colors, Volumetric lighting, baroque oil painting by Greg Rutkowski, Artgerm, WLOP, Alphonse Mucha dynamic lighting hyperdetailed intricately detailed, hdr, muted colors, complex background, hyperrealism, hyperdetailed, amandine van ray"
+    - "As I have gone alone in there and with my treasures bold, I can keep my secret where and hint of riches new and old. Begin it where warm waters halt and take it in a canyon down, not far but too far to walk, put in below the home of brown."
+
+  n_prompt:
+    - "3d, cartoon, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, young, loli, elf, 3d, illustration"
+    - "3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular"
+    - "dof, grayscale, black and white, bw, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular,badhandsv5-neg, By bad artist -neg 1, monochrome"
+    - "holding an item, cowboy, hat, cartoon, 3d, disfigured, bad art, deformed,extra limbs,close up,b&w, wierd colors, blurry, duplicate, morbid, mutilated, [out of frame], extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, out of frame, ugly, extra limbs, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, Photoshop, video game, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, 3d render"

+ 10 - 7
configs/prompts/1-ToonYou.yaml

@@ -1,9 +1,12 @@
 ToonYou:
   motion_module:
-    - "models/Motion_Module/mm_sd_v14.ckpt"
-    - "models/Motion_Module/mm_sd_v15.ckpt"
+    - "models/Motion_Module/mm_sd_v15_v2.ckpt"
+    
+  motion_module_lora_configs:
+    - path:  "/root/autodl-tmp/AnimateDiff/models/MotionLoRA/zoom_in_24/checkpoint-epoch-200.ckpt"
+      alpha: 1.0
 
-  dreambooth_path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors"
+  dreambooth_path: "models/DreamBooth_LoRA/toonyou_beta6.safetensors"
   lora_model_path: ""
 
   seed:           [10788741199826055526, 6520604954829636163, 6519455744612555650, 16372571278361863751]
@@ -17,7 +20,7 @@ ToonYou:
     - "best quality, masterpiece, 1girl, cloudy sky, dandelion, contrapposto, alternate hairstyle,"
 
   n_prompt:
-    - ""
-    - "badhandv4,easynegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, teeth"
-    - ""
-    - ""
+    - "(worst quality, low quality:2),NSFW,monochrome,zombie,overexposure,watermark,text,bad anatomy,bad hand,extra hands,extra fingers,too many fingers,fused fingers,bad arm,distorted arm,extra arms,fused arms,extra legs,missing leg,disembodied leg,extra nipples,detached arm,liquid hand,inverted hand,disembodied limb,small breasts,loli,oversized head,extra body,completely nude,extra navel,EasyNegative,(hair between eyes),sketch,duplicate,ugly,huge eyes,text,logo,worst face,(bad and mutated hands:1.3),(blurry:2),horror,geometry,bad_prompt,(bad hands),(missing fingers),multiple limbs,bad anatomy,(interlocked fingers:1.2),Ugly Fingers,(extra digit and hands and fingers and legs and arms:1.4),((2girl)),(deformed fingers:1.2),(long fingers:1.2),(bad-artist-anime),bad-artist,bad hand,extra legs,"
+    - "(worst quality, low quality:2),NSFW,monochrome,zombie,overexposure,watermark,text,bad anatomy,bad hand,extra hands,extra fingers,too many fingers,fused fingers,bad arm,distorted arm,extra arms,fused arms,extra legs,missing leg,disembodied leg,extra nipples,detached arm,liquid hand,inverted hand,disembodied limb,small breasts,loli,oversized head,extra body,completely nude,extra navel,EasyNegative,(hair between eyes),sketch,duplicate,ugly,huge eyes,text,logo,worst face,(bad and mutated hands:1.3),(blurry:2),horror,geometry,bad_prompt,(bad hands),(missing fingers),multiple limbs,bad anatomy,(interlocked fingers:1.2),Ugly Fingers,(extra digit and hands and fingers and legs and arms:1.4),((2girl)),(deformed fingers:1.2),(long fingers:1.2),(bad-artist-anime),bad-artist,bad hand,extra legs,"
+    - "(worst quality, low quality:2),NSFW,monochrome,zombie,overexposure,watermark,text,bad anatomy,bad hand,extra hands,extra fingers,too many fingers,fused fingers,bad arm,distorted arm,extra arms,fused arms,extra legs,missing leg,disembodied leg,extra nipples,detached arm,liquid hand,inverted hand,disembodied limb,small breasts,loli,oversized head,extra body,completely nude,extra navel,EasyNegative,(hair between eyes),sketch,duplicate,ugly,huge eyes,text,logo,worst face,(bad and mutated hands:1.3),(blurry:2),horror,geometry,bad_prompt,(bad hands),(missing fingers),multiple limbs,bad anatomy,(interlocked fingers:1.2),Ugly Fingers,(extra digit and hands and fingers and legs and arms:1.4),((2girl)),(deformed fingers:1.2),(long fingers:1.2),(bad-artist-anime),bad-artist,bad hand,extra legs,"
+    - "(worst quality, low quality:2),NSFW,monochrome,zombie,overexposure,watermark,text,bad anatomy,bad hand,extra hands,extra fingers,too many fingers,fused fingers,bad arm,distorted arm,extra arms,fused arms,extra legs,missing leg,disembodied leg,extra nipples,detached arm,liquid hand,inverted hand,disembodied limb,small breasts,loli,oversized head,extra body,completely nude,extra navel,EasyNegative,(hair between eyes),sketch,duplicate,ugly,huge eyes,text,logo,worst face,(bad and mutated hands:1.3),(blurry:2),horror,geometry,bad_prompt,(bad hands),(missing fingers),multiple limbs,bad anatomy,(interlocked fingers:1.2),Ugly Fingers,(extra digit and hands and fingers and legs and arms:1.4),((2girl)),(deformed fingers:1.2),(long fingers:1.2),(bad-artist-anime),bad-artist,bad hand,extra legs,"

+ 189 - 0
configs/prompts/v2/.ipynb_checkpoints/5-RealisticVision-MotionLoRA-checkpoint.yaml

@@ -0,0 +1,189 @@
+ZoomIn:
+  inference_config: "configs/inference/inference-v2.yaml"
+  motion_module:
+    - "models/Motion_Module/mm_sd_v15_v2.ckpt"
+
+  motion_module_lora_configs:
+    - path:  "models/MotionLoRA/v2_lora_ZoomIn.ckpt"
+      alpha: 1.0
+
+  dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
+  lora_model_path: ""
+
+  seed:           45987230
+  steps:          25
+  guidance_scale: 7.5
+
+  prompt:
+    - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
+
+  n_prompt:
+    - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
+
+
+
+ZoomOut:
+  inference_config: "configs/inference/inference-v2.yaml"
+  motion_module:
+    - "models/Motion_Module/mm_sd_v15_v2.ckpt"
+
+  motion_module_lora_configs:
+    - path:  "models/MotionLoRA/v2_lora_ZoomOut.ckpt"
+      alpha: 1.0
+
+  dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
+  lora_model_path: ""
+
+  seed:           45987230
+  steps:          25
+  guidance_scale: 7.5
+
+  prompt:
+    - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
+
+  n_prompt:
+    - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
+
+
+
+PanLeft:
+  inference_config: "configs/inference/inference-v2.yaml"
+  motion_module:
+    - "models/Motion_Module/mm_sd_v15_v2.ckpt"
+
+  motion_module_lora_configs:
+    - path:  "models/MotionLoRA/v2_lora_PanLeft.ckpt"
+      alpha: 1.0
+
+  dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
+  lora_model_path: ""
+
+  seed:           45987230
+  steps:          25
+  guidance_scale: 7.5
+
+  prompt:
+    - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
+
+  n_prompt:
+    - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
+
+
+
+PanRight:
+  inference_config: "configs/inference/inference-v2.yaml"
+  motion_module:
+    - "models/Motion_Module/mm_sd_v15_v2.ckpt"
+
+  motion_module_lora_configs:
+    - path:  "models/MotionLoRA/v2_lora_PanRight.ckpt"
+      alpha: 1.0
+
+  dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
+  lora_model_path: ""
+
+  seed:           45987230
+  steps:          25
+  guidance_scale: 7.5
+
+  prompt:
+    - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
+
+  n_prompt:
+    - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
+
+
+
+TiltUp:
+  inference_config: "configs/inference/inference-v2.yaml"
+  motion_module:
+    - "models/Motion_Module/mm_sd_v15_v2.ckpt"
+
+  motion_module_lora_configs:
+    - path:  "models/MotionLoRA/v2_lora_TiltUp.ckpt"
+      alpha: 1.0
+
+  dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
+  lora_model_path: ""
+
+  seed:           45987230
+  steps:          25
+  guidance_scale: 7.5
+
+  prompt:
+    - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
+
+  n_prompt:
+    - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
+
+
+
+TiltDown:
+  inference_config: "configs/inference/inference-v2.yaml"
+  motion_module:
+    - "models/Motion_Module/mm_sd_v15_v2.ckpt"
+
+  motion_module_lora_configs:
+    - path:  "models/MotionLoRA/v2_lora_TiltDown.ckpt"
+      alpha: 1.0
+
+  dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
+  lora_model_path: ""
+
+  seed:           45987230
+  steps:          25
+  guidance_scale: 7.5
+
+  prompt:
+    - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
+
+  n_prompt:
+    - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
+
+
+
+RollingAnticlockwise:
+  inference_config: "configs/inference/inference-v2.yaml"
+  motion_module:
+    - "models/Motion_Module/mm_sd_v15_v2.ckpt"
+
+  motion_module_lora_configs:
+    - path:  "models/MotionLoRA/v2_lora_RollingAnticlockwise.ckpt"
+      alpha: 1.0
+
+  dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
+  lora_model_path: ""
+
+  seed:           45987230
+  steps:          25
+  guidance_scale: 7.5
+
+  prompt:
+    - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
+
+  n_prompt:
+    - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
+
+
+
+RollingClockwise:
+  inference_config: "configs/inference/inference-v2.yaml"
+  motion_module:
+    - "models/Motion_Module/mm_sd_v15_v2.ckpt"
+
+  motion_module_lora_configs:
+    - path:  "models/MotionLoRA/v2_lora_RollingClockwise.ckpt"
+      alpha: 1.0
+
+  dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
+  lora_model_path: ""
+
+  seed:           45987230
+  steps:          25
+  guidance_scale: 7.5
+
+  prompt:
+    - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
+
+  n_prompt:
+    - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"

+ 66 - 0
configs/training/.ipynb_checkpoints/training-checkpoint.yaml

@@ -0,0 +1,66 @@
+image_finetune: false
+
+output_dir: "outputs"
+pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5"
+
+unet_additional_kwargs:
+  use_motion_module              : true
+  motion_module_resolutions      : [ 1,2,4,8 ]
+  unet_use_cross_frame_attention : false
+  unet_use_temporal_attention    : false
+
+  motion_module_type: Vanilla
+  motion_module_kwargs:
+    num_attention_heads                : 8
+    num_transformer_block              : 1
+    attention_block_types              : [ "Temporal_Self", "Temporal_Self" ]
+    temporal_position_encoding         : true
+    temporal_position_encoding_max_len : 32
+    temporal_attention_dim_div         : 1
+    zero_initialize                    : true
+
+noise_scheduler_kwargs:
+  num_train_timesteps: 1000
+  beta_start:          0.00085
+  beta_end:            0.012
+  beta_schedule:       "linear"
+  steps_offset:        1
+  clip_sample:         false
+
+train_data:
+  csv_path:        "/root/autodl-tmp/zoom_in_24.csv"
+  video_folder:    "/root/autodl-tmp/data"
+  sample_size:     256
+  sample_stride:   4
+  sample_n_frames: 16
+
+validation_data:
+  prompts:
+    - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons."
+    - "A drone view of celebration with Christma tree and fireworks, starry sky - background."
+    - "Robot dancing in times square."
+    - "Pacific coast, carmel by the sea ocean and waves."
+  num_inference_steps: 25
+  guidance_scale: 8.
+
+trainable_modules:
+  - "motion_modules."
+
+unet_checkpoint_path: "/root/autodl-tmp/mm_sd_v15_v2.ckpt"
+
+learning_rate:    1.e-4
+train_batch_size: 1
+
+max_train_epoch:      1000
+max_train_steps:      50000
+checkpointing_epochs: 10
+checkpointing_steps:  60
+
+validation_steps:       5000
+validation_steps_tuple: [2, 50]
+
+global_seed: 42
+mixed_precision_training: true
+enable_xformers_memory_efficient_attention: True
+
+is_debug: False

+ 9 - 9
configs/training/training.yaml

@@ -15,7 +15,7 @@ unet_additional_kwargs:
     num_transformer_block              : 1
     attention_block_types              : [ "Temporal_Self", "Temporal_Self" ]
     temporal_position_encoding         : true
-    temporal_position_encoding_max_len : 24
+    temporal_position_encoding_max_len : 32
     temporal_attention_dim_div         : 1
     zero_initialize                    : true
 
@@ -28,8 +28,8 @@ noise_scheduler_kwargs:
   clip_sample:         false
 
 train_data:
-  csv_path:        "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv"
-  video_folder:    "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val"
+  csv_path:        "/root/autodl-tmp/zoom_in_24.csv"
+  video_folder:    "/root/autodl-tmp/data"
   sample_size:     256
   sample_stride:   4
   sample_n_frames: 16
@@ -46,14 +46,14 @@ validation_data:
 trainable_modules:
   - "motion_modules."
 
-unet_checkpoint_path: ""
+unet_checkpoint_path: "/root/autodl-tmp/mm_sd_v15_v2.ckpt"
 
 learning_rate:    1.e-4
-train_batch_size: 4
+train_batch_size: 1
 
-max_train_epoch:      -1
-max_train_steps:      100
-checkpointing_epochs: -1
+max_train_epoch:      1000
+max_train_steps:      50000
+checkpointing_epochs: 10
 checkpointing_steps:  60
 
 validation_steps:       5000
@@ -63,4 +63,4 @@ global_seed: 42
 mixed_precision_training: true
 enable_xformers_memory_efficient_attention: True
 
-is_debug: False
+is_debug: False

+ 0 - 0
models/DreamBooth_LoRA/Put personalized T2I checkpoints here.txt


+ 0 - 0
models/MotionLoRA/Put MotionLoRA checkpoints here.txt


+ 0 - 0
models/Motion_Module/Put motion module checkpoints here.txt


+ 0 - 0
models/StableDiffusion/Put diffusers stable-diffusion-v1-5 repo here.txt


+ 1 - 1
scripts/animate.py

@@ -112,7 +112,7 @@ def main(args):
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",)
-    parser.add_argument("--inference_config",      type=str, default="configs/inference/inference-v1.yaml")    
+    parser.add_argument("--inference_config",      type=str, default="configs/inference/inference-v2.yaml")    
     parser.add_argument("--config",                type=str, required=True)
     
     parser.add_argument("--L", type=int, default=16 )

+ 532 - 0
test_train.py

@@ -0,0 +1,532 @@
+import os
+import math
+import wandb
+import random
+import logging
+import inspect
+import argparse
+import datetime
+import subprocess
+
+from pathlib import Path
+from tqdm.auto import tqdm
+from einops import rearrange
+from omegaconf import OmegaConf
+from safetensors import safe_open
+from typing import Dict, Optional, Tuple
+
+import torch
+import torchvision
+import torch.nn.functional as F
+import torch.distributed as dist
+from torch.optim.swa_utils import AveragedModel
+from torch.utils.data.distributed import DistributedSampler
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+import diffusers
+from diffusers import AutoencoderKL, DDIMScheduler
+from diffusers.models import UNet2DConditionModel
+from diffusers.pipelines import StableDiffusionPipeline
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version
+from diffusers.utils.import_utils import is_xformers_available
+
+import transformers
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from animatediff.data.dataset import WebVid10M
+from animatediff.models.unet import UNet3DConditionModel
+from animatediff.pipelines.pipeline_animation import AnimationPipeline
+from animatediff.utils.util import save_videos_grid, zero_rank_print
+from animatediff.models.lora import LoRAModule
+
+
+
+def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs):
+    """Initializes distributed environment."""
+    if launcher == 'pytorch':
+        rank = int(os.environ['RANK'])
+        num_gpus = torch.cuda.device_count()
+        local_rank = rank % num_gpus
+        torch.cuda.set_device(local_rank)
+        dist.init_process_group(backend=backend, **kwargs)
+        
+    elif launcher == 'slurm':
+        proc_id = int(os.environ['SLURM_PROCID'])
+        ntasks = int(os.environ['SLURM_NTASKS'])
+        node_list = os.environ['SLURM_NODELIST']
+        num_gpus = torch.cuda.device_count()
+        local_rank = proc_id % num_gpus
+        torch.cuda.set_device(local_rank)
+        addr = subprocess.getoutput(
+            f'scontrol show hostname {node_list} | head -n1')
+        os.environ['MASTER_ADDR'] = addr
+        os.environ['WORLD_SIZE'] = str(ntasks)
+        os.environ['RANK'] = str(proc_id)
+        port = os.environ.get('PORT', port)
+        os.environ['MASTER_PORT'] = str(port)
+        dist.init_process_group(backend=backend)
+        zero_rank_print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}")
+        
+    else:
+        raise NotImplementedError(f'Not implemented launcher type: `{launcher}`!')
+    
+    return local_rank
+
+
+
+def main(
+    image_finetune: bool,
+    
+    name: str,
+    use_wandb: bool,
+    launcher: str,
+    
+    output_dir: str,
+    pretrained_model_path: str,
+
+    train_data: Dict,
+    validation_data: Dict,
+    cfg_random_null_text: bool = True,
+    cfg_random_null_text_ratio: float = 0.1,
+    
+    unet_checkpoint_path: str = "",
+    unet_additional_kwargs: Dict = {},
+    ema_decay: float = 0.9999,
+    noise_scheduler_kwargs = None,
+    
+    max_train_epoch: int = -1,
+    max_train_steps: int = 100,
+    validation_steps: int = 100,
+    validation_steps_tuple: Tuple = (-1,),
+
+    learning_rate: float = 3e-5,
+    scale_lr: bool = False,
+    lr_warmup_steps: int = 0,
+    lr_scheduler: str = "constant",
+
+    trainable_modules: Tuple[str] = (None, ),
+    num_workers: int = 32,
+    train_batch_size: int = 1,
+    adam_beta1: float = 0.9,
+    adam_beta2: float = 0.999,
+    adam_weight_decay: float = 1e-2,
+    adam_epsilon: float = 1e-08,
+    max_grad_norm: float = 1.0,
+    gradient_accumulation_steps: int = 1,
+    gradient_checkpointing: bool = False,
+    checkpointing_epochs: int = 5,
+    checkpointing_steps: int = -1,
+
+    mixed_precision_training: bool = True,
+    enable_xformers_memory_efficient_attention: bool = True,
+
+    global_seed: int = 42,
+    is_debug: bool = False,
+):
+    check_min_version("0.10.0.dev0")
+
+    # Initialize distributed training
+    local_rank      = init_dist(launcher=launcher)
+    global_rank     = dist.get_rank()
+    num_processes   = dist.get_world_size()
+    is_main_process = global_rank == 0
+
+    seed = global_seed + global_rank
+    torch.manual_seed(seed)
+    
+    # Logging folder
+    folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S")
+    output_dir = os.path.join(output_dir, folder_name)
+    if is_debug and os.path.exists(output_dir):
+        os.system(f"rm -rf {output_dir}")
+
+    *_, config = inspect.getargvalues(inspect.currentframe())
+
+    # Make one log on every process with the configuration for debugging.
+    logging.basicConfig(
+        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+        datefmt="%m/%d/%Y %H:%M:%S",
+        level=logging.INFO,
+    )
+
+    if is_main_process and (not is_debug) and use_wandb:
+        run = wandb.init(project="animatediff", name=folder_name, config=config)
+
+    # Handle the output folder creation
+    if is_main_process:
+        os.makedirs(output_dir, exist_ok=True)
+        os.makedirs(f"{output_dir}/samples", exist_ok=True)
+        os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
+        os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
+        OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
+
+    # Load scheduler, tokenizer and models.
+    noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
+
+    vae          = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
+    tokenizer    = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
+    text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
+    if not image_finetune:
+        unet = UNet3DConditionModel.from_pretrained_2d(
+            pretrained_model_path, subfolder="unet", 
+            unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs)
+        )
+    else:
+        unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet")
+        
+    # Load pretrained unet weights
+    if unet_checkpoint_path != "":
+        zero_rank_print(f"from checkpoint: {unet_checkpoint_path}")
+        unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
+        if "global_step" in unet_checkpoint_path: zero_rank_print(f"global_step: {unet_checkpoint_path['global_step']}")
+        state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path
+
+        m, u = unet.load_state_dict(state_dict, strict=False)
+        zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+        assert len(u) == 0
+        
+    loras = []
+    for name, module in unet.named_modules():
+        if "to_q" in name or "to_k" in name or "to_v" in name or "to_out.0" in name:
+            lora_name = name + "_lora" 
+            print(lora_name)
+            lora = LoRAModule(lora_name,module)
+            loras.append(lora)
+            # print(module.__class__.__name__)
+    print("enable LoRA for U-Net")
+    for lora in loras:
+        lora.apply_to()
+    
+    # Freeze vae and text_encoder
+    vae.requires_grad_(False)
+    text_encoder.requires_grad_(False)
+    
+    # Set unet trainable parameters
+    unet.requires_grad_(False)
+    
+    for lora in loras:
+        lora.requires_grad = True
+    input()
+    
+    # for name, param in unet.named_parameters():
+    #     for trainable_module_name in trainable_modules:
+    #         if trainable_module_name in name:
+    #             param.requires_grad = True
+    #             break
+            
+    trainable_params = trainable_params = [param for lora in loras for param in lora.parameters() if param.requires_grad]
+    optimizer = torch.optim.AdamW(
+        trainable_params,
+        lr=learning_rate,
+        betas=(adam_beta1, adam_beta2),
+        weight_decay=adam_weight_decay,
+        eps=adam_epsilon,
+    )
+
+    if is_main_process:
+        zero_rank_print(f"trainable params number: {len(trainable_params)}")
+        zero_rank_print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
+
+    # Enable xformers
+    # if enable_xformers_memory_efficient_attention:
+    #     if is_xformers_available():
+    #         unet.enable_xformers_memory_efficient_attention()
+    #     else:
+    #         raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+    # Enable gradient checkpointing
+    if gradient_checkpointing:
+        unet.enable_gradient_checkpointing()
+
+    # Move models to GPU
+    vae.to(local_rank)
+    text_encoder.to(local_rank)
+
+    # Get the training dataset
+    train_dataset = WebVid10M(**train_data, is_image=image_finetune)
+    distributed_sampler = DistributedSampler(
+        train_dataset,
+        num_replicas=num_processes,
+        rank=global_rank,
+        shuffle=True,
+        seed=global_seed,
+    )
+
+    # DataLoaders creation:
+    train_dataloader = torch.utils.data.DataLoader(
+        train_dataset,
+        batch_size=train_batch_size,
+        shuffle=False,
+        sampler=distributed_sampler,
+        num_workers=num_workers,
+        pin_memory=True,
+        drop_last=True,
+    )
+
+    # Get the training iteration
+    if max_train_steps == -1:
+        assert max_train_epoch != -1
+        max_train_steps = max_train_epoch * len(train_dataloader)
+        
+    if checkpointing_steps == -1:
+        assert checkpointing_epochs != -1
+        checkpointing_steps = checkpointing_epochs * len(train_dataloader)
+
+    if scale_lr:
+        learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes)
+
+    # Scheduler
+    lr_scheduler = get_scheduler(
+        lr_scheduler,
+        optimizer=optimizer,
+        num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
+        num_training_steps=max_train_steps * gradient_accumulation_steps,
+    )
+
+    # Validation pipeline
+    if not image_finetune:
+        validation_pipeline = AnimationPipeline(
+            unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler,
+        ).to("cuda")
+    else:
+        validation_pipeline = StableDiffusionPipeline.from_pretrained(
+            pretrained_model_path,
+            unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, safety_checker=None,
+        )
+    validation_pipeline.enable_vae_slicing()
+
+    class LoRAModuleWrapper(torch.nn.Module):
+        def __init__(self, loras):
+            super().__init__()
+            for i, lora in enumerate(loras):
+                module_name = lora.lora_name.replace('.', '-')
+                self.add_module(module_name, lora)
+    loras_wrapper = LoRAModuleWrapper(loras).to(local_rank)
+    
+    # DDP warpper
+    unet.to(local_rank)
+    loras_ddp = DDP(loras_wrapper, device_ids=[local_rank], output_device=local_rank)
+
+    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
+    # Afterwards we recalculate our number of training epochs
+    num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
+
+    # Train!
+    total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps
+
+    if is_main_process:
+        logging.info("***** Running training *****")
+        logging.info(f"  Num examples = {len(train_dataset)}")
+        logging.info(f"  Num Epochs = {num_train_epochs}")
+        logging.info(f"  Instantaneous batch size per device = {train_batch_size}")
+        logging.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+        logging.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
+        logging.info(f"  Total optimization steps = {max_train_steps}")
+    global_step = 0
+    first_epoch = 0
+
+    # Only show the progress bar once on each machine.
+    progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process)
+    progress_bar.set_description("Steps")
+
+    # Support mixed-precision training
+    scaler = torch.cuda.amp.GradScaler() if mixed_precision_training else None
+
+    for epoch in range(first_epoch, num_train_epochs):
+        train_dataloader.sampler.set_epoch(epoch)
+        unet.train()
+        
+        for step, batch in enumerate(train_dataloader):
+            if cfg_random_null_text:
+                batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']]
+                
+            # Data batch sanity check
+            if epoch == first_epoch and step == 0:
+                pixel_values, texts = batch['pixel_values'].cpu(), batch['text']
+                if not image_finetune:
+                    pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
+                    for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
+                        pixel_value = pixel_value[None, ...]
+                        save_videos_grid(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.gif", rescale=True)
+                else:
+                    for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
+                        pixel_value = pixel_value / 2. + 0.5
+                        torchvision.utils.save_image(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.png")
+                    
+            ### >>>> Training >>>> ###
+            
+            # Convert videos to latent space            
+            pixel_values = batch["pixel_values"].to(local_rank)
+            video_length = pixel_values.shape[1]
+            with torch.no_grad():
+                if not image_finetune:
+                    pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
+                    latents = vae.encode(pixel_values).latent_dist
+                    latents = latents.sample()
+                    latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
+                else:
+                    latents = vae.encode(pixel_values).latent_dist
+                    latents = latents.sample()
+
+                latents = latents * 0.18215
+
+            # Sample noise that we'll add to the latents
+            noise = torch.randn_like(latents)
+            bsz = latents.shape[0]
+            
+            # Sample a random timestep for each video
+            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+            timesteps = timesteps.long()
+            
+            # Add noise to the latents according to the noise magnitude at each timestep
+            # (this is the forward diffusion process)
+            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+            
+            # Get the text embedding for conditioning
+            with torch.no_grad():
+                prompt_ids = tokenizer(
+                    batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+                ).input_ids.to(latents.device)
+                encoder_hidden_states = text_encoder(prompt_ids)[0]
+                
+            # Get the target for loss depending on the prediction type
+            if noise_scheduler.config.prediction_type == "epsilon":
+                target = noise
+            elif noise_scheduler.config.prediction_type == "v_prediction":
+                raise NotImplementedError
+            else:
+                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+            # Predict the noise residual and compute loss
+            # Mixed-precision training
+            with torch.cuda.amp.autocast(enabled=mixed_precision_training):
+                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+            optimizer.zero_grad()
+
+            # Backpropagate
+            if mixed_precision_training:
+                scaler.scale(loss).backward()
+                """ >>> gradient clipping >>> """
+                scaler.unscale_(optimizer)
+                torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
+                """ <<< gradient clipping <<< """
+                scaler.step(optimizer)
+                scaler.update()
+            else:
+                loss.backward()
+                """ >>> gradient clipping >>> """
+                torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
+                """ <<< gradient clipping <<< """
+                optimizer.step()
+
+            lr_scheduler.step()
+            progress_bar.update(1)
+            global_step += 1
+            
+#             for lora in loras:
+#                 print(f"lora_down grad: {lora.lora_down.weight.grad}")
+#                 print(f"lora_up grad: {lora.lora_up.weight.grad}")
+
+#                 print(f"lora_down requires_grad: {lora.lora_down.weight.requires_grad}")
+#                 print(f"lora_up requires_grad: {lora.lora_up.weight.requires_grad}")
+
+#                 print(f"lora_down data: {lora.lora_down.weight.data}")
+#                 print(f"lora_up data: {lora.lora_up.weight.data}")
+#                 input()
+
+            
+            ### <<<< Training <<<< ###
+            
+            # Wandb logging
+            if is_main_process and (not is_debug) and use_wandb:
+                wandb.log({"train_loss": loss.item()}, step=global_step)
+                
+            # Save checkpoint
+            if is_main_process and (global_step % checkpointing_steps == 0 or step == len(train_dataloader) - 1):
+                lora_save_path = os.path.join(output_dir, f"checkpoints", "lora")
+                os.makedirs(lora_save_path, exist_ok=True)
+                loras_state_dict = {
+                    "epoch": epoch,
+                    "global_step": global_step,
+                    "state_dict": loras_wrapper.state_dict(),
+                }
+                if step == len(train_dataloader) - 1:
+                    torch.save(loras_state_dict, os.path.join(lora_save_path, f"checkpoint-epoch-{epoch+1}.ckpt"))
+                else:
+                    torch.save(loras_state_dict, os.path.join(lora_save_path, f"checkpoint.ckpt"))
+                logging.info(f"Saved LoRA state to {lora_save_path} (global_step: {global_step})")
+                
+            # Periodically validation
+            if is_main_process and (global_step % validation_steps == 0 or global_step in validation_steps_tuple):
+                samples = []
+                
+                generator = torch.Generator(device=latents.device)
+                generator.manual_seed(global_seed)
+                
+                height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size
+                width  = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size
+
+                prompts = validation_data.prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.prompts
+
+                for idx, prompt in enumerate(prompts):
+                    if not image_finetune:
+                        sample = validation_pipeline(
+                            prompt,
+                            generator    = generator,
+                            video_length = train_data.sample_n_frames,
+                            height       = height,
+                            width        = width,
+                            **validation_data,
+                        ).videos
+                        save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif")
+                        samples.append(sample)
+                        
+                    else:
+                        sample = validation_pipeline(
+                            prompt,
+                            generator           = generator,
+                            height              = height,
+                            width               = width,
+                            num_inference_steps = validation_data.get("num_inference_steps", 25),
+                            guidance_scale      = validation_data.get("guidance_scale", 8.),
+                        ).images[0]
+                        sample = torchvision.transforms.functional.to_tensor(sample)
+                        samples.append(sample)
+                
+                if not image_finetune:
+                    samples = torch.concat(samples)
+                    save_path = f"{output_dir}/samples/sample-{global_step}.gif"
+                    save_videos_grid(samples, save_path)
+                    
+                else:
+                    samples = torch.stack(samples)
+                    save_path = f"{output_dir}/samples/sample-{global_step}.png"
+                    torchvision.utils.save_image(samples, save_path, nrow=4)
+
+                logging.info(f"Saved samples to {save_path}")
+                
+            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+            progress_bar.set_postfix(**logs)
+            
+            if global_step >= max_train_steps:
+                break
+            
+    dist.destroy_process_group()
+
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--config",   type=str, required=True)
+    parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch")
+    parser.add_argument("--wandb",    action="store_true")
+    args = parser.parse_args()
+
+    name   = Path(args.config).stem
+    config = OmegaConf.load(args.config)
+
+    main(name=name, launcher=args.launcher, use_wandb=args.wandb, **config)