What makes JAX special

API-compatible differentiable computing

The Python scientific computing stack, also known as the PyData or SciPy stack in the Python world, provides a large library of numerical programs that can be composed together into higher order programs. What JAX provides is a fully API-compatible reimplementation of the stack's differentiable functions, with what I think is a near-complete coverage of functions.

As such, users familiar with NumPy and SciPy can, with minimal changes to lines of code, write automatically differentiable versions of their existing programs, and develop new programs that are automatically differentiable.

How does this work? I have written a longer-form collection of teaching materials on this, but here is a quick example. If I have a silly program like the following one:

from jax.scipy.special import cholesky
import jax.numpy as np

def some_function(params, data):
    # do stuff with params, for example, a cholesky decomposition:
    U = cholesky(params)
    # followed by a sum of sine transform, for whatever reason it might be needed
    return np.sum(np.sin(U))

I can get the gradient function easily, using JAX's provided grad function:

from jax import grad

dsome_function = grad(some_function)

dsome_function has the same function signature as some_function, but instead of returning a scalar value, it returns the derivative of some_function w.r.t. params (the first argument), in the same (and possibly nested) data structure as params. That is to say, if params were a tuple, or a dictionary, or a list, or any other native Python construct, dsome_function would have the same structure.

grad can do many more fancy things, such as differentiating through loops and flow control, and second through nth-order derivatives, and I'd encourage you to check out the docs to learn more. (That is out of scope for this essay, as I'm focusing on high level points.)

Providing grad as a first-class citizen in an API-compatible fashion with the scientific Python computing stack makes it very easy to adopt differential computing tooling in one's programs.

API-compatibility with the rest of the PyData stack

A design choice made early on by the JAX developers was full NumPy and SciPy API compatibility, with minimal differences (mostly in the realm of random number generation) that are very well-documented. Incidentally, practice is also adopted by Dask and CuPy, which give us distributed and GPU-backed arrays respectively. This practice reflects a healthy dose of respect for what already exists and for end-users as well.

I think a contrasting example best illustrates this point. Consider a PyTorch or TensorFlow array vs. a JAX NumPy array.

To plot a PyTorch array's values in matplotlib, one must first convert it to a NumPy array:

a = torch.tensor([1, 2, 3])
plt.plot(a.numpy())

On the other hand, with JAX:

import jax.numpy as np
a = np.array([1, 2, 3])
plt.plot(a)

The syntax with JAX is identical to what one might write with vanilla NumPy. This is, I think, in part because JAX's developers have also strived to adhere to NEP18. This little API difference leads to compound interest-like savings in time for model developers.

Differential computing with JAX