yes, but did it work? evaluating variational inference

This post is about the paper Yes, but Did It Work?: Evaluating Variational Inference by Yuling Yao, Aki Vehtari, Daniel Simpson, Andrew Gelman, which will appear at ICML 2018. I'm going to try to summarise/explain the paper in my own words, largely for my own benefit. I'm also going to do this without writing any mathematical formulae, because I don't remember how to do LaTeX with my website, and I don't feel like shaving that particular yak right now.

After the accepted ICML papers were announced, I went through it hunting for relevant work. I've decided it's a better use of my time to read papers that have been accepted somewhere, rather than drowning under the firehouse of my arXiv RSS feed. This paper ticked two boxes: variational inference, and knowing if it worked. It also ticked a third, secret box of "titles that make it sound like the paper will have been written in a casual, conversational style, eschewing the tradition of appearing smarter by obfuscating the point".

Single-line summary: they describe two diagnostics for evaluating the variational posterior, with different properties and use-cases.

So let's talk about these two diagnostics.

Pareto Smoothed Importance Sampling (PSIS)

At this point I realised the auther overlap between this paper and Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC, which in turn builds on Pareto Smoothed Importance Sampling.

So what is PSIS and how is it useful for evaluating VI?

Importance sampling is a technique which enables us to estimate expectations under a distribution which is difficult to sample from (the target distribution), using an approximating proposal distribution. You sample from your proposal distribution (which is easy to sample from), then weight those samples by the ratio of target distribution to the proposal distribution (evaluated at the sample point). These weights are called importance ratios.

Pareto smoothing comes in because in the event that the proposal distribution is a poor fit to the target distribution, these weights can have a very high variance. The proposal distribution is the denominator in the importance ratio, so if you imagine that this distribution is a lot thinner than the target distribution - that is, it's near zero in regions where the target distribution is not, you can end up with some very large importance ratios - high variance. This means that you would need a lot of samples to estimate the expectation value of interest. Pareto smoothing is a way to control this variance. It builds on the idea of simply truncating the importance ratios (Truncated Importance Sampling, Ionides 2008) by instead fitting a Pareto distribution to them.

Side note: Part of the motivation of using the Pareto distribution at this point, I think, is to use its fitted parameters to do diagnostics on the proposal distribution. This is exactly what "Yes, But Did It Work?" is doing, but they already talk about it in the original PSIS paper, so I guess part of the novelty of this ICML paper is bringing it explicitly to the VI area. More about VI when I'm done with this Pareto stuff.

So how does fitting a Pareto distribution to the importance ratios help? In practice, you fit the Pareto distribution, and then instead of simply truncating the top M importance ratios (M is chosen empirically/arbitrarily) you replace them using the inverse cumulative density function of the Pareto distribution you fit. This replacement operates on the ranks of these importance ratios (so the smallest of the M, the second-smallest and so on), replacing those with what you'd get in a Pareto distribution ranked by CDF. This reminds me of rank-based inverse-normal transformations I've seen used in genetics (weirdly difficult to find papers about this, here is an R vignette). They argue that this produces an IS estimate that is less biased than what you get using truncated IS. Moreover, you can inspect the parameters of the fitted Pareto distribution to do diagnostics.

The reason they use a Pareto distribution to model the top M importance ratios is because It Is Known. Rather, it is shown in Pickands, 1975 to be an appropriate choice. To be specific, they use a generalised Pareto distribution. This distribution has three parameters (location, scale, shape), and it has the property that it has finite moments up to order 1/k, where k is the shape parameter. That means that if k > 0.5, the variance of the importance ratios is infinite, but if k < 1 at least the mean exists. They point to 0.5 < k < 0.7 as a regime where the importance sampling procedure exhibits a practicaly useful convergence rate. Side note: I don't quite see where the jump from modelling the variance of the tail of the importance ratios to modelling all the importance ratios happened. I suppose if you observe that your tail has a finite variance, then you must have finite variance in the rest of the values, but I would have expected an additional step to extend the conclusions made about the fit of the Pareto distribution to the rest of the importance ratios.

Now, relating this back to variational inference is straight forward: replace "target distribution" with "variational posterior". PSIS, via the shape parameter of the fitted Pareto distribution, gives us a diagnostic for how well the variational posterior fits with the true posterior.

But wait... don't we need the true posterior to calculate the importance ratios? Isn't this circular? The answer is that you can use the joint distribution (p(z, x) rather than p(z|x)) because the estimate of k is invariant to a constant multiplicative factor, which will be p(x).

The diagnostic approach is thus:

  1. Run VI, get variational distribution q(z) approximating p(z|x).
  2. Sample a bunch of zs from q(z)
  3. Calculate p(z, x) for all the zs (remember, x is known - it is a specific dataset), and get the importance ratios p(z, x)/q(z)
  4. Fit a generalised Pareto distribution to the largest M importance ratios
  5. Report the (estimated) shape parameter k
  6. If k > 0.7, the VI approximation is not reliable
  7. If k < 0.5, the VI approximation is good, and PSIS can additionally be used to calculate further divergence measures

They touch on two other points in this paper, regarding PSIS:

  1. The shape parameter k is invariant under reparametrisation, but reparametrisation can influence the VI procedure and produce better/worse proposal distributions. So looking at k can help guide reparametrisation efforts
  2. Marginal PSIS diagnostics are not useful. These marginal diagnostics would be doing the above procedure, but instead of sampling full zs, sampling only 1 dimension at a time. Compared to PSIS diagnostic evaluated from the joint distribution, these marginal ks are never larger (usually smaller) than k, and can be misleading. Also, this means you need access to the marginal distribution p(z_i, x) (or p(z_i | x)) to get the importance ratios, which may be unavailable. So don't do it.

Variational Simulation-Based Calibration Diagnostic (VSBC)

The PSIS diagnostic looks at the full approximate posterior. However, sometimes you don't need to properly approximate the full posterior, and can get away with producing useful point estimates. VSBC evaluates the quality of point estimates. It is based on Validation of Software for Bayesian Models Using Posterior Quantiles (Cook, 2006).

The key observation from (Cook et al., 2006) is going to be fun to explain with no proper equations. Let's try: suppose we have access to p(z), p(x|z) and p(z|x) (this will be approximated by p(z) shortly). We simulate an x by first sampling from z, then p(x|z). Now, we then sample multiple z's from p(z|x). We can ask what fraction of those sampled z's are smaller than the original z - we call this the calibration probability. Now, if we were to do this multiple times (picking a z, then x, then multiple resampled z's) we would get a distribution of calibration probabilities. And that should be uniform. I think this is Cook's observation.

So to relate this to VI, we can perform the above procedure, replacing the true posterior p(z|x) with the approximate posterior q(x). (This means we have to do a full VI step for each dataset x we sample!) We could then in principle ask how far the distribution of calibration probabilities deviates from normal, but in this paper they suggest (following on from other literature) to instead measure how asymmetric this probability is.

Thus, the VSBC diagnostic is to test for asymmetry in the distribution of calibration probabilities. They do this using a Kolmogorov-Smirnov test between the distribution of probabilities and one minus that distribution. More specifically, they actually focus on marginal probabilities - so where I said 'z' above, imagine this is one dimension of z. Thus, they look at marginal calibration probabilities. This is necessary because z < z' only makes sense for scalars.

So running the diagnostic means running VI multiple times over simulated datasets. If your generative model of the data is poor, this diagnostic won't tell you much about how your VI scheme will work on real data, or indeed on a given instance of real data, since VSBC gives average performance. An advantage of VSBC over PSIS is that it looks at marginals, so you can potentially identify which dimensions in z are problematic during fitting.

Applications, etc.

Given these two diagnostics, they then show how they can be used in a couple of different settings - Bayesian linear regression, logistic regression, a hierarchical model (the famous Eight-School model), and a cancer classification application. In all cases, they use mean-field Gaussian automatic differentiation variational inference.

The big question for me and probably a lot of other users of variational inference is how well these can be applied to the types of posteriors we try to approximate using hideous neural networks. VSBC may be computationally impractical because it puts the whole VI procedure inside an inner loop, although it's easily parallelisable. High-dimensional posteriors are problematic for importance sampling and thus PSIS, although I don't know what "high" is - 10, 100? 1000?? Multimodality in the posterior is also a challenge, as they point out in the discussion - the VI approximation could completely miss a mode, but the PSIS diagnostic would nonetheless indicate all is well. They suggest to use PSIS to evaluate some other divergence (such as a KL divergence) to diagnose this case.

In summary, this has been a post about evaluating variational inference using two diagnostics - Pareto-smoothed importance sampling, and variational simulation-based calibration. At its core this paper feels like an application of previous/existing work to a slightly new domain (variational inference). I'm curious to try these diagnostics on my own variational posteriors. Code is seemingly available (maybe just for PSIS) - R package (loo), and also a Python/Matlab port.