Notes on differential computing

Where differential computing gets used

In case you've found yourself living like a digital hermit for the past decade (no judgment, sometimes I do fantasize about going offline for a year), deep learning has been the place where automatic differentiation has been most utilized. With deep learning, the core technical problem that needs to be solved is optimizing parameters of a model to minimize some loss function. It's here where the full set of partial derivatives of the loss function w.r.t. each parameter in the model can be automatically calculated using an AD system, and these partial derivatives can be used to update their respective model parameters in the direction that minimizes loss.

Because deep learning models and their applications proliferated in the 2010-2020 decade, AD systems were most commonly associated with neural networks and deep learning. However, that is not the only place where AD systems show up.

For example, AD is used in the Bayesian statistical modelling world. Hamiltonian Monte Carlo samplers use AD to help the sampler program identify the direction in which its next MCMC step should be taken. AD systems can also be used to optimize parameters of non-neural network models of the world against data, such as Gaussian Mixture Models and Hidden Markov Models. We can even use AD in a class of problems called "input design" problems, where we try to optimize not the parameters of the model w.r.t. some output, but the inputs (assuming we know how to cast the inputs into some continuous numerical space.)

What automatic differentiation systems exist

Where do AD systems live? Firstly, they definitely live inside deep learning frameworks such as PyTorch and TensorFlow, and other deep learning frameworks. Without an AD system, these two deep learning frameworks would not work.

Secondly, they also live in independent packages. In Julia, there are two AD packages: one called Zygote.jl, and the other called AutoGrad.jl; both of them are actively developed. autograd, which was the reference Python package that AutoGrad.jl was written against, is also the precursor to JAX, which I think of as automatic differentiation on steroids.

What makes JAX special

API-compatible differentiable computing

The Python scientific computing stack, also known as the PyData or SciPy stack in the Python world, provides a large library of numerical programs that can be composed together into higher order programs. What JAX provides is a fully API-compatible reimplementation of the stack's differentiable functions, with what I think is a near-complete coverage of functions.

As such, users familiar with NumPy and SciPy can, with minimal changes to lines of code, write automatically differentiable versions of their existing programs, and develop new programs that are automatically differentiable.

How does this work? I have written a longer-form collection of teaching materials on this, but here is a quick example. If I have a silly program like the following one:

from jax.scipy.special import cholesky
import jax.numpy as np

def some_function(params, data):
    # do stuff with params, for example, a cholesky decomposition:
    U = cholesky(params)
    # followed by a sum of sine transform, for whatever reason it might be needed
    return np.sum(np.sin(U))

I can get the gradient function easily, using JAX's provided grad function:

from jax import grad

dsome_function = grad(some_function)

dsome_function has the same function signature as some_function, but instead of returning a scalar value, it returns the derivative of some_function w.r.t. params (the first argument), in the same (and possibly nested) data structure as params. That is to say, if params were a tuple, or a dictionary, or a list, or any other native Python construct, dsome_function would have the same structure.

grad can do many more fancy things, such as differentiating through loops and flow control, and second through nth-order derivatives, and I'd encourage you to check out the docs to learn more. (That is out of scope for this essay, as I'm focusing on high level points.)

Providing grad as a first-class citizen in an API-compatible fashion with the scientific Python computing stack makes it very easy to adopt differential computing tooling in one's programs.

API-compatibility with the rest of the PyData stack

A design choice made early on by the JAX developers was full NumPy and SciPy API compatibility, with minimal differences (mostly in the realm of random number generation) that are very well-documented. Incidentally, practice is also adopted by Dask and CuPy, which give us distributed and GPU-backed arrays respectively. This practice reflects a healthy dose of respect for what already exists and for end-users as well.

I think a contrasting example best illustrates this point. Consider a PyTorch or TensorFlow array vs. a JAX NumPy array.

To plot a PyTorch array's values in matplotlib, one must first convert it to a NumPy array:

a = torch.tensor([1, 2, 3])

On the other hand, with JAX:

import jax.numpy as np
a = np.array([1, 2, 3])

The syntax with JAX is identical to what one might write with vanilla NumPy. This is, I think, in part because JAX's developers have also strived to adhere to NEP18. This little API difference leads to compound interest-like savings in time for model developers.

Composable program transforms

JAX's grad function is merely the "gateway drug" to this bigger idea of "composable program transforms". grad is one example of a composable program transform: that is transforming one program into another in a composable fashion. Other transforms include vmap, lax.scan, jit, and more. These all accept Python functions and return Python functions. jit, in particular, can accelerate a program anywhere from 2-100 fold on a CPU, depending on what your reasonable baseline comparison is.

In particular, the latter three of the aformentioned transforms allow for highly performant loop code, written without loops, that can also be composed together. In a differential learning workshop that I have been developing, I provide further details in there, which you can take a look at.

There are other automatic program transformations that are in active development, and one exciting realm I see is in the probabilistic programming world. In PyMC3, for example, an automated transform happens when we take our PyMC3 syntax, which is written in a domain specific language (DSL) implemented in Python, and transform/compile it into a compute graph that gives us a likelihood function. It's as if PyMC3 gives us a likelihood(func) analogy to JAX's grad func. (If you've tried writing probabilistic model likelihoods by hand, you'll know how much of a convenience this is!)


This is the landing page for my notes.

This is 100% inspired by Andy Matuschak's famous notes page. I'm not technically skilled enough to replicate the full "Andy Mode", though, so I just did some simple hacks. If you're curious how these notes compiled, check out the summary in How these notes are made into HTML pages.

This is my "notes garden". I tend to it on a daily basis, and it contains some of my less fully-formed thoughts. Nothing here is intended to be cited, as the link structure evolves over time. The notes are best viewed on a desktop/laptop computer, because of the use of hovers for previews.

There's no formal "navigation", or "search" for these pages. To go somewhere, click on any of the "high-level" notes below, and enjoy.

  1. Notes on statistics
  2. Notes on differential computing
  3. The State of Data Science
  4. Network science
  5. Scholarly readings
  6. Software skills for data scientists
  7. The Data Science Programming Newsletter MOC
  8. Life and computer hacks
  9. Reading Bazaar
  10. Blog drafts
  11. Conference Proposals


Key Information

Open PRs


  • Next place to submit to: PeerJ Computer Science. (see: PeerJ CS submission guidelines)
  • Other locations in consideration: JCIM Application Notes
  • Rejected at:
    • Bioinformatics

The main uses of jax-unirep

The main use of jax-unirep really is in getting "standardized" representations of proteins that can be used for downstream ML applications easily.

Infinitely Wide Neural Networks

How useful are infinitely wide neural networks? I tried writing a Jupyter notebook to explore the idea.

Workshop on Differential Learning

Key Information:

Feedback from Kannan Sankar and Mei Xiao after today's differential computing tutorial.

  • Binary cross entropy formula: should have an _i indicator in the summation.
  • vmap is a compiled for-loop

Other thoughts

  • I think I can do my own version of "what each of JAX's primitives do", then show how they get composed together.

Differential computing with JAX

Estimating a multivariate Gaussian's parameters by gradient descent

I've seen this idea in action before, but never tried my hand at it until yesterday at work.

Turns out, if you have data that can be modelled by a multivariate Gaussian distribution, you can optimize the parameters of that Gaussian to maximize the likelihood of data under the multivariate Gaussian model.

The key tricky piece of this problem is that a multivariate Gaussian distribution's covariance parameter has to be structured. (see Covariance matrix) When pairing with gradient descent though, a key problem we face is that we can never guarantee that the partial derivative of our likelihood function w.r.t. the square covariance-like matrix that we input will preserve its correct structure. What can we do to solve this problem?

We use the Cholesky decomposition.

For example, using JAX's scipy wrapper (for automatic differentiation-compatible Cholesky decomposition):

import jax.numpy as np
from jax.scipy.linalg import cholesky

a = np.array([
    [1, 0.8],
    [0.8, 1],

U = cholesky(a)  # returns the upper triangle

From U, we can reconstitute a by taking the dot-product of the the transpose of U against U:

a_hat =, U)

assert np.all(a_hat == a)

That's all cool and such, but how does this apply to fitting a multivariate Gaussian against data? As mentioned above, the key problem is that when we initialize our parameters to optimize, we oftentimes sample a number from a Gaussian (or other unbounded) distributions. To initialize a covariance matrix, we might be tempted to draw a square matrix, but there's no easy way to guarantee that it follows the desired structure of a covariance matrix.

That's where the reconstitution of a covariance matrix from upper triangle matrices can help. To sweeten the deal, because of JAX's fully NumPy-compatible API, we can instead initialize howeve we want, and then perform a transformation step to help us get back to a covariance matrix:

from jax import random as rnd

key = rnd.PRNGKey(42)
init_cov = rnd.normal(key, shape=(5, 5))  # let's make a 5x5 covariance matrix

def transform_to_covariance_matrix(sq_mat):
    U = np.triu(sq_mat)
    U_T = np.transpose(U)
    return, U)

And now with that, we can perform a full optimization of the init_cov parameter to its maximum likelihood. Let me show you how this works.

Firstly, we define the likelihood of observing our data under a multivariate Gaussian model:

from jax.scipy.stats import multivariate_normal
from jax import vmap

def loglike(params, data):
    mu, untransformed_cov = params
    cov = transform_to_covariance_matrix(untransformed_cov)
    def logpdf_func(datum):
        """logpdf of multivariate normal for one datum."""
        return multivariate_normal.logpdf(mu, cov, datum)
    logp = vmap(logpdf_func)(data)
    return np.sum(logp)

Notice two pieces:
1. We used some pretty cool/rad JAX tooling, like vmap, to eliminate sample dimensions and for-loops!
2. We also assumed that our covariance matrix is passed into the loglike function in an untransformed form, and we transform it to the correct form directly.

Now, we perform gradient-based optimization. We define the derivative function, which we'll use later for calculating gradients:

from jax import grad
dloglike = grad(loglike)

We then initialize our parameters:

mu = random.normal(key, shape=(5,))
untransformed_cov = random.normal(key, shape=(5, 5))

params = mu, untransformed_cov  # package them up into a convenient variable.

Finally, we use JAX's built-in optimizers, calling on jit to make computations fast:

from jax import jit
from jax.experimental.optimizers import adam

init, update, get_params = adam(step_size=0.005)
get_params = jit(get_params); update = jit(update)
dloglike = jit(dloglike); loglike = jit(loglike)

state = init(params)
for i in range(300):
    params = get_params(state)
    g = dloglike(params, data)
    state = update(i, g, state)
mu_opt, untransformed_cov_opt = get_params(state)
cov_opt = transform_to_covariance_matrix(untransformed_cov_opt)

The optimized covariance matrix will be pretty darn close to the true one!

What is differential computing

I think at its core, we can think of differential computing as "computing where derivatives (from calculus) are a first-class citizen". This is analogous to probabilistic computing, in which probabilistic constructs (such as probability distributions and MCMC samplers) are given "first class" status in a language. By "first class" status, I mean that the relevant computing constructs are a well-developed category of constructs in the language, with clearly defined interfaces.

In differential computing, being able to evaluate gradients of a math function are at the heart of the language. (How the gradients are used is a matter of application.) The objective of a differential computing system is to write a program that gets a computer to automatically evaluate gradients of a math function that we have written in that language. This is where automatic differentiation (AD) systems come into play. They take a program, perhaps written as a Python function, and automatically transform the program into, perhaps, another Python function, which can be used to evaluate the derivative of the original function.

All of this falls under the paradigm of AD systems, as I mentioned earlier. Symbolic differentiation can be considered one subclass of AD systems; this is a point PyMC3 developer Brandon Willard would make. Packages such as autograd and now JAX (also by the autograd authors) are another subclass of AD systems, which leverage the chain rule and a recorder tape of math operations called in the program to automatically construct the gradient function.