diff --git a/providers/google/docs/operators/cloud/vertex_ai.rst b/providers/google/docs/operators/cloud/vertex_ai.rst index f870623bee046..646cb4fee51a0 100644 --- a/providers/google/docs/operators/cloud/vertex_ai.rst +++ b/providers/google/docs/operators/cloud/vertex_ai.rst @@ -741,6 +741,18 @@ To update cluster you can use :start-after: [START how_to_cloud_vertex_ai_update_ray_cluster_operator] :end-before: [END how_to_cloud_vertex_ai_update_ray_cluster_operator] +Interacting with experiment run +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To delete experiment run you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.DeleteExperimentRunOperator`. + +.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_delete_experiment_run_operator] + :end-before: [END how_to_cloud_vertex_ai_delete_experiment_run_operator] + Reference ^^^^^^^^^ diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py index d6d5a6cbadc20..f31e131b0a205 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING import vertexai +from google.cloud import aiplatform from vertexai.generative_models import GenerativeModel from vertexai.language_models import TextEmbeddingModel from vertexai.preview.caching import CachedContent @@ -359,3 +360,32 @@ def generate_from_cached_content( ) return response.text + + +class ExperimentRunHook(GoogleBaseHook): + """Use the Vertex AI SDK for Python to create and manage your experiment runs.""" + + @GoogleBaseHook.fallback_to_default_project_id + def delete_experiment_run( + self, + experiment_run_name: str, + experiment_name: str, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + delete_backing_tensorboard_run: bool = False, + ) -> None: + """ + Delete experiment run from the experiment. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param experiment_name: Required. The name of the evaluation experiment. + :param experiment_run_name: Required. The specific run name or ID for this experiment. + :param delete_backing_tensorboard_run: Whether to delete the backing Vertex AI TensorBoard run + that stores time series metrics for this run. + """ + self.log.info("Next experiment run will be deleted: %s", experiment_run_name) + experiment_run = aiplatform.ExperimentRun( + run_name=experiment_run_name, experiment=experiment_name, project=project_id, location=location + ) + experiment_run.delete(delete_backing_tensorboard_run=delete_backing_tensorboard_run) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py index ec8d7749cd0a7..bf7ffa5475c7f 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py @@ -22,7 +22,13 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import GenerativeModelHook +from google.api_core import exceptions + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import ( + ExperimentRunHook, + GenerativeModelHook, +) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator if TYPE_CHECKING: @@ -580,3 +586,68 @@ def execute(self, context: Context): self.log.info("Cached Content Response: %s", cached_content_text) return cached_content_text + + +class DeleteExperimentRunOperator(GoogleCloudBaseOperator): + """ + Use the Rapid Evaluation API to evaluate a model. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param experiment_name: Required. The name of the evaluation experiment. + :param experiment_run_name: Required. The specific run name or ID for this experiment. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ( + "location", + "project_id", + "impersonation_chain", + "experiment_name", + "experiment_run_name", + ) + + def __init__( + self, + *, + project_id: str, + location: str, + experiment_name: str, + experiment_run_name: str, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.experiment_name = experiment_name + self.experiment_run_name = experiment_run_name + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> None: + self.hook = ExperimentRunHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + try: + self.hook.delete_experiment_run( + project_id=self.project_id, + location=self.location, + experiment_name=self.experiment_name, + experiment_run_name=self.experiment_run_name, + ) + except exceptions.NotFound: + raise AirflowException(f"Experiment Run with name {self.experiment_run_name} not found") + + self.log.info("Deleted experiment run: %s", self.experiment_run_name) diff --git a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py index eb453b2d8544d..413c1f08ec0cd 100644 --- a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py +++ b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py @@ -33,6 +33,7 @@ from airflow.providers.google.cloud.operators.vertex_ai.generative_model import ( CountTokensOperator, CreateCachedContentOperator, + DeleteExperimentRunOperator, GenerateFromCachedContentOperator, GenerativeModelGenerateContentOperator, RunEvaluationOperator, @@ -238,6 +239,16 @@ def get_actual_models() -> dict[str, str]: ) # [END how_to_cloud_vertex_ai_run_evaluation_operator] + # [START how_to_cloud_vertex_ai_delete_experiment_run_operator] + delete_experiment_run = DeleteExperimentRunOperator( + task_id="delete_experiment_run_task", + project_id=PROJECT_ID, + location=REGION, + experiment_name=EXPERIMENT_NAME, + experiment_run_name=EXPERIMENT_RUN_NAME, + ) + # [END how_to_cloud_vertex_ai_delete_experiment_run_operator] + # [START how_to_cloud_vertex_ai_create_cached_content_operator] create_cached_content_task = CreateCachedContentOperator( task_id="create_cached_content_task", @@ -264,6 +275,7 @@ def get_actual_models() -> dict[str, str]: # [END how_to_cloud_vertex_ai_generate_from_cached_content_operator] create_cached_content_task >> generate_from_cached_content_task + run_evaluation_task >> delete_experiment_run from tests_common.test_utils.watcher import watcher diff --git a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_generative_model.py b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_generative_model.py index 2683936e7f57e..fbd8fde9c0db3 100644 --- a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_generative_model.py +++ b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_generative_model.py @@ -30,6 +30,7 @@ from airflow.providers.google.cloud.operators.vertex_ai.generative_model import ( CountTokensOperator, CreateCachedContentOperator, + DeleteExperimentRunOperator, GenerateFromCachedContentOperator, GenerativeModelGenerateContentOperator, RunEvaluationOperator, @@ -355,3 +356,31 @@ def test_execute(self, mock_hook): generation_config=None, safety_settings=None, ) + + +class TestVertexAIDeleteExperimentRunOperator: + @mock.patch(VERTEX_AI_PATH.format("generative_model.ExperimentRunHook")) + def test_execute(self, mock_hook): + test_experiment_name = "test_experiment_name" + test_experiment_run_name = "test_experiment_run_name" + + op = DeleteExperimentRunOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + experiment_name=test_experiment_name, + experiment_run_name=test_experiment_run_name, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.delete_experiment_run.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + experiment_name=test_experiment_name, + experiment_run_name=test_experiment_run_name, + )