util.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os
  2. import imageio
  3. import numpy as np
  4. from typing import Union
  5. import torch
  6. import torchvision
  7. from tqdm import tqdm
  8. from einops import rearrange
  9. def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
  10. videos = rearrange(videos, "b c t h w -> t b c h w")
  11. outputs = []
  12. for x in videos:
  13. x = torchvision.utils.make_grid(x, nrow=n_rows)
  14. x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
  15. if rescale:
  16. x = (x + 1.0) / 2.0 # -1,1 -> 0,1
  17. x = (x * 255).numpy().astype(np.uint8)
  18. outputs.append(x)
  19. os.makedirs(os.path.dirname(path), exist_ok=True)
  20. imageio.mimsave(path, outputs, fps=fps)
  21. # DDIM Inversion
  22. @torch.no_grad()
  23. def init_prompt(prompt, pipeline):
  24. uncond_input = pipeline.tokenizer(
  25. [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
  26. return_tensors="pt"
  27. )
  28. uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
  29. text_input = pipeline.tokenizer(
  30. [prompt],
  31. padding="max_length",
  32. max_length=pipeline.tokenizer.model_max_length,
  33. truncation=True,
  34. return_tensors="pt",
  35. )
  36. text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
  37. context = torch.cat([uncond_embeddings, text_embeddings])
  38. return context
  39. def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
  40. sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
  41. timestep, next_timestep = min(
  42. timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
  43. alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
  44. alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
  45. beta_prod_t = 1 - alpha_prod_t
  46. next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
  47. next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
  48. next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
  49. return next_sample
  50. def get_noise_pred_single(latents, t, context, unet):
  51. noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
  52. return noise_pred
  53. @torch.no_grad()
  54. def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
  55. context = init_prompt(prompt, pipeline)
  56. uncond_embeddings, cond_embeddings = context.chunk(2)
  57. all_latent = [latent]
  58. latent = latent.clone().detach()
  59. for i in tqdm(range(num_inv_steps)):
  60. t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
  61. noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
  62. latent = next_step(noise_pred, t, latent, ddim_scheduler)
  63. all_latent.append(latent)
  64. return all_latent
  65. @torch.no_grad()
  66. def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
  67. ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
  68. return ddim_latents