M
M
ML theory
Search…
Variational Inference
The goal of variational inference is to approximate a conditional intensity of latent variables (I prefer the word hidden variable instead) given the observed variables. Instead of directly estimate the density, we would like to find its best approximation with the smallest KL divergence from a family of candidate densities
L\mathscr{L}
.

Problem of approximate inference

Let
x={xi}i=1n\textbf{x}=\{x_i\}^{n}_{i=1}
be a set of observed variables and $
z={zi}i=1n\textbf{z}=\{z_i\}^{n}_{i=1}
be a set of hidden variables, with a joint probability of
p(x,z)p(\textbf{x}, \textbf{z})
  • Inference problem is to compute the conditional density of the hidden variables given the observations, aka.
    p(zx)p(\textbf{z}|\textbf{x})
We can write the conditional density as
p(zx)=p(z,x)p(x)=p(z,x)p(z,x)dxp(\textbf{z}|\textbf{x}) = \frac{p(\textbf{z},\textbf{x})}{p(\textbf{x})} = \frac{p(\textbf{z},\textbf{x})}{\int p(\textbf{z},\textbf{x}) \text{d}\textbf{x}}
The denominator contains the marginal density
p(x)p(\textbf{x})
of the observations, which is computed by integrating over the hidden variable from the joint density. We also call
p(x)p(\textbf{x})
the evidence. In general, computing integral is hard. Now we introduce one practical example of the gaussian mixture model

Bayesian mixture of Gaussians

Consider a mixture of
KK
unit-variance (variance equals 1) univariate (single variable) Gaussians. The means of i's Gaussian distribution is
μi\mu_i
,
μ={μ1,,μK}\mathbf{\mu}=\{\mu_1, \dots, \mu_K\}
. Each mean parameter is sampled from a common prior distribution
p(μ)p(\mu)
, which we assume
p(μ)=N(0,σ2)p(\mu)=\mathcal{N}(0,\sigma^2)
. To generate an observation
xix_i
from the model, we first choose a cluster assignment
ci=[0,,1,,0]\mathbf{c}_i=[0,\dots,1,\dots,0]
(1 at the
cic_i
's position) from a Categorical (uniform) distribution, which means that
xix_i
comes from mixture
cic_i
. We then draw
xix_i
from mixture
cic_i
,
xiN(ciμ,1)x_i \sim \mathcal{N}(\mathbf{c}_i^\top \mathbf{\mu}, 1)
The full model is
μiN(0,σ2),i=1,,KciCategorical(1K,,1K),i=1,,np(xici,μ)=N(ciμ,1)\mu_i \sim \mathcal{N}(0,\sigma^2), i=1,\dots,K\\ c_i \sim \text{Categorical}(\frac{1}{K}, \dots, \frac{1}{K}), i=1,\dots,n\\ p(x_i|c_i, \mathbf{\mu}) = \mathcal{N}(\mathbf{c}^\top_i\mathbf{\mu},1)
The joint density of hidden variable is
p(μ,c,x)=i=1np(xi,ci,μ)=i=1np(xici,μ)p(ci,μ)=i=1np(xici,μ)p(ci)p(μ)=p(μ)i=1np(xici,μ)p(ci)p(\mathbf{\mu}, \mathbf{c}, \mathbf{x}) = \prod_{i=1}^n p(x_i,c_i, \mathbf{\mu}) \\ = \prod_{i=1}^n p(x_i|c_i, \mathbf{\mu}) p(c_i, \mathbf{\mu})\\ = \prod_{i=1}^n p(x_i|c_i, \mathbf{\mu}) p(c_i) p(\mathbf{\mu})\\ = p(\mathbf{\mu}) \prod_{i=1}^n p(x_i|c_i, \mathbf{\mu}) p(c_i) \\
Given the observed
x\mathbf{x}
, our hidden variables are
z={μ,c}\mathbf{z} = \{\mathbf{\mu}, \mathbf{c}\}
. Hence the evidence integral is
p(x)=cip(μ)i=1np(xici,μ)p(ci)dμ=cip(ci)p(μ)i=1np(xici,μ)dμp(\mathbf{x}) = \sum_{c_i}\int p(\mathbf{\mu}) \prod_{i=1}^n p(x_i|c_i, \mathbf{\mu}) p(c_i) \text{d}\mathbf{\mu} \\ = \sum_{c_i}p(c_i)\int p(\mathbf{\mu}) \prod_{i=1}^n p(x_i|c_i, \mathbf{\mu}) \text{d}\mathbf{\mu}

The evidence lower bound (ELBO)

In variational inference, we specify a family
L\mathscr{L}
of density over the hidden variables. Our goal is to find the best
q(z)Lq(\mathbf{z}) \in \mathscr{L}
with the smallest KL divergence to the posterior density. Inference becomes a problem of optimization
q(z)=argminq(z)LKL(q(z)p(zx))q^*(\mathbf{z})=\arg\min_{q(\mathbf{z}) \in \mathscr{L}} KL(q(\mathbf{z})||p(\mathbf{z}|\mathbf{x}))
q(z)q^*(\mathbf{z})
is the best approximation of
p(zx)p(\mathbf{z}|\mathbf{x})
. Based on the definition of KL divergence
KL(q(z)p(zx))=E[logq(z)]E[logp(zx)]=E[logq(z)]E[logp(z,x)p(x)]=E[logq(z)]E[logp(z,x)]+E[logp(x)]KL(q(\mathbf{z})||p(\mathbf{z}|\mathbf{x})) = \mathbb{E}[\log q(\mathbf{z})] - \mathbb{E}[\log p(\mathbf{z}|\mathbf{x})]\\ = \mathbb{E}[\log q(\mathbf{z})] - \mathbb{E}[\log \frac{p(\mathbf{z},\mathbf{x})}{p(\mathbf{x})}]\\ = \mathbb{E}[\log q(\mathbf{z})] - \mathbb{E}[\log p(\mathbf{z},\mathbf{x})] + \mathbb{E}[\log p(\mathbf{x})]
Because all expectation are taken with respect to
z\mathbf{z}
,
E[logp(x)]=logp(x)\mathbb{E}[\log p(\mathbf{x})] = \log p(\mathbf{x})
. So this KL divergence requires the computation of
p(x)p(\mathbf{x})
again, which is not trackable.
Instead of compute KL directly, we optimize an alternative objective that is equivalent to KL adding a constant
ELBO(q)=E[logp(z,x)]E[logq(z)]\text{ELBO}(q) = \mathbb{E}[\log p(\mathbf{z},\mathbf{x})] - \mathbb{E}[\log q(\mathbf{z})]
This function is called evidence lower bound (ELBO). It is clear that
ELBO=KL+logp(x)\text{ELBO} = - \text{KL} + \log p(\mathbf{x})
. Since
logp(x)\log p(\mathbf{x})
is a constant with respect to
q(z)q(\mathbf{z})
, maximizing ELBO is equivalent to minimizing KL.
We rewrite the formula of ELBO as a sum of log likelihood of data and KL divergence between
q(z)q(\mathbf{z})
and
p(z)p(\mathbf{z})
ELBO(q)=E[logp(z,x)]E[logq(z)]=E[logp(xz)p(z)]E[logq(z)]=E[logp(xz)]+E[logp(z)]E[logq(z)]=E[logp(xz)]KL[q(z)p(z)]\text{ELBO}(q) = \mathbb{E}[\log p(\mathbf{z},\mathbf{x})] - \mathbb{E}[\log q(\mathbf{z})]\\ =\mathbb{E}[\log p(\mathbf{x}|\mathbf{z})p(\mathbf{z})] - \mathbb{E}[\log q(\mathbf{z})]\\ =\mathbb{E}[\log p(\mathbf{x}|\mathbf{z})] + \mathbb{E}[\log p(\mathbf{z})] - \mathbb{E}[\log q(\mathbf{z})]\\ =\mathbb{E}[\log p(\mathbf{x}|\mathbf{z})] - \text{KL}[q(\mathbf{z})||p(\mathbf{z})]

What does ELBO mean?

  1. 1.
    E[logp(xz)]\mathbb{E}[\log p(\mathbf{x}|\mathbf{z})]
    is the expected log-likelihood. It encourages the densities of hidden variables to explain the observed data.
  2. 2.
    KL[q(z)p(z)]- \text{KL}[q(\mathbf{z})||p(\mathbf{z})]
    is the the negative KL divergence between the variational density and the prior. It encourages the density close to the prior.
Thus the ELBO mirrors the usual balance between likelihood and prior.
Last modified 2mo ago