Workshop on Differential Learning

Key Information:

Feedback from Kannan Sankar and Mei Xiao after today's differential computing tutorial.

  • Binary cross entropy formula: should have an _i indicator in the summation.
  • vmap is a compiled for-loop

Other thoughts

  • I think I can do my own version of "what each of JAX's primitives do", then show how they get composed together.

vmap

vmap is a JAX function transformation that defines the application of a function across one axis of an array.

lax.scan

lax.scan gives us loops with carryover.

Examples include:

  • MCMC sampling
  • gradient descent training routines
  • recurrent neural networks

jit

jit does Just In Time (JIT) compilation of a function to lower level code, thus bypassing the Python interpreter overhead.

grad

grad returns a transformed version of a function f, which when called gives us a gradient function (let's call it df) gives us the derivative of f evaluated at the same arguments passed into f.