How are LSTM's trained for text generation?



I've seen some articles about text generation using LSTMs (or GRUs) for text generation.

Basically it seems you train them by folding them out, and putting a letter in each input. But say you trained it with text which includes the string:

"The dog chased the cat"

and also includes the string:

"The lion chased the ca"

Both should be acceptable. But although they are very similar they differ entirely after the 4th character. So they will result in two very different vectors. Further the longer you roll it out the more they will differ. How is it then possible for an LSTM to learn that "The [something] chased the cat" is an acceptable phrase?

Equally if you train it to try and learn to pair up parentheses. I can see how you could manually design it to do this, but how exactly could it be trained to do this just by entering strings like "(das asdasdas) axd"?

What I'm getting at is that I don't get how it could LEARN any sort of structure more than a Markov model.

Any ideas?

(Also, I've only ever seen one article that showed a LSTM that can pair parentheses. So has this study ever been replicated?) I get that it is possible for an LSTM to do this but I don't get how it can learn to do this!

What I'm getting at is you usually train things with the input phrase and then compare it to the expected phrase to get the error. But in text generation there might be millions of possible acceptable expected phrases! So how can you compute an error?


Posted 2017-11-10T06:19:58.130

Reputation: 1 934



You compute the error in the same way as normal - treat the actual output as "ground truth" even if other options are possible. The LSTM will output probabilities of next character, it will not learn to associate a single "true" output except for rare circumstances (such as completing a word that can only be done in one way according to the training data). In practice that means that understanding acceptable loss values in NLP training is different to understanding them in other supervised learning environments, typically being higher than for datasets where ground truth is more predictable from the data.

In this way, it is a lot like a Markov or n-gram model, except that it learns its own internal state representations (they are not prescribed by the model, but are limited by structure and data available). The gated RNN memory models like LSTM and GRU include a mechanism for handling long term memory so can relatively easily learn things like quote, paren pairs etc.

What I'm getting at is that I don't get how it could LEARN any sort of structure more than a Markov model.

A Markov model can in theory represent any sequence knowledge that an LSTM can. The difference is in terms of efficiency. RNNs have more compact internal representations than explicit state representation in say HMMs. This makes them harder to work with in some ways (it is hard to see what a RNN state vector represents), but means that they can store more complex knowledge about state transitions in less memory. It also means that RNNs can generalise (although not guaranteed correctly) more readily, they will not fall back to random behaviour given a brand new state without any training history.

Neil Slater

Posted 2017-11-10T06:19:58.130

Reputation: 14 632

1In practise does this mean you have to set the learning rate very low so it doesn't try to get too close to one particular answer for example? – zooby – 2017-11-11T20:59:44.840

@zooby I don't think it needs to be lower than normal, but deep NNs often have low values e.g. 0.001 or 0.0001 and use gradient adjustment (RMSProp seems to be a popular choice for LSTM) with mini-batches (which will smooth out some of the differences) – Neil Slater – 2017-11-11T21:19:25.643