123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760 |
- # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
- import torch
- from torch import nn
- from .attention import Transformer3DModel
- from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
- from .motion_module import get_motion_module
- import pdb
- def get_down_block(
- down_block_type,
- num_layers,
- in_channels,
- out_channels,
- temb_channels,
- add_downsample,
- resnet_eps,
- resnet_act_fn,
- attn_num_head_channels,
- resnet_groups=None,
- cross_attention_dim=None,
- downsample_padding=None,
- dual_cross_attention=False,
- use_linear_projection=False,
- only_cross_attention=False,
- upcast_attention=False,
- resnet_time_scale_shift="default",
-
- unet_use_cross_frame_attention=None,
- unet_use_temporal_attention=None,
- use_inflated_groupnorm=None,
- use_motion_module=None,
-
- motion_module_type=None,
- motion_module_kwargs=None,
- ):
- down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
- if down_block_type == "DownBlock3D":
- return DownBlock3D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- downsample_padding=downsample_padding,
- resnet_time_scale_shift=resnet_time_scale_shift,
- use_inflated_groupnorm=use_inflated_groupnorm,
- use_motion_module=use_motion_module,
- motion_module_type=motion_module_type,
- motion_module_kwargs=motion_module_kwargs,
- )
- elif down_block_type == "CrossAttnDownBlock3D":
- if cross_attention_dim is None:
- raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
- return CrossAttnDownBlock3D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- downsample_padding=downsample_padding,
- cross_attention_dim=cross_attention_dim,
- attn_num_head_channels=attn_num_head_channels,
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
- unet_use_temporal_attention=unet_use_temporal_attention,
- use_inflated_groupnorm=use_inflated_groupnorm,
-
- use_motion_module=use_motion_module,
- motion_module_type=motion_module_type,
- motion_module_kwargs=motion_module_kwargs,
- )
- raise ValueError(f"{down_block_type} does not exist.")
- def get_up_block(
- up_block_type,
- num_layers,
- in_channels,
- out_channels,
- prev_output_channel,
- temb_channels,
- add_upsample,
- resnet_eps,
- resnet_act_fn,
- attn_num_head_channels,
- resnet_groups=None,
- cross_attention_dim=None,
- dual_cross_attention=False,
- use_linear_projection=False,
- only_cross_attention=False,
- upcast_attention=False,
- resnet_time_scale_shift="default",
- unet_use_cross_frame_attention=None,
- unet_use_temporal_attention=None,
- use_inflated_groupnorm=None,
-
- use_motion_module=None,
- motion_module_type=None,
- motion_module_kwargs=None,
- ):
- up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
- if up_block_type == "UpBlock3D":
- return UpBlock3D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- resnet_time_scale_shift=resnet_time_scale_shift,
- use_inflated_groupnorm=use_inflated_groupnorm,
- use_motion_module=use_motion_module,
- motion_module_type=motion_module_type,
- motion_module_kwargs=motion_module_kwargs,
- )
- elif up_block_type == "CrossAttnUpBlock3D":
- if cross_attention_dim is None:
- raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
- return CrossAttnUpBlock3D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- cross_attention_dim=cross_attention_dim,
- attn_num_head_channels=attn_num_head_channels,
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
- unet_use_temporal_attention=unet_use_temporal_attention,
- use_inflated_groupnorm=use_inflated_groupnorm,
- use_motion_module=use_motion_module,
- motion_module_type=motion_module_type,
- motion_module_kwargs=motion_module_kwargs,
- )
- raise ValueError(f"{up_block_type} does not exist.")
- class UNetMidBlock3DCrossAttn(nn.Module):
- def __init__(
- self,
- in_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- attn_num_head_channels=1,
- output_scale_factor=1.0,
- cross_attention_dim=1280,
- dual_cross_attention=False,
- use_linear_projection=False,
- upcast_attention=False,
- unet_use_cross_frame_attention=None,
- unet_use_temporal_attention=None,
- use_inflated_groupnorm=None,
- use_motion_module=None,
-
- motion_module_type=None,
- motion_module_kwargs=None,
- ):
- super().__init__()
- self.has_cross_attention = True
- self.attn_num_head_channels = attn_num_head_channels
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
- # there is always at least one resnet
- resnets = [
- ResnetBlock3D(
- in_channels=in_channels,
- out_channels=in_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- use_inflated_groupnorm=use_inflated_groupnorm,
- )
- ]
- attentions = []
- motion_modules = []
- for _ in range(num_layers):
- if dual_cross_attention:
- raise NotImplementedError
- attentions.append(
- Transformer3DModel(
- attn_num_head_channels,
- in_channels // attn_num_head_channels,
- in_channels=in_channels,
- num_layers=1,
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- use_linear_projection=use_linear_projection,
- upcast_attention=upcast_attention,
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
- unet_use_temporal_attention=unet_use_temporal_attention,
- )
- )
- motion_modules.append(
- get_motion_module(
- in_channels=in_channels,
- motion_module_type=motion_module_type,
- motion_module_kwargs=motion_module_kwargs,
- ) if use_motion_module else None
- )
- resnets.append(
- ResnetBlock3D(
- in_channels=in_channels,
- out_channels=in_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- use_inflated_groupnorm=use_inflated_groupnorm,
- )
- )
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
- self.motion_modules = nn.ModuleList(motion_modules)
- def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
- hidden_states = self.resnets[0](hidden_states, temb)
- for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
- hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
- hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
- hidden_states = resnet(hidden_states, temb)
- return hidden_states
- class CrossAttnDownBlock3D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- attn_num_head_channels=1,
- cross_attention_dim=1280,
- output_scale_factor=1.0,
- downsample_padding=1,
- add_downsample=True,
- dual_cross_attention=False,
- use_linear_projection=False,
- only_cross_attention=False,
- upcast_attention=False,
- unet_use_cross_frame_attention=None,
- unet_use_temporal_attention=None,
- use_inflated_groupnorm=None,
-
- use_motion_module=None,
- motion_module_type=None,
- motion_module_kwargs=None,
- ):
- super().__init__()
- resnets = []
- attentions = []
- motion_modules = []
- self.has_cross_attention = True
- self.attn_num_head_channels = attn_num_head_channels
- for i in range(num_layers):
- in_channels = in_channels if i == 0 else out_channels
- resnets.append(
- ResnetBlock3D(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- use_inflated_groupnorm=use_inflated_groupnorm,
- )
- )
- if dual_cross_attention:
- raise NotImplementedError
- attentions.append(
- Transformer3DModel(
- attn_num_head_channels,
- out_channels // attn_num_head_channels,
- in_channels=out_channels,
- num_layers=1,
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
- unet_use_temporal_attention=unet_use_temporal_attention,
- )
- )
- motion_modules.append(
- get_motion_module(
- in_channels=out_channels,
- motion_module_type=motion_module_type,
- motion_module_kwargs=motion_module_kwargs,
- ) if use_motion_module else None
- )
-
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
- self.motion_modules = nn.ModuleList(motion_modules)
- if add_downsample:
- self.downsamplers = nn.ModuleList(
- [
- Downsample3D(
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
- )
- ]
- )
- else:
- self.downsamplers = None
- self.gradient_checkpointing = False
- def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
- output_states = ()
- for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
- if self.training and self.gradient_checkpointing:
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
- return custom_forward
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(attn, return_dict=False),
- hidden_states,
- encoder_hidden_states,
- )[0]
- if motion_module is not None:
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
-
- else:
- hidden_states = resnet(hidden_states, temb)
- hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
-
- # add motion module
- hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
- output_states += (hidden_states,)
- if self.downsamplers is not None:
- for downsampler in self.downsamplers:
- hidden_states = downsampler(hidden_states)
- output_states += (hidden_states,)
- return hidden_states, output_states
- class DownBlock3D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- output_scale_factor=1.0,
- add_downsample=True,
- downsample_padding=1,
- use_inflated_groupnorm=None,
-
- use_motion_module=None,
- motion_module_type=None,
- motion_module_kwargs=None,
- ):
- super().__init__()
- resnets = []
- motion_modules = []
- for i in range(num_layers):
- in_channels = in_channels if i == 0 else out_channels
- resnets.append(
- ResnetBlock3D(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- use_inflated_groupnorm=use_inflated_groupnorm,
- )
- )
- motion_modules.append(
- get_motion_module(
- in_channels=out_channels,
- motion_module_type=motion_module_type,
- motion_module_kwargs=motion_module_kwargs,
- ) if use_motion_module else None
- )
-
- self.resnets = nn.ModuleList(resnets)
- self.motion_modules = nn.ModuleList(motion_modules)
- if add_downsample:
- self.downsamplers = nn.ModuleList(
- [
- Downsample3D(
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
- )
- ]
- )
- else:
- self.downsamplers = None
- self.gradient_checkpointing = False
- def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
- output_states = ()
- for resnet, motion_module in zip(self.resnets, self.motion_modules):
- if self.training and self.gradient_checkpointing:
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
- return custom_forward
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
- if motion_module is not None:
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
- else:
- hidden_states = resnet(hidden_states, temb)
- # add motion module
- hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
- output_states += (hidden_states,)
- if self.downsamplers is not None:
- for downsampler in self.downsamplers:
- hidden_states = downsampler(hidden_states)
- output_states += (hidden_states,)
- return hidden_states, output_states
- class CrossAttnUpBlock3D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- prev_output_channel: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- attn_num_head_channels=1,
- cross_attention_dim=1280,
- output_scale_factor=1.0,
- add_upsample=True,
- dual_cross_attention=False,
- use_linear_projection=False,
- only_cross_attention=False,
- upcast_attention=False,
- unet_use_cross_frame_attention=None,
- unet_use_temporal_attention=None,
- use_inflated_groupnorm=None,
-
- use_motion_module=None,
- motion_module_type=None,
- motion_module_kwargs=None,
- ):
- super().__init__()
- resnets = []
- attentions = []
- motion_modules = []
- self.has_cross_attention = True
- self.attn_num_head_channels = attn_num_head_channels
- for i in range(num_layers):
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
- resnets.append(
- ResnetBlock3D(
- in_channels=resnet_in_channels + res_skip_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- use_inflated_groupnorm=use_inflated_groupnorm,
- )
- )
- if dual_cross_attention:
- raise NotImplementedError
- attentions.append(
- Transformer3DModel(
- attn_num_head_channels,
- out_channels // attn_num_head_channels,
- in_channels=out_channels,
- num_layers=1,
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
- unet_use_temporal_attention=unet_use_temporal_attention,
- )
- )
- motion_modules.append(
- get_motion_module(
- in_channels=out_channels,
- motion_module_type=motion_module_type,
- motion_module_kwargs=motion_module_kwargs,
- ) if use_motion_module else None
- )
-
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
- self.motion_modules = nn.ModuleList(motion_modules)
- if add_upsample:
- self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
- else:
- self.upsamplers = None
- self.gradient_checkpointing = False
- def forward(
- self,
- hidden_states,
- res_hidden_states_tuple,
- temb=None,
- encoder_hidden_states=None,
- upsample_size=None,
- attention_mask=None,
- ):
- for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
- # pop res hidden states
- res_hidden_states = res_hidden_states_tuple[-1]
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
- if self.training and self.gradient_checkpointing:
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
- return custom_forward
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(attn, return_dict=False),
- hidden_states,
- encoder_hidden_states,
- )[0]
- if motion_module is not None:
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
-
- else:
- hidden_states = resnet(hidden_states, temb)
- hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
-
- # add motion module
- hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
- if self.upsamplers is not None:
- for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states, upsample_size)
- return hidden_states
- class UpBlock3D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- prev_output_channel: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- output_scale_factor=1.0,
- add_upsample=True,
- use_inflated_groupnorm=None,
- use_motion_module=None,
- motion_module_type=None,
- motion_module_kwargs=None,
- ):
- super().__init__()
- resnets = []
- motion_modules = []
- for i in range(num_layers):
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
- resnets.append(
- ResnetBlock3D(
- in_channels=resnet_in_channels + res_skip_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- use_inflated_groupnorm=use_inflated_groupnorm,
- )
- )
- motion_modules.append(
- get_motion_module(
- in_channels=out_channels,
- motion_module_type=motion_module_type,
- motion_module_kwargs=motion_module_kwargs,
- ) if use_motion_module else None
- )
- self.resnets = nn.ModuleList(resnets)
- self.motion_modules = nn.ModuleList(motion_modules)
- if add_upsample:
- self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
- else:
- self.upsamplers = None
- self.gradient_checkpointing = False
- def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
- for resnet, motion_module in zip(self.resnets, self.motion_modules):
- # pop res hidden states
- res_hidden_states = res_hidden_states_tuple[-1]
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
- if self.training and self.gradient_checkpointing:
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
- return custom_forward
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
- if motion_module is not None:
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
- else:
- hidden_states = resnet(hidden_states, temb)
- hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
- if self.upsamplers is not None:
- for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states, upsample_size)
- return hidden_states
|