%load_ext autoreload %autoreload 2 %matplotlib inline %config InlineBackend.figure_format = 'retina'
Writing linear models with
In this notebook, I'll show the code for how to use JAX's
stax submodule to write arbitrary models.
I'm assuming you have read through the
jax-programming.ipynb notebook, as well as the
tutorial.ipynb notebook gives you a general introduction to differential programming using
grad, while the
jax-programming.ipynb notebook gives you a flavour of the other four main JAX idioms:
Most deep learning libraries use objects as the data structure for a neural network layer.
As such, the tunable parameters of the layer, for example
b for a linear ("dense") layer
are class attributes associated with the forward function.
In some sense, because a neural network layer is nothing more than a math function,
specifying the layer in terms of a function might also make sense.
stax, then, is a new take on writing neural network models using pure functions rather than objects.
The way that
stax layers work is as follows.
Every neural network layer is nothing more than a math function with a "forward" pass.
Neural network models typically have their parameters
initialized into the right shapes using random number generators.
Put these two together, and we have a pair of functions that specify a layer:
init_funfunction, that initializes parameters into the correct shapes, and
apply_funfunction, that applies the specified math transformations onto incoming data, using parameters of the correct shape.
Example: Linear layer
Let's see an example of this in action, by studying the implementation of the linear ("dense") layer in
from jax.experimental import stax
As you can see, the
apply_fun specifies the linear transformation.
It accepts a parameter called
which gets tuple-unpacked into the appropriate
Notice how the
params argument matches up with the second output of
init_fun always accepts an
rng parameter, which is returned from JAX's
It also accepts an
which specifies what the elementary shape of one sample of data is.
So if your entire dataset is of shape
then you would put in
(n_columns,) inside there,
as you would want to ignore the sample dimension,
thus allowing us to take advantage of
vmap to map our model function
over each and every i.i.d. sample in our dataset.
init_fun also returns the
which is used later when we chain layers together.
Let's see how we can use the Dense layer to specify a linear regression model.
Create the initialization and application function pairs
Firstly, we create the
init_fun, apply_fun = stax.Dense(1)
Initialize the parameters
Now, let's initialize parameters using the
Let's assume that we have data that is of 4 columns only.
from jax import random, numpy as np key = random.PRNGKey(42) output_shape, params_initial = init_fun(key, input_shape=(4,))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(DeviceArray([[ 0.43186307], [ 0.9563157 ], [ 0.9000483 ], [-0.32341173]], dtype=float32), DeviceArray([0.01369469], dtype=float32))
Apply parameters and data through function
We'll create some randomly generated data.
X = random.normal(key, shape=(200, 4)) X[0:5], X.shape
(DeviceArray([[ 0.03751167, 0.8777058 , -1.2008178 , -1.3824965 ], [-0.37519178, 1.5957882 , -1.5086783 , -0.75612265], [-0.51650995, -1.056697 , 0.99382603, -0.3520223 ], [-1.4446373 , -0.51537496, 0.30910558, -0.31770143], [ 0.9590066 , -0.6170032 , -0.37160665, -0.7339001 ]], dtype=float32), (200, 4))
y_true values that I've snuck in.
y_true = np.dot(X, np.array([1, 2, 3, 4])) + 5 y_true = y_true.reshape(-1, 1) y_true[0:5], y_true.shape
(DeviceArray([[-2.3395162 ], [ 0.26585913], [ 3.9434848 ], [ 2.1811237 ], [ 0.6745796 ]], dtype=float32), (200, 1))
Now, we'll pass data through the linear model!
from jax import vmap from functools import partial y_pred = vmap(partial(apply_fun, params_initial))(X) y_pred[0:5], y_pred.shape
(DeviceArray([[ 0.23557997], [ 0.26439613], [-0.21156326], [-0.72209364], [-0.2593076 ]], dtype=float32), (200, 1))
Voilà! We have a simple linear model implemented just like that.
Next question: how do we optimize the parameters using JAX?
Instead of writing a training loop on our own, we can take advantage of JAX's optimizers, which are also written in a functional paradigm!
JAX's optimizers are constructed as a "triplet" set of functions:
paramsand initializes them in as a
state, which is structured in a fashion that
updatecan operate on.
update: Takes in
state, which respectively are:
i: The current loop iteration
g: Gradients calculated from
state: The current state of the parameters.
get_params: Takes in the
stateat a given point, and returns the parameters structured correctly.
from jax import jit, grad from jax.experimental.optimizers import adam init, update, get_params = adam(step_size=1e-1) update = jit(update) get_params = jit(get_params)
We're still missing a piece here, that is the loss function. For illustration purposes, let's use the mean squared error.
def mseloss(params, model, x, y_true): y_preds = vmap(partial(model, params))(x) return np.mean(np.power(y_preds - y_true, 2)) dmseloss = grad(mseloss)
"Step" portion of update loop
Now, we're going to define the "step" portion of the update loop.
from dl_workshop.stax_models import step step??
Because it takes so many parameters (in order to remain pure, and not rely on notebook state),
we're going to bind some of them using
I'm also going to show you what happens when we JIT-compile vs. don't JIT-compile the function.
step_partial = partial(step, get_params=get_params, dlossfunc=dmseloss, update=update, model=apply_fun, x=X, y_true=y_true) step_partial_jit = jit(step_partial)
Firstly, let's see what kind of code we'd write if we did write the loop explicitly.
from time import time start = time() state = init(params_initial) for i in range(1000): params = get_params(state) g = dmseloss(params, apply_fun, X, y_true) state = update(i, g, state) end = time() print(end - start)
Partialled out loop step
Now, let's run the loop with the partialled out function.
start = time() state = init(params_initial) for i in range(1000): state = step_partial(i, state) end = time() print(end - start)
This is much cleaner of a loop, but we did have to do some work up-front.
What happens if we now use the JIT-ed function?
start = time() state = init(params_initial) for i in range(1000): state = step_partial_jit(i, state) end = time() print(end - start)
Whoa, holy smokes, that's fast! At least 10X faster using JIT-compilation.
Now we'll use some JAX trickery ot write a training loop without ever writing a for-loop.
from dl_workshop.stax_models import make_scannable_step make_scannable_step??
from jax import lax scannable_step = make_scannable_step(step_partial_jit) start = time() initial_state = init(params_initial) final_state, states_history = lax.scan(scannable_step, initial_state, np.arange(1000)) end = time() print(end - start)
(DeviceArray([[1. ], [2.0000002], [3. ], [3.9999995]], dtype=float32), DeviceArray([5.0000005], dtype=float32))
vmap-ed training loop over multiple starting points
Now, we're going to do the ultimate: we'll create at least 100 different parameter initializations and run our training loop over each of them.
from dl_workshop.stax_models import make_training_start make_training_start??
from jax import lax train_linear = make_training_start(partial(init_fun, input_shape=(-1, 4)), init, scannable_step, 1000) start = time() N_INITIALIZATIONS = 100 initialization_keys = random.split(key, N_INITIALIZATIONS) final_states, states_histories = vmap(train_linear)(initialization_keys) end = time() print(end - start)
w_final, b_final = vmap(get_params)(final_states) w_final.squeeze()[0:5]
DeviceArray([[1.0000002, 2. , 2.9999998, 4.0000005], [1.0000001, 2. , 3. , 4.0000005], [1.0000002, 2. , 3. , 4.000001 ], [1.0000001, 2. , 3. , 3.9999998], [1.0000001, 2. , 3.0000002, 4. ]], dtype=float32)
DeviceArray([5.000001, 5.000001, 5.000001, 5.000001, 5.000001], dtype=float32)
Looks like we were also able to run the whole optimization pretty fast, and recover the correct parameters over multiple training starts.
JIT-compiled training loop
What happens if we JIT-compile the vmapped initialization?
start = time() N_INITIALIZATIONS = 100 initialization_keys = random.split(key, N_INITIALIZATIONS) train_linear_jit = jit(train_linear) final_states, states_histories = vmap(train_linear_jit)(initialization_keys) vmap(get_params)(final_states) # this line exists to just block the computation until it completes. end = time() print(end - start)
HOOOOOLY SMOKES! Did you see that? With JIT-compilation, we essentially took the training time down to be identical to training on one starting point. Naturally, I don't expect this result to hold 100% of the time, but it's pretty darn rad to see that live.
The craziest piece here is that we could
vmap our training loop over multiple starting points and get massive speedups there.