unet.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
  2. from dataclasses import dataclass
  3. from typing import List, Optional, Tuple, Union
  4. import os
  5. import json
  6. import pdb
  7. import torch
  8. import torch.nn as nn
  9. import torch.utils.checkpoint
  10. from diffusers.configuration_utils import ConfigMixin, register_to_config
  11. from diffusers.modeling_utils import ModelMixin
  12. from diffusers.utils import BaseOutput, logging
  13. from diffusers.models.embeddings import TimestepEmbedding, Timesteps
  14. from .unet_blocks import (
  15. CrossAttnDownBlock3D,
  16. CrossAttnUpBlock3D,
  17. DownBlock3D,
  18. UNetMidBlock3DCrossAttn,
  19. UpBlock3D,
  20. get_down_block,
  21. get_up_block,
  22. )
  23. from .resnet import InflatedConv3d
  24. logger = logging.get_logger(__name__) # pylint: disable=invalid-name
  25. @dataclass
  26. class UNet3DConditionOutput(BaseOutput):
  27. sample: torch.FloatTensor
  28. class UNet3DConditionModel(ModelMixin, ConfigMixin):
  29. _supports_gradient_checkpointing = True
  30. @register_to_config
  31. def __init__(
  32. self,
  33. sample_size: Optional[int] = None,
  34. in_channels: int = 4,
  35. out_channels: int = 4,
  36. center_input_sample: bool = False,
  37. flip_sin_to_cos: bool = True,
  38. freq_shift: int = 0,
  39. down_block_types: Tuple[str] = (
  40. "CrossAttnDownBlock3D",
  41. "CrossAttnDownBlock3D",
  42. "CrossAttnDownBlock3D",
  43. "DownBlock3D",
  44. ),
  45. mid_block_type: str = "UNetMidBlock3DCrossAttn",
  46. up_block_types: Tuple[str] = (
  47. "UpBlock3D",
  48. "CrossAttnUpBlock3D",
  49. "CrossAttnUpBlock3D",
  50. "CrossAttnUpBlock3D"
  51. ),
  52. only_cross_attention: Union[bool, Tuple[bool]] = False,
  53. block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
  54. layers_per_block: int = 2,
  55. downsample_padding: int = 1,
  56. mid_block_scale_factor: float = 1,
  57. act_fn: str = "silu",
  58. norm_num_groups: int = 32,
  59. norm_eps: float = 1e-5,
  60. cross_attention_dim: int = 1280,
  61. attention_head_dim: Union[int, Tuple[int]] = 8,
  62. dual_cross_attention: bool = False,
  63. use_linear_projection: bool = False,
  64. class_embed_type: Optional[str] = None,
  65. num_class_embeds: Optional[int] = None,
  66. upcast_attention: bool = False,
  67. resnet_time_scale_shift: str = "default",
  68. # Additional
  69. use_motion_module = False,
  70. motion_module_resolutions = ( 1,2,4,8 ),
  71. motion_module_mid_block = False,
  72. motion_module_decoder_only = False,
  73. motion_module_type = None,
  74. motion_module_kwargs = {},
  75. unet_use_cross_frame_attention = None,
  76. unet_use_temporal_attention = None,
  77. ):
  78. super().__init__()
  79. self.sample_size = sample_size
  80. time_embed_dim = block_out_channels[0] * 4
  81. # input
  82. self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
  83. # time
  84. self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
  85. timestep_input_dim = block_out_channels[0]
  86. self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
  87. # class embedding
  88. if class_embed_type is None and num_class_embeds is not None:
  89. self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
  90. elif class_embed_type == "timestep":
  91. self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
  92. elif class_embed_type == "identity":
  93. self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
  94. else:
  95. self.class_embedding = None
  96. self.down_blocks = nn.ModuleList([])
  97. self.mid_block = None
  98. self.up_blocks = nn.ModuleList([])
  99. if isinstance(only_cross_attention, bool):
  100. only_cross_attention = [only_cross_attention] * len(down_block_types)
  101. if isinstance(attention_head_dim, int):
  102. attention_head_dim = (attention_head_dim,) * len(down_block_types)
  103. # down
  104. output_channel = block_out_channels[0]
  105. for i, down_block_type in enumerate(down_block_types):
  106. res = 2 ** i
  107. input_channel = output_channel
  108. output_channel = block_out_channels[i]
  109. is_final_block = i == len(block_out_channels) - 1
  110. down_block = get_down_block(
  111. down_block_type,
  112. num_layers=layers_per_block,
  113. in_channels=input_channel,
  114. out_channels=output_channel,
  115. temb_channels=time_embed_dim,
  116. add_downsample=not is_final_block,
  117. resnet_eps=norm_eps,
  118. resnet_act_fn=act_fn,
  119. resnet_groups=norm_num_groups,
  120. cross_attention_dim=cross_attention_dim,
  121. attn_num_head_channels=attention_head_dim[i],
  122. downsample_padding=downsample_padding,
  123. dual_cross_attention=dual_cross_attention,
  124. use_linear_projection=use_linear_projection,
  125. only_cross_attention=only_cross_attention[i],
  126. upcast_attention=upcast_attention,
  127. resnet_time_scale_shift=resnet_time_scale_shift,
  128. unet_use_cross_frame_attention=unet_use_cross_frame_attention,
  129. unet_use_temporal_attention=unet_use_temporal_attention,
  130. use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
  131. motion_module_type=motion_module_type,
  132. motion_module_kwargs=motion_module_kwargs,
  133. )
  134. self.down_blocks.append(down_block)
  135. # mid
  136. if mid_block_type == "UNetMidBlock3DCrossAttn":
  137. self.mid_block = UNetMidBlock3DCrossAttn(
  138. in_channels=block_out_channels[-1],
  139. temb_channels=time_embed_dim,
  140. resnet_eps=norm_eps,
  141. resnet_act_fn=act_fn,
  142. output_scale_factor=mid_block_scale_factor,
  143. resnet_time_scale_shift=resnet_time_scale_shift,
  144. cross_attention_dim=cross_attention_dim,
  145. attn_num_head_channels=attention_head_dim[-1],
  146. resnet_groups=norm_num_groups,
  147. dual_cross_attention=dual_cross_attention,
  148. use_linear_projection=use_linear_projection,
  149. upcast_attention=upcast_attention,
  150. unet_use_cross_frame_attention=unet_use_cross_frame_attention,
  151. unet_use_temporal_attention=unet_use_temporal_attention,
  152. use_motion_module=use_motion_module and motion_module_mid_block,
  153. motion_module_type=motion_module_type,
  154. motion_module_kwargs=motion_module_kwargs,
  155. )
  156. else:
  157. raise ValueError(f"unknown mid_block_type : {mid_block_type}")
  158. # count how many layers upsample the videos
  159. self.num_upsamplers = 0
  160. # up
  161. reversed_block_out_channels = list(reversed(block_out_channels))
  162. reversed_attention_head_dim = list(reversed(attention_head_dim))
  163. only_cross_attention = list(reversed(only_cross_attention))
  164. output_channel = reversed_block_out_channels[0]
  165. for i, up_block_type in enumerate(up_block_types):
  166. res = 2 ** (3 - i)
  167. is_final_block = i == len(block_out_channels) - 1
  168. prev_output_channel = output_channel
  169. output_channel = reversed_block_out_channels[i]
  170. input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
  171. # add upsample block for all BUT final layer
  172. if not is_final_block:
  173. add_upsample = True
  174. self.num_upsamplers += 1
  175. else:
  176. add_upsample = False
  177. up_block = get_up_block(
  178. up_block_type,
  179. num_layers=layers_per_block + 1,
  180. in_channels=input_channel,
  181. out_channels=output_channel,
  182. prev_output_channel=prev_output_channel,
  183. temb_channels=time_embed_dim,
  184. add_upsample=add_upsample,
  185. resnet_eps=norm_eps,
  186. resnet_act_fn=act_fn,
  187. resnet_groups=norm_num_groups,
  188. cross_attention_dim=cross_attention_dim,
  189. attn_num_head_channels=reversed_attention_head_dim[i],
  190. dual_cross_attention=dual_cross_attention,
  191. use_linear_projection=use_linear_projection,
  192. only_cross_attention=only_cross_attention[i],
  193. upcast_attention=upcast_attention,
  194. resnet_time_scale_shift=resnet_time_scale_shift,
  195. unet_use_cross_frame_attention=unet_use_cross_frame_attention,
  196. unet_use_temporal_attention=unet_use_temporal_attention,
  197. use_motion_module=use_motion_module and (res in motion_module_resolutions),
  198. motion_module_type=motion_module_type,
  199. motion_module_kwargs=motion_module_kwargs,
  200. )
  201. self.up_blocks.append(up_block)
  202. prev_output_channel = output_channel
  203. # out
  204. self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
  205. self.conv_act = nn.SiLU()
  206. self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
  207. def set_attention_slice(self, slice_size):
  208. r"""
  209. Enable sliced attention computation.
  210. When this option is enabled, the attention module will split the input tensor in slices, to compute attention
  211. in several steps. This is useful to save some memory in exchange for a small speed decrease.
  212. Args:
  213. slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
  214. When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
  215. `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
  216. provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
  217. must be a multiple of `slice_size`.
  218. """
  219. sliceable_head_dims = []
  220. def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
  221. if hasattr(module, "set_attention_slice"):
  222. sliceable_head_dims.append(module.sliceable_head_dim)
  223. for child in module.children():
  224. fn_recursive_retrieve_slicable_dims(child)
  225. # retrieve number of attention layers
  226. for module in self.children():
  227. fn_recursive_retrieve_slicable_dims(module)
  228. num_slicable_layers = len(sliceable_head_dims)
  229. if slice_size == "auto":
  230. # half the attention head size is usually a good trade-off between
  231. # speed and memory
  232. slice_size = [dim // 2 for dim in sliceable_head_dims]
  233. elif slice_size == "max":
  234. # make smallest slice possible
  235. slice_size = num_slicable_layers * [1]
  236. slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
  237. if len(slice_size) != len(sliceable_head_dims):
  238. raise ValueError(
  239. f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
  240. f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
  241. )
  242. for i in range(len(slice_size)):
  243. size = slice_size[i]
  244. dim = sliceable_head_dims[i]
  245. if size is not None and size > dim:
  246. raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
  247. # Recursively walk through all the children.
  248. # Any children which exposes the set_attention_slice method
  249. # gets the message
  250. def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
  251. if hasattr(module, "set_attention_slice"):
  252. module.set_attention_slice(slice_size.pop())
  253. for child in module.children():
  254. fn_recursive_set_attention_slice(child, slice_size)
  255. reversed_slice_size = list(reversed(slice_size))
  256. for module in self.children():
  257. fn_recursive_set_attention_slice(module, reversed_slice_size)
  258. def _set_gradient_checkpointing(self, module, value=False):
  259. if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
  260. module.gradient_checkpointing = value
  261. def forward(
  262. self,
  263. sample: torch.FloatTensor,
  264. timestep: Union[torch.Tensor, float, int],
  265. encoder_hidden_states: torch.Tensor,
  266. class_labels: Optional[torch.Tensor] = None,
  267. attention_mask: Optional[torch.Tensor] = None,
  268. return_dict: bool = True,
  269. ) -> Union[UNet3DConditionOutput, Tuple]:
  270. r"""
  271. Args:
  272. sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
  273. timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
  274. encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
  275. return_dict (`bool`, *optional*, defaults to `True`):
  276. Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
  277. Returns:
  278. [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
  279. [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
  280. returning a tuple, the first element is the sample tensor.
  281. """
  282. # By default samples have to be AT least a multiple of the overall upsampling factor.
  283. # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
  284. # However, the upsampling interpolation output size can be forced to fit any upsampling size
  285. # on the fly if necessary.
  286. default_overall_up_factor = 2**self.num_upsamplers
  287. # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
  288. forward_upsample_size = False
  289. upsample_size = None
  290. if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
  291. logger.info("Forward upsample size to force interpolation output size.")
  292. forward_upsample_size = True
  293. # prepare attention_mask
  294. if attention_mask is not None:
  295. attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
  296. attention_mask = attention_mask.unsqueeze(1)
  297. # center input if necessary
  298. if self.config.center_input_sample:
  299. sample = 2 * sample - 1.0
  300. # time
  301. timesteps = timestep
  302. if not torch.is_tensor(timesteps):
  303. # This would be a good case for the `match` statement (Python 3.10+)
  304. is_mps = sample.device.type == "mps"
  305. if isinstance(timestep, float):
  306. dtype = torch.float32 if is_mps else torch.float64
  307. else:
  308. dtype = torch.int32 if is_mps else torch.int64
  309. timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
  310. elif len(timesteps.shape) == 0:
  311. timesteps = timesteps[None].to(sample.device)
  312. # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
  313. timesteps = timesteps.expand(sample.shape[0])
  314. t_emb = self.time_proj(timesteps)
  315. # timesteps does not contain any weights and will always return f32 tensors
  316. # but time_embedding might actually be running in fp16. so we need to cast here.
  317. # there might be better ways to encapsulate this.
  318. t_emb = t_emb.to(dtype=self.dtype)
  319. emb = self.time_embedding(t_emb)
  320. if self.class_embedding is not None:
  321. if class_labels is None:
  322. raise ValueError("class_labels should be provided when num_class_embeds > 0")
  323. if self.config.class_embed_type == "timestep":
  324. class_labels = self.time_proj(class_labels)
  325. class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
  326. emb = emb + class_emb
  327. # pre-process
  328. sample = self.conv_in(sample)
  329. # down
  330. down_block_res_samples = (sample,)
  331. for downsample_block in self.down_blocks:
  332. if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
  333. sample, res_samples = downsample_block(
  334. hidden_states=sample,
  335. temb=emb,
  336. encoder_hidden_states=encoder_hidden_states,
  337. attention_mask=attention_mask,
  338. )
  339. else:
  340. sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
  341. down_block_res_samples += res_samples
  342. # mid
  343. sample = self.mid_block(
  344. sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
  345. )
  346. # up
  347. for i, upsample_block in enumerate(self.up_blocks):
  348. is_final_block = i == len(self.up_blocks) - 1
  349. res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
  350. down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
  351. # if we have not reached the final block and need to forward the
  352. # upsample size, we do it here
  353. if not is_final_block and forward_upsample_size:
  354. upsample_size = down_block_res_samples[-1].shape[2:]
  355. if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
  356. sample = upsample_block(
  357. hidden_states=sample,
  358. temb=emb,
  359. res_hidden_states_tuple=res_samples,
  360. encoder_hidden_states=encoder_hidden_states,
  361. upsample_size=upsample_size,
  362. attention_mask=attention_mask,
  363. )
  364. else:
  365. sample = upsample_block(
  366. hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
  367. )
  368. # post-process
  369. sample = self.conv_norm_out(sample)
  370. sample = self.conv_act(sample)
  371. sample = self.conv_out(sample)
  372. if not return_dict:
  373. return (sample,)
  374. return UNet3DConditionOutput(sample=sample)
  375. @classmethod
  376. def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
  377. if subfolder is not None:
  378. pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
  379. print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
  380. config_file = os.path.join(pretrained_model_path, 'config.json')
  381. if not os.path.isfile(config_file):
  382. raise RuntimeError(f"{config_file} does not exist")
  383. with open(config_file, "r") as f:
  384. config = json.load(f)
  385. config["_class_name"] = cls.__name__
  386. config["down_block_types"] = [
  387. "CrossAttnDownBlock3D",
  388. "CrossAttnDownBlock3D",
  389. "CrossAttnDownBlock3D",
  390. "DownBlock3D"
  391. ]
  392. config["up_block_types"] = [
  393. "UpBlock3D",
  394. "CrossAttnUpBlock3D",
  395. "CrossAttnUpBlock3D",
  396. "CrossAttnUpBlock3D"
  397. ]
  398. from diffusers.utils import WEIGHTS_NAME
  399. model = cls.from_config(config, **unet_additional_kwargs)
  400. model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
  401. if not os.path.isfile(model_file):
  402. raise RuntimeError(f"{model_file} does not exist")
  403. state_dict = torch.load(model_file, map_location="cpu")
  404. m, u = model.load_state_dict(state_dict, strict=False)
  405. print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
  406. # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
  407. params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
  408. print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
  409. return model