diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py index 5aede4a0465e7..76d94299d4967 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py @@ -20,19 +20,10 @@ from __future__ import annotations import dataclasses +from collections.abc import MutableMapping from typing import Any -from airflow.exceptions import AirflowOptionalProviderFeatureException - -try: - import vertex_ray - from google._upb._message import ScalarMapContainer # type: ignore[attr-defined] -except ImportError: - # Fallback for environments where the upb module is not available. - raise AirflowOptionalProviderFeatureException( - "google._upb._message.ScalarMapContainer is not available. " - "Please install the ray package to use this feature." - ) +import vertex_ray from google.cloud import aiplatform from google.cloud.aiplatform.vertex_ray.util import resources from google.cloud.aiplatform_v1 import ( @@ -59,7 +50,7 @@ def serialize_cluster_obj(self, cluster_obj: resources.Cluster) -> dict: def __encode_value(value: Any) -> Any: if isinstance(value, (list, Repeated)): return [__encode_value(nested_value) for nested_value in value] - if isinstance(value, ScalarMapContainer): + if not isinstance(value, dict) and isinstance(value, MutableMapping): return {key: __encode_value(nested_value) for key, nested_value in dict(value).items()} if dataclasses.is_dataclass(value): return dataclasses.asdict(value) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py index e06c18ea1c42a..95a284c1defc8 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py @@ -282,7 +282,7 @@ def execute(self, context: Context): location=self.location, cluster_id=self.cluster_id, ) - self.log.info("Cluster was gotten.") + self.log.info("Cluster data has been retrieved.") ray_cluster_dict = self.hook.serialize_cluster_obj(ray_cluster) return ray_cluster_dict except NotFound as not_found_err: diff --git a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_ray.py b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_ray.py index 34feb72e82182..20e5e8829d662 100644 --- a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_ray.py +++ b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_ray.py @@ -19,9 +19,7 @@ from unittest import mock -import pytest - -ScalarMapContainer = pytest.importorskip("google._upb._message.ScalarMapContainer") +from google.cloud.aiplatform.vertex_ray.util.resources import Cluster, Resources from airflow.providers.google.cloud.hooks.vertex_ai.ray import RayHook @@ -168,6 +166,55 @@ def test_list_ray_clusters(self, mock_aiplatform_init, mock_list_ray_clusters) - mock_aiplatform_init.assert_called_once() mock_list_ray_clusters.assert_called_once() + @mock.patch(RAY_STRING.format("aiplatform.init")) + def test_serialize_cluster_obj(self, mock_aiplatform_init) -> None: + RESOURCE_SAMPLE = { + "accelerator_count": 0, + "accelerator_type": None, + "autoscaling_spec": None, + "boot_disk_size_gb": 100, + "boot_disk_type": "pd-ssd", + "custom_image": None, + "machine_type": "n1-standard-16", + "node_count": 1, + } + SAMPLE_CLUSTER_SERIALIZED = { + "cluster_resource_name": TEST_CLUSTER_NAME, + "dashboard_address": "dashboard_addr", + "head_node_type": RESOURCE_SAMPLE, + "labels": {"label1": "val1"}, + "network": "custom_network", + "psc_interface_config": None, + "python_version": TEST_PYTHON_VERSION, + "ray_logs_enabled": True, + "ray_metric_enabled": True, + "ray_version": TEST_RAY_VERSION, + "reserved_ip_ranges": [ + "172.16.0.0/16", + "10.10.10.0/28", + ], + "service_account": None, + "state": "RUNNING", + "worker_node_types": [RESOURCE_SAMPLE, RESOURCE_SAMPLE], + } + cluster_obj = Cluster( + cluster_resource_name=TEST_CLUSTER_NAME, + state="RUNNING", # type: ignore[arg-type] + network="custom_network", + reserved_ip_ranges=["172.16.0.0/16", "10.10.10.0/28"], + python_version=TEST_PYTHON_VERSION, + ray_version=TEST_RAY_VERSION, + head_node_type=Resources(**RESOURCE_SAMPLE), # type: ignore[arg-type] + worker_node_types=[ + Resources(**RESOURCE_SAMPLE), # type: ignore[arg-type] + Resources(**RESOURCE_SAMPLE), # type: ignore[arg-type] + ], + dashboard_address="dashboard_addr", + labels={"label1": "val1"}, + ) + + assert self.hook.serialize_cluster_obj(cluster_obj) == SAMPLE_CLUSTER_SERIALIZED + class TestRayWithoutDefaultProjectIdHook: def setup_method(self):