Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more early stopping options #6795

Closed
jlperla opened this issue Apr 1, 2021 · 5 comments
Closed

Add more early stopping options #6795

jlperla opened this issue Apr 1, 2021 · 5 comments
Assignees
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Milestone

Comments

@jlperla
Copy link

jlperla commented Apr 1, 2021

🚀 Feature

Additional early stopping features

@tchaton

First, max_time which should probably be in parallel to max_epochs in the main trainer loop. Why an additional one? Because (1) you never have any idea how long an epoch will be - especially if you tinker with hyperparameters; and (2) sometimes you want to give an amount of time and see which version of the model does the best given a fixed amount of time.

Second, a few variations on the EarlyStopping callback which is based on a metric.

  1. A way to stop because things have diverged completely and you doubt it can recover (e.g. too big, too small, or isnan)
  2. A way to stop because things have converged completely in terms of the quality of the approximation and there is no point doing futher iterations. This is distinct from convergence because it has ceased to get better - which is what it currently does.

For both of these, I think it is useful to have a min_epochs or somethign option to ensure that it doesn't stop right away. I think that is what #6705 is supposed to do though so it isn't needed here?

Finally, I think that it would be great in PL to have a way to log the reason for stopping so that it can be seen in the logs and be available within grid experiments view. Not sure the way to do that though, but maybe the callback could save a string in the logs?

Implementation

I implemented these two features in the currrent callaback with something like:

    def __init__(
        self,
        monitor: str = 'early_stop_on',
        min_delta: float = 0.0,
        patience: int = 3,
        verbose: bool = False,
        mode: str = 'min',
        strict: bool = True,
        stopping_threshold: float = 0.0,
        divergence_threshold: float = 1e6
    ):
        super().__init__()
        self.monitor = monitor
        self.patience = patience
        self.verbose = verbose
        self.strict = strict
        self.min_delta = min_delta
        self.wait_count = 0
        self.stopped_epoch = 0
        self.mode = mode
        self.warned_result_obj = False

        self.__init_monitor_mode()

        self.min_delta *= 1 if self.monitor_op == torch.gt else -1
        torch_inf = torch.tensor(np.Inf)
        self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf

        self.stopping_threshold = stopping_threshold
        self.divergence_threshold = divergence_threshold
        self.last_time = time()
        self.elapsed_time = 0.0

Then I added in something like

    def _run_early_stopping_check(self, trainer, pl_module):
        """
        Checks whether the early stopping condition is met
        and if so tells the trainer to stop the training.
        """
        # ADDED
        self.elapsed_time += time() - self.last_time
        self.last_time = time()

        logs = trainer.callback_metrics

        if (
            trainer.fast_dev_run  # disable early_stopping with fast_dev_run
            or not self._validate_condition_metric(logs)  # short circuit if metric not present
        ):
            return  # short circuit if metric not present

        current = logs.get(self.monitor)

        # when in dev debugging
        trainer.dev_debugger.track_early_stopping_history(self, current)

        if self.monitor_op(current - self.min_delta, self.best_score):
            self.best_score = current
            self.wait_count = 0
        else:
            self.wait_count += 1

        # ADDED
        if self.wait_count >= self.patience:
            self.stopped_epoch = trainer.current_epoch
            trainer.should_stop = True
            print(f"\n{OKCYAN}Stopping. Above patience of {self.patience} epochs without improvement of {self.min_delta}")            
        elif(self.monitor_op(current, self.stopping_threshold)):
            self.stopped_epoch = trainer.current_epoch
            trainer.should_stop = True
            print(f"\n{OKCYAN}Stopping. Below tolerance {self.monitor} = {logs[self.monitor]} <= {self.stopping_threshold}{ENDC}")          
        elif(self.monitor_op(-current,-self.divergence_threshold) or torch.isnan(current)):
            self.stopped_epoch = trainer.current_epoch
            trainer.should_stop = True
            print(f"\n{OKCYAN}Stopping. Divergence {self.monitor} = {logs[self.monitor]} >= {self.divergence_threshold} {ENDC}")      
        # stop every ddp process if any world process decides to stop
        trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop)
@jlperla jlperla added feature Is an improvement or enhancement help wanted Open to be worked on labels Apr 1, 2021
@tchaton tchaton added the priority: 0 High priority task label Apr 1, 2021
@awaelchli
Copy link
Contributor

For both of these, I think it is useful to have a min_epochs or somethign option to ensure that it doesn't stop right away. I think that is what #6705 is supposed to do though so it isn't needed here?

Yes, with this fix min_epochs / min_steps in Trainer will force this amount of training progression without stopping.

@awaelchli
Copy link
Contributor

@jlperla Thomas assigned me. Work in progress on the time based stopping here #6823 and verbose feedback here #6811

@jlperla
Copy link
Author

jlperla commented Apr 5, 2021

Amazing, thanks!

@awaelchli
Copy link
Contributor

awaelchli commented Apr 6, 2021

For the convergence/divergence, I could think of something like (which is what I believe you implemented above)

# stop training if val_acc reached 0.95 (conditioned on patience??)
EarlyStopping(monitor="val_acc", mode="max", boundary=0.95)

# stop training if val_acc is reaching a value below 0.1 (divergence) or a value above 0.95. 
# Continue training as long as we are within the band
EarlyStopping(monitor="val_acc", mode="max", boundary=(0.1, 0.95))

However, I must say I am a bit skeptical about the usefulness of such a thresholding. In my opinion, the patience + min_delta criterion is sufficiently covering both cases of convergence and divergence.

@jlperla
Copy link
Author

jlperla commented Apr 6, 2021

However, I must say I am a bit skeptical about the usefulness of such a thresholding. In my opinion, the patience + min_delta criterion is sufficiently covering both cases of convergence and divergence.

I have no doubt that is true in your applications. But I think it is important to understand the use case of why this is essential before you design any code around it. To me, the patience criteria is only useful to stop a failing experiment early and try something different.

Forget completely about computer vision or whatever your application is and even forget about machine learning. Think to when you have used a serious optimizer and solver trying to maximize or minimize a function, solve a system of equations, etc. For that, look at https://nlopt.readthedocs.io/en/latest/NLopt_Reference/#return-values or https://www.artelys.com/docs/knitro/3_referenceManual/knitromatlabReference.html#return-codes-exit-flags or https://coin-or.github.io/Ipopt/OUTPUT.html or pretty much any of them.

In many applications, pytorch lightning is used in a way very similar to these optimizers. For example, you know exactly what the solution should be (e.g. everything should be zero if solving a big system of equations, the gradient should be zero if solving an optimization problem, etc.). For anything I do now, and likely will ever do, I am trying to effectively solve a big system of nonlinear equations and am evaluating the residual. It isn't successful unless that residual is below some stopping threshold, 1e-6 or whatever.

Just to be clear on the analogy here to early stopping.

  • epoch > max_epochs <-> Iteration limit reached etc. in solvers
  • time > max_time <-> Time limit reached in the solvers
  • metric < stopping_tolerance or > stopping_tolerance depending on the .mode <-> Locally optimal solution found. This is really the only success.
  • above the patience level <-> Current feasible solution estimate cannot be improved. This may or may not be acceptable as a solution. Typically for most applications it wouldn't be.... or you would analyze it and then change your stopping_tolerance if it was.

Hopefully that helps. These are just completely different use cases than you may be used to, but trust me when I say they are not particular to my usage. Anyone who uses PL as an optimizer would need this stuff.

@luiscape I think it might make sense to get involved a little here to think about how this would work with the sort of grid controller we discussed. Multistart optimization methods/hyperparameter tuning would need this sort of thing.


So with that, the boundary think is not the right interface for a couple of reasons:

  • The divergence vs. convergence are very different. One is a success and the other a failure.
  • You would almost always leave whatever defaults are in place for the divergence criteria
  • It would make it very confusing to interpret return codes if success and failure are handled symmetrically.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants