%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from IPython.display import YouTubeVideo, display
YouTubeVideo("YB74pwPyAT0")
Replace simple for-loops with vmap
The first JAX thing we will look at is the vmap
function.
What does vmap
do?
From the JAX docs on vmap
Vectorizing map. Creates a function which maps fun over argument axes.
Basically the idea here is to take a function
and apply it to every "element" along a particular array axis.
The key skill to learn to use vmap
is to be able to
decompose a computation into its repeatable element.
Let's take a look at a few examples to make this clear.
Example: Squaring every element in an array
This is the first example that we will walk through,
which involves applying a function over every element in a vector.
By default, vmap
takes in a function f
and returns a function f_prime
that maps f
over the leading axis of an array.
The axis along which the array is mapped is configurable,
and we'll see that in a moment.
For now, let's explore what the default behaviour of vmap
is.
In the example below, we start with a function square
that takes in scalars and returns scalars.
(Whether they are float or integers doesn't really matter,
but floats are the generalization of integers,
so we'll work with that.)
If we are being stringent about types,
we won't allow ourselves to pass in an array into the square function,
even though NumPy technically allows us to do so.
vmap
gives us the following equivalent function:
def func(x):
...
return result
def vmapped_func(array):
result = []
for element in array:
result.append(func(x))
result = np.stack(result)
return result
Hence, we can apply a function across the leading (first) axis of an array. In the case of a vector, there is only one axis, so we simply apply the function to all elements on the array.
import jax.numpy as np
a = np.arange(20) # (20,)
def square(x: float) -> float:
return x ** 2
from jax import vmap
mapped_sq = vmap(square) # this is a function!
mapped_sq(a)
vmap
returns a function
There is one very important thing to remember here!
vmap
takes in a function and returns another function.
(mapped_sq
in the example above.)
We still have to pass an array into the returned function,
otherwise we won't get a result.
Example: Summing every row or column of a 2D matrix
In this next example, we will see how to apply a reduction function (e.g. summation) across every row or column in a matrix. This example will allow us to see how to configure the in_axes
argument of vmap
.
def sum_vector(x: np.ndarray) -> np.ndarray:
"""Assumes `x` is a vector"""
return np.sum(x)
a = np.arange(20).reshape((4, 5))
a
# Apply `sum_vector` across each element along the 0th axis.
vmap(sum_vector)(a) # shape: (4,)
# Apply `sum_vector` across each element along the 1st axis.
vmap(sum_vector, in_axes=1)(a) # shape: (5,)
Those of you who are experienced with NumPy are probably thinking,
"Couldn't we just specify the axis
argument of np.sum
, such as np.sum(axis=1)
?"
Yes, but there's more:
- Using
vmap
nudges us to think about the elementary and repeatable computation that is used. - We practice this skill by thinking about it on a trivial example.
Moreover, if we think carefully about the semantic meaning of our array data structures, we can avoid magic axis numbers showing up in our code. (An example of this is consistently keeping the time axis on the leading axis.) And as we all know, the fewer magic numbers there are inside code, the easier it is for us to read it.
Example: Softmax function
Here is another example involving the softmax function. (We have provided for you the softmax function.) This sort of operation is usually done when we want to take every row in a matrix, which might have negative numbers, and convert them into a stack of probability vectors. (To learn more about the softmax function, check out the Wikipedia article.)
def softmax(x: np.ndarray) -> np.ndarray:
"""Vector-wise softmax transform."""
return np.exp(x) / np.sum(np.exp(x))
a = np.arange(20).reshape((4, 5))
a
vmap(softmax)(a)
Example: Solving for angle of a triangle
When solving for an angle inside a right triangle, we need to know two of its lengths.
Say we know the opposite and adjacent side lengths for a single triangle.
We can then solve for the corresponding angle by taking np.arctan(opp/adj)
.
def angle(opp: float, adj: float):
return np.arctan(opp / adj)
angle(5, 3)
Imagine we had two vectors, one for opp
s and the other for adj
s. vmap
can automatically transform the angle
function into one that operates on vector pairs of opp
and adj
(assuming they both have the same length).
opps = np.arange(20)
adjs = np.linspace(3, 30, 20)
vmap(angle)(opps, adjs)
Exercises
Let's go to some exercises to flex your newly-found vmap
muscles!
Everything you need to know you have picked up above;
all that's left is getting practice creatively combining them together.
Use the puzzles below, which are ordered in increasing complexity,
to challenge your skillsets!
Exercise: vmap
-ed dot products
Message passing is a fundamental operation in the network analysis and graph neural network worlds. It is defined by taking a square adjacency-like matrix (also known as the diffusion matrix) of a graph and matrix multiplying it against a node feature matrix (also known as the message matrix). In NumPy code:
def mp(a: np.ndarray, f: np.ndarray) -> np.ndarray:
"""Message passing operator.
- `a`: An adjacency-like matrix of size (num_nodes, num_nodes).
- `f`: A message matrix of size (num_nodes, num_feats).
"""
return np.dot(a, f)
Suppose we have 13 graphs, each of size 7 nodes such that each node has a message vector of length 11.
We'd like to perform a message passing operation on each of those graphs.
Your task is to implement this using vmap
.
from jax import random
num_nodes = 7
num_graphs = 13
num_feats = 11
key = random.PRNGKey(90)
As = random.bernoulli(key, p=0.1, shape=(num_graphs, num_nodes, num_nodes))
Fs = random.normal(key, shape=(num_graphs, num_nodes, num_feats))
The naive implementation should look something like this:
def naive_mp(As, Fs):
res = []
for a, f in zip(As, Fs):
res.append(np.dot(a, f))
return np.stack(res)
# Your answer here.
def vmapped_message_passing(As, Fs):
"""Your answer here!"""
result = vmap(mp)(As, Fs)
return result
Verify that your answer is correct.
result = vmapped_message_passing(As, Fs)
assert result.shape == naive_mp(As, Fs).shape
assert not np.allclose(result, Fs)
Exercise: Chained vmap
s
We're going to try our hand at constructing a slightly more complex program.
This program takes in one dataset of three dimensions,
(n_datasets, n_rows, n_columns)
.
The program first calculates
the cumulative product across each row in a dataset,
then sums them up (collapsing the columns) across each dataset,
and finally applies this same operation across all datasets stacked together.
This one is a bit more challenging!
To help you along here, the shape of the data are such:
- There are 11 stacks of data.
- Each stack of data has 31 rows, and 7 columns.
The result of this program still should have 11 stacks and 31 rows, but now each column is not the original data, but the cumulative product of the previous columns.
To get this answer right,
no magic numbers are allows (e.g. for accessing particular axes).
At least two vmap
s are necessary here.
from dl_workshop.jax_idioms import loopless_loops_ex2
data = random.normal(key, shape=(11, 31, 7))
def ex2_numpy_equivalent(data):
result = []
for d in data:
cp = np.cumprod(d, axis=-1)
s = np.sum(cp, axis=1)
result.append(s)
return np.stack(result)
def loopless_loops_ex2(data):
"""Your solution here!"""
pass
# Comment out the import if you want to test your answer.
from dl_workshop.jax_idioms import loopless_loops_ex2
assert loopless_loops_ex2(data).shape == ex2_numpy_equivalent(data).shape
Exercise: Double for-loops
This one is a favourite of mine,
and took me an afternoon of on-and-off thinking to reason about clearly.
Graphs, a.k.a. networks, are comprised of nodes and edges,
and nodes can be represented by a vector of information (i.e. node features).
Stacking all the nodes' vectors together
gives us a node feature matrix.
In graph attention networks, one step is needed where we pairwise concatenate
every node to every other node together.
For example, if every node had a length n_features
feature vector,
then concatenating two nodes' vectors together
should give us a length 2 * n_features
vector.
Doing this pairwise across all nodes in a graph
would give us an (n_nodes, n_nodes, 2 * n_features)
tensor.
Your challenge below is to write the vmapped version of the following:
num_nodes = 13
num_feats = 17
node_feats = random.normal(key, shape=(13, 17))
def ex3_numpy_equivalent(node_feats):
result = []
for node1 in node_feats:
node1_concats = []
for node2 in node_feats:
cc = np.concatenate([node1, node2])
node1_concats.append(cc)
result.append(np.stack(node1_concats))
return np.stack(result)
def loopless_loops_ex3(node_feats):
"""Your solution here!"""
pass
# Comment out the import if you want to test your answer.
from dl_workshop.jax_idioms import loopless_loops_ex3
assert (
loopless_loops_ex3(node_feats).shape
== ex3_numpy_equivalent(node_feats).shape
)
Summary
To recap, the semantics of vmap
basically follow this logic:
Take an elementary computation and repeat it across the leading axis of an array.
The elementary computation shouldn't know anything about the leading axis.
You can then create the vmap
-ed function that knows about the leading axis
by passing the function through vmap
and getting back another function.