Bladeren bron

support v2

Yuwei Guo 2 jaren geleden
bovenliggende
commit
108921965d
3 gewijzigde bestanden met toevoegingen van 61 en 6 verwijderingen
  1. 22 2
      animatediff/models/resnet.py
  2. 11 3
      animatediff/models/unet.py
  3. 28 1
      animatediff/models/unet_blocks.py

+ 22 - 2
animatediff/models/resnet.py

@@ -18,6 +18,17 @@ class InflatedConv3d(nn.Conv2d):
         return x
 
 
+class InflatedGroupNorm(nn.GroupNorm):
+    def forward(self, x):
+        video_length = x.shape[2]
+
+        x = rearrange(x, "b c f h w -> (b f) c h w")
+        x = super().forward(x)
+        x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
+
+        return x
+
+
 class Upsample3D(nn.Module):
     def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
         super().__init__()
@@ -112,6 +123,7 @@ class ResnetBlock3D(nn.Module):
         time_embedding_norm="default",
         output_scale_factor=1.0,
         use_in_shortcut=None,
+        use_inflated_groupnorm=None,
     ):
         super().__init__()
         self.pre_norm = pre_norm
@@ -126,7 +138,11 @@ class ResnetBlock3D(nn.Module):
         if groups_out is None:
             groups_out = groups
 
-        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+        assert use_inflated_groupnorm != None
+        if use_inflated_groupnorm:
+            self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+        else:
+            self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
 
         self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
 
@@ -142,7 +158,11 @@ class ResnetBlock3D(nn.Module):
         else:
             self.time_emb_proj = None
 
-        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+        if use_inflated_groupnorm:
+            self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+        else:
+            self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+
         self.dropout = torch.nn.Dropout(dropout)
         self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
 

+ 11 - 3
animatediff/models/unet.py

@@ -24,7 +24,7 @@ from .unet_blocks import (
     get_down_block,
     get_up_block,
 )
-from .resnet import InflatedConv3d
+from .resnet import InflatedConv3d, InflatedGroupNorm
 
 
 logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
@@ -77,6 +77,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
         upcast_attention: bool = False,
         resnet_time_scale_shift: str = "default",
         
+        use_inflated_groupnorm=False,
+        
         # Additional
         use_motion_module              = False,
         motion_module_resolutions      = ( 1,2,4,8 ),
@@ -88,7 +90,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
         unet_use_temporal_attention    = None,
     ):
         super().__init__()
-
+        
         self.sample_size = sample_size
         time_embed_dim = block_out_channels[0] * 4
 
@@ -150,6 +152,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
 
                 unet_use_cross_frame_attention=unet_use_cross_frame_attention,
                 unet_use_temporal_attention=unet_use_temporal_attention,
+                use_inflated_groupnorm=use_inflated_groupnorm,
                 
                 use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
                 motion_module_type=motion_module_type,
@@ -175,6 +178,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
 
                 unet_use_cross_frame_attention=unet_use_cross_frame_attention,
                 unet_use_temporal_attention=unet_use_temporal_attention,
+                use_inflated_groupnorm=use_inflated_groupnorm,
                 
                 use_motion_module=use_motion_module and motion_module_mid_block,
                 motion_module_type=motion_module_type,
@@ -227,6 +231,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
 
                 unet_use_cross_frame_attention=unet_use_cross_frame_attention,
                 unet_use_temporal_attention=unet_use_temporal_attention,
+                use_inflated_groupnorm=use_inflated_groupnorm,
 
                 use_motion_module=use_motion_module and (res in motion_module_resolutions),
                 motion_module_type=motion_module_type,
@@ -236,7 +241,10 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
             prev_output_channel = output_channel
 
         # out
-        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+        if use_inflated_groupnorm:
+            self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+        else:
+            self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
         self.conv_act = nn.SiLU()
         self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
 

+ 28 - 1
animatediff/models/unet_blocks.py

@@ -30,7 +30,8 @@ def get_down_block(
     
     unet_use_cross_frame_attention=None,
     unet_use_temporal_attention=None,
-    
+    use_inflated_groupnorm=None,
+
     use_motion_module=None,
     
     motion_module_type=None,
@@ -50,6 +51,8 @@ def get_down_block(
             downsample_padding=downsample_padding,
             resnet_time_scale_shift=resnet_time_scale_shift,
 
+            use_inflated_groupnorm=use_inflated_groupnorm,
+
             use_motion_module=use_motion_module,
             motion_module_type=motion_module_type,
             motion_module_kwargs=motion_module_kwargs,
@@ -77,6 +80,7 @@ def get_down_block(
 
             unet_use_cross_frame_attention=unet_use_cross_frame_attention,
             unet_use_temporal_attention=unet_use_temporal_attention,
+            use_inflated_groupnorm=use_inflated_groupnorm,
             
             use_motion_module=use_motion_module,
             motion_module_type=motion_module_type,
@@ -106,6 +110,7 @@ def get_up_block(
 
     unet_use_cross_frame_attention=None,
     unet_use_temporal_attention=None,
+    use_inflated_groupnorm=None,
     
     use_motion_module=None,
     motion_module_type=None,
@@ -125,6 +130,8 @@ def get_up_block(
             resnet_groups=resnet_groups,
             resnet_time_scale_shift=resnet_time_scale_shift,
 
+            use_inflated_groupnorm=use_inflated_groupnorm,
+
             use_motion_module=use_motion_module,
             motion_module_type=motion_module_type,
             motion_module_kwargs=motion_module_kwargs,
@@ -152,6 +159,7 @@ def get_up_block(
 
             unet_use_cross_frame_attention=unet_use_cross_frame_attention,
             unet_use_temporal_attention=unet_use_temporal_attention,
+            use_inflated_groupnorm=use_inflated_groupnorm,
 
             use_motion_module=use_motion_module,
             motion_module_type=motion_module_type,
@@ -181,6 +189,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
 
         unet_use_cross_frame_attention=None,
         unet_use_temporal_attention=None,
+        use_inflated_groupnorm=None,
 
         use_motion_module=None,
         
@@ -206,6 +215,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
                 non_linearity=resnet_act_fn,
                 output_scale_factor=output_scale_factor,
                 pre_norm=resnet_pre_norm,
+
+                use_inflated_groupnorm=use_inflated_groupnorm,
             )
         ]
         attentions = []
@@ -248,6 +259,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
                     non_linearity=resnet_act_fn,
                     output_scale_factor=output_scale_factor,
                     pre_norm=resnet_pre_norm,
+
+                    use_inflated_groupnorm=use_inflated_groupnorm,
                 )
             )
 
@@ -290,6 +303,7 @@ class CrossAttnDownBlock3D(nn.Module):
 
         unet_use_cross_frame_attention=None,
         unet_use_temporal_attention=None,
+        use_inflated_groupnorm=None,
         
         use_motion_module=None,
 
@@ -318,6 +332,8 @@ class CrossAttnDownBlock3D(nn.Module):
                     non_linearity=resnet_act_fn,
                     output_scale_factor=output_scale_factor,
                     pre_norm=resnet_pre_norm,
+
+                    use_inflated_groupnorm=use_inflated_groupnorm,
                 )
             )
             if dual_cross_attention:
@@ -421,6 +437,8 @@ class DownBlock3D(nn.Module):
         output_scale_factor=1.0,
         add_downsample=True,
         downsample_padding=1,
+
+        use_inflated_groupnorm=None,
         
         use_motion_module=None,
         motion_module_type=None,
@@ -444,6 +462,8 @@ class DownBlock3D(nn.Module):
                     non_linearity=resnet_act_fn,
                     output_scale_factor=output_scale_factor,
                     pre_norm=resnet_pre_norm,
+
+                    use_inflated_groupnorm=use_inflated_groupnorm,
                 )
             )
             motion_modules.append(
@@ -526,6 +546,7 @@ class CrossAttnUpBlock3D(nn.Module):
 
         unet_use_cross_frame_attention=None,
         unet_use_temporal_attention=None,
+        use_inflated_groupnorm=None,
         
         use_motion_module=None,
 
@@ -556,6 +577,8 @@ class CrossAttnUpBlock3D(nn.Module):
                     non_linearity=resnet_act_fn,
                     output_scale_factor=output_scale_factor,
                     pre_norm=resnet_pre_norm,
+
+                    use_inflated_groupnorm=use_inflated_groupnorm,
                 )
             )
             if dual_cross_attention:
@@ -661,6 +684,8 @@ class UpBlock3D(nn.Module):
         output_scale_factor=1.0,
         add_upsample=True,
 
+        use_inflated_groupnorm=None,
+
         use_motion_module=None,
         motion_module_type=None,
         motion_module_kwargs=None,
@@ -685,6 +710,8 @@ class UpBlock3D(nn.Module):
                     non_linearity=resnet_act_fn,
                     output_scale_factor=output_scale_factor,
                     pre_norm=resnet_pre_norm,
+
+                    use_inflated_groupnorm=use_inflated_groupnorm,
                 )
             )
             motion_modules.append(