Skip to content

Commit

Permalink
Added support for PyTorch Lightning in the DDP backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gopalji Gaur committed Dec 8, 2024
1 parent 5e40590 commit c2e4cc8
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 1 deletion.
50 changes: 49 additions & 1 deletion neps/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
WorkerRaiseError,
)
from neps.state._eval import evaluate_trial
from neps.state.filebased import create_or_load_filebased_neps_state
from neps.state.filebased import (
create_or_load_filebased_neps_state,
load_filebased_neps_state,
)
from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo
from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings
from neps.state.trial import Trial
Expand All @@ -43,6 +46,24 @@ def _default_worker_name() -> str:
return f"{os.getpid()}-{isoformat}"


def _is_ddp_and_not_rank_zero() -> bool:
import torch.distributed as dist

# Check for environment variables typically set by DDP
ddp_env_vars = ["WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"]
rank_env_vars = ["RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK"]

# Check if PyTorch distributed is initialized
if (dist.is_available() and dist.is_initialized()) or all(
var in os.environ for var in ddp_env_vars
):
for var in rank_env_vars:
rank = os.environ.get(var)
if rank is not None:
return int(rank) != 0
return False


N_FAILED_GET_NEXT_PENDING_ATTEMPTS_BEFORE_ERROR = 0
N_FAILED_TO_SET_TRIAL_STATE = 10

Expand Down Expand Up @@ -488,6 +509,26 @@ def run(self) -> None: # noqa: C901, PLR0915
)


def _launch_ddp_runtime(
*,
evaluation_fn: Callable[..., float | Mapping[str, Any]],
optimization_dir: Path,
) -> None:
neps_state = load_filebased_neps_state(directory=optimization_dir)

# TODO: This is a bit of a hack to get the current trial to evaluate. Sometimes
# the previous trial gets sampled when we don't want it to. This is a bit of a
# hack to get around that.
prev_trial = None
while True:
current_trial = neps_state.get_current_evaluating_trial()
if current_trial is not None and (
prev_trial is None or current_trial.id != prev_trial.id # type: ignore[unreachable]
):
evaluation_fn(**current_trial.config)
prev_trial = current_trial


# TODO: This should be done directly in `api.run` at some point to make it clearer at an
# entryy point how the woerer is set up to run if someone reads the entry point code.
def _launch_runtime( # noqa: PLR0913
Expand All @@ -506,6 +547,13 @@ def _launch_runtime( # noqa: PLR0913
max_evaluations_for_worker: int | None,
pre_load_hooks: Iterable[Callable[[BaseOptimizer], BaseOptimizer]] | None,
) -> None:
if _is_ddp_and_not_rank_zero():
# Do not launch a new worker if we are in a DDP setup and not rank 0
_launch_ddp_runtime(
evaluation_fn=evaluation_fn, optimization_dir=optimization_dir
)
return

if overwrite_optimization_dir and optimization_dir.exists():
logger.info(
f"Overwriting optimization directory '{optimization_dir}' as"
Expand Down
9 changes: 9 additions & 0 deletions neps/state/filebased.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,15 @@ def pending(self) -> Iterable[tuple[str, Synced[Trial, Path]]]:
]
return iter((_id, t) for _id, t, _ in sorted(pending, key=lambda x: x[2]))

@override
def evaluating(self) -> Iterable[tuple[str, Synced[Trial, Path]]]:
evaluating = [
(_id, t, trial.metadata.time_sampled)
for (_id, t) in self.all().items()
if (trial := t.synced()).state == Trial.State.EVALUATING
]
return iter((_id, t) for _id, t, _ in sorted(evaluating, key=lambda x: x[2]))


@dataclass
class ReaderWriterTrial(ReaderWriter[Trial, Path]):
Expand Down
6 changes: 6 additions & 0 deletions neps/state/neps_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,12 @@ def get_next_pending_trial(self, n: int | None = None) -> Trial | list[Trial] |
return take(n, _pending_itr)
return next(_pending_itr, None)

def get_current_evaluating_trial(self) -> Trial | None:
"""Get the current evaluating trial."""
for _, shared_trial in self._trials.evaluating():
return shared_trial.synced()
return None

def all_trial_ids(self) -> set[str]:
"""Get all the trial ids that are known about."""
return self._trials.all_trial_ids()
Expand Down
4 changes: 4 additions & 0 deletions neps/state/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def pending(self) -> Iterable[tuple[str, Synced[Trial, K]]]:
"""
...

def evaluating(self) -> Iterable[tuple[str, Synced[Trial, K]]]:
"""Get all evaluating trials in the repo."""
...


@dataclass
class VersionedResource(Generic[T, K]):
Expand Down

0 comments on commit c2e4cc8

Please sign in to comment.