Variational Inference 101
Published:
Say you want to estimate a distribution of a latent vector $Z$ based on observations $X$. You can do that with variational inference!
Variational inference (VI) is a method that reduces posterior inference to optimisation. Why is it called so?
- Inference: wee seek to say something about latent variables $Z$
- Variational: we optimise over a set of distributions of latent variables $Z$
In theory, we can derive posterior of $Z$ given observations of $X$, i.e. $p(Z|X)$, according to the Bayes theorem: \begin{equation} p(Z|X) = \dfrac{p(X|Z)p(Z)}{p(X)} \end{equation} where:
- $p(Z)$ is the prior distribution of latent variables
- $p(X|Z)$ is the likelihood of observing data $X$ given a particular structure of a latent space $Z$
- $p(X)$ is the marginal distribution of $X$
However, using the Bayes rule out of the box is often problematic. The reason is that calculating the marginal distribution $p(X)$ is rarely tractable since this requires marginalising out the latent space: $p(X) = \int_Z p(X, Z)dZ$. This is typically only possible when we know the form of the posterior ditribution in advance, which allows to find its normalising constant.
Variational inference proposes a solution to this problem: approximate intractable $p(Z|X)$ with tractable surrogate distribution $q(Z)$ so that the difference between $p(Z|X)$ and $q(Z)$ is minimal. When we measure distance between two distributions, we usually use the Kullback-Leibler divergence:
\[D_{\text{KL}}(q(Z)|| p(Z|X)) \equiv \mathbb{E}_{Z \sim q}\left[ \log\dfrac{q(Z)}{p(Z|X)} \right]\]Great, but we do not know $p(Z|X)$ – we actually want to approximate it with $q(Z)$! So how do we calculate the KL divergence? We need to notice that we can decompose the KL divergence into a more convenient form by using the fact that $p(Z|X) = \frac{p(Z,X)}{p(X)}$:
\(D_{\text{KL}}(q(Z)|| p(Z|X)) \equiv \mathbb{E}_{Z \sim q}\left[ \log\dfrac{q(Z)p(X)}{p(Z,X)} \right] = \text{(apply the rule for log product)}\) \(= \mathbb{E}_{Z \sim q}\left[ \log\dfrac{q(Z)}{p(Z,X)} \right] + \mathbb{E}_{Z \sim q}\left[ \log p(X)\right] = \text{($p(X)$ does not depend on $q(Z)$)}\) \(= \mathbb{E}_{Z \sim q}\left[ \log\dfrac{q(Z)}{p(Z,X)} \right] + \log p(X) = \text{(apply the rule for log ratio)}\) \(= \log p(X) - \mathbb{E}_{Z \sim q}\left[ \log\dfrac{p(X,Z)}{q(Z)} \right]\)
Note that \(\log p(X) \geq \mathbb{E}_{Z \sim q}\left[ \log\dfrac{p(X,Z)}{q(Z)} \right] \; \forall Z, X, q(.), p(.)\) since $D_{\text{KL}}(q(Z)|| p(Z|X)) \geq 0$ because it is a distance measure. Let us call $\log p(X)$ the “evidence”. Then $\mathbb{E}_{Z \sim q}\left[ \log\dfrac{p(X,Z)}{q(Z)} \right]$ is the lower bound for the evidence $\log p(X)$.
Thereby, we define the Evidence Lower Bound (ELBO) as:
\[\text{ELBO}(q) \equiv \mathbb{E}_{Z \sim q}\left[ \log\dfrac{p(X,Z)}{q(Z)} \right]\]And now we can also rewrite the Kullback-Leibler divergence as a function of ELBO: \(D_{\text{KL}}(q(Z)|| p(Z|X)) = \log p(X) - \text{ELBO}(q)\)
We can notice that $\log p(X)$ (i.e. evidence) is independent of $q(Z)$, i.e. changing a surrogate distribution for $Z$ does not change the marginal probability of observed data. Hence, the following optimisation problems are equivalent:
\[q = \arg\min_{q \in \mathbb{Q}} D_{\text{KL}}(q(Z)|| p(Z|X))\]and
\[q = \arg\max_{q \in \mathbb{Q}} \text{ELBO}(q)\](since $\text{ELBO}(q)$ has a negative sign in the KL equation). Because the latter problem does not require knowledge of the posterior distribution $p(Z|X)$, we can solve it to find optimal surrogate $q(Z)$ which approximates $p(Z|X)$!
In the next posts, we will discuss more advanced topics in VI, challenges in VI and its applications.