Workshop on Differential Learning
Key Information:
Feedback from Kannan Sankar and Mei Xiao after today's differential computing tutorial.
_i indicator in the summation.vmap is a compiled for-loopOther thoughts
vmap
vmap is a JAX function transformation that defines the application of a function across one axis of an array.
Notes on differential computing
This is an overview page of my notes on differential computing and JAX.
Contents:
Stuff I've built/done with JAX:
lax.scan
lax.scan gives us loops with carryover.
Examples include:
jit
jit does Just In Time (JIT) compilation of a function to lower level code, thus bypassing the Python interpreter overhead.
grad
grad returns a transformed version of a function f, which when called gives us a gradient function (let's call it df) gives us the derivative of f evaluated at the same arguments passed into f.