From c6499e86ba236b640f9bc5d1acc81cc9ce1aef9f Mon Sep 17 00:00:00 2001 From: jagodevreede <2422592+jagodevreede@users.noreply.github.com> Date: Wed, 25 Oct 2023 01:46:38 +0200 Subject: [PATCH] [api] Added Early stopping configuration (#38) (#2806) * [api] Added Early stopping configuration (#38) * [api] Added Builder for Early stopping configuration (#38) * Explicitly set NDManager for dataset in EarlyStoppingListenerTest to make the test run on JDK11 in gradle. --- .../listener/EarlyStoppingListener.java | 281 ++++++++++++++++++ .../listener/EarlyStoppingListenerTest.java | 189 ++++++++++++ .../tests/training/listener/package-info.java | 15 + 3 files changed, 485 insertions(+) create mode 100644 api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java create mode 100644 integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java create mode 100644 integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java diff --git a/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java b/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java new file mode 100644 index 00000000000..6c013c37715 --- /dev/null +++ b/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java @@ -0,0 +1,281 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.training.listener; + +import ai.djl.training.Trainer; +import ai.djl.training.TrainingResult; + +import java.time.Duration; + +/** + * Listener that allows the training to be stopped early if the validation loss is not improving, or + * if time has expired.
+ * + *

Usage: Add this listener to the training config, and add it as the last one. + * + *

+ *  new DefaultTrainingConfig(...)
+ *        .addTrainingListeners(EarlyStoppingListener.builder()
+ *                .setEpochPatience(1)
+ *                .setEarlyStopPctImprovement(1)
+ *                .setMaxDuration(Duration.ofMinutes(42))
+ *                .setMinEpochs(1)
+ *                .build()
+ *        );
+ * 
+ * + *

Then surround the fit with a try catch that catches the {@link + * EarlyStoppingListener.EarlyStoppedException}.
+ * Example: + * + *

+ * try {
+ *   EasyTrain.fit(trainer, 5, trainDataset, testDataset);
+ * } catch (EarlyStoppingListener.EarlyStoppedException e) {
+ *   // handle early stopping
+ *   log.info("Stopped early at epoch {} because: {}", e.getEpoch(), e.getMessage());
+ * }
+ * 
+ * + *
+ * Note: Ensure that Metrics are set on the trainer. + */ +public final class EarlyStoppingListener implements TrainingListener { + private final double objectiveSuccess; + + private final int minEpochs; + private final long maxMillis; + private final double earlyStopPctImprovement; + private final int epochPatience; + + private long startTimeMills; + private double prevLoss; + private int numberOfEpochsWithoutImprovements; + + private EarlyStoppingListener( + double objectiveSuccess, + int minEpochs, + long maxMillis, + double earlyStopPctImprovement, + int earlyStopPatience) { + this.objectiveSuccess = objectiveSuccess; + this.minEpochs = minEpochs; + this.maxMillis = maxMillis; + this.earlyStopPctImprovement = earlyStopPctImprovement; + this.epochPatience = earlyStopPatience; + } + + /** {@inheritDoc} */ + @Override + public void onEpoch(Trainer trainer) { + int currentEpoch = trainer.getTrainingResult().getEpoch(); + // stopping criteria + final double loss = getLoss(trainer.getTrainingResult()); + if (currentEpoch >= minEpochs) { + if (loss < objectiveSuccess) { + throw new EarlyStoppedException( + currentEpoch, + String.format( + "validation loss %s < objectiveSuccess %s", + loss, objectiveSuccess)); + } + long elapsedMillis = System.currentTimeMillis() - startTimeMills; + if (elapsedMillis >= maxMillis) { + throw new EarlyStoppedException( + currentEpoch, + String.format("%s ms elapsed >= %s maxMillis", elapsedMillis, maxMillis)); + } + // consider early stopping? + if (Double.isFinite(prevLoss)) { + double goalImprovement = prevLoss * (100 - earlyStopPctImprovement) / 100.0; + boolean improved = loss <= goalImprovement; // false if any NANs + if (improved) { + numberOfEpochsWithoutImprovements = 0; + } else { + numberOfEpochsWithoutImprovements++; + if (numberOfEpochsWithoutImprovements >= epochPatience) { + throw new EarlyStoppedException( + currentEpoch, + String.format( + "failed to achieve %s%% improvement %s times in a row", + earlyStopPctImprovement, epochPatience)); + } + } + } + } + if (Double.isFinite(loss)) { + prevLoss = loss; + } + } + + private static double getLoss(TrainingResult trainingResult) { + Float vLoss = trainingResult.getValidateLoss(); + if (vLoss != null) { + return vLoss; + } + Float tLoss = trainingResult.getTrainLoss(); + if (tLoss == null) { + return Double.NaN; + } + return tLoss; + } + + /** {@inheritDoc} */ + @Override + public void onTrainingBatch(Trainer trainer, BatchData batchData) { + // do nothing + } + + /** {@inheritDoc} */ + @Override + public void onValidationBatch(Trainer trainer, BatchData batchData) { + // do nothing + } + + /** {@inheritDoc} */ + @Override + public void onTrainingBegin(Trainer trainer) { + this.startTimeMills = System.currentTimeMillis(); + this.prevLoss = Double.NaN; + this.numberOfEpochsWithoutImprovements = 0; + } + + /** {@inheritDoc} */ + @Override + public void onTrainingEnd(Trainer trainer) { + // do nothing + } + + /** + * Creates a builder to build a {@link EarlyStoppingListener}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** A builder for a {@link EarlyStoppingListener}. */ + public static final class Builder { + private final double objectiveSuccess; + private int minEpochs; + private long maxMillis; + private double earlyStopPctImprovement; + private int epochPatience; + + /** Constructs a {@link Builder} with default values. */ + public Builder() { + this.objectiveSuccess = 0; + this.minEpochs = 0; + this.maxMillis = Long.MAX_VALUE; + this.earlyStopPctImprovement = 0; + this.epochPatience = 0; + } + + /** + * Set the minimum # epochs, defaults to 0. + * + * @param minEpochs the minimum # epochs + * @return this builder + */ + public Builder optMinEpochs(int minEpochs) { + this.minEpochs = minEpochs; + return this; + } + + /** + * Set the maximum duration a training run should take, defaults to Long.MAX_VALUE in ms. + * + * @param duration the maximum duration a training run should take + * @return this builder + */ + public Builder optMaxDuration(Duration duration) { + this.maxMillis = duration.toMillis(); + return this; + } + + /** + * Set the maximum # milliseconds a training run should take, defaults to Long.MAX_VALUE. + * + * @param maxMillis the maximum # milliseconds a training run should take + * @return this builder + */ + public Builder optMaxMillis(int maxMillis) { + this.maxMillis = maxMillis; + return this; + } + + /** + * Consider early stopping if not x% improvement, defaults to 0. + * + * @param earlyStopPctImprovement the percentage improvement to consider early stopping, + * must be between 0 and 100. + * @return this builder + */ + public Builder optEarlyStopPctImprovement(double earlyStopPctImprovement) { + this.earlyStopPctImprovement = earlyStopPctImprovement; + return this; + } + + /** + * Stop if insufficient improvement for x epochs in a row, defaults to 0. + * + * @param epochPatience the number of epochs without improvement to consider stopping, must + * be greater than 0. + * @return this builder + */ + public Builder optEpochPatience(int epochPatience) { + this.epochPatience = epochPatience; + return this; + } + + /** + * Builds a {@link EarlyStoppingListener} with the specified values. + * + * @return a new {@link EarlyStoppingListener} + */ + public EarlyStoppingListener build() { + return new EarlyStoppingListener( + objectiveSuccess, minEpochs, maxMillis, earlyStopPctImprovement, epochPatience); + } + } + + /** + * Thrown when training is stopped early, the message will contain the reason why it is stopped + * early. + */ + public static class EarlyStoppedException extends RuntimeException { + private static final long serialVersionUID = 1L; + private final int stopEpoch; + + /** + * Constructs an {@link EarlyStoppedException} with the specified message and epoch. + * + * @param stopEpoch the epoch at which training was stopped early + * @param message the message/reason why training was stopped early + */ + public EarlyStoppedException(int stopEpoch, String message) { + super(message); + this.stopEpoch = stopEpoch; + } + + /** + * Gets the epoch at which training was stopped early. + * + * @return the epoch at which training was stopped early. + */ + public int getStopEpoch() { + return stopEpoch; + } + } +} diff --git a/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java b/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java new file mode 100644 index 00000000000..c3dd34d0369 --- /dev/null +++ b/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java @@ -0,0 +1,189 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.integration.tests.training.listener; + +import ai.djl.Model; +import ai.djl.basicdataset.cv.classification.Mnist; +import ai.djl.basicmodelzoo.basic.Mlp; +import ai.djl.integration.util.TestUtils; +import ai.djl.metric.Metrics; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Activation; +import ai.djl.training.DefaultTrainingConfig; +import ai.djl.training.EasyTrain; +import ai.djl.training.Trainer; +import ai.djl.training.TrainingResult; +import ai.djl.training.dataset.Dataset; +import ai.djl.training.listener.EarlyStoppingListener; +import ai.djl.training.listener.TrainingListener; +import ai.djl.training.loss.Loss; +import ai.djl.training.optimizer.Optimizer; +import ai.djl.training.tracker.Tracker; +import ai.djl.translate.TranslateException; + +import org.testng.Assert; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.time.Duration; + +public class EarlyStoppingListenerTest { + + private final Optimizer sgd = + Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build(); + + private NDManager manager; + private Mnist testMnistDataset; + private Mnist trainMnistDataset; + + @BeforeTest + public void setUp() throws IOException, TranslateException { + manager = NDManager.newBaseManager(TestUtils.getEngine()); + testMnistDataset = + Mnist.builder() + .optUsage(Dataset.Usage.TEST) + .optManager(manager) + .optLimit(8) + .setSampling(8, false) + .build(); + testMnistDataset.prepare(); + + trainMnistDataset = + Mnist.builder() + .optUsage(Dataset.Usage.TRAIN) + .optManager(manager) + .optLimit(16) + .setSampling(8, false) + .build(); + trainMnistDataset.prepare(); + } + + @AfterTest + public void closeResources() { + manager.close(); + } + + @Test + public void testEarlyStoppingStopsOnEpoch2() throws Exception { + Mlp mlpModel = new Mlp(784, 1, new int[] {256}, Activation::relu); + + try (Model model = Model.newInstance("lin-reg", TestUtils.getEngine())) { + model.setBlock(mlpModel); + + DefaultTrainingConfig config = + new DefaultTrainingConfig(Loss.l2Loss()) + .optOptimizer(sgd) + .addTrainingListeners(TrainingListener.Defaults.logging()) + .addTrainingListeners( + EarlyStoppingListener.builder() + .optEpochPatience(1) + .optEarlyStopPctImprovement(99) + .optMaxDuration(Duration.ofMinutes(1)) + .optMinEpochs(1) + .build()); + + try (Trainer trainer = model.newTrainer(config)) { + trainer.initialize(new Shape(1, 784)); + Metrics metrics = new Metrics(); + trainer.setMetrics(metrics); + + try { + // Set epoch to 5 as we expect the early stopping to stop after the second epoch + EasyTrain.fit(trainer, 5, trainMnistDataset, testMnistDataset); + } catch (EarlyStoppingListener.EarlyStoppedException e) { + Assert.assertEquals( + e.getMessage(), "failed to achieve 99.0% improvement 1 times in a row"); + Assert.assertEquals(e.getStopEpoch(), 2); + } + + TrainingResult trainingResult = trainer.getTrainingResult(); + Assert.assertEquals(trainingResult.getEpoch(), 2); + } + } + } + + @Test + public void testEarlyStoppingStopsOnEpoch3AsMinEpochsIs3() throws Exception { + Mlp mlpModel = new Mlp(784, 1, new int[] {256}, Activation::relu); + + try (Model model = Model.newInstance("lin-reg", TestUtils.getEngine())) { + model.setBlock(mlpModel); + + DefaultTrainingConfig config = + new DefaultTrainingConfig(Loss.l2Loss()) + .optOptimizer(sgd) + .addTrainingListeners(TrainingListener.Defaults.logging()) + .addTrainingListeners( + EarlyStoppingListener.builder() + .optEpochPatience(1) + .optEarlyStopPctImprovement(50) + .optMaxMillis(60_000) + .optMinEpochs(3) + .build()); + + try (Trainer trainer = model.newTrainer(config)) { + trainer.initialize(new Shape(1, 784)); + Metrics metrics = new Metrics(); + trainer.setMetrics(metrics); + + try { + // Set epoch to 5 as we expect the early stopping to stop after the second epoch + EasyTrain.fit(trainer, 5, trainMnistDataset, testMnistDataset); + } catch (EarlyStoppingListener.EarlyStoppedException e) { + Assert.assertEquals( + e.getMessage(), "failed to achieve 50.0% improvement 1 times in a row"); + Assert.assertEquals(e.getStopEpoch(), 3); + } + + TrainingResult trainingResult = trainer.getTrainingResult(); + Assert.assertEquals(trainingResult.getEpoch(), 3); + } + } + } + + @Test + public void testEarlyStoppingStopsOnEpoch1AsMaxDurationIs1ms() throws Exception { + Mlp mlpModel = new Mlp(784, 1, new int[] {256}, Activation::relu); + + try (Model model = Model.newInstance("lin-reg", TestUtils.getEngine())) { + model.setBlock(mlpModel); + + DefaultTrainingConfig config = + new DefaultTrainingConfig(Loss.l2Loss()) + .optOptimizer(sgd) + .addTrainingListeners(TrainingListener.Defaults.logging()) + .addTrainingListeners( + EarlyStoppingListener.builder().optMaxMillis(1).build()); + + try (Trainer trainer = model.newTrainer(config)) { + trainer.initialize(new Shape(1, 784)); + Metrics metrics = new Metrics(); + trainer.setMetrics(metrics); + + try { + // Set epoch to 5 as we expect the early stopping to stop after the second epoch + EasyTrain.fit(trainer, 5, trainMnistDataset, testMnistDataset); + } catch (EarlyStoppingListener.EarlyStoppedException e) { + Assert.assertTrue(e.getMessage().contains("ms elapsed >= 1 maxMillis")); + Assert.assertEquals(e.getStopEpoch(), 1); + } + + TrainingResult trainingResult = trainer.getTrainingResult(); + Assert.assertEquals(trainingResult.getEpoch(), 1); + } + } + } +} diff --git a/integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java b/integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java new file mode 100644 index 00000000000..88680e5fe89 --- /dev/null +++ b/integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains tests using the listeners {@link ai.djl.training}. */ +package ai.djl.integration.tests.training.listener;