%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
Writing neural network models using stax
We're now going to try rewriting the neural network model that we had earlier on, now using stax
syntax, and traing it using the syntax that we have learned above.
Using stax.serial
Firstly, let's replicate the model using stax.serial
. It's a serial composition of a Dense+Tanh layer, followed by a Dense+Sigmoid layer.
from jax.experimental import stax
nn_init, nn_apply = stax.serial(
stax.Dense(20),
stax.Tanh,
stax.Dense(1),
stax.Sigmoid
)
def nn_init_wrapper(input_shape):
def inner(key):
return nn_init(key, input_shape)
return inner
nn_initializer = nn_init_wrapper(input_shape=(-1, 41))
nn_initializer
Now, we initialize one instance of the parameters.
from jax import random
key = random.PRNGKey(42)
output_shape, params_init = nn_initializer(key)
We'll need a loss funciton to optimize as well.
from jax import grad, numpy as np, vmap
from functools import partial
def binary_cross_entropy(y_true, y_pred, tol=1e-6):
return y_true * np.log(y_pred + tol) + (1 - y_true) * np.log(1 - y_pred + tol)
def logistic_loss(params, model, x, y):
preds = vmap(partial(model, params))(x)
bces = vmap(binary_cross_entropy)(y, preds)
return -np.sum(bces)
dlogistic_loss = grad(logistic_loss)
Load in data
Now, we load in the data.
import pandas as pd
from pyprojroot import here
X = pd.read_csv(here() / 'data/biodeg_X.csv', index_col=0)
y = pd.read_csv(here() / 'data/biodeg_y.csv', index_col=0)
Test-drive functions to make sure they work
Always important. It'll reveal whether there's anything wrong with our code.
logistic_loss(params_init, nn_apply, X.values, y.values)
Progressively construct our training functions
Firstly, we make sure the step function works with our logistic loss, model func, and actual data.
from jax.experimental.optimizers import adam
adam_init, update, get_params = adam(0.0005)
from dl_workshop.stax_models import step, make_scannable_step, make_training_start
from time import time
stepfunc_nn = partial(step, dlossfunc=dlogistic_loss, get_params=get_params, update=update, model=nn_apply, x=X.values, y_true=y.values)
scannable_step = make_scannable_step(stepfunc_nn)
train_nn = make_training_start(nn_initializer, adam_init, scannable_step, n_steps=3000)
start = time()
final_state, states_history = train_nn(key)
end = time()
print(end - start)
Friends, if you remember where we started in the tutorial.ipynb
notebook, the original neural network took approximately a minute to train on a GPU (and longer if on a CPU).
Let's now start by ploting the loss over training iterations. We start first with a function that returns the loss from a given state
object.
import matplotlib.pyplot as plt
def calculate_loss(state, get_params, model, lossfunc, x, y):
params = get_params(state)
return lossfunc(params, model, x, y)
calculate_loss(final_state, get_params, nn_apply, logistic_loss, X.values, y.values)
Now, we need to vmap
it over all states in the states history, to get back the loss score.
calc_loss_vmap = partial(
calculate_loss,
get_params=get_params,
model=nn_apply,
lossfunc=logistic_loss,
x=X.values,
y=y.values
)
start = time()
losses = vmap(calc_loss_vmap)(states_history)
end = time()
print(end - start)
plt.plot(losses)
Training with multiple starting points
Just as above, we can also train the neural network with multiple starting points, again by vmap
-ing our training function across split PRNGKeys.
keys = random.split(key, 5)
start = time()
final_states, state_histories = vmap(train_nn)(keys)
end = time()
print(end - start)
get_params(final_states)[0][0].shape
Let's plot the losses over each of the state histories. Our last function calc_loss_vmap
calculates loss score for one time point, which we then vmap
over a single states_history
, so we need another function that encapsulates this behaviour and vmap
s over all state histories.
def state_history_loss(state_history):
losses = vmap(calc_loss_vmap)(state_history)
return losses
losses = vmap(state_history_loss)(state_histories)
losses.shape
losses
Correctly-shaped! And now plotting it...
plt.plot(losses.T)
Now that's pretty cool! We were able to see the loss from three independent runs.
With sufficient memory, one would be able to do more runs; when I was writing this notebook early on, I saw that it was getting difficult to do on the order of tens of runs due to memory allocation issues.
Summary
In this notebook, we saw a few things in action.
Firstly, we saw how to use the stax
module on a linear model. Anytime we have a new framework for doing differential programming, it's super important to be able to explore it in the context of a linear model, which is basically the foundation of all deep learning.
Secondly, we also explored how to leverage the JAX idioms to create fast parallelized training loops. We mixed-and-matched together jit
, vmap
, lax.scan
, and grad
into a performant training loop that was minimally nested.
A corollary of this programming style is that every piece of the code can, in principle, be properly tested, because they are properly isolated. Have you written training loops where you modify a little piece here and a little piece there, until you lost what your original working one looked like? With training functions that are minimally nested, we can control the behaviour explicitly using closures/partials easily. Even when doing experimenation, our code can run reliably and fast.
Thirdly, we saw how to apply the same lessons to training a neural network really fast with multiple starting points. The essence of the solution was to properly structure our program in progressively higher level layers of abstraction. We carefully wrote the program to go from the inner most layer out until we hit our goal of allowing for a set of multiple starts. The key here is that each level of abstraction is very natural, and corresponds to a "unit computation" being applied consistently across an "array" of things. Once we identify that "unit computation", writing the vmap
-able or lax.scan
-able function becomes very easy.