Resources for learning about JAX

I invested one of my vacation weeks crystallizing my learnings from working with JAX over the past year and a half, and it's been extremely educational. If you're interested in reading it, you can find it at my dl-workshop repository on GitHub. In there, in addition to the original content, which was a workshop on deep learning, I also try to provide "simple complex examples" of how to use JAX idioms in solving modelling problems.

Besides that, JAX's documentation is quite well-written, and you can find it at jax.readthedocs.io. In particular, they have a very well-documented suite of "The Sharp Bits" to look out for when using JAX, geared towards both power users of vanilla NumPy and beginners. If you're using JAX and run into unexpected behaviour, I'd strongly encourage you to check out the post - it'll clear up many misconceptions you might have!

In terms of introductory material, a blog post by Colin Raffel, titled "You don't know JAX", is a very well-written introduction on how to use JAX. Eric Jang also has a blog post on implementing meta-learning in JAX, which I found very educational for both JAX syntax and meta-learning.

While the most flashy advances of the deep learning world came from 2010-2020, personally think that the most exciting foundational advance of that era was the development of a general purpose automatic differentiation package like autograd and JAX. At least for the Python world, it's enabled the writing of arbitrary models in a highly compatible fashion with the rest of the PyData stack, with differentiation and native compilation as first class program transformations. The use of gradients is varied; I'd definitely encourage you to try it out!

Differential computing with JAX