%load_ext autoreload %autoreload 2 %matplotlib inline %config InlineBackend.figure_format = 'retina'
Closures and Partials
We're going to take a quick detour and look at this idea of "partially evaluating a function". This is going to be important, as it'll allow us to construct functions that are compatible with the requirements of
lax.scan and others in JAX, i.e. they have the correct function signature, but still allow us the flexibility to put in arbitrary things that might be needed for the function to work correctly.
There are two ways to do this: you can either use
functools.partial, or you can use function closures. Let's see how to do this.
Partially evaluating a function using
For simplicity's sake, let's explore the idea using a function that adds two numbers:
def add(a, b): return a + b
Now, let's say we wanted to fix
b to the value
3, thus generating an
add_three function. We can do this two ways. The first is by
from functools import partial add_three = partial(add, b=3)
We can now call
add_three on any value of
If we inspect the function
We see that
add_three accepts one positional argument,
a, and its value of
b has been set to a default of
What if we wanted to fix
add_three_v2 = partial(add, a=3) add_three_v2?
Notice how now the function signature has changed, such that
b is not set while
a has been. This has implications for how we use the function.
Calling the function this way will error out:
>>> add_three_v2(3) --------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-109-e78f540eb25e> in <module> ----> 1 add_three_v2(3) TypeError: add() got multiple values for argument 'a'
That is because when we pass in the argument with no keyword specified, it is interpreted as the first positional argument, which as you can see, has already been set.
On the other hand, calling the function this way will not:
Another pattern that we can use is to use closures. Closures are functions that return a closed function that contains information from the closing function. Confused? Let me illustrate:
def closing_function(a): def closed_function(b): return a + b return closed_function
Using this pattern, we can rewrite
add_three using closures:
def make_add_something(value): def closed_function(b): return b + value return closed_function add_three_v3 = make_add_something(3) add_three_v3(5)
Now, you'll notice that the signature of
add_three_v3 follows that exactly of the closed function.
When writing array programs using JAX, this is the key design pattern you'll want to implement: Always return a function that has the function signature that you need.
Naming things is the hardest activity in programming, because we are giving categorical names to things, and sometimes their category of thing isn't always clear. Fret not: the pattern I'll give you is the following:
def SOME_FUNCTION_generator(argument1, argument2, keyword_arugment1=default_value1): """To simplify things, just give the name of the closing function <some_function>_generator.""" def inner(arg1, arg2, kwarg1=default_value1): """This function should follow the API that is neeed.""" return something return inner