import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from torch.nn.functional import softplus
= {
rc "figure.constrained_layout.use": True,
"axes.titlesize": 20,
"figure.figsize": (12, 9),
}="darkgrid", palette="colorblind", rc=rc) sns.set_theme(style
Marginalized Neural Network (aka Neural Linear Model)
Based on the tweet by Andrew Jesson. Below is from the notebook shared by Andrew.
Thanks to David Holzmüller, Kevin Patrick Murphy, and Jason Hartford for correctly identifying this approach as the Neural Linear Model (NLM). The NLM seems to have first appeared as the Marginalized Neural Network in Marginalized Neural Network Mixtures for Large-Scale Regression and then in Scalable Bayesian Optimization Using Deep Neural Networks. Benchmarking the Neural Linear Model for Regression compares the NLM to mean field variational inference and Monte Carlo dropout Bayesian Neural Networks and concludes that methods such as NLM which do exact inference over a subset of parameters may perform better than methods that do variational inference over all parameters.
class Simulated(torch.utils.data.Dataset):
def __init__(self, n=1000, sigma=0.1, in_between=False, outliers_percentage=None):
if in_between:
= torch.distributions.Bernoulli(0.5).sample(torch.Size([n, 1]))
mask self.data = (-2 * torch.rand(n, 1).float() - 1) * mask \
+ (2 * torch.rand(n, 1).float() + 1) * (1. - mask)
else:
self.data = torch.randn(n, 1).float()
self.targets = f(self.data, sigma).float()
#Adding outliers
if (outliers_percentage):
= int((outliers_percentage/100) * n)
num_outliers = torch.distributions.Uniform(
outliers -5.0]), torch.tensor([5.0])).sample(torch.Size([num_outliers]))
torch.tensor([= torch.randint_like(outliers, high=n).long()
mask #print(outliers)
self.targets[mask] = outliers.float()
def __len__(self):
return len(self.targets)
def __getitem__(self, idx):
return self.data[idx], self.targets[idx]
def f(x, sigma=0.1):
= torch.tensor([-0.6667, -0.6012, -1.0172, -0.7687, 1.4680, -0.1678])
w = 0
fx for i in range(len(w)):
+= w[i] * (x**i)
fx *= np.sin(np.pi * x)
fx *= np.exp(-0.5 * (x**2)) / np.sqrt(2 * np.pi)
fx return (
-1)
fx.squeeze(+ torch.randn(
len(x),
)* sigma
)
= 100
num_features = 200
n_train = 0.05
dr
= torch.nn.Sequential(
model 1, num_features),
torch.nn.Linear(
torch.nn.ReLU(),
torch.nn.Dropout(dr),
torch.nn.Linear(num_features, num_features),
torch.nn.ReLU(),
torch.nn.Dropout(dr),
torch.nn.Linear(num_features, num_features),
torch.nn.ReLU(),
torch.nn.Dropout(dr),1),
torch.nn.Linear(num_features,
)= torch.nn.Parameter(-2 * torch.ones(1))
log_sigma = Simulated(n=n_train)
ds = torch.utils.data.DataLoader(
dl
ds,=32,
batch_size=True,
shuffle=True
drop_last
)= torch.optim.AdamW(
optimizer
[
{"params": model.parameters(),
"lr": 1e-3,
"weight_decay": (1 - dr)
/ (2 * len(ds)),
},
{"params": log_sigma,
"lr": 1e-2,
"weight_decay": 1e-6,
},
] )
model.train()for epoch in range(500):
= []
train_loss for batch in dl:
= batch
x, y = model(x)
mu = torch.distributions.Normal(mu, softplus(log_sigma))
dist = -dist.log_prob(y.unsqueeze(-1)).mean()
loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss.append(loss.item())if epoch % 20 == 0:
print(f"Epoch {epoch}, loss: {np.mean(train_loss):.02f}")
Epoch 0, loss: 2.83
Epoch 20, loss: 0.10
Epoch 40, loss: -0.08
Epoch 60, loss: -0.28
Epoch 80, loss: -0.36
Epoch 100, loss: -0.38
Epoch 120, loss: -0.48
Epoch 140, loss: -0.67
Epoch 160, loss: -0.67
Epoch 180, loss: -0.63
Epoch 200, loss: -0.73
Epoch 220, loss: -0.67
Epoch 240, loss: -0.70
Epoch 260, loss: -0.76
Epoch 280, loss: -0.71
Epoch 300, loss: -0.58
Epoch 320, loss: -0.84
Epoch 340, loss: -0.75
Epoch 360, loss: -0.79
Epoch 380, loss: -0.83
Epoch 400, loss: -0.79
Epoch 420, loss: -0.71
Epoch 440, loss: -0.62
Epoch 460, loss: -0.74
Epoch 480, loss: -0.89
Computing uncertainty
Need to compute \(\phi(X_{\text{test}})\) and \(\phi(X_{\text{train}})\)
= torch.linspace(-4, 4, 150).unsqueeze(-1)
x_test eval()
model.with torch.no_grad():
= model(x_test)
mu_pred = model[:-1](x_test)
phi_test = model[:-1](ds.data)
phi_train print(phi_train.shape, phi_test.shape)
torch.Size([200, 100]) torch.Size([150, 100])
Now we calculate, \(\sigma^{-2}\phi(X_{\text{train}})^{⊤} \phi(X_{\text{train}}) + \mathbf{I}\), which has shape num_features x num_features. This is different from a standard GP with size n_train x n_train.
= torch.eye(num_features)
I = phi_train.T @ phi_train / torch.square(softplus(log_sigma).detach()) + I A
Now we solve for \(\phi(x)^{\top}\left(\sigma^{-2}\phi(X_{\text{train}})^{⊤} \phi(X_{\text{train}}) + \mathbf{I}\right)^{-1}\phi(x)\). We can use the identity \(\phi(x)^{\top}\left(\sigma^{-2}\phi(X_{\text{train}})^{⊤} \phi(X_{\text{train}}) + \mathbf{I}\right)^{-1}\phi(x) = ||L^{-1}\phi(x)||^2_2\), with \(L\) the Cholesky decomposition of \(\left(\sigma^{-2}\phi(X_{\text{train}})^{⊤} \phi(X_{\text{train}}) + \mathbf{I}\right)^{-1}\). Thank you Andreas Kirsch.
As pointed out by Yarin Gal, this is in fact Bayesian Linear regression in penultimate layer feature space. Which is indeed, “sort of a deep kernel GP.”
= torch.linalg.cholesky(A)
L print(L.shape)
= torch.linalg.solve_triangular(L, phi_test.T, upper=False)
v = v.T @ v + 1e-5 * torch.eye(len(phi_test))
cov = torch.sqrt(torch.diag(cov)) f_stddev
torch.Size([100, 100])
= plt.figure(figsize=(12,9))
_ = sns.scatterplot(x=ds.data.ravel(), y=ds.targets.ravel())
_ = sns.lineplot(x=x_test.ravel(), y=mu_pred.ravel())
_ = plt.fill_between(
_ =x_test.ravel(),
x=mu_pred.ravel() + 1.96 * f_stddev,
y1=mu_pred.ravel() - 1.96 * f_stddev,
y2=0.3,
alpha
)= plt.ylim(-2, 2) _
Pretty Samples
= torch.distributions.MultivariateNormal(loc=mu_pred.ravel(), covariance_matrix=cov)
dist = dist.sample(torch.Size([20])) f_samples
= sns.scatterplot(x=ds.data.ravel(), y=ds.targets.ravel())
_ for func in f_samples:
= sns.lineplot(x=x_test.ravel(), y=func.to("cpu"), color="C0", alpha=0.1)
_ = plt.ylim(-2, 2) _
In-between uncertainty
Great comment from Yingzhen Li on in-between uncertainty! Let’s take a look.
= 100
num_features = 200
n_train = 0.05
dr
= torch.nn.Sequential(
model 1, num_features),
torch.nn.Linear(
torch.nn.ReLU(),
torch.nn.Dropout(dr),
torch.nn.Linear(num_features, num_features),
torch.nn.ReLU(),
torch.nn.Dropout(dr),
torch.nn.Linear(num_features, num_features),
torch.nn.ReLU(),
torch.nn.Dropout(dr),1),
torch.nn.Linear(num_features,
)= torch.nn.Parameter(-2 * torch.ones(1))
log_sigma = Simulated(n=n_train, in_between=True)
ds = torch.utils.data.DataLoader(
dl
ds,=32,
batch_size=True,
shuffle=True
drop_last
)= torch.optim.AdamW(
optimizer
[
{"params": model.parameters(),
"lr": 1e-3,
"weight_decay": (1 - dr)
/ (2 * len(ds)),
},
{"params": log_sigma,
"lr": 1e-2,
"weight_decay": 1e-6,
},
] )
model.train()for epoch in range(500):
= []
train_loss for batch in dl:
= batch
x, y = model(x)
mu = torch.distributions.Normal(mu, softplus(log_sigma))
dist = -dist.log_prob(y.unsqueeze(-1)).mean()
loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss.append(loss.item())if epoch % 20 == 0:
print(f"Epoch {epoch}, loss: {np.mean(train_loss):.02f}")
Epoch 0, loss: 13.52
Epoch 20, loss: 0.87
Epoch 40, loss: 0.26
Epoch 60, loss: -0.04
Epoch 80, loss: -0.14
Epoch 100, loss: -0.13
Epoch 120, loss: -0.25
Epoch 140, loss: -0.03
Epoch 160, loss: -0.42
Epoch 180, loss: -0.35
Epoch 200, loss: -0.41
Epoch 220, loss: -0.47
Epoch 240, loss: -0.36
Epoch 260, loss: -0.41
Epoch 280, loss: -0.44
Epoch 300, loss: -0.54
Epoch 320, loss: -0.48
Epoch 340, loss: -0.55
Epoch 360, loss: -0.55
Epoch 380, loss: -0.57
Epoch 400, loss: -0.46
Epoch 420, loss: -0.52
Epoch 440, loss: -0.44
Epoch 460, loss: -0.60
Epoch 480, loss: -0.52
= torch.linspace(-4, 4, 150).unsqueeze(-1)
x_test eval()
model.with torch.no_grad():
= model(x_test)
mu_pred = model[:-1](x_test)
phi_test = model[:-1](ds.data)
phi_train
= torch.eye(num_features)
I = phi_train.T @ phi_train / torch.square(softplus(log_sigma).detach()) + I
A
= torch.linalg.cholesky(A)
L = torch.linalg.solve_triangular(L, phi_test.T, upper=False)
v = v.T @ v + 1e-5 * torch.eye(len(phi_test))
cov = torch.sqrt(torch.diag(cov)) f_stddev
= plt.figure(figsize=(12,9))
_ = sns.scatterplot(x=ds.data.ravel(), y=ds.targets.ravel())
_ = sns.lineplot(x=x_test.ravel(), y=mu_pred.ravel())
_ = plt.fill_between(
_ =x_test.ravel(),
x=mu_pred.ravel() + 1.96 * f_stddev,
y1=mu_pred.ravel() - 1.96 * f_stddev,
y2=0.3,
alpha )
= torch.distributions.MultivariateNormal(loc=mu_pred.ravel(), covariance_matrix=cov)
dist = dist.sample(torch.Size([50])) f_samples
= sns.scatterplot(x=ds.data.ravel(), y=ds.targets.ravel())
_ for func in f_samples:
= sns.lineplot(x=x_test.ravel(), y=func.to("cpu"), color="C0", alpha=0.1)
_ = plt.ylim(-2, 2) _
With Outliers
Adding 5% outliers : The loss increases from -0.89 to 1.27 .
= 100
num_features = 200
n_train = 0.05
dr
= torch.nn.Sequential(
model 1, num_features),
torch.nn.Linear(
torch.nn.ReLU(),
torch.nn.Dropout(dr),
torch.nn.Linear(num_features, num_features),
torch.nn.ReLU(),
torch.nn.Dropout(dr),
torch.nn.Linear(num_features, num_features),
torch.nn.ReLU(),
torch.nn.Dropout(dr),1),
torch.nn.Linear(num_features,
)= torch.nn.Parameter(-2 * torch.ones(1))
log_sigma = Simulated(n=n_train, outliers_percentage=5)
ds = torch.utils.data.DataLoader(
dl
ds,=32,
batch_size=True,
shuffle=True
drop_last
)= torch.optim.AdamW(
optimizer
[
{"params": model.parameters(),
"lr": 1e-3,
"weight_decay": (1 - dr)
/ (2 * len(ds)),
},
{"params": log_sigma,
"lr": 1e-2,
"weight_decay": 1e-6,
},
]
)
model.train()for epoch in range(500):
= []
train_loss for batch in dl:
= batch
x, y = model(x)
mu = torch.distributions.Normal(mu, softplus(log_sigma))
dist = -dist.log_prob(y.unsqueeze(-1)).mean()
loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss.append(loss.item())if epoch % 20 == 0:
print(f"Epoch {epoch}, loss: {np.mean(train_loss):.02f}")
= torch.linspace(-4, 4, 150).unsqueeze(-1)
x_test eval()
model.with torch.no_grad():
= model(x_test)
mu_pred = model[:-1](x_test)
phi_test = model[:-1](ds.data)
phi_train print(phi_train.shape, phi_test.shape)
= torch.eye(num_features)
I = phi_train.T @ phi_train / torch.square(softplus(log_sigma).detach()) + I
A
= torch.linalg.cholesky(A)
L print(L.shape)
= torch.linalg.solve_triangular(L, phi_test.T, upper=False)
v = v.T @ v + 1e-5 * torch.eye(len(phi_test))
cov = torch.sqrt(torch.diag(cov))
f_stddev
= plt.figure(figsize=(12,9))
_ = sns.scatterplot(x=ds.data.ravel(), y=ds.targets.ravel())
_ = sns.lineplot(x=x_test.ravel(), y=mu_pred.ravel())
_ = plt.fill_between(
_ =x_test.ravel(),
x=mu_pred.ravel() + 1.96 * f_stddev,
y1=mu_pred.ravel() - 1.96 * f_stddev,
y2=0.3,
alpha
)= plt.ylim(-2, 2) _
Epoch 0, loss: 25.33
Epoch 20, loss: 5.53
Epoch 40, loss: 3.51
Epoch 60, loss: 2.53
Epoch 80, loss: 2.00
Epoch 100, loss: 1.35
Epoch 120, loss: 1.61
Epoch 140, loss: 1.49
Epoch 160, loss: 1.43
Epoch 180, loss: 1.39
Epoch 200, loss: 1.33
Epoch 220, loss: 1.34
Epoch 240, loss: 1.23
Epoch 260, loss: 1.27
Epoch 280, loss: 1.21
Epoch 300, loss: 1.32
Epoch 320, loss: 1.29
Epoch 340, loss: 1.27
Epoch 360, loss: 1.29
Epoch 380, loss: 1.28
Epoch 400, loss: 1.29
Epoch 420, loss: 1.28
Epoch 440, loss: 1.22
Epoch 460, loss: 1.17
Epoch 480, loss: 1.27
torch.Size([200, 100]) torch.Size([150, 100])
torch.Size([100, 100])