Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove file database #14

Merged
merged 6 commits into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def __init__(
local_dir: Union[str, Path] = "./slurm_sweeps",
backend: Optional[Backend] = None,
asha: Optional[ASHA] = None,
database: Optional[Database] = None,
restore: bool = False,
overwrite: bool = False,
)
Expand All @@ -124,12 +123,10 @@ Set up an HPO experiment.
It must contain the search spaces via `slurm_sweeps.Uniform`, `slurm_sweeps.Choice`, etc.
- `name` - The name of the experiment.
- `local_dir` - Where to store and run the experiments. In this directory
we will create a folder with the experiment name.
we will create the database `slurm_sweeps.db` and a folder with the experiment name.
- `backend` - A backend to execute the trials. By default, we choose the `SlurmBackend` if Slurm is available,
otherwise we choose the standard `Backend` that simply executes the trial in another process.
- `asha` - An optional ASHA instance to cancel less promising trials. By default, it is None.
- `database` - A database instance to store the trial's (intermediate) results.
By default, we will create the database at `{local_dir}/slurm_sweeps.db`.
- `asha` - An optional ASHA instance to cancel less promising trials.
- `restore` - Restore an experiment with the same name?
- `overwrite` - Overwrite an existing experiment with the same name?

Expand Down
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ dependencies:
# for tests
- pytest
- pytest-cov
- fasteners
- lightning
- wandb
# for development
Expand Down
1 change: 0 additions & 1 deletion src/slurm_sweeps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from .asha import ASHA
from .backend import Backend, SlurmBackend
from .database import FileDatabase, SqlDatabase
from .experiment import Experiment
from .logger import Logger
from .sampler import Choice, Grid, LogUniform, Uniform
Expand Down
92 changes: 11 additions & 81 deletions src/slurm_sweeps/database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import abc
import datetime
import json
import sqlite3
from contextlib import contextmanager
from pathlib import Path
Expand All @@ -10,93 +7,26 @@

from .constants import CFG, ITERATION, TIMESTAMP, TRIAL_ID

try:
import fasteners
except ModuleNotFoundError:
_has_fasteners = False
else:
_has_fasteners = True

class SqlDatabase:
"""An SQLite database that stores the trials and their metrics.

class Database(abc.ABC):
def __init__(self, path: Union[str, Path] = "./slurm_sweeps.db"):
self._path = Path(path).resolve()
Args:
path: The path to the database file.
"""

@property
def path(self):
return self._path

@abc.abstractmethod
def create(self, experiment: str, overwrite: bool = False):
pass

@abc.abstractmethod
def write(self, experiment: str, row: Dict):
pass

@abc.abstractmethod
def read(self, experiment: str) -> pd.DataFrame:
pass


class FileDatabase(Database):
def __init__(self, path: Union[str, Path] = "./slurm_sweeps.db"):
if not _has_fasteners:
raise ModuleNotFoundError(
"You need to install 'fasteners' to use the FileDatabase: "
"`pip install fasteners`"
)

super().__init__(path=path)
self.path.mkdir(parents=True, exist_ok=True)

def _get_file_path_and_lock(
self, experiment
) -> Tuple[Path, fasteners.InterProcessReaderWriterLock]:
path = self.path / f"{experiment}.txt"
lock = fasteners.InterProcessReaderWriterLock(f"{path}.lock")

return path, lock

def create(self, experiment: str, overwrite: bool = False):
path, _ = self._get_file_path_and_lock(experiment)
try:
path.touch(exist_ok=overwrite)
except FileExistsError as err:
raise ExperimentExistsError(experiment) from err
with path.open(mode="w"):
pass

def write(self, experiment: str, row: Dict):
path, lock = self._get_file_path_and_lock(experiment)
# quick check if file exists
with path.open():
pass

row[TIMESTAMP] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
json_str = json.dumps(row)

lock.acquire_write_lock()
with path.open(mode="a") as f:
f.write(json_str + "\n")
lock.release_write_lock()

def read(self, experiment: str) -> pd.DataFrame:
path, lock = self._get_file_path_and_lock(experiment)
lock.acquire_read_lock()
database_df = pd.read_json(path, lines=True)
lock.release_read_lock()

return database_df

self._path = Path(path).resolve()

class SqlDatabase(Database):
def __init__(self, path: Union[str, Path] = "./slurm_sweeps.db"):
super().__init__(path)
if not self.path.exists():
with self._connection() as con:
con.execute("vacuum")

@property
def path(self):
"""The path to the database file."""
return self._path

@contextmanager
def _connection(self):
connection = sqlite3.connect(self.path, isolation_level=None)
Expand Down
11 changes: 4 additions & 7 deletions src/slurm_sweeps/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
TRIAL_ID,
WAITING_TIME_IN_SEC,
)
from .database import Database, ExperimentExistsError, SqlDatabase
from .database import ExperimentExistsError, SqlDatabase
from .sampler import Sampler
from .storage import Storage
from .trial import Status, Trial
Expand All @@ -39,12 +39,10 @@ class Experiment:
It must contain the search spaces via `slurm_sweeps.Uniform`, `slurm_sweeps.Choice`, etc.
name: The name of the experiment.
local_dir: Where to store and run the experiments. In this directory
we will create a folder with the experiment name.
we will create the database `slurm_sweeps.db` and a folder with the experiment name.
backend: A backend to execute the trials. By default, we choose the `SlurmBackend` if Slurm is available,
otherwise we choose the standard `Backend` that simply executes the trial in another process.
asha: An optional ASHA instance to cancel less promising trials. By default, it is None.
database: A database instance to store the trial's (intermediate) results.
By default, we will create the database at `{local_dir}/slurm_sweeps.db`.
asha: An optional ASHA instance to cancel less promising trials.
restore: Restore an experiment with the same name?
overwrite: Overwrite an existing experiment with the same name?
"""
Expand All @@ -57,7 +55,6 @@ def __init__(
local_dir: Union[str, Path] = "./slurm_sweeps",
backend: Optional[Backend] = None,
asha: Optional[ASHA] = None,
database: Optional[Database] = None,
restore: bool = False,
overwrite: bool = False,
):
Expand All @@ -73,7 +70,7 @@ def __init__(
if asha:
self._storage.dump(asha, ASHA_PKL)

self._database = database or SqlDatabase(self._local_dir / "slurm_sweeps.db")
self._database = SqlDatabase(self._local_dir / "slurm_sweeps.db")
if not restore:
self._database.create(experiment=self._name, overwrite=overwrite)

Expand Down
9 changes: 3 additions & 6 deletions src/slurm_sweeps/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
STORAGE_PATH,
TRIAL_ID,
)
from .database import FileDatabase, SqlDatabase
from .database import SqlDatabase
from .storage import Storage
from .trial import Trial

Expand All @@ -30,12 +30,9 @@ def __init__(self, cfg: Dict):
self._trial = Trial(cfg=cfg)

self._experiment_name = os.environ[EXPERIMENT_NAME]
if Path(os.environ[DB_PATH]).is_dir():
self._database = FileDatabase(os.environ[DB_PATH])
elif Path(os.environ[DB_PATH]).is_file():
self._database = SqlDatabase(os.environ[DB_PATH])
else:
if not Path(os.environ[DB_PATH]).is_file():
raise FileNotFoundError(f"Did not find a database at {os.environ[DB_PATH]}")
self._database = SqlDatabase(os.environ[DB_PATH])

self._asha: Optional[ASHA] = None
try:
Expand Down
102 changes: 69 additions & 33 deletions tests/test_database.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,100 @@
import sqlite3
from functools import partial
from multiprocessing import Pool
from typing import Dict, Union
from typing import Dict

import numpy as np
import pandas as pd
import pytest

from slurm_sweeps.constants import TIMESTAMP
from slurm_sweeps.database import ExperimentExistsError, FileDatabase, SqlDatabase
from slurm_sweeps.constants import CFG, ITERATION, TIMESTAMP, TRIAL_ID
from slurm_sweeps.database import ExperimentExistsError, SqlDatabase


@pytest.fixture(params=["file", "sql"])
def database(request, tmp_path) -> Union[FileDatabase, SqlDatabase]:
@pytest.fixture
def database(tmp_path) -> SqlDatabase:
db_path = tmp_path / "slurm_sweeps.db"
if request.param == "file":
return FileDatabase(db_path)
if request.param == "sql":
return SqlDatabase(db_path)
raise NotImplementedError(
f"'database' fixture not implemented for '{request.param}'"
)


@pytest.mark.parametrize(
"database, is_file_or_dir",
[("file", "is_dir"), ("sql", "is_file")],
indirect=["database"],
)
def test_init(database, is_file_or_dir):
assert getattr(database.path, is_file_or_dir)()
return SqlDatabase(db_path)


def test_fasteners_not_installed(monkeypatch):
monkeypatch.setattr("slurm_sweeps.database._has_fasteners", False)
with pytest.raises(ModuleNotFoundError):
FileDatabase()
def test_init(database):
assert database.path.is_file()


def test_create(database):
experiment = "test_experiment"
database.create(experiment)
database.write(experiment, {"test": "test"})
assert len(database.read(experiment)) == 1

con = sqlite3.connect(database.path)
check_exists = con.execute(
"select name from sqlite_master where type='table' and name='test_experiment';"
).fetchone()
check_columns = con.execute(f"pragma table_info({experiment})").fetchall()
con.close()
assert check_exists == ("test_experiment",)
assert [(col[1], col[2], col[4]) for col in check_columns] == [
(TIMESTAMP, "datetime", "strftime('%Y-%m-%d %H:%M:%f', 'NOW')"),
(TRIAL_ID, "TEXT", None),
(ITERATION, "INTEGER", None),
(CFG, "TEXT", None),
]

with pytest.raises(ExperimentExistsError):
database.create(experiment)

database.write("test_experiment", {"test": 0})
database.create(experiment, overwrite=True)
assert len(database.read(experiment)) == 0
# check if empty again
con = sqlite3.connect(database.path)
response = con.execute("select exists (select 1 from test_experiment);").fetchone()
con.close()
assert response == (0,)


def test_write_and_read(database):
database.create("test_experiment")
database.write("test_experiment", {"test_int": 1, "test_float": 0.5})

con = sqlite3.connect(database.path)
response = con.execute("select count(*) from test_experiment").fetchone()
con.close()
assert response == (1,)

df = database.read("test_experiment")
assert isinstance(df, pd.DataFrame)
assert list(df.columns) == [
TIMESTAMP,
TRIAL_ID,
ITERATION,
CFG,
"test_int",
"test_float",
]
assert type(df[TIMESTAMP].iloc[0]) is str
assert (
df.iloc[:, 1:]
.compare(
pd.DataFrame(
{
TRIAL_ID: [np.nan],
ITERATION: [np.nan],
CFG: [np.nan],
"test_int": [1],
"test_float": [0.5],
}
)
)
.empty
)


def read_or_write(
mode: str, database: Union[FileDatabase, SqlDatabase], experiment: str, row: Dict
):
def read_or_write(mode: str, database: SqlDatabase, experiment: str, row: Dict):
if mode == "w":
database.write(experiment, row)
else:
database.read(experiment)


@pytest.mark.parametrize("database", ["file", "sql"], indirect=["database"])
def test_concurrent_write_read(database):
experiment = "test_db_write_read"
n = 250
Expand Down Expand Up @@ -91,7 +127,7 @@ def test_nan_values(database):
assert np.isnan(df["loss"].iloc[0])


@pytest.mark.skip("Only for speed comparisons")
# @pytest.mark.skip("Only for speed comparisons")
def test_speed(monkeypatch, database):
from slurm_sweeps import Logger
from slurm_sweeps.constants import DB_PATH, EXPERIMENT_NAME
Expand Down
2 changes: 1 addition & 1 deletion tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pytest

import slurm_sweeps as ss
from slurm_sweeps import SqlDatabase
from slurm_sweeps.constants import ITERATION
from slurm_sweeps.database import SqlDatabase


def is_slurm_available() -> bool:
Expand Down
Loading