Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepEnsemble: training networks in parallel #897

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 229 additions & 18 deletions tests/unit/models/keras/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
negative_log_likelihood,
sample_with_replacement,
)
from trieste.models.keras.builders import build_keras_ensemble
from trieste.models.optimizer import KerasOptimizer, TrainingData
from trieste.models.utils import (
get_last_optimization_result,
Expand Down Expand Up @@ -538,25 +539,40 @@ def test_deep_ensemble_prepare_data_call(
model, _, _ = trieste_deep_ensemble_model(example_data, ensemble_size, bootstrap_data, False)

# call with whole dataset
data = model.prepare_dataset(example_data)
assert isinstance(data, tuple)
for ensemble_data in data:
assert isinstance(ensemble_data, dict)
assert len(ensemble_data.keys()) == ensemble_size
for member_data in ensemble_data:
if bootstrap_data:
assert tf.reduce_any(ensemble_data[member_data] != x)
else:
assert tf.reduce_all(ensemble_data[member_data] == x)
for inp, out in zip(data[0], data[1]):
assert "".join(filter(str.isdigit, inp)) == "".join(filter(str.isdigit, out))
inputs, outputs = model.prepare_dataset(example_data)

# Check input shape and structure
assert isinstance(inputs, tf.Tensor)
assert inputs.shape[0] == ensemble_size # First dimension is ensemble size
assert inputs.shape[1:] == x.shape # Rest of dimensions match original data

# Check outputs structure
assert isinstance(outputs, list)
assert len(outputs) == ensemble_size

# Check bootstrapping behavior
for i in range(ensemble_size):
if bootstrap_data:
# Both inputs and outputs should be different from original data
assert tf.reduce_any(inputs[i] != x)
assert tf.reduce_any(outputs[i] != y)
# But inputs and outputs should be aligned (same indices used)
input_indices = tf.argsort(inputs[i, :, 0])
output_indices = tf.argsort(outputs[i][:, 0])
tf.assert_equal(input_indices, output_indices)
else:
# Without bootstrapping, data should be identical
tf.assert_equal(inputs[i], x)
tf.assert_equal(outputs[i], y)

# call with query points alone
inputs = model.prepare_query_points(example_data.query_points)
assert isinstance(inputs, dict)
assert len(inputs.keys()) == ensemble_size
for member_data in inputs:
assert tf.reduce_all(inputs[member_data] == x)
# Test prepare_query_points
query_points = tf.random.uniform([10, 1])
prepared_points = model.prepare_query_points(query_points)
assert prepared_points.shape[0] == ensemble_size
assert prepared_points.shape[1:] == query_points.shape
# All ensemble members should get the same query points for prediction
for i in range(ensemble_size):
tf.assert_equal(prepared_points[i], query_points)


def test_deep_ensemble_deep_copyable() -> None:
Expand Down Expand Up @@ -777,3 +793,198 @@ def test_deep_ensemble_log(

assert mocked_summary_scalar.call_count == num_scalars
assert mocked_summary_histogram.call_count == num_histogram


@random_seed
def test_deep_ensemble_parallel_training_performance() -> None:
"""
Verify that doubling ensemble size doesn't double training time, as a test of parallel training.
We allow some overhead, but should be significantly less than 2x
"""
# Create a larger dataset with more features to increase computation per network
example_data = _get_example_data([100000, 10], [100000, 1]) # 10D input, 1D output

# Test with different ensemble sizes
ensemble_sizes = [5, 10]
ensemble_units = [500, 352]
training_times = []

for units, size in zip(ensemble_units, ensemble_sizes):
# Create a larger network to increase computation
keras_ensemble = build_keras_ensemble(
example_data,
size,
num_hidden_layers=3, # More layers
units=units, # More units per layer
independent_normal=True # Simpler output distribution
)
optimizer = tf_keras.optimizers.Adam()
fit_args = {
"batch_size": 512, # Larger batch size for better parallelization
"epochs": 3, # More epochs to amortize setup costs
"callbacks": [],
"verbose": 1,
}
optimizer_wrapper = KerasOptimizer(optimizer, fit_args)
model = DeepEnsemble(
keras_ensemble,
optimizer_wrapper,
True,
compile_args={"jit_compile": True} # Enable XLA compilation
)
print(model.model.summary())

# Time the training
start_time = tf.timestamp()
model.optimize(example_data)
end_time = tf.timestamp()
training_times.append(end_time - start_time)

print(f"Training times: {training_times}")
print(f"Time ratio (10/5 networks): {training_times[1] / training_times[0]:.3f}")

# Allow more overhead but still expect significant parallelization benefit
assert training_times[1] / training_times[0] < 1.7, (
f"Training time ratio {training_times[1] / training_times[0]:.3f} suggests "
f"training may not be parallel"
)



@random_seed
def test_pure_tf_ensemble_parallel_training():
"""
Test parallel training performance using a pure TensorFlow implementation
of deep ensemble, while keeping total parameters constant across ensemble sizes.
"""
# Generate synthetic data
n_points = 100000
input_dim = 10
output_dim = 1
x = tf.random.normal([n_points, input_dim], dtype=tf.float64)
y = tf.random.normal([n_points, output_dim], dtype=tf.float64)

def get_network_width(ensemble_size: int, target_params: int, input_dim: int) -> int:
"""Calculate network width to maintain constant total parameters"""
# Same as before...
a = 2 * ensemble_size
b = ensemble_size * (input_dim + 5)
c = 2 * ensemble_size - target_params

discriminant = b * b - 4 * a * c
if discriminant < 0:
raise ValueError("No real solution exists for network width")

width = int((-b + np.sqrt(discriminant)) / (2 * a))
if width <= 0:
raise ValueError("Calculated width is not positive")

return width

def create_network_params(ensemble_size):
"""Create parameters for all networks at once"""
# Calculate width to maintain constant total parameters
# Target total params based on 5 networks with width 500
target_params = 5 * (input_dim * 500 + 500 + 500 * 500 + 500 + 500 * 500 + 500 + 500 * 2 + 2)
width = get_network_width(ensemble_size, target_params, input_dim)
width = 500
print(f"Using width {width} for ensemble size {ensemble_size} to maintain {target_params} total parameters")

# Initialize all weights in a single tensor for each layer
# Shape: [ensemble_size, in_dim, out_dim]
tf_input_dim = tf.sqrt(tf.constant(float(input_dim), dtype=tf.float64))
tf_width = tf.sqrt(tf.constant(float(width), dtype=tf.float64))
weights = {
'h1': tf.Variable(tf.random.normal([ensemble_size, input_dim, width], dtype=tf.float64) / tf_input_dim),
'h2': tf.Variable(tf.random.normal([ensemble_size, width, width], dtype=tf.float64) / tf_width),
'h3': tf.Variable(tf.random.normal([ensemble_size, width, width], dtype=tf.float64) / tf_width),
'mean': tf.Variable(tf.random.normal([ensemble_size, width, output_dim], dtype=tf.float64) / tf_width),
'var': tf.Variable(tf.random.normal([ensemble_size, width, output_dim], dtype=tf.float64) / tf_width)
}
# Initialize all biases in a single tensor for each layer
# Shape: [ensemble_size, out_dim]
biases = {
'h1': tf.Variable(tf.zeros([ensemble_size, width], dtype=tf.float64)),
'h2': tf.Variable(tf.zeros([ensemble_size, width], dtype=tf.float64)),
'h3': tf.Variable(tf.zeros([ensemble_size, width], dtype=tf.float64)),
'mean': tf.Variable(tf.zeros([ensemble_size, output_dim], dtype=tf.float64)),
'var': tf.Variable(tf.zeros([ensemble_size, output_dim], dtype=tf.float64))
}
return weights, biases

@tf.function(jit_compile=True) # Enable XLA optimization
def forward_parallel(x, weights, biases):
"""Forward pass through all networks in parallel using a single operation"""
# Expand and tile input once: [batch_size, input_dim] -> [ensemble_size, batch_size, input_dim]
batch_size = tf.shape(x)[0]
x = tf.expand_dims(x, 0) # [1, batch_size, input_dim]
x = tf.tile(x, [weights['h1'].shape[0], 1, 1]) # [ensemble_size, batch_size, input_dim]

# Compute all layers in parallel using einsum
# Each operation processes all networks simultaneously
h1 = tf.nn.relu(tf.einsum('ebi,eij->ebj', x, weights['h1']) + tf.expand_dims(biases['h1'], 1))
h2 = tf.nn.relu(tf.einsum('ebi,eij->ebj', h1, weights['h2']) + tf.expand_dims(biases['h2'], 1))
h3 = tf.nn.relu(tf.einsum('ebi,eij->ebj', h2, weights['h3']) + tf.expand_dims(biases['h3'], 1))

# Output layer - compute mean and variance in parallel
mean = tf.einsum('ebi,eij->ebj', h3, weights['mean']) + tf.expand_dims(biases['mean'], 1)
var = tf.nn.softplus(tf.einsum('ebi,eij->ebj', h3, weights['var']) + tf.expand_dims(biases['var'], 1))

return mean, var

@tf.function(jit_compile=True) # Enable XLA optimization
def nll_loss(y_true, mean, var):
"""Vectorized negative log likelihood loss computation"""
# Expand targets once: [batch_size, output_dim] -> [ensemble_size, batch_size, output_dim]
y_expanded = tf.expand_dims(y_true, 0)
y_tiled = tf.tile(y_expanded, [mean.shape[0], 1, 1])

# Compute loss for all networks in a single operation
return 0.5 * tf.reduce_mean(
tf.math.log(var) + tf.square(y_tiled - mean) / var,
axis=[1, 2] # Average over batch and output dimensions
)

def train_ensemble(ensemble_size, batch_size=512, epochs=3):
# Create ensemble parameters
weights, biases = create_network_params(ensemble_size)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

# Create dataset once and cache it
dataset = tf.data.Dataset.from_tensor_slices((x, y))
dataset = dataset.shuffle(10000).batch(batch_size)
dataset = dataset.prefetch(tf.data.AUTOTUNE) # Enable prefetching

# Training loop
start_time = tf.timestamp()
for epoch in range(epochs):
# Reset dataset for each epoch
epoch_dataset = dataset.repeat(1)

for batch_x, batch_y in epoch_dataset:
with tf.GradientTape() as tape:
# Forward pass for all networks in parallel
mean, var = forward_parallel(batch_x, weights, biases)
losses = nll_loss(batch_y, mean, var)
total_loss = tf.reduce_mean(losses) # Average loss across ensemble

# Gradient computation and update - single operation for all networks
trainable_vars = list(weights.values()) + list(biases.values())
grads = tape.gradient(total_loss, trainable_vars)
optimizer.apply_gradients(zip(grads, trainable_vars))

end_time = tf.timestamp()
return end_time - start_time

# Test training time for different ensemble sizes
time_5 = train_ensemble(5)
time_10 = train_ensemble(10)

print(f"Training times - 5 networks: {time_5:.2f}s, 10 networks: {time_10:.2f}s")
print(f"Time ratio (10/5 networks): {time_10/time_5:.3f}")

# Check if training scales sub-linearly
assert time_10 / time_5 < 1.7, (
f"Pure TF implementation training time ratio {time_10/time_5:.3f} suggests "
f"training may not be parallel"
)
59 changes: 34 additions & 25 deletions trieste/models/keras/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,30 @@ def ensemble_size(self) -> int:
def _build_ensemble(self) -> tf_keras.Model:
"""
Builds the ensemble model by combining all the individual networks in a single Keras model.
This method relies on ``connect_layers`` method of :class:`KerasEnsembleNetwork` objects
to construct individual networks.
Uses a single input layer with an additional dimension for ensemble members, allowing
different data for each network while enabling parallel training.

:return: The Keras model.
"""
inputs, outputs = zip(*[network.connect_layers() for network in self._networks])
# Create input layer with ensemble dimension as the last dimension
input_shape = self._networks[0].input_tensor_spec.shape + (self.ensemble_size,)
input_tensor = tf_keras.Input(
shape=input_shape,
dtype=self._networks[0].input_tensor_spec.dtype,
name="ensemble_input",
batch_size=None, # Allow dynamic batch size
)

return tf_keras.Model(inputs=inputs, outputs=outputs)
# Connect each network to its slice of the input tensor
outputs = []
for i, network in enumerate(self._networks):
# Extract the input for this network
network_input = tf.gather(input_tensor, i, axis=-1) # Get slice from last dimension
output = network.connect_layers(network_input, i)
outputs.append(output)

# Return model with list of outputs
return tf_keras.Model(inputs=input_tensor, outputs=outputs)

def __getstate__(self) -> dict[str, Any]:
# When pickling use to_json to save the model.
Expand Down Expand Up @@ -221,13 +237,14 @@ def flattened_output_shape(self) -> int:
return int(np.prod(self.output_tensor_spec.shape))

@abstractmethod
def connect_layers(self) -> tuple[tf.Tensor, tf.Tensor]:
def connect_layers(self, input_tensor: tf.Tensor, ensemble_index: int) -> tf.Tensor:
"""
Connects the layers of the neural network. Architecture, layers and layer specifications
need to be defined by the subclasses.

:return: Input and output tensor of the network, required by :class:`tf.keras.Model` to
build a model.
:param input_tensor: Input tensor to connect the layers to.
:param ensemble_index: Index of this network in the ensemble.
:return: Output tensor of the network.
"""
raise NotImplementedError

Expand Down Expand Up @@ -302,15 +319,7 @@ def __init__(
self._hidden_layer_args = hidden_layer_args
self._independent = independent

def _gen_input_tensor(self) -> tf_keras.Input:
input_tensor = tf_keras.Input(
shape=self.input_tensor_spec.shape,
dtype=self.input_tensor_spec.dtype,
name=self.input_layer_name,
)
return input_tensor

def _gen_hidden_layers(self, input_tensor: tf.Tensor) -> tf.Tensor:
def _gen_hidden_layers(self, input_tensor: tf.Tensor, ensemble_index: int) -> tf.Tensor:
for index, hidden_layer_args in enumerate(self._hidden_layer_args):
layer_name = f"{self.network_name}dense_{index}"
layer = tf_keras.layers.Dense(
Expand Down Expand Up @@ -354,21 +363,21 @@ def distribution_fn(inputs: TensorType) -> tfp.distributions.Distribution:

return distribution

def connect_layers(self) -> tuple[tf.Tensor, tf.Tensor]:
def connect_layers(self, input_tensor: tf.Tensor, ensemble_index: int) -> tf.Tensor:
"""
Connect all layers in the network. We start by generating an input tensor based on input
tensor specification. Next we generate a sequence of hidden dense layers based on
hidden layer arguments. Finally, we generate a dense layer whose nodes act as parameters of
a Gaussian distribution in the final probabilistic layer.
Connect all layers in the network. We start with the input tensor, generate a sequence
of hidden dense layers based on hidden layer arguments, and finally generate a dense layer
whose nodes act as parameters of a Gaussian distribution in the final probabilistic layer.

:return: Input and output tensor of the sequence of layers.
:param input_tensor: Input tensor to connect the layers to.
:param ensemble_index: Index of this network in the ensemble.
:return: Output tensor of the sequence of layers.
"""
input_tensor = self._gen_input_tensor()
hidden_tensor = self._gen_hidden_layers(input_tensor)
hidden_tensor = self._gen_hidden_layers(input_tensor, ensemble_index)

if self.flattened_output_shape == 1:
output_tensor = self._gen_single_output_layer(hidden_tensor)
else:
output_tensor = self._gen_multi_output_layer(hidden_tensor)

return input_tensor, output_tensor
return output_tensor
Loading
Loading