Revisiting Variational Inference for Statististican
Variational Inference - A Review for Statisticians is perhaps the go to paper in order to learn variational inference (VI). After all, the paper has over 2800 citations indicating its popularity in the community. I recently decided to reread the paper while trying to closely follow the derivations. In this blogpost, I'll extend the derivations of the Gaussian Mixture model of the paper in the hope to elucidate some of the steps over which the authors went quickly.
Published
27 February 2022
Blei et. al. illustrate the coordinate ascent variational inference (CAVI) using a simple Gaussian Mixture model (Blei et al., 2017). The model1 places a prior on the mean of each component while keeping the variance of the likelihood fixed.
In the following, we will derive the joint probability and CAVI update equations for the model. Finally, we use these equations to implement the model in Python.
Constructing the log joint
We start by defining the components of the model. Note that we can write the probability of the prior component means as
p(μ)=k∏N(μk∣0,σ2).
Similarly, the prior for the latent variables $\mathbf{z}_n$ may be expressed as
p(zn)=k∏(K1)znk
while the likelihood is given by
p(xn∣μ,zn)=k∏N(0∣μk,1)znk.
We now introduce the variables X={xn}n=1N and Z={zn}n=1N to denote the complete dataset. Note that $p(\mathbf{Z})$ and $p(\mathbf{X}|\boldsymbol{\mu}, \mathbf{Z})$ are simply
The variational density for the mixture assignments
To obtain the (log) variational distribution of $\mathbf{z}_n$, we simply take the expectation of the log joint $(1)$ with respect to all other variables of the model. In our simple Gaussian mixture model this corresponds to $q(\mu_k)$, as it is the only other variable of the model.
Here I have canceled constant terms in $z_{nk}$ (only terms including the expectations w.r.t. to $q(\mu_k)$ change). Let’s take a closer look at the last line of $(2)$; exponentiating reveals $\log q^{*}(\mathbf{z}_n)$ that it has the form of a multinomial distribution
q∗(zn)∝k∏ρnkznk,
thus in order to normalise the distribution, we require that the variational parameter $\rho_{nk}$ represents a probability. We therefore define
The last line of the derivation suggests that the variational distribution for $\mu_k$ is Gaussian with natural parameter η=[Eq(zn)[znk]xn,−(∑n2Eq(zn)[znk]+2σ21)] and sufficient statistic t(μk)=[μk,μk2]. Using standard formulas (Blei, 2016), we find that the mean posterior mean and covariance are given by
Although we have derived parameters of our variational distributions, we can’t work properly with the results as all of them contain unresolved expectations. However, we can leverage the form of our variational distributions, i.e. $z_{nk}$ and $\mu_k$ are respectively multinomial and normally distributed. For example, to solve the expectation of $z_{nk}$, we use $(3)$ to determine
It is easy to see that Eq(μk)[μk]=mk. To determine the second moment of $\mu_k$, which is also required to compute $r_{nk}$, we make use of standard properties of the variance2
Eq(μk)[μk2]=mk2+sk2.
Implementing the model
With these equation in hand we can easily implement the model.
The following plot illustrates a fit of the model to simulated data with $N=100$, $\mu=[-4, 0, 9]$ and equal mixture component probabilities.
Note that I have slightly altered the notation of the paper using $\mathbf{z}$ instead of $\mathbf{c}$ and $n$ instead of $i$. ↩
Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational Inference: A Review for Statisticians. Journal of the American Statistical Association, 112(518), 859–877. https://doi.org/10.1080/01621459.2017.1285773
Blei, D. M. (2016). The Exponential Family.
Bishop, C. M. (2006). Pattern Recognition and Machine Learning.