Skip to content

Commit

Permalink
Add example program tracking_dummy.py
Browse files Browse the repository at this point in the history
- Refactor and improve test infrastructure
- Improve README about standalone usage
  • Loading branch information
amotl committed Sep 12, 2023
1 parent 56e6fc3 commit 2b6a590
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 38 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
- Add patch for SQLAlchemy Inspector's `get_table_names`
- Reorder CrateDB SQLAlchemy Dialect polyfills
- Add example experiment program `tracking_merlion.py`, and corresponding tests
- Add example program `tracking_dummy.py`, and improve test infrastructure
56 changes: 49 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ mlflow-cratedb cratedb --version

## Usage

This documentation section explains how to use this software successfully,
please read it carefully.

### Introduction

In order to spin up a CrateDB instance without further ado, you can use
Docker or Podman.
```shell
Expand All @@ -32,23 +37,61 @@ docker run --rm -it --publish=4200:4200 --publish=5432:5432 \
-Cdiscovery.type=single-node
```

Start the MLflow server, pointing it to your [CrateDB] instance,
running on `localhost`.
The repository includes a few [example programs](./examples), which can be used
to exercise the MLflow setup, and to get started.

The `MLFLOW_TRACKING_URI` environment variable defines whether to record outcomes
directly into the database, or by submitting them to an MLflow Tracking Server.

```shell
# Use CrateDB database directly
export MLFLOW_TRACKING_URI="crate://crate@localhost/?schema=examples"

# Use MLflow Tracking Server
export MLFLOW_TRACKING_URI=http://127.0.0.1:5000
```

### Standalone

In order to instruct MLflow to submit the experiment metadata directly to CrateDB,
configure the `MLFLOW_TRACKING_URI` environment variable to point to your CrateDB
server.

```shell
export MLFLOW_TRACKING_URI="crate://crate@localhost/?schema=mlflow"
python examples/tracking_dummy.py
```

### Tracking Server

Start the MLflow server, pointing it to your CrateDB instance, running on
`localhost`.
```shell
mlflow-cratedb server --backend-store-uri='crate://crate@localhost/?schema=mlflow' --dev
```

In order to instruct MLflow to submit the experiment metadata to the MLflow Tracking
Server, configure the `MLFLOW_TRACKING_URI` environment variable to point to it.

```shell
export MLFLOW_TRACKING_URI="http://127.0.0.1:5000"
python examples/tracking_dummy.py
```

### Remarks

Please note that you need to invoke the `mlflow-cratedb` command, which
runs MLflow amalgamated with the necessary changes to support CrateDB.

Also note that we recommend to use a dedicated schema for storing MLflows
tables. In that spirit, the default schema `"doc"` is not populated by
tables of 3rd-party systems.
Also note that we recommend to use a dedicated schema for storing MLflow's
tables. In that spirit, CrateDB's default schema `"doc"` is not populated
by any tables of 3rd-party systems.


## Development

Acquire source code and install development sandbox.
Acquire source code and install development sandbox. The authors recommend to
use a Python virtualenv.
```shell
git clone https://github.com/crate-workbench/mlflow-cratedb
cd mlflow-cratedb
Expand All @@ -59,7 +102,6 @@ pip install --editable='.[examples,develop,docs,test]'

Run linters and software tests, skipping slow tests:
```shell
source .venv/bin/activate
poe check-fast
```

Expand Down
72 changes: 72 additions & 0 deletions examples/tracking_dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
About
Use MLflow and CrateDB to track the metrics and parameters of a dummy ML
experiment program.
- https://github.com/crate-workbench/mlflow-cratedb
- https://mlflow.org/docs/latest/tracking.html
Usage
Before running the program, optionally define the `MLFLOW_TRACKING_URI` environment
variable, in order to record events and metrics either directly into the database,
or by submitting them to an MLflow Tracking Server.
# Use CrateDB database directly
export MLFLOW_TRACKING_URI="crate://crate@localhost/?schema=mlflow"
# Use MLflow Tracking Server
export MLFLOW_TRACKING_URI=http://127.0.0.1:5000
Resources
- https://mlflow.org/
- https://github.com/crate/crate
"""
import logging
import sys

import mlflow

logger = logging.getLogger()


def run_experiment():
"""
Run an MLflow dummy workflow, without any data.
"""
logger.info("Running experiment")
mlflow.set_experiment("dummy-experiment")

mlflow.log_metric("precision", 0.33)
mlflow.log_metric("recall", 0.48)
mlflow.log_metric("f1", 0.85)
mlflow.log_metric("mttd", 42.42)
mlflow.log_param("anomaly_threshold", 0.10)
mlflow.log_param("min_alm_window", 3600)
mlflow.log_param("alm_window_minutes", 60)
mlflow.log_param("alm_suppress_minutes", 5)
mlflow.log_param("ensemble_size", 25)


def start_adapter():
logger.info("Initializing CrateDB adapter")
import mlflow_cratedb # noqa: F401


def setup_logging():
logging.basicConfig(stream=sys.stderr, level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")


def main():
"""
Run dummy experiment.
"""
setup_logging()
start_adapter()
run_experiment()


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions mlflow_cratedb/patch/mlflow/db_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ def patch_dbtypes():

if db_types.CRATEDB not in db_types.DATABASE_ENGINES:
db_types.DATABASE_ENGINES.append(db_types.CRATEDB)

import mlflow.tracking._tracking_service.utils as tracking_utils

tracking_utils._tracking_store_registry.register(CRATEDB, tracking_utils._get_sqlalchemy_store)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ log_cli_level = "DEBUG"
testpaths = ["tests"]
xfail_strict = true
markers = [
"examples",
"notrackingurimock",
"slow",
]
Expand Down Expand Up @@ -189,6 +190,7 @@ extend-exclude = [

[tool.ruff.per-file-ignores]
"tests/*" = ["S101"] # Use of `assert` detected
"tests/conftest.py" = ["E402"] # Module level import not at top of file


# ===================
Expand Down
52 changes: 47 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,57 @@
import mlflow
import os

import pytest

from mlflow_cratedb import patch_all

patch_all()

import mlflow
import mlflow.store.tracking.sqlalchemy_store as mlflow_tracking
import sqlalchemy as sa

ARTIFACT_URI = "testdrive_folder"

# The canonical database schema used for test purposes is `testdrive`.
DB_URI = "crate://crate@localhost/?schema=testdrive"

@pytest.fixture(autouse=True)
def reset_environment() -> None:
"""
Make sure software tests do not pick up any environment variables.
"""
if "MLFLOW_TRACKING_URI" in os.environ:
del os.environ["MLFLOW_TRACKING_URI"]


@pytest.fixture
def db_uri() -> str:
"""
The canonical database schema used for testing purposes is `testdrive`.
"""
return "crate://crate@localhost/?schema=testdrive"


@pytest.fixture
def testdrive_engine():
yield mlflow.store.db.utils.create_sqlalchemy_engine_with_retry(DB_URI)
def engine(db_uri):
"""
Provide an SQLAlchemy engine object using the `testdrive` schema.
"""
yield mlflow.store.db.utils.create_sqlalchemy_engine_with_retry(db_uri)


@pytest.fixture
def tracking_store(engine: sa.Engine) -> mlflow_tracking.SqlAlchemyStore:
"""
A fixture for providing an instance of `SqlAlchemyStore`.
"""
yield mlflow_tracking.SqlAlchemyStore(str(engine.url), ARTIFACT_URI)


@pytest.fixture
def reset_database(engine: sa.Engine):
"""
Make sure to reset the database by dropping and re-creating tables.
"""
from mlflow_cratedb.adapter.setup_db import _setup_db_create_tables, _setup_db_drop_tables

_setup_db_drop_tables(engine=engine)
_setup_db_create_tables(engine=engine)
12 changes: 6 additions & 6 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@


@pytest.fixture
def store(testdrive_engine: sa.Engine):
def store(engine: sa.Engine):
"""
A fixture for providing an instance of `SqlAlchemyStore`.
"""
yield SqlAlchemyStore(str(testdrive_engine.url), ARTIFACT_URI)
yield SqlAlchemyStore(str(engine.url), ARTIFACT_URI)


@pytest.fixture
Expand All @@ -32,13 +32,13 @@ def store_empty(store):
yield store


def test_setup_tables(testdrive_engine: sa.Engine):
def test_setup_tables(engine: sa.Engine):
"""
Test if creating database tables works, and that they use the correct schema.
"""
_setup_db_drop_tables(engine=testdrive_engine)
_setup_db_create_tables(engine=testdrive_engine)
with testdrive_engine.connect() as connection:
_setup_db_drop_tables(engine=engine)
_setup_db_create_tables(engine=engine)
with engine.connect() as connection:
result = connection.execute(sa.text("SELECT * FROM testdrive.experiments;"))
assert result.rowcount == 0

Expand Down
78 changes: 63 additions & 15 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,77 @@
import time
from pathlib import Path

import mlflow
import pytest
import sqlalchemy as sa
from mlflow.store.tracking.dbmodels.models import SqlExperiment, SqlMetric, SqlParam
from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore

from mlflow_cratedb.adapter.setup_db import _setup_db_create_tables, _setup_db_drop_tables
from tests.util import process

# The canonical database schema used for example purposes is `examples`.
DB_URI = "crate://crate@localhost/?schema=examples"
MLFLOW_TRACKING_URI = "http://127.0.0.1:5000"
# The test cases within this file exercise two different ways of recording
# ML experiments. They can either be directly submitted to the database,
# or alternatively to an MLflow Tracking Server.
MLFLOW_TRACKING_URI_SERVER = "http://127.0.0.1:5000"


logger = logging.getLogger(__name__)


@pytest.fixture
def engine():
yield mlflow.store.db.utils.create_sqlalchemy_engine_with_retry(DB_URI)
def get_example_program_path(filename: str):
"""
Compute path to example program.
"""
return Path(__file__).parent.parent.joinpath("examples").joinpath(filename)


def test_tracking_merlion(engine: sa.Engine):
_setup_db_drop_tables(engine=engine)
_setup_db_create_tables(engine=engine)
tracking_merlion = Path(__file__).parent.parent.joinpath("examples").joinpath("tracking_merlion.py")
@pytest.mark.examples
def test_tracking_dummy(reset_database, engine: sa.Engine, tracking_store: SqlAlchemyStore, db_uri):
"""
Run a dummy experiment program, without any data.
Verify that the database has been populated appropriately.
Here, no MLflow Tracking Server is used, so the `MLFLOW_TRACKING_URI`
will be the SQLAlchemy database connection URI, i.e. the program will
directly communicate with CrateDB.
-- https://mlflow.org/docs/latest/tracking.html#backend-stores
"""

# Invoke example program.
tracking_dummy = get_example_program_path("tracking_dummy.py")
logger.info("Starting experiment program")
with process(
[sys.executable, tracking_dummy],
env={"MLFLOW_TRACKING_URI": db_uri},
stdout=sys.stdout.buffer,
stderr=sys.stderr.buffer,
) as client_process:
client_process.wait(timeout=10)
assert client_process.returncode == 0

# Verify database content.
with tracking_store.ManagedSessionMaker() as session:
assert session.query(SqlExperiment).count() == 2
assert session.query(SqlMetric).count() == 4
assert session.query(SqlParam).count() == 5


@pytest.mark.examples
def test_tracking_merlion(reset_database, engine: sa.Engine, tracking_store: SqlAlchemyStore, db_uri):
"""
Run a real experiment program, reporting to an MLflow Tracking Server.
Verify that the database has been populated appropriately.
Here, `MLFLOW_TRACKING_URI` will be the HTTP URL of the Tracking Server,
i.e. the program will submit events and metrics to it, wrapping the
connection to CrateDB.
"""
tracking_merlion = get_example_program_path("tracking_merlion.py")
cmd_server = [
"mlflow-cratedb",
"server",
"--workers=1",
f"--backend-store-uri={DB_URI}",
f"--backend-store-uri={db_uri}",
"--gunicorn-opts='--log-level=debug'",
]
cmd_client = [
Expand All @@ -44,14 +86,20 @@ def test_tracking_merlion(engine: sa.Engine):
logger.info(f"Started server with process id: {server_process.pid}")
# TODO: Wait for HTTP response.
time.sleep(4)

# Invoke example program.
logger.info("Starting client")
with process(
cmd_client,
env={"MLFLOW_TRACKING_URI": MLFLOW_TRACKING_URI},
env={"MLFLOW_TRACKING_URI": MLFLOW_TRACKING_URI_SERVER},
stdout=sys.stdout.buffer,
stderr=sys.stderr.buffer,
) as client_process:
client_process.wait(timeout=120)
assert client_process.returncode == 0

# TODO: Verify database content.
# Verify database content.
with tracking_store.ManagedSessionMaker() as session:
assert session.query(SqlExperiment).count() == 2
assert session.query(SqlMetric).count() == 4
assert session.query(SqlParam).count() == 5
Loading

0 comments on commit 2b6a590

Please sign in to comment.