```
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
```

```
from IPython.display import YouTubeVideo, display
YouTubeVideo("YB74pwPyAT0")
```

# Replace simple for-loops with `vmap`

The first JAX thing we will look at is the `vmap`

function.
What does `vmap`

do?
From the JAX docs on `vmap`

Vectorizing map. Creates a function which maps fun over argument axes.

Basically the idea here is to take a function
and apply it to every "element" along a particular array axis.
The key skill to learn to use `vmap`

is to be able to
decompose a computation into its repeatable element.
Let's take a look at a few examples to make this clear.

## Example: Squaring every element in an array

This is the first example that we will walk through,
which involves applying a function over every element in a vector.
By default, `vmap`

takes in a function `f`

and returns a function `f_prime`

that maps `f`

over the *leading* axis of an array.
The axis along which the array is mapped is configurable,
and we'll see that in a moment.
For now, let's explore what the default behaviour of `vmap`

is.

In the example below, we start with a function `square`

that takes in scalars and returns scalars.
(Whether they are float or integers doesn't really matter,
but floats are the generalization of integers,
so we'll work with that.)
If we are being stringent about types,
we won't allow ourselves to pass in an array into the square function,
even though NumPy technically allows us to do so.
`vmap`

gives us the following equivalent function:

```
def func(x):
...
return result
def vmapped_func(array):
result = []
for element in array:
result.append(func(x))
result = np.stack(result)
return result
```

Hence, we can apply a function across the leading (first) axis of an array. In the case of a vector, there is only one axis, so we simply apply the function to all elements on the array.

```
import jax.numpy as np
a = np.arange(20) # (20,)
def square(x: float) -> float:
return x ** 2
from jax import vmap
mapped_sq = vmap(square) # this is a function!
mapped_sq(a)
```

`vmap`

returns a function

There is one very important thing to remember here!
`vmap`

takes in a function and returns another function.
(`mapped_sq`

in the example above.)
We still have to pass an array into the returned function,
otherwise we won't get a result.

## Example: Summing every row or column of a 2D matrix

In this next example, we will see how to apply a reduction function (e.g. summation) across every row or column in a matrix. This example will allow us to see how to configure the `in_axes`

argument of `vmap`

.

```
def sum_vector(x: np.ndarray) -> np.ndarray:
"""Assumes `x` is a vector"""
return np.sum(x)
a = np.arange(20).reshape((4, 5))
a
```

```
# Apply `sum_vector` across each element along the 0th axis.
vmap(sum_vector)(a) # shape: (4,)
```

```
# Apply `sum_vector` across each element along the 1st axis.
vmap(sum_vector, in_axes=1)(a) # shape: (5,)
```

Those of you who are experienced with NumPy are probably thinking,
"Couldn't we just specify the `axis`

argument of `np.sum`

, such as `np.sum(axis=1)`

?"
Yes, but there's more:

- Using
`vmap`

nudges us to think about the elementary and repeatable computation that is used. - We practice this skill by thinking about it on a trivial example.

Moreover, if we think carefully about the semantic meaning of our array data structures, we can avoid magic axis numbers showing up in our code. (An example of this is consistently keeping the time axis on the leading axis.) And as we all know, the fewer magic numbers there are inside code, the easier it is for us to read it.

## Example: Softmax function

Here is another example involving the softmax function. (We have provided for you the softmax function.) This sort of operation is usually done when we want to take every row in a matrix, which might have negative numbers, and convert them into a stack of probability vectors. (To learn more about the softmax function, check out the Wikipedia article.)

```
def softmax(x: np.ndarray) -> np.ndarray:
"""Vector-wise softmax transform."""
return np.exp(x) / np.sum(np.exp(x))
```

```
a = np.arange(20).reshape((4, 5))
a
```

```
vmap(softmax)(a)
```

## Example: Solving for angle of a triangle

When solving for an angle inside a right triangle, we need to know two of its lengths.
Say we know the opposite and adjacent side lengths for a single triangle.
We can then solve for the corresponding angle by taking `np.arctan(opp/adj)`

.

```
def angle(opp: float, adj: float):
return np.arctan(opp / adj)
angle(5, 3)
```

Imagine we had two vectors, one for `opp`

s and the other for `adj`

s. `vmap`

can automatically transform the `angle`

function into one that operates on vector pairs of `opp`

and `adj`

(assuming they both have the same length).

```
opps = np.arange(20)
adjs = np.linspace(3, 30, 20)
vmap(angle)(opps, adjs)
```

## Exercises

Let's go to some exercises to flex your newly-found `vmap`

muscles!
Everything you need to know you have picked up above;
all that's left is getting practice creatively combining them together.
Use the puzzles below, which are ordered in increasing complexity,
to challenge your skillsets!

### Exercise: `vmap`

-ed dot products

Message passing is a fundamental operation in the network analysis and graph neural network worlds. It is defined by taking a square adjacency-like matrix (also known as the diffusion matrix) of a graph and matrix multiplying it against a node feature matrix (also known as the message matrix). In NumPy code:

```
def mp(a: np.ndarray, f: np.ndarray) -> np.ndarray:
"""Message passing operator.
- `a`: An adjacency-like matrix of size (num_nodes, num_nodes).
- `f`: A message matrix of size (num_nodes, num_feats).
"""
return np.dot(a, f)
```

Suppose we have 13 graphs, each of size 7 nodes such that each node has a message vector of length 11.
We'd like to perform a message passing operation on each of those graphs.
Your task is to implement this using `vmap`

.

```
from jax import random
num_nodes = 7
num_graphs = 13
num_feats = 11
key = random.PRNGKey(90)
As = random.bernoulli(key, p=0.1, shape=(num_graphs, num_nodes, num_nodes))
Fs = random.normal(key, shape=(num_graphs, num_nodes, num_feats))
```

The naive implementation should look something like this:

```
def naive_mp(As, Fs):
res = []
for a, f in zip(As, Fs):
res.append(np.dot(a, f))
return np.stack(res)
```

```
# Your answer here.
def vmapped_message_passing(As, Fs):
"""Your answer here!"""
result = vmap(mp)(As, Fs)
return result
```

Verify that your answer is correct.

```
result = vmapped_message_passing(As, Fs)
assert result.shape == naive_mp(As, Fs).shape
assert not np.allclose(result, Fs)
```

### Exercise: Chained `vmap`

s

We're going to try our hand at constructing a slightly more complex program.
This program takes in one dataset of three dimensions,
`(n_datasets, n_rows, n_columns)`

.
The program first calculates
the cumulative product across each row in a dataset,
then sums them up (collapsing the columns) across each dataset,
and finally applies this same operation across all datasets stacked together.
This one is a bit more challenging!

To help you along here, the shape of the data are such:

- There are 11 stacks of data.
- Each stack of data has 31 rows, and 7 columns.

The result of this program still should have 11 stacks and 31 rows, but now each column is not the original data, but the cumulative product of the previous columns.

To get this answer right,
no magic numbers are allows (e.g. for accessing particular axes).
At least two `vmap`

s are necessary here.

```
from dl_workshop.jax_idioms import loopless_loops_ex2
data = random.normal(key, shape=(11, 31, 7))
def ex2_numpy_equivalent(data):
result = []
for d in data:
cp = np.cumprod(d, axis=-1)
s = np.sum(cp, axis=1)
result.append(s)
return np.stack(result)
def loopless_loops_ex2(data):
"""Your solution here!"""
pass
# Comment out the import if you want to test your answer.
from dl_workshop.jax_idioms import loopless_loops_ex2
assert loopless_loops_ex2(data).shape == ex2_numpy_equivalent(data).shape
```

### Exercise: Double for-loops

This one is a favourite of mine,
and took me an afternoon of on-and-off thinking to reason about clearly.
Graphs, a.k.a. networks, are comprised of nodes and edges,
and nodes can be represented by a vector of information (i.e. node features).
Stacking all the nodes' vectors together
gives us a *node feature matrix*.
In graph attention networks, one step is needed where we pairwise concatenate
every node to every other node together.
For example, if every node had a length `n_features`

feature vector,
then concatenating two nodes' vectors together
should give us a length `2 * n_features`

vector.
Doing this pairwise across all nodes in a graph
would give us an `(n_nodes, n_nodes, 2 * n_features)`

tensor.

Your challenge below is to write the vmapped version of the following:

```
num_nodes = 13
num_feats = 17
node_feats = random.normal(key, shape=(13, 17))
def ex3_numpy_equivalent(node_feats):
result = []
for node1 in node_feats:
node1_concats = []
for node2 in node_feats:
cc = np.concatenate([node1, node2])
node1_concats.append(cc)
result.append(np.stack(node1_concats))
return np.stack(result)
def loopless_loops_ex3(node_feats):
"""Your solution here!"""
pass
# Comment out the import if you want to test your answer.
from dl_workshop.jax_idioms import loopless_loops_ex3
assert (
loopless_loops_ex3(node_feats).shape
== ex3_numpy_equivalent(node_feats).shape
)
```

## Summary

To recap, the semantics of `vmap`

basically follow this logic:
Take an elementary computation and repeat it across the leading axis of an array.
The elementary computation shouldn't know anything about the leading axis.
You can then create the `vmap`

-ed function that knows about the leading axis
by passing the function through `vmap`

and getting back another function.