Skip to content

Binder

%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

Gradient-Based Optimization

Implicit in what you were doing was something we formally call "gradient-based optimization". This is a very important point to understand. If you get this for a linear model, you will understand how this works for more complex models. Hence, we are going to go into a small crash-course detour on what gradient-based optimization is.

Derivatives

At the risk of ticking off mathematicians for a sloppy definition, for this book's purposes, a useful way of defining the derivative is:

How much our output changes as we take a small step on the inputs, taken in the limit of going to very small steps.

If we have a function:

f(w) = w^2 + 3w - 5

What is the derivative of f(x) with respect to w? From first-year undergraduate calculus, we should be able to calculate this:

f'(w) = 2w + 3

As a matter of style, we will use the apostrophe marks to indicate derivatives. 1 apostrophe mark means first derivative, 2nd apostrophe mark means 2nd derivative, and so on.

Minimizing f(w) Analytically

What is the value of w that minimizes f(w)? Again, from undergraduate calculus, we know that at a minima of a function (whether it is a global or local), the first derivative will be equal to zero, i.e. f'(w) = 0. By taking advantage of this property, we can analytically solve for the value of w at the minima.

2w + 3 = 0

Hence,

w = -\frac{3}{2} = 1.5

To check whether the value of w at the place where f'(w) = 0 is a minima or maxima, we can use another piece of knowledge from 1st year undergraduate calculus: The sign of the second derivative will tell us whether this is a minima or maxima.

  • If the second derivative is positive regardless of the value of w, then the point is a minima. (Smiley faces are positive!)
  • If the second derivative is negative regardless of the value of w, then the point is a maxima. (Frowning faces are negative!)

Hence,

f''(w) = 2

We can see that f''(w) > 0 for all w, hence the stationary point we find is going to be a local minima.

Minimizing f(w) Computationally

An alternative way of looking at this is to take advantage of f'(w), the gradient, evaluated at a particular w. A known property of the gradient is that if you take steps in the negative direction of the gradient, you will eventually reach a function's minima. If you take small steps in the positive direction of the gradient, you will reach a function's maxima (if it exists).

Exercise: Implement gradient functions by hand

Let's implement this using the function f(w), done using NumPy.

Firstly, implement the aforementioned function f below.

# Exercise: Write f(w) as a function.

def f(w):
    """Your answer here."""
    return None

from dl_workshop.answers import f

f(2.5)
/home/runner/work/dl-workshop/dl-workshop/src/dl_workshop/answers.py:34: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)
  from tqdm.autonotebook import tqdm
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

8.75

This function is the objective function that we wish to optimize, where "optimization" means finding the minima or maxima.

Now, implement the gradient function \frac{df}{dw} below in the function df:

# Exercise: Write df(w) as a function. 
def df(w):
    """Your answer here"""
    return None

from dl_workshop.answers import df
df(2.5)
8.0

This function is the gradient of the objective w.r.t. the parameter of interest. It will help us find out the direction in which to change the parameter w in order to optimize the objective function.

Now, pick a number at random to start with. You can specify a number explicitly, or use a random number generator to draw a number.

# Exercise: Pick a number to start w at.
w = 10.0  # start with a float

This gives us a starting point for optimization.

Finally, write an "optimization loop", in which you adjust the value of w in the negative direction of the gradient of f w.r.t. w (i.e. \frac{df}{dw}).

# Now, adjust the value of w 1000 times, taking small steps in the negative direction of the gradient.
for i in range(1000):
    w = w - df(w) * 0.01  # 0.01 is the size of the step taken.

print(w)
-1.4999999806458753

Congratulations, you have just implemented gradient descent!

Gradient descent is an optimization routine: a way of programming a computer to do optimization for you so that you don't have to do it by hand.

Minimizing f(w) with jax

jax is a Python package for automatically computing gradients; it provides what is known as an "automatic differentiation" system on top of the NumPy API. This way, we do not have to specify the gradient function by hand-calculating it; rather, jax will know how to automatically take the derivative of a Python function w.r.t. the first argument, leveraging the chain rule to help calculate gradients. With jax, our example above is modified in only a slightly different way:

from jax import grad
import jax
from tqdm.autonotebook import tqdm

# This is what changes: we use autograd's `grad` function to automatically return a gradient function.
df = grad(f)

# Exercise: Pick a number to start w at.
w = -10.0

# Now, adjust the value of w 1000 times, taking small steps in the negative direction of the gradient.
for i in range(1000):
    w = w - df(w) * 0.01  # 0.01 is the size of the step taken.

print(w)
-1.5000029

Summary

In this section, we saw one way to program a computer to automatically leverage gradients to find the optima of a polynomial function. This builds our knowledge and intuition for the next section, in which we find the optimal point of a linear regression loss function.