diff --git a/python/ray/air/checkpoint.py b/python/ray/air/checkpoint.py index 07a2ee7734c1..f1232922f4c3 100644 --- a/python/ray/air/checkpoint.py +++ b/python/ray/air/checkpoint.py @@ -435,6 +435,7 @@ def from_checkpoint(cls, other: "Checkpoint") -> "Checkpoint": generic :py:class:`Checkpoint` object. Examples: + >>> result = TorchTrainer.fit(...) # doctest: +SKIP >>> checkpoint = TorchCheckpoint.from_checkpoint(result.checkpoint) # doctest: +SKIP # noqa: E501 >>> model = checkpoint.get_model() # doctest: +SKIP diff --git a/python/ray/tune/stopper/stopper.py b/python/ray/tune/stopper/stopper.py index 0a29b1fed35d..c61900402c63 100644 --- a/python/ray/tune/stopper/stopper.py +++ b/python/ray/tune/stopper/stopper.py @@ -1,4 +1,5 @@ import abc +from typing import Any, Dict from ray.util.annotations import PublicAPI @@ -15,34 +16,42 @@ class Stopper(abc.ABC): >>> import time >>> from ray import air, tune + >>> from ray.air import session >>> from ray.tune import Stopper >>> >>> class TimeStopper(Stopper): ... def __init__(self): ... self._start = time.time() - ... self._deadline = 5 + ... self._deadline = 5 # Stop all trials after 5 seconds ... ... def __call__(self, trial_id, result): ... return False ... ... def stop_all(self): ... return time.time() - self._start > self._deadline - >>> + ... + >>> def train_fn(config): + ... for i in range(100): + ... time.sleep(1) + ... session.report({"iter": i}) + ... >>> tuner = tune.Tuner( - ... tune.Trainable, - ... tune_config=tune.TuneConfig(num_samples=200), - ... run_config=air.RunConfig(stop=TimeStopper()) + ... train_fn, + ... tune_config=tune.TuneConfig(num_samples=2), + ... run_config=air.RunConfig(stop=TimeStopper()), ... ) - >>> tuner.fit() - == Status ==... + >>> print("[ignore]"); result_grid = tuner.fit() # doctest: +ELLIPSIS + [ignore]... + >>> all(result.metrics["time_total_s"] < 6 for result in result_grid) + True """ - def __call__(self, trial_id, result): + def __call__(self, trial_id: str, result: Dict[str, Any]) -> bool: """Returns true if the trial should be terminated given the result.""" raise NotImplementedError - def stop_all(self): + def stop_all(self) -> bool: """Returns true if the experiment should be terminated.""" raise NotImplementedError @@ -56,28 +65,39 @@ class CombinedStopper(Stopper): Examples: - >>> from ray.tune.stopper import (CombinedStopper, - ... MaximumIterationStopper, TrialPlateauStopper) + >>> import numpy as np + >>> from ray import air, tune + >>> from ray.air import session + >>> from ray.tune.stopper import ( + ... CombinedStopper, + ... MaximumIterationStopper, + ... TrialPlateauStopper, + ... ) >>> >>> stopper = CombinedStopper( ... MaximumIterationStopper(max_iter=20), - ... TrialPlateauStopper(metric="my_metric") + ... TrialPlateauStopper(metric="my_metric"), ... ) - >>> + >>> def train_fn(config): + ... for i in range(25): + ... session.report({"my_metric": np.random.normal(0, 1 - i / 25)}) + ... >>> tuner = tune.Tuner( - ... tune.Trainable, - ... run_config=air.RunConfig(stop=stopper) + ... train_fn, + ... run_config=air.RunConfig(stop=stopper), ... ) - >>> tuner.fit() - == Status ==... + >>> print("[ignore]"); result_grid = tuner.fit() # doctest: +ELLIPSIS + [ignore]... + >>> all(result.metrics["training_iteration"] <= 20 for result in result_grid) + True """ def __init__(self, *stoppers: Stopper): self._stoppers = stoppers - def __call__(self, trial_id, result): + def __call__(self, trial_id: str, result: Dict[str, Any]) -> bool: return any(s(trial_id, result) for s in self._stoppers) - def stop_all(self): + def stop_all(self) -> bool: return any(s.stop_all() for s in self._stoppers)