```
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
```

# Closures and Partials

We're going to take a quick detour and look at this idea of "partially evaluating a function". This is going to be important, as it'll allow us to construct functions that are compatible with the requirements of `vmap`

and `lax.scan`

and others in JAX, i.e. they have the correct function signature, but still allow us the flexibility to put in arbitrary things that might be needed for the function to work correctly.

There are two ways to do this: you can either use `functools.partial`

, or you can use function closures. Let's see how to do this.

## Partially evaluating a function using `functools.partial`

For simplicity's sake, let's explore the idea using a function that adds two numbers:

```
def add(a, b):
return a + b
```

Now, let's say we wanted to fix `b`

to the value `3`

, thus generating an `add_three`

function. We can do this two ways. The first is by `functools.partial`

:

```
from functools import partial
add_three = partial(add, b=3)
```

We can now call `add_three`

on any value of `a`

:

```
add_three(20)
```

If we inspect the function `add_three`

:

```
add_three?
```

We see that `add_three`

accepts one *positional* argument, `a`

, and its value of `b`

has been set to a default of `3`

.

What if we wanted to fix `a`

to `3`

instead?

```
add_three_v2 = partial(add, a=3)
add_three_v2?
```

Notice how now the function signature has changed, such that `b`

is not set while `a`

has been. This has implications for how we use the function.

Calling the function this way will error out:

```
>>> add_three_v2(3)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-109-e78f540eb25e> in <module>
----> 1 add_three_v2(3)
TypeError: add() got multiple values for argument 'a'
```

That is because when we pass in the argument with no keyword specified, it is interpreted as the first positional argument, which as you can see, has already been set.

On the other hand, calling the function this way will not:

```
add_three_v2(b=3)
```

## Creating closures

Another pattern that we can use is to use closures. Closures are functions that return a closed function that contains information from the closing function. Confused? Let me illustrate:

```
def closing_function(a):
def closed_function(b):
return a + b
return closed_function
```

Using this pattern, we can rewrite `add_three`

using closures:

```
def make_add_something(value):
def closed_function(b):
return b + value
return closed_function
add_three_v3 = make_add_something(3)
add_three_v3(5)
```

```
add_three_v3?
```

Now, you'll notice that the signature of `add_three_v3`

follows that exactly of the closed function.

When writing array programs using JAX, this is the key design pattern you'll want to implement: Always return a function that has the function signature that you need.

Naming things is the hardest activity in programming, because we are giving categorical names to things, and sometimes their category of thing isn't always clear. Fret not: the pattern I'll give you is the following:

```
def SOME_FUNCTION_generator(argument1, argument2, keyword_arugment1=default_value1):
"""To simplify things, just give the name of the closing function <some_function>_generator."""
def inner(arg1, arg2, kwarg1=default_value1):
"""This function should follow the API that is neeed."""
return something
return inner
```