Skip to content


%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

Closures and Partials

We're going to take a quick detour and look at this idea of "partially evaluating a function". This is going to be important, as it'll allow us to construct functions that are compatible with the requirements of vmap and lax.scan and others in JAX, i.e. they have the correct function signature, but still allow us the flexibility to put in arbitrary things that might be needed for the function to work correctly.

There are two ways to do this: you can either use functools.partial, or you can use function closures. Let's see how to do this.

Partially evaluating a function using functools.partial

For simplicity's sake, let's explore the idea using a function that adds two numbers:

def add(a, b):
    return a + b

Now, let's say we wanted to fix b to the value 3, thus generating an add_three function. We can do this two ways. The first is by functools.partial:

from functools import partial

add_three = partial(add, b=3)

We can now call add_three on any value of a:


If we inspect the function add_three:


We see that add_three accepts one positional argument, a, and its value of b has been set to a default of 3.

What if we wanted to fix a to 3 instead?

add_three_v2 = partial(add, a=3)

Notice how now the function signature has changed, such that b is not set while a has been. This has implications for how we use the function.

Calling the function this way will error out:

>>> add_three_v2(3)
TypeError                                 Traceback (most recent call last)
<ipython-input-109-e78f540eb25e> in <module>
----> 1 add_three_v2(3)

TypeError: add() got multiple values for argument 'a'

That is because when we pass in the argument with no keyword specified, it is interpreted as the first positional argument, which as you can see, has already been set.

On the other hand, calling the function this way will not:


Creating closures

Another pattern that we can use is to use closures. Closures are functions that return a closed function that contains information from the closing function. Confused? Let me illustrate:

def closing_function(a):
    def closed_function(b):
        return a + b
    return closed_function

Using this pattern, we can rewrite add_three using closures:

def make_add_something(value):
    def closed_function(b):
        return b + value
    return closed_function

add_three_v3 = make_add_something(3)

Now, you'll notice that the signature of add_three_v3 follows that exactly of the closed function.

When writing array programs using JAX, this is the key design pattern you'll want to implement: Always return a function that has the function signature that you need.

Naming things is the hardest activity in programming, because we are giving categorical names to things, and sometimes their category of thing isn't always clear. Fret not: the pattern I'll give you is the following:

def SOME_FUNCTION_generator(argument1, argument2, keyword_arugment1=default_value1):
    """To simplify things, just give the name of the closing function <some_function>_generator."""
    def inner(arg1, arg2, kwarg1=default_value1):
        """This function should follow the API that is neeed."""
        return something
    return inner