Description
Please bear with me here.
This might be confusing to understand for some because I'm adding the pseudocode to support what's unclear to me. I've been following a tutorial and it was mentioned that we need to define a for loop over the target sequences considering we're doing machine translation using attention mechanisms using LSTMs.
I've made it something that would look like Keras.
This is the pseudocode
h = encoder(input) # For getting input sequences for calculation attention weights and context vector
decoder_hidden_state = 0 # Hidden state
decoder_cell_state = 0 # Cell state
outputs = []
for t in range(Ty): # Ty being the length of the target sequence
context = calc_attention(decoder_hidden_state, h) # decoder_hiddent_state(t-1), h(1),......h(Tx)
decoder_output, decoder_hidden_state, decoder_cell_state = decoder_lstm(context, init = [decoder_hidden_state,decoder_cell_state])
probabilities = dense_layer(o)
outputs.append(probabilities)
model = Model ( input, outputs)
The thing that is unclear to me is why are we using a for loop, It was said that "In a regular seq2seq, we pass in the entire target input sequence all at once because the output was calculated all at once. But we need a loop over Ty steps since each context depends on the previous state"
But I think that the same can be done in the case of attention because if I just remove the for loop.
Just like this code below, which is the decoder part of a normal seq2seq
decoder_inputs_placeholder = Input(shape=(max_len_target,))
decoder_embedding = Embedding(num_words_output, EMBEDDING_DIM)
decoder_inputs_x = decoder_embedding(decoder_inputs_placeholder)
decoder_lstm = LSTM(
LATENT_DIM,
return_sequences=True,
return_state=True,
)
If I want to add attention can't I just define the states here and call the calc_attention function that would return the context for a particular timestep while decoding, and can be passed onto the lstm call just as done before in pseudocode?
decoder_outputs, decoder_hidden_state, decoder_cell_state = decoder_lstm(
decoder_inputs_x,
initial_state=[decoder_hidden_state, decoder_cell_state]
)
decoder_outputs = decoder_dense(decoder_outputs)