Gumbel-Softmax trick vs Softmax with temperature

14

4

From what I understand, the Gumbel-Softmax trick is a technique that enables us to sample discrete random variables, in a way that is differentiable (and therefore suited for end-to-end deep learning).

Many papers and articles describe it as a way of selecting instances in the input (i.e. 'pointers') without using the non-differentiable argmax-function. The thing that confuses me is that this effect can be achieved without randomness by just using Softmax with temperature:

Softmax with temperature $$y_=\frac{exp(\frac{_}{\tau})}{\sum_{}exp(\frac{_}{\tau})}$$

Gumbel-Softmax $$y_=\frac{exp(\frac{log(\pi_)+g_i}{\tau})}{\sum_{}exp(\frac{log(\pi_j)+g_j}{\tau})}$$

My question

From a practical and theoretical perspective, when is it beneficial to incorporate Gumbel noise into a neural network, as opposed to just using Softmax with temperature?

A couple of observations:

  1. When the temperature is low, both Softmax with temperature and the Gumbel-Softmax functions will approximate a one-hot vector. However, before convergence, the Gumbel-Softmax may more suddenly 'change' its decision because of the noise.
  2. When the temperature is higher, the Gumbel noise will get a larger significance and the distribution will become more uniform. Why is this desired?

My best guess is that the introduction of the Gumbel noise enforces stronger exploration before convergence, but I can't recall reading any papers that use this as a motivation to bring in the extra randomness.

Does anyone have any experience or insights on this? Maybe I've completely missed the key point of Gumbel-Softmax :)

4-bit

Posted 2019-08-29T10:30:50.857

Reputation: 141

Answers

5

Let's say you have two states, $X_1$ and $X_2$, and you have a model, $M$, that produces a score $M(X_i)$ for each state (i.e, the logits). Next you can use the logits to compute some distribution

$$P = softmax(\{M(X_1), M(X_2)\})$$

and take the state with the highest probability

$$X=argmax_{X_i}(P)$$

But what if you actually want to sample from $P$ instead of just taking the argmax - and you want the sample operation to be differentiable! This is where the Gumbel Trick comes in - instead of softmax, you compute

\begin{equation} \ X = argmax_{X_i}(\{M(X_i)+Z_i\}) \end{equation}

Where $Zi$ are i.i.d Gumbel(0,1). It turns out that $X$ will be equal to $X_1$ about $P(X_1)$ of the times and to $X_2$ about $P(X_2)$ of the times. In other words, the equation above samples from $P$.

But it still not differentiable because of the argmax operation. So instead of doing that we'll compute the Gumbel-Softmax distribution. Now if the temperature is low enough then the Gumbel-Softmax will produce something very close to one hot vector, where the probability of the predicted label will be 1 and other labels will have a probability of zero. So for example, if the Gumbel-Softmax gave the highest probability to $X_1$, you can do:

\begin{equation} \ X = \sum_{x_i} P_g(X_i)*X_i = 1*X_1 + 0*X_2 = X_1 \end{equation}

Where $P_g$ is the Gumbel-Softmax operation. No argmax is needed! So with this cool trick we can sample from a discrete distribution in a differentiable way.

Asi sheffer

Posted 2019-08-29T10:30:50.857

Reputation: 51

3

From a practical and theoretical perspective, when is it beneficial to incorporate Gumbel noise into a neural network, as opposed to just using Softmax with temperature?

You don't necessarily need Gumbel-Softmax to obtain "one-hot like" vectors, or the ability to differentiate through an indexing mechanism.

The LSTM architecture and derived variants are examples of this. They model "forget/input" gates using sigmoid outputs, which are deterministic. A "true" gating mechanism would be either 0 or 1, but to make things differentiable, LSTMs relax that constraint to sigmoided outputs. You'll notice that there is no "random" inputs here, and you can still apply the straight-through trick here to make the gates truly discrete (while backpropagating a biased gradient).

Gumbel-Softmax can be used wherever you would consider using a non-stochastic indexing mechanism (it is a more general formulation). But it's especially useful when you want to backpropagate through random samples of discrete variables.

  • VAE with a Gumbel-Softmax or Categorical posterior (encoder) distribution. Notably, you cannot simply use a deterministic softmax here because it would turn your VAE into a standard autoencoder. Autoencoders lack a way to generate new samples from the prior.
  • Actor-Critic architecture with a Gumbel-softmax or Categorical actor (most policy gradient implementations assume you can re-parameterize the gradients from the critic through the actor without using a score function estimator to estimate the black-box gradient). You cannot simply substitute the deterministic softmax here, because there is a type mismatch: the critic takes as input a action $a \in A$, while the softmax represents the conditional policy distribution $\pi(a|s)$
  • The "probabilistic" interpretation of a non-random quantization such as an LSTM would essentially be mode-seeking behavior in fitting a density. You have loss function that takes in categorical decisions $c$, so the expected loss $\mathbb{E}_c[f(c)]$ is minimized by learning some distribution $p(c)$. Quantizing a softmax without sampling the Gumbel noise (e.g. just using a sigmoid or softmax) is akin to choosing the same $c$ every time. For some $f$ this is okay, and for other $f$ this is highly suboptimal (consider the categorical KL divergence as a loss).

My best guess is that the introduction of the Gumbel noise enforces stronger exploration before convergence, but I can't recall reading any papers that use this as a motivation to bring in the extra randomness.

This is an interesting idea, but there are many ways to inject "exploration" noise into the set of parameters you use in a function approximator.

ejang

Posted 2019-08-29T10:30:50.857

Reputation: 131

1

For the softmax function, no matter what is the temperature, it is not the exact one-hot vector. If you could accept a soft version, it is good. However, if you choose the argmax to be the one, it is non-differentiable. One alternative way to back-propagate the gradients is by using the Straight Through Estimator (STE)[1] trick, and directly back-propagate the gradients [2], the gradient is an inaccurate approximation.

The advantage of Gumbel Softmax [3] is it samples one-hot according to the current learned distribution of \pi, it is one-hot and it is differentiable and the probability of sampled one-hot vector is according to \pi.

For your send question: at the beginning, the distribution \pi does not have any prior knowledge, so we want to sample one-hot vector by uniform (at this stage, the noise matters), and the distribution will gradually converge to the desired distribution (slightly sharper). As you training for longer epochs, prior knowledge of the distribution is learned enough, gradually decrease the temperature \tau and make \pi converge to a discrete distribution. As you gradually decrease the temperature \tau, the effect of noise is smaller.

PS: The sentence is incorrect:

When the temperature is low, both Softmax with temperature and the Gumbel-Softmax functions will approximate a one-hot vector.

Gumbel-softmax could sample a one-hot vector rather than an approximation. You could read the PyTorch code at [4].

[1] Binaryconnect: Training deep neural networks with binary weights during propagations

[2] LegoNet: Efficient Convolutional Neural Networks with Lego Filters

[3] Categorical Reparameterization with Gumbel-Softmax

[4] https://github.com/pytorch/pytorch/blob/15f9fe1d92a5d1e86278ae25f92dd9677b4956dc/torch/nn/functional.py#L1237

zhaohui

Posted 2019-08-29T10:30:50.857

Reputation: 11

The hard sampling gradient is a smart trick! – Shaohua Li – 2020-07-21T12:23:32.977