Estimating a multivariate Gaussian's parameters by gradient descent

I've seen this idea in action before, but never tried my hand at it until yesterday at work.

Turns out, if you have data that can be modelled by a multivariate Gaussian distribution, you can optimize the parameters of that Gaussian to maximize the likelihood of data under the multivariate Gaussian model.

The key tricky piece of this problem is that a multivariate Gaussian distribution's covariance parameter has to be structured. (see Covariance matrix) When pairing with gradient descent though, a key problem we face is that we can never guarantee that the partial derivative of our likelihood function w.r.t. the square covariance-like matrix that we input will preserve its correct structure. What can we do to solve this problem?

We use the Cholesky decomposition.

For example, using JAX's scipy wrapper (for automatic differentiation-compatible Cholesky decomposition):

import jax.numpy as np
from jax.scipy.linalg import cholesky

a = np.array([
    [1, 0.8],
    [0.8, 1],
])

U = cholesky(a)  # returns the upper triangle

From U, we can reconstitute a by taking the dot-product of the the transpose of U against U:

a_hat = np.dot(np.transpose(U), U)

assert np.all(a_hat == a)

That's all cool and such, but how does this apply to fitting a multivariate Gaussian against data? As mentioned above, the key problem is that when we initialize our parameters to optimize, we oftentimes sample a number from a Gaussian (or other unbounded) distributions. To initialize a covariance matrix, we might be tempted to draw a square matrix, but there's no easy way to guarantee that it follows the desired structure of a covariance matrix.

That's where the reconstitution of a covariance matrix from upper triangle matrices can help. To sweeten the deal, because of JAX's fully NumPy-compatible API, we can instead initialize howeve we want, and then perform a transformation step to help us get back to a covariance matrix:

from jax import random as rnd

key = rnd.PRNGKey(42)
init_cov = rnd.normal(key, shape=(5, 5))  # let's make a 5x5 covariance matrix

def transform_to_covariance_matrix(sq_mat):
    U = np.triu(sq_mat)
    U_T = np.transpose(U)
    return np.dot(U_T, U)

And now with that, we can perform a full optimization of the init_cov parameter to its maximum likelihood. Let me show you how this works.

Firstly, we define the likelihood of observing our data under a multivariate Gaussian model:

from jax.scipy.stats import multivariate_normal
from jax import vmap

def loglike(params, data):
    mu, untransformed_cov = params
    cov = transform_to_covariance_matrix(untransformed_cov)
    def logpdf_func(datum):
        """logpdf of multivariate normal for one datum."""
        return multivariate_normal.logpdf(mu, cov, datum)
    logp = vmap(logpdf_func)(data)
    return np.sum(logp)

Notice two pieces:
1. We used some pretty cool/rad JAX tooling, like vmap, to eliminate sample dimensions and for-loops!
2. We also assumed that our covariance matrix is passed into the loglike function in an untransformed form, and we transform it to the correct form directly.

Now, we perform gradient-based optimization. We define the derivative function, which we'll use later for calculating gradients:

from jax import grad
dloglike = grad(loglike)

We then initialize our parameters:

mu = random.normal(key, shape=(5,))
untransformed_cov = random.normal(key, shape=(5, 5))

params = mu, untransformed_cov  # package them up into a convenient variable.

Finally, we use JAX's built-in optimizers, calling on jit to make computations fast:

from jax import jit
from jax.experimental.optimizers import adam

init, update, get_params = adam(step_size=0.005)
get_params = jit(get_params); update = jit(update)
dloglike = jit(dloglike); loglike = jit(loglike)

state = init(params)
for i in range(300):
    params = get_params(state)
    g = dloglike(params, data)
    state = update(i, g, state)
mu_opt, untransformed_cov_opt = get_params(state)
cov_opt = transform_to_covariance_matrix(untransformed_cov_opt)

The optimized covariance matrix will be pretty darn close to the true one!

Covariance matrix

Covariance matrices must be symmetric across the diagonal; it must also be positive semi-definite. Finally, they must be square. Non-square matrices are not allowed.

Cholesky decomposition

The Cholesky decomposition can be thought of as a "square root"-ish operation of a square matrix. Basically it takes any Covariance matrix and decomposes it into the dot product of a triangle matrix and its [conjugate transpose][conj]. (For matrices involving only real numbers, the transpose is equal to the conjugate transpose.)