Skip to content

Commit 6eb5ce3

Browse files
committed
Transformer debugging
Misc fixes: added torch.rand, decoding layer needed second input, and simplified embedding size and window size declarations
1 parent d5632e1 commit 6eb5ce3

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

transformer.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,26 @@ def __init__(self, class_num=10) -> None:
1717

1818
super().__init__()
1919

20-
self.embed_size = 32*32*3/(8*3)
21-
self.window_size = 8*3
20+
self.window_size = 24 #(8 x 3)
21+
self.embed_size = 128 #32 x 32 x 3 / (8 x 3)
22+
# Embed size and window size multiply to 32 x 32 x 3
2223
self.class_num = class_num
2324

2425
## --- ENCODER --- ##
2526
self.encoder_pos_embed = Positional_Encoding_Layer(self.window_size, self.embed_size)
2627

2728
# Implements self-attention with 16 attention heads
28-
self.encoding_layer = torch.nn.TransformerEncoderLayer(self.embed_size, 16, \
29+
self.encoding_layer = torch.nn.TransformerEncoderLayer(d_model=self.embed_size, nhead=16, \
2930
activation='gelu', dropout = 0.0)
30-
self.encoder = torch.nn.TransformerEncoder(self.encoding_layer, 5)
3131

32+
self.encoder = torch.nn.TransformerEncoder(self.encoding_layer, 5)
3233

34+
3335
## --- DECODER --- ##
3436
self.decoder_pos_embed = Positional_Encoding_Layer(self.window_size, self.embed_size)
3537

3638
# Implements self-attention with 8 attention heads
37-
self.decoding_layer = torch.nn.TransformerDecoderLayer(self.embed_size, 8, \
39+
self.decoding_layer = torch.nn.TransformerDecoderLayer(d_model=self.embed_size, nhead=8, \
3840
activation='gelu', dropout = 0.0)
3941
self.decoder = torch.nn.TransformerDecoder(self.decoding_layer, 5)
4042

@@ -69,8 +71,9 @@ def forward(self, inputs):
6971
# Add positional embeddings for decoder
7072
positioned_dec = self.decoder_pos_embed(enc_out)
7173

72-
# Pass through decoder
73-
dec_out = self.decoder(positioned_dec)
74+
# Pass positioned encoder output
75+
# AND original encoder output through decoder
76+
dec_out = self.decoder(positioned_dec, enc_out)
7477

7578
# Finally, pass through linear layers (RELU activation for each except last)
7679
flat = nn.Flatten()(dec_out)
@@ -89,7 +92,8 @@ def __init__(self, window_size, emb_size):
8992
""" Initializes trainable positional embeddings to add to the input """
9093
super(Positional_Encoding_Layer, self).__init__()
9194

92-
self.pos_embed = nn.Parameter(torch.tensor(window_size, emb_size))
95+
#self.pos_embed = nn.Parameter(torch.tensor([window_size, emb_size], dtype=torch.float32))
96+
self.pos_embed = nn.Parameter(torch.rand(window_size, emb_size))
9397

9498
def forward(self, word_embeds):
9599
""" Adds (trainable) positional embeddings to the input """

0 commit comments

Comments
 (0)