%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
Writing linear models with stax
In this notebook, I'll show the code for how to use JAX's stax
submodule to write arbitrary models.
Prerequisites
I'm assuming you have read through the jax-programming.ipynb
notebook, as well as the tutorial.ipynb
notebook.
The main tutorial.ipynb
notebook gives you a general introduction to differential programming using grad
, while the jax-programming.ipynb
notebook gives you a flavour of the other four main JAX idioms: vmap
, lax.scan
, random.PRNGKey
, and jit
.
What is stax
?
Most deep learning libraries use objects as the data structure for a neural network layer.
As such, the tunable parameters of the layer, for example w
and b
for a linear ("dense") layer
are class attributes associated with the forward function.
In some sense, because a neural network layer is nothing more than a math function,
specifying the layer in terms of a function might also make sense.
stax
, then, is a new take on writing neural network models using pure functions rather than objects.
How does stax
work?
The way that stax
layers work is as follows.
Every neural network layer is nothing more than a math function with a "forward" pass.
Neural network models typically have their parameters
initialized into the right shapes using random number generators.
Put these two together, and we have a pair of functions that specify a layer:
- An
init_fun
function, that initializes parameters into the correct shapes, and - An
apply_fun
function, that applies the specified math transformations onto incoming data, using parameters of the correct shape.
Example: Linear layer
Let's see an example of this in action, by studying the implementation of the linear ("dense") layer in stax
from jax.experimental import stax
stax.Dense??
As you can see, the apply_fun
specifies the linear transformation.
It accepts a parameter called params
,
which gets tuple-unpacked into the appropriate W
and b
.
Notice how the params
argument matches up with the second output of init_fun
!
The init_fun
always accepts an rng
parameter, which is returned from JAX's jax.random.PRNGKey()
.
It also accepts an input_shape
parameter,
which specifies what the elementary shape of one sample of data is.
So if your entire dataset is of shape (n_samples, n_columns)
,
then you would put in (n_columns,)
inside there,
as you would want to ignore the sample dimension,
thus allowing us to take advantage of vmap
to map our model function
over each and every i.i.d. sample in our dataset.
The init_fun
also returns the output_shape
,
which is used later when we chain layers together.
Let's see how we can use the Dense layer to specify a linear regression model.
Create the initialization and application function pairs
Firstly, we create the init_fun
and apply_fun
pair:
init_fun, apply_fun = stax.Dense(1)
Initialize the parameters
Now, let's initialize parameters using the init_fun
.
Let's assume that we have data that is of 4 columns only.
from jax import random, numpy as np
key = random.PRNGKey(42)
output_shape, params_initial = init_fun(key, input_shape=(4,))
params_initial
Apply parameters and data through function
We'll create some randomly generated data.
X = random.normal(key, shape=(200, 4))
X[0:5], X.shape
Here's some y_true
values that I've snuck in.
y_true = np.dot(X, np.array([1, 2, 3, 4])) + 5
y_true = y_true.reshape(-1, 1)
y_true[0:5], y_true.shape
Now, we'll pass data through the linear model!
apply_fun??
from jax import vmap
from functools import partial
y_pred = vmap(partial(apply_fun, params_initial))(X)
y_pred[0:5], y_pred.shape
VoilĂ ! We have a simple linear model implemented just like that.
Optimization
Next question: how do we optimize the parameters using JAX?
Instead of writing a training loop on our own, we can take advantage of JAX's optimizers, which are also written in a functional paradigm!
JAX's optimizers are constructed as a "triplet" set of functions:
init
: Takesparams
and initializes them in as astate
, which is structured in a fashion thatupdate
can operate on.update
: Takes ini
,g
, andstate
, which respectively are:i
: The current loop iterationg
: Gradients calculated fromgrad
!state
: The current state of the parameters.
get_params
: Takes in thestate
at a given point, and returns the parameters structured correctly.
from jax import jit, grad
from jax.experimental.optimizers import adam
init, update, get_params = adam(step_size=1e-1)
update = jit(update)
get_params = jit(get_params)
Loss Function
We're still missing a piece here, that is the loss function. For illustration purposes, let's use the mean squared error.
def mseloss(params, model, x, y_true):
y_preds = vmap(partial(model, params))(x)
return np.mean(np.power(y_preds - y_true, 2))
dmseloss = grad(mseloss)
"Step" portion of update loop
Now, we're going to define the "step" portion of the update loop.
from dl_workshop.stax_models import step
step??
JIT compilation
Because it takes so many parameters (in order to remain pure, and not rely on notebook state),
we're going to bind some of them using functools.partial
.
I'm also going to show you what happens when we JIT-compile vs. don't JIT-compile the function.
step_partial = partial(step, get_params=get_params, dlossfunc=dmseloss, update=update, model=apply_fun, x=X, y_true=y_true)
step_partial_jit = jit(step_partial)
Explicit loops
Firstly, let's see what kind of code we'd write if we did write the loop explicitly.
from time import time
start = time()
state = init(params_initial)
for i in range(1000):
params = get_params(state)
g = dmseloss(params, apply_fun, X, y_true)
state = update(i, g, state)
end = time()
print(end - start)
Partialled out loop step
Now, let's run the loop with the partialled out function.
start = time()
state = init(params_initial)
for i in range(1000):
state = step_partial(i, state)
end = time()
print(end - start)
JIT-compiled loop!
This is much cleaner of a loop, but we did have to do some work up-front.
What happens if we now use the JIT-ed function?
start = time()
state = init(params_initial)
for i in range(1000):
state = step_partial_jit(i, state)
end = time()
print(end - start)
Whoa, holy smokes, that's fast! At least 10X faster using JIT-compilation.
lax.scan
loop
Now we'll use some JAX trickery ot write a training loop without ever writing a for-loop.
from dl_workshop.stax_models import make_scannable_step
make_scannable_step??
from jax import lax
scannable_step = make_scannable_step(step_partial_jit)
start = time()
initial_state = init(params_initial)
final_state, states_history = lax.scan(scannable_step, initial_state, np.arange(1000))
end = time()
print(end - start)
get_params(final_state)
vmap
-ed training loop over multiple starting points
Now, we're going to do the ultimate: we'll create at least 100 different parameter initializations and run our training loop over each of them.
from dl_workshop.stax_models import make_training_start
make_training_start??
from jax import lax
train_linear = make_training_start(partial(init_fun, input_shape=(-1, 4)), init, scannable_step, 1000)
start = time()
N_INITIALIZATIONS = 100
initialization_keys = random.split(key, N_INITIALIZATIONS)
final_states, states_histories = vmap(train_linear)(initialization_keys)
end = time()
print(end - start)
w_final, b_final = vmap(get_params)(final_states)
w_final.squeeze()[0:5]
b_final.squeeze()[0:5]
Looks like we were also able to run the whole optimization pretty fast, and recover the correct parameters over multiple training starts.
JIT-compiled training loop
What happens if we JIT-compile the vmapped initialization?
start = time()
N_INITIALIZATIONS = 100
initialization_keys = random.split(key, N_INITIALIZATIONS)
train_linear_jit = jit(train_linear)
final_states, states_histories = vmap(train_linear_jit)(initialization_keys)
vmap(get_params)(final_states) # this line exists to just block the computation until it completes.
end = time()
print(end - start)
HOOOOOLY SMOKES! Did you see that? With JIT-compilation, we essentially took the training time down to be identical to training on one starting point. Naturally, I don't expect this result to hold 100% of the time, but it's pretty darn rad to see that live.
The craziest piece here is that we could vmap
our training loop over multiple starting points and get massive speedups there.