From 16022e008cc2eef122c5cd254cdb52d2c1b7e026 Mon Sep 17 00:00:00 2001 From: Christian Yarros Date: Mon, 16 Dec 2024 18:13:13 +0000 Subject: [PATCH] Add Google Vertex AI Feature Store - Feature View Sync Operators, Sensor (#44891) --- .../operators/cloud/vertex_ai.rst | 32 ++++ generated/provider_dependencies.json | 2 +- .../cloud/hooks/vertex_ai/feature_store.py | 147 ++++++++++++++++ .../operators/vertex_ai/feature_store.py | 163 ++++++++++++++++++ .../cloud/sensors/vertex_ai/__init__.py | 16 ++ .../cloud/sensors/vertex_ai/feature_store.py | 112 ++++++++++++ .../airflow/providers/google/provider.yaml | 7 +- .../hooks/vertex_ai/test_feature_store.py | 112 ++++++++++++ .../operators/vertex_ai/test_feature_store.py | 119 +++++++++++++ .../google/cloud/sensors/test_vertex_ai.py | 148 ++++++++++++++++ .../example_vertex_ai_feature_store.py | 90 ++++++++++ 11 files changed, 946 insertions(+), 2 deletions(-) create mode 100644 providers/src/airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py create mode 100644 providers/src/airflow/providers/google/cloud/operators/vertex_ai/feature_store.py create mode 100644 providers/src/airflow/providers/google/cloud/sensors/vertex_ai/__init__.py create mode 100644 providers/src/airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py create mode 100644 providers/tests/google/cloud/hooks/vertex_ai/test_feature_store.py create mode 100644 providers/tests/google/cloud/operators/vertex_ai/test_feature_store.py create mode 100644 providers/tests/google/cloud/sensors/test_vertex_ai.py create mode 100644 providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_feature_store.py diff --git a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst index 173b23dfa3002..12b86d25d9196 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst @@ -644,6 +644,38 @@ The operator returns the cached content response in :ref:`XCom ` :start-after: [START how_to_cloud_vertex_ai_generate_from_cached_content_operator] :end-before: [END how_to_cloud_vertex_ai_generate_from_cached_content_operator] +Interacting with Vertex AI Feature Store +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To get a feature view sync job you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.feature_store.GetFeatureViewSyncOperator`. +The operator returns sync job results in :ref:`XCom ` under ``return_value`` key. + +.. exampleinclude:: /../../providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_feature_store.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_feature_store_get_feature_view_sync_operator] + :end-before: [END how_to_cloud_vertex_ai_feature_store_get_feature_view_sync_operator] + +To sync a feature view you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.feature_store.SyncFeatureViewOperator`. +The operator returns the sync job name in :ref:`XCom ` under ``return_value`` key. + +.. exampleinclude:: /../../providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_feature_store.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_feature_store_sync_feature_view_operator] + :end-before: [END how_to_cloud_vertex_ai_feature_store_sync_feature_view_operator] + +To check if Feature View Sync succeeded you can use +:class:`~airflow.providers.google.cloud.sensors.vertex_ai.FeatureViewSyncSensor`. + +.. exampleinclude:: /../../providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_feature_store.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_feature_store_feature_view_sync_sensor] + :end-before: [END how_to_cloud_vertex_ai_feature_store_feature_view_sync_sensor] + Reference ^^^^^^^^^ diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 34cf23b9798b9..7a9307ade337f 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -646,7 +646,7 @@ "google-api-python-client>=2.0.2", "google-auth-httplib2>=0.0.1", "google-auth>=2.29.0", - "google-cloud-aiplatform>=1.70.0", + "google-cloud-aiplatform>=1.73.0", "google-cloud-automl>=2.12.0", "google-cloud-batch>=0.13.0", "google-cloud-bigquery-datatransfer>=3.13.0", diff --git a/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py b/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py new file mode 100644 index 0000000000000..69c7b69f8dad5 --- /dev/null +++ b/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Vertex AI Feature Store hook.""" + +from __future__ import annotations + +from google.api_core.client_options import ClientOptions +from google.cloud.aiplatform_v1beta1 import ( + FeatureOnlineStoreAdminServiceClient, +) + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.consts import CLIENT_INFO +from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook + + +class FeatureStoreHook(GoogleBaseHook): + """ + Hook for interacting with Google Cloud Vertex AI Feature Store. + + This hook provides an interface to manage Feature Store resources in Vertex AI, + including feature views and their synchronization operations. It handles authentication + and provides methods for common Feature Store operations. + + :param gcp_conn_id: The connection ID to use for connecting to Google Cloud Platform. + Defaults to 'google_cloud_default'. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials. Can be either a single account or a chain of accounts required to + get the access_token of the last account in the list, which will be impersonated + in the request. If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. If set as a sequence, the identities + from the list must grant Service Account Token Creator IAM role to the directly + preceding identity, with first account from the list granting this role to the + originating account. + """ + + def get_feature_online_store_admin_service_client( + self, + location: str | None = None, + ) -> FeatureOnlineStoreAdminServiceClient: + """ + Create and returns a FeatureOnlineStoreAdminServiceClient object. + + This method initializes a client for interacting with the Feature Store API, + handling proper endpoint configuration based on the specified location. + + :param location: Optional. The Google Cloud region where the service is located. + If provided and not 'global', the client will be configured to use the + region-specific API endpoint. + """ + if location and location != "global": + client_options = ClientOptions(api_endpoint=f"{location}-aiplatform.googleapis.com:443") + else: + client_options = ClientOptions() + return FeatureOnlineStoreAdminServiceClient( + credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options + ) + + def get_feature_view_sync( + self, + location: str, + feature_view_sync_name: str, + ) -> dict: + """ + Retrieve the status and details of a Feature View synchronization operation. + + This method fetches information about a specific feature view sync operation, + including its current status, timing information, and synchronization metrics. + + :param location: The Google Cloud region where the feature store is located + (e.g., 'us-central1', 'us-east1'). + :param feature_view_sync_name: The full resource name of the feature view + sync operation to retrieve. + """ + client = self.get_feature_online_store_admin_service_client(location) + + try: + response = client.get_feature_view_sync(name=feature_view_sync_name) + + report = { + "name": feature_view_sync_name, + "start_time": int(response.run_time.start_time.seconds), + } + + if hasattr(response.run_time, "end_time") and response.run_time.end_time.seconds: + report["end_time"] = int(response.run_time.end_time.seconds) + report["sync_summary"] = { + "row_synced": int(response.sync_summary.row_synced), + "total_slot": int(response.sync_summary.total_slot), + } + + return report + + except Exception as e: + self.log.error("Failed to get feature view sync: %s", str(e)) + raise AirflowException(str(e)) + + @GoogleBaseHook.fallback_to_default_project_id + def sync_feature_view( + self, + location: str, + feature_online_store_id: str, + feature_view_id: str, + project_id: str = PROVIDE_PROJECT_ID, + ) -> str: + """ + Initiate a synchronization operation for a Feature View. + + This method triggers a sync operation that updates the online serving data + for a feature view based on the latest data in the underlying batch source. + The sync operation ensures that the online feature values are up-to-date + for real-time serving. + + :param location: The Google Cloud region where the feature store is located + (e.g., 'us-central1', 'us-east1'). + :param feature_online_store_id: The ID of the online feature store that + contains the feature view to be synchronized. + :param feature_view_id: The ID of the feature view to synchronize. + :param project_id: The ID of the Google Cloud project that contains the + feature store. If not provided, will attempt to determine from the + environment. + """ + client = self.get_feature_online_store_admin_service_client(location) + feature_view = f"projects/{project_id}/locations/{location}/featureOnlineStores/{feature_online_store_id}/featureViews/{feature_view_id}" + + try: + response = client.sync_feature_view(feature_view=feature_view) + + return str(response.feature_view_sync) + + except Exception as e: + self.log.error("Failed to sync feature view: %s", str(e)) + raise AirflowException(str(e)) diff --git a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/feature_store.py b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/feature_store.py new file mode 100644 index 0000000000000..318ff25a3a915 --- /dev/null +++ b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/feature_store.py @@ -0,0 +1,163 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Vertex AI Feature Store operators.""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from airflow.providers.google.cloud.hooks.vertex_ai.feature_store import FeatureStoreHook +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class SyncFeatureViewOperator(GoogleCloudBaseOperator): + """ + Initiate a synchronization operation for a Feature View in Vertex AI Feature Store. + + This operator triggers a sync operation that updates the online serving data for a feature view + based on the latest data in the underlying batch source. The sync operation ensures that + the online feature values are up-to-date for real-time serving. + + :param project_id: Required. The ID of the Google Cloud project that contains the feature store. + This is used to identify which project's resources to interact with. + :param location: Required. The location of the feature store (e.g., 'us-central1', 'us-east1'). + This specifies the Google Cloud region where the feature store resources are located. + :param feature_online_store_id: Required. The ID of the online feature store that contains + the feature view to be synchronized. This store serves as the online serving layer. + :param feature_view_id: Required. The ID of the feature view to synchronize. This identifies + the specific view that needs to have its online values updated from the batch source. + :param gcp_conn_id: The connection ID to use for connecting to Google Cloud Platform. + Defaults to 'google_cloud_default'. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials. Can be either a single account or a chain of accounts required to + get the access_token of the last account in the list, which will be impersonated + in the request. If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. If set as a sequence, the identities + from the list must grant Service Account Token Creator IAM role to the directly + preceding identity, with first account from the list granting this role to the + originating account. + """ + + template_fields: Sequence[str] = ( + "project_id", + "location", + "feature_online_store_id", + "feature_view_id", + ) + + def __init__( + self, + *, + project_id: str, + location: str, + feature_online_store_id: str, + feature_view_id: str, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.feature_online_store_id = feature_online_store_id + self.feature_view_id = feature_view_id + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> str: + """Execute the feature view sync operation.""" + self.hook = FeatureStoreHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Submitting Feature View sync job now...") + response = self.hook.sync_feature_view( + project_id=self.project_id, + location=self.location, + feature_online_store_id=self.feature_online_store_id, + feature_view_id=self.feature_view_id, + ) + self.log.info("Retrieved Feature View sync: %s", response) + + return response + + +class GetFeatureViewSyncOperator(GoogleCloudBaseOperator): + """ + Retrieve the status and details of a Feature View synchronization operation. + + This operator fetches information about a specific feature view sync operation, + including its current status, timing information, and synchronization metrics. + It's typically used to monitor the progress of a sync operation initiated by + the SyncFeatureViewOperator. + + :param location: Required. The location of the feature store (e.g., 'us-central1', 'us-east1'). + This specifies the Google Cloud region where the feature store resources are located. + :param feature_view_sync_name: Required. The full resource name of the feature view + sync operation to retrieve. This is typically the return value from a + SyncFeatureViewOperator execution. + :param gcp_conn_id: The connection ID to use for connecting to Google Cloud Platform. + Defaults to 'google_cloud_default'. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials. Can be either a single account or a chain of accounts required to + get the access_token of the last account in the list, which will be impersonated + in the request. If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. If set as a sequence, the identities + from the list must grant Service Account Token Creator IAM role to the directly + preceding identity, with first account from the list granting this role to the + originating account. + """ + + template_fields: Sequence[str] = ( + "location", + "feature_view_sync_name", + ) + + def __init__( + self, + *, + location: str, + feature_view_sync_name: str, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.feature_view_sync_name = feature_view_sync_name + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> dict[str, Any]: + """Execute the get feature view sync operation.""" + self.hook = FeatureStoreHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Retrieving Feature View sync job now...") + response = self.hook.get_feature_view_sync( + location=self.location, feature_view_sync_name=self.feature_view_sync_name + ) + self.log.info("Retrieved Feature View sync: %s", self.feature_view_sync_name) + self.log.info(response) + + return response diff --git a/providers/src/airflow/providers/google/cloud/sensors/vertex_ai/__init__.py b/providers/src/airflow/providers/google/cloud/sensors/vertex_ai/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/src/airflow/providers/google/cloud/sensors/vertex_ai/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/src/airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py b/providers/src/airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py new file mode 100644 index 0000000000000..88503db2e0cb6 --- /dev/null +++ b/providers/src/airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains a Vertex AI Feature Store sensor.""" + +from __future__ import annotations + +import time +from collections.abc import Sequence +from typing import TYPE_CHECKING + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.vertex_ai.feature_store import FeatureStoreHook +from airflow.sensors.base import BaseSensorOperator + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class FeatureViewSyncSensor(BaseSensorOperator): + """ + Sensor to monitor the state of a Vertex AI Feature View sync operation. + + :param feature_view_sync_name: The name of the feature view sync operation to monitor. (templated) + :param location: Required. The Cloud region in which to handle the request. (templated) + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. + :param wait_timeout: How many seconds to wait for sync to complete. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials. + """ + + template_fields: Sequence[str] = ("location", "feature_view_sync_name") + ui_color = "#f0eee4" + + def __init__( + self, + *, + feature_view_sync_name: str, + location: str, + gcp_conn_id: str = "google_cloud_default", + wait_timeout: int | None = None, + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.feature_view_sync_name = feature_view_sync_name + self.location = location + self.gcp_conn_id = gcp_conn_id + self.wait_timeout = wait_timeout + self.impersonation_chain = impersonation_chain + self.start_sensor_time: float | None = None + + def execute(self, context: Context) -> None: + self.start_sensor_time = time.monotonic() + super().execute(context) + + def _duration(self): + return time.monotonic() - self.start_sensor_time + + def poke(self, context: Context) -> bool: + hook = FeatureStoreHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + try: + response = hook.get_feature_view_sync( + location=self.location, + feature_view_sync_name=self.feature_view_sync_name, + ) + + # Check if the sync has completed by verifying end_time exists + if response.get("end_time", 0) > 0: + self.log.info( + "Feature View sync %s completed. Rows synced: %d, Total slots: %d", + self.feature_view_sync_name, + int(response.get("sync_summary", "").get("row_synced", "")), + int(response.get("sync_summary", "").get("total_slot", "")), + ) + return True + + if self.wait_timeout and self._duration() > self.wait_timeout: + raise AirflowException( + f"Timeout: Feature View sync {self.feature_view_sync_name} " + f"not completed after {self.wait_timeout}s" + ) + + self.log.info("Waiting for Feature View sync %s to complete.", self.feature_view_sync_name) + return False + + except Exception as e: + if self.wait_timeout and self._duration() > self.wait_timeout: + raise AirflowException( + f"Timeout: Feature View sync {self.feature_view_sync_name} " + f"not completed after {self.wait_timeout}s" + ) + self.log.info("Error checking sync status, will retry: %s", str(e)) + return False diff --git a/providers/src/airflow/providers/google/provider.yaml b/providers/src/airflow/providers/google/provider.yaml index a869560a9450f..442e1cecccce9 100644 --- a/providers/src/airflow/providers/google/provider.yaml +++ b/providers/src/airflow/providers/google/provider.yaml @@ -118,7 +118,7 @@ dependencies: - google-api-python-client>=2.0.2 - google-auth>=2.29.0 - google-auth-httplib2>=0.0.1 - - google-cloud-aiplatform>=1.70.0 + - google-cloud-aiplatform>=1.73.0 - google-cloud-automl>=2.12.0 # Excluded versions contain bug https://github.com/apache/airflow/issues/39541 which is resolved in 3.24.0 - google-cloud-bigquery>=3.4.0,!=3.21.*,!=3.22.0,!=3.23.* @@ -693,6 +693,7 @@ operators: - airflow.providers.google.cloud.operators.vertex_ai.model_service - airflow.providers.google.cloud.operators.vertex_ai.pipeline_job - airflow.providers.google.cloud.operators.vertex_ai.generative_model + - airflow.providers.google.cloud.operators.vertex_ai.feature_store - integration-name: Google Looker python-modules: - airflow.providers.google.cloud.operators.looker @@ -743,6 +744,9 @@ sensors: - integration-name: Google Cloud Pub/Sub python-modules: - airflow.providers.google.cloud.sensors.pubsub + - integration-name: Google Vertex AI + python-modules: + - airflow.providers.google.cloud.sensors.vertex_ai.feature_store - integration-name: Google Cloud Workflows python-modules: - airflow.providers.google.cloud.sensors.workflows @@ -963,6 +967,7 @@ hooks: - airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job - airflow.providers.google.cloud.hooks.vertex_ai.generative_model - airflow.providers.google.cloud.hooks.vertex_ai.prediction_service + - airflow.providers.google.cloud.hooks.vertex_ai.feature_store - integration-name: Google Looker python-modules: - airflow.providers.google.cloud.hooks.looker diff --git a/providers/tests/google/cloud/hooks/vertex_ai/test_feature_store.py b/providers/tests/google/cloud/hooks/vertex_ai/test_feature_store.py new file mode 100644 index 0000000000000..aff2a9612bf0a --- /dev/null +++ b/providers/tests/google/cloud/hooks/vertex_ai/test_feature_store.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest import mock + +from airflow.providers.google.cloud.hooks.vertex_ai.feature_store import FeatureStoreHook + +from providers.tests.google.cloud.utils.base_gcp_mock import ( + mock_base_gcp_hook_default_project_id, +) + +BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" +FEATURE_STORE_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.feature_store.{}" + +TEST_GCP_CONN_ID = "test-gcp-conn-id" +TEST_PROJECT_ID = "test-project" +TEST_LOCATION = "us-central1" +TEST_FEATURE_ONLINE_STORE_ID = "test-store" +TEST_FEATURE_VIEW_ID = "test-view" +TEST_FEATURE_VIEW = f"projects/{TEST_PROJECT_ID}/locations/{TEST_LOCATION}/featureOnlineStores/{TEST_FEATURE_ONLINE_STORE_ID}/featureViews/{TEST_FEATURE_VIEW_ID}" +TEST_FEATURE_VIEW_SYNC_NAME = f"{TEST_FEATURE_VIEW}/featureViewSyncs/sync123" + + +class TestFeatureStoreHook: + def setup_method(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id + ): + self.hook = FeatureStoreHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @mock.patch(FEATURE_STORE_STRING.format("FeatureOnlineStoreAdminServiceClient"), autospec=True) + @mock.patch(BASE_STRING.format("GoogleBaseHook.get_credentials")) + def test_get_feature_online_store_admin_service_client(self, mock_get_credentials, mock_client): + self.hook.get_feature_online_store_admin_service_client(location=TEST_LOCATION) + mock_client.assert_called_once_with( + credentials=mock_get_credentials.return_value, client_info=mock.ANY, client_options=mock.ANY + ) + client_options = mock_client.call_args[1]["client_options"] + assert client_options.api_endpoint == f"{TEST_LOCATION}-aiplatform.googleapis.com:443" + + mock_client.reset_mock() + self.hook.get_feature_online_store_admin_service_client() + mock_client.assert_called_once_with( + credentials=mock_get_credentials.return_value, client_info=mock.ANY, client_options=mock.ANY + ) + client_options = mock_client.call_args[1]["client_options"] + assert not client_options.api_endpoint + + @mock.patch(FEATURE_STORE_STRING.format("FeatureStoreHook.get_feature_online_store_admin_service_client")) + def test_get_feature_view_sync(self, mock_client_getter): + mock_client = mock.MagicMock() + mock_client_getter.return_value = mock_client + + # Create a mock response with the expected structure + mock_response = mock.MagicMock() + mock_response.run_time.start_time.seconds = 1 + mock_response.run_time.end_time.seconds = 1 + mock_response.sync_summary.row_synced = 1 + mock_response.sync_summary.total_slot = 1 + + mock_client.get_feature_view_sync.return_value = mock_response + + expected_result = { + "name": TEST_FEATURE_VIEW_SYNC_NAME, + "start_time": 1, + "end_time": 1, + "sync_summary": {"row_synced": 1, "total_slot": 1}, + } + + result = self.hook.get_feature_view_sync( + location=TEST_LOCATION, + feature_view_sync_name=TEST_FEATURE_VIEW_SYNC_NAME, + ) + + mock_client.get_feature_view_sync.assert_called_once_with(name=TEST_FEATURE_VIEW_SYNC_NAME) + assert result == expected_result + + @mock.patch(FEATURE_STORE_STRING.format("FeatureStoreHook.get_feature_online_store_admin_service_client")) + def test_sync_feature_view(self, mock_client_getter): + mock_client = mock.MagicMock() + mock_client_getter.return_value = mock_client + + # Create a mock response with the expected structure + mock_response = mock.MagicMock() + mock_response.feature_view_sync = "test-sync-operation-name" + mock_client.sync_feature_view.return_value = mock_response + + result = self.hook.sync_feature_view( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + feature_online_store_id=TEST_FEATURE_ONLINE_STORE_ID, + feature_view_id=TEST_FEATURE_VIEW_ID, + ) + + mock_client.sync_feature_view.assert_called_once_with(feature_view=TEST_FEATURE_VIEW) + assert result == "test-sync-operation-name" diff --git a/providers/tests/google/cloud/operators/vertex_ai/test_feature_store.py b/providers/tests/google/cloud/operators/vertex_ai/test_feature_store.py new file mode 100644 index 0000000000000..5340b69d72009 --- /dev/null +++ b/providers/tests/google/cloud/operators/vertex_ai/test_feature_store.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest import mock + +from airflow.providers.google.cloud.operators.vertex_ai.feature_store import ( + GetFeatureViewSyncOperator, + SyncFeatureViewOperator, +) + +VERTEX_AI_PATH = "airflow.providers.google.cloud.operators.vertex_ai.{}" + +TASK_ID = "test_task_id" +GCP_PROJECT = "test-project" +GCP_LOCATION = "us-central1" +GCP_CONN_ID = "test-conn" +IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] +FEATURE_ONLINE_STORE_ID = "test-store" +FEATURE_VIEW_ID = "test-view" +FEATURE_VIEW_SYNC_NAME = f"projects/{GCP_PROJECT}/locations/{GCP_LOCATION}/featureOnlineStores/{FEATURE_ONLINE_STORE_ID}/featureViews/{FEATURE_VIEW_ID}/featureViewSyncs/sync123" + + +class TestSyncFeatureViewOperator: + @mock.patch(VERTEX_AI_PATH.format("feature_store.FeatureStoreHook")) + def test_execute(self, mock_hook_class): + # Create the mock hook and set up its return value + mock_hook = mock.MagicMock() + mock_hook_class.return_value = mock_hook + + # Set up the return value for sync_feature_view to match the hook implementation + mock_hook.sync_feature_view.return_value = FEATURE_VIEW_SYNC_NAME + + op = SyncFeatureViewOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + feature_online_store_id=FEATURE_ONLINE_STORE_ID, + feature_view_id=FEATURE_VIEW_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + response = op.execute(context={"ti": mock.MagicMock()}) + + # Verify hook initialization + mock_hook_class.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + # Verify hook method call + mock_hook.sync_feature_view.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + feature_online_store_id=FEATURE_ONLINE_STORE_ID, + feature_view_id=FEATURE_VIEW_ID, + ) + + # Verify response matches expected value + assert response == FEATURE_VIEW_SYNC_NAME + + +class TestGetFeatureViewSyncOperator: + @mock.patch(VERTEX_AI_PATH.format("feature_store.FeatureStoreHook")) + def test_execute(self, mock_hook_class): + # Create the mock hook and set up expected response + mock_hook = mock.MagicMock() + mock_hook_class.return_value = mock_hook + + expected_response = { + "name": FEATURE_VIEW_SYNC_NAME, + "start_time": 1000, + "end_time": 2000, + "sync_summary": {"row_synced": 500, "total_slot": 4}, + } + + # Set up the return value for get_feature_view_sync to match the hook implementation + mock_hook.get_feature_view_sync.return_value = expected_response + + op = GetFeatureViewSyncOperator( + task_id=TASK_ID, + location=GCP_LOCATION, + feature_view_sync_name=FEATURE_VIEW_SYNC_NAME, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + response = op.execute(context={"ti": mock.MagicMock()}) + + # Verify hook initialization + mock_hook_class.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + # Verify hook method call + mock_hook.get_feature_view_sync.assert_called_once_with( + location=GCP_LOCATION, + feature_view_sync_name=FEATURE_VIEW_SYNC_NAME, + ) + + # Verify response matches expected structure + assert response == expected_response diff --git a/providers/tests/google/cloud/sensors/test_vertex_ai.py b/providers/tests/google/cloud/sensors/test_vertex_ai.py new file mode 100644 index 0000000000000..5e3171e988510 --- /dev/null +++ b/providers/tests/google/cloud/sensors/test_vertex_ai.py @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import Mock + +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.sensors.vertex_ai.feature_store import FeatureViewSyncSensor + +TASK_ID = "test-task" +GCP_CONN_ID = "test-conn" +GCP_LOCATION = "us-central1" +FEATURE_VIEW_SYNC_NAME = "projects/123/locations/us-central1/featureViews/test-view/operations/sync-123" +TIMEOUT = 120 + + +class TestFeatureViewSyncSensor: + def create_sync_response(self, end_time=None, row_synced=None, total_slot=None): + response = {} + if end_time is not None: + response["end_time"] = end_time + if row_synced is not None and total_slot is not None: + response["sync_summary"] = {"row_synced": str(row_synced), "total_slot": str(total_slot)} + return response + + @mock.patch("airflow.providers.google.cloud.sensors.vertex_ai.feature_store.FeatureStoreHook") + def test_sync_completed(self, mock_hook): + mock_hook.return_value.get_feature_view_sync.return_value = self.create_sync_response( + end_time=1234567890, row_synced=1000, total_slot=5 + ) + + sensor = FeatureViewSyncSensor( + task_id=TASK_ID, + feature_view_sync_name=FEATURE_VIEW_SYNC_NAME, + location=GCP_LOCATION, + gcp_conn_id=GCP_CONN_ID, + timeout=TIMEOUT, + ) + ret = sensor.poke(context={}) + + mock_hook.return_value.get_feature_view_sync.assert_called_once_with( + location=GCP_LOCATION, + feature_view_sync_name=FEATURE_VIEW_SYNC_NAME, + ) + assert ret + + @mock.patch("airflow.providers.google.cloud.sensors.vertex_ai.feature_store.FeatureStoreHook") + def test_sync_running(self, mock_hook): + mock_hook.return_value.get_feature_view_sync.return_value = self.create_sync_response( + end_time=0, row_synced=0, total_slot=5 + ) + + sensor = FeatureViewSyncSensor( + task_id=TASK_ID, + feature_view_sync_name=FEATURE_VIEW_SYNC_NAME, + location=GCP_LOCATION, + gcp_conn_id=GCP_CONN_ID, + timeout=TIMEOUT, + ) + ret = sensor.poke(context={}) + + mock_hook.return_value.get_feature_view_sync.assert_called_once_with( + location=GCP_LOCATION, + feature_view_sync_name=FEATURE_VIEW_SYNC_NAME, + ) + assert not ret + + @mock.patch("airflow.providers.google.cloud.sensors.vertex_ai.feature_store.FeatureStoreHook") + def test_sync_error_with_retry(self, mock_hook): + mock_hook.return_value.get_feature_view_sync.side_effect = Exception("API Error") + + sensor = FeatureViewSyncSensor( + task_id=TASK_ID, + feature_view_sync_name=FEATURE_VIEW_SYNC_NAME, + location=GCP_LOCATION, + gcp_conn_id=GCP_CONN_ID, + timeout=TIMEOUT, + ) + ret = sensor.poke(context={}) + + mock_hook.return_value.get_feature_view_sync.assert_called_once_with( + location=GCP_LOCATION, + feature_view_sync_name=FEATURE_VIEW_SYNC_NAME, + ) + assert not ret + + @mock.patch("airflow.providers.google.cloud.sensors.vertex_ai.feature_store.FeatureStoreHook") + def test_timeout_during_running(self, mock_hook): + mock_hook.return_value.get_feature_view_sync.return_value = self.create_sync_response( + end_time=0, row_synced=0, total_slot=5 + ) + + sensor = FeatureViewSyncSensor( + task_id=TASK_ID, + feature_view_sync_name=FEATURE_VIEW_SYNC_NAME, + location=GCP_LOCATION, + gcp_conn_id=GCP_CONN_ID, + timeout=TIMEOUT, + wait_timeout=300, + ) + + sensor._duration = Mock() + sensor._duration.return_value = 301 + + with pytest.raises( + AirflowException, + match=f"Timeout: Feature View sync {FEATURE_VIEW_SYNC_NAME} not completed after 300s", + ): + sensor.poke(context={}) + + @mock.patch("airflow.providers.google.cloud.sensors.vertex_ai.feature_store.FeatureStoreHook") + def test_timeout_during_error(self, mock_hook): + mock_hook.return_value.get_feature_view_sync.side_effect = Exception("API Error") + + sensor = FeatureViewSyncSensor( + task_id=TASK_ID, + feature_view_sync_name=FEATURE_VIEW_SYNC_NAME, + location=GCP_LOCATION, + gcp_conn_id=GCP_CONN_ID, + timeout=TIMEOUT, + wait_timeout=300, + ) + + sensor._duration = Mock() + sensor._duration.return_value = 301 + + with pytest.raises( + AirflowException, + match=f"Timeout: Feature View sync {FEATURE_VIEW_SYNC_NAME} not completed after 300s", + ): + sensor.poke(context={}) diff --git a/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_feature_store.py b/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_feature_store.py new file mode 100644 index 0000000000000..3d0794fa85cbd --- /dev/null +++ b/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_feature_store.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google Vertex AI Feature Store operations. +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow import DAG +from airflow.providers.google.cloud.operators.vertex_ai.feature_store import ( + GetFeatureViewSyncOperator, + SyncFeatureViewOperator, +) +from airflow.providers.google.cloud.sensors.vertex_ai.feature_store import FeatureViewSyncSensor + +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") +DAG_ID = "vertex_ai_feature_store_dag" +REGION = "us-central1" + +FEATURE_ONLINE_STORE_ID = "my_feature_online_store_unique" +FEATURE_VIEW_ID = "feature_view_publications" + +with DAG( + dag_id=DAG_ID, + description="Sample DAG with Vertex AI Feature Store operations.", + schedule="@once", + start_date=datetime(2024, 1, 1), + catchup=False, + tags=["example", "vertex_ai", "feature_store"], +) as dag: + # [START how_to_cloud_vertex_ai_feature_store_sync_feature_view_operator] + sync_task = SyncFeatureViewOperator( + task_id="sync_task", + project_id=PROJECT_ID, + location=REGION, + feature_online_store_id=FEATURE_ONLINE_STORE_ID, + feature_view_id=FEATURE_VIEW_ID, + ) + # [END how_to_cloud_vertex_ai_feature_store_sync_feature_view_operator] + + # [START how_to_cloud_vertex_ai_feature_store_feature_view_sync_sensor] + wait_for_sync = FeatureViewSyncSensor( + task_id="wait_for_sync", + location=REGION, + feature_view_sync_name="{{ task_instance.xcom_pull(task_ids='sync_task', key='return_value')}}", + poke_interval=60, # Check every minute + timeout=600, # Timeout after 10 minutes + mode="reschedule", + ) + # [END how_to_cloud_vertex_ai_feature_store_feature_view_sync_sensor] + + # [START how_to_cloud_vertex_ai_feature_store_get_feature_view_sync_operator] + get_task = GetFeatureViewSyncOperator( + task_id="get_task", + location=REGION, + feature_view_sync_name="{{ task_instance.xcom_pull(task_ids='sync_task', key='return_value')}}", + ) + # [END how_to_cloud_vertex_ai_feature_store_get_feature_view_sync_operator] + + sync_task >> wait_for_sync >> get_task + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)