What JAX enables and doesn't enable
JAX as a package doesn't pretend to be a replacement for established deep learning frameworks. That is because JAX doesn't provide the deep learning abstractions as a first-class citizen; its focus is on the much more generally useful idea of composable program transformations. To compare it against a deep learning framework is a bit of a red herring - a distraction away from what JAX enables.
What JAX actually enables is for us to write numerical programs using the NumPy API that are performant and automatically differentiable. vmap
and lax.scan
help us eliminate Python loop overhead in our code; jit
just-in-time compiles code to accelerate it; grad
gives us differentiability, thus opening the door for us to write optimization routines that solve real world problems. (see: grad, jit, vmap, and lax.scan)
At work, I have used JAX productively in both neural network and non-neural network settings, with the unifying theme being gradient-based optimization of model parameters. With JAX, I can seamlessly move between problem classes while using the PyData community's idiomatic NumPy API. We have used JAX to implement Hierarchical Dirichlet Process autoregressive multivariate Gaussian hidden Markov models (what a mouthful!), LSTM recurrent neural networks, graph neural networks, simple feed-forward neural networks, linear models, and more... and train them using the same gradient descent tooling available to us in JAX.
The upside here is that we could hand-craft each model and tailor it to each problem encountered, and the code was written in a very explicit fashion that exposed the many layers of abstractions that were sometimes needed. This may also be viewed as the downside of writing JAX code -- we had to write a lot of code, partially because the abstractions we needed weren't already implemented in some cases, and partially because they aren't easily available in JAX in other cases.
One thing I wanted to highlight though: leveraging simple tricks learned from the neural network and probabilistic programming world (such as optimizing in unbounded rather than bounded space), we were able to train covariance matrices in our multivariate Gaussian HMMs using gradient descent rather than expectation-maximization, and it just worked. It was amazing to see in action.
Now, the lack of deep learning abstractions in JAX doesn't mean that JAX as a backend to other computing frameworks isn't available! A flurry of development after JAX's initial release led to a suite of deep learning libraries and probabilistic programming languages targeting JAX as an array backend, because of its provision of a library of Python-compatible composable program transformations.
For deep learning libraries, an experimental stax
module exists inside JAX; my intern Arkadij Kummer and myself used it productively in a JAX-based reimplementation of an LSTM model used for protein engineering. flax
, also developed by Googlers, exists, and provides a PyTorch-like API that builds on top of the functional programming paradigm encouraged by JAX. The Neural Tangents package for infinitely wide, Bayesian neural networks follows stax
's idioms, with well-documented differences (though without reasons given).
For probabilistic programming languages, even TensorFlow Probability has a JAX backend as an alternative to the TensorFlow backend. PyMC3's, which is built on top of Theano, is getting a JAX-ified Theano backend too, while mcx
, written by a French software developer Remà Louf, is a pedagogical PPL written entirely using JAX as a backend too.
vmap
vmap is a JAX function transformation that defines the application of a function across one axis of an array.
lax.scan
lax.scan
gives us loops with carryover.
Examples include:
jit
jit
does Just In Time (JIT) compilation of a function to lower level code, thus bypassing the Python interpreter overhead.
grad
grad
returns a transformed version of a function f
, which when called gives us a gradient function (let's call it df
) gives us the derivative of f
evaluated at the same arguments passed into f
.
Differential computing with JAX
A collection of notes about JAX and differential computing.