%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
Closures and Partials
We're going to take a quick detour and look at this idea of "partially evaluating a function". This is going to be important, as it'll allow us to construct functions that are compatible with the requirements of vmap
and lax.scan
and others in JAX, i.e. they have the correct function signature, but still allow us the flexibility to put in arbitrary things that might be needed for the function to work correctly.
There are two ways to do this: you can either use functools.partial
, or you can use function closures. Let's see how to do this.
Partially evaluating a function using functools.partial
For simplicity's sake, let's explore the idea using a function that adds two numbers:
def add(a, b):
return a + b
Now, let's say we wanted to fix b
to the value 3
, thus generating an add_three
function. We can do this two ways. The first is by functools.partial
:
from functools import partial
add_three = partial(add, b=3)
We can now call add_three
on any value of a
:
add_three(20)
If we inspect the function add_three
:
add_three?
We see that add_three
accepts one positional argument, a
, and its value of b
has been set to a default of 3
.
What if we wanted to fix a
to 3
instead?
add_three_v2 = partial(add, a=3)
add_three_v2?
Notice how now the function signature has changed, such that b
is not set while a
has been. This has implications for how we use the function.
Calling the function this way will error out:
>>> add_three_v2(3)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-109-e78f540eb25e> in <module>
----> 1 add_three_v2(3)
TypeError: add() got multiple values for argument 'a'
That is because when we pass in the argument with no keyword specified, it is interpreted as the first positional argument, which as you can see, has already been set.
On the other hand, calling the function this way will not:
add_three_v2(b=3)
Creating closures
Another pattern that we can use is to use closures. Closures are functions that return a closed function that contains information from the closing function. Confused? Let me illustrate:
def closing_function(a):
def closed_function(b):
return a + b
return closed_function
Using this pattern, we can rewrite add_three
using closures:
def make_add_something(value):
def closed_function(b):
return b + value
return closed_function
add_three_v3 = make_add_something(3)
add_three_v3(5)
add_three_v3?
Now, you'll notice that the signature of add_three_v3
follows that exactly of the closed function.
When writing array programs using JAX, this is the key design pattern you'll want to implement: Always return a function that has the function signature that you need.
Naming things is the hardest activity in programming, because we are giving categorical names to things, and sometimes their category of thing isn't always clear. Fret not: the pattern I'll give you is the following:
def SOME_FUNCTION_generator(argument1, argument2, keyword_arugment1=default_value1):
"""To simplify things, just give the name of the closing function <some_function>_generator."""
def inner(arg1, arg2, kwarg1=default_value1):
"""This function should follow the API that is neeed."""
return something
return inner