Magical NumPy with JAX (SciPy 2021)

scipy2021

Title

Magical NumPy with JAX

Abstract

The most significant contribution of the age the decade in which deep learning exploded was not these big models but a generalized toolkit to train any model by gradient descent. We're now in an era where differential computing can give you the toolkit to train models of any kind. Does a Pythonista well-versed in the PyData stack have to learn an entirely new toolkit, a new array library to have access to this power?

This tutorial's answer is as follows: If you can write NumPy code, then with JAX, differential computing is at your fingertips with no need to learn a new array library! In this tutorial, you will learn how to use the NumPy-compatible JAX API to write performant numerical models of the world and train them using gradient-based optimization. Along the way, you will write loopy numerical code without loops, think in data cubes, get your functional programming muscles trained up, generate random numbers deterministically (no, this is not an oxymoron!), and preview how to mix neural networks and probabilistic models... leveraging everything you know about NumPy plus some nearly-learned JAX magic sprinkled in!

Keywords

differential computing
numpy
jax
array computing

Track

Machine Learning and Data Science

Short Summary (<100 words)

JAX's magic brings your NumPy game to the next level! Come learn how to write loop-less numerical loops, optimize any function, jit-compile your programs, gain reliability over stochastic numbers - basically equip yourself with a bag of tricks to help you write robust numerical programs.

Tutorial Description

The most significant contribution of the age the decade in which deep learning exploded was not these big models but a generalized toolkit to train any model by gradient descent. We're now in an era where differential computing can give you the toolkit to train models of any kind. Does a Pythonista well-versed in the PyData stack have to learn an entirely new toolkit, a new array library to have access to this power?

This tutorial's answer is as follows: If you can write NumPy code, then with JAX, differential computing is at your fingertips with no need to learn a new array library! In this tutorial, you will learn how to use the NumPy-compatible JAX API to write performant numerical models of the world and train them using gradient-based optimization. Along the way, you will write loopy numerical code without loops, think in data cubes, get your functional programming muscles trained up, generate random numbers deterministically (no, this is not an oxymoron!), learn tips and tricks for handling array programs sanely, and preview how to mix neural networks and probabilistic models... leveraging everything you know about NumPy plus some nearly-learned JAX magic sprinkled in!

Pre-requisite knowledge: If you're comfortable with the NumPy API, then you'll be well-equipped for this tutorial. This is a tutorial that will equip you beyond simple deep learning; instead of learning how to use a deep learning framework, you'll leave equipped with a toolkit to write high-performance numerical models of the world and optimize them with gradient descent. Familiarity with Jupyter will help. Local setup is not necessary; Binder will be an option for tutorial participants.

Tutorial Outline

  1. Introduction (10 min)
  2. Loopless loops (20 min)
    1. An introduction to vmap and lax.scan, which allow for loop writing without writing loop code.
    2. Demo examples: computing averages row-wise (vmap), calculating the lagging average (lax.scan).
    3. Exercises: scaling to row-wise probabilities (vmap), calculating compound interest (lax.scan).
  3. Break (10 min)
  4. Deterministic randomness (30 min)
    1. An introduction to JAX's PRNGKey system and its contrast to stateful random number generation systems.
    2. Emphasis on 100% reproducibility within the same session, and the need to split key for more complex programs.
    3. Demo examples: drawing from Gaussians and Poissons deterministically, simulating stochastic processes with dependencies.
    4. Exercise: simulating stock market using a Gaussian random walk (lax.scan)
  5. Break (10 min)
  6. Optimized learning (50 min)
    1. An introduction to grad, which gives us the derivative of a scalar-output function w.r.t. its input, and JAX's built-in optimizers
    2. Demo examples: minimizing a simple math function; finding the maximum likelihood parameter values for Gaussian data.
    3. Exercises: programming a robot to find wells in a field (grad)
  7. Break (10 min)
  8. Advanced Fun (40 min)
    1. Tutorial attendees can pick-and-choose their own adventure here.
    2. Pairwise concatenation of each row in a matrix (a.k.a. double for-loops using vmap)
    3. Simulating multiple multiple instances of a Gaussian random walk using vmap and split PRNGKeys.
    4. Maximum likelihood estimation of multinomial parameters (vmap and grad)
    5. Fast message passing on graphs with linear algebra.
    6. Neural networks with stax.
  9. Concluding words (10 min)

Additional Tutorial Information

A preliminary version of the tutorial is available on GitHub: https://github.com/ericmjl/dl-workshop.

Tutorial Prerequisites

Pre-requisite knowledge: If you're comfortable with the NumPy API, then you'll be well-equipped for this tutorial. This is a tutorial that will equip you beyond simple deep learning; instead of learning how to use a deep learning framework, you'll leave equipped with a toolkit to write high-performance numerical models of the world and optimize them with gradient descent. Familiarity with Jupyter will help. Local setup is not necessary; Binder will be an option for tutorial participants.

Instructor Bio

Eric is a data scientist at the Novartis Institutes for Biomedical Research. There, he conducts biomedical data science research, with a focus on using Bayesian statistical methods in the service of making medicines for patients. Prior to Novartis, he was an Insight Health Data Fellow in the summer of 2017, and defended his doctoral thesis in the Department of Biological Engineering at MIT in the spring of 2017.

Eric is also an open source software developer, and has led the development of pyjanitor, a clean API for cleaning data in Python, and nxviz, a visualization package for NetworkX. In addition, he gives back to the open source community through code contributions to multiple projects.

His personal life motto is found in the Gospel of Luke 12:48.