Eric J Ma's Website

Numba: My first attempt at being serious with it

written by Eric J. Ma on 2017-02-08 | tags: numba open source data science optimization coding snippets


This evening, I saw a Tweet about using numba, and I thought, it's about time I give it a proper shot. I had been solving some dynamic programming problems just for fun, and I thought this would be a good test case for numba's capabilities.

The DP problem I was trying to solve was that of collecting apples on a grid. Here's how the problem is posed:

I have a number of apples distributed randomly on a grid. I start at the top-left hand corner, and I'm only allowed to move downwards or to the right. Along the way, I pick up apples. What's the maximum number of apples I can pick up along the way?

This is a classic 2-dimensional DP problem. I simulated some random integers:

n = 200
arr = np.random.randint(low=0, high=100, size=n**2).reshape(n, n)

I then wrote out my solution, and wrapped it in two versions of the function call: one native and one numba-JIT'd.

# Let's collect apples.

from numba import jit

@jit
def collect_apples(arr):
    sum_apples = np.zeros(shape=arr.shape)
    for row in range(arr.shape[0]):
        for col in range(arr.shape[1]):
            if col != 0:
                val_left = arr[row, col - 1]
            else:
                val_left = 0
            if row != 0:
                val_up = arr[row - 1, col]
            else:
                val_up = 0
            sum_apples[row, col] = arr[row, col] + max(val_left, val_up)
    return sum_apples

def collect_apples_nonjit(arr):
    sum_apples = np.zeros(shape=arr.shape)
    for row in range(arr.shape[0]):
        for col in range(arr.shape[1]):
            if col != 0:
                val_left = arr[row, col - 1]
            else:
                val_left = 0
            if row != 0:
                val_up = arr[row - 1, col]
            else:
                val_up = 0
            sum_apples[row, col] = arr[row, col] + max(val_left, val_up)
    return sum_apples

Here's the performance results:

%%timeit
collect_apples(arr)

The slowest run took 4.27 times longer than the fastest. This could mean that an intermediate result is being cached. 10000 loops, best of 3: 99.7 µs per loop

%%timeit
collect_apples_nonjit(arr)

10 loops, best of 3: 50.3 ms per loop

Wow! Over 500-fold speedup! All obtained for free using the @jit decorator.


I send out a newsletter with tips and tools for data scientists. Come check it out at Substack.

If you would like to sponsor the coffee that goes into making my posts, please consider GitHub Sponsors!

Finally, I do free 30-minute GenAI strategy calls for organizations who are seeking guidance on how to best leverage this technology. Consider booking a call on Calendly if you're interested!