I'm trying to run a variational auto-encoder on the CIFAR-10 dataset, for which I've put together a simple network in TensorFlow with 4 layers in the encoder and decoder each, an encoded vector size of 256. For calculating the latent loss, I'm forcing the encoder part of my network to output log variances instead of standard deviations, so the latent loss function looks like:
latent_loss = -0.5 * tf.reduce_sum(1 + log_var_vector - tf.square(mean_vector) - tf.exp(log_var_vector), axis=1)
I found this formulation to be more stable than directly using the logarithms in the KL-divergence formula since the latter often results in infinite loss value. I'm applying a sigmoid activation function on the last layer of the decoder, and the generative loss is computed using mean-squared error. The combined loss is simple a sum of both latent and generative losses. I train the network in batches of 40 using Adam Optimizer with a learning rate of 0.001.
The problem is that my network doesn't train. The latent loss immediately drops to zero, and the generative loss doesn't go down. However when I only optimize only for the generative loss, the loss does reduce as expected. Under this setting, the value of the latent loss quickly jumps to very large values (order of 10e4 - 10e6).
I have a hunch that the culprit is the extreme mismatch between the magnitudes of both losses. The KL-divergence is unbounded, whereas the mean-squared error always remains <1, so when optimizing for both, the generative loss basically becomes irrelevant.
Any suggestions to solve the problem are welcome.