Yuwei Guo 2 년 전
부모
커밋
41a698ae8e
1개의 변경된 파일29개의 추가작업 그리고 28개의 파일을 삭제
  1. 29 28
      animatediff/models/motion_module.py

+ 29 - 28
animatediff/models/motion_module.py

@@ -51,14 +51,14 @@ class VanillaTemporalModule(nn.Module):
     def __init__(
         self,
         in_channels,
-        num_attention_heads                 = 8,
-        num_transformer_block               = 2,
-        attention_block_types               =( "Temporal_Self", "Temporal_Self" ),
-        cross_frame_attention_mode          = None,
-        temporal_position_encoding          = False,
-        temporal_position_encoding_max_len  = 24,
-        temporal_attention_dim_div          = 1,
-        zero_initialize                     = True,
+        num_attention_heads                = 8,
+        num_transformer_block              = 2,
+        attention_block_types              =( "Temporal_Self", "Temporal_Self" ),
+        cross_frame_attention_mode         = None,
+        temporal_position_encoding         = False,
+        temporal_position_encoding_max_len = 24,
+        temporal_attention_dim_div         = 1,
+        zero_initialize                    = True,
     ):
         super().__init__()
         
@@ -92,20 +92,17 @@ class TemporalTransformer3DModel(nn.Module):
         attention_head_dim,
 
         num_layers,
-        attention_block_types=(
-            "Temporal_Self", 
-            "Temporal_Self",
-        ),        
-        dropout=0.0,
-        norm_num_groups=32,
-        cross_attention_dim=768,
-        activation_fn="geglu",
-        attention_bias=False,
-        upcast_attention=False,
-
-        cross_frame_attention_mode=None,
-        temporal_position_encoding=False,
-        temporal_position_encoding_max_len=24,
+        attention_block_types              = ( "Temporal_Self", "Temporal_Self", ),        
+        dropout                            = 0.0,
+        norm_num_groups                    = 32,
+        cross_attention_dim                = 768,
+        activation_fn                      = "geglu",
+        attention_bias                     = False,
+        upcast_attention                   = False,
+        
+        cross_frame_attention_mode         = None,
+        temporal_position_encoding         = False,
+        temporal_position_encoding_max_len = 24,
     ):
         super().__init__()
 
@@ -228,10 +225,14 @@ class TemporalTransformerBlock(nn.Module):
 
 
 class PositionalEncoding(nn.Module):
-    def __init__(self, d_model: int, dropout: float = 0., max_len: int = 24):
+    def __init__(
+        self, 
+        d_model, 
+        dropout = 0., 
+        max_len = 24
+    ):
         super().__init__()
         self.dropout = nn.Dropout(p=dropout)
-        # print(f"d_model: {d_model}")
         position = torch.arange(max_len).unsqueeze(1)
         div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
         pe = torch.zeros(1, max_len, d_model)
@@ -247,10 +248,10 @@ class PositionalEncoding(nn.Module):
 class VersatileAttention(CrossAttention):
     def __init__(
             self,
-            attention_mode=None,
-            cross_frame_attention_mode=None,
-            temporal_position_encoding=False,
-            temporal_position_encoding_max_len=24,            
+            attention_mode                     = None,
+            cross_frame_attention_mode         = None,
+            temporal_position_encoding         = False,
+            temporal_position_encoding_max_len = 24,            
             *args, **kwargs
         ):
         super().__init__(*args, **kwargs)