Differential computing with JAX

What JAX enables and doesn't enable

JAX as a package doesn't pretend to be a replacement for established deep learning frameworks. That is because JAX doesn't provide the deep learning abstractions as a first-class citizen; its focus is on the much more generally useful idea of composable program transformations. To compare it against a deep learning framework is a bit of a red herring - a distraction away from what JAX enables.

What JAX actually enables is for us to write numerical programs using the NumPy API that are performant and automatically differentiable. vmap and lax.scan help us eliminate Python loop overhead in our code; jit just-in-time compiles code to accelerate it; grad gives us differentiability, thus opening the door for us to write optimization routines that solve real world problems. (see: grad, jit, vmap, and lax.scan)

At work, I have used JAX productively in both neural network and non-neural network settings, with the unifying theme being gradient-based optimization of model parameters. With JAX, I can seamlessly move between problem classes while using the PyData community's idiomatic NumPy API. We have used JAX to implement Hierarchical Dirichlet Process autoregressive multivariate Gaussian hidden Markov models (what a mouthful!), LSTM recurrent neural networks, graph neural networks, simple feed-forward neural networks, linear models, and more... and train them using the same gradient descent tooling available to us in JAX.

The upside here is that we could hand-craft each model and tailor it to each problem encountered, and the code was written in a very explicit fashion that exposed the many layers of abstractions that were sometimes needed. This may also be viewed as the downside of writing JAX code -- we had to write a lot of code, partially because the abstractions we needed weren't already implemented in some cases, and partially because they aren't easily available in JAX in other cases.

One thing I wanted to highlight though: leveraging simple tricks learned from the neural network and probabilistic programming world (such as optimizing in unbounded rather than bounded space), we were able to train covariance matrices in our multivariate Gaussian HMMs using gradient descent rather than expectation-maximization, and it just worked. It was amazing to see in action.

Now, the lack of deep learning abstractions in JAX doesn't mean that JAX as a backend to other computing frameworks isn't available! A flurry of development after JAX's initial release led to a suite of deep learning libraries and probabilistic programming languages targeting JAX as an array backend, because of its provision of a library of Python-compatible composable program transformations.

For deep learning libraries, an experimental stax module exists inside JAX; my intern Arkadij Kummer and myself used it productively in a JAX-based reimplementation of an LSTM model used for protein engineering. flax, also developed by Googlers, exists, and provides a PyTorch-like API that builds on top of the functional programming paradigm encouraged by JAX. The Neural Tangents package for infinitely wide, Bayesian neural networks follows stax's idioms, with well-documented differences (though without reasons given).

For probabilistic programming languages, even TensorFlow Probability has a JAX backend as an alternative to the TensorFlow backend. PyMC3's, which is built on top of Theano, is getting a JAX-ified Theano backend too, while mcx, written by a French software developer Remí Louf, is a pedagogical PPL written entirely using JAX as a backend too.

What makes JAX special

API-compatible differentiable computing

The Python scientific computing stack, also known as the PyData or SciPy stack in the Python world, provides a large library of numerical programs that can be composed together into higher order programs. What JAX provides is a fully API-compatible reimplementation of the stack's differentiable functions, with what I think is a near-complete coverage of functions.

As such, users familiar with NumPy and SciPy can, with minimal changes to lines of code, write automatically differentiable versions of their existing programs, and develop new programs that are automatically differentiable.

How does this work? I have written a longer-form collection of teaching materials on this, but here is a quick example. If I have a silly program like the following one:

from jax.scipy.special import cholesky
import jax.numpy as np

def some_function(params, data):
    # do stuff with params, for example, a cholesky decomposition:
    U = cholesky(params)
    # followed by a sum of sine transform, for whatever reason it might be needed
    return np.sum(np.sin(U))

I can get the gradient function easily, using JAX's provided grad function:

from jax import grad

dsome_function = grad(some_function)

dsome_function has the same function signature as some_function, but instead of returning a scalar value, it returns the derivative of some_function w.r.t. params (the first argument), in the same (and possibly nested) data structure as params. That is to say, if params were a tuple, or a dictionary, or a list, or any other native Python construct, dsome_function would have the same structure.

grad can do many more fancy things, such as differentiating through loops and flow control, and second through nth-order derivatives, and I'd encourage you to check out the docs to learn more. (That is out of scope for this essay, as I'm focusing on high level points.)

Providing grad as a first-class citizen in an API-compatible fashion with the scientific Python computing stack makes it very easy to adopt differential computing tooling in one's programs.

API-compatibility with the rest of the PyData stack

A design choice made early on by the JAX developers was full NumPy and SciPy API compatibility, with minimal differences (mostly in the realm of random number generation) that are very well-documented. Incidentally, practice is also adopted by Dask and CuPy, which give us distributed and GPU-backed arrays respectively. This practice reflects a healthy dose of respect for what already exists and for end-users as well.

I think a contrasting example best illustrates this point. Consider a PyTorch or TensorFlow array vs. a JAX NumPy array.

To plot a PyTorch array's values in matplotlib, one must first convert it to a NumPy array:

a = torch.tensor([1, 2, 3])
plt.plot(a.numpy())

On the other hand, with JAX:

import jax.numpy as np
a = np.array([1, 2, 3])
plt.plot(a)

The syntax with JAX is identical to what one might write with vanilla NumPy. This is, I think, in part because JAX's developers have also strived to adhere to NEP18. This little API difference leads to compound interest-like savings in time for model developers.

Resources for learning about JAX

I invested one of my vacation weeks crystallizing my learnings from working with JAX over the past year and a half, and it's been extremely educational. If you're interested in reading it, you can find it at my dl-workshop repository on GitHub. In there, in addition to the original content, which was a workshop on deep learning, I also try to provide "simple complex examples" of how to use JAX idioms in solving modelling problems.

Besides that, JAX's documentation is quite well-written, and you can find it at jax.readthedocs.io. In particular, they have a very well-documented suite of "The Sharp Bits" to look out for when using JAX, geared towards both power users of vanilla NumPy and beginners. If you're using JAX and run into unexpected behaviour, I'd strongly encourage you to check out the post - it'll clear up many misconceptions you might have!

In terms of introductory material, a blog post by Colin Raffel, titled "You don't know JAX", is a very well-written introduction on how to use JAX. Eric Jang also has a blog post on implementing meta-learning in JAX, which I found very educational for both JAX syntax and meta-learning.

While the most flashy advances of the deep learning world came from 2010-2020, personally think that the most exciting foundational advance of that era was the development of a general purpose automatic differentiation package like autograd and JAX. At least for the Python world, it's enabled the writing of arbitrary models in a highly compatible fashion with the rest of the PyData stack, with differentiation and native compilation as first class program transformations. The use of gradients is varied; I'd definitely encourage you to try it out!

Who are the makers of JAX

JAX has been actively developed for over two years now, and as a project, it continues to attract talent to the project. The originators were Dougal Maclaurin, Matt Johnson, Alex Wiltschko and David Duvenaud while they were at all at Harvard, and has since grown to include many prominent Pythonistas including Jake Vanderplas and Stephan Hoyer on the team. (There are many more, whose names I don't know very well, so my apologies in advance if I have left your name out. For a full list of code contributors, the repository contributors page is the most definitive.)