Pytorch LSTM not training

1

So I am currently trying to implement an LSTM on Pytorch, but for some reason the loss is not decreasing. Here is my network:

class MyNN(nn.Module):
    def __init__(self, input_size=3, seq_len=107, pred_len=68, hidden_size=50, num_layers=1, dropout=0.2):
        super().__init__()
        
        self.pred_len = pred_len
        
        self.rnn = nn.LSTM(
            input_size=input_size, 
            hidden_size=hidden_size, 
            num_layers=num_layers, 
            dropout=dropout, 
            bidirectional=True,
            batch_first=True
        )
        
        self.linear = nn.Linear(hidden_size*2, 5)
    
    def forward(self, X):
        lstm_output, (hidden_state, cell_state) = self.rnn(X)
        
        labels = self.linear(lstm_output[:, :self.pred_len, :])
        
        return lstm_output, labels

And my training loop

LEARNING_RATE = 1e-2


net = MyNN(num_layers=1, dropout=0)

compute_loss = nn.MSELoss()

optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)


all_loss = []
for data in tqdm(list(train_loader)):
    X, y = data
    
    optimizer.zero_grad()

    lstm_output, output = net(X.float())
    
    # Computing the loss
    loss = compute_loss(y, output)
    all_loss.append(loss)
    loss.backward()
    
    optimizer.step()
    
# Plot
plt.plot(all_loss, marker=".")
plt.xlabel("Epoch")
plt.xlabel("Loss")
plt.show()

And this is what I got enter image description here

I have been trying to look for what the hell I am doing wrong but I have no idea. Also, before I used a keras LSTM and it worked well on the dataset.

Any help? Thanks!

David Marques

Posted 2020-09-21T22:34:53.170

Reputation: 35

Answers

1

You look at loss at every batch. You should average your loss over all batches. When you look at different batches your loss may increase simply because one batch is harder to predict than the other one. That's why it's not really interpretable. So start with that. If the problem persists it's probably exploding gradients. In that case lower your learning rate to 1e-3 or 1e-4 or even less if it continues.

YuseqYaseq

Posted 2020-09-21T22:34:53.170

Reputation: 153

Hmm had already tried lowering the LR. What I didn't do was trying to clip the gradients, because I did not think this could be exploding gradients. I will try it and see if it solves it. Thanks! – David Marques – 2020-09-22T13:18:00.483

If lower LR didn't help gradient clipping shouldn't work either because they both do nothing more than scale/clip gradients. Take a look at my edit though. – YuseqYaseq – 2020-09-22T13:23:12.167

Omg that makes total sense... I knew I was making some idiot mistake. I spent an afternoon yesterday thinking I was calling the loss function in a wrong way or something like that. I will check it later when I have the chance and then accept the answer. Thank you! – David Marques – 2020-09-22T13:28:46.970