UP | HOME

gumbel max trick

1 goal

We have categorical random variable \(\{1,...,k\}\) with associated log-probabilities \(\{x_1,...,x_k\}\) that we want to sample from.

2 usual way

Usually, we would exponentiate and then normalize: \[ \pi_k = \frac{\exp(x_k)}{\sum_{j} \exp(x_j)} \] (Note that this is the same thing we do when we take the softmax)

3 trick

Instead, take: \[ y = \arg\max_{i\in\{1..k\}} x_k + z_k \] where \(z_k \sim \text{Gumbel}(0,1)\) are i.i.d. drawn from the standard Gumbel

It turns out that sampling \(y\) this way results in the same distribution as sampling \(y\) according to the \(\pi_k\)'s!

Proof here

4 Source

Created: 2021-09-14 Tue 21:44