Skip to content

Commit

Permalink
Added metrics from custom train_step/test_step are now returned. (#…
Browse files Browse the repository at this point in the history
…19529)

This works the same way as in Keras 2, whereby the metrics are returned directly from the logs if the set of keys doesn't match the model metrics.
  • Loading branch information
hertschuh authored Apr 17, 2024
1 parent 1937d48 commit e57b138
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 19 deletions.
12 changes: 7 additions & 5 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,8 @@ def fit(
}

# Callbacks
callbacks.on_train_batch_end(step, self._pythonify_logs(logs))
logs = self._pythonify_logs(logs)
callbacks.on_train_batch_end(step, logs)
if self.stop_training:
break

Expand All @@ -446,12 +447,12 @@ def fit(
# bottleneck.
self.jax_state_sync()

# Override with model metrics instead of last step logs
# Override with model metrics instead of last step logs if needed.
# The jax spmd_mode is need for multi-process context, since the
# metrics values are replicated, and we don't want to do a all
# gather, and only need the local copy of the value.
with jax.spmd_mode("allow_all"):
epoch_logs = self.get_metrics_result()
epoch_logs = dict(self._get_metrics_result_or_logs(logs))

# Run validation.
if validation_data is not None and self._should_eval(
Expand Down Expand Up @@ -585,7 +586,8 @@ def evaluate(
"non_trainable_variables": non_trainable_variables,
"metrics_variables": metrics_variables,
}
callbacks.on_test_batch_end(step, self._pythonify_logs(logs))
logs = self._pythonify_logs(logs)
callbacks.on_test_batch_end(step, logs)
if self.stop_evaluating:
break

Expand All @@ -596,7 +598,7 @@ def evaluate(
# metrics values are replicated, and we don't want to do a all
# gather, and only need the local copy of the value.
with jax.spmd_mode("allow_all"):
logs = self.get_metrics_result()
logs = self._get_metrics_result_or_logs(logs)
callbacks.on_test_end(logs)
self._jax_state = None
if return_dict:
Expand Down
5 changes: 3 additions & 2 deletions keras/src/backend/numpy/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,11 @@ def evaluate(
for step, data in epoch_iterator.enumerate_epoch():
callbacks.on_test_batch_begin(step)
logs = self.test_function(data)
callbacks.on_test_batch_end(step, self._pythonify_logs(logs))
logs = self._pythonify_logs(logs)
callbacks.on_test_batch_end(step, logs)
if self.stop_evaluating:
break
logs = self.get_metrics_result()
logs = self._get_metrics_result_or_logs(logs)
callbacks.on_test_end(logs)

if return_dict:
Expand Down
14 changes: 7 additions & 7 deletions keras/src/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,13 @@ def fit(
for step, iterator in epoch_iterator.enumerate_epoch():
callbacks.on_train_batch_begin(step)
logs = self.train_function(iterator)
callbacks.on_train_batch_end(
step, self._pythonify_logs(logs)
)
logs = self._pythonify_logs(logs)
callbacks.on_train_batch_end(step, logs)
if self.stop_training:
break

# Override with model metrics instead of last step logs
epoch_logs = self.get_metrics_result()
# Override with model metrics instead of last step logs if needed.
epoch_logs = dict(self._get_metrics_result_or_logs(logs))

# Run validation.
if validation_data is not None and self._should_eval(
Expand Down Expand Up @@ -424,10 +423,11 @@ def evaluate(
for step, iterator in epoch_iterator.enumerate_epoch():
callbacks.on_test_batch_begin(step)
logs = self.test_function(iterator)
callbacks.on_test_batch_end(step, self._pythonify_logs(logs))
logs = self._pythonify_logs(logs)
callbacks.on_test_batch_end(step, logs)
if self.stop_evaluating:
break
logs = self.get_metrics_result()
logs = self._get_metrics_result_or_logs(logs)
callbacks.on_test_end(logs)

if return_dict:
Expand Down
12 changes: 7 additions & 5 deletions keras/src/backend/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,14 +252,15 @@ def fit(
callbacks.on_train_batch_begin(step)

logs = self.train_function(data)
logs = self._pythonify_logs(logs)

# Callbacks
callbacks.on_train_batch_end(step, self._pythonify_logs(logs))
callbacks.on_train_batch_end(step, logs)
if self.stop_training:
break

# Override with model metrics instead of last step logs
epoch_logs = self.get_metrics_result()
# Override with model metrics instead of last step logs if needed.
epoch_logs = dict(self._get_metrics_result_or_logs(logs))

# Switch the torch Module back to testing mode.
self.eval()
Expand Down Expand Up @@ -368,10 +369,11 @@ def evaluate(
for step, data in epoch_iterator.enumerate_epoch():
callbacks.on_test_batch_begin(step)
logs = self.test_function(data)
callbacks.on_test_batch_end(step, self._pythonify_logs(logs))
logs = self._pythonify_logs(logs)
callbacks.on_test_batch_end(step, logs)
if self.stop_evaluating:
break
logs = self.get_metrics_result()
logs = self._get_metrics_result_or_logs(logs)
callbacks.on_test_end(logs)

if return_dict:
Expand Down
31 changes: 31 additions & 0 deletions keras/src/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,37 @@ def _pythonify_logs(self, logs):
result[key] = value
return result

def _get_metrics_result_or_logs(self, logs):
"""Returns model metrics as a dict if the keys match with input logs.
When the training / evalution is performed with an asynchronous steps,
the last scheduled `train / test_step` may not give the latest metrics
because it is not guaranteed to be executed the last. This method gets
metrics from the model directly instead of relying on the return from
last step function.
When the user has custom train / test step functions, the metrics
returned may be different from `Model.metrics`. In those instances,
this function will be no-op and return the logs passed in.
Args:
logs: A `dict` of metrics returned by train / test step function.
Returns:
A `dict` containing values of the metrics listed in `self.metrics`
when logs and model metrics keys match. Otherwise it returns input
`logs`.
"""
metric_logs = self.get_metrics_result()
# Verify that train / test step logs passed and metric logs have
# matching keys. It could be different when using custom step functions,
# in which case we return the logs from the last step.
if isinstance(logs, dict) and set(logs.keys()) == set(
metric_logs.keys()
):
return metric_logs
return logs

def _flatten_metrics_in_order(self, logs):
"""Turns `logs` dict into a list as per key order of `metrics_names`."""
metric_names = []
Expand Down
70 changes: 70 additions & 0 deletions keras/src/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,30 @@ def __init__(self, units):
Trainer.__init__(self)


class CustomTrainTestStepModel(ExampleModel):
def train_step(self, data):
logs = super().train_step(data)
logs["my_custom_metric"] = 10.0
return logs

def test_step(self, data):
logs = super().test_step(data)
logs["my_custom_metric"] = 5.0
return logs


class JaxCustomTrainTestStepModel(ExampleModel):
def train_step(self, state, data):
logs, state = super().train_step(state, data)
logs["my_custom_metric"] = 10.0
return logs, state

def test_step(self, state, data):
logs, state = super().test_step(state, data)
logs["my_custom_metric"] = 5.0
return logs, state


class StructModel(Trainer, layers.Layer):
def __init__(self, units):
layers.Layer.__init__(self)
Expand Down Expand Up @@ -308,6 +332,27 @@ def test_fit_with_val_split(
self.assertIn("loss", history)
self.assertIn("val_loss", history)

@pytest.mark.requires_trainable_backend
def test_fit_with_custom_train_step(self):
if backend.backend() == "jax":
model = JaxCustomTrainTestStepModel(units=3)
else:
model = CustomTrainTestStepModel(units=3)
x = np.ones((100, 4))
y = np.zeros((100, 3))
batch_size = 16

model.compile(
optimizer=optimizers.SGD(),
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
)
history = model.fit(x, y, batch_size=batch_size)
history = history.history
self.assertIn("loss", history)
self.assertIn("mean_squared_error", history)
self.assertAllClose(history["my_custom_metric"], 10.0)

@parameterized.named_parameters(
named_product(
generator_type=["tf", "jax", "scipy"], mode=["eager", "graph"]
Expand Down Expand Up @@ -375,6 +420,31 @@ def test_evaluate_flow(self, run_eagerly, jit_compile):
self.assertIn("mean_squared_error", output)
self.assertAllClose(output["mean_squared_error"], 16.0)

@parameterized.named_parameters([("flat", False), ("dict", True)])
@pytest.mark.requires_trainable_backend
def test_evaluate_with_custom_test_step(self, return_dict):
if backend.backend() == "jax":
model = JaxCustomTrainTestStepModel(units=3)
else:
model = CustomTrainTestStepModel(units=3)
x = np.ones((100, 4))
y = np.zeros((100, 3))
batch_size = 16

model.compile(
optimizer=optimizers.SGD(),
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
)
output = model.evaluate(
x, y, batch_size=batch_size, return_dict=return_dict
)
self.assertLen(output, 3)
if return_dict:
self.assertAllClose(output["my_custom_metric"], 5.0)
else:
self.assertAllClose(output[-1], 5.0) # Custom metrics go last.

@parameterized.named_parameters(
named_product(
generator_type=["tf", "jax", "scipy"], mode=["eager", "graph"]
Expand Down

0 comments on commit e57b138

Please sign in to comment.