Magical NumPy with JAX (PyCon 2021)

pycon2021

Magical NumPy with JAX

Description

The greatest contribution of the age the decade in which deep learning exploded was not these big models, but a generalized toolkit to train any model by gradient descent. We're now in an era where differential computing can give you the toolkit to train models of any kind. Does a Pythonista well-versed in the PyData stack have to learn an entirely new toolkit, a new array library to have access to this power?

This tutorial's answer is as follows: If you can write NumPy code, then with JAX, differential computing is at your fingertips with no need to learn a new array library! In this tutorial, you will learn how to use the NumPy-compatible JAX API to write performant numerical models of the world and train them using gradient-based optimization. Along the way, you will write loopy numerical code without loops, think in data cubes, get your functional programming muscles trained up, generate random numbers completely deterministically (no, this is not an oxymoron!), and preview how to mix neural networks and probabilistic models together... leveraging everything you know about NumPy plus some nearly-learned JAX magic sprinkled in!

Audience

This tutorial is for Pythonistas who wish to take their NumPy skills to the next level. As such, this is not a tutorial for Python beginners. Attendees should know how to use NumPy syntax; for them, a "dot product" should not feel unfamiliar. That said, advanced knowledge in deep learning and linear algebra is not necessary; for each of the sections in the tutorial, there will be minimally complex examples and exercise puzzles that illustrate the point. Prior experience using functools.partial or writing closures is not necessary (as we will go through what they do) but definitely useful to have.

By the end of the tutorial, attendees will be familiar with JAX primitives, including vmap (to replace for-loops), jit (for just-in-time compilation), grad (to calculate gradients), lax.scan to write loops with carryover, random to generate purely deterministic random sequences, and a numerical program composed of multiple primitive parts. Attendees will also see their application in a few toy problem settings, for which their natural extensions to complex models will be explained, and thus be equipped to take what they have learned in the tutorial and extend it to their own problems.


  • level: advanced
  • category: machine learning

Format

The tutorial format will be comprised of short lecture-style demonstrations followed by hands-on activities that reinforce the point made. Jupyter notebooks will feature prominently here, hosted on Binder or locally (at participants' choice). For Q&A, I am happy to address them offline.

Outline

  1. Introduction (10 min)
  2. Loopless loops (20 min)
    1. An introduction to vmap and lax.scan, which allow for loop writing without writing loop code.
    2. Demo examples: computing averages row-wise (vmap), calculating the lagging average (lax.scan).
    3. Exercises: scaling to row-wise probabilities (vmap), calculating compound interest (lax.scan).
  3. Break (10 min)
  4. Deterministic randomness (30 min)
    1. An introduction to JAX's PRNGKey system and its contrast to stateful random number generation systems.
    2. Emphasis on 100% reproducibility within the same session, and the need to split key for more complex programs.
    3. Demo examples: drawing from Gaussians and Poissons deterministically, simulating stochastic processes with dependencies.
    4. Exercise: simulating stock market using a Gaussian random walk (lax.scan)
  5. Break (10 min)
  6. Optimized learning (50 min)
    1. An introduction to grad, which gives us the derivative of a scalar-output function w.r.t. its input, and JAX's built-in optimizers
    2. Demo examples: minimizing a simple math function; finding the maximum likelihood parameter values for Gaussian data.
    3. Exercises: programming a robot to find wells in a field (grad)
  7. Break (10 min)
  8. Advanced Fun (40 min)
    1. Tutorial attendees can pick-and-choose their own adventure here.
    2. Pairwise concatenation of each row in a matrix (a.k.a. double for-loops using vmap)
    3. Simulating multiple multiple instances of a Gaussian random walk using vmap and split PRNGKeys.
    4. Maximum likelihood estimation of multinomial parameters (vmap and grad)
    5. Fast message passing on graphs with linear algebra.
    6. Neural networks with stax.
  9. Concluding words (10 min)

Past Experience

I have presented talks and tutorials at PyCon, SciPy, PyData and ODSC. My most well-known tutorial is "Network Analysis Made Simple", and I have also led tutorials on Bayesian statistical modelling before. The content for this tutorial is freely available online at this website that I made, and will be updated based on this tutorial proposal's content if accepted.