Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions docs/databricks_model_registry.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Databricks Model Registry Guide

This project targets the Databricks **Workspace Model Registry** by default because many subscriptions (including ours) do not provide Unity Catalog access. The integration is designed so that switching to Unity Catalog later only requires configuration changes—no code changes.

## Default Behaviour

- Hydra config `logging=mlflow_logger` sets `tracking_uri=databricks` and `registry_uri=databricks`.
- `simplexity.utils.mlflow_utils.resolve_registry_uri` downgrades Unity Catalog URIs (``databricks-uc``) to workspace URIs when `allow_workspace_fallback=True` (the default) and emits a warning so you know a downgrade happened.
- `MLFlowLogger` and `MLFlowPersister.from_experiment` both call `resolve_registry_uri`, so any code path that uses Simplexity helpers gets the same fallback logic.
- `examples/mlflow_workspace_registry_demo.py` mirrors this behaviour and can be used to sanity-check Databricks connectivity.

## Preparing for a Future Unity Catalog Migration

To keep migration friction low we expose an `allow_workspace_fallback` flag everywhere MLflow clients are created.

- **Logger config** (`simplexity/configs/logging/mlflow_logger.yaml`):
- Set `registry_uri: databricks-uc` once your workspace is UC-enabled.
- Flip `allow_workspace_fallback: false` to stop the automatic downgrade.
- **Programmatic use**: `MLFlowLogger(..., allow_workspace_fallback=False)` or `MLFlowPersister.from_experiment(..., allow_workspace_fallback=False)` preserves Unity Catalog URIs.
- **Environment variables**: you can still rely on `MLFLOW_TRACKING_URI` / `MLFLOW_REGISTRY_URI`. When fallback is disabled those values are forwarded unchanged.

Because the flag defaults to `True`, current jobs continue working even if a Unity Catalog URI is supplied accidentally—Simplexity automatically falls back to the workspace registry and logs a warning. When you are ready to migrate, toggling the flag allows UC usage without touching the codebase.

## Suggested Migration Checklist

1. **Enable Unity Catalog in Databricks** and make sure the MLflow registry permissions are set up (see the official Databricks migration guide).
2. **Create the Unity Catalog equivalents** of any workspace-registered models if you plan to keep history—Databricks provides automated migration jobs for this.
3. **Update configuration**:
- Set `registry_uri` (and optionally `tracking_uri`) to the appropriate `databricks-uc` endpoint.
- Set `allow_workspace_fallback: false` to surface real UC connectivity errors instead of silently downgrading.
4. **Smoke test** using `examples/mlflow_workspace_registry_demo.py` with the updated config. The script will now run against UC and should register the demo model there.
5. **Monitor warnings**: once fallback is disabled, any remaining downgrade warnings indicate stale configs or code paths that still pass the workspace URI.

## Operational Notes

- Keeping fallback enabled during the transition phase is helpful because it avoids runtime failures, but remember that models will continue to land in the workspace registry until you turn it off.
- After migration you can remove the fallback flag entirely or leave it `False` so that future regressions are caught early.
- If you need parallel workspace/UC logging (for validation) you can run two Hydrated jobs with different logger configs—no application code changes are required.
287 changes: 287 additions & 0 deletions examples/mlflow_workspace_registry_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
"""Demonstrate saving and loading a PyTorch model with the MLflow workspace registry."""

from __future__ import annotations

import os
import sys
import time
import urllib.parse
from dataclasses import dataclass, field

import hydra
import mlflow
from hydra.core.config_store import ConfigStore
from mlflow.entities.model_registry import ModelVersion
from omegaconf import MISSING

from simplexity.utils.mlflow_utils import resolve_registry_uri

try:
import torch
from torch import nn
except ImportError as exc: # pragma: no cover - script guard
raise SystemExit(
"PyTorch is required for this demo. Install it with `pip install torch` "
"or add the `pytorch` extra when installing this project."
) from exc


WORKSPACE_REGISTRY_URI = "databricks"


class TinyClassifier(nn.Module):
"""A tiny classifier for testing."""

def __init__(self) -> None:
super().__init__()
self.model = nn.Sequential(
nn.Linear(4, 16),
nn.ReLU(),
nn.Linear(16, 2),
)

def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override]
"""Forward pass."""
return self.model(x)


@dataclass
class DemoConfig:
"""Configuration for the MLflow workspace registry demo."""

experiment: str = "WorkspaceRegistryDemo"
run_name: str | None = None
registered_model_name: str = MISSING
tracking_uri: str | None = field(default_factory=lambda: os.getenv("MLFLOW_TRACKING_URI"))
registry_uri: str | None = field(default_factory=lambda: os.getenv("MLFLOW_REGISTRY_URI", WORKSPACE_REGISTRY_URI))
artifact_path: str = "pytorch-model"
poll_interval: float = 2.0
poll_timeout: float = 300.0
databricks_host: str | None = field(default_factory=lambda: os.getenv("DATABRICKS_HOST"))
allow_workspace_fallback: bool = True


CONFIG_NAME = "mlflow_workspace_registry_demo"
LEGACY_CONFIG_NAME = "mlflow_unity_catalog_demo"

config_store = ConfigStore.instance()
config_store.store(name=CONFIG_NAME, node=DemoConfig)
config_store.store(name=LEGACY_CONFIG_NAME, node=DemoConfig)


def ensure_experiment(client: mlflow.MlflowClient, name: str) -> str:
"""Ensure an experiment exists."""
experiment = client.get_experiment_by_name(name)
if experiment:
return experiment.experiment_id
return client.create_experiment(name)


def await_model_version_ready(
client: mlflow.MlflowClient,
model_name: str,
version: str,
poll_interval: float,
poll_timeout: float,
) -> ModelVersion:
"""Wait for a model version to be ready."""
deadline = time.monotonic() + poll_timeout
while True:
current = client.get_model_version(name=model_name, version=version)
if current.status == "READY":
return current
if current.status == "FAILED":
raise RuntimeError(f"Model version {model_name}/{version} failed to register: {current.status_message}")
if time.monotonic() > deadline:
raise TimeoutError(f"Model version {model_name}/{version} did not become READY within {poll_timeout}s")
time.sleep(poll_interval)


def search_model_version_for_run(
client: mlflow.MlflowClient,
model_name: str,
run_id: str,
) -> ModelVersion:
"""Search for a model version for a run."""
versions = client.search_model_versions(f"name = '{model_name}' and run_id = '{run_id}'")
if not versions:
raise RuntimeError(
"No model versions were created for this run. Ensure the run has permission to register a model."
)
# MLflow returns the newest model version first for this query.
return versions[0]


def build_databricks_urls(
host: str | None,
experiment_id: str,
run_id: str,
model_name: str,
model_version: str,
) -> tuple[str | None, str | None]:
"""Build Databricks URLs for a model version."""
if not host:
return None, None
base = host.rstrip("/")
encoded_name = urllib.parse.quote(model_name, safe="")
run_url = f"{base}/#mlflow/experiments/{experiment_id}/runs/{run_id}"
model_url = f"{base}/#mlflow/models/{encoded_name}/versions/{model_version}"
return run_url, model_url


def run_demo(config: DemoConfig) -> None:
"""Run the MLflow workspace registry demo."""
resolved_registry_uri = resolve_registry_uri(
config.tracking_uri,
config.registry_uri,
allow_workspace_fallback=config.allow_workspace_fallback,
)
if config.tracking_uri:
mlflow.set_tracking_uri(config.tracking_uri)
if resolved_registry_uri:
mlflow.set_registry_uri(resolved_registry_uri)

client = mlflow.MlflowClient(tracking_uri=mlflow.get_tracking_uri(), registry_uri=mlflow.get_registry_uri())
experiment_id = ensure_experiment(client, config.experiment)

torch.manual_seed(7)
model = TinyClassifier()
sample_input = torch.randn(4, 4)

run_id: str = "" # Initialize to avoid "possibly unbound" error
model_version: ModelVersion | None = None # Initialize to avoid "possibly unbound" error

with mlflow.start_run(experiment_id=experiment_id, run_name=config.run_name) as run:
run_id = run.info.run_id
mlflow.log_params({"demo": "workspace_registry", "framework": "pytorch", "layers": len(list(model.modules()))})

# First log the model without registering it
mlflow.pytorch.log_model( # type: ignore[attr-defined]
model,
artifact_path=config.artifact_path,
)

# Then register the model separately
try:
client.create_registered_model(config.registered_model_name)
print(f"Created registered model: {config.registered_model_name}")
except Exception as e:
if "already exists" in str(e).lower():
print(f"Registered model {config.registered_model_name} already exists")
else:
raise

# Create model version using the model URI from the logged model
model_uri = f"runs:/{run_id}/{config.artifact_path}"
model_version = client.create_model_version(
name=config.registered_model_name,
source=model_uri,
run_id=run_id,
description="Demo model from workspace registry",
)
print(f"Created model version: {model_version.version}")

predictions = model(sample_input).detach()
mlflow.log_artifact(
_dump_tensor(predictions, "predictions.txt"),
artifact_path="artifacts",
)

# Wait for model version to be ready
if model_version is None:
raise RuntimeError("Failed to create model version")
ready_version = await_model_version_ready(
client,
config.registered_model_name,
model_version.version,
config.poll_interval,
config.poll_timeout,
)

model_uri = f"models:/{config.registered_model_name}/{ready_version.version}"
loaded_model = mlflow.pytorch.load_model(model_uri) # type: ignore[attr-defined]
restored_model = TinyClassifier()
restored_model.load_state_dict(loaded_model.state_dict())

verification_input = torch.randn(2, 4)
original_output = model(verification_input)
restored_output = restored_model(verification_input)
if not torch.allclose(original_output, restored_output, atol=1e-5):
raise RuntimeError("Loaded weights differ from the original model outputs.")

run_url, model_url = build_databricks_urls(
config.databricks_host,
experiment_id,
run_id,
config.registered_model_name,
ready_version.version,
)

info_lines = [
"MLflow workspace registry demo complete!",
f"Run ID: {run_id}",
f"Model URI: {model_uri}",
f"Model version status: {ready_version.status}",
]
if run_url:
info_lines.append(f"Run UI: {run_url}")
if model_url:
info_lines.append(f"Model UI: {model_url}")
print("\n".join(info_lines))


def _dump_tensor(tensor: torch.Tensor, filename: str) -> str:
"""Dump a tensor to a file."""
path = os.path.join(_ensure_temp_dir(), filename)
with open(path, "w", encoding="utf-8") as handle:
for row in tensor.tolist():
handle.write(",".join(f"{value:.6f}" for value in row))
handle.write("\n")
return path


_TEMP_DIR: str | None = None


def _ensure_temp_dir() -> str:
"""Ensure a temporary directory exists."""
global _TEMP_DIR
if _TEMP_DIR is None:
import tempfile

_TEMP_DIR = tempfile.mkdtemp(prefix="mlflow-workspace-demo-")
return _TEMP_DIR


def _cleanup_temp_dir() -> None:
"""Cleanup the temporary directory."""
global _TEMP_DIR
if _TEMP_DIR and os.path.isdir(_TEMP_DIR):
import shutil

shutil.rmtree(_TEMP_DIR, ignore_errors=True)
_TEMP_DIR = None


def _register_atexit() -> None:
"""Register an atexit handler to cleanup the temporary directory."""
import atexit

atexit.register(_cleanup_temp_dir)


_register_atexit()


@hydra.main(version_base="1.2", config_name=CONFIG_NAME)
def main(config: DemoConfig) -> None:
"""Main entry point for the MLflow workspace registry demo."""
try:
run_demo(config)
except (RuntimeError, TimeoutError) as error:
print(f"Error: {error}", file=sys.stderr)
sys.exit(1)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [
"jax",
"jupyter",
"matplotlib",
"mlflow",
"mlflow>=3.0.0",
"optax",
"orbax-checkpoint",
"pandas",
Expand Down
2 changes: 2 additions & 0 deletions simplexity/configs/logging/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class MLFlowLoggerConfig(LoggingInstanceConfig):
experiment_name: str
run_name: str
tracking_uri: str
registry_uri: str | None = None
allow_workspace_fallback: bool = True


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions simplexity/configs/logging/mlflow_logger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ instance:
experiment_name: /Shared/${experiment_name}
run_name: ${run_name}
tracking_uri: databricks
registry_uri: databricks
allow_workspace_fallback: true
Loading
Loading