EM#

EM algorithm applies to a large faimily of estimation problems with latent variables, e.g GMM.

Suppose we have a training set \(\{x^{(1)},...,x^{(n)}\}\) with \(z\) being the latent variable, by marginal probabilities:

\[p(x;\theta) = \sum_{z}p(x, z; \theta)\]

We wish to fit the parameters \(\theta\) by maximizing the log-likelihood of the data:

\[\begin{split} \begin{equation} \begin{split} l(\theta) =& \sum_{i=1}^{n}\log{p(x^{(i)};\theta)} \\ =& \sum_{i=1}^{n}\log\sum_{z^{(i)}}p(x^{(i)}, z^{(i)}; \theta) \end{split} \end{equation} \end{split}\]

Maximizing \(l(\theta)\) directly might be difficult.

Our strategy will be to instead repeatedly construct a lower-bound on \(l\) (E-step), and then optimize that lower-bound (M-step).

Lower Bound#

Let \(Q\) be a distribution over \(z\), then:

\[\begin{split} \begin{equation} \begin{split} \log{p(x;\theta)} =& \log\sum_{z}p(x,z;\theta)\\ =& \log\sum_{z}Q(z)\frac{p(x,z;\theta)}{Q(z)}\\ \ge& \sum_{z}Q(z)\log{\frac{p(x,z;\theta)}{Q(z)}}\quad\mbox{(log is concave)} \end{split} \end{equation} \end{split}\]

We call this bound the evidence lower bound(ELBO) and denote it by:

\[\mbox{ELBO}(x;Q,\theta) = \sum_{z}Q(z)\log\frac{p(x,z;\theta)}{Q(z)}\]

To hold with equality, it is sufficient that:

\[\frac{p(x,z;\theta)}{Q(z)} = c\]

which is equivalent to:

\[Q(z) = p(z|x;\theta)\]

The EM Algorithm#

Taking all instances into account, for any distributions \(Q_{1},...,Q_{n}\):

\[\begin{split} \begin{equation} \begin{split} l(\theta) \ge& \sum_{i}\mbox{ELBO}(x^{(i)};Q_{i},\theta)\\ =& \sum_{i}\sum_{z^{(i)}}Q_{i}(z^{(i)})\log\frac{p(x^{(i)},z^{(i)};\theta)}{Q_{i}(z^{(i)})} \end{split} \end{equation} \end{split}\]

equality holds when \(Q_{i}\) equal to the posterior distribution in this setting of \(\theta\):

\[\mbox{E-step:}\quad{Q_{i}{(z^{(i)})}} = p(z^{(i)}|x^{(i)};\theta)\]

M-step maximize the lower bound with respect to \(\theta\) while keeping \(Q_{i}\) fixed.

\[ \theta := \underset{\theta}{\mbox{argmax}}\sum_{i}\mbox{ELBO}(x^{(i)};Q_{i},\theta) \]

EM algorithm ensures \(l(\theta^{(t)}) \le l(\theta^{(t+1)})\), thus ensures convergence.