Skip to content

Binder

from IPython.display import YouTubeVideo, display

YouTubeVideo("hzJJAmX27Wg")

Introduction

In this chapter, we'll introduce you to some idiomatic JAX tools that will help you write performant numerical array programs. The main takeaways you should get from this series of notebooks are how you can:

  1. Replace slow Python for loop constructs with fast, just-in-time compiled JAX loop constructs,
  2. Create deterministic random numbers (and yes, this is not an oxymoron!) for reproducibility,
  3. Freely mix-and-match through this idea of composable transforms.

Because contrasts to what we might be used to doing are an effective way to teach and learn, in each section, we'll be explicit about what exactly we're replacing when we write these numerical array programs. In doing so, my hope is that you'll see very clearly that structuring your array programs in a composable and atomic fashion will help you take advantage of JAX's composable function transforms to write really fast and compiled functions. And for good measure, we'll contrast this against pure Python programs, so you can witness for yourself how powerful JAX's ideas are... and appreciate how much effort has gone into making the whole thing NumPy compatible!

Prerequisites

To get the most out of this notebook, you need only be familiar with the NumPy API, and writing functions. Having an appreciation of functools.partial, will help a bit, because we use it a lot in writing JAX programs. However, I know that not everybody has had prior experience with partial-ed functions, so we will introduce the idea mid-way, in a just-in-time fashion as well.

If you've gone through tutorial.ipynb, which is the main tutorial notebook for this repository, then you'll have some appreciation of JAX's composable transforms. You'll also see how we wrote some loops in there, and hopefully have an appreciation of how much faster things will run when we use JAX's looping constructs instead.