Note
Go to the end to download the full example code or to run this example in your browser via Binder
Training a Bayesian neural network
In this example, we train a Bayesian neural network to learn the function \(f(x)=x^2\)
First, we have to import the necessary modules.
# Default imports
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import UQpy.scientific_machine_learning as sml
torch.manual_seed(123)
We define the network architecture using the nn.Sequential object
and instantiate the BayesianNeuralNetwork.
width = 8
network = nn.Sequential(
sml.BayesianLinear(1, width),
nn.ReLU(),
sml.BayesianLinear(width, width),
nn.ReLU(),
sml.BayesianLinear(width, 1),
)
model = sml.FeedForwardNeuralNetwork(network)
With the neural network defined, we turn our attention to the training data. We want to learn the function \(f(x)=x^2\) and define the training data using the pytorch Dataset and Dataloader.
For more information on defining the training data, see the pytorch documentation at https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
class QuadraticDataset(Dataset):
def __init__(self, n_samples=200):
self.n_samples = n_samples
self.x = torch.linspace(-5, 5, n_samples, dtype=torch.float).reshape(-1, 1)
self.y = self.x**2
def __len__(self):
return self.n_samples
def __getitem__(self, item):
return self.x[item], self.y[item]
Before we continue with training the network, let’s get the initial prediction of the neural network on the data.
initial_prediction = model(QuadraticDataset().x)
So far we have the neural network and training data. The BBBTrainer combines the two along with a
pytorch optimization algorithm to learn the network parameters.
We instantiate the BBBTrainer, train the network, then print the initial and final loss alongside a model summary.
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_data = DataLoader(QuadraticDataset(), batch_size=40, shuffle=True)
trainer = sml.BBBTrainer(model, optimizer)
print("Starting Training...", end="")
trainer.run(train_data=train_data, epochs=500, beta=1e-5, num_samples=10)
print("done")
print("Initial loss:", trainer.history["train_loss"][0])
print("Final loss:", trainer.history["train_loss"][-1])
model.summary()
We compare the initial and final predictions and plot the loss history using matplotlib.
x = QuadraticDataset().x
y = QuadraticDataset().y
model.train(False)
model.sample(False)
final_prediction = model(x)
fig, ax = plt.subplots()
ax.plot(
x.detach().numpy(),
initial_prediction.detach().numpy(),
label="Initial Prediction",
color="tab:blue",
)
ax.plot(
x.detach().numpy(),
final_prediction.detach().numpy(),
label="Final Prediction",
color="tab:orange",
)
ax.plot(
x.detach().numpy(),
y.detach().numpy(),
label="Exact",
color="black",
linestyle="dashed",
)
ax.set_title("Initial and Final NN Predictions")
ax.set(xlabel="x", ylabel="f(x)")
ax.legend()
train_loss = trainer.history["train_loss"].detach().numpy()
fig, ax = plt.subplots()
ax.semilogy(train_loss)
ax.set_title("Bayes By Backpropagation Training Loss")
ax.set(xlabel="Epoch", ylabel="Loss")
plt.show()
The Bayesian neural network is a probabilistic model. Each of its parameters, in this case weights and biases,
are governed by Gaussian distributions. We can get a deterministic output from the BNN by setting
model.sample(False), which sets each parameter to the mean of its distribution.
We can obtain error bars on model’s output by sampling the parameters from their governing distribution.
This is done by setting model.sample(True) and computing the forward model evaluation many times,
then computing the sample variance
model.sample(False)
print("BNN is deterministic:", model.is_deterministic())
mean = model(x)
model.sample(True)
print("BNN is deterministic:", model.is_deterministic())
n = 10_000
samples = torch.zeros(len(x), n)
for i in range(n):
samples[:, i] = model(x).squeeze()
variance = torch.var(samples, dim=1)
standard_deviation = torch.sqrt(variance)
x_plot = x.squeeze().detach().numpy()
mu = mean.squeeze().detach().numpy()
sigma = standard_deviation.squeeze().detach().numpy()
fig, ax = plt.subplots()
ax.plot(x_plot, mu, label="$\mu$")
ax.plot(x_plot, y.detach().numpy(), label="Exact", color="black", linestyle="dashed")
ax.fill_between(
x_plot, mu - (3 * sigma), mu + (3 * sigma), label="$\mu \pm 3\sigma$,", alpha=0.3
)
ax.set_title("Bayesian Neural Network $\mu \pm 3\sigma$")
ax.set(xlabel="x", ylabel="f(x)")
ax.legend()
plt.show()
Total running time of the script: ( 0 minutes 0.000 seconds)