Magical NumPy with JAX (PyCon 2021)
The greatest contribution of the age the decade in which deep learning exploded was not these big models, but a generalized toolkit to train any model by gradient descent. We're now in an era where differential computing can give you the toolkit to train models of any kind. Does a Pythonista well-versed in the PyData stack have to learn an entirely new toolkit, a new array library to have access to this power?
This tutorial's answer is as follows: If you can write NumPy code, then with JAX, differential computing is at your fingertips with no need to learn a new array library! In this tutorial, you will learn how to use the NumPy-compatible JAX API to write performant numerical models of the world and train them using gradient-based optimization. Along the way, you will write loopy numerical code without loops, think in data cubes, get your functional programming muscles trained up, generate random numbers completely deterministically (no, this is not an oxymoron!), and preview how to mix neural networks and probabilistic models together... leveraging everything you know about NumPy plus some nearly-learned JAX magic sprinkled in!
This tutorial is for Pythonistas who wish to take their NumPy skills to the next level. As such, this is not a tutorial for Python beginners. Attendees should know how to use NumPy syntax; for them, a "dot product" should not feel unfamiliar. That said, advanced knowledge in deep learning and linear algebra is not necessary; for each of the sections in the tutorial, there will be minimally complex examples and exercise puzzles that illustrate the point. Prior experience using functools.partial
or writing closures is not necessary (as we will go through what they do) but definitely useful to have.
By the end of the tutorial, attendees will be familiar with JAX primitives, including vmap
(to replace for-loops), jit
(for just-in-time compilation), grad
(to calculate gradients), lax.scan
to write loops with carryover, random
to generate purely deterministic random sequences, and a numerical program composed of multiple primitive parts. Attendees will also see their application in a few toy problem settings, for which their natural extensions to complex models will be explained, and thus be equipped to take what they have learned in the tutorial and extend it to their own problems.
The tutorial format will be comprised of short lecture-style demonstrations followed by hands-on activities that reinforce the point made. Jupyter notebooks will feature prominently here, hosted on Binder or locally (at participants' choice). For Q&A, I am happy to address them offline.
vmap
and lax.scan
, which allow for loop writing without writing loop code.vmap
), calculating the lagging average (lax.scan
).vmap
), calculating compound interest (lax.scan
).PRNGKey
system and its contrast to stateful random number generation systems.lax.scan
)grad
, which gives us the derivative of a scalar-output function w.r.t. its input, and JAX's built-in optimizersgrad
)vmap
)vmap
and split PRNGKeys.vmap
and grad
)stax
.I have presented talks and tutorials at PyCon, SciPy, PyData and ODSC. My most well-known tutorial is "Network Analysis Made Simple", and I have also led tutorials on Bayesian statistical modelling before. The content for this tutorial is freely available online at this website that I made, and will be updated based on this tutorial proposal's content if accepted.