Skip to content


%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

Writing neural network models using stax

We're now going to try rewriting the neural network model that we had earlier on, now using stax syntax, and traing it using the syntax that we have learned above.

Using stax.serial

Firstly, let's replicate the model using stax.serial. It's a serial composition of a Dense+Tanh layer, followed by a Dense+Sigmoid layer.

from jax.experimental import stax

nn_init, nn_apply = stax.serial(

def nn_init_wrapper(input_shape):
    def inner(key):
        return nn_init(key, input_shape)
    return inner

nn_initializer = nn_init_wrapper(input_shape=(-1, 41))
<function __main__.nn_init_wrapper.<locals>.inner(key)>

Now, we initialize one instance of the parameters.

from jax import random
key = random.PRNGKey(42)
output_shape, params_init = nn_initializer(key)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

We'll need a loss funciton to optimize as well.

from jax import grad, numpy as np, vmap
from functools import partial

def binary_cross_entropy(y_true, y_pred, tol=1e-6):
    return y_true * np.log(y_pred + tol) + (1 - y_true) * np.log(1 - y_pred + tol)

def logistic_loss(params, model, x, y):
    preds = vmap(partial(model, params))(x)
    bces = vmap(binary_cross_entropy)(y, preds)
    return -np.sum(bces)

dlogistic_loss = grad(logistic_loss)

Load in data

Now, we load in the data.

import pandas as pd
from pyprojroot import here

X = pd.read_csv(here() / 'data/biodeg_X.csv', index_col=0)
y = pd.read_csv(here() / 'data/biodeg_y.csv', index_col=0)

Test-drive functions to make sure they work

Always important. It'll reveal whether there's anything wrong with our code.

logistic_loss(params_init, nn_apply, X.values, y.values)
DeviceArray(910.0804, dtype=float32)

Progressively construct our training functions

Firstly, we make sure the step function works with our logistic loss, model func, and actual data.

from jax.experimental.optimizers import adam

adam_init, update, get_params = adam(0.0005)
from dl_workshop.stax_models import step, make_scannable_step, make_training_start
from time import time

stepfunc_nn = partial(step, dlossfunc=dlogistic_loss, get_params=get_params, update=update, model=nn_apply, x=X.values, y_true=y.values)
scannable_step = make_scannable_step(stepfunc_nn)
train_nn = make_training_start(nn_initializer, adam_init, scannable_step, n_steps=3000)
start = time()
final_state, states_history = train_nn(key)
end = time()
print(end - start)

Friends, if you remember where we started in the tutorial.ipynb notebook, the original neural network took approximately a minute to train on a GPU (and longer if on a CPU).

Let's now start by ploting the loss over training iterations. We start first with a function that returns the loss from a given state object.

import matplotlib.pyplot as plt
def calculate_loss(state, get_params, model, lossfunc, x, y):
    params = get_params(state)
    return lossfunc(params, model, x, y)

calculate_loss(final_state, get_params, nn_apply, logistic_loss, X.values, y.values)
DeviceArray(101.23043, dtype=float32)

Now, we need to vmap it over all states in the states history, to get back the loss score.

calc_loss_vmap = partial(
start = time()
losses = vmap(calc_loss_vmap)(states_history)
end = time()
print(end - start)


[<matplotlib.lines.Line2D at 0x7fc3181707c0>]

Training with multiple starting points

Just as above, we can also train the neural network with multiple starting points, again by vmap-ing our training function across split PRNGKeys.

keys = random.split(key, 5)

start = time()
final_states, state_histories = vmap(train_nn)(keys)
end = time()
print(end - start)

(5, 41, 20)

Let's plot the losses over each of the state histories. Our last function calc_loss_vmap calculates loss score for one time point, which we then vmap over a single states_history, so we need another function that encapsulates this behaviour and vmaps over all state histories.

def state_history_loss(state_history):
    losses = vmap(calc_loss_vmap)(state_history)
    return losses

losses = vmap(state_history_loss)(state_histories)
(5, 3000)
DeviceArray([[ 892.20435 ,  885.81555 ,  879.50555 , ...,  106.13542 ,
               106.08061 ,  106.025826],
             [ 747.18805 ,  740.2636  ,  733.4714  , ...,   87.10683 ,
                87.06524 ,   87.02368 ],
             [ 808.54504 ,  802.43036 ,  796.3972  , ...,   87.570694,
                87.52498 ,   87.47928 ],
             [1273.8547  , 1264.2145  , 1254.6216  , ...,  140.18053 ,
               140.12332 ,  140.06607 ],
             [ 720.6325  ,  716.0028  ,  711.449   , ...,   76.92966 ,
                76.8839  ,   76.838135]], dtype=float32)

Correctly-shaped! And now plotting it...

[<matplotlib.lines.Line2D at 0x7fc2eb25d850>,
 <matplotlib.lines.Line2D at 0x7fc2eb25d700>,
 <matplotlib.lines.Line2D at 0x7fc2eb25dbe0>,
 <matplotlib.lines.Line2D at 0x7fc2eb25d820>,
 <matplotlib.lines.Line2D at 0x7fc2eb25dee0>]

Now that's pretty cool! We were able to see the loss from three independent runs.

With sufficient memory, one would be able to do more runs; when I was writing this notebook early on, I saw that it was getting difficult to do on the order of tens of runs due to memory allocation issues.


In this notebook, we saw a few things in action.

Firstly, we saw how to use the stax module on a linear model. Anytime we have a new framework for doing differential programming, it's super important to be able to explore it in the context of a linear model, which is basically the foundation of all deep learning.

Secondly, we also explored how to leverage the JAX idioms to create fast parallelized training loops. We mixed-and-matched together jit, vmap, lax.scan, and grad into a performant training loop that was minimally nested.

A corollary of this programming style is that every piece of the code can, in principle, be properly tested, because they are properly isolated. Have you written training loops where you modify a little piece here and a little piece there, until you lost what your original working one looked like? With training functions that are minimally nested, we can control the behaviour explicitly using closures/partials easily. Even when doing experimenation, our code can run reliably and fast.

Thirdly, we saw how to apply the same lessons to training a neural network really fast with multiple starting points. The essence of the solution was to properly structure our program in progressively higher level layers of abstraction. We carefully wrote the program to go from the inner most layer out until we hit our goal of allowing for a set of multiple starts. The key here is that each level of abstraction is very natural, and corresponds to a "unit computation" being applied consistently across an "array" of things. Once we identify that "unit computation", writing the vmap-able or lax.scan-able function becomes very easy.