Workshop on Differential Learning
Key Information:
Feedback from Kannan Sankar and Mei Xiao after today's differential computing tutorial.
_i
indicator in the summation.vmap
is a compiled for-loopOther thoughts
vmap
vmap is a JAX function transformation that defines the application of a function across one axis of an array.
Notes on differential computing
This is an overview page of my notes on differential computing and JAX.
Contents:
Stuff I've built/done with JAX:
lax.scan
lax.scan
gives us loops with carryover.
Examples include:
jit
jit
does Just In Time (JIT) compilation of a function to lower level code, thus bypassing the Python interpreter overhead.
grad
grad
returns a transformed version of a function f
, which when called gives us a gradient function (let's call it df
) gives us the derivative of f
evaluated at the same arguments passed into f
.