Magical NumPy with JAX (SciPy 2021)
Magical NumPy with JAX
The most significant contribution of the age the decade in which deep learning exploded was not these big models but a generalized toolkit to train any model by gradient descent. We're now in an era where differential computing can give you the toolkit to train models of any kind. Does a Pythonista well-versed in the PyData stack have to learn an entirely new toolkit, a new array library to have access to this power?
This tutorial's answer is as follows: If you can write NumPy code, then with JAX, differential computing is at your fingertips with no need to learn a new array library! In this tutorial, you will learn how to use the NumPy-compatible JAX API to write performant numerical models of the world and train them using gradient-based optimization. Along the way, you will write loopy numerical code without loops, think in data cubes, get your functional programming muscles trained up, generate random numbers deterministically (no, this is not an oxymoron!), and preview how to mix neural networks and probabilistic models... leveraging everything you know about NumPy plus some nearly-learned JAX magic sprinkled in!
differential computing
numpy
jax
array computing
Machine Learning and Data Science
JAX's magic brings your NumPy game to the next level! Come learn how to write loop-less numerical loops, optimize any function, jit-compile your programs, gain reliability over stochastic numbers - basically equip yourself with a bag of tricks to help you write robust numerical programs.
The most significant contribution of the age the decade in which deep learning exploded was not these big models but a generalized toolkit to train any model by gradient descent. We're now in an era where differential computing can give you the toolkit to train models of any kind. Does a Pythonista well-versed in the PyData stack have to learn an entirely new toolkit, a new array library to have access to this power?
This tutorial's answer is as follows: If you can write NumPy code, then with JAX, differential computing is at your fingertips with no need to learn a new array library! In this tutorial, you will learn how to use the NumPy-compatible JAX API to write performant numerical models of the world and train them using gradient-based optimization. Along the way, you will write loopy numerical code without loops, think in data cubes, get your functional programming muscles trained up, generate random numbers deterministically (no, this is not an oxymoron!), learn tips and tricks for handling array programs sanely, and preview how to mix neural networks and probabilistic models... leveraging everything you know about NumPy plus some nearly-learned JAX magic sprinkled in!
Pre-requisite knowledge: If you're comfortable with the NumPy API, then you'll be well-equipped for this tutorial. This is a tutorial that will equip you beyond simple deep learning; instead of learning how to use a deep learning framework, you'll leave equipped with a toolkit to write high-performance numerical models of the world and optimize them with gradient descent. Familiarity with Jupyter will help. Local setup is not necessary; Binder will be an option for tutorial participants.
vmap
and lax.scan
, which allow for loop writing without writing loop code.vmap
), calculating the lagging average (lax.scan
).vmap
), calculating compound interest (lax.scan
).PRNGKey
system and its contrast to stateful random number generation systems.lax.scan
)grad
, which gives us the derivative of a scalar-output function w.r.t. its input, and JAX's built-in optimizersgrad
)vmap
)vmap
and split PRNGKeys.vmap
and grad
)stax
.A preliminary version of the tutorial is available on GitHub: https://github.com/ericmjl/dl-workshop.
Pre-requisite knowledge: If you're comfortable with the NumPy API, then you'll be well-equipped for this tutorial. This is a tutorial that will equip you beyond simple deep learning; instead of learning how to use a deep learning framework, you'll leave equipped with a toolkit to write high-performance numerical models of the world and optimize them with gradient descent. Familiarity with Jupyter will help. Local setup is not necessary; Binder will be an option for tutorial participants.
Eric is a data scientist at the Novartis Institutes for Biomedical Research. There, he conducts biomedical data science research, with a focus on using Bayesian statistical methods in the service of making medicines for patients. Prior to Novartis, he was an Insight Health Data Fellow in the summer of 2017, and defended his doctoral thesis in the Department of Biological Engineering at MIT in the spring of 2017.
Eric is also an open source software developer, and has led the development of pyjanitor
, a clean API for cleaning data in Python, and nxviz
, a visualization package for NetworkX. In addition, he gives back to the open source community through code contributions to multiple projects.
His personal life motto is found in the Gospel of Luke 12:48.