util.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import os
  2. import imageio
  3. import numpy as np
  4. from typing import Union
  5. import torch
  6. import torchvision
  7. import torch.distributed as dist
  8. from tqdm import tqdm
  9. from einops import rearrange
  10. def zero_rank_print(s):
  11. if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
  12. def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
  13. videos = rearrange(videos, "b c t h w -> t b c h w")
  14. outputs = []
  15. for x in videos:
  16. x = torchvision.utils.make_grid(x, nrow=n_rows)
  17. x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
  18. if rescale:
  19. x = (x + 1.0) / 2.0 # -1,1 -> 0,1
  20. x = (x * 255).numpy().astype(np.uint8)
  21. outputs.append(x)
  22. os.makedirs(os.path.dirname(path), exist_ok=True)
  23. imageio.mimsave(path, outputs, fps=fps)
  24. # DDIM Inversion
  25. @torch.no_grad()
  26. def init_prompt(prompt, pipeline):
  27. uncond_input = pipeline.tokenizer(
  28. [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
  29. return_tensors="pt"
  30. )
  31. uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
  32. text_input = pipeline.tokenizer(
  33. [prompt],
  34. padding="max_length",
  35. max_length=pipeline.tokenizer.model_max_length,
  36. truncation=True,
  37. return_tensors="pt",
  38. )
  39. text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
  40. context = torch.cat([uncond_embeddings, text_embeddings])
  41. return context
  42. def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
  43. sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
  44. timestep, next_timestep = min(
  45. timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
  46. alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
  47. alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
  48. beta_prod_t = 1 - alpha_prod_t
  49. next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
  50. next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
  51. next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
  52. return next_sample
  53. def get_noise_pred_single(latents, t, context, unet):
  54. noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
  55. return noise_pred
  56. @torch.no_grad()
  57. def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
  58. context = init_prompt(prompt, pipeline)
  59. uncond_embeddings, cond_embeddings = context.chunk(2)
  60. all_latent = [latent]
  61. latent = latent.clone().detach()
  62. for i in tqdm(range(num_inv_steps)):
  63. t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
  64. noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
  65. latent = next_step(noise_pred, t, latent, ddim_scheduler)
  66. all_latent.append(latent)
  67. return all_latent
  68. @torch.no_grad()
  69. def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
  70. ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
  71. return ddim_latents