I am implementing a seq2seq autoencoder in pytorch:
Q1) While it is true that we can keep the encoder as bidirectional, but can we keep the decoder as bidirectional as well(does it make any sense) if we are not using teacher forcing?
if we cant, then let us suppose the architecture is:
encoder = nn.LSTM(128, 128, layers = 2, bidirectional=True)
dencoder = nn.LSTM(128, 128, layers = 2, bidirectional=False)
here 128 is the input and output dim of both the LSTM. The encoder hidden output will be of size (4, 1, 128) following the convention(2(for bidirectional)*num_layers, batch_size = 1, 128)
Q2) Now I wanna know that among these 4 tensors of size (1, 128) which tensor is the hidden output of which layer and of which direction from the encoder. considering the complete output of encoder being:
encoder_output, (encoder_hidden, encoder_state) = encoder(input_tensor)
My basic assumption is:
encoder_hidden[0, :, :] is from encoder layer 1 forward direction
encoder_hidden[1, :, :] is from encoder layer 1 reverse direction
encoder_hidden[2, :, :] is from encoder layer 2 forward direction
encoder_hidden[3, :, :] is from encoder layer 2 reverse direction
but I have no way to confirm this.
then, let us suppose I have Identified this and let us say that this assumption above of mine is correct. Even then I have to somehow pass these encoder hidden outputs as thought vector to the decoder(which has 2 layer but is unidirectional), so it expects only two tensors to be passed as the initialized hidden state inputs (one tensor for each layer).
Q3) Then how should I convert 4 tensors to 2. Some Ideas are:
- pair up by concatenating any two vectors among these 4 to convert them to 2 tensors and make the decoder hidden dimension as double the size of the encoder hidden dimension while instantiating the decoder.
- pair up by averaging or adding any two tensors among these 4(the decoder input dimension here stays the same)
- follow 1 and then pass them to a linear layer to map them back to the original decoder hidden dimension.
I have confused about for a long time, any help is appreciated. Thanks.