Skip to content

Commit

Permalink
apply keras workaround, update tests to incorporate validation
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Jun 11, 2024
1 parent 340f2a6 commit 7cc0529
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

import keras
import warnings

from bayesflow.experimental.configurators import BaseConfigurator
from bayesflow.experimental.networks import InferenceNetwork, SummaryNetwork
Expand All @@ -23,7 +24,23 @@ def train_step(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
raise NotImplementedError

def test_step(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
return self.compute_metrics(data)
metrics = self.compute_metrics(data, stage="validation")
self._loss_tracker.update_state(metrics["loss"])
return metrics

def evaluate(self, *args, **kwargs):
val_logs = super().evaluate(*args, **kwargs)

if val_logs is None:
# https://github.com/keras-team/keras/issues/19835
warnings.warn(f"Found no validation logs due to a bug in keras. "
f"Applying workaround, but incorrect loss values may be logged. "
f"If possible, increase the size of your dataset, "
f"or lower the number of validation steps used.")

val_logs = {}

return val_logs

# noinspection PyMethodOverriding
def compute_metrics(self, data: dict[str, Tensor], stage: str = "training") -> dict[str, Tensor]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ def train_step(self, data):
with torch.no_grad():
self.optimizer.apply(gradients, trainable_weights)

self._loss_tracker.update_state(loss)

return metrics
14 changes: 11 additions & 3 deletions tests/test_two_moons/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,17 @@ def sample(self, batch_shape):


@pytest.fixture()
def dataset(simulator):
from bayesflow.experimental.datasets import OnlineDataset
return OnlineDataset(simulator, workers=4, max_queue_size=16, batch_size=16)
def train_dataset(simulator, batch_size):
from bayesflow.experimental.datasets import OfflineDataset
data = simulator.sample((16 * batch_size,))
return OfflineDataset(data, workers=4, max_queue_size=16, batch_size=batch_size)


@pytest.fixture()
def validation_dataset(simulator, batch_size):
from bayesflow.experimental.datasets import OfflineDataset
data = simulator.sample((4 * batch_size,))
return OfflineDataset(data, workers=4, max_queue_size=16, batch_size=batch_size)


@pytest.fixture()
Expand Down
8 changes: 6 additions & 2 deletions tests/test_two_moons/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
from tests.utils import InterruptFitCallback, FitInterruptedError


def test_fit(approximator, dataset):
def test_fit(approximator, train_dataset, validation_dataset):
# TODO: verify the model learns something by comparing a metric before and after training
approximator.compile(optimizer="AdamW")
approximator.fit(dataset, epochs=10, steps_per_epoch=10)
approximator.fit(
x=train_dataset,
validation_data=validation_dataset,
epochs=2,
)


@pytest.mark.skip(reason="not implemented")
Expand Down

0 comments on commit 7cc0529

Please sign in to comment.