Skip to content

Commit

Permalink
✨ Add support for loading model with alias in MlflowModelRegistryData…
Browse files Browse the repository at this point in the history
…set (#553)
  • Loading branch information
Galileo-Galilei committed Aug 29, 2024
1 parent 4714525 commit db131ab
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 11 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## [Unreleased]

### Added

- :sparkles: Add support for loading model with alias in ``MlflowModelRegistryDataset`` [#553](https://github.com/Galileo-Galilei/kedro-mlflow/issues/553)

### Changed

- :boom: :pushpin: Officially drop support for ``mlflow<1.29.0`` which was implicit since the introduction of ``km.random_name`` resolver in [#481](https://github.com/Galileo-Galilei/kedro-mlflow/issues/481) ([#571](https://github.com/Galileo-Galilei/kedro-mlflow/issues/571))
Expand Down
13 changes: 7 additions & 6 deletions docs/source/07_python_objects/01_DataSets.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ my_model:

### ``MlflowModelLocalFileSystemDataset``

The ``MlflowModelTrackingDataset`` accepts the following arguments:
The ``MlflowModelLocalFileSystemDataset`` accepts the following arguments:

- flavor (str): Built-in or custom MLflow model flavor module. Must be Python-importable.
- filepath (str): Path to store the dataset locally.
Expand Down Expand Up @@ -163,11 +163,12 @@ my_model:

The ``MlflowModelRegistryDataset`` accepts the following arguments:

- model_name (str): The name of the registered model is the mlflow registry
- stage_or_version (str): A valid stage (either "staging" or "production") or version number for the registred model.Default to "latest" which fetch the last version and the higher "stage" available.
- flavor (str): Built-in or custom MLflow model flavor module. Must be Python-importable.
- pyfunc_workflow (str, optional): Either `python_model` or `loader_module`. See [mlflow workflows](https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#workflows).
- load_args (Dict[str, Any], optional): Arguments to `load_model` function from specified `flavor`. Defaults to None.
- ``model_name`` (str): The name of the registered model is the mlflow registry
- ``stage_or_version`` (str): A valid stage (either "staging" or "production") or version number for the registred model.Default to None,(internally converted to "latest" if no alias si provided) which fetch the last version and the higher "stage" available.
- ``alias`` (str): A valid alias, which is used instead of stage to filter model since mlflow 2.9.0. Will raise an error if both ``stage_or_version`` and ``alias`` are provided.
- ``flavor`` (str): Built-in or custom MLflow model flavor module. Must be Python-importable.
- ``pyfunc_workflow`` (str, optional): Either `python_model` or `loader_module`. See [mlflow workflows](https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#workflows).
- ``load_args`` (Dict[str, Any], optional): Arguments to `load_model` function from specified `flavor`. Defaults to None.

We assume you have registered a mlflow model first, either [with the ``MlflowClient``](https://mlflow.org/docs/latest/model-registry.html#adding-an-mlflow-model-to-the-model-registry) or [within the mlflow ui](https://mlflow.org/docs/latest/model-registry.html#ui-workflow), e.g. :

Expand Down
22 changes: 20 additions & 2 deletions kedro_mlflow/io/models/mlflow_model_registry_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict, Optional, Union

from kedro.io.core import DatasetError

from kedro_mlflow.io.models.mlflow_abstract_model_dataset import (
MlflowAbstractModelDataSet,
)
Expand All @@ -11,7 +13,8 @@ class MlflowModelRegistryDataset(MlflowAbstractModelDataSet):
def __init__(
self,
model_name: str,
stage_or_version: Union[str, int] = "latest",
stage_or_version: Union[str, int, None] = None,
alias: Optional[str] = None,
flavor: Optional[str] = "mlflow.pyfunc",
pyfunc_workflow: Optional[str] = "python_model",
load_args: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -46,9 +49,23 @@ def __init__(
version=None,
)

if alias is None and stage_or_version is None:
# reassign stage_or_version to "latest"
stage_or_version = "latest"

if alias and stage_or_version:
raise DatasetError(
f"You cannot specify 'alias' and 'stage_or_version' simultaneously ({alias=} and {stage_or_version=})"
)

self.model_name = model_name
self.stage_or_version = stage_or_version
self.model_uri = f"models:/{model_name}/{stage_or_version}"
self.alias = alias
self.model_uri = (
f"models:/{model_name}@{alias}"
if alias
else f"models:/{model_name}/{stage_or_version}"
)

def _load(self) -> Any:
"""Loads an MLflow model from local path or from MLflow run.
Expand All @@ -74,6 +91,7 @@ def _describe(self) -> Dict[str, Any]:
model_uri=self.model_uri,
model_name=self.model_name,
stage_or_version=self.stage_or_version,
alias=self.alias,
flavor=self._flavor,
pyfunc_workflow=self._pyfunc_workflow,
# load_args=self._load_args,
Expand Down
61 changes: 58 additions & 3 deletions tests/io/models/test_mlflow_model_registry_dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import re

import mlflow
import pytest
from kedro.io.core import DatasetError
from mlflow import MlflowClient
from mlflow import __version__ as mlflow_version
from sklearn.tree import DecisionTreeClassifier

from kedro_mlflow.io.models import MlflowModelRegistryDataset

MLFLOW_VERSION_TUPLE = tuple(
int(x) for x in re.findall("([0-9]+)\.([0-9]+)\.([0-9]+)", mlflow_version)[0]
)


def test_mlflow_model_registry_save_not_implemented(tmp_path):
ml_ds = MlflowModelRegistryDataset(model_name="demo_model")
Expand All @@ -16,14 +23,25 @@ def test_mlflow_model_registry_save_not_implemented(tmp_path):
ml_ds.save(DecisionTreeClassifier())


def test_mlflow_model_registry_alias_and_stage_or_version_fails(tmp_path):
with pytest.raises(
DatasetError,
match=r"You cannot specify 'alias' and 'stage_or_version' simultaneously",
):
MlflowModelRegistryDataset(
model_name="demo_model", alias="my_alias", stage_or_version="my_stage"
)


def test_mlflow_model_registry_load_given_stage_or_version(tmp_path, monkeypatch):
# we must change the working directory because when
# using mlflow with a local database tracking, the artifacts
# are stored in a relative mlruns/ folder so we need to have
# the same working directory that the one of the tracking uri
monkeypatch.chdir(tmp_path)
tracking_uri = r"sqlite:///" + (tmp_path / "mlruns3.db").as_posix()
mlflow.set_tracking_uri(tracking_uri)
tracking_and_registry_uri = r"sqlite:///" + (tmp_path / "mlruns3.db").as_posix()
mlflow.set_tracking_uri(tracking_and_registry_uri)
mlflow.set_registry_uri(tracking_and_registry_uri)

# setup: we train 3 version of a model under a single
# registered model and stage the 2nd one
Expand All @@ -36,7 +54,9 @@ def test_mlflow_model_registry_load_given_stage_or_version(tmp_path, monkeypatch
)
runs[i + 1] = mlflow.active_run().info.run_id

client = MlflowClient(tracking_uri=tracking_uri)
client = MlflowClient(
tracking_uri=tracking_and_registry_uri, registry_uri=tracking_and_registry_uri
)
client.transition_model_version_stage(name="demo_model", version=2, stage="Staging")

# case 1: no version is provided, we take the last one
Expand All @@ -55,3 +75,38 @@ def test_mlflow_model_registry_load_given_stage_or_version(tmp_path, monkeypatch
ml_ds = MlflowModelRegistryDataset(model_name="demo_model", stage_or_version="1")
loaded_model = ml_ds.load()
assert loaded_model.metadata.run_id == runs[1]


@pytest.mark.skipif(
MLFLOW_VERSION_TUPLE < (2, 9, 0), reason="Requires mlflow 2.9.0 or higher"
)
def test_mlflow_model_registry_load_given_alias(tmp_path, monkeypatch):
# we must change the working directory because when
# using mlflow with a local database tracking, the artifacts
# are stored in a relative mlruns/ folder so we need to have
# the same working directory that the one of the tracking uri
monkeypatch.chdir(tmp_path)
tracking_and_registry_uri = r"sqlite:///" + (tmp_path / "mlruns4.db").as_posix()
mlflow.set_tracking_uri(tracking_and_registry_uri)
mlflow.set_registry_uri(tracking_and_registry_uri)

# setup: we train 3 version of a model under a single
# registered model and alias the 2nd one
runs = {}
for i in range(2):
with mlflow.start_run():
model = DecisionTreeClassifier()
mlflow.sklearn.log_model(
model, artifact_path="demo_model", registered_model_name="demo_model"
)
runs[i + 1] = mlflow.active_run().info.run_id

client = MlflowClient(
tracking_uri=tracking_and_registry_uri, registry_uri=tracking_and_registry_uri
)
client.set_registered_model_alias(name="demo_model", alias="champion", version=1)

# case 2: an alias is provided, we take the last model with this stage
ml_ds = MlflowModelRegistryDataset(model_name="demo_model", alias="champion")
loaded_model = ml_ds.load()
assert loaded_model.metadata.run_id == runs[1]

0 comments on commit db131ab

Please sign in to comment.