Skip to content


%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

Gradient-Based Optimization

Implicit in what you were doing was something we formally call "gradient-based optimization". This is a very important point to understand. If you get this for a linear model, you will understand how this works for more complex models. Hence, we are going to go into a small crash-course detour on what gradient-based optimization is.


At the risk of ticking off mathematicians for a sloppy definition, for this book's purposes, a useful way of defining the derivative is:

How much our output changes as we take a small step on the inputs, taken in the limit of going to very small steps.

If we have a function:

f(w) = w^2 + 3w - 5

What is the derivative of f(x) with respect to w? From first-year undergraduate calculus, we should be able to calculate this:

f'(w) = 2w + 3

As a matter of style, we will use the apostrophe marks to indicate derivatives. 1 apostrophe mark means first derivative, 2nd apostrophe mark means 2nd derivative, and so on.

Minimizing f(w) Analytically

What is the value of w that minimizes f(w)? Again, from undergraduate calculus, we know that at a minima of a function (whether it is a global or local), the first derivative will be equal to zero, i.e. f'(w) = 0. By taking advantage of this property, we can analytically solve for the value of w at the minima.

2w + 3 = 0


w = -\frac{3}{2} = 1.5

To check whether the value of w at the place where f'(w) = 0 is a minima or maxima, we can use another piece of knowledge from 1st year undergraduate calculus: The sign of the second derivative will tell us whether this is a minima or maxima.

  • If the second derivative is positive regardless of the value of w, then the point is a minima. (Smiley faces are positive!)
  • If the second derivative is negative regardless of the value of w, then the point is a maxima. (Frowning faces are negative!)


f''(w) = 2

We can see that f''(w) > 0 for all w, hence the stationary point we find is going to be a local minima.

Minimizing f(w) Computationally

An alternative way of looking at this is to take advantage of f'(w), the gradient, evaluated at a particular w. A known property of the gradient is that if you take steps in the negative direction of the gradient, you will eventually reach a function's minima. If you take small steps in the positive direction of the gradient, you will reach a function's maxima (if it exists).

Exercise: Implement gradient functions by hand

Let's implement this using the function f(w), done using NumPy.

Firstly, implement the aforementioned function f below.

# Exercise: Write f(w) as a function.

def f(w):
    """Your answer here."""
    return None

from dl_workshop.answers import f

/home/runner/work/dl-workshop/dl-workshop/src/dl_workshop/ TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)
  from tqdm.autonotebook import tqdm
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


This function is the objective function that we wish to optimize, where "optimization" means finding the minima or maxima.

Now, implement the gradient function \frac{df}{dw} below in the function df:

# Exercise: Write df(w) as a function. 
def df(w):
    """Your answer here"""
    return None

from dl_workshop.answers import df

This function is the gradient of the objective w.r.t. the parameter of interest. It will help us find out the direction in which to change the parameter w in order to optimize the objective function.

Now, pick a number at random to start with. You can specify a number explicitly, or use a random number generator to draw a number.

# Exercise: Pick a number to start w at.
w = 10.0  # start with a float

This gives us a starting point for optimization.

Finally, write an "optimization loop", in which you adjust the value of w in the negative direction of the gradient of f w.r.t. w (i.e. \frac{df}{dw}).

# Now, adjust the value of w 1000 times, taking small steps in the negative direction of the gradient.
for i in range(1000):
    w = w - df(w) * 0.01  # 0.01 is the size of the step taken.


Congratulations, you have just implemented gradient descent!

Gradient descent is an optimization routine: a way of programming a computer to do optimization for you so that you don't have to do it by hand.

Minimizing f(w) with jax

jax is a Python package for automatically computing gradients; it provides what is known as an "automatic differentiation" system on top of the NumPy API. This way, we do not have to specify the gradient function by hand-calculating it; rather, jax will know how to automatically take the derivative of a Python function w.r.t. the first argument, leveraging the chain rule to help calculate gradients. With jax, our example above is modified in only a slightly different way:

from jax import grad
import jax
from tqdm.autonotebook import tqdm

# This is what changes: we use autograd's `grad` function to automatically return a gradient function.
df = grad(f)

# Exercise: Pick a number to start w at.
w = -10.0

# Now, adjust the value of w 1000 times, taking small steps in the negative direction of the gradient.
for i in range(1000):
    w = w - df(w) * 0.01  # 0.01 is the size of the step taken.



In this section, we saw one way to program a computer to automatically leverage gradients to find the optima of a polynomial function. This builds our knowledge and intuition for the next section, in which we find the optimal point of a linear regression loss function.