diff --git a/neps/env.py b/neps/env.py index cd37d4df..c614ebac 100644 --- a/neps/env.py +++ b/neps/env.py @@ -33,7 +33,6 @@ def is_nullable(e: str) -> bool: parse=str, default="lockf", ) -assert LINUX_FILELOCK_FUNCTION in ("lockf", "flock") MAX_RETRIES_GET_NEXT_TRIAL = get_env( "NEPS_MAX_RETRIES_GET_NEXT_TRIAL", @@ -66,35 +65,17 @@ def is_nullable(e: str) -> bool: default=120, ) -SEED_SNAPSHOT_FILELOCK_POLL = get_env( - "NEPS_SEED_SNAPSHOT_FILELOCK_POLL", +# NOTE: We want this to be greater than the trials filelock, so that +# anything requesting to just update the trials is more likely to obtain it +# as those operations tend to be faster than something that requires optimizer +# state. +STATE_FILELOCK_POLL = get_env( + "NEPS_STATE_FILELOCK_POLL", parse=float, - default=0.05, -) -SEED_SNAPSHOT_FILELOCK_TIMEOUT = get_env( - "NEPS_SEED_SNAPSHOT_FILELOCK_TIMEOUT", - parse=lambda e: None if is_nullable(e) else float(e), - default=120, -) - -OPTIMIZER_INFO_FILELOCK_POLL = get_env( - "NEPS_OPTIMIZER_INFO_FILELOCK_POLL", - parse=float, - default=0.05, -) -OPTIMIZER_INFO_FILELOCK_TIMEOUT = get_env( - "NEPS_OPTIMIZER_INFO_FILELOCK_TIMEOUT", - parse=lambda e: None if is_nullable(e) else float(e), - default=120, -) - -OPTIMIZER_STATE_FILELOCK_POLL = get_env( - "NEPS_OPTIMIZER_STATE_FILELOCK_POLL", - parse=float, - default=0.05, + default=0.20, ) -OPTIMIZER_STATE_FILELOCK_TIMEOUT = get_env( - "NEPS_OPTIMIZER_STATE_FILELOCK_TIMEOUT", +STATE_FILELOCK_TIMEOUT = get_env( + "NEPS_STATE_FILELOCK_TIMEOUT", parse=lambda e: None if is_nullable(e) else float(e), default=120, ) diff --git a/neps/plot/tensorboard_eval.py b/neps/plot/tensorboard_eval.py index 380ad6b4..816f16eb 100644 --- a/neps/plot/tensorboard_eval.py +++ b/neps/plot/tensorboard_eval.py @@ -94,7 +94,7 @@ def _initiate_internal_configurations() -> None: register_notify_trial_end("NEPS_TBLOGGER", tblogger.end_of_config) # We are assuming that neps state is all filebased here - root_dir = Path(neps_state.location) + root_dir = Path(neps_state.path) assert root_dir.exists() tblogger.config_working_directory = Path(trial.metadata.location) diff --git a/neps/runtime.py b/neps/runtime.py index fb3e1e9b..308a2560 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -33,14 +33,13 @@ WorkerRaiseError, ) from neps.state._eval import evaluate_trial -from neps.state.filebased import create_or_load_filebased_neps_state +from neps.state.neps_state import NePSState from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings +from neps.state.trial import Trial if TYPE_CHECKING: from neps.optimizers.base_optimizer import BaseOptimizer - from neps.state.neps_state import NePSState - from neps.state.trial import Trial logger = logging.getLogger(__name__) @@ -64,7 +63,7 @@ def _default_worker_name() -> str: # TODO: This only works with a filebased nepsstate -def get_workers_neps_state() -> NePSState[Path]: +def get_workers_neps_state() -> NePSState: """Get the worker's NePS state.""" if _WORKER_NEPS_STATE is None: raise RuntimeError( @@ -76,7 +75,7 @@ def get_workers_neps_state() -> NePSState[Path]: return _WORKER_NEPS_STATE -def _set_workers_neps_state(state: NePSState[Path]) -> None: +def _set_workers_neps_state(state: NePSState) -> None: global _WORKER_NEPS_STATE # noqa: PLW0603 _WORKER_NEPS_STATE = state @@ -177,27 +176,7 @@ def new( _pre_sample_hooks=_pre_sample_hooks, ) - def _get_next_trial_from_state(self) -> Trial: - nxt_trial = self.state.get_next_pending_trial() - - # If we have a trial, we will use it - if nxt_trial is not None: - logger.info( - f"Worker '{self.worker_id}' got previosly sampled trial: {nxt_trial}" - ) - - # Otherwise sample a new one - else: - nxt_trial = self.state.sample_trial( - worker_id=self.worker_id, - optimizer=self.optimizer, - _sample_hooks=self._pre_sample_hooks, - ) - logger.info(f"Worker '{self.worker_id}' sampled a new trial: {nxt_trial}") - - return nxt_trial - - def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911 + def _check_worker_local_settings( self, *, time_monotonic_start: float, @@ -205,8 +184,6 @@ def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911 ) -> str | Literal[False]: # NOTE: Sorry this code is kind of ugly but it's pretty straightforward, just a # lot of conditional checking and making sure to check cheaper conditions first. - # It would look a little nicer with a match statement but we've got to wait - # for python 3.10 for that. # First check for stopping criterion for this worker in particular as it's # cheaper and doesn't require anything from the state. @@ -280,13 +257,16 @@ def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911 f", given by `{self.settings.max_evaluation_time_for_worker_seconds=}`." ) + return False + + def _check_shared_error_stopping_criterion(self) -> str | Literal[False]: # We check this global error stopping criterion as it's much # cheaper than sweeping the state from all trials. if self.settings.on_error in ( OnErrorPossibilities.RAISE_ANY_ERROR, OnErrorPossibilities.STOP_ANY_ERROR, ): - err = self.state._shared_errors.synced().latest_err_as_raisable() + err = self.state.lock_and_get_errors().latest_err_as_raisable() if err is not None: msg = ( "An error occurred in another worker and this worker is set to stop" @@ -306,20 +286,12 @@ def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911 return msg - # If there are no global stopping criterion, we can no just return early. - if ( - self.settings.max_evaluations_total is None - and self.settings.max_cost_total is None - and self.settings.max_evaluation_time_total_seconds is None - ): - return False - - # At this point, if we have some global stopping criterion, we need to sweep - # the current state of trials to determine if we should stop - # NOTE: If these `sum` turn out to somehow be a bottleneck, these could - # be precomputed and accumulated over time. This would have to be handled - # in the `NePSState` class. - trials = self.state.get_all_trials() + return False + + def _check_global_stopping_criterion( + self, + trials: Mapping[str, Trial], + ) -> str | Literal[False]: if self.settings.max_evaluations_total is not None: if self.settings.include_in_progress_evaluations_towards_maximum: # NOTE: We can just use the sum of trials in this case as they @@ -368,6 +340,8 @@ def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911 return False + # Forgive me lord, for I have sinned, this function is atrocious but complicated + # due to locking. def run(self) -> None: # noqa: C901, PLR0915, PLR0912 """Run the worker. @@ -385,18 +359,27 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912 n_failed_set_trial_state = 0 n_repeated_failed_check_should_stop = 0 while True: - # NOTE: We rely on this function to do logging and raising errors if it should try: - should_stop = self._check_if_should_stop( + # First check local worker settings + should_stop = self._check_worker_local_settings( time_monotonic_start=_time_monotonic_start, error_from_this_worker=_error_from_evaluation, ) if should_stop is not False: logger.info(should_stop) break + + # Next check global errs having occured + should_stop = self._check_shared_error_stopping_criterion() + if should_stop is not False: + logger.info(should_stop) + break + except WorkerRaiseError as e: + # If we raise a specific error, we should stop the worker raise e except Exception as e: + # An unknown exception, check our retry countk n_repeated_failed_check_should_stop += 1 if ( n_repeated_failed_check_should_stop @@ -415,8 +398,48 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912 time.sleep(1) # Help stagger retries continue + # From here, we now begin sampling or getting the next pending trial. + # As the global stopping criterion requires us to check all trials, and + # needs to be in locked in-step with sampling try: - trial_to_eval = self._get_next_trial_from_state() + # If there are no global stopping criterion, we can no just return early. + with self.state.lock_for_sampling(): + trials = self.state._trials.latest() + + requires_checking_global_stopping_criterion = ( + self.settings.max_evaluations_total is not None + or self.settings.max_cost_total is not None + or self.settings.max_evaluation_time_total_seconds is not None + ) + if requires_checking_global_stopping_criterion: + should_stop = self._check_global_stopping_criterion(trials) + if should_stop is not False: + logger.info(should_stop) + break + + pending_trials = [ + trial + for trial in trials.values() + if trial.state == Trial.State.PENDING + ] + if len(pending_trials) > 0: + earliest_pending = sorted( + pending_trials, + key=lambda t: t.metadata.time_sampled, + )[0] + earliest_pending.set_evaluating( + time_started=time.time(), + worker_id=self.worker_id, + ) + self.state._trials.update_trial(earliest_pending) + trial_to_eval = earliest_pending + else: + sampled_trial = self.state._sample_trial( + optimizer=self.optimizer, + worker_id=self.worker_id, + ) + trial_to_eval = sampled_trial + _repeated_fail_get_next_trial_count = 0 except Exception as e: _repeated_fail_get_next_trial_count += 1 @@ -439,11 +462,6 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912 # If we can't set this working to evaluating, then just retry the loop try: - trial_to_eval.set_evaluating( - time_started=time.time(), - worker_id=self.worker_id, - ) - self.state.put_updated_trial(trial_to_eval) n_failed_set_trial_state = 0 except VersionMismatchError: n_failed_set_trial_state += 1 @@ -512,11 +530,12 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912 # We do not retry this, as if some other worker has # managed to manipulate this trial in the meantime, # then something has gone wrong - self.state.report_trial_evaluation( - trial=evaluated_trial, - report=report, - worker_id=self.worker_id, - ) + with self.state.lock_trials(): + self.state._report_trial_evaluation( + trial=evaluated_trial, + report=report, + worker_id=self.worker_id, + ) logger.debug("Config %s: %s", evaluated_trial.id, evaluated_trial.config) logger.debug("Loss %s: %s", evaluated_trial.id, report.loss) @@ -553,8 +572,9 @@ def _launch_runtime( # noqa: PLR0913 for _retry_count in range(MAX_RETRIES_CREATE_LOAD_STATE): try: - neps_state = create_or_load_filebased_neps_state( - directory=optimization_dir, + neps_state = NePSState.create_or_load( + path=optimization_dir, + load_only=False, optimizer_info=OptimizerInfo(optimizer_info), optimizer_state=OptimizationState( budget=( @@ -613,7 +633,17 @@ def _launch_runtime( # noqa: PLR0913 # it's not directly advertised as a parameter/env variable or otherwise. import portalocker.portalocker as portalocker_lock_module - setattr(portalocker_lock_module, "LOCKER", LINUX_FILELOCK_FUNCTION) + try: + import fcntl + + if LINUX_FILELOCK_FUNCTION.lower() == "flock": + setattr(portalocker_lock_module, "LOCKER", fcntl.flock) + elif LINUX_FILELOCK_FUNCTION.lower() == "lockf": + setattr(portalocker_lock_module, "LOCKER", fcntl.lockf) + else: + pass + except ImportError: + pass worker = DefaultWorker.new( state=neps_state, diff --git a/neps/state/__init__.py b/neps/state/__init__.py index e870d656..6b190afb 100644 --- a/neps/state/__init__.py +++ b/neps/state/__init__.py @@ -1,23 +1,11 @@ from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo -from neps.state.protocols import ( - Locker, - ReaderWriter, - Synced, - VersionedResource, - Versioner, -) from neps.state.seed_snapshot import SeedSnapshot from neps.state.trial import Trial __all__ = [ - "Locker", "SeedSnapshot", - "Synced", "BudgetInfo", "OptimizationState", "OptimizerInfo", "Trial", - "ReaderWriter", - "Versioner", - "VersionedResource", ] diff --git a/neps/state/filebased.py b/neps/state/filebased.py index 9394d0e1..3658d511 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -1,55 +1,22 @@ -"""This module houses the implementation of a NePSState that -does everything on the filesystem, i.e. locking, versioning and -storing/loading. - -The main components are: -* [`FileVersioner`][neps.state.filebased.FileVersioner]: A versioner that - stores a version tag on disk, usually for a resource like a Trial. -* [`FileLocker`][neps.state.filebased.FileLocker]: A locker that uses a file - to lock between processes. -* [`TrialRepoInDirectory`][neps.state.filebased.TrialRepoInDirectory]: A - repository of Trials that are stored in a directory. -* `ReaderWriterXXX`: Reader/writers for various resources NePSState needs -* [`load_filebased_neps_state`][neps.state.filebased.load_filebased_neps_state]: - A function to load a NePSState from a directory. -* [`create_filebased_neps_state`][neps.state.filebased.create_filebased_neps_state]: - A function to create a new NePSState in a directory. -""" - from __future__ import annotations import json import logging import pprint -from collections.abc import Iterable, Iterator +from collections.abc import Iterator from contextlib import contextmanager -from dataclasses import asdict, dataclass, field +from dataclasses import asdict, dataclass from pathlib import Path from typing import ClassVar, Final, TypeVar -from typing_extensions import override -from uuid import uuid4 import numpy as np import portalocker as pl from neps.env import ( ENV_VARS_USED, - GLOBAL_ERR_FILELOCK_POLL, - GLOBAL_ERR_FILELOCK_TIMEOUT, - OPTIMIZER_INFO_FILELOCK_POLL, - OPTIMIZER_INFO_FILELOCK_TIMEOUT, - OPTIMIZER_STATE_FILELOCK_POLL, - OPTIMIZER_STATE_FILELOCK_TIMEOUT, - SEED_SNAPSHOT_FILELOCK_POLL, - SEED_SNAPSHOT_FILELOCK_TIMEOUT, - TRIAL_FILELOCK_POLL, - TRIAL_FILELOCK_TIMEOUT, ) -from neps.exceptions import NePSError from neps.state.err_dump import ErrDump -from neps.state.neps_state import NePSState from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo -from neps.state.protocols import Locker, ReaderWriter, Synced, TrialRepo, Versioner from neps.state.seed_snapshot import SeedSnapshot from neps.state.trial import Trial from neps.utils.files import deserialize, serialize @@ -59,43 +26,16 @@ T = TypeVar("T") -def make_sha() -> str: - """Generate a str hex sha.""" - return uuid4().hex - - @dataclass -class FileVersioner(Versioner): - """A versioner that stores a version tag on disk.""" - - version_file: Path - - @override - def current(self) -> str | None: - if not self.version_file.exists(): - return None - return self.version_file.read_text() - - @override - def bump(self) -> str: - sha = make_sha() - self.version_file.write_text(sha) - return sha - - -@dataclass -class ReaderWriterTrial(ReaderWriter[Trial, Path]): +class ReaderWriterTrial: """ReaderWriter for Trial objects.""" - CHEAP_LOCKLESS_READ: ClassVar = True - CONFIG_FILENAME = "config.yaml" METADATA_FILENAME = "metadata.yaml" STATE_FILENAME = "state.txt" REPORT_FILENAME = "report.yaml" PREVIOUS_TRIAL_ID_FILENAME = "previous_trial_id.txt" - @override @classmethod def read(cls, directory: Path) -> Trial: config_path = directory / cls.CONFIG_FILENAME @@ -112,7 +52,6 @@ def read(cls, directory: Path) -> Trial: ), ) - @override @classmethod def write(cls, trial: Trial, directory: Path) -> None: config_path = directory / cls.CONFIG_FILENAME @@ -132,166 +71,13 @@ def write(cls, trial: Trial, directory: Path) -> None: serialize(asdict(trial.report), report_path) -_StaticReaderWriterTrial: Final = ReaderWriterTrial() - -CONFIG_PREFIX_LEN: Final = len("config_") +TrialReaderWriter: Final = ReaderWriterTrial() @dataclass -class TrialRepoInDirectory(TrialRepo[Path]): - """A repository of Trials that are stored in a directory.""" - - directory: Path - _cache: dict[str, Synced[Trial, Path]] = field(default_factory=dict) - - @override - def all_trial_ids(self) -> list[str]: - """List all the trial ids in this trial Repo.""" - return [ - config_path.name[CONFIG_PREFIX_LEN:] - for config_path in self.directory.iterdir() - if config_path.name.startswith("config_") and config_path.is_dir() - ] - - @override - def get_by_id( - self, - trial_id: str, - *, - lock_poll: float = TRIAL_FILELOCK_POLL, - lock_timeout: float | None = TRIAL_FILELOCK_TIMEOUT, - ) -> Synced[Trial, Path]: - """Get a Trial by its ID. - - !!! note - - This will **not** explicitly sync the trial and it is up to the caller - to do so. Most of the time, the caller should be a NePSState - object which will do that for you. However if the trial is not in the - cache, then it will be loaded from disk which requires syncing. - - Args: - trial_id: The ID of the trial to get. - lock_poll: The poll time for the file lock. - lock_timeout: The timeout for the file lock. - - Returns: - The trial with the given ID. - """ - trial = self._cache.get(trial_id) - if trial is not None: - return trial - - config_path = self.directory / f"config_{trial_id}" - if not config_path.exists(): - raise TrialRepo.TrialNotFoundError(trial_id, config_path) - - trial = Synced.load( - location=config_path, - locker=FileLocker( - lock_path=config_path / ".lock", - poll=lock_poll, - timeout=lock_timeout, - ), - versioner=FileVersioner(version_file=config_path / ".version"), - reader_writer=_StaticReaderWriterTrial, - ) - self._cache[trial_id] = trial - return trial - - @override - def put_new( - self, - trial: Trial, - *, - lock_poll: float = TRIAL_FILELOCK_POLL, - lock_timeout: float | None = TRIAL_FILELOCK_TIMEOUT, - ) -> Synced[Trial, Path]: - """Put a new Trial into the repository. - - Args: - trial: The trial to put. - lock_poll: The poll time for the file lock. - lock_timeout: The timeout for the file lock. - - Returns: - The synced trial. - - Raises: - TrialRepo.TrialAlreadyExistsError: If the trial already exists in the - repository. - """ - config_path = self.directory.absolute().resolve() / f"config_{trial.metadata.id}" - if config_path.exists(): - # This shouldn't exist, we load in the trial to see the current state of it - # to try determine wtf is going on for logging purposes. - try: - shared_trial = Synced.load( - location=config_path, - locker=FileLocker( - lock_path=config_path / ".lock", - poll=lock_poll, - timeout=lock_timeout, - ), - versioner=FileVersioner(version_file=config_path / ".version"), - reader_writer=_StaticReaderWriterTrial, - ) - already_existing_trial = shared_trial._unsynced() - extra_msg = ( - f"The existing trial is the following: {already_existing_trial}" - ) - except Exception: # noqa: BLE001 - extra_msg = "Failed to load the existing trial to provide more info." - - raise TrialRepo.TrialAlreadyExistsError( - f"Trial '{trial.metadata.id}' already exists as '{config_path}'." - f" Tried to put in the trial: {trial}." - f"\n{extra_msg}" - ) - - # HACK: We do this here as there is no way to know where a Trial will - # be located when it's created... - trial.metadata.location = str(config_path) - shared_trial = Synced.new( - data=trial, - location=config_path, - locker=FileLocker( - lock_path=config_path / ".lock", - poll=lock_poll, - timeout=lock_timeout, - ), - versioner=FileVersioner(version_file=config_path / ".version"), - reader_writer=_StaticReaderWriterTrial, - ) - self._cache[trial.metadata.id] = shared_trial - return shared_trial - - @override - def all(self) -> dict[str, Synced[Trial, Path]]: - """Get a dictionary of all the Trials in the repository. - - !!! note - See [`get_by_id()`][neps.state.filebased.TrialRepoInDirectory.get_by_id] - for notes on the trials syncing. - """ - return {trial_id: self.get_by_id(trial_id) for trial_id in self.all_trial_ids()} - - @override - def pending(self) -> Iterable[tuple[str, Trial]]: - pending = [ - (_id, trial, trial.metadata.time_sampled) - for (_id, t) in self.all().items() - if (trial := t.synced()).state == Trial.State.PENDING - ] - return iter((_id, t) for _id, t, _ in sorted(pending, key=lambda x: x[2])) - - -@dataclass -class ReaderWriterSeedSnapshot(ReaderWriter[SeedSnapshot, Path]): +class ReaderWriterSeedSnapshot: """ReaderWriter for SeedSnapshot objects.""" - CHEAP_LOCKLESS_READ: ClassVar = True - # It seems like they're all uint32 but I can't be sure. PY_RNG_STATE_DTYPE: ClassVar = np.int64 @@ -301,7 +87,6 @@ class ReaderWriterSeedSnapshot(ReaderWriter[SeedSnapshot, Path]): TORCH_CUDA_RNG_STATE_FILENAME: ClassVar = "torch_cuda_rng_state.pt" SEED_INFO_FILENAME: ClassVar = "seed_info.json" - @override @classmethod def read(cls, directory: Path) -> SeedSnapshot: seedinfo_path = directory / cls.SEED_INFO_FILENAME @@ -350,7 +135,6 @@ def read(cls, directory: Path) -> SeedSnapshot: torch_cuda_rng=torch_cuda_rng, ) - @override @classmethod def write(cls, snapshot: SeedSnapshot, directory: Path) -> None: seedinfo_path = directory / cls.SEED_INFO_FILENAME @@ -387,20 +171,16 @@ def write(cls, snapshot: SeedSnapshot, directory: Path) -> None: @dataclass -class ReaderWriterOptimizerInfo(ReaderWriter[OptimizerInfo, Path]): +class ReaderWriterOptimizerInfo: """ReaderWriter for OptimizerInfo objects.""" - CHEAP_LOCKLESS_READ: ClassVar = True - INFO_FILENAME: ClassVar = "info.yaml" - @override @classmethod def read(cls, directory: Path) -> OptimizerInfo: info_path = directory / cls.INFO_FILENAME return OptimizerInfo(info=deserialize(info_path)) - @override @classmethod def write(cls, optimizer_info: OptimizerInfo, directory: Path) -> None: info_path = directory / cls.INFO_FILENAME @@ -412,14 +192,11 @@ def write(cls, optimizer_info: OptimizerInfo, directory: Path) -> None: # handle this. # TODO(eddiebergman): May also want to consider serializing budget into a seperate entity @dataclass -class ReaderWriterOptimizationState(ReaderWriter[OptimizationState, Path]): +class ReaderWriterOptimizationState: """ReaderWriter for OptimizationState objects.""" - CHEAP_LOCKLESS_READ: ClassVar = True - STATE_FILE_NAME: ClassVar = "state.yaml" - @override @classmethod def read(cls, directory: Path) -> OptimizationState: state_path = directory / cls.STATE_FILE_NAME @@ -431,7 +208,6 @@ def read(cls, directory: Path) -> OptimizationState: budget=budget, ) - @override @classmethod def write(cls, info: OptimizationState, directory: Path) -> None: info_path = directory / cls.STATE_FILE_NAME @@ -439,24 +215,20 @@ def write(cls, info: OptimizationState, directory: Path) -> None: @dataclass -class ReaderWriterErrDump(ReaderWriter[ErrDump, Path]): +class ReaderWriterErrDump: """ReaderWriter for shared error lists.""" - CHEAP_LOCKLESS_READ: ClassVar = True - - name: str - - @override - def read(self, directory: Path) -> ErrDump: - errors_path = directory / f"{self.name}-errors.jsonl" + @classmethod + def read(cls, directory: Path) -> ErrDump: + errors_path = directory / "errors.jsonl" with errors_path.open("r") as f: data = [json.loads(line) for line in f] return ErrDump([ErrDump.SerializableTrialError(**d) for d in data]) - @override - def write(self, err_dump: ErrDump, directory: Path) -> None: - errors_path = directory / f"{self.name}-errors.jsonl" + @classmethod + def write(cls, err_dump: ErrDump, directory: Path) -> None: + errors_path = directory / "errors.jsonl" with errors_path.open("w") as f: lines = [json.dumps(asdict(trial_err)) for trial_err in err_dump.errs] f.write("\n".join(lines)) @@ -466,7 +238,7 @@ def write(self, err_dump: ErrDump, directory: Path) -> None: @dataclass -class FileLocker(Locker): +class FileLocker: """File-based locker using `portalocker`. [`FileLocker`][neps.state.locker.file.FileLocker] implements @@ -482,18 +254,6 @@ class FileLocker(Locker): def __post_init__(self) -> None: self.lock_path = self.lock_path.resolve().absolute() - @override - def is_locked(self) -> bool: - if not self.lock_path.exists(): - return False - try: - with self.lock(fail_if_locked=True): - pass - return False - except pl.exceptions.LockException: - return True - - @override @contextmanager def lock( self, @@ -521,187 +281,3 @@ def lock( " environment variables to increase the timeout:" f"\n\n{pprint.pformat(ENV_VARS_USED)}" ) from e - - -def load_filebased_neps_state(directory: Path) -> NePSState[Path]: - """Load a NePSState from a directory. - - Args: - directory: The directory to load the state from. - - Returns: - The loaded NePSState. - - Raises: - FileNotFoundError: If no NePSState is found at the given directory. - """ - if not directory.exists(): - raise FileNotFoundError(f"No NePSState found at '{directory}'.") - directory.mkdir(parents=True, exist_ok=True) - config_dir = directory / "configs" - config_dir.mkdir(parents=True, exist_ok=True) - seed_dir = directory / ".seed_state" - seed_dir.mkdir(parents=True, exist_ok=True) - error_dir = directory / ".errors" - error_dir.mkdir(parents=True, exist_ok=True) - optimizer_state_dir = directory / ".optimizer_state" - optimizer_state_dir.mkdir(parents=True, exist_ok=True) - optimizer_info_dir = directory / ".optimizer_info" - optimizer_info_dir.mkdir(parents=True, exist_ok=True) - - return NePSState( - location=str(directory.absolute().resolve()), - _trials=TrialRepoInDirectory(config_dir), - _optimizer_info=Synced.load( - location=optimizer_info_dir, - versioner=FileVersioner(version_file=optimizer_info_dir / ".version"), - locker=FileLocker( - lock_path=optimizer_info_dir / ".lock", - poll=OPTIMIZER_INFO_FILELOCK_POLL, - timeout=OPTIMIZER_INFO_FILELOCK_TIMEOUT, - ), - reader_writer=ReaderWriterOptimizerInfo(), - ), - _seed_state=Synced.load( - location=seed_dir, - reader_writer=ReaderWriterSeedSnapshot(), - versioner=FileVersioner(version_file=seed_dir / ".version"), - locker=FileLocker( - lock_path=seed_dir / ".lock", - poll=SEED_SNAPSHOT_FILELOCK_POLL, - timeout=SEED_SNAPSHOT_FILELOCK_TIMEOUT, - ), - ), - _shared_errors=Synced.load( - location=error_dir, - reader_writer=ReaderWriterErrDump("all"), - versioner=FileVersioner(version_file=error_dir / ".all.version"), - locker=FileLocker( - lock_path=error_dir / ".all.lock", - poll=GLOBAL_ERR_FILELOCK_POLL, - timeout=GLOBAL_ERR_FILELOCK_TIMEOUT, - ), - ), - _optimizer_state=Synced.load( - location=optimizer_state_dir, - reader_writer=ReaderWriterOptimizationState(), - versioner=FileVersioner(version_file=optimizer_state_dir / ".version"), - locker=FileLocker( - lock_path=optimizer_state_dir / ".lock", - poll=OPTIMIZER_STATE_FILELOCK_POLL, - timeout=OPTIMIZER_STATE_FILELOCK_TIMEOUT, - ), - ), - ) - - -def create_or_load_filebased_neps_state( - directory: Path, - *, - optimizer_info: OptimizerInfo, - optimizer_state: OptimizationState, -) -> NePSState[Path]: - """Create a new NePSState in a directory or load the existing one - if it already exists. - - !!! warning - - We check that the optimizer info in the NePSState on disk matches - the one that is passed. However we do not lock this check so it - is possible that if two processes try to create a NePSState at the - same time, both with different optimizer infos, that one will fail - to create the NePSState. This is a limitation of the current design. - - In principal, we could allow multiple optimizers to be run and share - the same set of trials. - - Args: - directory: The directory to create the state in. - optimizer_info: The optimizer info to use. - optimizer_state: The optimizer state to use. - - Returns: - The NePSState. - - Raises: - NePSError: If the optimizer info on disk does not match the one provided. - """ - is_new = not directory.exists() - directory.mkdir(parents=True, exist_ok=True) - config_dir = directory / "configs" - config_dir.mkdir(parents=True, exist_ok=True) - seed_dir = directory / ".seed_state" - seed_dir.mkdir(parents=True, exist_ok=True) - error_dir = directory / ".errors" - error_dir.mkdir(parents=True, exist_ok=True) - optimizer_state_dir = directory / ".optimizer_state" - optimizer_state_dir.mkdir(parents=True, exist_ok=True) - optimizer_info_dir = directory / ".optimizer_info" - optimizer_info_dir.mkdir(parents=True, exist_ok=True) - - # We have to do one bit of sanity checking to ensure that the optimzier - # info on disk manages the one we have recieved, otherwise we are unsure which - # optimizer is being used. - # NOTE: We assume that we do not have to worry about a race condition - # here where we have two different NePSState objects with two different optimizer - # infos trying to be created at the same time. This avoids the need to lock to - # check the optimizer info. If this assumption changes, then we would have - # to first lock before we do this check - optimizer_info_reader_writer = ReaderWriterOptimizerInfo() - if not is_new: - existing_info = optimizer_info_reader_writer.read(optimizer_info_dir) - if existing_info != optimizer_info: - raise NePSError( - "The optimizer info on disk does not match the one provided." - f"\nOn disk: {existing_info}\nProvided: {optimizer_info}" - f"\n\nLoaded the one on disk from {optimizer_info_dir}." - ) - - return NePSState( - location=str(directory.absolute().resolve()), - _trials=TrialRepoInDirectory(config_dir), - _optimizer_info=Synced.new_or_load( - data=optimizer_info, # type: ignore - location=optimizer_info_dir, - versioner=FileVersioner(version_file=optimizer_info_dir / ".version"), - locker=FileLocker( - lock_path=optimizer_info_dir / ".lock", - poll=OPTIMIZER_INFO_FILELOCK_POLL, - timeout=OPTIMIZER_INFO_FILELOCK_TIMEOUT, - ), - reader_writer=ReaderWriterOptimizerInfo(), - ), - _seed_state=Synced.new_or_load( - data=SeedSnapshot.new_capture(), - location=seed_dir, - reader_writer=ReaderWriterSeedSnapshot(), - versioner=FileVersioner(version_file=seed_dir / ".version"), - locker=FileLocker( - lock_path=seed_dir / ".lock", - poll=SEED_SNAPSHOT_FILELOCK_POLL, - timeout=SEED_SNAPSHOT_FILELOCK_TIMEOUT, - ), - ), - _shared_errors=Synced.new_or_load( - data=ErrDump(), - location=error_dir, - reader_writer=ReaderWriterErrDump("all"), - versioner=FileVersioner(version_file=error_dir / ".all.version"), - locker=FileLocker( - lock_path=error_dir / ".all.lock", - poll=GLOBAL_ERR_FILELOCK_POLL, - timeout=GLOBAL_ERR_FILELOCK_TIMEOUT, - ), - ), - _optimizer_state=Synced.new_or_load( - data=optimizer_state, - location=optimizer_state_dir, - reader_writer=ReaderWriterOptimizationState(), - versioner=FileVersioner(version_file=optimizer_state_dir / ".version"), - locker=FileLocker( - lock_path=optimizer_state_dir / ".lock", - poll=OPTIMIZER_STATE_FILELOCK_POLL, - timeout=OPTIMIZER_STATE_FILELOCK_TIMEOUT, - ), - ), - ) diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 3d4f3186..39dc5265 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -11,70 +11,267 @@ from __future__ import annotations import logging +import pickle import time -from collections.abc import Callable +from collections.abc import Callable, Iterator +from contextlib import contextmanager from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Generic, TypeVar, overload - -from more_itertools import take - -from neps.exceptions import TrialAlreadyExistsError +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Generic, + Literal, + TypeAlias, + TypeVar, + overload, +) +from uuid import uuid4 + +from neps.env import ( + STATE_FILELOCK_POLL, + STATE_FILELOCK_TIMEOUT, + TRIAL_FILELOCK_POLL, + TRIAL_FILELOCK_TIMEOUT, +) +from neps.exceptions import NePSError, TrialAlreadyExistsError, TrialNotFoundError from neps.state.err_dump import ErrDump +from neps.state.filebased import ( + FileLocker, + ReaderWriterErrDump, + ReaderWriterOptimizationState, + ReaderWriterOptimizerInfo, + ReaderWriterSeedSnapshot, + TrialReaderWriter, +) from neps.state.optimizer import OptimizationState, OptimizerInfo -from neps.state.trial import Trial +from neps.state.seed_snapshot import SeedSnapshot +from neps.state.trial import Report, Trial if TYPE_CHECKING: from neps.optimizers.base_optimizer import BaseOptimizer - from neps.state.protocols import Synced, TrialRepo - from neps.state.seed_snapshot import SeedSnapshot logger = logging.getLogger(__name__) +N_UNSAFE_RETRIES = 10 # TODO: Technically we don't need the same Location type for all shared objects. Loc = TypeVar("Loc") T = TypeVar("T") +Version: TypeAlias = str + +Resource: TypeAlias = Literal[ + "optimizer_info", "optimizer_state", "seed_state", "errors", "configs" +] + + +def make_sha() -> Version: + """Generate a str hex sha.""" + return uuid4().hex + + +CONFIG_PREFIX_LEN = len("config_") + + +# TODO: Ergonomics of this class sucks +@dataclass +class TrialRepo: + directory: Path + version_file: Path + cache: dict[str, tuple[Trial, Version]] = field(default_factory=dict) + + def __post_init__(self) -> None: + self.directory.mkdir(parents=True, exist_ok=True) + + def list_trial_ids(self) -> list[str]: + return [ + config_path.name[CONFIG_PREFIX_LEN:] + for config_path in self.directory.iterdir() + if config_path.name.startswith("config_") and config_path.is_dir() + ] + + def latest(self) -> dict[str, Trial]: + if not self.version_file.exists(): + return {} + + with self.version_file.open("rb") as f: + versions_on_disk = pickle.load(f) # noqa: S301 + + stale = { + k: v + for k, v in versions_on_disk.items() + if self.cache.get(k, (None, "__not_found__")) != v + } + for trial_id, disk_version in stale.items(): + loaded_trial = self.load_trial_from_disk(trial_id) + self.cache[trial_id] = (loaded_trial, disk_version) + + return {k: v[0] for k, v in self.cache.items()} + + def new_trial(self, trial: Trial) -> None: + config_path = self.directory / f"config_{trial.id}" + if config_path.exists(): + raise TrialAlreadyExistsError(trial.id, config_path) + config_path.mkdir(parents=True, exist_ok=True) + self.update_trial(trial) + + def update_trial(self, trial: Trial) -> None: + new_version = make_sha() + TrialReaderWriter.write(trial, self.directory / f"config_{trial.id}") + self.cache[trial.id] = (trial, new_version) + + def write_version_file(self) -> None: + with self.version_file.open("wb") as f: + pickle.dump({k: v[1] for k, v in self.cache.items()}, f) + + def trials_in_memory(self) -> dict[str, Trial]: + return {k: v[0] for k, v in self.cache.items()} + + def load_trial_from_disk(self, trial_id: str) -> Trial: + config_path = self.directory / f"config_{trial_id}" + if not config_path.exists(): + raise TrialNotFoundError(trial_id, config_path) + + return TrialReaderWriter.read(config_path) + @dataclass -class NePSState(Generic[Loc]): +class VersionedResource(Generic[T]): + resource: T + path: Path + read: Callable[[Path], T] + write: Callable[[T, Path], None] + version_file: Path + version: Version = "__not_yet_written__" + + def latest(self) -> T: + if not self.version_file.exists(): + return self.resource + + file_version = self.version_file.read_text() + if self.version == file_version: + return self.resource + + self.resource = self.read(self.path) + self.version = file_version + return self.resource + + def update(self, new_resource: T) -> Version: + self.resource = new_resource + self.version = make_sha() + self.version_file.write_text(self.version) + self.write(new_resource, self.path) + return self.version + + @classmethod + def new( + cls, + resource: T, + path: Path, + read: Callable[[Path], T], + write: Callable[[T, Path], None], + version_file: Path, + ) -> VersionedResource[T]: + if version_file.exists(): + raise FileExistsError(f"Version file already exists at '{version_file}'.") + + write(resource, path) + version = make_sha() + version_file.write_text(version) + return cls( + resource=resource, + path=path, + read=read, + write=write, + version_file=version_file, + version=version, + ) + + @classmethod + def load( + cls, + path: Path, + *, + read: Callable[[Path], T], + write: Callable[[T, Path], None], + version_file: Path, + ) -> VersionedResource[T]: + if not path.exists(): + raise FileNotFoundError(f"Resource not found at '{path}'.") + + return cls( + resource=read(path), + path=path, + read=read, + write=write, + version_file=version_file, + version=version_file.read_text(), + ) + + +@dataclass +class NePSState: """The main state object that holds all the shared state objects.""" - location: str + path: Path - _trials: TrialRepo[Loc] = field(repr=False) - _optimizer_info: Synced[OptimizerInfo, Loc] - _seed_state: Synced[SeedSnapshot, Loc] = field(repr=False) - _optimizer_state: Synced[OptimizationState, Loc] - _shared_errors: Synced[ErrDump, Loc] = field(repr=False) + _trial_lock: FileLocker = field(repr=False) + _trials: TrialRepo = field(repr=False) - def put_updated_trial(self, trial: Trial, /) -> None: - """Update the trial with the new information. + _state_lock: FileLocker = field(repr=False) + _optimizer_info: VersionedResource[OptimizerInfo] = field(repr=False) + _seed_snapshot: VersionedResource[SeedSnapshot] = field(repr=False) + _optimizer_state: VersionedResource[OptimizationState] = field(repr=False) - Args: - trial: The trial to update. + _err_lock: FileLocker = field(repr=False) + _shared_errors: VersionedResource[ErrDump] = field(repr=False) - Raises: - VersionMismatchError: If the trial has been updated since it was last - fetched by the worker using this state. This indicates that some other - worker has updated the trial in the meantime and the changes from - this worker are rejected. - """ - shared_trial = self._trials.get_by_id(trial.id) - shared_trial.put(trial) + @contextmanager + def lock_for_sampling(self) -> Iterator[None]: + """Acquire the state lock and trials lock.""" + with self._state_lock.lock(), self._trial_lock.lock(): + yield - def get_trial_by_id(self, trial_id: str, /) -> Trial: - """Get a trial by its id.""" - return self._trials.get_by_id(trial_id).synced() + @contextmanager + def lock_trials(self) -> Iterator[None]: + """Acquire the state lock.""" + with self._trial_lock.lock(): + yield + + def lock_and_read_trials(self) -> dict[str, Trial]: + """Acquire the state lock and read the trials.""" + with self._trial_lock.lock(): + return self._trials.latest() - def sample_trial( + def lock_and_sample_trial(self, optimizer: BaseOptimizer, *, worker_id: str) -> Trial: + """Acquire the state lock and sample a trial.""" + with self.lock_for_sampling(): + return self._sample_trial(optimizer, worker_id=worker_id) + + def lock_and_report_trial_evaluation( + self, + trial: Trial, + report: Report, + *, + worker_id: str, + ) -> None: + """Acquire the state lock and report the trial evaluation.""" + with self._trial_lock.lock(), self._err_lock.lock(): + self._report_trial_evaluation(trial, report, worker_id=worker_id) + + def _sample_trial( self, optimizer: BaseOptimizer, *, worker_id: str, _sample_hooks: list[Callable] | None = None, + _trials: dict[str, Trial] | None = None, ) -> Trial: """Sample a new trial from the optimizer. + !!! warning + + Responsibility of locking is on caller. + Args: optimizer: The optimizer to sample the trial from. worker_id: The worker that is sampling the trial. @@ -83,110 +280,100 @@ def sample_trial( Returns: The new trial. """ - with ( - self._optimizer_state.acquire() as (opt_state, put_opt), - self._seed_state.acquire() as (seed_state, put_seed_state), - ): - # NOTE: We make the assumption that as we have acquired the optimizer - # state, there is not possibility of another trial being created between - # the time we read in the trials below and `ask()`ing for the next trials - # from the optimizer. If so, that means there is another source of trial - # generation that occurs outside of this function and outside the scope - # of acquiring the optimizer_state lock. - trials: dict[str, Trial] = { - trial_id: shared_trial.synced() - for trial_id, shared_trial in list(self._trials.all().items()) - } - - seed_state.set_as_global_seed_state() - - # TODO: Not sure if any existing pre_load hooks required - # it to be done after `load_results`... I hope not. - if _sample_hooks is not None: - for hook in _sample_hooks: - optimizer = hook(optimizer) - - # NOTE: Re-work this, as the part's that are recomputed - # do not need to be serialized - budget = opt_state.budget - if budget is not None: - budget = budget.clone() - - # NOTE: All other values of budget are ones that should remain - # constant, there are currently only these two which are dynamic as - # optimization unfold - budget.used_cost_budget = sum( - trial.report.cost - for trial in trials.values() - if trial.report is not None and trial.report.cost is not None - ) - budget.used_evaluations = len(trials) - - sampled_config_maybe_new_opt_state = optimizer.ask( - trials=trials, - budget_info=budget, + trials = self._trials.latest() if _trials is None else _trials + seed_state = self._seed_snapshot.latest() + opt_state = self._optimizer_state.latest() + + seed_state.set_as_global_seed_state() + + # TODO: Not sure if any existing pre_load hooks required + # it to be done after `load_results`... I hope not. + if _sample_hooks is not None: + for hook in _sample_hooks: + optimizer = hook(optimizer) + + # NOTE: Re-work this, as the part's that are recomputed + # do not need to be serialized + budget = opt_state.budget + if budget is not None: + budget = budget.clone() + + # NOTE: All other values of budget are ones that should remain + # constant, there are currently only these two which are dynamic as + # optimization unfold + budget.used_cost_budget = sum( + trial.report.cost + for trial in trials.values() + if trial.report is not None and trial.report.cost is not None ) - - if isinstance(sampled_config_maybe_new_opt_state, tuple): - sampled_config, new_opt_state = sampled_config_maybe_new_opt_state - else: - sampled_config = sampled_config_maybe_new_opt_state - new_opt_state = opt_state.shared_state - - if sampled_config.previous_config_id is not None: - previous_trial = trials.get(sampled_config.previous_config_id) - if previous_trial is None: - raise ValueError( - f"Previous trial '{sampled_config.previous_config_id}' not found." - ) - previous_trial_location = previous_trial.metadata.location + budget.used_evaluations = len(trials) + + sampled_config_maybe_new_opt_state = optimizer.ask( + trials=trials, + budget_info=budget, + ) + + if isinstance(sampled_config_maybe_new_opt_state, tuple): + sampled_config, new_opt_state = sampled_config_maybe_new_opt_state + else: + sampled_config = sampled_config_maybe_new_opt_state + new_opt_state = opt_state.shared_state + + if sampled_config.previous_config_id is not None: + previous_trial = trials.get(sampled_config.previous_config_id) + if previous_trial is None: + raise ValueError( + f"Previous trial '{sampled_config.previous_config_id}' not found." + ) + previous_trial_location = previous_trial.metadata.location + else: + previous_trial_location = None + + trial = Trial.new( + trial_id=sampled_config.id, + location="", # HACK: This will be set by the `TrialRepo` in `put_new` + config=sampled_config.config, + previous_trial=sampled_config.previous_config_id, + previous_trial_location=previous_trial_location, + time_sampled=time.time(), + worker_id=worker_id, + ) + try: + self._trials.new_trial(trial) + self._trials.write_version_file() + except TrialAlreadyExistsError as e: + if sampled_config.id in trials: + logger.warning( + "The new sampled trial was given an id of '%s', yet this already" + " exists in the loaded in trials given to the optimizer. This" + " indicates a bug with the optimizers allocation of ids.", + sampled_config.id, + ) else: - previous_trial_location = None - - trial = Trial.new( - trial_id=sampled_config.id, - location="", # HACK: This will be set by the `TrialRepo` in `put_new` - config=sampled_config.config, - previous_trial=sampled_config.previous_config_id, - previous_trial_location=previous_trial_location, - time_sampled=time.time(), - worker_id=worker_id, - ) - try: - self._trials.put_new(trial) - except TrialAlreadyExistsError as e: - if sampled_config.id in trials: - logger.warning( - "The new sampled trial was given an id of '%s', yet this already" - " exists in the loaded in trials given to the optimizer. This" - " indicates a bug with the optimizers allocation of ids.", - sampled_config.id, - ) - else: - logger.warning( - "The new sampled trial was given an id of '%s', which is not one" - " that was loaded in by the optimizer. This indicates that" - " configuration '%s' was put on disk during the time that this" - " worker had the optimizer state lock OR that after obtaining the" - " optimizer state lock, somehow this configuration failed to be" - " loaded in and passed to the optimizer.", - sampled_config.id, - sampled_config.id, - ) - raise e + logger.warning( + "The new sampled trial was given an id of '%s', which is not one" + " that was loaded in by the optimizer. This indicates that" + " configuration '%s' was put on disk during the time that this" + " worker had the optimizer state lock OR that after obtaining the" + " optimizer state lock, somehow this configuration failed to be" + " loaded in and passed to the optimizer.", + sampled_config.id, + sampled_config.id, + ) + raise e - seed_state.recapture() - put_seed_state(seed_state) - put_opt( - OptimizationState(budget=opt_state.budget, shared_state=new_opt_state) - ) + seed_state.recapture() + self._seed_snapshot.update(seed_state) + self._optimizer_state.update( + OptimizationState(budget=opt_state.budget, shared_state=new_opt_state) + ) return trial - def report_trial_evaluation( + def _report_trial_evaluation( self, trial: Trial, - report: Trial.Report, + report: Report, *, worker_id: str, ) -> None: @@ -199,61 +386,248 @@ def report_trial_evaluation( optimizer: The optimizer to update and get the state from worker_id: The worker that evaluated the trial. """ - shared_trial = self._trials.get_by_id(trial.id) - # TODO: This would fail if some other worker has already updated the trial. - # IMPORTANT: We need to attach the report to the trial before updating the things. trial.report = report - shared_trial.put(trial) + self._trials.update_trial(trial) + self._trials.write_version_file() + logger.debug("Updated trial '%s' with status '%s'", trial.id, trial.state) if report.err is not None: - with self._shared_errors.acquire() as (errs, put_errs): - trial_err = ErrDump.SerializableTrialError( - trial_id=trial.id, - worker_id=worker_id, - err_type=type(report.err).__name__, - err=str(report.err), - tb=report.tb, + with self._err_lock.lock(): + err_dump = self._shared_errors.latest() + err_dump.errs.append( + ErrDump.SerializableTrialError( + trial_id=trial.id, + worker_id=worker_id, + err_type=type(report.err).__name__, + err=str(report.err), + tb=report.tb, + ) ) - errs.append(trial_err) - put_errs(errs) + self._shared_errors.update(err_dump) + + def all_trial_ids(self) -> list[str]: + """Get all the trial ids.""" + return self._trials.list_trial_ids() - def get_errors(self) -> ErrDump: + def lock_and_get_errors(self) -> ErrDump: """Get all the errors that have occurred during the optimization.""" - return self._shared_errors.synced() + with self._err_lock.lock(): + return self._shared_errors.latest() + + def lock_and_get_optimizer_info(self) -> OptimizerInfo: + """Get the optimizer information.""" + with self._state_lock.lock(): + return self._optimizer_info.latest() + + def lock_and_get_optimizer_state(self) -> OptimizationState: + """Get the optimizer state.""" + with self._state_lock.lock(): + return self._optimizer_state.latest() + + def lock_and_get_trial_by_id(self, trial_id: str) -> Trial: + """Get a trial by its id.""" + with self._trial_lock.lock(): + return self._trials.load_trial_from_disk(trial_id) + + def unsafe_retry_get_trial_by_id(self, trial_id: str) -> Trial: + """Get a trial by id but use unsafe retries.""" + for _ in range(N_UNSAFE_RETRIES): + try: + return self._trials.load_trial_from_disk(trial_id) + except TrialNotFoundError as e: + raise e + except Exception as e: # noqa: BLE001 + logger.warning( + "Failed to get trial '%s' due to an error: %s", trial_id, e + ) + time.sleep(0.1) + continue + + raise NePSError( + f"Failed to get trial '{trial_id}' after {N_UNSAFE_RETRIES} retries." + ) + + def put_updated_trial(self, trial: Trial) -> None: + """Update the trial.""" + with self._trial_lock.lock(): + self._trials.update_trial(trial) + self._trials.write_version_file() @overload - def get_next_pending_trial(self) -> Trial | None: ... + def lock_and_get_next_pending_trial(self) -> Trial | None: ... @overload - def get_next_pending_trial(self, n: int | None = None) -> list[Trial]: ... + def lock_and_get_next_pending_trial(self, n: int | None = None) -> list[Trial]: ... + + def lock_and_get_next_pending_trial( + self, + n: int | None = None, + ) -> Trial | list[Trial] | None: + """Get the next pending trial.""" + with self._trial_lock.lock(): + trials = self._trials.latest() + pendings = sorted( + [ + trial + for trial in trials.values() + if trial.state == Trial.State.PENDING + ], + key=lambda t: t.metadata.time_sampled, + ) + if n is None: + return pendings[0] if pendings else None + return pendings[:n] + + @classmethod + def create_or_load( + cls, + path: Path, + *, + load_only: bool = False, + optimizer_info: OptimizerInfo | None = None, + optimizer_state: OptimizationState | None = None, + ) -> NePSState: + """Create a new NePSState in a directory or load the existing one + if it already exists, depending on the argument. - def get_next_pending_trial(self, n: int | None = None) -> Trial | list[Trial] | None: - """Get the next pending trial to evaluate. + !!! warning - Args: - n: The number of trials to get. If `None`, get the next trial. + We check that the optimizer info in the NePSState on disk matches + the one that is passed. However we do not lock this check so it + is possible that if two processes try to create a NePSState at the + same time, both with different optimizer infos, that one will fail + to create the NePSState. This is a limitation of the current design. - Returns: - The next trial or a list of trials if `n` is not `None`. - """ - _pending_itr = (shared_trial for _, shared_trial in self._trials.pending()) - if n is not None: - return take(n, _pending_itr) - return next(_pending_itr, None) + In principal, we could allow multiple optimizers to be run and share + the same set of trials. - def all_trial_ids(self) -> list[str]: - """Get all the trial ids that are known about.""" - return self._trials.all_trial_ids() + Args: + path: The directory to create the state in. + load_only: If True, only load the state and do not create a new one. + optimizer_info: The optimizer info to use. + optimizer_state: The optimizer state to use. - def get_all_trials(self) -> dict[str, Trial]: - """Get all the trials that are known about.""" - return {_id: trial.synced() for _id, trial in self._trials.all().items()} + Returns: + The NePSState. - def optimizer_info(self) -> OptimizerInfo: - """Get the optimizer information.""" - return self._optimizer_info.synced() + Raises: + NePSError: If the optimizer info on disk does not match the one provided. + """ + is_new = not path.exists() + if load_only: + if is_new: + raise FileNotFoundError(f"No NePSState found at '{path}'.") + else: + assert optimizer_info is not None + assert optimizer_state is not None + + path.mkdir(parents=True, exist_ok=True) + config_dir = path / "configs" + config_dir.mkdir(parents=True, exist_ok=True) + seed_dir = path / ".seed_state" + seed_dir.mkdir(parents=True, exist_ok=True) + error_dir = path / ".errors" + error_dir.mkdir(parents=True, exist_ok=True) + optimizer_state_dir = path / ".optimizer_state" + optimizer_state_dir.mkdir(parents=True, exist_ok=True) + optimizer_info_dir = path / ".optimizer_info" + optimizer_info_dir.mkdir(parents=True, exist_ok=True) + + # We have to do one bit of sanity checking to ensure that the optimzier + # info on disk manages the one we have recieved, otherwise we are unsure which + # optimizer is being used. + # NOTE: We assume that we do not have to worry about a race condition + # here where we have two different NePSState objects with two different optimizer + # infos trying to be created at the same time. This avoids the need to lock to + # check the optimizer info. If this assumption changes, then we would have + # to first lock before we do this check + if not is_new: + _optimizer_info = VersionedResource.load( + optimizer_info_dir, + read=ReaderWriterOptimizerInfo.read, + write=ReaderWriterOptimizerInfo.write, + version_file=optimizer_info_dir / ".version", + ) + _optimizer_state = VersionedResource.load( + optimizer_state_dir, + read=ReaderWriterOptimizationState.read, + write=ReaderWriterOptimizationState.write, + version_file=optimizer_state_dir / ".version", + ) + _seed_snapshot = VersionedResource.load( + seed_dir, + read=ReaderWriterSeedSnapshot.read, + write=ReaderWriterSeedSnapshot.write, + version_file=seed_dir / ".version", + ) + _shared_errors = VersionedResource.load( + error_dir, + read=ReaderWriterErrDump.read, + write=ReaderWriterErrDump.write, + version_file=error_dir / ".version", + ) + existing_info = _optimizer_info.latest() + if not load_only and existing_info != optimizer_info: + raise NePSError( + "The optimizer info on disk does not match the one provided." + f"\nOn disk: {existing_info}\nProvided: {optimizer_info}" + f"\n\nLoaded the one on disk from {optimizer_info_dir}." + ) + else: + assert optimizer_info is not None + assert optimizer_state is not None + _optimizer_info = VersionedResource.new( + resource=optimizer_info, + path=optimizer_info_dir, + read=ReaderWriterOptimizerInfo.read, + write=ReaderWriterOptimizerInfo.write, + version_file=optimizer_info_dir / ".version", + ) + _optimizer_state = VersionedResource.new( + resource=optimizer_state, + path=optimizer_state_dir, + read=ReaderWriterOptimizationState.read, + write=ReaderWriterOptimizationState.write, + version_file=optimizer_state_dir / ".version", + ) + _seed_snapshot = VersionedResource.new( + resource=SeedSnapshot.new_capture(), + path=seed_dir, + read=ReaderWriterSeedSnapshot.read, + write=ReaderWriterSeedSnapshot.write, + version_file=seed_dir / ".version", + ) + _shared_errors = VersionedResource.new( + resource=ErrDump(), + path=error_dir, + read=ReaderWriterErrDump.read, + write=ReaderWriterErrDump.write, + version_file=error_dir / ".version", + ) - def optimizer_state(self) -> OptimizationState: - """Get the optimizer state.""" - return self._optimizer_state.synced() + return cls( + path=path, + _trials=TrialRepo(config_dir, version_file=config_dir / ".versions"), + # Locks, + _trial_lock=FileLocker( + lock_path=path / ".configs.lock", + poll=TRIAL_FILELOCK_POLL, + timeout=TRIAL_FILELOCK_TIMEOUT, + ), + _state_lock=FileLocker( + lock_path=path / ".state.lock", + poll=STATE_FILELOCK_POLL, + timeout=STATE_FILELOCK_TIMEOUT, + ), + _err_lock=FileLocker( + lock_path=error_dir / "errors.lock", + poll=TRIAL_FILELOCK_POLL, + timeout=TRIAL_FILELOCK_TIMEOUT, + ), + # State + _optimizer_info=_optimizer_info, + _optimizer_state=_optimizer_state, + _seed_snapshot=_seed_snapshot, + _shared_errors=_shared_errors, + ) diff --git a/neps/state/protocols.py b/neps/state/protocols.py deleted file mode 100644 index 7bbe7a9a..00000000 --- a/neps/state/protocols.py +++ /dev/null @@ -1,577 +0,0 @@ -"""This module defines the protocols used by -[`NePSState`][neps.state.neps_state.NePSState] and -[`Synced`][neps.state.synced.Synced] to ensure atomic operations to the state itself. -""" - -from __future__ import annotations - -import logging -from collections.abc import Callable, Iterable, Iterator -from contextlib import contextmanager -from copy import deepcopy -from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar -from typing_extensions import Self - -from neps.exceptions import ( - LockFailedError, - TrialAlreadyExistsError, - TrialNotFoundError, - VersionedResourceAlreadyExistsError, - VersionedResourceDoesNotExistError, - VersionedResourceRemovedError, - VersionMismatchError, -) - -if TYPE_CHECKING: - from neps.state import Trial - -logger = logging.getLogger(__name__) - -T = TypeVar("T") -K = TypeVar("K") - -# https://github.com/MaT1g3R/option/issues/40 -K2 = TypeVar("K2") -T2 = TypeVar("T2") - -Loc_contra = TypeVar("Loc_contra", contravariant=True) - - -class Versioner(Protocol): - """A versioner that can bump the version of a resource. - - It should have some [`current()`][neps.state.protocols.Versioner.current] method - to give the current version tag of a resource and a - [`bump()`][neps.state.protocols.Versioner.bump] method to provide a new version tag. - - These [`current()`][neps.state.protocols.Versioner.current] and - [`bump()`][neps.state.protocols.Versioner.bump] methods do not need to be atomic - but they should read/write to external state, i.e. file-system, database, etc. - """ - - def current(self) -> str | None: - """Return the current version as defined by the external state, i.e. - the version of the tag on disk. - - Returns: - The current version if there is one written. - """ - ... - - def bump(self) -> str: - """Create a new external version tag. - - Returns: - The new version tag. - """ - ... - - -class Locker(Protocol): - """A locker that can be used to communicate between workers.""" - - LockFailedError: ClassVar = LockFailedError - - @contextmanager - def lock(self) -> Iterator[None]: - """Initiate the lock as a context manager, releasing it when done.""" - ... - - def is_locked(self) -> bool: - """Check if lock is...well, locked. - - Should return True if the resource is locked, even if the lock is held by the - current worker/process. - """ - ... - - -class ReaderWriter(Protocol[T, Loc_contra]): - """A reader-writer that can read and write some resource T with location Loc. - - For example, a `ReaderWriter[Trial, Path]` indicates a class that can read and write - trials, given some `Path`. - """ - - CHEAP_LOCKLESS_READ: ClassVar[bool] - """Whether reading the contents of the resource is cheap, cheap enough to be - most likely safe without a lock if outdated information is acceptable. - - This is currently used to help debugging instances of a VersionMismatchError - to see what the current state is and what was attempted to be written. - """ - - def read(self, loc: Loc_contra, /) -> T: - """Read the resource at the given location.""" - ... - - def write(self, value: T, loc: Loc_contra, /) -> None: - """Write the resource at the given location.""" - ... - - -class TrialRepo(Protocol[K]): - """A repository of trials. - - The primary purpose of this protocol is to ensure consistent access to trial, - the ability to put in a new trial and know about the trials that are stored there. - """ - - TrialAlreadyExistsError: ClassVar = TrialAlreadyExistsError - TrialNotFoundError: ClassVar = TrialNotFoundError - - def all_trial_ids(self) -> list[str]: - """List all the trial ids in this trial Repo.""" - ... - - def get_by_id(self, trial_id: str) -> Synced[Trial, K]: - """Get a trial by its id.""" - ... - - def put_new(self, trial: Trial) -> Synced[Trial, K]: - """Put a new trial in the repo.""" - ... - - def all(self) -> dict[str, Synced[Trial, K]]: - """Get all trials in the repo.""" - ... - - def pending(self) -> Iterable[tuple[str, Trial]]: - """Get all pending trials in the repo. - - !!! note - This should return trials in the order in which they should be next evaluated, - usually the order in which they were put in the repo. - """ - ... - - -@dataclass -class VersionedResource(Generic[T, K]): - """A resource that will be read if it needs to update to the latest version. - - Relies on 3 main components: - * A [`Versioner`][neps.state.protocols.Versioner] to manage the versioning of the - resource. - * A [`ReaderWriter`][neps.state.protocols.ReaderWriter] to read and write the - resource. - * The location of the resource that can be used for the reader-writer. - """ - - VersionMismatchError: ClassVar = VersionMismatchError - VersionedResourceDoesNotExistsError: ClassVar = VersionedResourceDoesNotExistError - VersionedResourceAlreadyExistsError: ClassVar = VersionedResourceAlreadyExistsError - VersionedResourceRemovedError: ClassVar = VersionedResourceRemovedError - - _current: T - _location: K - _version: str - _versioner: Versioner - _reader_writer: ReaderWriter[T, K] - - @staticmethod - def new( - *, - data: T2, - location: K2, - versioner: Versioner, - reader_writer: ReaderWriter[T2, K2], - ) -> VersionedResource[T2, K2]: - """Create a new VersionedResource. - - This will create a new resource if it doesn't exist, otherwise, - if it already exists, it will raise an error. - - Use [`load()`][neps.state.protocols.VersionedResource.load] if you want to - load an existing resource. - - Args: - data: The data to be stored. - location: The location where the data will be stored. - versioner: The versioner to be used. - reader_writer: The reader-writer to be used. - - Returns: - A new VersionedResource - - Raises: - VersionedResourceAlreadyExistsError: If a versioned resource already exists - at the given location. - """ - current_version = versioner.current() - if current_version is not None: - raise VersionedResourceAlreadyExistsError( - f"A versioned resource already already exists at '{location}'" - f" with version '{current_version}'" - ) - - version = versioner.bump() - reader_writer.write(data, location) - return VersionedResource( - _current=data, - _location=location, - _version=version, - _versioner=versioner, - _reader_writer=reader_writer, - ) - - @classmethod - def load( - cls, - *, - location: K2, - versioner: Versioner, - reader_writer: ReaderWriter[T2, K2], - ) -> VersionedResource[T2, K2]: - """Load an existing VersionedResource. - - This will load an existing resource if it exists, otherwise, it will raise an - error. - - Use [`new()`][neps.state.protocols.VersionedResource.new] if you want to - create a new resource. - - Args: - location: The location of the resource. - versioner: The versioner to be used. - reader_writer: The reader-writer to be used. - - Returns: - A VersionedResource - - Raises: - VersionedResourceDoesNotExistsError: If no versioned resource exists at - the given location. - """ - version = versioner.current() - if version is None: - raise cls.VersionedResourceDoesNotExistsError( - f"No versioned resource exists at '{location}'." - ) - data = reader_writer.read(location) - return VersionedResource( - _current=data, - _location=location, - _version=version, - _versioner=versioner, - _reader_writer=reader_writer, - ) - - def sync_and_get(self) -> T: - """Get the data and version of the resource.""" - self.sync() - return self._current - - def sync(self) -> None: - """Sync the resource with the latest version.""" - current_version = self._versioner.current() - if current_version is None: - raise self.VersionedResourceRemovedError( - f"Versioned resource at '{self._location}' has been removed!" - f" Last known version was '{self._version}'." - ) - - if self._version != current_version: - self._current = self._reader_writer.read(self._location) - self._version = current_version - - def put(self, data: T) -> None: - """Put the data and version of the resource. - - Raises: - VersionMismatchError: If the version of the resource is not the same as the - current version. This implies that the resource has been updated by - another worker. - """ - current_version = self._versioner.current() - if self._version != current_version: - # We will attempt to do a lockless read on the contents of the items, as this - # would allow us to better debug in the error raised below. - if self._reader_writer.CHEAP_LOCKLESS_READ: - current_contents = self._reader_writer.read(self._location) - extra_msg = ( - f"\nThe attempted write was: {data}\n" - f"The current contents are: {current_contents}" - ) - else: - extra_msg = "" - - raise self.VersionMismatchError( - f"Version mismatch - ours: '{self._version}', remote: '{current_version}'" - f" Tried to put data at '{self._location}'. Doing so would overwrite" - " changes made by another worker. The solution is to pull the latest" - " version of the resource and try again." - " The most possible reasons for this error is that a lock was not" - " utilized when getting this resource before putting it back." - f"{extra_msg}" - ) - - self._reader_writer.write(data, self._location) - self._current = data - self._version = self._versioner.bump() - - def current(self) -> T: - """Get the current data of the resource.""" - return self._current - - def is_stale(self) -> bool: - """Check if the resource is stale.""" - return self._version != self._versioner.current() - - def location(self) -> K: - """Get the location of the resource.""" - return self._location - - -@dataclass -class Synced(Generic[T, K]): - """Manages a versioned resource but it's methods also implement locking procedures - for accessing it. - - Its types are parametrized by two type variables: - - * `T` is the type of the data stored in the resource. - * `K` is the type of the location of the resource, for example `Path` - - This wraps a [`VersionedResource`][neps.state.protocols.VersionedResource] and - additionally provides utility to perform atmoic operations on it using a - [`Locker`][neps.state.protocols.Locker]. - - This is used by [`NePSState`][neps.state.neps_state.NePSState] to manage the state - of trials and other shared resources. - - It consists of 2 main components: - - * A [`VersionedResource`][neps.state.protocols.VersionedResource] to manage the - versioning of the resource. - * A [`Locker`][neps.state.protocols.Locker] to manage the locking of the resource. - - The primary methods to interact with a resource that is behined a `Synced` are: - - * [`synced()`][neps.state.protocols.Synced.synced] to get the data of the resource - after syncing it to it's latest verison. - * [`acquire()`][neps.state.protocols.Synced.acquire] context manager to get latest - version of the data while also mainting a lock on it. This additionally provides - a `put()` operation to put the data back. This can primarily be used to get the - data, perform some mutation on it and then put it back, while not allowing other - workers access to the data. - """ - - LockFailedError: ClassVar = Locker.LockFailedError - VersionedResourceRemovedError: ClassVar = ( - VersionedResource.VersionedResourceRemovedError - ) - VersionMismatchError: ClassVar = VersionedResource.VersionMismatchError - VersionedResourceAlreadyExistsError: ClassVar = ( - VersionedResource.VersionedResourceAlreadyExistsError - ) - VersionedResourceDoesNotExistsError: ClassVar = ( - VersionedResource.VersionedResourceDoesNotExistsError - ) - - _resource: VersionedResource[T, K] - _locker: Locker - - @classmethod - def new( - cls, - *, - locker: Locker, - data: T2, - location: K2, - versioner: Versioner, - reader_writer: ReaderWriter[T2, K2], - ) -> Synced[T2, K2]: - """Create a new Synced resource. - - This will create a new resource if it doesn't exist, otherwise, - if it already exists, it will raise an error. - - Use [`load()`][neps.state.protocols.Synced.load] if you want to load an existing - resource. Use [`new_or_load()`][neps.state.protocols.Synced.new_or_load] if you - want to create a new resource if it doesn't exist, otherwise load an existing - resource. - - Args: - locker: The locker to be used. - data: The data to be stored. - location: The location where the data will be stored. - versioner: The versioner to be used. - reader_writer: The reader-writer to be used. - - Returns: - A new Synced resource. - - Raises: - VersionedResourceAlreadyExistsError: If a versioned resource already exists - at the given location. - """ - with locker.lock(): - vr = VersionedResource.new( - data=data, - location=location, - versioner=versioner, - reader_writer=reader_writer, - ) - return Synced(_resource=vr, _locker=locker) - - @classmethod - def load( - cls, - *, - locker: Locker, - location: K2, - versioner: Versioner, - reader_writer: ReaderWriter[T2, K2], - ) -> Synced[T2, K2]: - """Load an existing Synced resource. - - This will load an existing resource if it exists, otherwise, it will raise an - error. - - Use [`new()`][neps.state.protocols.Synced.new] if you want to create a new - resource. Use [`new_or_load()`][neps.state.protocols.Synced.new_or_load] if you - want to create a new resource if it doesn't exist, otherwise load an existing - resource. - - Args: - locker: The locker to be used. - location: The location of the resource. - versioner: The versioner to be used. - reader_writer: The reader-writer to be used. - - Returns: - A Synced resource. - - Raises: - VersionedResourceDoesNotExistsError: If no versioned resource exists at - the given location. - """ - with locker.lock(): - return Synced( - _resource=VersionedResource.load( - location=location, - versioner=versioner, - reader_writer=reader_writer, - ), - _locker=locker, - ) - - @classmethod - def new_or_load( - cls, - *, - locker: Locker, - data: T2, - location: K2, - versioner: Versioner, - reader_writer: ReaderWriter[T2, K2], - ) -> Synced[T2, K2]: - """Create a new Synced resource if it doesn't exist, otherwise load it. - - This will create a new resource if it doesn't exist, otherwise, it will load - an existing resource. - - Use [`new()`][neps.state.protocols.Synced.new] if you want to create a new - resource and fail otherwise. Use [`load()`][neps.state.protocols.Synced.load] - if you want to load an existing resource and fail if it doesn't exist. - - Args: - locker: The locker to be used. - data: The data to be stored. - - !!! warning - - This will be ignored if the data already exists. - - location: The location where the data will be stored. - versioner: The versioner to be used. - reader_writer: The reader-writer to be used. - - Returns: - A Synced resource. - """ - try: - return Synced.new( - locker=locker, - data=data, - location=location, - versioner=versioner, - reader_writer=reader_writer, - ) - except VersionedResourceAlreadyExistsError: - return Synced.load( - locker=locker, - location=location, - versioner=versioner, - reader_writer=reader_writer, - ) - - def synced(self) -> T: - """Get the data of the resource atomically.""" - with self._locker.lock(): - return self._resource.sync_and_get() - - def location(self) -> K: - """Get the location of the resource.""" - return self._resource.location() - - def put(self, data: T) -> None: - """Update the data atomically.""" - with self._locker.lock(): - self._resource.put(data) - - @contextmanager - def acquire(self) -> Iterator[tuple[T, Callable[[T], None]]]: - """Acquire the lock and get the data of the resource. - - This is a context manager that returns the data of the resource and a function - to put the data back. - - !!! note - This is the primary way to get the resource, mutate it and put it back. - Otherwise you likely want [`synced()`][neps.state.protocols.Synced.synced] - or [`put()`][neps.state.protocols.Synced.put]. - - Yields: - A tuple containing the data of the resource and a function to put the data - back. - """ - with self._locker.lock(): - self._resource.sync() - yield self._resource.current(), self._put_unsafe - - def deepcopy(self) -> Self: - """Create a deep copy of the shared resource.""" - return deepcopy(self) - - def _components(self) -> tuple[T, K, Versioner, ReaderWriter[T, K], Locker]: - """Get the components of the shared resource.""" - return ( - self._resource.current(), - self._resource.location(), - self._resource._versioner, - self._resource._reader_writer, - self._locker, - ) - - def _unsynced(self) -> T: - """Get the current data of the resource **without** locking and syncing it.""" - return self._resource.current() - - def _is_stale(self) -> bool: - """Check if the data held currently is not the latest version.""" - return self._resource.is_stale() - - def _is_locked(self) -> bool: - """Check if the resource is locked.""" - return self._locker.is_locked() - - def _put_unsafe(self, data: T) -> None: - """Put the data without checking for staleness or acquiring the lock. - - !!! warning - This should only really be called if you know what you're doing. - """ - self._resource.put(data) diff --git a/neps/status/status.py b/neps/status/status.py index bb68e50d..4b47eaeb 100644 --- a/neps/status/status.py +++ b/neps/status/status.py @@ -9,9 +9,9 @@ import pandas as pd -from neps.state.filebased import load_filebased_neps_state +from neps.state.filebased import FileLocker +from neps.state.neps_state import NePSState from neps.state.trial import Trial -from neps.utils._locker import Locker from neps.utils.types import ConfigID, _ConfigResultForStats if TYPE_CHECKING: @@ -37,9 +37,8 @@ def get_summary_dict( # NOTE: We don't lock the shared state since we are just reading and don't need to # make decisions based on the state - shared_state = load_filebased_neps_state(root_directory) - - trials = shared_state.get_all_trials() + shared_state = NePSState.create_or_load(root_directory, load_only=True) + trials = shared_state.lock_and_read_trials() evaluated: dict[ConfigID, _ConfigResultForStats] = {} @@ -160,7 +159,7 @@ def status( return summary["previous_results"], summary["pending_configs"] -def _initiate_summary_csv(root_directory: str | Path) -> tuple[Path, Path, Locker]: +def _initiate_summary_csv(root_directory: str | Path) -> tuple[Path, Path, FileLocker]: """Initializes a summary CSV and an associated locker for file access control. Args: @@ -181,7 +180,7 @@ def _initiate_summary_csv(root_directory: str | Path) -> tuple[Path, Path, Locke csv_config_data = summary_csv_directory / "config_data.csv" csv_run_data = summary_csv_directory / "run_status.csv" - csv_locker = Locker(summary_csv_directory / ".csv_lock") + csv_locker = FileLocker(summary_csv_directory / ".csv_lock", poll=2, timeout=600) return ( csv_config_data, @@ -282,7 +281,7 @@ def _get_dataframes_from_summary( def _save_data_to_csv( config_data_file_path: Path, run_data_file_path: Path, - locker: Locker, + locker: FileLocker, config_data_df: pd.DataFrame, run_data_df: pd.DataFrame, ) -> None: @@ -299,7 +298,7 @@ def _save_data_to_csv( config_data_df: The DataFrame containing configuration data. run_data_df: The DataFrame containing additional run data. """ - with locker(poll=2, timeout=600): + with locker.lock(): try: pending_configs = run_data_df.loc["num_pending_configs", "value"] pending_configs_with_worker = run_data_df.loc[ diff --git a/neps/utils/_locker.py b/neps/utils/_locker.py deleted file mode 100644 index f2d430f8..00000000 --- a/neps/utils/_locker.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterator -from contextlib import contextmanager -from pathlib import Path -from typing import IO - -import portalocker as pl - -EXCLUSIVE_NONE_BLOCKING = pl.LOCK_EX | pl.LOCK_NB - - -class Locker: - FailedToAcquireLock = pl.exceptions.LockException - - def __init__(self, lock_path: Path): - self.lock_path = lock_path - self.lock_path.touch(exist_ok=True) - - @contextmanager - def try_lock(self) -> Iterator[bool]: - try: - with self.acquire(fail_when_locked=True): - yield True - except self.FailedToAcquireLock: - yield False - - def is_locked(self) -> bool: - with self.try_lock() as acquired_lock: - return not acquired_lock - - @contextmanager - def __call__( - self, - poll: float = 1, - *, - timeout: float | None = None, - fail_when_locked: bool = False, - ) -> Iterator[IO]: - with pl.Lock( - self.lock_path, - check_interval=poll, - timeout=timeout, - flags=EXCLUSIVE_NONE_BLOCKING, - fail_when_locked=fail_when_locked, - ) as fh: - yield fh # We almost never use it but nothing better to yield - - @contextmanager - def acquire( - self, - poll: float = 1.0, - *, - timeout: float | None = None, - fail_when_locked: bool = False, - ) -> Iterator[IO]: - with self( - poll, - timeout=timeout, - fail_when_locked=fail_when_locked, - ) as fh: - yield fh diff --git a/neps/utils/cli.py b/neps/utils/cli.py index cde70357..455fec43 100644 --- a/neps/utils/cli.py +++ b/neps/utils/cli.py @@ -40,10 +40,6 @@ ) from neps.optimizers.base_optimizer import BaseOptimizer from neps.utils.run_args import load_and_return_object -from neps.state.filebased import ( - create_or_load_filebased_neps_state, - load_filebased_neps_state, -) from neps.state.neps_state import NePSState from neps.state.trial import Trial from neps.exceptions import VersionedResourceDoesNotExistError, TrialNotFoundError @@ -140,8 +136,8 @@ def init_config(args: argparse.Namespace) -> None: else: directory = Path(directory) is_new = not directory.exists() - _ = create_or_load_filebased_neps_state( - directory=directory, + _ = NePSState.create_or_load( + path=directory, optimizer_info=OptimizerInfo(optimizer_info), optimizer_state=OptimizationState( budget=( @@ -335,7 +331,7 @@ def info_config(args: argparse.Namespace) -> None: if neps_state is None: return try: - trial = neps_state.get_trial_by_id(config_id) + trial = neps_state.unsafe_retry_get_trial_by_id(config_id) except TrialNotFoundError: print(f"No trial found with ID {config_id}.") return @@ -381,7 +377,7 @@ def load_neps_errors(args: argparse.Namespace) -> None: neps_state = load_neps_state(directory_path) if neps_state is None: return - errors = neps_state.get_errors() + errors = neps_state.lock_and_get_errors() if not errors.errs: print("No errors found.") @@ -441,7 +437,7 @@ def sample_config(args: argparse.Namespace) -> None: # Sample trials for _ in range(num_configs): try: - trial = neps_state.sample_trial(optimizer, worker_id=worker_id) + trial = neps_state.lock_and_sample_trial(optimizer, worker_id=worker_id) except Exception as e: print(f"Error during configuration sampling: {e}") continue # Skip to the next iteration @@ -491,10 +487,9 @@ def status(args: argparse.Namespace) -> None: summary = get_summary_dict(directory_path, add_details=True) # Calculate the number of trials in different states + trials = neps_state.lock_and_read_trials() evaluating_trials_count = sum( - 1 - for trial in neps_state.get_all_trials().values() - if trial.state.name == "EVALUATING" + 1 for trial in trials.values() if trial.state == Trial.State.EVALUATING ) pending_trials_count = summary["num_pending_configs"] succeeded_trials_count = summary["num_evaluated_configs"] - summary["num_error"] @@ -503,7 +498,7 @@ def status(args: argparse.Namespace) -> None: # Print summary print("NePS Status:") print("-----------------------------") - print(f"Optimizer: {neps_state.optimizer_info().info['searcher_alg']}") + print(f"Optimizer: {neps_state.lock_and_get_optimizer_info().info['searcher_alg']}") print(f"Succeeded Trials: {succeeded_trials_count}") print(f"Failed Trials (Errors): {failed_trials_count}") print(f"Active Trials: {evaluating_trials_count}") @@ -514,9 +509,8 @@ def status(args: argparse.Namespace) -> None: print("-----------------------------") # Retrieve and sort the trials by time_sampled - all_trials = neps_state.get_all_trials() sorted_trials = sorted( - all_trials.values(), key=lambda t: t.metadata.time_sampled, reverse=True + trials.values(), key=lambda t: t.metadata.time_sampled, reverse=True ) # Filter trials based on state @@ -589,7 +583,7 @@ def status(args: argparse.Namespace) -> None: print("\nNo successful trial found.") # Display optimizer information - optimizer_info = neps_state.optimizer_info().info + optimizer_info = neps_state.lock_and_get_optimizer_info().info searcher_name = optimizer_info.get("searcher_name", "N/A") searcher_alg = optimizer_info.get("searcher_alg", "N/A") searcher_args = optimizer_info.get("searcher_args", {}) @@ -631,7 +625,7 @@ def sort_trial_id(trial_id: str) -> List[int]: # Convert each part to an integer for proper numeric sorting return [int(part) for part in parts] - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() sorted_trials = sorted(trials.values(), key=lambda x: sort_trial_id(x.id)) # Compute incumbents @@ -662,10 +656,10 @@ def sort_trial_id(trial_id: str) -> List[int]: print(f"Plot saved to '{plot_path}'.") -def load_neps_state(directory_path: Path) -> Optional[NePSState[Path]]: +def load_neps_state(directory_path: Path) -> Optional[NePSState]: """Load the NePS state with error handling.""" try: - return load_filebased_neps_state(directory_path) + return NePSState.create_or_load(directory_path, load_only=True) except VersionedResourceDoesNotExistError: print(f"Error: No NePS state found in the directory '{directory_path}'.") print("Ensure that the NePS run has been initialized correctly.") @@ -679,7 +673,11 @@ def compute_incumbents(sorted_trials: List[Trial]) -> List[Trial]: best_loss = float("inf") incumbents = [] for trial in sorted_trials: - if trial.report and trial.report.loss < best_loss: + if ( + trial.report is not None + and trial.report.loss is not None + and trial.report.loss < best_loss + ): best_loss = trial.report.loss incumbents.append(trial) return incumbents[::-1] # Reverse for most recent first @@ -1031,7 +1029,7 @@ def handle_report_config(args: argparse.Namespace) -> None: # Load the existing trial by ID try: - trial = neps_state.get_trial_by_id(args.trial_id) + trial = neps_state.unsafe_retry_get_trial_by_id(args.trial_id) if not trial: print(f"No trial found with ID {args.trial_id}") return @@ -1054,7 +1052,7 @@ def handle_report_config(args: argparse.Namespace) -> None: # Update NePS state try: - neps_state.report_trial_evaluation( + neps_state._report_trial_evaluation( trial=trial, report=report, worker_id=args.worker_id ) except Exception as e: diff --git a/neps/utils/common.py b/neps/utils/common.py index 1e735063..3887565e 100644 --- a/neps/utils/common.py +++ b/neps/utils/common.py @@ -160,7 +160,7 @@ def get_initial_directory(pipeline_directory: Path | str | None = None) -> Path: if pipeline_directory is not None: # TODO: Hard coded assumption config_id = Path(pipeline_directory).name.split("_", maxsplit=1)[-1] - trial = neps_state.get_trial_by_id(config_id) + trial = neps_state.unsafe_retry_get_trial_by_id(config_id) else: trial = get_in_progress_trial() @@ -169,7 +169,7 @@ def get_initial_directory(pipeline_directory: Path | str | None = None) -> Path: # Recursively find the initial directory while (prev_trial_id := trial.metadata.previous_trial_id) is not None: - trial = neps_state.get_trial_by_id(prev_trial_id) + trial = neps_state.unsafe_retry_get_trial_by_id(prev_trial_id) initial_dir = trial.metadata.location diff --git a/tests/test_runtime/test_default_report_values.py b/tests/test_runtime/test_default_report_values.py index 265d4c08..d857c69a 100644 --- a/tests/test_runtime/test_default_report_values.py +++ b/tests/test_runtime/test_default_report_values.py @@ -4,7 +4,6 @@ from neps.optimizers.random_search.optimizer import RandomSearch from neps.runtime import DefaultWorker from neps.search_spaces.search_space import SearchSpace -from neps.state.filebased import create_or_load_filebased_neps_state from neps.state.neps_state import NePSState from neps.state.optimizer import OptimizationState, OptimizerInfo from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings @@ -13,9 +12,9 @@ @fixture -def neps_state(tmp_path: Path) -> NePSState[Path]: - return create_or_load_filebased_neps_state( - directory=tmp_path / "neps_state", +def neps_state(tmp_path: Path) -> NePSState: + return NePSState.create_or_load( + path=tmp_path / "neps_state", optimizer_info=OptimizerInfo(info={"nothing": "here"}), optimizer_state=OptimizationState(budget=None, shared_state={}), ) @@ -54,15 +53,15 @@ def eval_function(*args, **kwargs) -> float: ) worker.run() - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() n_crashed = sum( trial.state == Trial.State.CRASHED is not None for trial in trials.values() ) assert len(trials) == 1 assert n_crashed == 1 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 1 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 1 trial = trials.popitem()[1] assert trial.state == Trial.State.CRASHED @@ -104,15 +103,15 @@ def eval_function(*args, **kwargs) -> float: ) worker.run() - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() n_sucess = sum( trial.state == Trial.State.SUCCESS is not None for trial in trials.values() ) assert len(trials) == 1 assert n_sucess == 1 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 0 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 0 trial = trials.popitem()[1] assert trial.state == Trial.State.SUCCESS @@ -152,15 +151,15 @@ def eval_function(*args, **kwargs) -> float: ) worker.run() - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() n_sucess = sum( trial.state == Trial.State.SUCCESS is not None for trial in trials.values() ) assert len(trials) == 1 assert n_sucess == 1 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 0 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 0 trial = trials.popitem()[1] assert trial.state == Trial.State.SUCCESS diff --git a/tests/test_runtime/test_error_handling_strategies.py b/tests/test_runtime/test_error_handling_strategies.py index 05cf762a..d341aa2b 100644 --- a/tests/test_runtime/test_error_handling_strategies.py +++ b/tests/test_runtime/test_error_handling_strategies.py @@ -8,8 +8,6 @@ from neps.optimizers.random_search.optimizer import RandomSearch from neps.runtime import DefaultWorker from neps.search_spaces.search_space import SearchSpace -from neps.state.err_dump import SerializedError -from neps.state.filebased import create_or_load_filebased_neps_state from neps.state.neps_state import NePSState from neps.state.optimizer import OptimizationState, OptimizerInfo from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings @@ -18,9 +16,9 @@ @fixture -def neps_state(tmp_path: Path) -> NePSState[Path]: - return create_or_load_filebased_neps_state( - directory=tmp_path / "neps_state", +def neps_state(tmp_path: Path) -> NePSState: + return NePSState.create_or_load( + path=tmp_path / "neps_state", optimizer_info=OptimizerInfo(info={"nothing": "here"}), optimizer_state=OptimizationState(budget=None, shared_state={}), ) @@ -61,15 +59,15 @@ def eval_function(*args, **kwargs) -> float: with pytest.raises(WorkerRaiseError): worker.run() - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() n_crashed = sum( trial.state == Trial.State.CRASHED is not None for trial in trials.values() ) assert len(trials) == 1 assert n_crashed == 1 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 1 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 1 def test_worker_raises_when_error_in_other_worker(neps_state: NePSState) -> None: @@ -114,15 +112,15 @@ def evaler(*args, **kwargs) -> float: with pytest.raises(WorkerRaiseError): worker2.run() - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() n_crashed = sum( trial.state == Trial.State.CRASHED is not None for trial in trials.values() ) assert len(trials) == 1 assert n_crashed == 1 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 1 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 1 @pytest.mark.parametrize( @@ -184,7 +182,7 @@ def __call__(self, *args, **kwargs) -> float: worker2.run() assert worker2.worker_cumulative_eval_count == 1 - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() n_success = sum( trial.state == Trial.State.SUCCESS is not None for trial in trials.values() ) @@ -195,5 +193,5 @@ def __call__(self, *args, **kwargs) -> float: assert n_crashed == 1 assert len(trials) == 2 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 1 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 1 diff --git a/tests/test_runtime/test_stopping_criterion.py b/tests/test_runtime/test_stopping_criterion.py index c73051a9..3e6da7ce 100644 --- a/tests/test_runtime/test_stopping_criterion.py +++ b/tests/test_runtime/test_stopping_criterion.py @@ -5,7 +5,7 @@ from neps.optimizers.random_search.optimizer import RandomSearch from neps.runtime import DefaultWorker from neps.search_spaces.search_space import SearchSpace -from neps.state.filebased import create_or_load_filebased_neps_state +from neps.state.neps_state import NePSState from neps.state.neps_state import NePSState from neps.state.optimizer import OptimizationState, OptimizerInfo from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings @@ -14,9 +14,9 @@ @fixture -def neps_state(tmp_path: Path) -> NePSState[Path]: - return create_or_load_filebased_neps_state( - directory=tmp_path / "neps_state", +def neps_state(tmp_path: Path) -> NePSState: + return NePSState.create_or_load( + path=tmp_path / "neps_state", optimizer_info=OptimizerInfo(info={"nothing": "here"}), optimizer_state=OptimizationState(budget=None, shared_state={}), ) @@ -52,10 +52,10 @@ def eval_function(*args, **kwargs) -> float: worker.run() assert worker.worker_cumulative_eval_count == 3 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 0 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 0 - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() for _, trial in trials.items(): assert trial.state == Trial.State.SUCCESS assert trial.report is not None @@ -71,8 +71,8 @@ def eval_function(*args, **kwargs) -> float: ) new_worker.run() assert new_worker.worker_cumulative_eval_count == 0 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 0 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 0 def test_worker_evaluations_total_stopping_criterion( @@ -105,10 +105,10 @@ def eval_function(*args, **kwargs) -> float: worker.run() assert worker.worker_cumulative_eval_count == 2 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 0 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 0 - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() assert len(trials) == 2 for _, trial in trials.items(): assert trial.state == Trial.State.SUCCESS @@ -126,10 +126,10 @@ def eval_function(*args, **kwargs) -> float: new_worker.run() assert worker.worker_cumulative_eval_count == 2 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 0 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 0 - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() assert len(trials) == 4 # Now we should have 4 of them for _, trial in trials.items(): assert trial.state == Trial.State.SUCCESS @@ -155,7 +155,7 @@ def test_include_in_progress_evaluations_towards_maximum_with_work_eval_count( ) # We put in one trial as being inprogress - pending_trial = neps_state.sample_trial(optimizer, worker_id="dummy") + pending_trial = neps_state.lock_and_sample_trial(optimizer, worker_id="dummy") pending_trial.set_evaluating(time_started=0.0, worker_id="dummy") neps_state.put_updated_trial(pending_trial) @@ -173,11 +173,11 @@ def eval_function(*args, **kwargs) -> float: assert worker.worker_cumulative_eval_count == 1 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 + assert len(neps_state.lock_and_get_errors()) == 0 - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() assert len(trials) == 2 the_pending_trial = trials[pending_trial.id] @@ -225,11 +225,11 @@ def eval_function(*args, **kwargs) -> dict: assert worker.worker_cumulative_eval_count == 2 assert worker.worker_cumulative_eval_cost == 2.0 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 + assert len(neps_state.lock_and_get_errors()) == 0 - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() assert len(trials) == 2 # New worker should now not run anything as the total cost has been reached. @@ -276,11 +276,11 @@ def eval_function(*args, **kwargs) -> dict: assert worker.worker_cumulative_eval_count == 2 assert worker.worker_cumulative_eval_cost == 2.0 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 + assert len(neps_state.lock_and_get_errors()) == 0 - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() assert len(trials) == 2 # New worker should also run 2 more trials @@ -295,11 +295,11 @@ def eval_function(*args, **kwargs) -> dict: assert new_worker.worker_cumulative_eval_count == 2 assert new_worker.worker_cumulative_eval_cost == 2.0 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 + assert len(neps_state.lock_and_get_errors()) == 0 - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() assert len(trials) == 4 # 2 more trials were ran @@ -336,10 +336,10 @@ def eval_function(*args, **kwargs) -> float: assert worker.worker_cumulative_eval_count > 0 assert worker.worker_cumulative_evaluation_time_seconds <= 2.0 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 - len_trials_on_first_worker = len(neps_state.get_all_trials()) + assert len(neps_state.lock_and_get_errors()) == 0 + len_trials_on_first_worker = len(neps_state.lock_and_read_trials()) # New worker should also run some trials more trials new_worker = DefaultWorker.new( @@ -354,10 +354,10 @@ def eval_function(*args, **kwargs) -> float: assert new_worker.worker_cumulative_eval_count > 0 assert new_worker.worker_cumulative_evaluation_time_seconds <= 2.0 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 - len_trials_on_second_worker = len(neps_state.get_all_trials()) + assert len(neps_state.lock_and_get_errors()) == 0 + len_trials_on_second_worker = len(neps_state.lock_and_read_trials()) assert len_trials_on_second_worker > len_trials_on_first_worker @@ -395,10 +395,10 @@ def eval_function(*args, **kwargs) -> float: assert worker.worker_cumulative_eval_count > 0 assert worker.worker_cumulative_evaluation_time_seconds <= 1.0 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 - len_trials_on_first_worker = len(neps_state.get_all_trials()) + assert len(neps_state.lock_and_get_errors()) == 0 + len_trials_on_first_worker = len(neps_state.lock_and_read_trials()) # New worker should also run some trials more trials new_worker = DefaultWorker.new( @@ -413,10 +413,10 @@ def eval_function(*args, **kwargs) -> float: assert new_worker.worker_cumulative_eval_count > 0 assert new_worker.worker_cumulative_evaluation_time_seconds <= 1.0 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 - len_trials_on_second_worker = len(neps_state.get_all_trials()) + assert len(neps_state.lock_and_get_errors()) == 0 + len_trials_on_second_worker = len(neps_state.lock_and_read_trials()) assert len_trials_on_second_worker > len_trials_on_first_worker @@ -454,10 +454,10 @@ def eval_function(*args, **kwargs) -> float: assert worker.worker_cumulative_eval_count > 0 assert worker.worker_cumulative_evaluation_time_seconds <= 1.0 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 - len_trials_on_first_worker = len(neps_state.get_all_trials()) + assert len(neps_state.lock_and_get_errors()) == 0 + len_trials_on_first_worker = len(neps_state.lock_and_read_trials()) # New worker should also run some trials more trials new_worker = DefaultWorker.new( @@ -472,8 +472,8 @@ def eval_function(*args, **kwargs) -> float: assert new_worker.worker_cumulative_eval_count == 0 assert new_worker.worker_cumulative_evaluation_time_seconds == 0 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 - len_trials_on_second_worker = len(neps_state.get_all_trials()) + assert len(neps_state.lock_and_get_errors()) == 0 + len_trials_on_second_worker = len(neps_state.lock_and_read_trials()) assert len_trials_on_second_worker == len_trials_on_first_worker diff --git a/tests/test_state/test_filebased_neps_state.py b/tests/test_state/test_filebased_neps_state.py index 02f5a52c..87085639 100644 --- a/tests/test_state/test_filebased_neps_state.py +++ b/tests/test_state/test_filebased_neps_state.py @@ -6,10 +6,7 @@ from typing import Any from neps.exceptions import NePSError, TrialNotFoundError from neps.state.err_dump import ErrDump -from neps.state.filebased import ( - create_or_load_filebased_neps_state, - load_filebased_neps_state, -) +from neps.state.neps_state import NePSState import pytest from pytest_cases import fixture, parametrize @@ -38,21 +35,21 @@ def test_create_with_new_filebased_neps_state( optimizer_state: OptimizationState, ) -> None: new_path = tmp_path / "neps_state" - neps_state = create_or_load_filebased_neps_state( - directory=new_path, + neps_state = NePSState.create_or_load( + path=new_path, optimizer_info=optimizer_info, optimizer_state=optimizer_state, ) - assert neps_state.optimizer_info() == optimizer_info - assert neps_state.optimizer_state() == optimizer_state + assert neps_state.lock_and_get_optimizer_info() == optimizer_info + assert neps_state.lock_and_get_optimizer_state() == optimizer_state assert neps_state.all_trial_ids() == [] - assert neps_state.get_all_trials() == {} - assert neps_state.get_errors() == ErrDump(errs=[]) - assert neps_state.get_next_pending_trial() is None - assert neps_state.get_next_pending_trial(n=10) == [] + assert neps_state.lock_and_read_trials() == {} + assert neps_state.lock_and_get_errors() == ErrDump(errs=[]) + assert neps_state.lock_and_get_next_pending_trial() is None + assert neps_state.lock_and_get_next_pending_trial(n=10) == [] with pytest.raises(TrialNotFoundError): - assert neps_state.get_trial_by_id("1") + assert neps_state.lock_and_get_trial_by_id("1") def test_create_or_load_with_load_filebased_neps_state( @@ -61,8 +58,8 @@ def test_create_or_load_with_load_filebased_neps_state( optimizer_state: OptimizationState, ) -> None: new_path = tmp_path / "neps_state" - neps_state = create_or_load_filebased_neps_state( - directory=new_path, + neps_state = NePSState.create_or_load( + path=new_path, optimizer_info=optimizer_info, optimizer_state=optimizer_state, ) @@ -74,8 +71,8 @@ def test_create_or_load_with_load_filebased_neps_state( budget=BudgetInfo(max_cost_budget=20, used_cost_budget=10), shared_state={"c": "d"}, ) - neps_state2 = create_or_load_filebased_neps_state( - directory=new_path, + neps_state2 = NePSState.create_or_load( + path=new_path, optimizer_info=optimizer_info, optimizer_state=different_state, ) @@ -88,13 +85,13 @@ def test_load_on_existing_neps_state( optimizer_state: OptimizationState, ) -> None: new_path = tmp_path / "neps_state" - neps_state = create_or_load_filebased_neps_state( - directory=new_path, + neps_state = NePSState.create_or_load( + path=new_path, optimizer_info=optimizer_info, optimizer_state=optimizer_state, ) - neps_state2 = load_filebased_neps_state(directory=new_path) + neps_state2 = NePSState.create_or_load(path=new_path, load_only=True) assert neps_state == neps_state2 @@ -104,15 +101,15 @@ def test_new_or_load_on_existing_neps_state_with_different_optimizer_info( optimizer_state: OptimizationState, ) -> None: new_path = tmp_path / "neps_state" - create_or_load_filebased_neps_state( - directory=new_path, + NePSState.create_or_load( + path=new_path, optimizer_info=optimizer_info, optimizer_state=optimizer_state, ) with pytest.raises(NePSError): - create_or_load_filebased_neps_state( - directory=new_path, + NePSState.create_or_load( + path=new_path, optimizer_info=OptimizerInfo({"e": "f"}), optimizer_state=optimizer_state, ) diff --git a/tests/test_state/test_neps_state.py b/tests/test_state/test_neps_state.py index c64cb64e..78b3213b 100644 --- a/tests/test_state/test_neps_state.py +++ b/tests/test_state/test_neps_state.py @@ -15,9 +15,7 @@ Categorical, ) from neps.search_spaces.search_space import SearchSpace -from neps.state.filebased import ( - create_or_load_filebased_neps_state, -) +from neps.state.neps_state import NePSState from pytest_cases import fixture, parametrize, parametrize_with_cases, case from neps.state.neps_state import NePSState @@ -156,8 +154,8 @@ def case_neps_state_filebased( shared_state: dict[str, Any], ) -> NePSState: new_path = tmp_path / "neps_state" - return create_or_load_filebased_neps_state( - directory=new_path, + return NePSState.create_or_load( + path=new_path, optimizer_info=optimizer_info, optimizer_state=OptimizationState(budget=budget, shared_state=shared_state), ) @@ -169,15 +167,15 @@ def test_sample_trial( optimizer_and_key: tuple[BaseOptimizer, str], ) -> None: optimizer, key = optimizer_and_key - if key in REQUIRES_COST and neps_state.optimizer_state().budget is None: + if key in REQUIRES_COST and neps_state.lock_and_get_optimizer_state().budget is None: pytest.xfail(f"{key} requires a cost budget") - assert neps_state.get_all_trials() == {} - assert neps_state.get_next_pending_trial() is None - assert neps_state.get_next_pending_trial(n=10) == [] + assert neps_state.lock_and_read_trials() == {} + assert neps_state.lock_and_get_next_pending_trial() is None + assert neps_state.lock_and_get_next_pending_trial(n=10) == [] assert neps_state.all_trial_ids() == [] - trial1 = neps_state.sample_trial(optimizer=optimizer, worker_id="1") + trial1 = neps_state.lock_and_sample_trial(optimizer=optimizer, worker_id="1") for k, v in trial1.config.items(): assert k in optimizer.pipeline_space.hyperparameters assert v is not None, f"'{k}' is None in {trial1.config}" @@ -186,19 +184,19 @@ def test_sample_trial( # precise, we need to introduce a sleep -_- time.sleep(0.1) - assert neps_state.get_all_trials() == {trial1.id: trial1} - assert neps_state.get_next_pending_trial() == trial1 - assert neps_state.get_next_pending_trial(n=10) == [trial1] + assert neps_state.lock_and_read_trials() == {trial1.id: trial1} + assert neps_state.lock_and_get_next_pending_trial() == trial1 + assert neps_state.lock_and_get_next_pending_trial(n=10) == [trial1] assert neps_state.all_trial_ids() == [trial1.id] - trial2 = neps_state.sample_trial(optimizer=optimizer, worker_id="1") + trial2 = neps_state.lock_and_sample_trial(optimizer=optimizer, worker_id="1") for k, v in trial1.config.items(): assert k in optimizer.pipeline_space.hyperparameters assert v is not None, f"'{k}' is None in {trial1.config}" assert trial1 != trial2 - assert neps_state.get_all_trials() == {trial1.id: trial1, trial2.id: trial2} - assert neps_state.get_next_pending_trial() == trial1 - assert neps_state.get_next_pending_trial(n=10) == [trial1, trial2] + assert neps_state.lock_and_read_trials() == {trial1.id: trial1, trial2.id: trial2} + assert neps_state.lock_and_get_next_pending_trial() == trial1 + assert neps_state.lock_and_get_next_pending_trial(n=10) == [trial1, trial2] assert sorted(neps_state.all_trial_ids()) == [trial1.id, trial2.id] diff --git a/tests/test_state/test_synced.py b/tests/test_state/test_synced.py deleted file mode 100644 index 6294db37..00000000 --- a/tests/test_state/test_synced.py +++ /dev/null @@ -1,429 +0,0 @@ -import copy -import random - -from pytest_cases import parametrize, parametrize_with_cases, case -import numpy as np -from neps.state.err_dump import ErrDump, SerializableTrialError -from neps.state.filebased import ( - ReaderWriterErrDump, - ReaderWriterOptimizationState, - ReaderWriterOptimizerInfo, - ReaderWriterSeedSnapshot, - ReaderWriterTrial, - FileVersioner, - FileLocker, -) -from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo -import pytest -from typing import Any, Callable -from pathlib import Path -from neps.state import SeedSnapshot, Synced, Trial - - -@case -def case_trial_1(tmp_path: Path) -> tuple[Synced[Trial, Path], Callable[[Trial], None]]: - trial_id = "1" - trial = Trial.new( - trial_id=trial_id, - location="", - config={"a": "b"}, - time_sampled=0, - previous_trial=None, - previous_trial_location=None, - worker_id=0, - ) - - def _update(trial: Trial) -> None: - trial.set_submitted(time_submitted=1) - - x = Synced.new( - data=trial, - location=tmp_path / "1", - locker=FileLocker(lock_path=tmp_path / "1" / ".lock", poll=0.1, timeout=None), - versioner=FileVersioner(version_file=tmp_path / "1" / ".version"), - reader_writer=ReaderWriterTrial(), - ) - return x, _update - - -@case -def case_trial_2(tmp_path: Path) -> tuple[Synced[Trial, Path], Callable[[Trial], None]]: - trial_id = "1" - trial = Trial.new( - trial_id=trial_id, - location="", - config={"a": "b"}, - time_sampled=0, - previous_trial=None, - previous_trial_location=None, - worker_id=0, - ) - trial.set_submitted(time_submitted=1) - - def _update(trial: Trial) -> None: - trial.set_evaluating(time_started=2, worker_id="1") - - x = Synced.new( - data=trial, - location=tmp_path / "1", - locker=FileLocker(lock_path=tmp_path / "1" / ".lock", poll=0.1, timeout=None), - versioner=FileVersioner(version_file=tmp_path / "1" / ".version"), - reader_writer=ReaderWriterTrial(), - ) - return x, _update - - -@case -def case_trial_3(tmp_path: Path) -> tuple[Synced[Trial, Path], Callable[[Trial], None]]: - trial_id = "1" - trial = Trial.new( - trial_id=trial_id, - config={"a": "b"}, - location="", - time_sampled=0, - previous_trial=None, - previous_trial_location=None, - worker_id=0, - ) - trial.set_submitted(time_submitted=1) - trial.set_evaluating(time_started=2, worker_id="1") - - def _update(trial: Trial) -> None: - trial.set_complete( - time_end=3, - loss=1, - cost=1, - extra={"hi": [1, 2, 3]}, - learning_curve=[1], - report_as="success", - evaluation_duration=1, - err=None, - tb=None, - ) - - x = Synced.new( - data=trial, - location=tmp_path / "1", - locker=FileLocker(lock_path=tmp_path / "1" / ".lock", poll=0.1, timeout=None), - versioner=FileVersioner(version_file=tmp_path / "1" / ".version"), - reader_writer=ReaderWriterTrial(), - ) - return x, _update - - -@case -def case_trial_4(tmp_path: Path) -> tuple[Synced[Trial, Path], Callable[[Trial], None]]: - trial_id = "1" - trial = Trial.new( - trial_id=trial_id, - config={"a": "b"}, - location="", - time_sampled=0, - previous_trial=None, - previous_trial_location=None, - worker_id=0, - ) - trial.set_submitted(time_submitted=1) - trial.set_evaluating(time_started=2, worker_id="1") - - def _update(trial: Trial) -> None: - trial.set_complete( - time_end=3, - loss=np.nan, - cost=np.inf, - extra={"hi": [1, 2, 3]}, - report_as="failed", - learning_curve=None, - evaluation_duration=2, - err=None, - tb=None, - ) - - x = Synced.new( - data=trial, - location=tmp_path / "1", - locker=FileLocker(lock_path=tmp_path / "1" / ".lock", poll=0.1, timeout=None), - versioner=FileVersioner(version_file=tmp_path / "1" / ".version"), - reader_writer=ReaderWriterTrial(), - ) - return x, _update - - -@case -def case_trial_5(tmp_path: Path) -> tuple[Synced[Trial, Path], Callable[[Trial], None]]: - trial_id = "1" - trial = Trial.new( - trial_id=trial_id, - config={"a": "b"}, - location="", - time_sampled=0, - previous_trial=None, - previous_trial_location=None, - worker_id=0, - ) - trial.set_submitted(time_submitted=1) - trial.set_evaluating(time_started=2, worker_id=1) - - def _update(trial: Trial) -> None: - trial.set_complete( - time_end=3, - loss=np.nan, - cost=np.inf, - extra={"hi": [1, 2, 3]}, - learning_curve=None, - evaluation_duration=2, - report_as="failed", - err=ValueError("hi"), - tb="something something traceback", - ) - - x = Synced.new( - data=trial, - location=tmp_path / "1", - locker=FileLocker(lock_path=tmp_path / "1" / ".lock", poll=0.1, timeout=None), - versioner=FileVersioner(version_file=tmp_path / "1" / ".version"), - reader_writer=ReaderWriterTrial(), - ) - return x, _update - - -@case -def case_trial_6(tmp_path: Path) -> tuple[Synced[Trial, Path], Callable[[Trial], None]]: - trial_id = "1" - trial = Trial.new( - trial_id=trial_id, - config={"a": "b"}, - location="", - time_sampled=0, - previous_trial=None, - previous_trial_location=None, - worker_id=0, - ) - trial.set_submitted(time_submitted=1) - trial.set_evaluating(time_started=2, worker_id=1) - - def _update(trial: Trial) -> None: - trial.set_corrupted() - - x = Synced.new( - data=trial, - location=tmp_path / "1", - locker=FileLocker(lock_path=tmp_path / "1" / ".lock", poll=0.1, timeout=None), - versioner=FileVersioner(version_file=tmp_path / "1" / ".version"), - reader_writer=ReaderWriterTrial(), - ) - return x, _update - - -@case -def case_trial_7(tmp_path: Path) -> tuple[Synced[Trial, Path], Callable[[Trial], None]]: - trial_id = "1" - trial = Trial.new( - trial_id=trial_id, - config={"a": "b"}, - location="", - time_sampled=0, - previous_trial=None, - previous_trial_location=None, - worker_id=0, - ) - trial.set_submitted(time_submitted=1) - trial.set_evaluating(time_started=2, worker_id=1) - trial.set_complete( - time_end=3, - loss=np.nan, - cost=np.inf, - extra={"hi": [1, 2, 3]}, - learning_curve=[1, 2, 3], - report_as="failed", - evaluation_duration=2, - err=ValueError("hi"), - tb="something something traceback", - ) - - def _update(trial: Trial) -> None: - trial.reset() - - x = Synced.new( - data=trial, - location=tmp_path / "1", - locker=FileLocker(lock_path=tmp_path / "1" / ".lock", poll=0.1, timeout=None), - versioner=FileVersioner(version_file=tmp_path / "1" / ".version"), - reader_writer=ReaderWriterTrial(), - ) - return x, _update - - -@case -def case_seed_snapshot( - tmp_path: Path, -) -> tuple[Synced[SeedSnapshot, Path], Callable[[SeedSnapshot], None]]: - seed = SeedSnapshot.new_capture() - - def _update(seed: SeedSnapshot) -> None: - random.randint(0, 100) - seed.recapture() - - x = Synced.new( - data=seed, - location=tmp_path / "seeds", - locker=FileLocker(lock_path=tmp_path / "seeds" / ".lock", poll=0.1, timeout=None), - versioner=FileVersioner(version_file=tmp_path / "seeds" / ".version"), - reader_writer=ReaderWriterSeedSnapshot(), - ) - return x, _update - - -@case -@parametrize( - "err", - [ - None, - SerializableTrialError( - trial_id="1", - worker_id="2", - err_type="ValueError", - err="hi", - tb="traceback\nmore", - ), - ], -) -def case_err_dump( - tmp_path: Path, - err: None | SerializableTrialError, -) -> tuple[Synced[ErrDump, Path], Callable[[ErrDump], None]]: - err_dump = ErrDump() if err is None else ErrDump(errs=[err]) - - def _update(err_dump: ErrDump) -> None: - new_err = SerializableTrialError( - trial_id="2", - worker_id="2", - err_type="RuntimeError", - err="hi", - tb="traceback\nless", - ) - err_dump.append(new_err) - - x = Synced.new( - data=err_dump, - location=tmp_path / "err_dump", - locker=FileLocker( - lock_path=tmp_path / "err_dump" / ".lock", poll=0.1, timeout=None - ), - versioner=FileVersioner(version_file=tmp_path / "err_dump" / ".version"), - reader_writer=ReaderWriterErrDump("all"), - ) - return x, _update - - -@case -def case_optimizer_info( - tmp_path: Path, -) -> tuple[Synced[OptimizerInfo, Path], Callable[[OptimizerInfo], None]]: - optimizer_info = OptimizerInfo(info={"a": "b"}) - - def _update(optimizer_info: OptimizerInfo) -> None: - optimizer_info.info["b"] = "c" # type: ignore # NOTE: We shouldn't be mutating but anywho... - - x = Synced.new( - data=optimizer_info, - location=tmp_path / "optimizer_info", - locker=FileLocker( - lock_path=tmp_path / "optimizer_info" / ".lock", poll=0.1, timeout=None - ), - versioner=FileVersioner(version_file=tmp_path / "optimizer_info" / ".version"), - reader_writer=ReaderWriterOptimizerInfo(), - ) - return x, _update - - -@case -@pytest.mark.parametrize( - "budget", (None, BudgetInfo(max_cost_budget=10, used_cost_budget=0)) -) -@pytest.mark.parametrize("shared_state", ({}, {"a": "b"})) -def case_optimization_state( - tmp_path: Path, - budget: BudgetInfo | None, - shared_state: dict[str, Any], -) -> tuple[Synced[OptimizationState, Path], Callable[[OptimizationState], None]]: - optimization_state = OptimizationState(budget=budget, shared_state=shared_state) - - def _update(optimization_state: OptimizationState) -> None: - optimization_state.shared_state["a"] = "c" # type: ignore # NOTE: We shouldn't be mutating but anywho... - optimization_state.budget = BudgetInfo(max_cost_budget=10, used_cost_budget=5) - - x = Synced.new( - data=optimization_state, - location=tmp_path / "optimizer_info", - locker=FileLocker( - lock_path=tmp_path / "optimizer_info" / ".lock", poll=0.1, timeout=None - ), - versioner=FileVersioner(version_file=tmp_path / "optimizer_info" / ".version"), - reader_writer=ReaderWriterOptimizationState(), - ) - return x, _update - - -@parametrize_with_cases("shared, update", cases=".") -def test_initial_state(shared: Synced, update: Callable) -> None: - assert shared._is_locked() == False - assert shared._is_stale() == False - assert shared._unsynced() == shared.synced() - - -@parametrize_with_cases("shared, update", cases=".") -def test_put_updates_current_data_and_is_not_stale( - shared: Synced, update: Callable -) -> None: - current_data = shared._unsynced() - - new_data = copy.deepcopy(current_data) - update(new_data) - assert new_data != current_data - - shared.put(new_data) - assert shared._unsynced() == new_data - assert shared._is_stale() == False - assert shared._is_locked() == False - - -@parametrize_with_cases("shared1, update", cases=".") -def test_share_synced_update_and_put(shared1: Synced, update: Callable) -> None: - shared2 = shared1.deepcopy() - assert shared1 == shared2 - assert not shared1._is_locked() - assert not shared2._is_locked() - - with shared2.acquire() as (data2, put2): - assert shared1._is_locked() - assert shared2._is_locked() - update(data2) - put2(data2) - - assert not shared1._is_locked() - assert not shared2._is_locked() - - assert shared1 != shared2 - assert shared1._unsynced() != shared2._unsynced() - assert shared1._is_stale() - - shared1.synced() - assert not shared1._is_stale() - assert not shared2._is_stale() - assert shared1._unsynced() == shared2._unsynced() - - -@parametrize_with_cases("shared, update", cases=".") -def test_shared_new_fails_if_done_on_existing_resource( - shared: Synced, update: Callable -) -> None: - data, location, versioner, rw, lock = shared._components() - with pytest.raises(Synced.VersionedResourceAlreadyExistsError): - Synced.new( - data=data, - location=location, - versioner=versioner, - reader_writer=rw, - locker=lock, - )