Skip to content


%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from IPython.display import YouTubeVideo, display


Replace simple for-loops with vmap

The first JAX thing we will look at is the vmap function. What does vmap do? From the JAX docs on vmap

Vectorizing map. Creates a function which maps fun over argument axes.

Basically the idea here is to take a function and apply it to every "element" along a particular array axis. The key skill to learn to use vmap is to be able to decompose a computation into its repeatable element. Let's take a look at a few examples to make this clear.

Example: Squaring every element in an array

This is the first example that we will walk through, which involves applying a function over every element in a vector. By default, vmap takes in a function f and returns a function f_prime that maps f over the leading axis of an array. The axis along which the array is mapped is configurable, and we'll see that in a moment. For now, let's explore what the default behaviour of vmap is.

In the example below, we start with a function square that takes in scalars and returns scalars. (Whether they are float or integers doesn't really matter, but floats are the generalization of integers, so we'll work with that.) If we are being stringent about types, we won't allow ourselves to pass in an array into the square function, even though NumPy technically allows us to do so. vmap gives us the following equivalent function:

def func(x):
    return result

def vmapped_func(array):
    result = []
    for element in array:
    result = np.stack(result)
    return result

Hence, we can apply a function across the leading (first) axis of an array. In the case of a vector, there is only one axis, so we simply apply the function to all elements on the array.

import jax.numpy as np
a = np.arange(20)   # (20,)

def square(x: float) -> float:
    return x ** 2

from jax import vmap
mapped_sq = vmap(square)  # this is a function!
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

DeviceArray([  0,   1,   4,   9,  16,  25,  36,  49,  64,  81, 100, 121,
             144, 169, 196, 225, 256, 289, 324, 361], dtype=int32)

vmap returns a function

There is one very important thing to remember here! vmap takes in a function and returns another function. (mapped_sq in the example above.) We still have to pass an array into the returned function, otherwise we won't get a result.

Example: Summing every row or column of a 2D matrix

In this next example, we will see how to apply a reduction function (e.g. summation) across every row or column in a matrix. This example will allow us to see how to configure the in_axes argument of vmap.

def sum_vector(x: np.ndarray) -> np.ndarray:
    """Assumes `x` is a vector"""
    return np.sum(x)

a = np.arange(20).reshape((4, 5))
DeviceArray([[ 0,  1,  2,  3,  4],
             [ 5,  6,  7,  8,  9],
             [10, 11, 12, 13, 14],
             [15, 16, 17, 18, 19]], dtype=int32)
# Apply `sum_vector` across each element along the 0th axis.
vmap(sum_vector)(a)          # shape: (4,)
DeviceArray([10, 35, 60, 85], dtype=int32)
# Apply `sum_vector` across each element along the 1st axis.
vmap(sum_vector, in_axes=1)(a)  # shape: (5,)
DeviceArray([30, 34, 38, 42, 46], dtype=int32)

Those of you who are experienced with NumPy are probably thinking, "Couldn't we just specify the axis argument of np.sum, such as np.sum(axis=1)?" Yes, but there's more:

  1. Using vmap nudges us to think about the elementary and repeatable computation that is used.
  2. We practice this skill by thinking about it on a trivial example.

Moreover, if we think carefully about the semantic meaning of our array data structures, we can avoid magic axis numbers showing up in our code. (An example of this is consistently keeping the time axis on the leading axis.) And as we all know, the fewer magic numbers there are inside code, the easier it is for us to read it.

Example: Softmax function

Here is another example involving the softmax function. (We have provided for you the softmax function.) This sort of operation is usually done when we want to take every row in a matrix, which might have negative numbers, and convert them into a stack of probability vectors. (To learn more about the softmax function, check out the Wikipedia article.)

def softmax(x: np.ndarray) -> np.ndarray:
    """Vector-wise softmax transform."""
    return np.exp(x) / np.sum(np.exp(x))
a = np.arange(20).reshape((4, 5))
DeviceArray([[ 0,  1,  2,  3,  4],
             [ 5,  6,  7,  8,  9],
             [10, 11, 12, 13, 14],
             [15, 16, 17, 18, 19]], dtype=int32)
DeviceArray([[0.01165623, 0.03168492, 0.08612854, 0.23412164, 0.6364086 ],
             [0.01165623, 0.03168492, 0.08612854, 0.23412165, 0.6364086 ],
             [0.01165623, 0.03168492, 0.08612855, 0.23412165, 0.6364086 ],
             [0.01165623, 0.03168492, 0.08612854, 0.23412165, 0.6364086 ]],            dtype=float32)

Example: Solving for angle of a triangle

When solving for an angle inside a right triangle, we need to know two of its lengths. Say we know the opposite and adjacent side lengths for a single triangle. We can then solve for the corresponding angle by taking np.arctan(opp/adj).

def angle(opp: float, adj: float):
    return np.arctan(opp / adj)

angle(5, 3)
DeviceArray(1.0303768, dtype=float32)

Imagine we had two vectors, one for opps and the other for adjs. vmap can automatically transform the angle function into one that operates on vector pairs of opp and adj (assuming they both have the same length).

opps = np.arange(20)
adjs = np.linspace(3, 30, 20)

vmap(angle)(opps, adjs)
DeviceArray([0.        , 0.22244725, 0.32983664, 0.39169994, 0.43163884,
             0.45947224, 0.4799505 , 0.4956367 , 0.508031  , 0.51806855,
             0.52636147, 0.5333274 , 0.53926075, 0.5443749 , 0.5488283 ,
             0.55274117, 0.55620605, 0.55929583, 0.56206805, 0.5645694 ],            dtype=float32)


Let's go to some exercises to flex your newly-found vmap muscles! Everything you need to know you have picked up above; all that's left is getting practice creatively combining them together. Use the puzzles below, which are ordered in increasing complexity, to challenge your skillsets!

Exercise: vmap-ed dot products

Message passing is a fundamental operation in the network analysis and graph neural network worlds. It is defined by taking a square adjacency-like matrix (also known as the diffusion matrix) of a graph and matrix multiplying it against a node feature matrix (also known as the message matrix). In NumPy code:

def mp(a: np.ndarray, f: np.ndarray) -> np.ndarray:
    """Message passing operator.

    - `a`: An adjacency-like matrix of size (num_nodes, num_nodes).
    - `f`: A message matrix of size (num_nodes, num_feats).

    return, f)

Suppose we have 13 graphs, each of size 7 nodes such that each node has a message vector of length 11. We'd like to perform a message passing operation on each of those graphs. Your task is to implement this using vmap.

from jax import random

num_nodes = 7
num_graphs = 13
num_feats = 11

key = random.PRNGKey(90)
As = random.bernoulli(key, p=0.1, shape=(num_graphs, num_nodes, num_nodes))
Fs = random.normal(key, shape=(num_graphs, num_nodes, num_feats))

The naive implementation should look something like this:

def naive_mp(As, Fs):
    res = []
    for a, f in zip(As, Fs):
        res.append(, f))
    return np.stack(res)
# Your answer here.
def vmapped_message_passing(As, Fs):
    """Your answer here!"""
    result = vmap(mp)(As, Fs)
    return result

Verify that your answer is correct.

result = vmapped_message_passing(As, Fs)
assert result.shape == naive_mp(As, Fs).shape
assert not np.allclose(result, Fs)

Exercise: Chained vmaps

We're going to try our hand at constructing a slightly more complex program. This program takes in one dataset of three dimensions, (n_datasets, n_rows, n_columns). The program first calculates the cumulative product across each row in a dataset, then sums them up (collapsing the columns) across each dataset, and finally applies this same operation across all datasets stacked together. This one is a bit more challenging!

To help you along here, the shape of the data are such:

  • There are 11 stacks of data.
  • Each stack of data has 31 rows, and 7 columns.

The result of this program still should have 11 stacks and 31 rows, but now each column is not the original data, but the cumulative product of the previous columns.

To get this answer right, no magic numbers are allows (e.g. for accessing particular axes). At least two vmaps are necessary here.

from dl_workshop.jax_idioms import loopless_loops_ex2

data = random.normal(key, shape=(11, 31, 7))

def ex2_numpy_equivalent(data):
    result = []
    for d in data:
        cp = np.cumprod(d, axis=-1)
        s = np.sum(cp, axis=1)
    return np.stack(result)

def loopless_loops_ex2(data):
    """Your solution here!"""

# Comment out the import if you want to test your answer.
from dl_workshop.jax_idioms import loopless_loops_ex2

assert loopless_loops_ex2(data).shape == ex2_numpy_equivalent(data).shape

Exercise: Double for-loops

This one is a favourite of mine, and took me an afternoon of on-and-off thinking to reason about clearly. Graphs, a.k.a. networks, are comprised of nodes and edges, and nodes can be represented by a vector of information (i.e. node features). Stacking all the nodes' vectors together gives us a node feature matrix. In graph attention networks, one step is needed where we pairwise concatenate every node to every other node together. For example, if every node had a length n_features feature vector, then concatenating two nodes' vectors together should give us a length 2 * n_features vector. Doing this pairwise across all nodes in a graph would give us an (n_nodes, n_nodes, 2 * n_features) tensor.

Your challenge below is to write the vmapped version of the following:

num_nodes = 13
num_feats = 17
node_feats = random.normal(key, shape=(13, 17))

def ex3_numpy_equivalent(node_feats):
    result = []
    for node1 in node_feats:
        node1_concats = []
        for node2 in node_feats:
            cc = np.concatenate([node1, node2])

    return np.stack(result)

def loopless_loops_ex3(node_feats):
    """Your solution here!"""

# Comment out the import if you want to test your answer.
from dl_workshop.jax_idioms import loopless_loops_ex3

assert (
    == ex3_numpy_equivalent(node_feats).shape


To recap, the semantics of vmap basically follow this logic: Take an elementary computation and repeat it across the leading axis of an array. The elementary computation shouldn't know anything about the leading axis. You can then create the vmap-ed function that knows about the leading axis by passing the function through vmap and getting back another function.