Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions providers/google/docs/operators/cloud/vertex_ai.rst
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,18 @@ To update cluster you can use
:start-after: [START how_to_cloud_vertex_ai_update_ray_cluster_operator]
:end-before: [END how_to_cloud_vertex_ai_update_ray_cluster_operator]

Interacting with experiment run
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

To delete experiment run you can use
:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.DeleteExperimentRunOperator`.

.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_vertex_ai_delete_experiment_run_operator]
:end-before: [END how_to_cloud_vertex_ai_delete_experiment_run_operator]

Reference
^^^^^^^^^

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import TYPE_CHECKING

import vertexai
from google.cloud import aiplatform
from vertexai.generative_models import GenerativeModel
from vertexai.language_models import TextEmbeddingModel
from vertexai.preview.caching import CachedContent
Expand Down Expand Up @@ -359,3 +360,32 @@ def generate_from_cached_content(
)

return response.text


class ExperimentRunHook(GoogleBaseHook):
"""Use the Vertex AI SDK for Python to create and manage your experiment runs."""

@GoogleBaseHook.fallback_to_default_project_id
def delete_experiment_run(
self,
experiment_run_name: str,
experiment_name: str,
location: str,
project_id: str = PROVIDE_PROJECT_ID,
delete_backing_tensorboard_run: bool = False,
) -> None:
"""
Delete experiment run from the experiment.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param experiment_name: Required. The name of the evaluation experiment.
:param experiment_run_name: Required. The specific run name or ID for this experiment.
:param delete_backing_tensorboard_run: Whether to delete the backing Vertex AI TensorBoard run
that stores time series metrics for this run.
"""
self.log.info("Next experiment run will be deleted: %s", experiment_run_name)
experiment_run = aiplatform.ExperimentRun(
run_name=experiment_run_name, experiment=experiment_name, project=project_id, location=location
)
experiment_run.delete(delete_backing_tensorboard_run=delete_backing_tensorboard_run)
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING

from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import GenerativeModelHook
from google.api_core import exceptions

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import (
ExperimentRunHook,
GenerativeModelHook,
)
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator

if TYPE_CHECKING:
Expand Down Expand Up @@ -580,3 +586,68 @@ def execute(self, context: Context):
self.log.info("Cached Content Response: %s", cached_content_text)

return cached_content_text


class DeleteExperimentRunOperator(GoogleCloudBaseOperator):
"""
Use the Rapid Evaluation API to evaluate a model.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param experiment_name: Required. The name of the evaluation experiment.
:param experiment_run_name: Required. The specific run name or ID for this experiment.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

template_fields = (
"location",
"project_id",
"impersonation_chain",
"experiment_name",
"experiment_run_name",
)

def __init__(
self,
*,
project_id: str,
location: str,
experiment_name: str,
experiment_run_name: str,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.project_id = project_id
self.location = location
self.experiment_name = experiment_name
self.experiment_run_name = experiment_run_name
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Context) -> None:
self.hook = ExperimentRunHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

try:
self.hook.delete_experiment_run(
project_id=self.project_id,
location=self.location,
experiment_name=self.experiment_name,
experiment_run_name=self.experiment_run_name,
)
except exceptions.NotFound:
raise AirflowException(f"Experiment Run with name {self.experiment_run_name} not found")

self.log.info("Deleted experiment run: %s", self.experiment_run_name)
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from airflow.providers.google.cloud.operators.vertex_ai.generative_model import (
CountTokensOperator,
CreateCachedContentOperator,
DeleteExperimentRunOperator,
GenerateFromCachedContentOperator,
GenerativeModelGenerateContentOperator,
RunEvaluationOperator,
Expand Down Expand Up @@ -238,6 +239,16 @@ def get_actual_models() -> dict[str, str]:
)
# [END how_to_cloud_vertex_ai_run_evaluation_operator]

# [START how_to_cloud_vertex_ai_delete_experiment_run_operator]
delete_experiment_run = DeleteExperimentRunOperator(
task_id="delete_experiment_run_task",
project_id=PROJECT_ID,
location=REGION,
experiment_name=EXPERIMENT_NAME,
experiment_run_name=EXPERIMENT_RUN_NAME,
)
# [END how_to_cloud_vertex_ai_delete_experiment_run_operator]

# [START how_to_cloud_vertex_ai_create_cached_content_operator]
create_cached_content_task = CreateCachedContentOperator(
task_id="create_cached_content_task",
Expand All @@ -264,6 +275,7 @@ def get_actual_models() -> dict[str, str]:
# [END how_to_cloud_vertex_ai_generate_from_cached_content_operator]

create_cached_content_task >> generate_from_cached_content_task
run_evaluation_task >> delete_experiment_run

from tests_common.test_utils.watcher import watcher

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from airflow.providers.google.cloud.operators.vertex_ai.generative_model import (
CountTokensOperator,
CreateCachedContentOperator,
DeleteExperimentRunOperator,
GenerateFromCachedContentOperator,
GenerativeModelGenerateContentOperator,
RunEvaluationOperator,
Expand Down Expand Up @@ -355,3 +356,31 @@ def test_execute(self, mock_hook):
generation_config=None,
safety_settings=None,
)


class TestVertexAIDeleteExperimentRunOperator:
@mock.patch(VERTEX_AI_PATH.format("generative_model.ExperimentRunHook"))
def test_execute(self, mock_hook):
test_experiment_name = "test_experiment_name"
test_experiment_run_name = "test_experiment_run_name"

op = DeleteExperimentRunOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
experiment_name=test_experiment_name,
experiment_run_name=test_experiment_run_name,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.return_value.delete_experiment_run.assert_called_once_with(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
experiment_name=test_experiment_name,
experiment_run_name=test_experiment_run_name,
)