123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497 |
- # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
- from dataclasses import dataclass
- from typing import List, Optional, Tuple, Union
- import os
- import json
- import pdb
- import torch
- import torch.nn as nn
- import torch.utils.checkpoint
- from diffusers.configuration_utils import ConfigMixin, register_to_config
- from diffusers.modeling_utils import ModelMixin
- from diffusers.utils import BaseOutput, logging
- from diffusers.models.embeddings import TimestepEmbedding, Timesteps
- from .unet_blocks import (
- CrossAttnDownBlock3D,
- CrossAttnUpBlock3D,
- DownBlock3D,
- UNetMidBlock3DCrossAttn,
- UpBlock3D,
- get_down_block,
- get_up_block,
- )
- from .resnet import InflatedConv3d, InflatedGroupNorm
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
- @dataclass
- class UNet3DConditionOutput(BaseOutput):
- sample: torch.FloatTensor
- class UNet3DConditionModel(ModelMixin, ConfigMixin):
- _supports_gradient_checkpointing = True
- @register_to_config
- def __init__(
- self,
- sample_size: Optional[int] = None,
- in_channels: int = 4,
- out_channels: int = 4,
- center_input_sample: bool = False,
- flip_sin_to_cos: bool = True,
- freq_shift: int = 0,
- down_block_types: Tuple[str] = (
- "CrossAttnDownBlock3D",
- "CrossAttnDownBlock3D",
- "CrossAttnDownBlock3D",
- "DownBlock3D",
- ),
- mid_block_type: str = "UNetMidBlock3DCrossAttn",
- up_block_types: Tuple[str] = (
- "UpBlock3D",
- "CrossAttnUpBlock3D",
- "CrossAttnUpBlock3D",
- "CrossAttnUpBlock3D"
- ),
- only_cross_attention: Union[bool, Tuple[bool]] = False,
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
- layers_per_block: int = 2,
- downsample_padding: int = 1,
- mid_block_scale_factor: float = 1,
- act_fn: str = "silu",
- norm_num_groups: int = 32,
- norm_eps: float = 1e-5,
- cross_attention_dim: int = 1280,
- attention_head_dim: Union[int, Tuple[int]] = 8,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = False,
- class_embed_type: Optional[str] = None,
- num_class_embeds: Optional[int] = None,
- upcast_attention: bool = False,
- resnet_time_scale_shift: str = "default",
-
- use_inflated_groupnorm=False,
-
- # Additional
- use_motion_module = False,
- motion_module_resolutions = ( 1,2,4,8 ),
- motion_module_mid_block = False,
- motion_module_decoder_only = False,
- motion_module_type = None,
- motion_module_kwargs = {},
- unet_use_cross_frame_attention = None,
- unet_use_temporal_attention = None,
- ):
- super().__init__()
-
- self.sample_size = sample_size
- time_embed_dim = block_out_channels[0] * 4
- # input
- self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
- # time
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
- timestep_input_dim = block_out_channels[0]
- self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
- # class embedding
- if class_embed_type is None and num_class_embeds is not None:
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
- elif class_embed_type == "timestep":
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
- elif class_embed_type == "identity":
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
- else:
- self.class_embedding = None
- self.down_blocks = nn.ModuleList([])
- self.mid_block = None
- self.up_blocks = nn.ModuleList([])
- if isinstance(only_cross_attention, bool):
- only_cross_attention = [only_cross_attention] * len(down_block_types)
- if isinstance(attention_head_dim, int):
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
- # down
- output_channel = block_out_channels[0]
- for i, down_block_type in enumerate(down_block_types):
- res = 2 ** i
- input_channel = output_channel
- output_channel = block_out_channels[i]
- is_final_block = i == len(block_out_channels) - 1
- down_block = get_down_block(
- down_block_type,
- num_layers=layers_per_block,
- in_channels=input_channel,
- out_channels=output_channel,
- temb_channels=time_embed_dim,
- add_downsample=not is_final_block,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- cross_attention_dim=cross_attention_dim,
- attn_num_head_channels=attention_head_dim[i],
- downsample_padding=downsample_padding,
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention[i],
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
- unet_use_temporal_attention=unet_use_temporal_attention,
- use_inflated_groupnorm=use_inflated_groupnorm,
-
- use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
- motion_module_type=motion_module_type,
- motion_module_kwargs=motion_module_kwargs,
- )
- self.down_blocks.append(down_block)
- # mid
- if mid_block_type == "UNetMidBlock3DCrossAttn":
- self.mid_block = UNetMidBlock3DCrossAttn(
- in_channels=block_out_channels[-1],
- temb_channels=time_embed_dim,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- output_scale_factor=mid_block_scale_factor,
- resnet_time_scale_shift=resnet_time_scale_shift,
- cross_attention_dim=cross_attention_dim,
- attn_num_head_channels=attention_head_dim[-1],
- resnet_groups=norm_num_groups,
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- upcast_attention=upcast_attention,
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
- unet_use_temporal_attention=unet_use_temporal_attention,
- use_inflated_groupnorm=use_inflated_groupnorm,
-
- use_motion_module=use_motion_module and motion_module_mid_block,
- motion_module_type=motion_module_type,
- motion_module_kwargs=motion_module_kwargs,
- )
- else:
- raise ValueError(f"unknown mid_block_type : {mid_block_type}")
-
- # count how many layers upsample the videos
- self.num_upsamplers = 0
- # up
- reversed_block_out_channels = list(reversed(block_out_channels))
- reversed_attention_head_dim = list(reversed(attention_head_dim))
- only_cross_attention = list(reversed(only_cross_attention))
- output_channel = reversed_block_out_channels[0]
- for i, up_block_type in enumerate(up_block_types):
- res = 2 ** (3 - i)
- is_final_block = i == len(block_out_channels) - 1
- prev_output_channel = output_channel
- output_channel = reversed_block_out_channels[i]
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
- # add upsample block for all BUT final layer
- if not is_final_block:
- add_upsample = True
- self.num_upsamplers += 1
- else:
- add_upsample = False
- up_block = get_up_block(
- up_block_type,
- num_layers=layers_per_block + 1,
- in_channels=input_channel,
- out_channels=output_channel,
- prev_output_channel=prev_output_channel,
- temb_channels=time_embed_dim,
- add_upsample=add_upsample,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- cross_attention_dim=cross_attention_dim,
- attn_num_head_channels=reversed_attention_head_dim[i],
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention[i],
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
- unet_use_temporal_attention=unet_use_temporal_attention,
- use_inflated_groupnorm=use_inflated_groupnorm,
- use_motion_module=use_motion_module and (res in motion_module_resolutions),
- motion_module_type=motion_module_type,
- motion_module_kwargs=motion_module_kwargs,
- )
- self.up_blocks.append(up_block)
- prev_output_channel = output_channel
- # out
- if use_inflated_groupnorm:
- self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
- else:
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
- self.conv_act = nn.SiLU()
- self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
- def set_attention_slice(self, slice_size):
- r"""
- Enable sliced attention computation.
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
- Args:
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
- `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
- must be a multiple of `slice_size`.
- """
- sliceable_head_dims = []
- def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
- if hasattr(module, "set_attention_slice"):
- sliceable_head_dims.append(module.sliceable_head_dim)
- for child in module.children():
- fn_recursive_retrieve_slicable_dims(child)
- # retrieve number of attention layers
- for module in self.children():
- fn_recursive_retrieve_slicable_dims(module)
- num_slicable_layers = len(sliceable_head_dims)
- if slice_size == "auto":
- # half the attention head size is usually a good trade-off between
- # speed and memory
- slice_size = [dim // 2 for dim in sliceable_head_dims]
- elif slice_size == "max":
- # make smallest slice possible
- slice_size = num_slicable_layers * [1]
- slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
- if len(slice_size) != len(sliceable_head_dims):
- raise ValueError(
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
- )
- for i in range(len(slice_size)):
- size = slice_size[i]
- dim = sliceable_head_dims[i]
- if size is not None and size > dim:
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
- # Recursively walk through all the children.
- # Any children which exposes the set_attention_slice method
- # gets the message
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
- if hasattr(module, "set_attention_slice"):
- module.set_attention_slice(slice_size.pop())
- for child in module.children():
- fn_recursive_set_attention_slice(child, slice_size)
- reversed_slice_size = list(reversed(slice_size))
- for module in self.children():
- fn_recursive_set_attention_slice(module, reversed_slice_size)
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
- module.gradient_checkpointing = value
- def forward(
- self,
- sample: torch.FloatTensor,
- timestep: Union[torch.Tensor, float, int],
- encoder_hidden_states: torch.Tensor,
- class_labels: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- return_dict: bool = True,
- ) -> Union[UNet3DConditionOutput, Tuple]:
- r"""
- Args:
- sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
- timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
- encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
- Returns:
- [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
- [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
- returning a tuple, the first element is the sample tensor.
- """
- # By default samples have to be AT least a multiple of the overall upsampling factor.
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
- # However, the upsampling interpolation output size can be forced to fit any upsampling size
- # on the fly if necessary.
- default_overall_up_factor = 2**self.num_upsamplers
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
- forward_upsample_size = False
- upsample_size = None
- if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
- logger.info("Forward upsample size to force interpolation output size.")
- forward_upsample_size = True
- # prepare attention_mask
- if attention_mask is not None:
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
- attention_mask = attention_mask.unsqueeze(1)
- # center input if necessary
- if self.config.center_input_sample:
- sample = 2 * sample - 1.0
- # time
- timesteps = timestep
- if not torch.is_tensor(timesteps):
- # This would be a good case for the `match` statement (Python 3.10+)
- is_mps = sample.device.type == "mps"
- if isinstance(timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
- else:
- dtype = torch.int32 if is_mps else torch.int64
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
- elif len(timesteps.shape) == 0:
- timesteps = timesteps[None].to(sample.device)
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- timesteps = timesteps.expand(sample.shape[0])
- t_emb = self.time_proj(timesteps)
- # timesteps does not contain any weights and will always return f32 tensors
- # but time_embedding might actually be running in fp16. so we need to cast here.
- # there might be better ways to encapsulate this.
- t_emb = t_emb.to(dtype=self.dtype)
- emb = self.time_embedding(t_emb)
- if self.class_embedding is not None:
- if class_labels is None:
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
- if self.config.class_embed_type == "timestep":
- class_labels = self.time_proj(class_labels)
- class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
- emb = emb + class_emb
- # pre-process
- sample = self.conv_in(sample)
- # down
- down_block_res_samples = (sample,)
- for downsample_block in self.down_blocks:
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
- sample, res_samples = downsample_block(
- hidden_states=sample,
- temb=emb,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- )
- else:
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
- down_block_res_samples += res_samples
- # mid
- sample = self.mid_block(
- sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
- )
- # up
- for i, upsample_block in enumerate(self.up_blocks):
- is_final_block = i == len(self.up_blocks) - 1
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
- # if we have not reached the final block and need to forward the
- # upsample size, we do it here
- if not is_final_block and forward_upsample_size:
- upsample_size = down_block_res_samples[-1].shape[2:]
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
- sample = upsample_block(
- hidden_states=sample,
- temb=emb,
- res_hidden_states_tuple=res_samples,
- encoder_hidden_states=encoder_hidden_states,
- upsample_size=upsample_size,
- attention_mask=attention_mask,
- )
- else:
- sample = upsample_block(
- hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
- )
- # post-process
- sample = self.conv_norm_out(sample)
- sample = self.conv_act(sample)
- sample = self.conv_out(sample)
- if not return_dict:
- return (sample,)
- return UNet3DConditionOutput(sample=sample)
- @classmethod
- def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
- if subfolder is not None:
- pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
- print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
- config_file = os.path.join(pretrained_model_path, 'config.json')
- if not os.path.isfile(config_file):
- raise RuntimeError(f"{config_file} does not exist")
- with open(config_file, "r") as f:
- config = json.load(f)
- config["_class_name"] = cls.__name__
- config["down_block_types"] = [
- "CrossAttnDownBlock3D",
- "CrossAttnDownBlock3D",
- "CrossAttnDownBlock3D",
- "DownBlock3D"
- ]
- config["up_block_types"] = [
- "UpBlock3D",
- "CrossAttnUpBlock3D",
- "CrossAttnUpBlock3D",
- "CrossAttnUpBlock3D"
- ]
- from diffusers.utils import WEIGHTS_NAME
- model = cls.from_config(config, **unet_additional_kwargs)
- model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
- if not os.path.isfile(model_file):
- raise RuntimeError(f"{model_file} does not exist")
- state_dict = torch.load(model_file, map_location="cpu")
- m, u = model.load_state_dict(state_dict, strict=False)
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
- # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
-
- params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
- print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
-
- return model
|