Skip to content

Commit

Permalink
feat: Pass in batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman committed Dec 11, 2024
1 parent a3610bc commit 7a749b1
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 110 deletions.
7 changes: 6 additions & 1 deletion neps/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def run(
loss_value_on_error: None | float = Default(None),
cost_value_on_error: None | float = Default(None),
pre_load_hooks: Iterable | None = Default(None),
sample_batch_size: int | None = Default(None),
searcher: (
Literal[
"default",
Expand Down Expand Up @@ -98,6 +99,8 @@ def run(
cost_value_on_error: Setting this and loss_value_on_error to any float will
supress any error and will use given cost value instead. default: None
pre_load_hooks: List of functions that will be called before load_results().
sample_batch_size: The number of samples to ask for in a single call to the
optimizer.
searcher: Which optimizer to use. Can be a string identifier, an
instance of BaseOptimizer, or a Path to a custom optimizer.
**searcher_kwargs: Will be passed to the searcher. This is usually only needed by
Expand Down Expand Up @@ -236,6 +239,7 @@ def run(
ignore_errors=settings.ignore_errors,
overwrite_optimization_dir=settings.overwrite_working_directory,
pre_load_hooks=settings.pre_load_hooks,
sample_batch_size=settings.sample_batch_size,
)

if settings.post_run_summary:
Expand Down Expand Up @@ -278,7 +282,8 @@ def _run_args(
"mobster",
"asha",
]
| BaseOptimizer | dict
| BaseOptimizer
| dict
) = "default",
**searcher_kwargs,
) -> tuple[BaseOptimizer, dict]:
Expand Down
33 changes: 11 additions & 22 deletions neps/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from typing import Any


class NePSError(Exception):
"""Base class for all NePS exceptions.
Expand All @@ -11,35 +13,22 @@ class NePSError(Exception):
"""


class VersionMismatchError(NePSError):
"""Raised when the version of a resource does not match the expected version."""


class VersionedResourceAlreadyExistsError(NePSError):
"""Raised when a version already exists when trying to create a new versioned
data.
"""


class VersionedResourceRemovedError(NePSError):
"""Raised when a version already exists when trying to create a new versioned
data.
"""


class VersionedResourceDoesNotExistError(NePSError):
"""Raised when a versioned resource does not exist at a location."""


class LockFailedError(NePSError):
"""Raised when a lock cannot be acquired."""


class TrialAlreadyExistsError(VersionedResourceAlreadyExistsError):
class TrialAlreadyExistsError(NePSError):
"""Raised when a trial already exists in the store."""

def __init__(self, trial_id: str, *args: Any) -> None:
super().__init__(trial_id, *args)
self.trial_id = trial_id

def __str__(self) -> str:
return f"Trial with id {self.trial_id} already exists!"


class TrialNotFoundError(VersionedResourceDoesNotExistError):
class TrialNotFoundError(NePSError):
"""Raised when a trial already exists in the store."""


Expand Down
21 changes: 19 additions & 2 deletions neps/optimizers/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, overload

from neps.state.trial import Report, Trial

Expand Down Expand Up @@ -106,12 +106,29 @@ def __init__(
self.learning_curve_on_error = learning_curve_on_error
self.ignore_errors = ignore_errors

@overload
def ask(
self,
trials: Mapping[str, Trial],
budget_info: BudgetInfo | None,
n: int,
) -> list[SampledConfig]: ...

@overload
def ask(
self,
trials: Mapping[str, Trial],
budget_info: BudgetInfo | None,
n: None = None,
) -> SampledConfig: ...

@abstractmethod
def ask(
self,
trials: Mapping[str, Trial],
budget_info: BudgetInfo | None,
) -> SampledConfig:
n: int | None = None,
) -> SampledConfig | list[SampledConfig]:
"""Sample a new configuration.
Args:
Expand Down
59 changes: 48 additions & 11 deletions neps/optimizers/bayesian_optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,35 +127,58 @@ def __init__(
self.cost_on_log_scale = cost_on_log_scale
self.device = device
self.sample_default_first = sample_default_first
self.n_initial_design = initial_design_size
self.init_design: list[dict[str, Any]] | None = None

if initial_design_size is not None:
self.n_initial_design = initial_design_size
else:
self.n_initial_design = len(pipeline_space.numerical) + len(
pipeline_space.categoricals
)

@override
def ask(
self,
trials: Mapping[str, Trial],
budget_info: BudgetInfo | None = None,
) -> SampledConfig:
n: int | None = None,
) -> SampledConfig | list[SampledConfig]:
_n = 1 if n is None else n
n_sampled = len(trials)
config_id = str(n_sampled + 1)
config_ids = iter(str(n_sampled + i) for i in range(_n + 1))
space = self.pipeline_space

# If we havn't passed the intial design phase
if self.init_design is None:
if n is not None and self.n_initial_design < n:
init_design_size = n
else:
init_design_size = self.n_initial_design
self.init_design = make_initial_design(
space=space,
encoder=self.encoder,
sample_default_first=self.sample_default_first,
sampler=self.prior if self.prior is not None else "sobol",
seed=None, # TODO: Seeding
sample_size=(
"ndim" if self.n_initial_design is None else self.n_initial_design
),
sample_size=init_design_size,
sample_fidelity="max",
)

if n_sampled < len(self.init_design):
return SampledConfig(id=config_id, config=self.init_design[n_sampled])
sampled_configs = [
SampledConfig(id=config_id, config=config)
for config_id, config in zip(
config_ids, self.init_design[n_sampled : n_sampled + _n], strict=False
)
]
if len(sampled_configs) == _n:
if n is None:
return sampled_configs[0]

return sampled_configs

assert len(sampled_configs) < _n

_n = _n - len(sampled_configs)

# Otherwise, we encode trials and setup to fit and acquire from a GP
data, encoder = encode_trials_for_gp(
Expand Down Expand Up @@ -185,7 +208,7 @@ def ask(
prior = None if pibo_exp_term < 1e-4 else self.prior

gp = make_default_single_obj_gp(x=data.x, y=data.y, encoder=encoder)
candidate = fit_and_acquire_from_gp(
candidates = fit_and_acquire_from_gp(
gp=gp,
x_train=data.x,
encoder=encoder,
Expand All @@ -200,11 +223,25 @@ def ask(
prune_baseline=True,
),
prior=prior,
n_candidates_required=_n,
pibo_exp_term=pibo_exp_term,
costs=data.cost if self.use_cost else None,
cost_percentage_used=cost_percent,
costs_on_log_scale=self.cost_on_log_scale,
)

config = encoder.decode(candidate)[0]
return SampledConfig(id=config_id, config=config)
config_ids = list(config_ids)
print(_n, len(candidates), len(config_ids)) # noqa: T201

configs = encoder.decode(candidates)
sampled_configs.extend(
[
SampledConfig(id=config_id, config=config)
for config_id, config in zip(config_ids, configs, strict=True)
]
)

if n is None:
return sampled_configs[0]

return sampled_configs
23 changes: 18 additions & 5 deletions neps/optimizers/random_search/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,22 @@ def ask(
self,
trials: Mapping[str, Trial],
budget_info: BudgetInfo | None,
) -> SampledConfig:
n: int | None = None,
) -> SampledConfig | list[SampledConfig]:
n_trials = len(trials)
config = self.sampler.sample_one(to=self.encoder.domains)
config_dict = self.encoder.decode_one(config)
config_id = str(n_trials + 1)
return SampledConfig(config=config_dict, id=config_id, previous_config_id=None)
_n = 1 if n is None else n
configs = self.sampler.sample(_n, to=self.encoder.domains)
config_dicts = self.encoder.decode(configs)
if n == 1:
config = config_dicts[0]
config_id = str(n_trials + 1)
return SampledConfig(config=config, id=config_id, previous_config_id=None)

return [
SampledConfig(
config=config,
id=str(n_trials + i + 1),
previous_config_id=None,
)
for i, config in enumerate(config_dicts)
]
55 changes: 36 additions & 19 deletions neps/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,12 @@ def _check_global_stopping_criterion(
) -> str | Literal[False]:
if self.settings.max_evaluations_total is not None:
if self.settings.include_in_progress_evaluations_towards_maximum:
# NOTE: We can just use the sum of trials in this case as they
# either have a report, are pending or being evaluated. There
# are also crashed and unknown states which we include into this.
count = len(trials)
count = sum(
1
for _, trial in trials.items()
if trial.metadata.state
not in (Trial.State.PENDING, Trial.State.SUBMITTED)
)
else:
# This indicates they have completed.
count = sum(1 for _, trial in trials.items() if trial.report is not None)
Expand Down Expand Up @@ -399,32 +401,45 @@ def _get_next_trial(self) -> Trial | Literal["break"]:
)
return earliest_pending

sampled_trial = self.state._sample_trial(
sampled_trials = self.state._sample_trial(
optimizer=self.optimizer,
worker_id=self.worker_id,
trials=trials,
n=self.settings.batch_size,
)
if isinstance(sampled_trials, Trial):
this_workers_trial = sampled_trials
else:
this_workers_trial = sampled_trials[0]
sampled_trials[1:]

with self.state._trial_lock.lock(worker_id=self.worker_id), gc_disabled():
this_workers_trial.set_evaluating(
time_started=time.time(),
worker_id=self.worker_id,
)
try:
sampled_trial.set_evaluating(
time_started=time.time(),
worker_id=self.worker_id,
)
self.state._trials.new_trial(sampled_trial)
logger.info(
"Worker '%s' sampled new trial: %s.",
self.worker_id,
sampled_trial.id,
)
return sampled_trial
self.state._trials.new_trial(sampled_trials)
if isinstance(sampled_trials, Trial):
logger.info(
"Worker '%s' sampled new trial: %s.",
self.worker_id,
this_workers_trial.id,
)
else:
logger.info(
"Worker '%s' sampled new trials: %s.",
self.worker_id,
",".join(trial.id for trial in sampled_trials),
)
return this_workers_trial
except TrialAlreadyExistsError as e:
if sampled_trial.id in trials:
if e.trial_id in trials:
logger.error(
"The new sampled trial was given an id of '%s', yet this"
" exists in the loaded in trials given to the optimizer. This"
" indicates a bug with the optimizers allocation of ids.",
sampled_trial.id,
e.trial_id,
)
else:
_grace = DefaultWorker._GRACE
Expand All @@ -439,7 +454,7 @@ def _get_next_trial(self) -> Trial | Literal["break"]:
" '%s's to '%s's. You can control the initial"
" grace with 'NEPS_FS_SYNC_GRACE_BASE' and the increment with"
" 'NEPS_FS_SYNC_GRACE_INC'.",
sampled_trial.id,
e.trial_id,
_grace,
_grace + _inc,
)
Expand Down Expand Up @@ -595,6 +610,7 @@ def _launch_runtime( # noqa: PLR0913
overwrite_optimization_dir: bool,
max_evaluations_total: int | None,
max_evaluations_for_worker: int | None,
sample_batch_size: int | None,
pre_load_hooks: Iterable[Callable[[BaseOptimizer], BaseOptimizer]] | None,
) -> None:
if overwrite_optimization_dir and optimization_dir.exists():
Expand Down Expand Up @@ -643,6 +659,7 @@ def _launch_runtime( # noqa: PLR0913
if ignore_errors
else OnErrorPossibilities.RAISE_ANY_ERROR
),
batch_size=sample_batch_size,
default_report_values=DefaultReportValues(
loss_value_on_error=loss_value_on_error,
cost_value_on_error=cost_value_on_error,
Expand Down
Loading

0 comments on commit 7a749b1

Please sign in to comment.