Conference Proposals

Magical NumPy with JAX (PyCon 2021)

pycon2021

Magical NumPy with JAX

Description

The greatest contribution of the age the decade in which deep learning exploded was not these big models, but a generalized toolkit to train any model by gradient descent. We're now in an era where differential computing can give you the toolkit to train models of any kind. Does a Pythonista well-versed in the PyData stack have to learn an entirely new toolkit, a new array library to have access to this power?

This tutorial's answer is as follows: If you can write NumPy code, then with JAX, differential computing is at your fingertips with no need to learn a new array library! In this tutorial, you will learn how to use the NumPy-compatible JAX API to write performant numerical models of the world and train them using gradient-based optimization. Along the way, you will write loopy numerical code without loops, think in data cubes, get your functional programming muscles trained up, generate random numbers completely deterministically (no, this is not an oxymoron!), and preview how to mix neural networks and probabilistic models together... leveraging everything you know about NumPy plus some nearly-learned JAX magic sprinkled in!

Audience

This tutorial is for Pythonistas who wish to take their NumPy skills to the next level. As such, this is not a tutorial for Python beginners. Attendees should know how to use NumPy syntax; for them, a "dot product" should not feel unfamiliar. That said, advanced knowledge in deep learning and linear algebra is not necessary; for each of the sections in the tutorial, there will be minimally complex examples and exercise puzzles that illustrate the point. Prior experience using functools.partial or writing closures is not necessary (as we will go through what they do) but definitely useful to have.

By the end of the tutorial, attendees will be familiar with JAX primitives, including vmap (to replace for-loops), jit (for just-in-time compilation), grad (to calculate gradients), lax.scan to write loops with carryover, random to generate purely deterministic random sequences, and a numerical program composed of multiple primitive parts. Attendees will also see their application in a few toy problem settings, for which their natural extensions to complex models will be explained, and thus be equipped to take what they have learned in the tutorial and extend it to their own problems.


  • level: advanced
  • category: machine learning

Format

The tutorial format will be comprised of short lecture-style demonstrations followed by hands-on activities that reinforce the point made. Jupyter notebooks will feature prominently here, hosted on Binder or locally (at participants' choice). For Q&A, I am happy to address them offline.

Outline

  1. Introduction (10 min)
  2. Loopless loops (20 min)
    1. An introduction to vmap and lax.scan, which allow for loop writing without writing loop code.
    2. Demo examples: computing averages row-wise (vmap), calculating the lagging average (lax.scan).
    3. Exercises: scaling to row-wise probabilities (vmap), calculating compound interest (lax.scan).
  3. Break (10 min)
  4. Deterministic randomness (30 min)
    1. An introduction to JAX's PRNGKey system and its contrast to stateful random number generation systems.
    2. Emphasis on 100% reproducibility within the same session, and the need to split key for more complex programs.
    3. Demo examples: drawing from Gaussians and Poissons deterministically, simulating stochastic processes with dependencies.
    4. Exercise: simulating stock market using a Gaussian random walk (lax.scan)
  5. Break (10 min)
  6. Optimized learning (50 min)
    1. An introduction to grad, which gives us the derivative of a scalar-output function w.r.t. its input, and JAX's built-in optimizers
    2. Demo examples: minimizing a simple math function; finding the maximum likelihood parameter values for Gaussian data.
    3. Exercises: programming a robot to find wells in a field (grad)
  7. Break (10 min)
  8. Advanced Fun (40 min)
    1. Tutorial attendees can pick-and-choose their own adventure here.
    2. Pairwise concatenation of each row in a matrix (a.k.a. double for-loops using vmap)
    3. Simulating multiple multiple instances of a Gaussian random walk using vmap and split PRNGKeys.
    4. Maximum likelihood estimation of multinomial parameters (vmap and grad)
    5. Fast message passing on graphs with linear algebra.
    6. Neural networks with stax.
  9. Concluding words (10 min)

Past Experience

I have presented talks and tutorials at PyCon, SciPy, PyData and ODSC. My most well-known tutorial is "Network Analysis Made Simple", and I have also led tutorials on Bayesian statistical modelling before. The content for this tutorial is freely available online at this website that I made, and will be updated based on this tutorial proposal's content if accepted.

index

This is the landing page for my notes.

This is 100% inspired by Andy Matuschak's famous notes page. I'm not technically skilled enough to replicate the full "Andy Mode", though, so I just did some simple hacks. If you're curious how these notes compiled, check out the summary in How these notes are made into HTML pages.

This is my "notes garden". I tend to it on a daily basis, and it contains some of my less fully-formed thoughts. Nothing here is intended to be cited, as the link structure evolves over time. The notes are best viewed on a desktop/laptop computer, because of the use of hovers for previews.

There's no formal "navigation", or "search" for these pages. To go somewhere, click on any of the "high-level" notes below, and enjoy.

  1. Notes on statistics
  2. Notes on differential computing
  3. The State of Data Science
  4. Network science
  5. Scholarly readings
  6. Software skills for data scientists
  7. The Data Science Programming Newsletter MOC
  8. Life and computer hacks
  9. Reading Bazaar
  10. Blog drafts
  11. Conference Proposals

Magical NumPy with JAX (SciPy 2021)

scipy2021

Title

Magical NumPy with JAX

Abstract

The most significant contribution of the age the decade in which deep learning exploded was not these big models but a generalized toolkit to train any model by gradient descent. We're now in an era where differential computing can give you the toolkit to train models of any kind. Does a Pythonista well-versed in the PyData stack have to learn an entirely new toolkit, a new array library to have access to this power?

This tutorial's answer is as follows: If you can write NumPy code, then with JAX, differential computing is at your fingertips with no need to learn a new array library! In this tutorial, you will learn how to use the NumPy-compatible JAX API to write performant numerical models of the world and train them using gradient-based optimization. Along the way, you will write loopy numerical code without loops, think in data cubes, get your functional programming muscles trained up, generate random numbers deterministically (no, this is not an oxymoron!), and preview how to mix neural networks and probabilistic models... leveraging everything you know about NumPy plus some nearly-learned JAX magic sprinkled in!

Keywords

differential computing
numpy
jax
array computing

Track

Machine Learning and Data Science

Short Summary (<100 words)

JAX's magic brings your NumPy game to the next level! Come learn how to write loop-less numerical loops, optimize any function, jit-compile your programs, gain reliability over stochastic numbers - basically equip yourself with a bag of tricks to help you write robust numerical programs.

Tutorial Description

The most significant contribution of the age the decade in which deep learning exploded was not these big models but a generalized toolkit to train any model by gradient descent. We're now in an era where differential computing can give you the toolkit to train models of any kind. Does a Pythonista well-versed in the PyData stack have to learn an entirely new toolkit, a new array library to have access to this power?

This tutorial's answer is as follows: If you can write NumPy code, then with JAX, differential computing is at your fingertips with no need to learn a new array library! In this tutorial, you will learn how to use the NumPy-compatible JAX API to write performant numerical models of the world and train them using gradient-based optimization. Along the way, you will write loopy numerical code without loops, think in data cubes, get your functional programming muscles trained up, generate random numbers deterministically (no, this is not an oxymoron!), learn tips and tricks for handling array programs sanely, and preview how to mix neural networks and probabilistic models... leveraging everything you know about NumPy plus some nearly-learned JAX magic sprinkled in!

Pre-requisite knowledge: If you're comfortable with the NumPy API, then you'll be well-equipped for this tutorial. This is a tutorial that will equip you beyond simple deep learning; instead of learning how to use a deep learning framework, you'll leave equipped with a toolkit to write high-performance numerical models of the world and optimize them with gradient descent. Familiarity with Jupyter will help. Local setup is not necessary; Binder will be an option for tutorial participants.

Tutorial Outline

  1. Introduction (10 min)
  2. Loopless loops (20 min)
    1. An introduction to vmap and lax.scan, which allow for loop writing without writing loop code.
    2. Demo examples: computing averages row-wise (vmap), calculating the lagging average (lax.scan).
    3. Exercises: scaling to row-wise probabilities (vmap), calculating compound interest (lax.scan).
  3. Break (10 min)
  4. Deterministic randomness (30 min)
    1. An introduction to JAX's PRNGKey system and its contrast to stateful random number generation systems.
    2. Emphasis on 100% reproducibility within the same session, and the need to split key for more complex programs.
    3. Demo examples: drawing from Gaussians and Poissons deterministically, simulating stochastic processes with dependencies.
    4. Exercise: simulating stock market using a Gaussian random walk (lax.scan)
  5. Break (10 min)
  6. Optimized learning (50 min)
    1. An introduction to grad, which gives us the derivative of a scalar-output function w.r.t. its input, and JAX's built-in optimizers
    2. Demo examples: minimizing a simple math function; finding the maximum likelihood parameter values for Gaussian data.
    3. Exercises: programming a robot to find wells in a field (grad)
  7. Break (10 min)
  8. Advanced Fun (40 min)
    1. Tutorial attendees can pick-and-choose their own adventure here.
    2. Pairwise concatenation of each row in a matrix (a.k.a. double for-loops using vmap)
    3. Simulating multiple multiple instances of a Gaussian random walk using vmap and split PRNGKeys.
    4. Maximum likelihood estimation of multinomial parameters (vmap and grad)
    5. Fast message passing on graphs with linear algebra.
    6. Neural networks with stax.
  9. Concluding words (10 min)

Additional Tutorial Information

A preliminary version of the tutorial is available on GitHub: https://github.com/ericmjl/dl-workshop.

Tutorial Prerequisites

Pre-requisite knowledge: If you're comfortable with the NumPy API, then you'll be well-equipped for this tutorial. This is a tutorial that will equip you beyond simple deep learning; instead of learning how to use a deep learning framework, you'll leave equipped with a toolkit to write high-performance numerical models of the world and optimize them with gradient descent. Familiarity with Jupyter will help. Local setup is not necessary; Binder will be an option for tutorial participants.

Instructor Bio

Eric is a data scientist at the Novartis Institutes for Biomedical Research. There, he conducts biomedical data science research, with a focus on using Bayesian statistical methods in the service of making medicines for patients. Prior to Novartis, he was an Insight Health Data Fellow in the summer of 2017, and defended his doctoral thesis in the Department of Biological Engineering at MIT in the spring of 2017.

Eric is also an open source software developer, and has led the development of pyjanitor, a clean API for cleaning data in Python, and nxviz, a visualization package for NetworkX. In addition, he gives back to the open source community through code contributions to multiple projects.

His personal life motto is found in the Gospel of Luke 12:48.

Bayesian Data Science by Simulation (SciPy 2021)

scipy2021

Title

Bayesian Data Science by Simulation

Abstract

This introduces tutorial participants on how to use a probabilistic programming language, PyMC3, to perform a variety of statistical inference tasks. We will use hands-on instruction working on real-world examples (which have been simplified for pedagogical purposes) to show you how to do parameter estimation and inference, with a specific focus on building towards generalized Bayesian A/B/C/D/E… testing, a.k.a. multi-group experimental comparison, hierarchical modelling, and arbitrary curve regression.

Keywords

bayesian
simulation
data science

Additional Tutorial Information

The tutorial material is available at https://github.com/ericmjl/bayesian-stats-modelling-tutorial.

Tutorial Prerequisites

Tutorial participants should be familiar with the NumPy API, as well as matplotlib for plotting. They should feel comfortable operating in a Jupyter environment. Binder is available as a backup compute option for participants.

Bayesian Data Science by Simulation (PyCon 2021)

pycon2021

Title: Bayesian Data Science by Probabilistic Programming

Description

This introduces tutorial participants on how to use a probabilistic programming language, PyMC3, to perform a variety of statistical inference tasks. We will use hands-on instruction working on real-world examples (which have been simplified for pedagogical purposes) to show you how to do parameter estimation and inference, with a specific focus on building towards generalized Bayesian A/B/C/D/E… testing, a.k.a. multi-group experimental comparison, hierarchical modelling, and arbitrary curve regression.

Audience
This tutorial is intended for Pythonistas who are interested in using a probabilistic programming language to learn how to do flexible Bayesian data analysis without needing to know the fancy math behind it. 

Tutorial participants should come equipped with working knowledge of numpy, probability distributions and where they are commonly used, and frequentist statistics. By the end of the tutorial, participants will have code that they can use and modify for their own problems. 

Participants will also have at least one round of practice with the Bayesian modelling loop, starting from model (re-)formulation and ending in model checking. More generally, by the end of both sessions, participants should be equipped with the ability to think through and describe a problem using arbitrary (but suitable) statistical distributions and link functions.

Format
This tutorial is a hands-on tutorial, with a series of hands-on exercises for each topic being covered. Roughly 60% of the time will be spent on hands-on exercises, and 40% of the time on lecture-style material + discussions.

Internet access is not strictly necessary, but can be useful (to access Binder and its text editor and terminal emulator) if you encounter difficulties in setting up locally.

The general educational strategy here is to study on one model class for an extended period (the Beta-Binomial model), but use it to gradually introduce more advanced concepts, such as vectorization in PyMC3/Theano and hierarchical modelling. The use of repetition in workflow will also reinforce good practices. At the end, a lecture-style delivery will introduce Bayesian regression modelling.

Outline

The timings below indicate at what time we begin that section. Everything is relative to 0 minutes.

0th min: Introduction

In this section, we will cover some basic topics that are useful for the tutorial:

  • Probability as “assigning credibility over values”.
  • Probability distributions:
    • key parameters and shapes (i.e. likelihood functions)
    • what process they are modelling.

A combination of simulation and lectures will be used in this section.

15th min: Warm-Up With The Coin Flip

This section gets participants warmed up and familiar with PyMC3 and Theano syntax. We will use the classic coin flip to build intuition behind the Beta-Binomial model, a classic model that can be used in many places. Along the way, we will learn the basics of PyMC3, including:

  • How to structure a Bayesian model in PyMC3: priors, likelihood.
  • How to sample from posterior: “Inference Button”
  • How to examine model for correctness: posterior predictive checks.

40th min: Break for 5 min.

45th min: Extending the coin flip to two groups

This section extends the Beta-Binomial model to two groups. Here, we will compare the results of A/B testing of e-commerce site design, and use it to introduce the use of Bayesian estimation to provide richer information than a simple t-test.

We will begin by implementing the model the “manual” way, in which there is explicit duplication of code. Then, we will see how to vectorize this model. Many visuals will be provided to help participants understand what is going on.

Through this example, we will extend our knowledge with the ability to use vectorization to express what we would otherwise write in a for-loop.

More importantly, we will engage in a comparison with how we might do this in a frequentist setting, discover that the modelling assumptions of the t-test do not fit the problem setting, and conclude that a flexible modelling language is necessary.

1 hr 30th min: Break, 10 min.

1 hr 40th min: Hierarchical Beta-Bernoulli

In this section, we will extend our use of the Beta-Binomial model in a multi-group setting. Here, we will use hockey goalies’ save percentage as an example. A particular quirk of this dataset is that there are players that have very few data points.

We will first implement the model using lessons learned from the two-group case (i.e. how to vectorize our model), but soon after fitting the model and critiquing it, participants should discover that there are qualitative issues: namely, wide posterior distributions on measured ability where we plausibly wouldn’t believe so. (Should be 15 minutes or so to reach here.)

We will then introduce the idea of a hierarchical model, and use a code-along format to introduce how to build up the model, mainly by working backwards from the likelihood up to the parental priors involved. (Should be 15 minutes to reach here.)

Following that, we will look at a comparison between the posterior distributions for the non-hierarchical and hierarchical model. (10 minutes to finish this).

By the end of this section, participants should have a fairly complete view of the Bayesian modelling workflow, and should have a well-grounded anchoring example of the general way to model the data.

2 hr 30th minute: Break 10 minutes

2 hr 40th minute: Bayesian Regression Modelling

This section is more of a lecture than hands-on section, though there will be pre-written code for participants to read and execute at their own pace if they prefer it.

In this section, we will take the ideas learned before - that there are parameters in a model that are directly used in the likelihood distribution - and extend that idea to regression modelling, where model parameters are linked to the likelihood function parameters by an equation. In doing so, we will introduce the idea of a “link function”, and show how this is the general idea behind arbitrary curve regression.

There will be two examples, one for logistic regression and one for exponential curve decay. We are intentionally avoiding linear regression because the point is to show that any kind of “link function” is possible (including neural networks!).

3 hr 10th minute: Conclusion

We will summarize with the following framework given to participants:

  • Model = Parameters + Priors + Data + Structure (Equations) + Likelihood
  • Model + Sampler -> Posterior
  • Bayesian Estimation -> Hierarchical Bayesian Estimation
  • Single Group -> Two Group Comparison -> Multi-Group Comparison
  • Direct Estimation vs. Arbitrary Link Functions

3 hr 20th minute: End