From 14c590b0026ae01ef5b8ca64e60414cf49abec1e Mon Sep 17 00:00:00 2001 From: Safoine El khabich Date: Thu, 1 Aug 2024 14:02:14 +0100 Subject: [PATCH 1/2] Enable Unity Catalog for MLFlow experiment tracking on Databricks --- .../mlflow/experiment_trackers/mlflow_experiment_tracker.py | 4 +++- .../mlflow/flavors/mlflow_experiment_tracker_flavor.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py b/src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py index ab88848ed60..2a8cc91f582 100644 --- a/src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py +++ b/src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py @@ -57,6 +57,7 @@ DATABRICKS_USERNAME = "DATABRICKS_USERNAME" DATABRICKS_PASSWORD = "DATABRICKS_PASSWORD" DATABRICKS_TOKEN = "DATABRICKS_TOKEN" +DATABRICKS_UNITY_CATALOG = "databricks-uc" class MLFlowExperimentTracker(BaseExperimentTracker): @@ -285,7 +286,6 @@ def configure_mlflow(self) -> None: """Configures the MLflow tracking URI and any additional credentials.""" tracking_uri = self.get_tracking_uri() mlflow.set_tracking_uri(tracking_uri) - mlflow.set_registry_uri(tracking_uri) if is_databricks_tracking_uri(tracking_uri): if self.config.databricks_host: @@ -296,6 +296,8 @@ def configure_mlflow(self) -> None: os.environ[DATABRICKS_PASSWORD] = self.config.tracking_password if self.config.tracking_token: os.environ[DATABRICKS_TOKEN] = self.config.tracking_token + if self.config.enable_unity_catalog: + mlflow.set_registry_uri(DATABRICKS_UNITY_CATALOG) else: os.environ[MLFLOW_TRACKING_URI] = tracking_uri if self.config.tracking_username: diff --git a/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py b/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py index cb1952bf469..e103b221fb7 100644 --- a/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py +++ b/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py @@ -98,6 +98,8 @@ class MLFlowExperimentTrackerConfig( databricks_host: The host of the Databricks workspace with the MLflow managed server to connect to. This is only required if `tracking_uri` value is set to `"databricks"`. + enable_unity_catalog: If `True`, will enable the Unity Catalog for + logging and registering models. """ tracking_uri: Optional[str] = None @@ -106,6 +108,7 @@ class MLFlowExperimentTrackerConfig( tracking_token: Optional[str] = SecretField(default=None) tracking_insecure_tls: bool = False databricks_host: Optional[str] = None + enable_unity_catalog: bool = False @model_validator(mode="after") def _ensure_authentication_if_necessary( From 08ccd8f28cf24be5429aee459fa0da1cd570877f Mon Sep 17 00:00:00 2001 From: Safoine El Khabich <34200873+safoinme@users.noreply.github.com> Date: Fri, 2 Aug 2024 12:49:50 +0100 Subject: [PATCH 2/2] Update src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py Co-authored-by: Alex Strick van Linschoten --- .../mlflow/flavors/mlflow_experiment_tracker_flavor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py b/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py index e103b221fb7..368820199b3 100644 --- a/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py +++ b/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py @@ -98,7 +98,7 @@ class MLFlowExperimentTrackerConfig( databricks_host: The host of the Databricks workspace with the MLflow managed server to connect to. This is only required if `tracking_uri` value is set to `"databricks"`. - enable_unity_catalog: If `True`, will enable the Unity Catalog for + enable_unity_catalog: If `True`, will enable the Databricks Unity Catalog for logging and registering models. """