Config Enumerate in Pyro
Pyro is a powerful probabilistic programming language, allowing to define and perform inference with complex statistical models. The usage of the library has become widespread in our lab, as the library enables to perform stochastic variational inference, which enables to scale statistical models to large data. In this post, I’ll take a closer look on Pyro’s enumeration strategy for discrete latent variables and illustrate this feature in a simple model.
Let us consider a standard text book problem (this one is in fact from David Mac Keys superb Information theory, Inderence and Learning Algorithms book): consider that we blindly draw a urn from set of ten urns each containing $10$ balls. Urn $u$ contains $u$ black balls and $10-u$ white balls, and we draw from our chosen urn $N$ times with replacement from that urn, obtaining in this way $nB$ black and $N-nB$ white balls. After drawing from the urn $N=10$ times we ask ourselves which urn we have drawn from.
The posterior probability distribution is
which we can easily determine analytically, but here we rather use Pyro. However, we first start by defining a function that let us simulate the experiment described above.
In the first line we sample uniformly from a Categorical distribution, as the probability for each category $u=0\dots 9$ is $\frac{1}{10}$, and then draw $10$ times from a Binomial distribution with probability $\frac{u}{10}$. The function returns the true urn, whose probability we seek to determine with a Pyro model, and the actual draw(s) (observations). Also the function enables us to specify the number of times we want to perform the experiment ($n$).
Defining statistical models in Pyro requires us to define a models which in some sense “reverse engineers” the stochastic process of interest. For this reason, our model looks very similar to the function we defined to simulate our data, but let us go through the function line by line:
model first defines a distribution for the probability of each urn, which is in this case a Dirichlet distribution, a common prior for the categorical distribution. Setting the concentration of the Dirichlet to vector of ones generates a flat distribution, thus representing a uniform probability for each urn u. The next statement with pyro.plate(‘…’)
is a so called plate, a context to indicate conditional independence and enable vectorised computations. Within in the this context we sample from a Categorical(u)
distribution which will return the chosen urn. Finally, the program evaluates the likelihood of the observations y given the urn u (note these are passed via the kwarg obs in the sample statement).
Optimising the model
We will use stochastic variational inference (SVI) for the inference and set up the code appropriately. To simplify things, we use the AutoDiagonalNormal guide from which sets up a Normal distribution with diagonal covariance for all hidden variables. The Trace_ELBO
loss enables to compute the ELBO over graph representation of our model, and finally we use an ADAM to perform the optimisation.
However, executing this code fails with a NotImplementedError: Cannot transform _IntergerInterval constraints
exception, so what went wrong here ? The error is due to the Categorical distribution which only has discrete support. To make the model work we have to explicitly tell Pyro to enumerate out the variables during training the model. Enumerating may occur sequentially or in parallel, with the latter enabling speed ups as it allows to parallelise computations.
The simplest way to enable enumeration is to decorate our model with @config_enumerate
.
This tells Pyro to enumerate all discrete variables in the model. Next we need to instruct the guide about the variables we have enumerated out, or in other words for which variables want variational distributions. Here, we have two possibilities, we could either hide the Categorical “urn” distribution or expose all other variables (“u
”) with pyro.poutine.block
.
Finally, we have to modify the loss function; rather than using Trace_ELBO
, we use TraceEnum_ELBO
which allows for enumeration on model graph.
Let now try out the code. We start by performing the hypothetical urn experiment 10 times, i.e. we select a random urn and then draw 10 times 10 balls with replacement while counting the number of black balls.
After training the model for 1000 iterations we find that the ELBO has converged.
To obtain the posterior distribution u requires us to write some additional line of code,
which draw from the fitted posterior distributions 5000 samples. To figure out which urn we have most likely drawn from we plot the resulting distribution
indicating that urn 6 is the most likely urn.