Skip to content


%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

Writing linear models with stax

In this notebook, I'll show the code for how to use JAX's stax submodule to write arbitrary models.


I'm assuming you have read through the jax-programming.ipynb notebook, as well as the tutorial.ipynb notebook.

The main tutorial.ipynb notebook gives you a general introduction to differential programming using grad, while the jax-programming.ipynb notebook gives you a flavour of the other four main JAX idioms: vmap, lax.scan, random.PRNGKey, and jit.

What is stax?

Most deep learning libraries use objects as the data structure for a neural network layer. As such, the tunable parameters of the layer, for example w and b for a linear ("dense") layer are class attributes associated with the forward function.

In some sense, because a neural network layer is nothing more than a math function, specifying the layer in terms of a function might also make sense. stax, then, is a new take on writing neural network models using pure functions rather than objects.

How does stax work?

The way that stax layers work is as follows. Every neural network layer is nothing more than a math function with a "forward" pass. Neural network models typically have their parameters initialized into the right shapes using random number generators. Put these two together, and we have a pair of functions that specify a layer:

  • An init_fun function, that initializes parameters into the correct shapes, and
  • An apply_fun function, that applies the specified math transformations onto incoming data, using parameters of the correct shape.

Example: Linear layer

Let's see an example of this in action, by studying the implementation of the linear ("dense") layer in stax

from jax.experimental import stax

As you can see, the apply_fun specifies the linear transformation. It accepts a parameter called params, which gets tuple-unpacked into the appropriate W and b.

Notice how the params argument matches up with the second output of init_fun! The init_fun always accepts an rng parameter, which is returned from JAX's jax.random.PRNGKey(). It also accepts an input_shape parameter, which specifies what the elementary shape of one sample of data is. So if your entire dataset is of shape (n_samples, n_columns), then you would put in (n_columns,) inside there, as you would want to ignore the sample dimension, thus allowing us to take advantage of vmap to map our model function over each and every i.i.d. sample in our dataset. The init_fun also returns the output_shape, which is used later when we chain layers together.

Let's see how we can use the Dense layer to specify a linear regression model.

Create the initialization and application function pairs

Firstly, we create the init_fun and apply_fun pair:

init_fun, apply_fun = stax.Dense(1)

Initialize the parameters

Now, let's initialize parameters using the init_fun.

Let's assume that we have data that is of 4 columns only.

from jax import random, numpy as np

key = random.PRNGKey(42)

output_shape, params_initial = init_fun(key, input_shape=(4,))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

(DeviceArray([[ 0.43186307],
              [ 0.9563157 ],
              [ 0.9000483 ],
              [-0.32341173]], dtype=float32),
 DeviceArray([0.01369469], dtype=float32))

Apply parameters and data through function

We'll create some randomly generated data.

X = random.normal(key, shape=(200, 4))
X[0:5], X.shape
(DeviceArray([[ 0.03751167,  0.8777058 , -1.2008178 , -1.3824965 ],
              [-0.37519178,  1.5957882 , -1.5086783 , -0.75612265],
              [-0.51650995, -1.056697  ,  0.99382603, -0.3520223 ],
              [-1.4446373 , -0.51537496,  0.30910558, -0.31770143],
              [ 0.9590066 , -0.6170032 , -0.37160665, -0.7339001 ]],            dtype=float32),
 (200, 4))

Here's some y_true values that I've snuck in.

y_true =, np.array([1, 2, 3, 4])) + 5
y_true = y_true.reshape(-1, 1)
y_true[0:5], y_true.shape
(DeviceArray([[-2.3395162 ],
              [ 0.26585913],
              [ 3.9434848 ],
              [ 2.1811237 ],
              [ 0.6745796 ]], dtype=float32),
 (200, 1))

Now, we'll pass data through the linear model!

from jax import vmap
from functools import partial

y_pred = vmap(partial(apply_fun, params_initial))(X)
y_pred[0:5], y_pred.shape
(DeviceArray([[ 0.23557997],
              [ 0.26439613],
              [-0.2593076 ]], dtype=float32),
 (200, 1))

VoilĂ ! We have a simple linear model implemented just like that.


Next question: how do we optimize the parameters using JAX?

Instead of writing a training loop on our own, we can take advantage of JAX's optimizers, which are also written in a functional paradigm!

JAX's optimizers are constructed as a "triplet" set of functions:

  • init: Takes params and initializes them in as a state, which is structured in a fashion that update can operate on.
  • update: Takes in i, g, and state, which respectively are:
    • i: The current loop iteration
    • g: Gradients calculated from grad!
    • state: The current state of the parameters.
  • get_params: Takes in the state at a given point, and returns the parameters structured correctly.
from jax import jit, grad
from jax.experimental.optimizers import adam

init, update, get_params = adam(step_size=1e-1)
update = jit(update)
get_params = jit(get_params)

Loss Function

We're still missing a piece here, that is the loss function. For illustration purposes, let's use the mean squared error.

def mseloss(params, model, x, y_true):
    y_preds = vmap(partial(model, params))(x)
    return np.mean(np.power(y_preds - y_true, 2))

dmseloss = grad(mseloss)

"Step" portion of update loop

Now, we're going to define the "step" portion of the update loop.

from dl_workshop.stax_models import step

JIT compilation

Because it takes so many parameters (in order to remain pure, and not rely on notebook state), we're going to bind some of them using functools.partial.

I'm also going to show you what happens when we JIT-compile vs. don't JIT-compile the function.

step_partial = partial(step, get_params=get_params, dlossfunc=dmseloss, update=update, model=apply_fun, x=X, y_true=y_true)
step_partial_jit = jit(step_partial)

Explicit loops

Firstly, let's see what kind of code we'd write if we did write the loop explicitly.

from time import time
start = time()
state = init(params_initial)
for i in range(1000):
    params = get_params(state)
    g = dmseloss(params, apply_fun, X, y_true)
    state = update(i, g, state)
end = time()
print(end - start)

Partialled out loop step

Now, let's run the loop with the partialled out function.

start = time()
state = init(params_initial)
for i in range(1000):
    state = step_partial(i, state)
end = time()
print(end - start)

JIT-compiled loop!

This is much cleaner of a loop, but we did have to do some work up-front.

What happens if we now use the JIT-ed function?

start = time()
state = init(params_initial)
for i in range(1000):
    state = step_partial_jit(i, state)
end = time()
print(end - start)

Whoa, holy smokes, that's fast! At least 10X faster using JIT-compilation.

lax.scan loop

Now we'll use some JAX trickery ot write a training loop without ever writing a for-loop.

from dl_workshop.stax_models import make_scannable_step
from jax import lax

scannable_step = make_scannable_step(step_partial_jit)

start = time()
initial_state = init(params_initial)
final_state, states_history = lax.scan(scannable_step, initial_state, np.arange(1000))
end = time()
print(end - start)

(DeviceArray([[1.       ],
              [3.       ],
              [3.9999995]], dtype=float32),
 DeviceArray([5.0000005], dtype=float32))

vmap-ed training loop over multiple starting points

Now, we're going to do the ultimate: we'll create at least 100 different parameter initializations and run our training loop over each of them.

from dl_workshop.stax_models import make_training_start
from jax import lax

train_linear = make_training_start(partial(init_fun, input_shape=(-1, 4)), init, scannable_step, 1000)

start = time()
initialization_keys = random.split(key, N_INITIALIZATIONS)
final_states, states_histories = vmap(train_linear)(initialization_keys)
end = time()
print(end - start)

w_final, b_final = vmap(get_params)(final_states)
DeviceArray([[1.0000002, 2.       , 2.9999998, 4.0000005],
             [1.0000001, 2.       , 3.       , 4.0000005],
             [1.0000002, 2.       , 3.       , 4.000001 ],
             [1.0000001, 2.       , 3.       , 3.9999998],
             [1.0000001, 2.       , 3.0000002, 4.       ]], dtype=float32)
DeviceArray([5.000001, 5.000001, 5.000001, 5.000001, 5.000001], dtype=float32)

Looks like we were also able to run the whole optimization pretty fast, and recover the correct parameters over multiple training starts.

JIT-compiled training loop

What happens if we JIT-compile the vmapped initialization?

start = time()
initialization_keys = random.split(key, N_INITIALIZATIONS)
train_linear_jit = jit(train_linear)
final_states, states_histories = vmap(train_linear_jit)(initialization_keys)
vmap(get_params)(final_states)  # this line exists to just block the computation until it completes.
end = time()
print(end - start)

HOOOOOLY SMOKES! Did you see that? With JIT-compilation, we essentially took the training time down to be identical to training on one starting point. Naturally, I don't expect this result to hold 100% of the time, but it's pretty darn rad to see that live.

The craziest piece here is that we could vmap our training loop over multiple starting points and get massive speedups there.