|
|
@@ -51,14 +51,14 @@ class VanillaTemporalModule(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels,
|
|
|
- num_attention_heads = 8,
|
|
|
- num_transformer_block = 2,
|
|
|
- attention_block_types =( "Temporal_Self", "Temporal_Self" ),
|
|
|
- cross_frame_attention_mode = None,
|
|
|
- temporal_position_encoding = False,
|
|
|
- temporal_position_encoding_max_len = 24,
|
|
|
- temporal_attention_dim_div = 1,
|
|
|
- zero_initialize = True,
|
|
|
+ num_attention_heads = 8,
|
|
|
+ num_transformer_block = 2,
|
|
|
+ attention_block_types =( "Temporal_Self", "Temporal_Self" ),
|
|
|
+ cross_frame_attention_mode = None,
|
|
|
+ temporal_position_encoding = False,
|
|
|
+ temporal_position_encoding_max_len = 24,
|
|
|
+ temporal_attention_dim_div = 1,
|
|
|
+ zero_initialize = True,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
@@ -92,20 +92,17 @@ class TemporalTransformer3DModel(nn.Module):
|
|
|
attention_head_dim,
|
|
|
|
|
|
num_layers,
|
|
|
- attention_block_types=(
|
|
|
- "Temporal_Self",
|
|
|
- "Temporal_Self",
|
|
|
- ),
|
|
|
- dropout=0.0,
|
|
|
- norm_num_groups=32,
|
|
|
- cross_attention_dim=768,
|
|
|
- activation_fn="geglu",
|
|
|
- attention_bias=False,
|
|
|
- upcast_attention=False,
|
|
|
-
|
|
|
- cross_frame_attention_mode=None,
|
|
|
- temporal_position_encoding=False,
|
|
|
- temporal_position_encoding_max_len=24,
|
|
|
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
|
|
|
+ dropout = 0.0,
|
|
|
+ norm_num_groups = 32,
|
|
|
+ cross_attention_dim = 768,
|
|
|
+ activation_fn = "geglu",
|
|
|
+ attention_bias = False,
|
|
|
+ upcast_attention = False,
|
|
|
+
|
|
|
+ cross_frame_attention_mode = None,
|
|
|
+ temporal_position_encoding = False,
|
|
|
+ temporal_position_encoding_max_len = 24,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
@@ -228,10 +225,14 @@ class TemporalTransformerBlock(nn.Module):
|
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
|
- def __init__(self, d_model: int, dropout: float = 0., max_len: int = 24):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ d_model,
|
|
|
+ dropout = 0.,
|
|
|
+ max_len = 24
|
|
|
+ ):
|
|
|
super().__init__()
|
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
|
- # print(f"d_model: {d_model}")
|
|
|
position = torch.arange(max_len).unsqueeze(1)
|
|
|
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
|
|
pe = torch.zeros(1, max_len, d_model)
|
|
|
@@ -247,10 +248,10 @@ class PositionalEncoding(nn.Module):
|
|
|
class VersatileAttention(CrossAttention):
|
|
|
def __init__(
|
|
|
self,
|
|
|
- attention_mode=None,
|
|
|
- cross_frame_attention_mode=None,
|
|
|
- temporal_position_encoding=False,
|
|
|
- temporal_position_encoding_max_len=24,
|
|
|
+ attention_mode = None,
|
|
|
+ cross_frame_attention_mode = None,
|
|
|
+ temporal_position_encoding = False,
|
|
|
+ temporal_position_encoding_max_len = 24,
|
|
|
*args, **kwargs
|
|
|
):
|
|
|
super().__init__(*args, **kwargs)
|