Notes on differential computing
This is an overview page of my notes on differential computing and JAX.
Contents:
Stuff I've built/done with JAX:
Where differential computing gets used
In case you've found yourself living like a digital hermit for the past decade (no judgment, sometimes I do fantasize about going offline for a year), deep learning has been the place where automatic differentiation has been most utilized. With deep learning, the core technical problem that needs to be solved is optimizing parameters of a model to minimize some loss function. It's here where the full set of partial derivatives of the loss function w.r.t. each parameter in the model can be automatically calculated using an AD system, and these partial derivatives can be used to update their respective model parameters in the direction that minimizes loss.
Because deep learning models and their applications proliferated in the 2010-2020 decade, AD systems were most commonly associated with neural networks and deep learning. However, that is not the only place where AD systems show up.
For example, AD is used in the Bayesian statistical modelling world. Hamiltonian Monte Carlo samplers use AD to help the sampler program identify the direction in which its next MCMC step should be taken. AD systems can also be used to optimize parameters of non-neural network models of the world against data, such as Gaussian Mixture Models and Hidden Markov Models. We can even use AD in a class of problems called "input design" problems, where we try to optimize not the parameters of the model w.r.t. some output, but the inputs (assuming we know how to cast the inputs into some continuous numerical space.)
What automatic differentiation systems exist
Where do AD systems live? Firstly, they definitely live inside deep learning frameworks such as PyTorch and TensorFlow, and other deep learning frameworks. Without an AD system, these two deep learning frameworks would not work.
Secondly, they also live in independent packages. In Julia, there are two AD packages: one called Zygote.jl
, and the other called AutoGrad.jl
; both of them are actively developed. autograd
, which was the reference Python package that AutoGrad.jl
was written against, is also the precursor to JAX, which I think of as automatic differentiation on steroids.
What makes JAX special
The Python scientific computing stack, also known as the PyData or SciPy stack in the Python world, provides a large library of numerical programs that can be composed together into higher order programs. What JAX provides is a fully API-compatible reimplementation of the stack's differentiable functions, with what I think is a near-complete coverage of functions.
As such, users familiar with NumPy and SciPy can, with minimal changes to lines of code, write automatically differentiable versions of their existing programs, and develop new programs that are automatically differentiable.
How does this work? I have written a longer-form collection of teaching materials on this, but here is a quick example. If I have a silly program like the following one:
from jax.scipy.special import cholesky
import jax.numpy as np
def some_function(params, data):
# do stuff with params, for example, a cholesky decomposition:
U = cholesky(params)
# followed by a sum of sine transform, for whatever reason it might be needed
return np.sum(np.sin(U))
I can get the gradient function easily, using JAX's provided grad
function:
from jax import grad
dsome_function = grad(some_function)
dsome_function
has the same function signature as some_function
, but instead of returning a scalar value, it returns the derivative of some_function
w.r.t. params
(the first argument), in the same (and possibly nested) data structure as params
. That is to say, if params
were a tuple, or a dictionary, or a list, or any other native Python construct, dsome_function
would have the same structure.
grad
can do many more fancy things, such as differentiating through loops and flow control, and second through nth-order derivatives, and I'd encourage you to check out the docs to learn more. (That is out of scope for this essay, as I'm focusing on high level points.)
Providing grad
as a first-class citizen in an API-compatible fashion with the scientific Python computing stack makes it very easy to adopt differential computing tooling in one's programs.
A design choice made early on by the JAX developers was full NumPy and SciPy API compatibility, with minimal differences (mostly in the realm of random number generation) that are very well-documented. Incidentally, practice is also adopted by Dask and CuPy, which give us distributed and GPU-backed arrays respectively. This practice reflects a healthy dose of respect for what already exists and for end-users as well.
I think a contrasting example best illustrates this point. Consider a PyTorch or TensorFlow array vs. a JAX NumPy array.
To plot a PyTorch array's values in matplotlib
, one must first convert it to a NumPy array:
a = torch.tensor([1, 2, 3])
plt.plot(a.numpy())
On the other hand, with JAX:
import jax.numpy as np
a = np.array([1, 2, 3])
plt.plot(a)
The syntax with JAX is identical to what one might write with vanilla NumPy. This is, I think, in part because JAX's developers have also strived to adhere to NEP18. This little API difference leads to compound interest-like savings in time for model developers.
Composable program transforms
JAX's grad
function is merely the "gateway drug" to this bigger idea of "composable program transforms". grad
is one example of a composable program transform: that is transforming one program into another in a composable fashion. Other transforms include vmap
, lax.scan
, jit
, and more. These all accept Python functions and return Python functions. jit
, in particular, can accelerate a program anywhere from 2-100 fold on a CPU, depending on what your reasonable baseline comparison is.
In particular, the latter three of the aformentioned transforms allow for highly performant loop code, written without loops, that can also be composed together. In a differential learning workshop that I have been developing, I provide further details in there, which you can take a look at.
There are other automatic program transformations that are in active development, and one exciting realm I see is in the probabilistic programming world. In PyMC3, for example, an automated transform happens when we take our PyMC3 syntax, which is written in a domain specific language (DSL) implemented in Python, and transform/compile it into a compute graph that gives us a likelihood function. It's as if PyMC3 gives us a likelihood(func)
analogy to JAX's grad
func. (If you've tried writing probabilistic model likelihoods by hand, you'll know how much of a convenience this is!)
index
This is the landing page for my notes.
This is 100% inspired by Andy Matuschak's famous notes page. I'm not technically skilled enough to replicate the full "Andy Mode", though, so I just did some simple hacks. If you're curious how these notes compiled, check out the summary in How these notes are made into HTML pages.
This is my "notes garden". I tend to it on a daily basis, and it contains some of my less fully-formed thoughts. Nothing here is intended to be cited, as the link structure evolves over time. The notes are best viewed on a desktop/laptop computer, because of the use of hovers for previews.
There's no formal "navigation", or "search" for these pages. To go somewhere, click on any of the "high-level" notes below, and enjoy.
jax-unirep
jax-unirep
The main use of jax-unirep really is in getting "standardized" representations of proteins that can be used for downstream ML applications easily.
Infinitely Wide Neural Networks
How useful are infinitely wide neural networks? I tried writing a Jupyter notebook to explore the idea.
https://ericmjl.github.io/essays-on-data-science/machine-learning/nngp/
Workshop on Differential Learning
Key Information:
Feedback from Kannan Sankar and Mei Xiao after today's differential computing tutorial.
_i
indicator in the summation.vmap
is a compiled for-loopOther thoughts
Differential computing with JAX
A collection of notes about JAX and differential computing.
Estimating a multivariate Gaussian's parameters by gradient descent
I've seen this idea in action before, but never tried my hand at it until yesterday at work.
Turns out, if you have data that can be modelled by a multivariate Gaussian distribution, you can optimize the parameters of that Gaussian to maximize the likelihood of data under the multivariate Gaussian model.
The key tricky piece of this problem is that a multivariate Gaussian distribution's covariance parameter has to be structured. (see Covariance matrix) When pairing with gradient descent though, a key problem we face is that we can never guarantee that the partial derivative of our likelihood function w.r.t. the square covariance-like matrix that we input will preserve its correct structure. What can we do to solve this problem?
We use the Cholesky decomposition.
For example, using JAX's scipy wrapper (for automatic differentiation-compatible Cholesky decomposition):
import jax.numpy as np
from jax.scipy.linalg import cholesky
a = np.array([
[1, 0.8],
[0.8, 1],
])
U = cholesky(a) # returns the upper triangle
From U
, we can reconstitute a
by taking the dot-product of the the transpose of U
against U
:
a_hat = np.dot(np.transpose(U), U)
assert np.all(a_hat == a)
That's all cool and such, but how does this apply to fitting a multivariate Gaussian against data? As mentioned above, the key problem is that when we initialize our parameters to optimize, we oftentimes sample a number from a Gaussian (or other unbounded) distributions. To initialize a covariance matrix, we might be tempted to draw a square matrix, but there's no easy way to guarantee that it follows the desired structure of a covariance matrix.
That's where the reconstitution of a covariance matrix from upper triangle matrices can help. To sweeten the deal, because of JAX's fully NumPy-compatible API, we can instead initialize howeve we want, and then perform a transformation step to help us get back to a covariance matrix:
from jax import random as rnd
key = rnd.PRNGKey(42)
init_cov = rnd.normal(key, shape=(5, 5)) # let's make a 5x5 covariance matrix
def transform_to_covariance_matrix(sq_mat):
U = np.triu(sq_mat)
U_T = np.transpose(U)
return np.dot(U_T, U)
And now with that, we can perform a full optimization of the init_cov
parameter to its maximum likelihood. Let me show you how this works.
Firstly, we define the likelihood of observing our data under a multivariate Gaussian model:
from jax.scipy.stats import multivariate_normal
from jax import vmap
def loglike(params, data):
mu, untransformed_cov = params
cov = transform_to_covariance_matrix(untransformed_cov)
def logpdf_func(datum):
"""logpdf of multivariate normal for one datum."""
return multivariate_normal.logpdf(mu, cov, datum)
logp = vmap(logpdf_func)(data)
return np.sum(logp)
Notice two pieces:
1. We used some pretty cool/rad JAX tooling, like vmap
, to eliminate sample dimensions and for-loops!
2. We also assumed that our covariance matrix is passed into the loglike
function in an untransformed form, and we transform it to the correct form directly.
Now, we perform gradient-based optimization. We define the derivative function, which we'll use later for calculating gradients:
from jax import grad
dloglike = grad(loglike)
We then initialize our parameters:
mu = random.normal(key, shape=(5,))
untransformed_cov = random.normal(key, shape=(5, 5))
params = mu, untransformed_cov # package them up into a convenient variable.
Finally, we use JAX's built-in optimizers, calling on jit
to make computations fast:
from jax import jit
from jax.experimental.optimizers import adam
init, update, get_params = adam(step_size=0.005)
get_params = jit(get_params); update = jit(update)
dloglike = jit(dloglike); loglike = jit(loglike)
state = init(params)
for i in range(300):
params = get_params(state)
g = dloglike(params, data)
state = update(i, g, state)
mu_opt, untransformed_cov_opt = get_params(state)
cov_opt = transform_to_covariance_matrix(untransformed_cov_opt)
The optimized covariance matrix will be pretty darn close to the true one!
What is differential computing
I think at its core, we can think of differential computing as "computing where derivatives (from calculus) are a first-class citizen". This is analogous to probabilistic computing, in which probabilistic constructs (such as probability distributions and MCMC samplers) are given "first class" status in a language. By "first class" status, I mean that the relevant computing constructs are a well-developed category of constructs in the language, with clearly defined interfaces.
In differential computing, being able to evaluate gradients of a math function are at the heart of the language. (How the gradients are used is a matter of application.) The objective of a differential computing system is to write a program that gets a computer to automatically evaluate gradients of a math function that we have written in that language. This is where automatic differentiation (AD) systems come into play. They take a program, perhaps written as a Python function, and automatically transform the program into, perhaps, another Python function, which can be used to evaluate the derivative of the original function.
All of this falls under the paradigm of AD systems, as I mentioned earlier. Symbolic differentiation can be considered one subclass of AD systems; this is a point PyMC3 developer Brandon Willard would make. Packages such as autograd
and now JAX (also by the autograd
authors) are another subclass of AD systems, which leverage the chain rule and a recorder tape of math operations called in the program to automatically construct the gradient function.