```
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
```

```
from IPython.display import YouTubeVideo, display
YouTubeVideo("6pnl7Eu2wN0")
```

# Eliminating for-loops that have carry-over using `lax.scan`

We are now going to see how we can eliminate for-loops that have carry-over using `lax.scan`

.

From the JAX docs, `lax.scan`

replaces a for-loop with carry-over,
with some of my own annotations added in for clarity:

Scan a function over leading array axes while carrying along state.

The semantics are described as follows:

```
def scan(f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x) # carry is the carryover
ys.append(y) # the `y`s get accumulated into a stacked array
return carry, np.stack(ys)
```

A key requirement of the function `f`

,
which is the function that gets scanned over the array `xs`

,
is that it must have only two positional arguments in there,
one for `carry`

and one for `x`

.
You'll see how we can thus apply `functools.partial`

to construct functions that have this signature
from other functions that have more arguments present.

Let's see some concrete examples of this in action.

## Example: Cumulative Summation

One example where we might use a for-loop is in the cumulative sum or product of an array. Here, we need the current loop information to update the information from the previous loop. Let's see it in action for the cumulative sum:

```
import jax.numpy as np
a = np.array([1, 2, 3, 5, 7, 11, 13, 17])
result = []
res = 0
for el in a:
res += el
result.append(res)
np.array(result)
```

This is identical to the cumulative sum:

```
np.cumsum(a)
```

Now, let's write it using `lax.scan`

, so we can see the pattern in action:

```
from jax import lax
def cumsum(res, el):
"""
- `res`: The result from the previous loop.
- `el`: The current array element.
"""
res = res + el
return res, res # ("carryover", "accumulated")
result_init = 0
final, result = lax.scan(cumsum, result_init, a)
result
```

As you can see, scanned function has to return two things:

- One object that gets carried over to the next loop (
`carryover`

), and - Another object that gets "accumulated" into an array (
`accumulated`

).

The starting initial value, `result_init`

, is passed into the `scanfunc`

as `res`

on the first call of the `scanfunc`

. On subsequent calls, the first `res`

is passed back into the `scanfunc`

as the new `res`

.

## Exercise 1: Simulating compound interest

We can use `lax.scan`

to generate data that simulates
the generation of wealth by compound interest.
Here's an implementation using a plain vanilla for-loop:

```
wealth_record = []
starting_wealth = 100.0
interest_factor = 1.01
num_timesteps = 100
prev_wealth = starting_wealth
for t in range(num_timesteps):
new_wealth = prev_wealth * interest_factor
wealth_record.append(prev_wealth)
prev_wealth = new_wealth
wealth_record = np.array(wealth_record)
```

Now, your challenge is to implement it in a `lax.scan`

form.
Implement the `wealth_at_time`

function below.

```
from functools import partial
def wealth_at_time(prev_wealth, time, interest_factor):
# The lax.scannable function to compute wealth at a given time.
# your answer here
pass
# Comment out the import to test your answer
from dl_workshop.jax_idioms import lax_scan_ex_1 as wealth_at_time
wealth_func = partial(wealth_at_time, interest_factor=interest_factor)
timesteps = np.arange(num_timesteps)
final, result = lax.scan(wealth_func, init=starting_wealth, xs=timesteps)
assert np.allclose(wealth_record, result)
```

The two are equivalent, so we know we have the `lax.scan`

implementation right.

```
import matplotlib.pyplot as plt
plt.plot(wealth_record, label="for-loop")
plt.plot(result, label="lax.scan")
plt.legend();
```

## Example: Simulating compound interest from multiple starting points

Previously, was one simulation of wealth generation by compound interest
from one starting amount of money.
Now, let's simulate the wealth generation
for different starting wealth levels;
onemay choose any 300 starting points however one likes.
This will be a demonstration of how to compose `lax.scan`

with `vmap`

to do computation without loops.

To do so, you'll likely want to start with a function
that accepts a scalar starting wealth
and generates the simulated time series from there,
and then `vmap`

that function across multiple starting points (which is an array itself).

```
from jax import vmap
def simulate_compound_interest(
starting_wealth: np.ndarray, timesteps: np.ndarray
):
final, result = lax.scan(wealth_func, init=starting_wealth, xs=timesteps)
return final, result
num_timesteps = np.arange(200)
starting_wealths = np.arange(300).astype(float)
simulation_func = partial(simulate_compound_interest, timesteps=np.arange(200))
final, growth = vmap(simulation_func)(starting_wealths)
growth.shape
```

```
plt.plot(growth[1])
plt.plot(growth[2])
plt.plot(growth[3]);
```

## Exercise 2: Stick breaking process

The stick breaking process is one that is important in Bayesian non-parametric modelling, where we want to model something that may have potentially an infinite number of components while being biased towards a smaller subset of components.

The stick-breaking process uses the following generative process:

- Take a stick of length 1.
- Draw a number between 0 and 1 from a Beta distribution (we will modify this step for this notebook).
- Break that fraction of the stick, and leave it aside in a pile.
- Repeat steps 2 and 3 with the fraction leftover after breaking the stick.

We repeat *ad infinitum* (in theory)
or until a pre-specified large number of stick breaks have happened (in practice).

In the exercise below, your task is to write the stick-breaking process
in terms of a `lax.scan`

operation.
Because we have not yet covered drawing random numbers using JAX,
the breaking fraction will be a fixed variable rather than a random variable.
Here's the vanilla NumPy + Python equivalent for you to reference.

```
# NumPy equivalent
num_breaks = 30
breaking_fraction = 0.1
sticks = []
stick_length = 1.0
for i in range(num_breaks):
stick = stick_length * breaking_fraction
sticks.append(stick)
stick_length = stick_length - stick
sticks = np.array(sticks)
sticks
```

```
def lax_scan_ex_2(num_breaks: int, frac: float):
# Your answer goes here!
pass
# Comment out the import if you want to test your answer.
from dl_workshop.jax_idioms import lax_scan_ex_2
sticksres = lax_scan_ex_2(num_breaks, breaking_fraction)
assert np.allclose(sticksres, sticks)
```