## Will the target network, which is less trained than the normal network, output inferior estimates?

2

1

I'm having some trouble understanding some parts of the usage of target networks.

I get that having the same network predict the state/action/advantage values for both the current networks can lead to instability.

Based on my understanding, the intuition behind 1-step TD error that going a step into the future will give you a better estimate, that can then be used to update your original state/action/advantage value.

However, if you use a target network, which is less trained than the normal net — especially at early stages of the training — wouldn't the state/action/advantage value be updating towards an inferior estimate?

I've tried implementing DQNs and DDPGs on Cartpole, and I've found that the algorithms fail to converge when target networks are used, but work fine when those target networks are removed.

DQN convergence is so slow (and often non-monotonous) that difference in prediction between target and training networks is not noticeable. – mirror2image – 2019-07-20T14:29:00.527

## Answers

1

However, if you use a target network, which is less trained than the normal net — especially at early stages of the training — wouldn't the state/action/advantage value be updating towards an inferior estimate?

Possibly, but a critical part of stability of TD learning with function approximation when adding a target network is that any updates will be consistent in the short term.

Consider that with a single network, the TD target calculation will also be biased (and by the same amounts on the first few passes), but they will not be consistent. Each learning update shifts the estimates by a value that is biased. Problems occur with stability when instead of the bias decaying over time due to real data from each step, that the bias amount is enough to form a positive feedback loop.

As a concrete example, suppose an initial random network calculating $$\hat{q}(s,a,\theta)$$ where

• The reward is 1.0
• The true value function $$q^*(s,a)$$ is 10.0 (it is not really relevant) and that would make $$q^*(s',a') = 9.0$$
• The NN initially, predicts 5.0 for $$\hat{q}(s,a,\theta)$$ and 7.0 for $$\text{max}_{a'}\hat{q}(s',a',\theta)$$.

• We also have a learning network using parameters $$\theta$$ that change on each learning update.

• We have a frozen copy of the initial network using parameter $$\bar{\theta}$$

• We have some learning rate $$\alpha$$ that for arguments sake we say reduces error in this case by 50% each time it sees this example.

• The approximation process in the NN means that $$\hat{q}(s,a,\theta)$$ and $$\text{max}_{a'}\hat{q}(s',a',\theta)$$ are linked. This is a real feature of neural networks, but hard to model as it evolves over time. Let's say that with the initial setup that any learning of $$\hat{q}(s,a,\theta)$$ in isolation will make $$\text{max}_{a'}\hat{q}(s',a',\theta)$$ change towards $$\hat{q}(s,a,\theta)$$ by a 50% step. These steps sizes are not that important, they could be just 1% and the problem still occurs.

• You can control the learning rate directly, but cannot really control the "generalisation link strength" of the approximation.
• To really trigger the problem I want to show, let's say that these single steps are occurring far enough away from the end of an episode, or other moderating feedback, that we effectively get a few hundred visits to our $$(s,a)$$ pair without impact from other items (in fact it is likely that many will be going through the same issues)

Without the target network, the first TD target is $$1.0 + 7.0 = 8.0$$. After update $$\hat{q}(s,a,\theta) = 6.5$$ and $$\text{max}_{a'}\hat{q}(s',a',\theta)$$ = 6.75\$. This looks good, right? Getting closer to the real values . . . but keep going . . . the next updates work like this:

q(s,a)      max q(s',a')
7.125       6.938
7.531       7.234
7.882       7.559
8.221       7.889


Again, this looks OK? But let's come back to it after 100 [isolated, so a bit fake] iterations:

40.222      39.889
40.556      40.222
40.889      40.556
41.222      40.889
41.556      41.222


This has overshot, and both values are increasing exponentially. The initial neural network bias is caught in a positive feedback loop.

Now if we use the target network that always predicts 7.0, what happens is what you expect. After 100 iterations we have:

 8.0         7.5
8.0         7.5
8.0         7.5
8.0         7.5
8.0         7.5


These values are still incorrect, but the updates have made a more conservative and stable step towards the correct values. Note the second value is what the learning network would predict for $$\text{max}_{a'}\hat{q}(s',a',\theta)$$, but we have used the other prediction $$\text{max}_{a'}\hat{q}(s',a',\bar{\theta})$$ on each step.

In reality the feedback loops are more complex that this, because this is function approximation so action value estimates from different $$(s,a)$$ pairs interact in ways that are hard to predict. But the worry is that bias causes divergence in estimates, and this does happen in practice. It is more likely to happen in environments with long episodes, or where state loops are possible.

It is also worth noting that using the frozen target network has not solved the problem, it has just significantly throttled runaway feedback. The amount of throttling required to keep learning stable will vary depending on the problem. The number of steps between target network updates is a hyperparameter. Set it too low and you risk seeing stability problems. Set it too high and learning will take longer.

I've tried implementing DQNs and DDPGs on Cartpole, and I've found that the algorithms fail to converge when target networks are used, but work fine when those target networks are removed.

In that case, your implementations are incorrect. Definitely I have observed DQN with a target network working many times on Cartpole.

Whether or not using a target network makes convergence faster or more stable is more complex, and it may be for your network design or hyperparameter choices that adding the target network is making performance worse.