diff --git a/README.md b/README.md index b9dc65c..2f35f23 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,6 @@ def __init__( local_dir: Union[str, Path] = "./slurm_sweeps", backend: Optional[Backend] = None, asha: Optional[ASHA] = None, - database: Optional[Database] = None, restore: bool = False, overwrite: bool = False, ) @@ -124,12 +123,10 @@ Set up an HPO experiment. It must contain the search spaces via `slurm_sweeps.Uniform`, `slurm_sweeps.Choice`, etc. - `name` - The name of the experiment. - `local_dir` - Where to store and run the experiments. In this directory - we will create a folder with the experiment name. + we will create the database `slurm_sweeps.db` and a folder with the experiment name. - `backend` - A backend to execute the trials. By default, we choose the `SlurmBackend` if Slurm is available, otherwise we choose the standard `Backend` that simply executes the trial in another process. -- `asha` - An optional ASHA instance to cancel less promising trials. By default, it is None. -- `database` - A database instance to store the trial's (intermediate) results. - By default, we will create the database at `{local_dir}/slurm_sweeps.db`. +- `asha` - An optional ASHA instance to cancel less promising trials. - `restore` - Restore an experiment with the same name? - `overwrite` - Overwrite an existing experiment with the same name? diff --git a/environment.yml b/environment.yml index 49e13dd..fa74291 100644 --- a/environment.yml +++ b/environment.yml @@ -15,7 +15,6 @@ dependencies: # for tests - pytest - pytest-cov - - fasteners - lightning - wandb # for development diff --git a/src/slurm_sweeps/__init__.py b/src/slurm_sweeps/__init__.py index 9892372..d94a446 100644 --- a/src/slurm_sweeps/__init__.py +++ b/src/slurm_sweeps/__init__.py @@ -2,7 +2,6 @@ from .asha import ASHA from .backend import Backend, SlurmBackend -from .database import FileDatabase, SqlDatabase from .experiment import Experiment from .logger import Logger from .sampler import Choice, Grid, LogUniform, Uniform diff --git a/src/slurm_sweeps/database.py b/src/slurm_sweeps/database.py index 1e0b147..1dbe13e 100644 --- a/src/slurm_sweeps/database.py +++ b/src/slurm_sweeps/database.py @@ -1,6 +1,3 @@ -import abc -import datetime -import json import sqlite3 from contextlib import contextmanager from pathlib import Path @@ -10,93 +7,26 @@ from .constants import CFG, ITERATION, TIMESTAMP, TRIAL_ID -try: - import fasteners -except ModuleNotFoundError: - _has_fasteners = False -else: - _has_fasteners = True +class SqlDatabase: + """An SQLite database that stores the trials and their metrics. -class Database(abc.ABC): - def __init__(self, path: Union[str, Path] = "./slurm_sweeps.db"): - self._path = Path(path).resolve() + Args: + path: The path to the database file. + """ - @property - def path(self): - return self._path - - @abc.abstractmethod - def create(self, experiment: str, overwrite: bool = False): - pass - - @abc.abstractmethod - def write(self, experiment: str, row: Dict): - pass - - @abc.abstractmethod - def read(self, experiment: str) -> pd.DataFrame: - pass - - -class FileDatabase(Database): def __init__(self, path: Union[str, Path] = "./slurm_sweeps.db"): - if not _has_fasteners: - raise ModuleNotFoundError( - "You need to install 'fasteners' to use the FileDatabase: " - "`pip install fasteners`" - ) - - super().__init__(path=path) - self.path.mkdir(parents=True, exist_ok=True) - - def _get_file_path_and_lock( - self, experiment - ) -> Tuple[Path, fasteners.InterProcessReaderWriterLock]: - path = self.path / f"{experiment}.txt" - lock = fasteners.InterProcessReaderWriterLock(f"{path}.lock") - - return path, lock - - def create(self, experiment: str, overwrite: bool = False): - path, _ = self._get_file_path_and_lock(experiment) - try: - path.touch(exist_ok=overwrite) - except FileExistsError as err: - raise ExperimentExistsError(experiment) from err - with path.open(mode="w"): - pass - - def write(self, experiment: str, row: Dict): - path, lock = self._get_file_path_and_lock(experiment) - # quick check if file exists - with path.open(): - pass - - row[TIMESTAMP] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] - json_str = json.dumps(row) - - lock.acquire_write_lock() - with path.open(mode="a") as f: - f.write(json_str + "\n") - lock.release_write_lock() - - def read(self, experiment: str) -> pd.DataFrame: - path, lock = self._get_file_path_and_lock(experiment) - lock.acquire_read_lock() - database_df = pd.read_json(path, lines=True) - lock.release_read_lock() - - return database_df - + self._path = Path(path).resolve() -class SqlDatabase(Database): - def __init__(self, path: Union[str, Path] = "./slurm_sweeps.db"): - super().__init__(path) if not self.path.exists(): with self._connection() as con: con.execute("vacuum") + @property + def path(self): + """The path to the database file.""" + return self._path + @contextmanager def _connection(self): connection = sqlite3.connect(self.path, isolation_level=None) diff --git a/src/slurm_sweeps/experiment.py b/src/slurm_sweeps/experiment.py index 37286d4..411bec0 100644 --- a/src/slurm_sweeps/experiment.py +++ b/src/slurm_sweeps/experiment.py @@ -22,7 +22,7 @@ TRIAL_ID, WAITING_TIME_IN_SEC, ) -from .database import Database, ExperimentExistsError, SqlDatabase +from .database import ExperimentExistsError, SqlDatabase from .sampler import Sampler from .storage import Storage from .trial import Status, Trial @@ -39,12 +39,10 @@ class Experiment: It must contain the search spaces via `slurm_sweeps.Uniform`, `slurm_sweeps.Choice`, etc. name: The name of the experiment. local_dir: Where to store and run the experiments. In this directory - we will create a folder with the experiment name. + we will create the database `slurm_sweeps.db` and a folder with the experiment name. backend: A backend to execute the trials. By default, we choose the `SlurmBackend` if Slurm is available, otherwise we choose the standard `Backend` that simply executes the trial in another process. - asha: An optional ASHA instance to cancel less promising trials. By default, it is None. - database: A database instance to store the trial's (intermediate) results. - By default, we will create the database at `{local_dir}/slurm_sweeps.db`. + asha: An optional ASHA instance to cancel less promising trials. restore: Restore an experiment with the same name? overwrite: Overwrite an existing experiment with the same name? """ @@ -57,7 +55,6 @@ def __init__( local_dir: Union[str, Path] = "./slurm_sweeps", backend: Optional[Backend] = None, asha: Optional[ASHA] = None, - database: Optional[Database] = None, restore: bool = False, overwrite: bool = False, ): @@ -73,7 +70,7 @@ def __init__( if asha: self._storage.dump(asha, ASHA_PKL) - self._database = database or SqlDatabase(self._local_dir / "slurm_sweeps.db") + self._database = SqlDatabase(self._local_dir / "slurm_sweeps.db") if not restore: self._database.create(experiment=self._name, overwrite=overwrite) diff --git a/src/slurm_sweeps/logger.py b/src/slurm_sweeps/logger.py index 601fa44..efedc6e 100644 --- a/src/slurm_sweeps/logger.py +++ b/src/slurm_sweeps/logger.py @@ -14,7 +14,7 @@ STORAGE_PATH, TRIAL_ID, ) -from .database import FileDatabase, SqlDatabase +from .database import SqlDatabase from .storage import Storage from .trial import Trial @@ -30,12 +30,9 @@ def __init__(self, cfg: Dict): self._trial = Trial(cfg=cfg) self._experiment_name = os.environ[EXPERIMENT_NAME] - if Path(os.environ[DB_PATH]).is_dir(): - self._database = FileDatabase(os.environ[DB_PATH]) - elif Path(os.environ[DB_PATH]).is_file(): - self._database = SqlDatabase(os.environ[DB_PATH]) - else: + if not Path(os.environ[DB_PATH]).is_file(): raise FileNotFoundError(f"Did not find a database at {os.environ[DB_PATH]}") + self._database = SqlDatabase(os.environ[DB_PATH]) self._asha: Optional[ASHA] = None try: diff --git a/tests/test_database.py b/tests/test_database.py index d3abcff..aa79fb6 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,64 +1,100 @@ +import sqlite3 from functools import partial from multiprocessing import Pool -from typing import Dict, Union +from typing import Dict import numpy as np +import pandas as pd import pytest -from slurm_sweeps.constants import TIMESTAMP -from slurm_sweeps.database import ExperimentExistsError, FileDatabase, SqlDatabase +from slurm_sweeps.constants import CFG, ITERATION, TIMESTAMP, TRIAL_ID +from slurm_sweeps.database import ExperimentExistsError, SqlDatabase -@pytest.fixture(params=["file", "sql"]) -def database(request, tmp_path) -> Union[FileDatabase, SqlDatabase]: +@pytest.fixture +def database(tmp_path) -> SqlDatabase: db_path = tmp_path / "slurm_sweeps.db" - if request.param == "file": - return FileDatabase(db_path) - if request.param == "sql": - return SqlDatabase(db_path) - raise NotImplementedError( - f"'database' fixture not implemented for '{request.param}'" - ) - - -@pytest.mark.parametrize( - "database, is_file_or_dir", - [("file", "is_dir"), ("sql", "is_file")], - indirect=["database"], -) -def test_init(database, is_file_or_dir): - assert getattr(database.path, is_file_or_dir)() + return SqlDatabase(db_path) -def test_fasteners_not_installed(monkeypatch): - monkeypatch.setattr("slurm_sweeps.database._has_fasteners", False) - with pytest.raises(ModuleNotFoundError): - FileDatabase() +def test_init(database): + assert database.path.is_file() def test_create(database): experiment = "test_experiment" database.create(experiment) - database.write(experiment, {"test": "test"}) - assert len(database.read(experiment)) == 1 + + con = sqlite3.connect(database.path) + check_exists = con.execute( + "select name from sqlite_master where type='table' and name='test_experiment';" + ).fetchone() + check_columns = con.execute(f"pragma table_info({experiment})").fetchall() + con.close() + assert check_exists == ("test_experiment",) + assert [(col[1], col[2], col[4]) for col in check_columns] == [ + (TIMESTAMP, "datetime", "strftime('%Y-%m-%d %H:%M:%f', 'NOW')"), + (TRIAL_ID, "TEXT", None), + (ITERATION, "INTEGER", None), + (CFG, "TEXT", None), + ] with pytest.raises(ExperimentExistsError): database.create(experiment) + database.write("test_experiment", {"test": 0}) database.create(experiment, overwrite=True) - assert len(database.read(experiment)) == 0 + # check if empty again + con = sqlite3.connect(database.path) + response = con.execute("select exists (select 1 from test_experiment);").fetchone() + con.close() + assert response == (0,) + + +def test_write_and_read(database): + database.create("test_experiment") + database.write("test_experiment", {"test_int": 1, "test_float": 0.5}) + + con = sqlite3.connect(database.path) + response = con.execute("select count(*) from test_experiment").fetchone() + con.close() + assert response == (1,) + + df = database.read("test_experiment") + assert isinstance(df, pd.DataFrame) + assert list(df.columns) == [ + TIMESTAMP, + TRIAL_ID, + ITERATION, + CFG, + "test_int", + "test_float", + ] + assert type(df[TIMESTAMP].iloc[0]) is str + assert ( + df.iloc[:, 1:] + .compare( + pd.DataFrame( + { + TRIAL_ID: [np.nan], + ITERATION: [np.nan], + CFG: [np.nan], + "test_int": [1], + "test_float": [0.5], + } + ) + ) + .empty + ) -def read_or_write( - mode: str, database: Union[FileDatabase, SqlDatabase], experiment: str, row: Dict -): +def read_or_write(mode: str, database: SqlDatabase, experiment: str, row: Dict): if mode == "w": database.write(experiment, row) else: database.read(experiment) -@pytest.mark.parametrize("database", ["file", "sql"], indirect=["database"]) def test_concurrent_write_read(database): experiment = "test_db_write_read" n = 250 @@ -91,7 +127,7 @@ def test_nan_values(database): assert np.isnan(df["loss"].iloc[0]) -@pytest.mark.skip("Only for speed comparisons") +# @pytest.mark.skip("Only for speed comparisons") def test_speed(monkeypatch, database): from slurm_sweeps import Logger from slurm_sweeps.constants import DB_PATH, EXPERIMENT_NAME diff --git a/tests/test_readme.py b/tests/test_readme.py index bc9585b..b2aa154 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -4,8 +4,8 @@ import pytest import slurm_sweeps as ss -from slurm_sweeps import SqlDatabase from slurm_sweeps.constants import ITERATION +from slurm_sweeps.database import SqlDatabase def is_slurm_available() -> bool: