Skip to content

Binder

How useful are infinitely wide neural networks?

In this notebook, I would like to explore the use of the Neural Tangents package for doing cheminformatics tasks.

For the non-cheminformatician reading this, cheminformatics tasks should be treated as a surrogate for complex modelling problems that do not have an obvious input-output relationship structure that can be easily written in an equation, hence necessitating the use of a neural network.

The goal here is not to see how we can achieve state-of-the-art (SOTA) performance, but instead empirically explore how infinitely wide neural networks could possibly be leveraged to improve a data scientist's workflow and solve complex problems. As such, you won't see us try to tweak hyperparameters to squeeze out maximum performance. Rather, you'll see what we hope to be is a systematic exploration of how neural network architectures affect their infinite-width limit performance.

Motivation for this notebook

Given the popularity of neural networks and their productive use in applications, it sort of begs the question, why would one want to be interested in infinitely-wide neural networks as an applied topic?

Practical implications for predictive modelling

Putting aside the really fascinating theoretical connections between Gaussian Processes and infinitely wide neural networks, I saw some potentially attractive practical implications of infinitely wide neural networks. In particular, having read the original paper, and studying the code, I had this hunch that infinitely wide neural networks could save us deep learners from a wide swathe of Sisyphean hyperparameter tuning tasks, such as tuning learning rate, number of hidden nodes, number of epochs, and more.

Curiousity surrounding kernels

Having read the paper, I had this hunch that infinitely wide neural networks should give us different Gaussian Process kernels depending on the neural network architecture. As we know from GP literature, the kernel, which is a function of the distance between data points in input space, dictates the behaviour of a GP. Kernels are compositionally additive, and can thus be used to simultaneously model short-range and long-range behaviour, at least in 1-D examples that I have seen.

Try MoleculeNet's ESOL

The dataset that we will focus on here is MoleculeNet's solubility dataset.

%load_ext autoreload
%autoreload 2
%load_ext watermark
import pandas as pd 
import janitor
import janitor.chemistry

from jax.config import config
import seaborn as sns

sns.set_context("talk")

config.update("jax_enable_x64", True)

The dataset is already well-cleaned; the column that is the input to our models will be the "smiles" column (converted into a 2048-bit Morgan-2 fingerprint), while the target column is the "measured log solubility in mols per litre" column.

# Reset index to guarantee contiguous indexing.
# Will be useful later.
data = pd.read_csv("https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv").smiles2mol("smiles", "mol").reset_index(drop=True)
data.head()
Compound ID ESOL predicted log solubility in mols per litre Minimum Degree Molecular Weight Number of H-Bond Donors Number of Rings Number of Rotatable Bonds Polar Surface Area measured log solubility in mols per litre smiles mol
0 Amigdalin -0.974 1 457.432 7 3 7 202.32 -0.77 OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)... <rdkit.Chem.rdchem.Mol object at 0x7f2f771e76a0>
1 Fenfuram -2.885 1 201.225 1 2 2 42.24 -3.30 Cc1occc1C(=O)Nc2ccccc2 <rdkit.Chem.rdchem.Mol object at 0x7f2f771e7640>
2 citral -2.579 1 152.237 0 0 4 17.07 -2.06 CC(C)=CCCC(C)=CC(=O) <rdkit.Chem.rdchem.Mol object at 0x7f2f771e75e0>
3 Picene -6.618 2 278.354 0 5 0 0.00 -7.87 c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43 <rdkit.Chem.rdchem.Mol object at 0x7f2f771e7580>
4 Thiophene -2.232 2 84.143 0 1 0 0.00 -1.33 c1ccsc1 <rdkit.Chem.rdchem.Mol object at 0x7f2f771e7520>
import jax.numpy as np
morgans = data.morgan_fingerprint("mol", radius=3, nbits=2048)
morgans = np.array(morgans)
morgans
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

DeviceArray([[0., 1., 0., ..., 0., 0., 0.],
             [0., 0., 0., ..., 0., 0., 0.],
             [0., 0., 0., ..., 0., 0., 0.],
             ...,
             [0., 0., 0., ..., 0., 0., 0.],
             [0., 1., 0., ..., 0., 0., 0.],
             [0., 0., 0., ..., 0., 0., 0.]], dtype=float64)
task = "SOLUBILITY"
target_col_mapping = {
    "LIPO": "exp",
    "SOLUBILITY": "measured log solubility in mols per litre",
}
targets = data[target_col_mapping[task]]
targets = np.array(targets).reshape(-1, 1)
targets.shape
(1128, 1)

Writing infinitely wide neural networks

Infinitely wide neural networks are written using the neural tangents library developed by Google Research. It is based on JAX, and provides a neural network library that lets us analytically obtain the infinite-width kernel corresponding to the particular neural network architecture specified. This is known as a Neural Network Gaussian Process (NNGP) kernel.

I will primarily be concerned with the NNGP kernel rather than the Neural Tangent Kernel (NTK). Mostly for simplicity's sake; being the lazy person I am, I'll try avoiding adding an orthogonal and combinatorial axis of complexity wherever possible.

Here is an example infinitely-wide neural network with two hidden layers, each with 4096 hidden nodes.

from jax import jit 
import neural_tangents as nt
from neural_tangents import stax

N_HIDDEN = 4096

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(N_HIDDEN, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(N_HIDDEN, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(1),
)

apply_fn = jit(apply_fn)
kernel_fn = jit(kernel_fn, static_argnums=(2,))

Initial model params

For benchmarking purposes, we want to obtain the parameters at initialization in order to compare how the NNGP kernel conditioned on observed data compares against the initialized, unfitted model.

from jax import random as rnd
key = rnd.PRNGKey(71)
input_shape, params = init_fn(key, input_shape=(2048,))

Preprocessing of data

We are going to use a fixed split, again, for simplicity purposes. To take advantage of JAX's deterministic random number generation capabilities, I have written my own train/test split function.

We will bring all of our inputs into a standardized Gaussian space, mostly to ensure that our inputs are nicely scaled to roughly the same magnitudes. (Something something good numerical behaviour.)

from sklearn.preprocessing import StandardScaler


def train_test_split(key, x, y, train_fraction=0.2):
    indices = rnd.shuffle(key, np.arange(len(x)))
    n_train = int(np.floor(len(x) * train_fraction))
    return indices, n_train


train_fraction = 0.8
indices, n_train = train_test_split(key, morgans, targets, train_fraction)
x_train = morgans[indices][:n_train]
y_train = targets[indices][:n_train]
train_samples = data.loc[indices][:n_train]

x_test = morgans[indices][n_train:]
y_test = targets[indices][n_train:]
test_samples = data.loc[indices][n_train:]

ss = StandardScaler()
x_train = ss.fit_transform(x_train)
x_test = ss.transform(x_test)
/usr/share/miniconda/envs/essays-on-data-science/lib/python3.9/site-packages/jax/_src/random.py:478: FutureWarning: jax.random.shuffle is deprecated and will be removed in a future release. Use jax.random.permutation
  warnings.warn(msg, FutureWarning)

Obtaining the NNGP and NTK kernels

Following the notebook examples released by the neural tangents team, we now get the NNGP/NTK mean and covariance.

But wait, didn't I say I wasn't going to care about NTK kernels? Yeah, I guess I waffled a bit here. I was copying and pasting code at this point.

# We generate the prediction function
# which is conditioned on the training data.
predict_fn = nt.predict.gradient_descent_mse_ensemble(
    kernel_fn,
    x_train,
    y_train,
    diag_reg=1e-4,
    learning_rate=0.1
)

nngp_mean, nngp_covariance = predict_fn(
    x_test=x_test, 
    get='nngp', 
    compute_cov=True
)

nngp_std = np.sqrt(np.diag(nngp_covariance))
ntk_mean, ntk_covariance = predict_fn(
    x_test=x_test, 
    get='ntk', 
    compute_cov=True
)

ntk_std = np.sqrt(np.diag(ntk_covariance))

Compare against standard models

You can't have a benchmarking exercise without having some models to compare against. We'll compare the NNGP/NTK means against predictions from a fitted Random Forest Regressor and a Gaussian Process Regressor, taken from our favourite standard machine learning library scikit-learn.

from sklearn.ensemble import RandomForestRegressor

rfr = RandomForestRegressor(n_estimators=200, n_jobs=-1)
rfr.fit(x_train, y_train)
/tmp/ipykernel_2917/3912822998.py:4: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().
  rfr.fit(x_train, y_train)

RandomForestRegressor(n_estimators=200, n_jobs=-1)
from sklearn.gaussian_process import GaussianProcessRegressor

gpr = GaussianProcessRegressor()
gpr.fit(x_train, y_train)
GaussianProcessRegressor()

Drumroll please...

Now for our first benchmarking exericse: how do NNGPs and NTKs perform against standard models, and against the initialized neural network itself?

We'll be using the explained variance score to measure performance. Why not others? Well, again, I'm a bit lazy, but there's another actually good reason here. Firstly, explained variance is a measure of "goodness of fit" that is scale-free. It is defined as:

1 - \frac{var(residuals)}{var(target)}

I wrote a blog post before on why I think the variance explained metric is a pretty neat one, even if it's not the most widely understood. The tl;dr of why is that it's a scale-free metric that is easily interpretable.

import matplotlib.pyplot as plt
import seaborn as sns 

from sklearn.metrics import r2_score as r2 , explained_variance_score as evs

FIG_ALPHA = 0.1
SCOREFUNC = evs

fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(20, 4), sharex=True, sharey=True)
axes = axes.flatten()
y_preds = apply_fn(params, x_test)
axes[0].scatter(y_test, y_preds, label="untrained", alpha=FIG_ALPHA)
score = SCOREFUNC(y_preds.squeeze(), y_test.squeeze())
axes[0].set_title(f"Untrained: {score:.2f}")
axes[0].set_xlabel("Actual")
axes[0].set_ylabel("Predicted")

axes[1].errorbar(y_test, ntk_mean, yerr=ntk_std * 3, marker="o", ls="", alpha=FIG_ALPHA, label="ntk")
score = SCOREFUNC(ntk_mean.squeeze(), y_test.squeeze())
axes[1].set_title(f"NTK: {score:.2f}")
axes[1].set_xlabel("Actual")


axes[2].errorbar(y_test, nngp_mean, yerr=nngp_std * 3, marker="o", ls="", alpha=FIG_ALPHA, label="nngp")
score = SCOREFUNC(nngp_mean.squeeze(), y_test.squeeze())
axes[2].set_title(f"NNGP: {score:.2f}")
axes[2].set_xlabel("Actual")

y_preds = rfr.predict(x_test)
axes[3].scatter(y_test, y_preds, alpha=FIG_ALPHA, label="random forest")
score = SCOREFUNC(y_preds.squeeze(), y_test.squeeze())
axes[3].set_title(f"sklearn RF: {score:.2f}")
axes[3].set_xlabel("Actual")

y_preds = gpr.predict(x_test)
axes[4].scatter(y_test, y_preds, alpha=FIG_ALPHA, label="random forest")
score = SCOREFUNC(y_preds.squeeze(), y_test.squeeze())
axes[4].set_title(f"sklearn GP: {score:.2f}")
axes[4].set_xlabel("Actual")


def get_x_eq_y(axes):
    minvals = []
    maxvals = []
    for ax in axes.flatten():
        xlims, ylims = ax.get_xlim(), ax.get_ylim()
        minval = min(min(xlims), min(ylims))
        maxval = max(max(xlims), max(ylims))
        minvals.append(minval)
        maxvals.append(maxval)
    return min(minvals), max(maxvals)


minval, maxval = get_x_eq_y(axes)
for ax in axes.flatten():
    ax.plot([minval, maxval], [minval, maxval], color="black", ls="--")
    ax.set_xticks(ax.get_yticks())
    sns.despine()
plt.tight_layout()

In terms of raw performance, the scikit-learn random forest model performs the best, while the scikit-learn Gaussian Process model performs the worst. This is all done without any hyperparameter tweaking though. Again, laziness is the primary reason why I didn't go deeper into optimizing these baseline models, though, you know, using something like pycaret to satisfy my laziness would be a totally valid way to go.

Now, at this point, we might be tempted to say that the NNGP model doesn't compare well vs. the Random Forest model. However, we haven't done any kind of exploration of the parameter space for the neural network models, so such a conclusion would be premature.

Optimizing the infinitely wide neural network

There's still some pieces of the infinitely wide neural network that I think we need to get a practical handle over. In particular, how does the neural network performance, and the uncertainties, change as a function of:

  • The number of nodes in a hidden layer?
  • The depth of the neural network?
  • The activation function that we use?

That's what we're going to explore in this section of the notebook.

It's always good to go into an inquiry with some expected range of answers. Here's mine, stated up front:

Number of nodes: The number of nodes shouldn't matter. We're talking about expanding each layer to infinite width, after all.

Depth: Depth might affect the structure of the corresponding NNGP kernel. However, I am not well-versed enough in the math to know exactly how.

Activation function: I suspect the activation function will also affect the structure of the NNGP kernel. A hunch tells me it may have the weakest effect compared to depth, but I can't back that up with any theory. It's just a hunch.

Let me first set up some utility functions below to help out with coding later.

# Some utility functions to help out

import jax.numpy as np
import neural_tangents as nt
from neural_tangents import stax
from jax import jit 


class NNGP:
    """sklearn-compatible NNGP class."""
    def __init__(self, n_layers:int = 1, nonlinearity=stax.Relu()):
        # Storing for information only
        self.n_layers = n_layers
        self.nonlinearity = nonlinearity

        init_fn, apply_fn, kernel_fn = make_inf_nn(n_layers, 300, nonlinearity=nonlinearity)
        self.init_fn = init_fn
        self.apply_fn = apply_fn
        self.kernel_fn = kernel_fn
        self.predict_fn = None

    def __repr__(self):
        return f"NNGP()"

    def fit(self, X, y):
        if len(y.shape) == 1:
            y = y.reshape(-1, 1)
        self.predict_fn = nt.predict.gradient_descent_mse_ensemble(
            self.kernel_fn, X, y, diag_reg=1e-4
        )

    def predict(self, X):
        mean = self.predict_fn(x_test=X, get="nngp", compute_cov=False)
        return mean


def compute_mean_std(kernel_fn, x_train, y_train, x_test):
    """Function to compute mean and std function."""
    predict_fn = nt.predict.gradient_descent_mse_ensemble(
        kernel_fn, x_train, y_train, diag_reg=1e-4
    )

    nngp_mean, nngp_covariance = predict_fn(x_test=x_test, get="nngp", compute_cov=True)

    nngp_std = np.sqrt(np.diag(nngp_covariance))
    return nngp_mean, nngp_std


def plot_results(SCOREFUNC, y_test, mean, std, ax, n_layers, n_hidden, FIG_ALPHA):
    """Helper plotting function."""
    ax.errorbar(y_test, mean, yerr=std * 3, marker="o", ls="", alpha=FIG_ALPHA)
    score = SCOREFUNC(mean.squeeze(), y_test.squeeze())
    ax.set_title(f"Score: {score:.2f}, {n_layers} layers,\n{n_hidden} hidden nodes")


def make_inf_nn(n_layers: int, n_hidden: int, nonlinearity=stax.Erf()):
    """Convenience function to make infinitely wide NN of different architectures."""
    layers = []
    for _ in range(n_layers):
        layers.append(stax.Dense(n_hidden))
        if nonlinearity:
            layers.append(nonlinearity)
    init_fn, apply_fn, kernel_fn = stax.serial(
        *layers,
        stax.Dense(1),
    )
    apply_fn = jit(apply_fn)
    kernel_fn = jit(kernel_fn, static_argnums=(2,))
    return init_fn, apply_fn, kernel_fn

Experiment setup

For number of layers, we're going to explore 1, 10 and 20 layers.

For number of nodes per hidden layer, we'll explore what happens with 1, 512, and 4096 nodes. (I mean, it's gonna go to infinity, so why would I try thousands of hidden nodes, right?)

For non-linearities, we'll do:

  • No non-linearities (i.e. simple linear projections)
  • Erf non-linearity
  • Relu non-linearity
  • Gelu non-linearity

The three non-linearities are available in neural tangents, so we're using them for convenience. At 4x3x3, we're already at 36 combinations, which I think is more than enough to see any patterns.

Linear projections (a.k.a. no non-linearities)

Let's start with linear projections and see what happens.

from itertools import product 

NUM_LAYERS = [1, 10, 20]
NUM_HIDDEN = [1, 512, 4096]

fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(12, 12), sharex=True, sharey=True)

for i, (n_layers, n_hidden) in enumerate(product(NUM_LAYERS, NUM_HIDDEN)):
    init_fn, apply_fn, kernel_fn = make_inf_nn(n_layers, n_hidden, nonlinearity=None)
    mean, std = compute_mean_std(kernel_fn, x_train, y_train, x_test)

    plot_results(SCOREFUNC, y_test, mean, std, axes.flatten()[i], n_layers, n_hidden, FIG_ALPHA)

minval, maxval = get_x_eq_y(axes)
for ax in axes.flatten():
    ax.plot([minval, maxval], [minval, maxval], color="black", ls="--")
    sns.despine()

for ax in axes[:, 0]:
    ax.set_ylabel("Predicted")

for ax in axes[-1, :]:
    ax.set_xlabel("Actual")

plt.tight_layout()

Apparently, with no pure linear projections, we:

  1. Do much better with the performance as measured by posterior expectation values.
  2. See that layer depth and number of hidden nodes have zero effect on performance and uncertainty widths.

Erf nonlinearity

Now let's see what happens when we use the Erf nonlinearity, which was the default activation function that the neural tangents demo notebooks use.

fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(12, 12), sharex=True, sharey=True)

for i, (n_layers, n_hidden) in enumerate(product(NUM_LAYERS, NUM_HIDDEN)):
    init_fn, apply_fn, kernel_fn = make_inf_nn(n_layers, n_hidden)
    mean, std = compute_mean_std(kernel_fn, x_train, y_train, x_test)

    plot_results(SCOREFUNC, y_test, mean, std, axes.flatten()[i], n_layers, n_hidden, FIG_ALPHA)

minval, maxval = get_x_eq_y(axes)
for ax in axes.flatten():
    ax.plot([minval, maxval], [minval, maxval], color="black", ls="--")
    sns.despine()

for ax in axes[:, 0]:
    ax.set_ylabel("Predicted")

for ax in axes[-1, :]:
    ax.set_xlabel("Actual")

plt.tight_layout()

We see here that the Erf non-linearity gives us poorer performance when compared to having no non-linearities as the activation function.

Additionally, there is a trend w.r.t. network depth. The deeper the network, the poorer the performance of the corresponding NNGP kernel. I suspect there may be some kind of overfitting that happens. How the overfitting happens, though, remains something of inquiry for me to learn. (My current hypothesis hinges on the hunch that deeper networks correspond to stacked GPs, though I would need further confirmation from someone better at math than I am to see if that is correct.)

An intriguing thing I see as well is that as we add depth to the network, the resulting NNGP kernel yields smaller uncertainties. Anthropomorphically, we would say that the neural network is more confident of its predictions even as it becomes generally more wrong. We should call these networks Dunning-Kruger networks.

Try with Relu non-linearity

Let's now try a popular non-linearity, the Relu non-linearity.

fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(12, 12), sharex=True, sharey=True)

for i, (n_layers, n_hidden) in enumerate(product(NUM_LAYERS, NUM_HIDDEN)):
    init_fn, apply_fn, kernel_fn = make_inf_nn(n_layers, n_hidden, nonlinearity=stax.Relu())
    mean, std = compute_mean_std(kernel_fn, x_train, y_train, x_test)

    plot_results(SCOREFUNC, y_test, mean, std, axes.flatten()[i], n_layers, n_hidden, FIG_ALPHA)

minval, maxval = get_x_eq_y(axes)
for ax in axes.flatten():
    ax.plot([minval, maxval], [minval, maxval], color="black", ls="--")
    sns.despine()

for ax in axes[:, 0]:
    ax.set_ylabel("Predicted")

for ax in axes[-1, :]:
    ax.set_xlabel("Actual")

plt.tight_layout()

A really cool result shows up here: we have better performance using the Relu activation function! The same trend w.r.t. number of hidden layers also holds: as we increase the number of hidden layers, the performance decreases, and the uncertainty decreases. The network becomes a Dunning-Kruger network.

Try with Gelu nonlinearity

fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(12, 12), sharex=True, sharey=True)

for i, (n_layers, n_hidden) in enumerate(product(NUM_LAYERS, NUM_HIDDEN)):
    init_fn, apply_fn, kernel_fn = make_inf_nn(n_layers, n_hidden, nonlinearity=stax.Gelu())
    mean, std = compute_mean_std(kernel_fn, x_train, y_train, x_test)

    plot_results(SCOREFUNC, y_test, mean, std, axes.flatten()[i], n_layers, n_hidden, FIG_ALPHA)

minval, maxval = get_x_eq_y(axes)
for ax in axes.flatten():
    ax.plot([minval, maxval], [minval, maxval], color="black", ls="--")
    sns.despine()

for ax in axes[:, 0]:
    ax.set_ylabel("Predicted")

for ax in axes[-1, :]:
    ax.set_xlabel("Actual")

plt.tight_layout()

The Gelu network is a highly interesting one! As the depth increases, the correctness of the network, i.e. performance, doesn't go down, but the uncertainties do become smaller. Anthropomorphically, the network remains just as wrong about predictions, and becomes more confident of them. As usual, the number of hidden nodes affect neither the model's expected performance nor the uncertainties.

Depth

The observable trend here is that going with deeper nets that have non-identity activation functions leads to overfitting on the training set; shallower nets have the best performance. This holds in general, with the Gelu network being the exception of the activation functions tested here.

Activation functions

Of the three tested inside the nt.stax library of nonlinearities, the Gelu activation function provided the best performance based on the posterior expectation, with the Relu activation function a close second.

The effect of activation functions on model performance was dramatic, and this was something out of expectation based on my stated prior beliefs.

What the effect of activation functions are on the width of the posterior uncertainties is something we will investigate next.

Number of hidden nodes

From the get-go, I expected that the number of hidden nodes declared in the network constructor should not play a role with NNGP performance, as after all, we're taking the number of hidden nodes to the infinite limit. As the results show, whether we use 1, 2, or 3 hidden nodes in the hidden layer, the infinite-width performance is identical.

Comparison between non-linearities' effects on uncertainties

Having seen that neural networks begin to exhibit the Dunning-Kruger effect as they get more powerful (at least in the limit of infinite width), we will limit further probing of our model to just the single layer depth networks with a single hidden node. I'm playing with fire here, just to prove the point. Take that, overparameterized networks, we're going infinite now.

To investigate how the nonlinearities affect uncertainty, we will plot the distribution of standard deviations of the test set predictions. It will be an ECDF plot, so that we can cram everything into one plot and compare directly. (Histograms are so 1950s!)

from neural_tangents import stax
nonlinearities = {
    "none": None,
    "ReLu": stax.Relu(),
    "GeLu": stax.Gelu(),
    "Erf": stax.Erf(),
}

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(4, 4))


def ecdf(data):
    x, y = np.sort(data), np.arange(1, len(data)+1) / len(data)
    return x, y


for i, (name, nonlinearity) in enumerate(nonlinearities.items()):
    init_fn, apply_fn, kernel_fn = make_inf_nn(n_layers=1, n_hidden=1, nonlinearity=nonlinearity)
    mean, std = compute_mean_std(kernel_fn, x_train, y_train, x_test)
    x, y = ecdf(std)
    ax.plot(x, y, label=name)
ax.set_xlabel("Std. Dev")
ax.set_ylabel("Cumulative Fraction")
ax.set_title("Distribution\nof uncertainties")
ax.legend()
sns.despine()

What do we see here? Well, to begin with, infinitely wide neural networks with no nonlinearities end up being the most uncertain about their predictions. The ReLu and GeLu network end up showing practically identical distributions of uncertainties in test set predictions. Meanwhile, the Erf network, which was the worst of all of the networks tested, showed a tighter distribution of uncertainties, though roughly at the same scale as the other networks.

How do we measure the "performance" of the uncertainties?

Well, one way to do so is to ask what fraction of the uncertainties' range capture the correct value. We have the standard deviation of predictions from our GP; 3x the standard deviation gives us about 96% of the posterior probability. If the mean ± 3x standard deviation encapsulates the true value, then we know our posterior standard deviation was highly likely to capture the true value, even if the expected value was off.

def evaluate_uncertainty(mean: float, std: float, y_test: float):
    lt = y_test < mean + std * 3
    gt = y_test > mean - std * 3
    return np.logical_and(lt, gt)
from jax import vmap

results = dict()
for i, (name, nonlinearity) in enumerate(nonlinearities.items()):
    init_fn, apply_fn, kernel_fn = make_inf_nn(n_layers=1, n_hidden=1, nonlinearity=nonlinearity)
    mean, std = compute_mean_std(kernel_fn, x_train, y_train, x_test)
    result = vmap(evaluate_uncertainty)(mean, std, y_test)
    results[name] = result.squeeze()
import pandas as pd

pd.DataFrame(results).mean().plot(kind="bar", rot=45)
ax = plt.gca()
ax.set_xlabel("Activation")
ax.set_ylabel("P(correct)")
ax.set_ylim(0, 1)
sns.despine()

As we can see above, the ReLu function has the best performing uncertainties, with 70% of the posteriors harbouring the correct value. One might say that the ReLu infinite networks are best calibrated for prediction purposes.

You'd think that the linear projection networks would have well-calibrated uncertainties given their widths, and that's certainly something I thought would happen, but that turned out not to be the case.

Using the uncertainties for decision-making

In a chemistry setting, we're usually concerned with prioritization tasks. A chemist's time for testing molecules is usually finite, so we need some help rank-ordering items in a principled fashion. Well, turns out when you have Gaussian uncertainties, you can do this pretty easily.

If you know anything about Bayesian optimization, we need to have an acquisition function that helps us rank-order our predictions. The acquisition function that gets chosen is going to depend on the task at hand. Because we're dealing with solubility as a task, and in medicinal chemistry high solubility is a desirable property, we can rank order our test samples according to the probability that they are above some threshold value. This is effectively evaluating the survival function, or 1 - CDF(threshold), of a Gaussian.

test_indices = indices[n_train:]
from jax.scipy.stats import norm
def prob_above_threshold(mean: float, std: float, threshold: float):
    return 1 - norm.cdf(x=threshold, loc=mean, scale=std)
from functools import partial
prob_thresh = dict()
for i, (name, nonlinearity) in enumerate(nonlinearities.items()):
    init_fn, apply_fn, kernel_fn = make_inf_nn(n_layers=1, n_hidden=1, nonlinearity=nonlinearity)
    mean, std = compute_mean_std(kernel_fn, x_train, y_train, x_test)
    result = vmap(partial(prob_above_threshold, threshold=-0.1))(mean, std)
    prob_thresh[name] = result.squeeze()

Let's now query for molecules in the test set that are predicted to have greater than 50% probability of having solubility greater than 0. This would be akin to helping a chemist prioritize the selection of molecules where we'd be happy to bet at 1:1 odds of having a desirable property, given the model's predictions.

Remember that the uncertainties dataframe is indexed differently, so we'll need to do some fancy indexing to get back to the indexing that is used in the original dataframe.

prob_thresh_df = pd.DataFrame(prob_thresh).sort_values("ReLu", ascending=False).query("ReLu > 0.5")
prob_thresh_df
none ReLu GeLu Erf
133 1.000000 1.000000 1.000000 1.000000
158 1.000000 0.991346 0.999885 0.999994
78 1.000000 0.980182 0.920877 1.000000
121 0.999974 0.538650 0.629214 1.000000
39 1.000000 0.536250 0.007561 0.206573
test_samples.loc[test_indices[prob_thresh_df.index.values]]
Compound ID ESOL predicted log solubility in mols per litre Minimum Degree Molecular Weight Number of H-Bond Donors Number of Rings Number of Rotatable Bonds Polar Surface Area measured log solubility in mols per litre smiles mol
679 Glycerol 0.688 1 92.094 3 0 2 60.69 1.120 OCC(O)CO <rdkit.Chem.rdchem.Mol object at 0x7f2f76af3400>
785 Urea 0.832 1 60.056 2 0 0 69.11 0.960 NC(=O)N <rdkit.Chem.rdchem.Mol object at 0x7f2f76af5c40>
298 Raffinose 0.496 1 504.438 11 3 8 268.68 -0.410 OCC1OC(CO)(OC2OC(COC3OC(CO)C(O)C(O)C3O)C(O)C(O... <rdkit.Chem.rdchem.Mol object at 0x7f2f76b6a2e0>
145 2-hydroxypteridine -1.404 1 148.125 1 2 0 71.79 -1.947 Oc2ncc1nccnc1n2 <rdkit.Chem.rdchem.Mol object at 0x7f2f76b66880>
983 Ethanol 0.020 1 46.069 1 0 0 20.23 1.100 CCO <rdkit.Chem.rdchem.Mol object at 0x7f2f76afa7c0>

Hey, not bad! From what I know about chemistry, glycerol is highly soluble in water, and is used by microbiologists to protect frozen bacteria cells from being cracked open upon freeze/thawing cycles. Urea is definitely a molecule that is soluble in water; we make it all the time in our bodies as a waste product, and is carried out of our body in our urine. Raffinose is a sugar, and so with all of the oxygens and hydrogens in there, it'll likely be dissolvable in water. 2-hydroxy-something always sounds like it'll be soluble in water, and Ethanol... well, we know how that one goes :).

Conclusions

Overall, this deep dive into infinite-width neural networks tells me that we deep learners have a superb tool at our disposal for modelling complex relationships that don't have an obvious input-output structure. We can write a single layer neural network, expand it out to its infinite width form easily to get a Neural Net Gaussian Process kernel that, when conditioned on training data, could give us good test performance. (No guarantees though, we'd have to benchmark.) For expensive data collection problems like those that we find in the life sciences, this is a superb modelling tool on many levels.

Firstly, Gaussian processes are highly flexible and inherently Bayesian modelling tools, but proper kernel design is difficult; neural net design is an easier alternative language to work with to generate GPs for our modelling problems. Infinitely wide neural networks give us the best of both worlds.

Secondly, because we get a highly flexible function learner, complex functions are easily modelled without needing to do much work with hyperparameter tuning. Down with the Sisyphean tasks of twiddling with neural networks! (OK, maybe I'm taking the rant a bit too far.) Single layer infinitely wide neural networks did a practically great job on solubility with standard molecular descriptors, were as easy to wield as the cheminformatics-favourite Random Forest model, and had performance metrics effectively on par with Random Forests. We may be able to squeeze out more performance by adopting alternative model architectures that have different inductive biases, such as graph neural networks, but I have to admit, it was pretty easy to get to the performance levels we observed here.

Thirdly, because NNGPs are inherently Bayesian, the suite of tools from Bayesian decision science are at our disposal. Prioritizing, ranking, selecting... once you have the posterior likelihood density at hand, you have the principle by which these activities can be done. No more premature thresholds; delay all binarization until the end, when we have the posterior likelihood calculated. No fancy methods needed either; "likelihood is all you need".

Lessons Learned

From this exercise, it seems evident to me that the nonlinearity is the biggest thing that affects the kernel, which in turn affects the expected performance (as given by the posterior means) and uncertainty calibration (as measured by the posterior standard deviations). If you remember where my headspace was when we first started, I was completely wrong w.r.t. this point; the NNGP kernel is highly influenced by the choice of activation function. Where I was correct was that the number of hidden nodes per layer didn't matter one bit for the performance of the network, because we are, after all, going to infinity ~~and beyond~~.

Knowing that we can get pretty decent expected predictive performance and somewhat nicely calibrated uncertainties without needing to fiddle with layer width, number of layers, number of training epochs, learning rate, and the myriad of other hyperparameters that we might be interested in... is actually kind of a nice result. With the NNGP kernel, we basically eliminate a ton of hyperparameters in training neural nets, and can focus primarily on the non-linear activation function to design the best NNGP kernel. Hyperparameter tuning is the worst Sisyphean task that will bore the crap out of any data scientist.

Hence, from the perspective of workflow, using infinitely-wide neural networks is an attractive proposition for modelling problems that do not have an obvious input-output relationship. At the end of the day, by using infinitely wide neural networks, fitting a Gaussian Process with a particular kernel to the problem at hand, and this cuts down on the number of hyperparameters that we have to fiddle with. Pragmatically, I have a hunch that we can mostly stick with single layer NNGPs and leverage the uncertainty instead.

Finally, I have to comment on the lack of compute intensity needed to wield infinitely wide neural networks. Everything I did in this notebook was first prototyped on an M1 MacBook Air, done in emulation mode. No GPUs were used. None! The entire notebook runs from top to bottom in under two minutes. Take that, Transformers.

%watermark
Last updated: 2021-08-20T00:56:39.237345+00:00

Python implementation: CPython
Python version       : 3.9.6
IPython version      : 7.26.0

Compiler    : GCC 9.3.0
OS          : Linux
Release     : 5.4.0-1055-azure
Machine     : x86_64
Processor   : x86_64
CPU cores   : 2
Architecture: 64bit


%watermark --iversions
jax            : 0.2.9
janitor        : 0.21.0
pandas         : 1.3.2
matplotlib     : 3.4.3
neural_tangents: 0.3.6
seaborn        : 0.11.2