```
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
```

# Writing linear models with `stax`

In this notebook, I'll show the code for how to use JAX's `stax`

submodule to write arbitrary models.

## Prerequisites

I'm assuming you have read through the `jax-programming.ipynb`

notebook, as well as the `tutorial.ipynb`

notebook.

The main `tutorial.ipynb`

notebook gives you a general introduction to differential programming using `grad`

, while the `jax-programming.ipynb`

notebook gives you a flavour of the other four main JAX idioms: `vmap`

, `lax.scan`

, `random.PRNGKey`

, and `jit`

.

## What is `stax`

?

Most deep learning libraries use objects as the data structure for a neural network layer.
As such, the tunable parameters of the layer, for example `w`

and `b`

for a linear ("dense") layer
are class attributes associated with the forward function.

In some sense, because a neural network layer is nothing more than a math function,
specifying the layer in terms of a function might also make sense.
`stax`

, then, is a new take on writing neural network models using pure functions rather than objects.

## How does `stax`

work?

The way that `stax`

layers work is as follows.
Every neural network layer is nothing more than a math function with a "forward" pass.
Neural network models typically have their parameters
initialized into the right shapes using random number generators.
Put these two together, and we have a *pair* of functions that specify a layer:

- An
`init_fun`

function, that*initializes*parameters into the correct shapes, and - An
`apply_fun`

function, that*applies*the specified math transformations onto incoming data, using parameters of the correct shape.

## Example: Linear layer

Let's see an example of this in action, by studying the implementation of the linear ("dense") layer in `stax`

```
from jax.experimental import stax
```

```
stax.Dense??
```

As you can see, the `apply_fun`

specifies the linear transformation.
It accepts a parameter called `params`

,
which gets tuple-unpacked into the appropriate `W`

and `b`

.

**Notice how the params argument matches up with the second output of init_fun!**
The

`init_fun`

always accepts an `rng`

parameter, which is returned from JAX's `jax.random.PRNGKey()`

.
It also accepts an `input_shape`

parameter,
which specifies what the elementary shape of one sample of data is.
So if your entire dataset is of shape `(n_samples, n_columns)`

,
then you would put in `(n_columns,)`

inside there,
as you would want to ignore the sample dimension,
thus allowing us to take advantage of `vmap`

to map our model function
over each and every i.i.d. sample in our dataset.
The `init_fun`

also returns the `output_shape`

,
which is used later when we chain layers together.Let's see how we can use the Dense layer to specify a linear regression model.

### Create the initialization and application function pairs

Firstly, we create the `init_fun`

and `apply_fun`

pair:

```
init_fun, apply_fun = stax.Dense(1)
```

### Initialize the parameters

Now, let's initialize parameters using the `init_fun`

.

Let's assume that we have data that is of 4 columns only.

```
from jax import random, numpy as np
key = random.PRNGKey(42)
output_shape, params_initial = init_fun(key, input_shape=(4,))
```

```
params_initial
```

### Apply parameters and data through function

We'll create some randomly generated data.

```
X = random.normal(key, shape=(200, 4))
X[0:5], X.shape
```

Here's some `y_true`

values that I've snuck in.

```
y_true = np.dot(X, np.array([1, 2, 3, 4])) + 5
y_true = y_true.reshape(-1, 1)
y_true[0:5], y_true.shape
```

Now, we'll pass data through the linear model!

```
apply_fun??
```

```
from jax import vmap
from functools import partial
y_pred = vmap(partial(apply_fun, params_initial))(X)
y_pred[0:5], y_pred.shape
```

VoilĂ ! We have a simple linear model implemented just like that.

## Optimization

Next question: how do we *optimize* the parameters using JAX?

Instead of writing a training loop on our own, we can take advantage of JAX's optimizers, which are also written in a functional paradigm!

JAX's optimizers are constructed as a "triplet" set of functions:

`init`

: Takes`params`

and initializes them in as a`state`

, which is structured in a fashion that`update`

can operate on.`update`

: Takes in`i`

,`g`

, and`state`

, which respectively are:`i`

: The current loop iteration`g`

: Gradients calculated from`grad`

!`state`

: The current state of the parameters.

`get_params`

: Takes in the`state`

at a given point, and returns the parameters structured correctly.

```
from jax import jit, grad
from jax.experimental.optimizers import adam
init, update, get_params = adam(step_size=1e-1)
update = jit(update)
get_params = jit(get_params)
```

### Loss Function

We're still missing a piece here, that is the loss function. For illustration purposes, let's use the mean squared error.

```
def mseloss(params, model, x, y_true):
y_preds = vmap(partial(model, params))(x)
return np.mean(np.power(y_preds - y_true, 2))
dmseloss = grad(mseloss)
```

### "Step" portion of update loop

Now, we're going to define the "step" portion of the update loop.

```
from dl_workshop.stax_models import step
step??
```

### JIT compilation

Because it takes so many parameters (in order to remain pure, and not rely on notebook state),
we're going to bind some of them using `functools.partial`

.

I'm also going to show you what happens when we JIT-compile vs. don't JIT-compile the function.

```
step_partial = partial(step, get_params=get_params, dlossfunc=dmseloss, update=update, model=apply_fun, x=X, y_true=y_true)
step_partial_jit = jit(step_partial)
```

### Explicit loops

Firstly, let's see what kind of code we'd write if we *did* write the loop explicitly.

```
from time import time
start = time()
state = init(params_initial)
for i in range(1000):
params = get_params(state)
g = dmseloss(params, apply_fun, X, y_true)
state = update(i, g, state)
end = time()
print(end - start)
```

### Partialled out loop step

Now, let's run the loop with the partialled out function.

```
start = time()
state = init(params_initial)
for i in range(1000):
state = step_partial(i, state)
end = time()
print(end - start)
```

### JIT-compiled loop!

This is much cleaner of a loop, but we did have to do some work up-front.

What happens if we now use the JIT-ed function?

```
start = time()
state = init(params_initial)
for i in range(1000):
state = step_partial_jit(i, state)
end = time()
print(end - start)
```

Whoa, holy smokes, that's fast! At least 10X faster using JIT-compilation.

`lax.scan`

loop

Now we'll use some JAX trickery ot write a training loop without ever writing a for-loop.

```
from dl_workshop.stax_models import make_scannable_step
make_scannable_step??
```

```
from jax import lax
scannable_step = make_scannable_step(step_partial_jit)
start = time()
initial_state = init(params_initial)
final_state, states_history = lax.scan(scannable_step, initial_state, np.arange(1000))
end = time()
print(end - start)
```

```
get_params(final_state)
```

`vmap`

-ed training loop over multiple starting points

Now, we're going to do the ultimate: we'll create at least 100 different parameter initializations and run our training loop over each of them.

```
from dl_workshop.stax_models import make_training_start
make_training_start??
```

```
from jax import lax
train_linear = make_training_start(partial(init_fun, input_shape=(-1, 4)), init, scannable_step, 1000)
start = time()
N_INITIALIZATIONS = 100
initialization_keys = random.split(key, N_INITIALIZATIONS)
final_states, states_histories = vmap(train_linear)(initialization_keys)
end = time()
print(end - start)
```

```
w_final, b_final = vmap(get_params)(final_states)
w_final.squeeze()[0:5]
```

```
b_final.squeeze()[0:5]
```

Looks like we were also able to run the whole optimization pretty fast, *and* recover the correct parameters over multiple training starts.

### JIT-compiled training loop

What happens if we JIT-compile the vmapped initialization?

```
start = time()
N_INITIALIZATIONS = 100
initialization_keys = random.split(key, N_INITIALIZATIONS)
train_linear_jit = jit(train_linear)
final_states, states_histories = vmap(train_linear_jit)(initialization_keys)
vmap(get_params)(final_states) # this line exists to just block the computation until it completes.
end = time()
print(end - start)
```

HOOOOOLY SMOKES! Did you see that? With JIT-compilation, we essentially took the training time down to be identical to training on one starting point. Naturally, I don't expect this result to hold 100% of the time, but it's pretty darn rad to see that live.

The craziest piece here is that we could `vmap`

our training loop over multiple starting points and get massive speedups there.