Skip to content

Commit

Permalink
optim: reading/loading of items
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman committed Dec 2, 2024
1 parent a44f551 commit bb18db7
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 158 deletions.
167 changes: 83 additions & 84 deletions neps/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@
LINUX_FILELOCK_FUNCTION,
MAX_RETRIES_CREATE_LOAD_STATE,
MAX_RETRIES_GET_NEXT_TRIAL,
MAX_RETRIES_SET_EVALUATING,
MAX_RETRIES_WORKER_CHECK_SHOULD_STOP,
)
from neps.exceptions import (
NePSError,
VersionMismatchError,
TrialAlreadyExistsError,
WorkerFailedToGetPendingTrialsError,
WorkerRaiseError,
)
Expand Down Expand Up @@ -340,9 +339,84 @@ def _check_global_stopping_criterion(

return False

@property
def _requires_global_stopping_criterion(self) -> bool:
return (
self.settings.max_evaluations_total is not None
or self.settings.max_cost_total is not None
or self.settings.max_evaluation_time_total_seconds is not None
)

def _get_next_trial(self) -> Trial | Literal["break"]:
# If there are no global stopping criterion, we can no just return early.
with self.state._state_lock.lock():
# With the trial lock, we'll load everything in, if we have a pending
# config, use that and return.
with self.state._trial_lock.lock():
trials = self.state._trials.latest()

if self._requires_global_stopping_criterion:
should_stop = self._check_global_stopping_criterion(trials)
if should_stop is not False:
logger.info(should_stop)
return "break"

pending_trials = [
trial
for trial in trials.values()
if trial.state == Trial.State.PENDING
]

if len(pending_trials) > 0:
earliest_pending = sorted(
pending_trials,
key=lambda t: t.metadata.time_sampled,
)[0]
earliest_pending.set_evaluating(
time_started=time.time(),
worker_id=self.worker_id,
)
self.state._trials.update_trial(
earliest_pending,
hints=["metadata", "state"],
)
return earliest_pending

# Otherwise, we release the trial lock while sampling
sampled_trial = self.state._sample_trial(
optimizer=self.optimizer,
worker_id=self.worker_id,
trials=trials,
)

with self.state._trial_lock.lock():
try:
self.state._trials.new_trial(sampled_trial)
return sampled_trial
except TrialAlreadyExistsError as e:
if sampled_trial.id in trials:
logger.warning(
"The new sampled trial was given an id of '%s', yet this already"
" exists in the loaded in trials given to the optimizer. This"
" indicates a bug with the optimizers allocation of ids.",
sampled_trial.id,
)
else:
logger.warning(
"The new sampled trial was given an id of '%s', which is not one"
" that was loaded in by the optimizer. This indicates that"
" configuration '%s' was put on disk during the time that this"
" worker had the optimizer state lock OR that after obtaining the"
" optimizer state lock, somehow this configuration failed to be"
" loaded in and passed to the optimizer.",
sampled_trial.id,
sampled_trial.id,
)
raise e

# Forgive me lord, for I have sinned, this function is atrocious but complicated
# due to locking.
def run(self) -> None: # noqa: C901, PLR0915, PLR0912
def run(self) -> None: # noqa: C901, PLR0915
"""Run the worker.
Will keep running until one of the criterion defined by the `WorkerSettings`
Expand All @@ -356,7 +430,6 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912
_error_from_evaluation: Exception | None = None

_repeated_fail_get_next_trial_count = 0
n_failed_set_trial_state = 0
n_repeated_failed_check_should_stop = 0
while True:
try:
Expand Down Expand Up @@ -400,50 +473,13 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912

# From here, we now begin sampling or getting the next pending trial.
# As the global stopping criterion requires us to check all trials, and
# needs to be in locked in-step with sampling
# needs to be in locked in-step with sampling and is done inside
# _get_next_trial
try:
# If there are no global stopping criterion, we can no just return early.
with self.state._state_lock.lock():
with self.state._trial_lock.lock():
trials = self.state._trials.latest()

requires_checking_global_stopping_criterion = (
self.settings.max_evaluations_total is not None
or self.settings.max_cost_total is not None
or self.settings.max_evaluation_time_total_seconds is not None
)
if requires_checking_global_stopping_criterion:
should_stop = self._check_global_stopping_criterion(trials)
if should_stop is not False:
logger.info(should_stop)
break

pending_trials = [
trial
for trial in trials.values()
if trial.state == Trial.State.PENDING
]
if len(pending_trials) > 0:
earliest_pending = sorted(
pending_trials,
key=lambda t: t.metadata.time_sampled,
)[0]
earliest_pending.set_evaluating(
time_started=time.time(),
worker_id=self.worker_id,
)
with self.state._trial_lock.lock():
self.state._trials.update_trial(earliest_pending)
trial_to_eval = earliest_pending
else:
sampled_trial = self.state._sample_trial(
optimizer=self.optimizer,
worker_id=self.worker_id,
_trials=trials,
)
trial_to_eval = sampled_trial

_repeated_fail_get_next_trial_count = 0
trial_to_eval = self._get_next_trial()
if trial_to_eval == "break":
break
_repeated_fail_get_next_trial_count = 0
except Exception as e:
_repeated_fail_get_next_trial_count += 1
logger.debug(
Expand All @@ -452,7 +488,6 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912
exc_info=True,
)
time.sleep(1) # Help stagger retries

# NOTE: This is to prevent any infinite loops if we can't get a trial
if _repeated_fail_get_next_trial_count >= MAX_RETRIES_GET_NEXT_TRIAL:
raise WorkerFailedToGetPendingTrialsError(
Expand All @@ -463,42 +498,6 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912

continue

# If we can't set this working to evaluating, then just retry the loop
try:
n_failed_set_trial_state = 0
except VersionMismatchError:
n_failed_set_trial_state += 1
logger.debug(
"Another worker has managed to change trial '%s'"
" while this worker '%s' was trying to set it to"
" evaluating. This is fine and likely means the other worker is"
" evaluating it, this worker will attempt to sample new trial.",
trial_to_eval.id,
self.worker_id,
exc_info=True,
)
time.sleep(1) # Help stagger retries
except Exception:
n_failed_set_trial_state += 1
logger.error(
"Unexpected error from worker '%s' trying to set trial"
" '%' to evaluating.",
self.worker_id,
trial_to_eval.id,
exc_info=True,
)
time.sleep(1) # Help stagger retries

# NOTE: This is to prevent infinite looping if it somehow keeps getting
# the same trial and can't set it to evaluating.
if n_failed_set_trial_state != 0:
if n_failed_set_trial_state >= MAX_RETRIES_SET_EVALUATING:
raise WorkerFailedToGetPendingTrialsError(
f"Worker {self.worker_id} failed to set trial to evaluating"
f" {MAX_RETRIES_SET_EVALUATING} times in a row. Bailing!"
)
continue

# We (this worker) has managed to set it to evaluating, now we can evaluate it
with _set_global_trial(trial_to_eval):
evaluated_trial, report = evaluate_trial(
Expand Down
98 changes: 74 additions & 24 deletions neps/state/filebased.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import ClassVar, Final, TypeVar
from typing import ClassVar, Final, Literal, TypeAlias, TypeVar

import numpy as np
import portalocker as pl
Expand All @@ -25,15 +25,23 @@
K = TypeVar("K")
T = TypeVar("T")

TrialWriteHint: TypeAlias = Literal["metadata", "report", "state", "config"]


@dataclass
class ReaderWriterTrial:
"""ReaderWriter for Trial objects."""

# Report and config are kept as yaml since they are most likely to be
# read
CONFIG_FILENAME = "config.yaml"
METADATA_FILENAME = "metadata.yaml"
STATE_FILENAME = "state.txt"
REPORT_FILENAME = "report.yaml"

# Metadata is put as json as it's more likely to be machine read and
# is much faster.
METADATA_FILENAME = "metadata.json"

STATE_FILENAME = "state.txt"
PREVIOUS_TRIAL_ID_FILENAME = "previous_trial_id.txt"

@classmethod
Expand All @@ -43,32 +51,71 @@ def read(cls, directory: Path) -> Trial:
state_path = directory / cls.STATE_FILENAME
report_path = directory / cls.REPORT_FILENAME

with metadata_path.open("r") as f:
metadata = json.load(f)

return Trial(
config=deserialize(config_path),
metadata=Trial.MetaData(**deserialize(metadata_path)),
metadata=Trial.MetaData(**metadata),
state=Trial.State(state_path.read_text(encoding="utf-8").strip()),
report=(
Trial.Report(**deserialize(report_path)) if report_path.exists() else None
),
)

@classmethod
def write(cls, trial: Trial, directory: Path) -> None:
def write(
cls,
trial: Trial,
directory: Path,
*,
hints: list[TrialWriteHint] | TrialWriteHint | None = None,
) -> None:
config_path = directory / cls.CONFIG_FILENAME
metadata_path = directory / cls.METADATA_FILENAME
state_path = directory / cls.STATE_FILENAME

serialize(trial.config, config_path)
serialize(asdict(trial.metadata), metadata_path)
state_path.write_text(trial.state.value, encoding="utf-8")

if trial.metadata.previous_trial_id is not None:
previous_trial_path = directory / cls.PREVIOUS_TRIAL_ID_FILENAME
previous_trial_path.write_text(trial.metadata.previous_trial_id)

if trial.report is not None:
report_path = directory / cls.REPORT_FILENAME
serialize(asdict(trial.report), report_path)
if isinstance(hints, str):
match hints:
case "config":
serialize(trial.config, config_path)
case "metadata":
with metadata_path.open("w") as f:
json.dump(asdict(trial.metadata), f)

if trial.metadata.previous_trial_id is not None:
previous_trial_path = directory / cls.PREVIOUS_TRIAL_ID_FILENAME
previous_trial_path.write_text(trial.metadata.previous_trial_id)
case "report":
if trial.report is None:
raise ValueError(
"Cannot write report 'hint' when report is None."
)

report_path = directory / cls.REPORT_FILENAME
serialize(asdict(trial.report), report_path)
case "state":
state_path.write_text(trial.state.value, encoding="utf-8")
case _:
raise ValueError(f"Invalid hint: {hints}")
elif hints is None:
# We don't know, write everything
serialize(trial.config, config_path)
with metadata_path.open("w") as f:
json.dump(asdict(trial.metadata), f)

if trial.metadata.previous_trial_id is not None:
previous_trial_path = directory / cls.PREVIOUS_TRIAL_ID_FILENAME
previous_trial_path.write_text(trial.metadata.previous_trial_id)

state_path.write_text(trial.state.value, encoding="utf-8")

if trial.report is not None:
report_path = directory / cls.REPORT_FILENAME
serialize(asdict(trial.report), report_path)
else:
for hint in hints:
cls.write(trial, directory, hints=hint)


TrialReaderWriter: Final = ReaderWriterTrial()
Expand Down Expand Up @@ -155,7 +202,9 @@ def write(cls, snapshot: SeedSnapshot, directory: Path) -> None:
"py_rng_version": py_rng_version,
"py_guass_next": py_guass_next,
}
serialize(seed_info, seedinfo_path)
with seedinfo_path.open("w") as f:
json.dump(seed_info, f)

np_rng_state = snapshot.np_rng[1]
np_rng_state.tofile(np_rng_path)

Expand Down Expand Up @@ -195,23 +244,24 @@ def write(cls, optimizer_info: OptimizerInfo, directory: Path) -> None:
class ReaderWriterOptimizationState:
"""ReaderWriter for OptimizationState objects."""

STATE_FILE_NAME: ClassVar = "state.yaml"
STATE_FILE_NAME: ClassVar = "state.json"

@classmethod
def read(cls, directory: Path) -> OptimizationState:
state_path = directory / cls.STATE_FILE_NAME
state = deserialize(state_path)
with state_path.open("r") as f:
state = json.load(f)

shared_state = state.get("shared_state") or {}
budget_info = state.get("budget")
budget = BudgetInfo(**budget_info) if budget_info is not None else None
return OptimizationState(
shared_state=state.get("shared_state") or {},
budget=budget,
)
return OptimizationState(shared_state=shared_state, budget=budget)

@classmethod
def write(cls, info: OptimizationState, directory: Path) -> None:
info_path = directory / cls.STATE_FILE_NAME
serialize(asdict(info), info_path)
with info_path.open("w") as f:
json.dump(asdict(info), f)


@dataclass
Expand Down
Loading

0 comments on commit bb18db7

Please sign in to comment.