@@ -17,24 +17,26 @@ def __init__(self, class_num=10) -> None:
17
17
18
18
super ().__init__ ()
19
19
20
- self .embed_size = 32 * 32 * 3 / (8 * 3 )
21
- self .window_size = 8 * 3
20
+ self .window_size = 24 #(8 x 3)
21
+ self .embed_size = 128 #32 x 32 x 3 / (8 x 3)
22
+ # Embed size and window size multiply to 32 x 32 x 3
22
23
self .class_num = class_num
23
24
24
25
## --- ENCODER --- ##
25
26
self .encoder_pos_embed = Positional_Encoding_Layer (self .window_size , self .embed_size )
26
27
27
28
# Implements self-attention with 16 attention heads
28
- self .encoding_layer = torch .nn .TransformerEncoderLayer (self .embed_size , 16 , \
29
+ self .encoding_layer = torch .nn .TransformerEncoderLayer (d_model = self .embed_size , nhead = 16 , \
29
30
activation = 'gelu' , dropout = 0.0 )
30
- self .encoder = torch .nn .TransformerEncoder (self .encoding_layer , 5 )
31
31
32
+ self .encoder = torch .nn .TransformerEncoder (self .encoding_layer , 5 )
32
33
34
+
33
35
## --- DECODER --- ##
34
36
self .decoder_pos_embed = Positional_Encoding_Layer (self .window_size , self .embed_size )
35
37
36
38
# Implements self-attention with 8 attention heads
37
- self .decoding_layer = torch .nn .TransformerDecoderLayer (self .embed_size , 8 , \
39
+ self .decoding_layer = torch .nn .TransformerDecoderLayer (d_model = self .embed_size , nhead = 8 , \
38
40
activation = 'gelu' , dropout = 0.0 )
39
41
self .decoder = torch .nn .TransformerDecoder (self .decoding_layer , 5 )
40
42
@@ -69,8 +71,9 @@ def forward(self, inputs):
69
71
# Add positional embeddings for decoder
70
72
positioned_dec = self .decoder_pos_embed (enc_out )
71
73
72
- # Pass through decoder
73
- dec_out = self .decoder (positioned_dec )
74
+ # Pass positioned encoder output
75
+ # AND original encoder output through decoder
76
+ dec_out = self .decoder (positioned_dec , enc_out )
74
77
75
78
# Finally, pass through linear layers (RELU activation for each except last)
76
79
flat = nn .Flatten ()(dec_out )
@@ -89,7 +92,8 @@ def __init__(self, window_size, emb_size):
89
92
""" Initializes trainable positional embeddings to add to the input """
90
93
super (Positional_Encoding_Layer , self ).__init__ ()
91
94
92
- self .pos_embed = nn .Parameter (torch .tensor (window_size , emb_size ))
95
+ #self.pos_embed = nn.Parameter(torch.tensor([window_size, emb_size], dtype=torch.float32))
96
+ self .pos_embed = nn .Parameter (torch .rand (window_size , emb_size ))
93
97
94
98
def forward (self , word_embeds ):
95
99
""" Adds (trainable) positional embeddings to the input """
0 commit comments