diff --git a/CHANGES.md b/CHANGES.md index 6b51f4c..2734e4b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -12,3 +12,4 @@ - Add patch for SQLAlchemy Inspector's `get_table_names` - Reorder CrateDB SQLAlchemy Dialect polyfills - Add example experiment program `tracking_merlion.py`, and corresponding tests +- Add example program `tracking_dummy.py`, and improve test infrastructure diff --git a/README.md b/README.md index 865a997..2e99cc3 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,11 @@ mlflow-cratedb cratedb --version ## Usage +This documentation section explains how to use this software successfully, +please read it carefully. + +### Introduction + In order to spin up a CrateDB instance without further ado, you can use Docker or Podman. ```shell @@ -32,23 +37,61 @@ docker run --rm -it --publish=4200:4200 --publish=5432:5432 \ -Cdiscovery.type=single-node ``` -Start the MLflow server, pointing it to your [CrateDB] instance, -running on `localhost`. +The repository includes a few [example programs](./examples), which can be used +to exercise the MLflow setup, and to get started. + +The `MLFLOW_TRACKING_URI` environment variable defines whether to record outcomes +directly into the database, or by submitting them to an MLflow Tracking Server. + +```shell +# Use CrateDB database directly +export MLFLOW_TRACKING_URI="crate://crate@localhost/?schema=examples" + +# Use MLflow Tracking Server +export MLFLOW_TRACKING_URI=http://127.0.0.1:5000 +``` + +### Standalone + +In order to instruct MLflow to submit the experiment metadata directly to CrateDB, +configure the `MLFLOW_TRACKING_URI` environment variable to point to your CrateDB +server. + +```shell +export MLFLOW_TRACKING_URI="crate://crate@localhost/?schema=mlflow" +python examples/tracking_dummy.py +``` + +### Tracking Server + +Start the MLflow server, pointing it to your CrateDB instance, running on +`localhost`. ```shell mlflow-cratedb server --backend-store-uri='crate://crate@localhost/?schema=mlflow' --dev ``` +In order to instruct MLflow to submit the experiment metadata to the MLflow Tracking +Server, configure the `MLFLOW_TRACKING_URI` environment variable to point to it. + +```shell +export MLFLOW_TRACKING_URI="http://127.0.0.1:5000" +python examples/tracking_dummy.py +``` + +### Remarks + Please note that you need to invoke the `mlflow-cratedb` command, which runs MLflow amalgamated with the necessary changes to support CrateDB. -Also note that we recommend to use a dedicated schema for storing MLflows -tables. In that spirit, the default schema `"doc"` is not populated by -tables of 3rd-party systems. +Also note that we recommend to use a dedicated schema for storing MLflow's +tables. In that spirit, CrateDB's default schema `"doc"` is not populated +by any tables of 3rd-party systems. ## Development -Acquire source code and install development sandbox. +Acquire source code and install development sandbox. The authors recommend to +use a Python virtualenv. ```shell git clone https://github.com/crate-workbench/mlflow-cratedb cd mlflow-cratedb @@ -59,7 +102,6 @@ pip install --editable='.[examples,develop,docs,test]' Run linters and software tests, skipping slow tests: ```shell -source .venv/bin/activate poe check-fast ``` diff --git a/examples/tracking_dummy.py b/examples/tracking_dummy.py new file mode 100644 index 0000000..466d168 --- /dev/null +++ b/examples/tracking_dummy.py @@ -0,0 +1,72 @@ +""" +About + +Use MLflow and CrateDB to track the metrics and parameters of a dummy ML +experiment program. + +- https://github.com/crate-workbench/mlflow-cratedb +- https://mlflow.org/docs/latest/tracking.html + +Usage + +Before running the program, optionally define the `MLFLOW_TRACKING_URI` environment +variable, in order to record events and metrics either directly into the database, +or by submitting them to an MLflow Tracking Server. + + # Use CrateDB database directly + export MLFLOW_TRACKING_URI="crate://crate@localhost/?schema=mlflow" + + # Use MLflow Tracking Server + export MLFLOW_TRACKING_URI=http://127.0.0.1:5000 + +Resources + +- https://mlflow.org/ +- https://github.com/crate/crate +""" +import logging +import sys + +import mlflow + +logger = logging.getLogger() + + +def run_experiment(): + """ + Run an MLflow dummy workflow, without any data. + """ + logger.info("Running experiment") + mlflow.set_experiment("dummy-experiment") + + mlflow.log_metric("precision", 0.33) + mlflow.log_metric("recall", 0.48) + mlflow.log_metric("f1", 0.85) + mlflow.log_metric("mttd", 42.42) + mlflow.log_param("anomaly_threshold", 0.10) + mlflow.log_param("min_alm_window", 3600) + mlflow.log_param("alm_window_minutes", 60) + mlflow.log_param("alm_suppress_minutes", 5) + mlflow.log_param("ensemble_size", 25) + + +def start_adapter(): + logger.info("Initializing CrateDB adapter") + import mlflow_cratedb # noqa: F401 + + +def setup_logging(): + logging.basicConfig(stream=sys.stderr, level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") + + +def main(): + """ + Run dummy experiment. + """ + setup_logging() + start_adapter() + run_experiment() + + +if __name__ == "__main__": + main() diff --git a/mlflow_cratedb/patch/mlflow/db_types.py b/mlflow_cratedb/patch/mlflow/db_types.py index e57cc17..a40889c 100644 --- a/mlflow_cratedb/patch/mlflow/db_types.py +++ b/mlflow_cratedb/patch/mlflow/db_types.py @@ -11,3 +11,7 @@ def patch_dbtypes(): if db_types.CRATEDB not in db_types.DATABASE_ENGINES: db_types.DATABASE_ENGINES.append(db_types.CRATEDB) + + import mlflow.tracking._tracking_service.utils as tracking_utils + + tracking_utils._tracking_store_registry.register(CRATEDB, tracking_utils._get_sqlalchemy_store) diff --git a/pyproject.toml b/pyproject.toml index f4a8fa2..9267a52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,7 @@ log_cli_level = "DEBUG" testpaths = ["tests"] xfail_strict = true markers = [ + "examples", "notrackingurimock", "slow", ] @@ -189,6 +190,7 @@ extend-exclude = [ [tool.ruff.per-file-ignores] "tests/*" = ["S101"] # Use of `assert` detected +"tests/conftest.py" = ["E402"] # Module level import not at top of file # =================== diff --git a/tests/conftest.py b/tests/conftest.py index 6f5fbdd..06d4c5a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,57 @@ -import mlflow +import os + import pytest from mlflow_cratedb import patch_all patch_all() +import mlflow +import mlflow.store.tracking.sqlalchemy_store as mlflow_tracking +import sqlalchemy as sa + +ARTIFACT_URI = "testdrive_folder" -# The canonical database schema used for test purposes is `testdrive`. -DB_URI = "crate://crate@localhost/?schema=testdrive" + +@pytest.fixture(autouse=True) +def reset_environment() -> None: + """ + Make sure software tests do not pick up any environment variables. + """ + if "MLFLOW_TRACKING_URI" in os.environ: + del os.environ["MLFLOW_TRACKING_URI"] + + +@pytest.fixture +def db_uri() -> str: + """ + The canonical database schema used for testing purposes is `testdrive`. + """ + return "crate://crate@localhost/?schema=testdrive" @pytest.fixture -def testdrive_engine(): - yield mlflow.store.db.utils.create_sqlalchemy_engine_with_retry(DB_URI) +def engine(db_uri): + """ + Provide an SQLAlchemy engine object using the `testdrive` schema. + """ + yield mlflow.store.db.utils.create_sqlalchemy_engine_with_retry(db_uri) + + +@pytest.fixture +def tracking_store(engine: sa.Engine) -> mlflow_tracking.SqlAlchemyStore: + """ + A fixture for providing an instance of `SqlAlchemyStore`. + """ + yield mlflow_tracking.SqlAlchemyStore(str(engine.url), ARTIFACT_URI) + + +@pytest.fixture +def reset_database(engine: sa.Engine): + """ + Make sure to reset the database by dropping and re-creating tables. + """ + from mlflow_cratedb.adapter.setup_db import _setup_db_create_tables, _setup_db_drop_tables + + _setup_db_drop_tables(engine=engine) + _setup_db_create_tables(engine=engine) diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 72a2cbc..a5875b6 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -10,11 +10,11 @@ @pytest.fixture -def store(testdrive_engine: sa.Engine): +def store(engine: sa.Engine): """ A fixture for providing an instance of `SqlAlchemyStore`. """ - yield SqlAlchemyStore(str(testdrive_engine.url), ARTIFACT_URI) + yield SqlAlchemyStore(str(engine.url), ARTIFACT_URI) @pytest.fixture @@ -32,13 +32,13 @@ def store_empty(store): yield store -def test_setup_tables(testdrive_engine: sa.Engine): +def test_setup_tables(engine: sa.Engine): """ Test if creating database tables works, and that they use the correct schema. """ - _setup_db_drop_tables(engine=testdrive_engine) - _setup_db_create_tables(engine=testdrive_engine) - with testdrive_engine.connect() as connection: + _setup_db_drop_tables(engine=engine) + _setup_db_create_tables(engine=engine) + with engine.connect() as connection: result = connection.execute(sa.text("SELECT * FROM testdrive.experiments;")) assert result.rowcount == 0 diff --git a/tests/test_examples.py b/tests/test_examples.py index befb718..809c64b 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -3,35 +3,77 @@ import time from pathlib import Path -import mlflow import pytest import sqlalchemy as sa +from mlflow.store.tracking.dbmodels.models import SqlExperiment, SqlMetric, SqlParam +from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore -from mlflow_cratedb.adapter.setup_db import _setup_db_create_tables, _setup_db_drop_tables from tests.util import process -# The canonical database schema used for example purposes is `examples`. -DB_URI = "crate://crate@localhost/?schema=examples" -MLFLOW_TRACKING_URI = "http://127.0.0.1:5000" +# The test cases within this file exercise two different ways of recording +# ML experiments. They can either be directly submitted to the database, +# or alternatively to an MLflow Tracking Server. +MLFLOW_TRACKING_URI_SERVER = "http://127.0.0.1:5000" logger = logging.getLogger(__name__) -@pytest.fixture -def engine(): - yield mlflow.store.db.utils.create_sqlalchemy_engine_with_retry(DB_URI) +def get_example_program_path(filename: str): + """ + Compute path to example program. + """ + return Path(__file__).parent.parent.joinpath("examples").joinpath(filename) -def test_tracking_merlion(engine: sa.Engine): - _setup_db_drop_tables(engine=engine) - _setup_db_create_tables(engine=engine) - tracking_merlion = Path(__file__).parent.parent.joinpath("examples").joinpath("tracking_merlion.py") +@pytest.mark.examples +def test_tracking_dummy(reset_database, engine: sa.Engine, tracking_store: SqlAlchemyStore, db_uri): + """ + Run a dummy experiment program, without any data. + Verify that the database has been populated appropriately. + + Here, no MLflow Tracking Server is used, so the `MLFLOW_TRACKING_URI` + will be the SQLAlchemy database connection URI, i.e. the program will + directly communicate with CrateDB. + + -- https://mlflow.org/docs/latest/tracking.html#backend-stores + """ + + # Invoke example program. + tracking_dummy = get_example_program_path("tracking_dummy.py") + logger.info("Starting experiment program") + with process( + [sys.executable, tracking_dummy], + env={"MLFLOW_TRACKING_URI": db_uri}, + stdout=sys.stdout.buffer, + stderr=sys.stderr.buffer, + ) as client_process: + client_process.wait(timeout=10) + assert client_process.returncode == 0 + + # Verify database content. + with tracking_store.ManagedSessionMaker() as session: + assert session.query(SqlExperiment).count() == 2 + assert session.query(SqlMetric).count() == 4 + assert session.query(SqlParam).count() == 5 + + +@pytest.mark.examples +def test_tracking_merlion(reset_database, engine: sa.Engine, tracking_store: SqlAlchemyStore, db_uri): + """ + Run a real experiment program, reporting to an MLflow Tracking Server. + Verify that the database has been populated appropriately. + + Here, `MLFLOW_TRACKING_URI` will be the HTTP URL of the Tracking Server, + i.e. the program will submit events and metrics to it, wrapping the + connection to CrateDB. + """ + tracking_merlion = get_example_program_path("tracking_merlion.py") cmd_server = [ "mlflow-cratedb", "server", "--workers=1", - f"--backend-store-uri={DB_URI}", + f"--backend-store-uri={db_uri}", "--gunicorn-opts='--log-level=debug'", ] cmd_client = [ @@ -44,14 +86,20 @@ def test_tracking_merlion(engine: sa.Engine): logger.info(f"Started server with process id: {server_process.pid}") # TODO: Wait for HTTP response. time.sleep(4) + + # Invoke example program. logger.info("Starting client") with process( cmd_client, - env={"MLFLOW_TRACKING_URI": MLFLOW_TRACKING_URI}, + env={"MLFLOW_TRACKING_URI": MLFLOW_TRACKING_URI_SERVER}, stdout=sys.stdout.buffer, stderr=sys.stderr.buffer, ) as client_process: client_process.wait(timeout=120) assert client_process.returncode == 0 - # TODO: Verify database content. + # Verify database content. + with tracking_store.ManagedSessionMaker() as session: + assert session.query(SqlExperiment).count() == 2 + assert session.query(SqlMetric).count() == 4 + assert session.query(SqlParam).count() == 5 diff --git a/tests/test_mlflow.py b/tests/test_mlflow.py index 4146ec8..0947a11 100644 --- a/tests/test_mlflow.py +++ b/tests/test_mlflow.py @@ -3,12 +3,12 @@ from mlflow_cratedb.adapter.setup_db import _setup_db_create_tables, _setup_db_drop_tables -def test_all_tables_exist(testdrive_engine): +def test_all_tables_exist(engine): """ Cover `patch_sqlalchemy_inspector`: SQLAlchemy's Inspector needs a patch to honor the `schema` parameter. """ - _setup_db_drop_tables(engine=testdrive_engine) - assert mlflow.store.db.utils._all_tables_exist(testdrive_engine) is False - _setup_db_create_tables(engine=testdrive_engine) - assert mlflow.store.db.utils._all_tables_exist(testdrive_engine) is True + _setup_db_drop_tables(engine=engine) + assert mlflow.store.db.utils._all_tables_exist(engine) is False + _setup_db_create_tables(engine=engine) + assert mlflow.store.db.utils._all_tables_exist(engine) is True