No items found.

Training a Multilayer Perceptron in Collimator

Training a Multilayer Perceptron in CollimatorTraining a Multilayer Perceptron in Collimator

Multilayer Perceptrons (MLPs), a class of feedforward artificial neural networks, stand at the core of deep learning. Comprising multiple layers of nodes in a directed graph, with each layer fully connected to the next, MLPs embody the simple yet powerful architecture for approximating complex functions.

MLPs are versatile in their applications, ranging from simple tasks like linear regression to complex challenges in image recognition, natural language processing, and beyond. Their ability to learn non-linear relationships and patterns from data makes them invaluable in various fields including finance for predicting stock prices, healthcare for diagnosing diseases from medical images, and technology for powering voice recognition systems.

Key Features:

  • Function Approximation: At their heart, MLPs are universal function approximators. Given enough data and computational resources, they can learn to approximate an unknown function that maps any set of inputs to outputs
  • Classification and Regression: MLPs are adept at both classification, where they predict discrete labels, and regression, where they forecast continuous quantities.
  • Flexibility and Scalability: The architecture of MLPs can be easily adjusted to suit the complexity of the task at hand. By modifying the number of layers and nodes, MLPs can be scaled to improve their learning capacity.

In this tutorial, we will demonstrate how to configure and train an MLP in Collimator.

Import Necessary Libraries

Let's start by importing all the necessary libraries. We'll need [.code]matplotlib[.code] for plotting, [.code]jax[.code] and [.code]equinox[.code] for model definition and training, [.code]optax[.code] for optimization, and [.code]collimator[.code] for simulation. Additionally, we use [.code]sklearn[.code] to generate and split our dataset.

%matplotlib inline
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import equinox as eqx
import optax

import collimator
from collimator.library import MLP

from sklearn.datasets import make_regression # to make dataset
from sklearn.model_selection import train_test_split

%load_ext autoreload
%autoreload 2

Preparing the Dataset

For this demonstration, we'll create a synthetic regression dataset using scikit-learn's [.code]make_regression[.code] function. This dataset will then be split into training and testing sets to evaluate our model's performance.

n_samples = 500
n_input_features = 4
n_output_features = 2

# Generate data from sklearn's make_regression
X, Y = make_regression(n_samples, n_input_features, n_targets=n_output_features, noise=10.0, random_state=42)

# split data into test and train
train_x, test_x, train_y, test_y = train_test_split(X, Y, random_state=84)

print(f"{train_x.shape=}")
print(f"{train_y.shape=}")
train_x.shape=(375, 4)
train_y.shape=(375, 2)

Collimator MLP block configuration and initialization

The MLP block in Collimator is based on the MLP implementation in the equinox library. As such, the configuration parmameters closely follow the [.code]equinox[.code] implementation.

Here, we define the structure and configuration of our MLP model. We then create an instance of the model (the MLP block) with the specified configuration. This model will be used for both training and prediction.

Create a function to evaluate MLP with a given input and  MLP parameters

For training and prediction, it is imperative to create custom functions that can evaluate our MLP model given an input [.code]x[.code] and the trainable MLP parameters [.code]mlp_params[.code]. To achieve this, we first create a [.code]base_context[.code] for the MLP block created above. We do this by arbitrarily fixing the input port of the block. In this [.code]base_context[.code], the initialized parameters of the MLP block (generated through the provided [.code]seed[.code]) will be represented. Subsequently, in the [.code]evaluate_model[.code] function, when a new input [.code]x[.code] and a new set of [.code]mlp_params[.code] is available, we update the [.code]base_context[.code] with the new [.code]mlp_params[.code], fix the input port to the new input [.code]x[.code] and evaluate the block's output port.

# create a base context that can be modified later
with model.input_ports[0].fixed(jnp.zeros(n_input_features)):
    base_context = model.create_context()

# Function to evaluate the output of the MLP given MLP params and input x
def evaluate_model(x, mlp_params, base_context):
    context = base_context.with_parameter("mlp_params", mlp_params)
    with model.input_ports[0].fixed(x):
        y = model.output_ports[0].eval(context)
    return y

Visualize prediction from initial MLP parameters

Before training, it's insightful to visualize the model's predictions with its initial parameters. This gives us a baseline performance indicator.

# Extract the mlp parameters from the block
initial_mlp_parameters, _ = eqx.partition(model.mlp, eqx.is_array)

# Predict for training set
pred_y_train = jax.vmap(evaluate_model, in_axes=(0,None,None))(train_x, initial_mlp_parameters, base_context)
print(f"{pred_y_train.shape=}")

# Predict for test set
pred_y_test = jax.vmap(evaluate_model, in_axes=(0,None,None))(test_x, initial_mlp_parameters, base_context)
print(f"{pred_y_test.shape=}")

def plot_predictions(pred_y, dataset="test"):
    if dataset=="train":
        fig, (ax1, ax2) = plt.subplots(1,2,figsize=(10,4))
        ax1.plot(train_y[:,0], pred_y[:,0], 'ro', label=r"$y_1$")
        ax1.plot(train_y[:,0], train_y[:,0], '-r')
        ax2.plot(train_y[:,1], pred_y[:,1], 'bo', label=r"$y_1$")
        ax2.plot(train_y[:,1], train_y[:,1], '-b')
        fig.suptitle("Train")
    elif dataset=="test":
        fig, (ax1, ax2) = plt.subplots(1,2,figsize=(10,4))
        ax1.plot(test_y[:,0], pred_y[:,0], 'rs', label=r"$y_1$")
        ax1.plot(test_y[:,0], test_y[:,0], '-r')
        ax2.plot(test_y[:,1], pred_y[:,1], 'bs', label=r"$y_1$")
        ax2.plot(test_y[:,1], test_y[:,1], '-b')
        fig.suptitle("Test")
    else:
        raise ValueError("Invalid dataset")

    for ax in (ax1, ax2):
        ax.set_xlabel("true")
        ax.set_ylabel("predicted")
        ax.legend()

    plt.show()

plot_predictions(pred_y_train, dataset="train")
plot_predictions(pred_y_test, dataset="test")
pred_y_train.shape=(375, 2)
pred_y_test.shape=(125, 2)

Train the network

In this section, we define the training loop with a mean-squared loss. We leverage JAX's JIT compilation and automatic differentiation capabilities to for efficiency and gradient computation. The [.code]optax[.code] library provides the optimization algorithm, specifically Adam in this case.

num_epochs = 15000

@jax.jit
def loss(mlp_params_flat, train_x, train_y, base_context):
    pred_y = jax.vmap(evaluate_model, in_axes=(0,None,None))(train_x, mlp_params_flat, base_context)
    return jnp.mean((pred_y - train_y)**2)

value_and_grad = jax.jit(jax.value_and_grad(loss))

# Optax optimizer
optimizer = optax.adam(learning_rate=0.05)

# Initialize optimizer state
params = initial_mlp_parameters
opt_state = optimizer.init(params)

train_loss_history = []
test_loss_history = []
for epoch in range(num_epochs):
    train_loss, grads = value_and_grad(params, train_x, train_y, base_context)
    test_loss = loss(params, test_x, test_y, base_context)

    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    train_loss_history.append(train_loss)
    test_loss_history.append(test_loss)

    # Print the function value at the current parameters
    if (epoch + 1) % 1000 == 0:
        print(
            f"Epoch [{epoch+1}/{num_epochs}]: loss = {train_loss}"
        )

fig, ax = plt.subplots(1,1,figsize=(6,4))
ax.plot(train_loss_history, "-r", label="train loss")
ax.plot(test_loss_history, "-b", label="test loss")
ax.set_xlabel("epochs")
ax.set_ylabel("loss")
ax.legend()
plt.show()
Epoch [1000/15000]: loss = 5153.29563413222
Epoch [2000/15000]: loss = 2476.626246660495
Epoch [3000/15000]: loss = 1254.2464854222098
Epoch [4000/15000]: loss = 695.1984602861155
Epoch [5000/15000]: loss = 432.2903980064276
Epoch [6000/15000]: loss = 300.0311937880019
Epoch [7000/15000]: loss = 226.00624504198203
Epoch [8000/15000]: loss = 181.83802240617
Epoch [9000/15000]: loss = 153.46777646271676
Epoch [10000/15000]: loss = 134.83779730238234
Epoch [11000/15000]: loss = 130.97021401305358
Epoch [12000/15000]: loss = 126.13660536430028
Epoch [13000/15000]: loss = 114.49505596571055
Epoch [14000/15000]: loss = 114.78093487004155
Epoch [15000/15000]: loss = 108.95603333498251

Assess performance

After training, we evaluate our model on both the training and testing sets to see the improvement and to check for overfitting.

pred_y_train = jax.vmap(evaluate_model, in_axes=(0,None,None))(train_x, params, base_context)
pred_y_test = jax.vmap(evaluate_model, in_axes=(0,None,None))(test_x, params, base_context)

plot_predictions(pred_y_train, dataset="train")
plot_predictions(pred_y_test, dataset="test")

Save serialized model to use trained model in Collimator GUI

Finally, we demonstrate how to save the trained model parameters for later use and how to load them back into a model. Collimator provides a [.code]serialize[.code] utility to achieve this. This is particularly useful if you want to train an MLP model in Python/Jupyter environment and then want to use the trained model in Collimator GUI.

file_name = "saved_mlp.eqx"
model.serialize(file_name, params)

Load model and make predictions again

Finally, the code below demonstrates how the saved file can be reloaded.

config = {
    "in_size":n_input_features,
    "out_size":n_output_features,
    "width_size":n_hidden_features,
    "depth":n_hidden_layers,
    "activation_str":"sigmoid",
    "file_name": file_name,
}

loaded_model = MLP(**config)

# create a base context that can be modified
with loaded_model.input_ports[0].fixed(jnp.zeros(n_input_features)):
    base_context_loaded = loaded_model.create_context()

# Function to evaluate the output of the MLP given input x (base context with params does not need to be updated)
def evaluate_loaded_model(x, base_context_loaded):
    with loaded_model.input_ports[0].fixed(x):
        y = loaded_model.output_ports[0].eval(base_context_loaded)
    return y

pred_y_train_loaded = jax.vmap(evaluate_loaded_model, in_axes=(0,None))(train_x, base_context_loaded)
pred_y_test_loaded = jax.vmap(evaluate_loaded_model, in_axes=(0,None))(test_x, base_context_loaded)

assert jnp.allclose(pred_y_train, pred_y_train_loaded)
assert jnp.allclose(pred_y_test, pred_y_test_loaded)
Try it in Collimator