Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,10 @@
from __future__ import annotations

import dataclasses
from collections.abc import MutableMapping
from typing import Any

from airflow.exceptions import AirflowOptionalProviderFeatureException

try:
import vertex_ray
from google._upb._message import ScalarMapContainer # type: ignore[attr-defined]
except ImportError:
# Fallback for environments where the upb module is not available.
raise AirflowOptionalProviderFeatureException(
"google._upb._message.ScalarMapContainer is not available. "
"Please install the ray package to use this feature."
)
import vertex_ray
from google.cloud import aiplatform
from google.cloud.aiplatform.vertex_ray.util import resources
from google.cloud.aiplatform_v1 import (
Expand All @@ -59,7 +50,7 @@ def serialize_cluster_obj(self, cluster_obj: resources.Cluster) -> dict:
def __encode_value(value: Any) -> Any:
if isinstance(value, (list, Repeated)):
return [__encode_value(nested_value) for nested_value in value]
if isinstance(value, ScalarMapContainer):
if not isinstance(value, dict) and isinstance(value, MutableMapping):
return {key: __encode_value(nested_value) for key, nested_value in dict(value).items()}
if dataclasses.is_dataclass(value):
return dataclasses.asdict(value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def execute(self, context: Context):
location=self.location,
cluster_id=self.cluster_id,
)
self.log.info("Cluster was gotten.")
self.log.info("Cluster data has been retrieved.")
ray_cluster_dict = self.hook.serialize_cluster_obj(ray_cluster)
return ray_cluster_dict
except NotFound as not_found_err:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@

from unittest import mock

import pytest

ScalarMapContainer = pytest.importorskip("google._upb._message.ScalarMapContainer")
from google.cloud.aiplatform.vertex_ray.util.resources import Cluster, Resources

from airflow.providers.google.cloud.hooks.vertex_ai.ray import RayHook

Expand Down Expand Up @@ -168,6 +166,55 @@ def test_list_ray_clusters(self, mock_aiplatform_init, mock_list_ray_clusters) -
mock_aiplatform_init.assert_called_once()
mock_list_ray_clusters.assert_called_once()

@mock.patch(RAY_STRING.format("aiplatform.init"))
def test_serialize_cluster_obj(self, mock_aiplatform_init) -> None:
RESOURCE_SAMPLE = {
"accelerator_count": 0,
"accelerator_type": None,
"autoscaling_spec": None,
"boot_disk_size_gb": 100,
"boot_disk_type": "pd-ssd",
"custom_image": None,
"machine_type": "n1-standard-16",
"node_count": 1,
}
SAMPLE_CLUSTER_SERIALIZED = {
"cluster_resource_name": TEST_CLUSTER_NAME,
"dashboard_address": "dashboard_addr",
"head_node_type": RESOURCE_SAMPLE,
"labels": {"label1": "val1"},
"network": "custom_network",
"psc_interface_config": None,
"python_version": TEST_PYTHON_VERSION,
"ray_logs_enabled": True,
"ray_metric_enabled": True,
"ray_version": TEST_RAY_VERSION,
"reserved_ip_ranges": [
"172.16.0.0/16",
"10.10.10.0/28",
],
"service_account": None,
"state": "RUNNING",
"worker_node_types": [RESOURCE_SAMPLE, RESOURCE_SAMPLE],
}
cluster_obj = Cluster(
cluster_resource_name=TEST_CLUSTER_NAME,
state="RUNNING", # type: ignore[arg-type]
network="custom_network",
reserved_ip_ranges=["172.16.0.0/16", "10.10.10.0/28"],
python_version=TEST_PYTHON_VERSION,
ray_version=TEST_RAY_VERSION,
head_node_type=Resources(**RESOURCE_SAMPLE), # type: ignore[arg-type]
worker_node_types=[
Resources(**RESOURCE_SAMPLE), # type: ignore[arg-type]
Resources(**RESOURCE_SAMPLE), # type: ignore[arg-type]
],
dashboard_address="dashboard_addr",
labels={"label1": "val1"},
)

assert self.hook.serialize_cluster_obj(cluster_obj) == SAMPLE_CLUSTER_SERIALIZED


class TestRayWithoutDefaultProjectIdHook:
def setup_method(self):
Expand Down