pipeline_animation.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
  2. import inspect
  3. from typing import Callable, List, Optional, Union
  4. from dataclasses import dataclass
  5. import numpy as np
  6. import torch
  7. from diffusers.utils import is_accelerate_available
  8. from packaging import version
  9. from transformers import CLIPTextModel, CLIPTokenizer
  10. from diffusers.configuration_utils import FrozenDict
  11. from diffusers.models import AutoencoderKL
  12. from diffusers.pipeline_utils import DiffusionPipeline
  13. from diffusers.schedulers import (
  14. DDIMScheduler,
  15. DPMSolverMultistepScheduler,
  16. EulerAncestralDiscreteScheduler,
  17. EulerDiscreteScheduler,
  18. LMSDiscreteScheduler,
  19. PNDMScheduler,
  20. )
  21. from diffusers.utils import deprecate, logging, BaseOutput
  22. from einops import rearrange
  23. from ..models.unet import UNet3DConditionModel
  24. logger = logging.get_logger(__name__) # pylint: disable=invalid-name
  25. @dataclass
  26. class AnimationPipelineOutput(BaseOutput):
  27. videos: Union[torch.Tensor, np.ndarray]
  28. class AnimationPipeline(DiffusionPipeline):
  29. _optional_components = []
  30. def __init__(
  31. self,
  32. vae: AutoencoderKL,
  33. text_encoder: CLIPTextModel,
  34. tokenizer: CLIPTokenizer,
  35. unet: UNet3DConditionModel,
  36. scheduler: Union[
  37. DDIMScheduler,
  38. PNDMScheduler,
  39. LMSDiscreteScheduler,
  40. EulerDiscreteScheduler,
  41. EulerAncestralDiscreteScheduler,
  42. DPMSolverMultistepScheduler,
  43. ],
  44. ):
  45. super().__init__()
  46. if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
  47. deprecation_message = (
  48. f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
  49. f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
  50. "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
  51. " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
  52. " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
  53. " file"
  54. )
  55. deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
  56. new_config = dict(scheduler.config)
  57. new_config["steps_offset"] = 1
  58. scheduler._internal_dict = FrozenDict(new_config)
  59. if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
  60. deprecation_message = (
  61. f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
  62. " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
  63. " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
  64. " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
  65. " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
  66. )
  67. deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
  68. new_config = dict(scheduler.config)
  69. new_config["clip_sample"] = False
  70. scheduler._internal_dict = FrozenDict(new_config)
  71. is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
  72. version.parse(unet.config._diffusers_version).base_version
  73. ) < version.parse("0.9.0.dev0")
  74. is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
  75. if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
  76. deprecation_message = (
  77. "The configuration file of the unet has set the default `sample_size` to smaller than"
  78. " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
  79. " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
  80. " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
  81. " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
  82. " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
  83. " in the config might lead to incorrect results in future versions. If you have downloaded this"
  84. " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
  85. " the `unet/config.json` file"
  86. )
  87. deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
  88. new_config = dict(unet.config)
  89. new_config["sample_size"] = 64
  90. unet._internal_dict = FrozenDict(new_config)
  91. self.register_modules(
  92. vae=vae,
  93. text_encoder=text_encoder,
  94. tokenizer=tokenizer,
  95. unet=unet,
  96. scheduler=scheduler,
  97. )
  98. self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
  99. def enable_vae_slicing(self):
  100. self.vae.enable_slicing()
  101. def disable_vae_slicing(self):
  102. self.vae.disable_slicing()
  103. def enable_sequential_cpu_offload(self, gpu_id=0):
  104. if is_accelerate_available():
  105. from accelerate import cpu_offload
  106. else:
  107. raise ImportError("Please install accelerate via `pip install accelerate`")
  108. device = torch.device(f"cuda:{gpu_id}")
  109. for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
  110. if cpu_offloaded_model is not None:
  111. cpu_offload(cpu_offloaded_model, device)
  112. @property
  113. def _execution_device(self):
  114. if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
  115. return self.device
  116. for module in self.unet.modules():
  117. if (
  118. hasattr(module, "_hf_hook")
  119. and hasattr(module._hf_hook, "execution_device")
  120. and module._hf_hook.execution_device is not None
  121. ):
  122. return torch.device(module._hf_hook.execution_device)
  123. return self.device
  124. def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
  125. batch_size = len(prompt) if isinstance(prompt, list) else 1
  126. text_inputs = self.tokenizer(
  127. prompt,
  128. padding="max_length",
  129. max_length=self.tokenizer.model_max_length,
  130. truncation=True,
  131. return_tensors="pt",
  132. )
  133. text_input_ids = text_inputs.input_ids
  134. untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
  135. if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
  136. removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
  137. logger.warning(
  138. "The following part of your input was truncated because CLIP can only handle sequences up to"
  139. f" {self.tokenizer.model_max_length} tokens: {removed_text}"
  140. )
  141. if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
  142. attention_mask = text_inputs.attention_mask.to(device)
  143. else:
  144. attention_mask = None
  145. text_embeddings = self.text_encoder(
  146. text_input_ids.to(device),
  147. attention_mask=attention_mask,
  148. )
  149. text_embeddings = text_embeddings[0]
  150. # duplicate text embeddings for each generation per prompt, using mps friendly method
  151. bs_embed, seq_len, _ = text_embeddings.shape
  152. text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
  153. text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
  154. # get unconditional embeddings for classifier free guidance
  155. if do_classifier_free_guidance:
  156. uncond_tokens: List[str]
  157. if negative_prompt is None:
  158. uncond_tokens = [""] * batch_size
  159. elif type(prompt) is not type(negative_prompt):
  160. raise TypeError(
  161. f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
  162. f" {type(prompt)}."
  163. )
  164. elif isinstance(negative_prompt, str):
  165. uncond_tokens = [negative_prompt]
  166. elif batch_size != len(negative_prompt):
  167. raise ValueError(
  168. f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
  169. f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
  170. " the batch size of `prompt`."
  171. )
  172. else:
  173. uncond_tokens = negative_prompt
  174. max_length = text_input_ids.shape[-1]
  175. uncond_input = self.tokenizer(
  176. uncond_tokens,
  177. padding="max_length",
  178. max_length=max_length,
  179. truncation=True,
  180. return_tensors="pt",
  181. )
  182. if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
  183. attention_mask = uncond_input.attention_mask.to(device)
  184. else:
  185. attention_mask = None
  186. uncond_embeddings = self.text_encoder(
  187. uncond_input.input_ids.to(device),
  188. attention_mask=attention_mask,
  189. )
  190. uncond_embeddings = uncond_embeddings[0]
  191. # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
  192. seq_len = uncond_embeddings.shape[1]
  193. uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
  194. uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
  195. # For classifier free guidance, we need to do two forward passes.
  196. # Here we concatenate the unconditional and text embeddings into a single batch
  197. # to avoid doing two forward passes
  198. text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
  199. return text_embeddings
  200. def decode_latents(self, latents):
  201. video_length = latents.shape[2]
  202. latents = 1 / 0.18215 * latents
  203. latents = rearrange(latents, "b c f h w -> (b f) c h w")
  204. video = self.vae.decode(latents).sample
  205. video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
  206. video = (video / 2 + 0.5).clamp(0, 1)
  207. # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
  208. video = video.cpu().float().numpy()
  209. return video
  210. def prepare_extra_step_kwargs(self, generator, eta):
  211. # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
  212. # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
  213. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
  214. # and should be between [0, 1]
  215. accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
  216. extra_step_kwargs = {}
  217. if accepts_eta:
  218. extra_step_kwargs["eta"] = eta
  219. # check if the scheduler accepts generator
  220. accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
  221. if accepts_generator:
  222. extra_step_kwargs["generator"] = generator
  223. return extra_step_kwargs
  224. def check_inputs(self, prompt, height, width, callback_steps):
  225. if not isinstance(prompt, str) and not isinstance(prompt, list):
  226. raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
  227. if height % 8 != 0 or width % 8 != 0:
  228. raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
  229. if (callback_steps is None) or (
  230. callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
  231. ):
  232. raise ValueError(
  233. f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
  234. f" {type(callback_steps)}."
  235. )
  236. def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
  237. shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
  238. if isinstance(generator, list) and len(generator) != batch_size:
  239. raise ValueError(
  240. f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
  241. f" size of {batch_size}. Make sure the batch size matches the length of the generators."
  242. )
  243. if latents is None:
  244. rand_device = "cpu" if device.type == "mps" else device
  245. if isinstance(generator, list):
  246. shape = shape
  247. # shape = (1,) + shape[1:]
  248. latents = [
  249. torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
  250. for i in range(batch_size)
  251. ]
  252. latents = torch.cat(latents, dim=0).to(device)
  253. else:
  254. latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
  255. else:
  256. if latents.shape != shape:
  257. raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
  258. latents = latents.to(device)
  259. # scale the initial noise by the standard deviation required by the scheduler
  260. latents = latents * self.scheduler.init_noise_sigma
  261. return latents
  262. @torch.no_grad()
  263. def __call__(
  264. self,
  265. prompt: Union[str, List[str]],
  266. video_length: Optional[int],
  267. height: Optional[int] = None,
  268. width: Optional[int] = None,
  269. num_inference_steps: int = 50,
  270. guidance_scale: float = 7.5,
  271. negative_prompt: Optional[Union[str, List[str]]] = None,
  272. num_videos_per_prompt: Optional[int] = 1,
  273. eta: float = 0.0,
  274. generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
  275. latents: Optional[torch.FloatTensor] = None,
  276. output_type: Optional[str] = "tensor",
  277. return_dict: bool = True,
  278. callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
  279. callback_steps: Optional[int] = 1,
  280. **kwargs,
  281. ):
  282. # Default height and width to unet
  283. height = height or self.unet.config.sample_size * self.vae_scale_factor
  284. width = width or self.unet.config.sample_size * self.vae_scale_factor
  285. # Check inputs. Raise error if not correct
  286. self.check_inputs(prompt, height, width, callback_steps)
  287. # Define call parameters
  288. # batch_size = 1 if isinstance(prompt, str) else len(prompt)
  289. batch_size = 1
  290. if latents is not None:
  291. batch_size = latents.shape[0]
  292. if isinstance(prompt, list):
  293. batch_size = len(prompt)
  294. device = self._execution_device
  295. # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
  296. # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
  297. # corresponds to doing no classifier free guidance.
  298. do_classifier_free_guidance = guidance_scale > 1.0
  299. # Encode input prompt
  300. prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
  301. if negative_prompt is not None:
  302. negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
  303. text_embeddings = self._encode_prompt(
  304. prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
  305. )
  306. # Prepare timesteps
  307. self.scheduler.set_timesteps(num_inference_steps, device=device)
  308. timesteps = self.scheduler.timesteps
  309. # Prepare latent variables
  310. num_channels_latents = self.unet.in_channels
  311. latents = self.prepare_latents(
  312. batch_size * num_videos_per_prompt,
  313. num_channels_latents,
  314. video_length,
  315. height,
  316. width,
  317. text_embeddings.dtype,
  318. device,
  319. generator,
  320. latents,
  321. )
  322. latents_dtype = latents.dtype
  323. # Prepare extra step kwargs.
  324. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
  325. # Denoising loop
  326. num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
  327. with self.progress_bar(total=num_inference_steps) as progress_bar:
  328. for i, t in enumerate(timesteps):
  329. # expand the latents if we are doing classifier free guidance
  330. latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
  331. latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
  332. # predict the noise residual
  333. noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
  334. # noise_pred = []
  335. # import pdb
  336. # pdb.set_trace()
  337. # for batch_idx in range(latent_model_input.shape[0]):
  338. # noise_pred_single = self.unet(latent_model_input[batch_idx:batch_idx+1], t, encoder_hidden_states=text_embeddings[batch_idx:batch_idx+1]).sample.to(dtype=latents_dtype)
  339. # noise_pred.append(noise_pred_single)
  340. # noise_pred = torch.cat(noise_pred)
  341. # perform guidance
  342. if do_classifier_free_guidance:
  343. noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
  344. noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
  345. # compute the previous noisy sample x_t -> x_t-1
  346. latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
  347. # call the callback, if provided
  348. if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
  349. progress_bar.update()
  350. if callback is not None and i % callback_steps == 0:
  351. callback(i, t, latents)
  352. # Post-processing
  353. video = self.decode_latents(latents)
  354. # Convert to tensor
  355. if output_type == "tensor":
  356. video = torch.from_numpy(video)
  357. if not return_dict:
  358. return video
  359. return AnimationPipelineOutput(videos=video)