|
|
@@ -30,7 +30,8 @@ def get_down_block(
|
|
|
|
|
|
unet_use_cross_frame_attention=None,
|
|
|
unet_use_temporal_attention=None,
|
|
|
-
|
|
|
+ use_inflated_groupnorm=None,
|
|
|
+
|
|
|
use_motion_module=None,
|
|
|
|
|
|
motion_module_type=None,
|
|
|
@@ -50,6 +51,8 @@ def get_down_block(
|
|
|
downsample_padding=downsample_padding,
|
|
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
|
|
|
|
|
+ use_inflated_groupnorm=use_inflated_groupnorm,
|
|
|
+
|
|
|
use_motion_module=use_motion_module,
|
|
|
motion_module_type=motion_module_type,
|
|
|
motion_module_kwargs=motion_module_kwargs,
|
|
|
@@ -77,6 +80,7 @@ def get_down_block(
|
|
|
|
|
|
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
|
|
unet_use_temporal_attention=unet_use_temporal_attention,
|
|
|
+ use_inflated_groupnorm=use_inflated_groupnorm,
|
|
|
|
|
|
use_motion_module=use_motion_module,
|
|
|
motion_module_type=motion_module_type,
|
|
|
@@ -106,6 +110,7 @@ def get_up_block(
|
|
|
|
|
|
unet_use_cross_frame_attention=None,
|
|
|
unet_use_temporal_attention=None,
|
|
|
+ use_inflated_groupnorm=None,
|
|
|
|
|
|
use_motion_module=None,
|
|
|
motion_module_type=None,
|
|
|
@@ -125,6 +130,8 @@ def get_up_block(
|
|
|
resnet_groups=resnet_groups,
|
|
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
|
|
|
|
|
+ use_inflated_groupnorm=use_inflated_groupnorm,
|
|
|
+
|
|
|
use_motion_module=use_motion_module,
|
|
|
motion_module_type=motion_module_type,
|
|
|
motion_module_kwargs=motion_module_kwargs,
|
|
|
@@ -152,6 +159,7 @@ def get_up_block(
|
|
|
|
|
|
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
|
|
unet_use_temporal_attention=unet_use_temporal_attention,
|
|
|
+ use_inflated_groupnorm=use_inflated_groupnorm,
|
|
|
|
|
|
use_motion_module=use_motion_module,
|
|
|
motion_module_type=motion_module_type,
|
|
|
@@ -181,6 +189,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
|
|
|
|
|
unet_use_cross_frame_attention=None,
|
|
|
unet_use_temporal_attention=None,
|
|
|
+ use_inflated_groupnorm=None,
|
|
|
|
|
|
use_motion_module=None,
|
|
|
|
|
|
@@ -206,6 +215,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
|
|
non_linearity=resnet_act_fn,
|
|
|
output_scale_factor=output_scale_factor,
|
|
|
pre_norm=resnet_pre_norm,
|
|
|
+
|
|
|
+ use_inflated_groupnorm=use_inflated_groupnorm,
|
|
|
)
|
|
|
]
|
|
|
attentions = []
|
|
|
@@ -248,6 +259,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
|
|
non_linearity=resnet_act_fn,
|
|
|
output_scale_factor=output_scale_factor,
|
|
|
pre_norm=resnet_pre_norm,
|
|
|
+
|
|
|
+ use_inflated_groupnorm=use_inflated_groupnorm,
|
|
|
)
|
|
|
)
|
|
|
|
|
|
@@ -290,6 +303,7 @@ class CrossAttnDownBlock3D(nn.Module):
|
|
|
|
|
|
unet_use_cross_frame_attention=None,
|
|
|
unet_use_temporal_attention=None,
|
|
|
+ use_inflated_groupnorm=None,
|
|
|
|
|
|
use_motion_module=None,
|
|
|
|
|
|
@@ -318,6 +332,8 @@ class CrossAttnDownBlock3D(nn.Module):
|
|
|
non_linearity=resnet_act_fn,
|
|
|
output_scale_factor=output_scale_factor,
|
|
|
pre_norm=resnet_pre_norm,
|
|
|
+
|
|
|
+ use_inflated_groupnorm=use_inflated_groupnorm,
|
|
|
)
|
|
|
)
|
|
|
if dual_cross_attention:
|
|
|
@@ -421,6 +437,8 @@ class DownBlock3D(nn.Module):
|
|
|
output_scale_factor=1.0,
|
|
|
add_downsample=True,
|
|
|
downsample_padding=1,
|
|
|
+
|
|
|
+ use_inflated_groupnorm=None,
|
|
|
|
|
|
use_motion_module=None,
|
|
|
motion_module_type=None,
|
|
|
@@ -444,6 +462,8 @@ class DownBlock3D(nn.Module):
|
|
|
non_linearity=resnet_act_fn,
|
|
|
output_scale_factor=output_scale_factor,
|
|
|
pre_norm=resnet_pre_norm,
|
|
|
+
|
|
|
+ use_inflated_groupnorm=use_inflated_groupnorm,
|
|
|
)
|
|
|
)
|
|
|
motion_modules.append(
|
|
|
@@ -526,6 +546,7 @@ class CrossAttnUpBlock3D(nn.Module):
|
|
|
|
|
|
unet_use_cross_frame_attention=None,
|
|
|
unet_use_temporal_attention=None,
|
|
|
+ use_inflated_groupnorm=None,
|
|
|
|
|
|
use_motion_module=None,
|
|
|
|
|
|
@@ -556,6 +577,8 @@ class CrossAttnUpBlock3D(nn.Module):
|
|
|
non_linearity=resnet_act_fn,
|
|
|
output_scale_factor=output_scale_factor,
|
|
|
pre_norm=resnet_pre_norm,
|
|
|
+
|
|
|
+ use_inflated_groupnorm=use_inflated_groupnorm,
|
|
|
)
|
|
|
)
|
|
|
if dual_cross_attention:
|
|
|
@@ -661,6 +684,8 @@ class UpBlock3D(nn.Module):
|
|
|
output_scale_factor=1.0,
|
|
|
add_upsample=True,
|
|
|
|
|
|
+ use_inflated_groupnorm=None,
|
|
|
+
|
|
|
use_motion_module=None,
|
|
|
motion_module_type=None,
|
|
|
motion_module_kwargs=None,
|
|
|
@@ -685,6 +710,8 @@ class UpBlock3D(nn.Module):
|
|
|
non_linearity=resnet_act_fn,
|
|
|
output_scale_factor=output_scale_factor,
|
|
|
pre_norm=resnet_pre_norm,
|
|
|
+
|
|
|
+ use_inflated_groupnorm=use_inflated_groupnorm,
|
|
|
)
|
|
|
)
|
|
|
motion_modules.append(
|