From c6499e86ba236b640f9bc5d1acc81cc9ce1aef9f Mon Sep 17 00:00:00 2001
From: jagodevreede <2422592+jagodevreede@users.noreply.github.com>
Date: Wed, 25 Oct 2023 01:46:38 +0200
Subject: [PATCH] [api] Added Early stopping configuration (#38) (#2806)
* [api] Added Early stopping configuration (#38)
* [api] Added Builder for Early stopping configuration (#38)
* Explicitly set NDManager for dataset in EarlyStoppingListenerTest to make the test run on JDK11 in gradle.
---
.../listener/EarlyStoppingListener.java | 281 ++++++++++++++++++
.../listener/EarlyStoppingListenerTest.java | 189 ++++++++++++
.../tests/training/listener/package-info.java | 15 +
3 files changed, 485 insertions(+)
create mode 100644 api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java
create mode 100644 integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java
create mode 100644 integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java
diff --git a/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java b/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java
new file mode 100644
index 00000000000..6c013c37715
--- /dev/null
+++ b/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java
@@ -0,0 +1,281 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.training.listener;
+
+import ai.djl.training.Trainer;
+import ai.djl.training.TrainingResult;
+
+import java.time.Duration;
+
+/**
+ * Listener that allows the training to be stopped early if the validation loss is not improving, or
+ * if time has expired.
+ *
+ *
Usage: Add this listener to the training config, and add it as the last one.
+ *
+ *
+ * new DefaultTrainingConfig(...)
+ * .addTrainingListeners(EarlyStoppingListener.builder()
+ * .setEpochPatience(1)
+ * .setEarlyStopPctImprovement(1)
+ * .setMaxDuration(Duration.ofMinutes(42))
+ * .setMinEpochs(1)
+ * .build()
+ * );
+ *
+ *
+ * Then surround the fit with a try catch that catches the {@link
+ * EarlyStoppingListener.EarlyStoppedException}.
+ * Example:
+ *
+ *
+ * try {
+ * EasyTrain.fit(trainer, 5, trainDataset, testDataset);
+ * } catch (EarlyStoppingListener.EarlyStoppedException e) {
+ * // handle early stopping
+ * log.info("Stopped early at epoch {} because: {}", e.getEpoch(), e.getMessage());
+ * }
+ *
+ *
+ *
+ * Note: Ensure that Metrics are set on the trainer.
+ */
+public final class EarlyStoppingListener implements TrainingListener {
+ private final double objectiveSuccess;
+
+ private final int minEpochs;
+ private final long maxMillis;
+ private final double earlyStopPctImprovement;
+ private final int epochPatience;
+
+ private long startTimeMills;
+ private double prevLoss;
+ private int numberOfEpochsWithoutImprovements;
+
+ private EarlyStoppingListener(
+ double objectiveSuccess,
+ int minEpochs,
+ long maxMillis,
+ double earlyStopPctImprovement,
+ int earlyStopPatience) {
+ this.objectiveSuccess = objectiveSuccess;
+ this.minEpochs = minEpochs;
+ this.maxMillis = maxMillis;
+ this.earlyStopPctImprovement = earlyStopPctImprovement;
+ this.epochPatience = earlyStopPatience;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void onEpoch(Trainer trainer) {
+ int currentEpoch = trainer.getTrainingResult().getEpoch();
+ // stopping criteria
+ final double loss = getLoss(trainer.getTrainingResult());
+ if (currentEpoch >= minEpochs) {
+ if (loss < objectiveSuccess) {
+ throw new EarlyStoppedException(
+ currentEpoch,
+ String.format(
+ "validation loss %s < objectiveSuccess %s",
+ loss, objectiveSuccess));
+ }
+ long elapsedMillis = System.currentTimeMillis() - startTimeMills;
+ if (elapsedMillis >= maxMillis) {
+ throw new EarlyStoppedException(
+ currentEpoch,
+ String.format("%s ms elapsed >= %s maxMillis", elapsedMillis, maxMillis));
+ }
+ // consider early stopping?
+ if (Double.isFinite(prevLoss)) {
+ double goalImprovement = prevLoss * (100 - earlyStopPctImprovement) / 100.0;
+ boolean improved = loss <= goalImprovement; // false if any NANs
+ if (improved) {
+ numberOfEpochsWithoutImprovements = 0;
+ } else {
+ numberOfEpochsWithoutImprovements++;
+ if (numberOfEpochsWithoutImprovements >= epochPatience) {
+ throw new EarlyStoppedException(
+ currentEpoch,
+ String.format(
+ "failed to achieve %s%% improvement %s times in a row",
+ earlyStopPctImprovement, epochPatience));
+ }
+ }
+ }
+ }
+ if (Double.isFinite(loss)) {
+ prevLoss = loss;
+ }
+ }
+
+ private static double getLoss(TrainingResult trainingResult) {
+ Float vLoss = trainingResult.getValidateLoss();
+ if (vLoss != null) {
+ return vLoss;
+ }
+ Float tLoss = trainingResult.getTrainLoss();
+ if (tLoss == null) {
+ return Double.NaN;
+ }
+ return tLoss;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void onTrainingBatch(Trainer trainer, BatchData batchData) {
+ // do nothing
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void onValidationBatch(Trainer trainer, BatchData batchData) {
+ // do nothing
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void onTrainingBegin(Trainer trainer) {
+ this.startTimeMills = System.currentTimeMillis();
+ this.prevLoss = Double.NaN;
+ this.numberOfEpochsWithoutImprovements = 0;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void onTrainingEnd(Trainer trainer) {
+ // do nothing
+ }
+
+ /**
+ * Creates a builder to build a {@link EarlyStoppingListener}.
+ *
+ * @return a new builder
+ */
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /** A builder for a {@link EarlyStoppingListener}. */
+ public static final class Builder {
+ private final double objectiveSuccess;
+ private int minEpochs;
+ private long maxMillis;
+ private double earlyStopPctImprovement;
+ private int epochPatience;
+
+ /** Constructs a {@link Builder} with default values. */
+ public Builder() {
+ this.objectiveSuccess = 0;
+ this.minEpochs = 0;
+ this.maxMillis = Long.MAX_VALUE;
+ this.earlyStopPctImprovement = 0;
+ this.epochPatience = 0;
+ }
+
+ /**
+ * Set the minimum # epochs, defaults to 0.
+ *
+ * @param minEpochs the minimum # epochs
+ * @return this builder
+ */
+ public Builder optMinEpochs(int minEpochs) {
+ this.minEpochs = minEpochs;
+ return this;
+ }
+
+ /**
+ * Set the maximum duration a training run should take, defaults to Long.MAX_VALUE in ms.
+ *
+ * @param duration the maximum duration a training run should take
+ * @return this builder
+ */
+ public Builder optMaxDuration(Duration duration) {
+ this.maxMillis = duration.toMillis();
+ return this;
+ }
+
+ /**
+ * Set the maximum # milliseconds a training run should take, defaults to Long.MAX_VALUE.
+ *
+ * @param maxMillis the maximum # milliseconds a training run should take
+ * @return this builder
+ */
+ public Builder optMaxMillis(int maxMillis) {
+ this.maxMillis = maxMillis;
+ return this;
+ }
+
+ /**
+ * Consider early stopping if not x% improvement, defaults to 0.
+ *
+ * @param earlyStopPctImprovement the percentage improvement to consider early stopping,
+ * must be between 0 and 100.
+ * @return this builder
+ */
+ public Builder optEarlyStopPctImprovement(double earlyStopPctImprovement) {
+ this.earlyStopPctImprovement = earlyStopPctImprovement;
+ return this;
+ }
+
+ /**
+ * Stop if insufficient improvement for x epochs in a row, defaults to 0.
+ *
+ * @param epochPatience the number of epochs without improvement to consider stopping, must
+ * be greater than 0.
+ * @return this builder
+ */
+ public Builder optEpochPatience(int epochPatience) {
+ this.epochPatience = epochPatience;
+ return this;
+ }
+
+ /**
+ * Builds a {@link EarlyStoppingListener} with the specified values.
+ *
+ * @return a new {@link EarlyStoppingListener}
+ */
+ public EarlyStoppingListener build() {
+ return new EarlyStoppingListener(
+ objectiveSuccess, minEpochs, maxMillis, earlyStopPctImprovement, epochPatience);
+ }
+ }
+
+ /**
+ * Thrown when training is stopped early, the message will contain the reason why it is stopped
+ * early.
+ */
+ public static class EarlyStoppedException extends RuntimeException {
+ private static final long serialVersionUID = 1L;
+ private final int stopEpoch;
+
+ /**
+ * Constructs an {@link EarlyStoppedException} with the specified message and epoch.
+ *
+ * @param stopEpoch the epoch at which training was stopped early
+ * @param message the message/reason why training was stopped early
+ */
+ public EarlyStoppedException(int stopEpoch, String message) {
+ super(message);
+ this.stopEpoch = stopEpoch;
+ }
+
+ /**
+ * Gets the epoch at which training was stopped early.
+ *
+ * @return the epoch at which training was stopped early.
+ */
+ public int getStopEpoch() {
+ return stopEpoch;
+ }
+ }
+}
diff --git a/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java b/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java
new file mode 100644
index 00000000000..c3dd34d0369
--- /dev/null
+++ b/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java
@@ -0,0 +1,189 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.integration.tests.training.listener;
+
+import ai.djl.Model;
+import ai.djl.basicdataset.cv.classification.Mnist;
+import ai.djl.basicmodelzoo.basic.Mlp;
+import ai.djl.integration.util.TestUtils;
+import ai.djl.metric.Metrics;
+import ai.djl.ndarray.NDManager;
+import ai.djl.ndarray.types.Shape;
+import ai.djl.nn.Activation;
+import ai.djl.training.DefaultTrainingConfig;
+import ai.djl.training.EasyTrain;
+import ai.djl.training.Trainer;
+import ai.djl.training.TrainingResult;
+import ai.djl.training.dataset.Dataset;
+import ai.djl.training.listener.EarlyStoppingListener;
+import ai.djl.training.listener.TrainingListener;
+import ai.djl.training.loss.Loss;
+import ai.djl.training.optimizer.Optimizer;
+import ai.djl.training.tracker.Tracker;
+import ai.djl.translate.TranslateException;
+
+import org.testng.Assert;
+import org.testng.annotations.AfterTest;
+import org.testng.annotations.BeforeTest;
+import org.testng.annotations.Test;
+
+import java.io.IOException;
+import java.time.Duration;
+
+public class EarlyStoppingListenerTest {
+
+ private final Optimizer sgd =
+ Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build();
+
+ private NDManager manager;
+ private Mnist testMnistDataset;
+ private Mnist trainMnistDataset;
+
+ @BeforeTest
+ public void setUp() throws IOException, TranslateException {
+ manager = NDManager.newBaseManager(TestUtils.getEngine());
+ testMnistDataset =
+ Mnist.builder()
+ .optUsage(Dataset.Usage.TEST)
+ .optManager(manager)
+ .optLimit(8)
+ .setSampling(8, false)
+ .build();
+ testMnistDataset.prepare();
+
+ trainMnistDataset =
+ Mnist.builder()
+ .optUsage(Dataset.Usage.TRAIN)
+ .optManager(manager)
+ .optLimit(16)
+ .setSampling(8, false)
+ .build();
+ trainMnistDataset.prepare();
+ }
+
+ @AfterTest
+ public void closeResources() {
+ manager.close();
+ }
+
+ @Test
+ public void testEarlyStoppingStopsOnEpoch2() throws Exception {
+ Mlp mlpModel = new Mlp(784, 1, new int[] {256}, Activation::relu);
+
+ try (Model model = Model.newInstance("lin-reg", TestUtils.getEngine())) {
+ model.setBlock(mlpModel);
+
+ DefaultTrainingConfig config =
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optOptimizer(sgd)
+ .addTrainingListeners(TrainingListener.Defaults.logging())
+ .addTrainingListeners(
+ EarlyStoppingListener.builder()
+ .optEpochPatience(1)
+ .optEarlyStopPctImprovement(99)
+ .optMaxDuration(Duration.ofMinutes(1))
+ .optMinEpochs(1)
+ .build());
+
+ try (Trainer trainer = model.newTrainer(config)) {
+ trainer.initialize(new Shape(1, 784));
+ Metrics metrics = new Metrics();
+ trainer.setMetrics(metrics);
+
+ try {
+ // Set epoch to 5 as we expect the early stopping to stop after the second epoch
+ EasyTrain.fit(trainer, 5, trainMnistDataset, testMnistDataset);
+ } catch (EarlyStoppingListener.EarlyStoppedException e) {
+ Assert.assertEquals(
+ e.getMessage(), "failed to achieve 99.0% improvement 1 times in a row");
+ Assert.assertEquals(e.getStopEpoch(), 2);
+ }
+
+ TrainingResult trainingResult = trainer.getTrainingResult();
+ Assert.assertEquals(trainingResult.getEpoch(), 2);
+ }
+ }
+ }
+
+ @Test
+ public void testEarlyStoppingStopsOnEpoch3AsMinEpochsIs3() throws Exception {
+ Mlp mlpModel = new Mlp(784, 1, new int[] {256}, Activation::relu);
+
+ try (Model model = Model.newInstance("lin-reg", TestUtils.getEngine())) {
+ model.setBlock(mlpModel);
+
+ DefaultTrainingConfig config =
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optOptimizer(sgd)
+ .addTrainingListeners(TrainingListener.Defaults.logging())
+ .addTrainingListeners(
+ EarlyStoppingListener.builder()
+ .optEpochPatience(1)
+ .optEarlyStopPctImprovement(50)
+ .optMaxMillis(60_000)
+ .optMinEpochs(3)
+ .build());
+
+ try (Trainer trainer = model.newTrainer(config)) {
+ trainer.initialize(new Shape(1, 784));
+ Metrics metrics = new Metrics();
+ trainer.setMetrics(metrics);
+
+ try {
+ // Set epoch to 5 as we expect the early stopping to stop after the second epoch
+ EasyTrain.fit(trainer, 5, trainMnistDataset, testMnistDataset);
+ } catch (EarlyStoppingListener.EarlyStoppedException e) {
+ Assert.assertEquals(
+ e.getMessage(), "failed to achieve 50.0% improvement 1 times in a row");
+ Assert.assertEquals(e.getStopEpoch(), 3);
+ }
+
+ TrainingResult trainingResult = trainer.getTrainingResult();
+ Assert.assertEquals(trainingResult.getEpoch(), 3);
+ }
+ }
+ }
+
+ @Test
+ public void testEarlyStoppingStopsOnEpoch1AsMaxDurationIs1ms() throws Exception {
+ Mlp mlpModel = new Mlp(784, 1, new int[] {256}, Activation::relu);
+
+ try (Model model = Model.newInstance("lin-reg", TestUtils.getEngine())) {
+ model.setBlock(mlpModel);
+
+ DefaultTrainingConfig config =
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optOptimizer(sgd)
+ .addTrainingListeners(TrainingListener.Defaults.logging())
+ .addTrainingListeners(
+ EarlyStoppingListener.builder().optMaxMillis(1).build());
+
+ try (Trainer trainer = model.newTrainer(config)) {
+ trainer.initialize(new Shape(1, 784));
+ Metrics metrics = new Metrics();
+ trainer.setMetrics(metrics);
+
+ try {
+ // Set epoch to 5 as we expect the early stopping to stop after the second epoch
+ EasyTrain.fit(trainer, 5, trainMnistDataset, testMnistDataset);
+ } catch (EarlyStoppingListener.EarlyStoppedException e) {
+ Assert.assertTrue(e.getMessage().contains("ms elapsed >= 1 maxMillis"));
+ Assert.assertEquals(e.getStopEpoch(), 1);
+ }
+
+ TrainingResult trainingResult = trainer.getTrainingResult();
+ Assert.assertEquals(trainingResult.getEpoch(), 1);
+ }
+ }
+ }
+}
diff --git a/integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java b/integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java
new file mode 100644
index 00000000000..88680e5fe89
--- /dev/null
+++ b/integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java
@@ -0,0 +1,15 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+
+/** Contains tests using the listeners {@link ai.djl.training}. */
+package ai.djl.integration.tests.training.listener;