Learning a simple sequence with RNN (Keras)



I am trying to learn a very simple sequence using an RNN (implemented in Keras)

The input sequence is randomly generated integers between 0 to 100:

x=np.random.randint(0,100, size=2000)

while the expected output value for time t is the (t-2)th input term i.e:


such that an example dataset looks like this:

|  X   |  Y   |
|    0 |   NA |
|   24 |   NA |
|   33 |    0 |
|    6 |   24 |
|   78 |   33 |
|   11 |    6 |
|    . |    . |
|    . |    . |

Note: I drop the NA rows before training.

I am trying to train a simple RNN to learn this sequence as below:

xtrain=np.reshape(df['X'], (df.shape[0], 1, 1))
#to match dimension of input shape for SimpleRNN layer.  

model.SimpleRNN(2, input_shape=(None,1)))
model.compile(optimizer='adam', loss='mse')
model.fit(x=xtrain, y=df['Y'], epochs=200, batch_size=5)

however, I find that this implementation results in a local minima which predicts constant value(~50) for all test observations.

Could anyone help me with the right way of implementing a basic RNN in Keras to learn this sequence?


Posted 2018-02-21T08:27:11.190

Reputation: 41

Your network might seem to small for it to memorize a sequence of 100 integers. – Daniel – 2018-09-02T20:37:13.337

@Aditya, could you share the code for the solution of your problem? I'm interested too. – Basj – 2019-02-11T10:38:50.047



Using the raw integers as inputs and targets will make this a very difficult task. A better approach would be to come up with a vector for each number. You can simply encode each number directly as a vector, use a "onehot" representation, or use an Embedding layer. You can see an example of embeddings in conx that is built on Keras here:


and and example without embeddings using onehot representations here:


Doug Blank

Posted 2018-02-21T08:27:11.190

Reputation: 291

The idea of using a vector / an embedding for each integer is to have a higher dimensionality than 1, so that the NN can detect "features" in them? – Basj – 2019-02-11T10:43:36.843

Yes, you could describe it that way. It is easier for the network to start with such representations, rather than having to learn to break up an integer value. I presume that the items represented by (say) integers 22, and 23 don't have anything particular in common, so it would be better to represent them as a non-overlapping vector. – Doug Blank – 2019-02-12T12:45:28.470