Skip to content

Commit

Permalink
Feature/sq lite database (#9)
Browse files Browse the repository at this point in the history
* First working version of sqlite

* Make existing file-based DB subclass of an abstract DBObject

* Import and use FileDatabase

* Added SQLDatabase logic

* move to SQLDatabase

* remove unwanted stuff

* add metric property to ASHA

* Add DB constants

* rename SQLDatabase + use constants as DB keys + define types in DB

* update readme

* skip speed test

* setup dev env in CI

* fix CI

* fix + cpuonly

* fix

* fix

* fix

* fix

* update readme

* fix test and readme

---------

Co-authored-by: xxemmexx <emme@pm.me>
  • Loading branch information
dcfidalgo and xxemmexx authored Nov 10, 2023
1 parent 391d467 commit 785006e
Show file tree
Hide file tree
Showing 14 changed files with 395 additions and 125 deletions.
42 changes: 30 additions & 12 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ jobs:

tests:
runs-on: ubuntu-latest
defaults:
run:
shell: bash -el {0}
# For the slurm action to work, you have to supply a mysql service as defined below.
services:
mysql:
Expand All @@ -22,22 +25,37 @@ jobs:
steps:
- name: Set up SLURM
uses: koesterlab/setup-slurm-action@v1
- name: Setup Mambaforge
uses: conda-incubator/setup-miniconda@v2
with:
miniforge-variant: Mambaforge
miniforge-version: latest
use-mamba: true
activate-environment: slurm_sweeps
- name: Check out repo
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
- name: Get Date for Conda cache
id: get-date
run: echo "today=$(/bin/date -u '+%Y%m%d')" >> $GITHUB_OUTPUT
shell: bash
- name: Cache Conda env
uses: actions/cache@v3
env:
CACHE_NUMBER: 0
with:
python-version: 3.9
- name: Install package
run: |
python -m pip install --upgrade pip
python -m pip install -e .
- name: Install test dependencies
run: |
python -m pip install pytest pytest-cov
path: ${{ env.CONDA }}/envs
key:
conda-${{
runner.os }}-${{
steps.get-date.outputs.today }}-${{
hashFiles('environment.yml') }}-${{
env.CACHE_NUMBER }}
id: cache
- name: Update environment
run: mamba env update -n slurm_sweeps -f environment.yml
if: steps.cache.outputs.cache-hit != 'true'
- name: Running tests
run: |
pytest --cov=slurm_sweeps
run: pytest --cov=slurm_sweeps
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
env:
Expand Down
56 changes: 39 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
slurm sweeps
</h1>
<p align="center"><b>A simple tool to perform parameter sweeps on SLURM clusters.</b></p>
<p align="center">
<p align="center">
<a href="https://github.com/dcfidalgo/slurm_sweeps/blob/main/LICENSE">
<img alt="License" src="https://img.shields.io/github/license/dcfidalgo/slurm_sweeps.svg?color=blue">
</a>
Expand Down Expand Up @@ -32,7 +32,6 @@ pip install .

### Dependencies
- cloudpickle
- fasteners
- numpy
- pandas
- pyyaml
Expand All @@ -53,7 +52,8 @@ def train(cfg: dict):
for epoch in range(cfg["epochs"]):
sleep(0.5)
loss = (cfg["parameter"] - 1) ** 2 * epoch
logger.log("loss", loss, epoch)
# log your metric
logger.log({"loss": loss}, epoch)


# Define your experiment
Expand All @@ -79,7 +79,7 @@ Write a small SLURM script `test_ss.slurm` that runs the code above:
```bash
#!/bin/bash -l
#SBATCH --nodes=2
#SBATCH --tasks-per-node=18
#SBATCH --ntasks-per-node=18
#SBATCH --cpus-per-task=4
#SBATCH --mem-per-cpu=1GB

Expand All @@ -102,24 +102,25 @@ See the `tests` folder for an advanced example of training a PyTorch model with

```python
def __init__(
self,
train: Callable,
cfg: Dict,
name: str = "MySweep",
local_dir: str = "./slurm_sweeps",
local_dir: Union[str, Path] = "./slurm_sweeps",
backend: Optional[Backend] = None,
asha: Optional[ASHA] = None,
database: Optional[Database] = None,
restore: bool = False,
exist_ok: bool = False,
overwrite: bool = False,
)
```

Run an HPO experiment using random search and the Asynchronous Successive Halving Algorithm (ASHA).
Set up an HPO experiment.

**Arguments**:

- `train` - A train function that takes as input a `cfg` dict.
- `cfg` - A dict passed on to the `LICENSEtrain` function.
- `train` - A train function that takes as input the `cfg` dict.
- `cfg` - A dict passed on to the `train` function.
It must contain the search spaces via `slurm_sweeps.Uniform`, `slurm_sweeps.Choice`, etc.
- `name` - The name of the experiment.
- `local_dir` - Where to store and run the experiments. In this directory
Expand All @@ -128,29 +129,28 @@ Run an HPO experiment using random search and the Asynchronous Successive Halvin
otherwise we choose the standard `Backend` that simply executes the trial in another process.
- `asha` - An optional ASHA instance to cancel less promising trials. By default, it is None.
- `database` - A database instance to store the trial's (intermediate) results.
By default, it will create the database at `{local_dir}/.database'.
By default, we will create the database at `{local_dir}/slurm_sweeps.db`.
- `restore` - Restore an experiment with the same name?
- `exist_ok` - Replace an existing experiment with the same name?

<a id="slurm_sweeps.experiment.Experiment.run"></a>
- `overwrite` - Overwrite an existing experiment with the same name?

#### `Experiment.run`

```python
def run(
self,
n_trials: int = 1,
max_concurrent_trials: Optional[int] = None,
summary_interval_in_sec: float = 5.0,
nr_of_rows_in_summary: int = 10,
summarize_cfg_and_metrics: Union[bool, List[str]] = True
) -> pandas.DataFrame
summarize_cfg_and_metrics: Union[bool, List[str]] = True,
) -> pd.DataFrame
```

Run the experiment.

**Arguments**:

- `n_trials` - Number of trials to run.
- `n_trials` - Number of trials to run. For grid searches this parameter is ignored.
- `max_concurrent_trials` - The maximum number of trials running concurrently. By default, we will set this to
the number of cpus available, or the number of total Slurm tasks divided by the number of trial Slurm
tasks requested.
Expand All @@ -159,10 +159,32 @@ Run the experiment.
- `summarize_cfg_and_metrics` - Should we include the cfg and the metrics in the summary table?
You can also pass in a list of strings to only select a few cfg and metric keys.


**Returns**:

A DataFrame of the database.

### `CLASS slurm_sweeps.SlurmBackend`

```python
def __init__(
self,
exclusive: bool = True,
nodes: int = 1,
ntasks: int = 1,
args: str = ""
)
```

Execute the training runs on a Slurm cluster via `srun`.

Pass an instance of this class to your experiment.

**Arguments**:

- `exclusive` - Add the `--exclusive` switch.
- `nodes` - How many nodes do you request for your srun?
- `ntasks` - How many tasks do you request for your srun?
- `args` - Additional command line arguments for srun, formatted as a string.

## Contact
David Carreto Fidalgo (david.carreto.fidalgo@mpcdf.mpg.de)
20 changes: 11 additions & 9 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@ channels:
- pytorch
- nvidia
dependencies:
- python=3.9
- python>3.9.0
- pytorch
- pytorch-cuda=11.8
# - cpuonly
# - pytorch-cuda=11.8
- cpuonly
- torchvision
- ipython
- pytest=7.3.1
- pre_commit=3.3.1
- blackd=22.10.0
- wandb
- pytorch-lightning
- pip
- pip:
- -e .
# for tests
- pytest
- pytest-cov
- fasteners
- lightning
- wandb
# for development
- pre_commit
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ version = "0.1.dev"
requires-python = ">=3.9"
dependencies = [
"cloudpickle",
"fasteners",
"numpy",
"pandas",
"pyyaml",
Expand Down
2 changes: 1 addition & 1 deletion src/slurm_sweeps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .asha import ASHA
from .backend import Backend, SlurmBackend
from .database import Database
from .database import FileDatabase, SqlDatabase
from .experiment import Experiment
from .logger import Logger
from .sampler import Choice, Grid, LogUniform, Uniform
Expand Down
5 changes: 5 additions & 0 deletions src/slurm_sweeps/asha.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def __init__(
self._min_t * (self._rf**i) for i in reversed(range(rung_max + 1))
]

@property
def metric(self):
"""The metric to optimize."""
return self._metric

def find_trials_to_prune(self, database: "pd.DataFrame") -> List[str]:
"""Check the database and find trials to prune.
Expand Down
2 changes: 2 additions & 0 deletions src/slurm_sweeps/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def _build_args(train_path: Path, cfg_path: Path) -> str:
class SlurmBackend(Backend):
"""Execute the training runs on a Slurm cluster via `srun`.
Pass an instance of this class to your experiment.
Args:
exclusive: Add the `--exclusive` switch.
nodes: How many nodes do you request for your srun?
Expand Down
8 changes: 6 additions & 2 deletions src/slurm_sweeps/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
EXPERIMENT_NAME = "SLURMSWEEPS_EXPERIMENT_NAME"

# DB keys
ITERATION = "iteration"
TRIAL_ID = "trial_id"
CFG = "_cfg"
ITERATION = "_iteration"
LOGGED = "_logged"
TIMESTAMP = "_timestamp"
TRIAL_ID = "_trial_id"


# Storage keys
TRAIN_PKL = "train.pkl"
Expand Down
Loading

0 comments on commit 785006e

Please sign in to comment.