Estimating a multivariate Gaussian's parameters by gradient descent
I've seen this idea in action before, but never tried my hand at it until yesterday at work.
Turns out, if you have data that can be modelled by a multivariate Gaussian distribution, you can optimize the parameters of that Gaussian to maximize the likelihood of data under the multivariate Gaussian model.
The key tricky piece of this problem is that a multivariate Gaussian distribution's covariance parameter has to be structured. (see Covariance matrix) When pairing with gradient descent though, a key problem we face is that we can never guarantee that the partial derivative of our likelihood function w.r.t. the square covariance-like matrix that we input will preserve its correct structure. What can we do to solve this problem?
We use the Cholesky decomposition.
For example, using JAX's scipy wrapper (for automatic differentiation-compatible Cholesky decomposition):
import jax.numpy as np
from jax.scipy.linalg import cholesky
a = np.array([
[1, 0.8],
[0.8, 1],
])
U = cholesky(a) # returns the upper triangle
From U
, we can reconstitute a
by taking the dot-product of the the transpose of U
against U
:
a_hat = np.dot(np.transpose(U), U)
assert np.all(a_hat == a)
That's all cool and such, but how does this apply to fitting a multivariate Gaussian against data? As mentioned above, the key problem is that when we initialize our parameters to optimize, we oftentimes sample a number from a Gaussian (or other unbounded) distributions. To initialize a covariance matrix, we might be tempted to draw a square matrix, but there's no easy way to guarantee that it follows the desired structure of a covariance matrix.
That's where the reconstitution of a covariance matrix from upper triangle matrices can help. To sweeten the deal, because of JAX's fully NumPy-compatible API, we can instead initialize howeve we want, and then perform a transformation step to help us get back to a covariance matrix:
from jax import random as rnd
key = rnd.PRNGKey(42)
init_cov = rnd.normal(key, shape=(5, 5)) # let's make a 5x5 covariance matrix
def transform_to_covariance_matrix(sq_mat):
U = np.triu(sq_mat)
U_T = np.transpose(U)
return np.dot(U_T, U)
And now with that, we can perform a full optimization of the init_cov
parameter to its maximum likelihood. Let me show you how this works.
Firstly, we define the likelihood of observing our data under a multivariate Gaussian model:
from jax.scipy.stats import multivariate_normal
from jax import vmap
def loglike(params, data):
mu, untransformed_cov = params
cov = transform_to_covariance_matrix(untransformed_cov)
def logpdf_func(datum):
"""logpdf of multivariate normal for one datum."""
return multivariate_normal.logpdf(mu, cov, datum)
logp = vmap(logpdf_func)(data)
return np.sum(logp)
Notice two pieces:
1. We used some pretty cool/rad JAX tooling, like vmap
, to eliminate sample dimensions and for-loops!
2. We also assumed that our covariance matrix is passed into the loglike
function in an untransformed form, and we transform it to the correct form directly.
Now, we perform gradient-based optimization. We define the derivative function, which we'll use later for calculating gradients:
from jax import grad
dloglike = grad(loglike)
We then initialize our parameters:
mu = random.normal(key, shape=(5,))
untransformed_cov = random.normal(key, shape=(5, 5))
params = mu, untransformed_cov # package them up into a convenient variable.
Finally, we use JAX's built-in optimizers, calling on jit
to make computations fast:
from jax import jit
from jax.experimental.optimizers import adam
init, update, get_params = adam(step_size=0.005)
get_params = jit(get_params); update = jit(update)
dloglike = jit(dloglike); loglike = jit(loglike)
state = init(params)
for i in range(300):
params = get_params(state)
g = dloglike(params, data)
state = update(i, g, state)
mu_opt, untransformed_cov_opt = get_params(state)
cov_opt = transform_to_covariance_matrix(untransformed_cov_opt)
The optimized covariance matrix will be pretty darn close to the true one!
Covariance matrix
Covariance matrices must be symmetric across the diagonal; it must also be positive semi-definite. Finally, they must be square. Non-square matrices are not allowed.
Cholesky decomposition
The Cholesky decomposition can be thought of as a "square root"-ish operation of a square matrix. Basically it takes any Covariance matrix and decomposes it into the dot product of a triangle matrix and its [conjugate transpose][conj]. (For matrices involving only real numbers, the transpose is equal to the conjugate transpose.)
Notes on Statistics
Some learnings while training myself in statistics.
Distributions:
Statistical Estimation:
Papers that I'm writing:
Dealing with data:
Probabilistic Programming:
Notes on differential computing
This is an overview page of my notes on differential computing and JAX.
Contents:
Stuff I've built/done with JAX: