diff --git a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst index b6bc94620ff2d..f24ab3b06f719 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst @@ -645,6 +645,26 @@ The operator returns the evaluation summary metrics in :ref:`XCom ` under ``return_value`` key. + +.. exampleinclude:: /../../providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_create_cached_content_operator] + :end-before: [END how_to_cloud_vertex_ai_create_cached_content_operator] + +To generate a response from cached content you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.GenerateFromCachedContentOperator`. +The operator returns the cached content response in :ref:`XCom ` under ``return_value`` key. + +.. exampleinclude:: /../../providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_generate_from_cached_content_operator] + :end-before: [END how_to_cloud_vertex_ai_generate_from_cached_content_operator] + Reference ^^^^^^^^^ diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 02323df136bda..ac587bad765be 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -640,7 +640,7 @@ "google-api-python-client>=2.0.2", "google-auth-httplib2>=0.0.1", "google-auth>=2.29.0", - "google-cloud-aiplatform>=1.63.0", + "google-cloud-aiplatform>=1.70.0", "google-cloud-automl>=2.12.0", "google-cloud-batch>=0.13.0", "google-cloud-bigquery-datatransfer>=3.13.0", diff --git a/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py b/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py index 04242306f8739..27037baaafdad 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +++ b/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py @@ -20,12 +20,15 @@ from __future__ import annotations import time +from datetime import timedelta from typing import TYPE_CHECKING, Sequence import vertexai from vertexai.generative_models import GenerativeModel, Part from vertexai.language_models import TextEmbeddingModel, TextGenerationModel +from vertexai.preview.caching import CachedContent from vertexai.preview.evaluation import EvalResult, EvalTask +from vertexai.preview.generative_models import GenerativeModel as preview_generative_model from vertexai.preview.tuning import sft from airflow.exceptions import AirflowProviderDeprecationWarning @@ -95,6 +98,16 @@ def get_eval_task( ) return eval_task + def get_cached_context_model( + self, + cached_content_name: str, + ) -> preview_generative_model: + """Return a Generative Model with Cached Context.""" + cached_content = CachedContent(cached_content_name=cached_content_name) + + cached_context_model = preview_generative_model.from_cached_content(cached_content) + return cached_context_model + @deprecated( planned_removal_date="January 01, 2025", use_instead="Part objects included in contents parameter of " @@ -528,3 +541,69 @@ def run_evaluation( ) return eval_result + + def create_cached_content( + self, + model_name: str, + location: str, + ttl_hours: float = 1, + system_instruction: str | None = None, + contents: list | None = None, + display_name: str | None = None, + project_id: str = PROVIDE_PROJECT_ID, + ) -> str: + """ + Create CachedContent to reduce the cost of requests that contain repeat content with high input token counts. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param model_name: Required. The name of the publisher model to use for cached content. + :param system_instruction: Developer set system instruction. + :param contents: The content to cache. + :param ttl_hours: The TTL for this resource in hours. The expiration time is computed: now + TTL. + Defaults to one hour. + :param display_name: The user-generated meaningful display name of the cached content + """ + vertexai.init(project=project_id, location=location, credentials=self.get_credentials()) + + response = CachedContent.create( + model_name=model_name, + system_instruction=system_instruction, + contents=contents, + ttl=timedelta(hours=ttl_hours), + display_name=display_name, + ) + + return response.name + + def generate_from_cached_content( + self, + location: str, + cached_content_name: str, + contents: list, + generation_config: dict | None = None, + safety_settings: dict | None = None, + project_id: str = PROVIDE_PROJECT_ID, + ) -> str: + """ + Generate a response from CachedContent. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param cached_content_name: Required. The name of the cached content resource. + :param contents: Required. The multi-part content of a message that a user or a program + gives to the generative model, in order to elicit a specific response. + :param generation_config: Optional. Generation configuration settings. + :param safety_settings: Optional. Per request settings for blocking unsafe content. + """ + vertexai.init(project=project_id, location=location, credentials=self.get_credentials()) + + cached_context_model = self.get_cached_context_model(cached_content_name=cached_content_name) + + response = cached_context_model.generate_content( + contents=contents, + generation_config=generation_config, + safety_settings=safety_settings, + ) + + return response.text diff --git a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py index fddd5dcf72861..05a4abf138073 100644 --- a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +++ b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py @@ -21,8 +21,6 @@ from typing import TYPE_CHECKING, Sequence -from google.cloud.aiplatform_v1beta1 import types as types_v1beta1 - from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import GenerativeModelHook from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator @@ -742,8 +740,6 @@ def execute(self, context: Context): self.xcom_push(context, key="total_tokens", value=response.total_tokens) self.xcom_push(context, key="total_billable_characters", value=response.total_billable_characters) - return types_v1beta1.CountTokensResponse.to_dict(response) - class RunEvaluationOperator(GoogleCloudBaseOperator): """ @@ -842,3 +838,155 @@ def execute(self, context: Context): ) return response.summary_metrics + + +class CreateCachedContentOperator(GoogleCloudBaseOperator): + """ + Create CachedContent to reduce the cost of requests that contain repeat content with high input token counts. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param model_name: Required. The name of the publisher model to use for cached content. + :param system_instruction: Developer set system instruction. + :param contents: The content to cache. + :param ttl_hours: The TTL for this resource in hours. The expiration time is computed: now + TTL. + Defaults to one hour. + :param display_name: The user-generated meaningful display name of the cached content + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ( + "location", + "project_id", + "impersonation_chain", + "model_name", + "contents", + "system_instruction", + ) + + def __init__( + self, + *, + project_id: str, + location: str, + model_name: str, + system_instruction: str | None = None, + contents: list | None = None, + ttl_hours: float = 1, + display_name: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.project_id = project_id + self.location = location + self.model_name = model_name + self.system_instruction = system_instruction + self.contents = contents + self.ttl_hours = ttl_hours + self.display_name = display_name + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + self.hook = GenerativeModelHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + cached_content_name = self.hook.create_cached_content( + project_id=self.project_id, + location=self.location, + model_name=self.model_name, + system_instruction=self.system_instruction, + contents=self.contents, + ttl_hours=self.ttl_hours, + display_name=self.display_name, + ) + + self.log.info("Cached Content Name: %s", cached_content_name) + + return cached_content_name + + +class GenerateFromCachedContentOperator(GoogleCloudBaseOperator): + """ + Generate a response from CachedContent. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param cached_content_name: Required. The name of the cached content resource. + :param contents: Required. The multi-part content of a message that a user or a program + gives to the generative model, in order to elicit a specific response. + :param generation_config: Optional. Generation configuration settings. + :param safety_settings: Optional. Per request settings for blocking unsafe content. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ( + "location", + "project_id", + "impersonation_chain", + "cached_content_name", + "contents", + ) + + def __init__( + self, + *, + project_id: str, + location: str, + cached_content_name: str, + contents: list, + generation_config: dict | None = None, + safety_settings: dict | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.project_id = project_id + self.location = location + self.cached_content_name = cached_content_name + self.contents = contents + self.generation_config = generation_config + self.safety_settings = safety_settings + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + self.hook = GenerativeModelHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + cached_content_text = self.hook.generate_from_cached_content( + project_id=self.project_id, + location=self.location, + cached_content_name=self.cached_content_name, + contents=self.contents, + generation_config=self.generation_config, + safety_settings=self.safety_settings, + ) + + self.log.info("Cached Content Response: %s", cached_content_text) + + return cached_content_text diff --git a/providers/src/airflow/providers/google/provider.yaml b/providers/src/airflow/providers/google/provider.yaml index 86a98a6962757..0d616a53e12f5 100644 --- a/providers/src/airflow/providers/google/provider.yaml +++ b/providers/src/airflow/providers/google/provider.yaml @@ -114,7 +114,7 @@ dependencies: - google-api-python-client>=2.0.2 - google-auth>=2.29.0 - google-auth-httplib2>=0.0.1 - - google-cloud-aiplatform>=1.63.0 + - google-cloud-aiplatform>=1.70.0 - google-cloud-automl>=2.12.0 # Excluded versions contain bug https://github.com/apache/airflow/issues/39541 which is resolved in 3.24.0 - google-cloud-bigquery>=3.4.0,!=3.21.*,!=3.22.0,!=3.23.* diff --git a/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py b/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py index 52a8c417c6ac2..35d3fc9256e6c 100644 --- a/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py +++ b/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py @@ -27,7 +27,9 @@ # For no Pydantic environment, we need to skip the tests pytest.importorskip("google.cloud.aiplatform_v1") -from vertexai.generative_models import HarmBlockThreshold, HarmCategory, Tool, grounding +from datetime import timedelta + +from vertexai.generative_models import HarmBlockThreshold, HarmCategory, Part, Tool, grounding from vertexai.preview.evaluation import MetricPromptTemplateExamples from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import ( @@ -106,6 +108,27 @@ TEST_EXPERIMENT_RUN_NAME = "eval-experiment-airflow-operator-run" TEST_PROMPT_TEMPLATE = "{instruction}. Article: {context}. Summary:" +TEST_CACHED_CONTENT_NAME = "test-example-cache" +TEST_CACHED_CONTENT_PROMPT = ["What are these papers about?"] +TEST_CACHED_MODEL = "gemini-1.5-pro-002" +TEST_CACHED_SYSTEM_INSTRUCTION = """ +You are an expert researcher. You always stick to the facts in the sources provided, and never make up new facts. +Now look at these research papers, and answer the following questions. +""" + +TEST_CACHED_CONTENTS = [ + Part.from_uri( + "gs://cloud-samples-data/generative-ai/pdf/2312.11805v3.pdf", + mime_type="application/pdf", + ), + Part.from_uri( + "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf", + mime_type="application/pdf", + ), +] +TEST_CACHED_TTL = 1 +TEST_CACHED_DISPLAY_NAME = "test-example-cache" + BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" GENERATIVE_MODEL_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.generative_model.{}" @@ -299,3 +322,38 @@ def test_run_evaluation(self, mock_eval_task, mock_model) -> None: prompt_template=TEST_PROMPT_TEMPLATE, experiment_run_name=TEST_EXPERIMENT_RUN_NAME, ) + + @mock.patch("vertexai.preview.caching.CachedContent.create") + def test_create_cached_content(self, mock_cached_content_create) -> None: + self.hook.create_cached_content( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + model_name=TEST_CACHED_MODEL, + system_instruction=TEST_CACHED_SYSTEM_INSTRUCTION, + contents=TEST_CACHED_CONTENTS, + ttl_hours=TEST_CACHED_TTL, + display_name=TEST_CACHED_DISPLAY_NAME, + ) + + mock_cached_content_create.assert_called_once_with( + model_name=TEST_CACHED_MODEL, + system_instruction=TEST_CACHED_SYSTEM_INSTRUCTION, + contents=TEST_CACHED_CONTENTS, + ttl=timedelta(hours=TEST_CACHED_TTL), + display_name=TEST_CACHED_DISPLAY_NAME, + ) + + @mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_cached_context_model")) + def test_generate_from_cached_content(self, mock_cached_context_model) -> None: + self.hook.generate_from_cached_content( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + cached_content_name=TEST_CACHED_CONTENT_NAME, + contents=TEST_CACHED_CONTENT_PROMPT, + ) + + mock_cached_context_model.return_value.generate_content.assert_called_once_with( + contents=TEST_CACHED_CONTENT_PROMPT, + generation_config=None, + safety_settings=None, + ) diff --git a/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py b/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py index e8efb9601fce0..5bdb04cb3edb3 100644 --- a/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py +++ b/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py @@ -28,11 +28,13 @@ pytest.importorskip("google.cloud.aiplatform_v1") pytest.importorskip("google.cloud.aiplatform_v1beta1") vertexai = pytest.importorskip("vertexai.generative_models") -from vertexai.generative_models import HarmBlockThreshold, HarmCategory, Tool, grounding +from vertexai.generative_models import HarmBlockThreshold, HarmCategory, Part, Tool, grounding from vertexai.preview.evaluation import MetricPromptTemplateExamples from airflow.providers.google.cloud.operators.vertex_ai.generative_model import ( CountTokensOperator, + CreateCachedContentOperator, + GenerateFromCachedContentOperator, GenerateTextEmbeddingsOperator, GenerativeModelGenerateContentOperator, PromptLanguageModelOperator, @@ -540,3 +542,83 @@ def test_execute( safety_settings=safety_settings, tools=tools, ) + + +class TestVertexAICreateCachedContentOperator: + @mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook")) + def test_execute(self, mock_hook): + model_name = "gemini-1.5-pro-002" + system_instruction = """ + You are an expert researcher. You always stick to the facts in the sources provided, and never make up new facts. + Now look at these research papers, and answer the following questions. + """ + + contents = [ + Part.from_uri( + "gs://cloud-samples-data/generative-ai/pdf/2312.11805v3.pdf", + mime_type="application/pdf", + ), + Part.from_uri( + "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf", + mime_type="application/pdf", + ), + ] + ttl_hours = 1 + display_name = "test-example-cache" + + op = CreateCachedContentOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + model_name=model_name, + system_instruction=system_instruction, + contents=contents, + ttl_hours=ttl_hours, + display_name=display_name, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.create_cached_content.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + model_name=model_name, + system_instruction=system_instruction, + contents=contents, + ttl_hours=ttl_hours, + display_name=display_name, + ) + + +class TestVertexAIGenerateFromCachedContentOperator: + @mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook")) + def test_execute(self, mock_hook): + cached_content_name = "test" + contents = ["what are in these papers"] + + op = GenerateFromCachedContentOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + cached_content_name=cached_content_name, + contents=contents, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.generate_from_cached_content.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + cached_content_name=cached_content_name, + contents=contents, + generation_config=None, + safety_settings=None, + ) diff --git a/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py b/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py index a1140afc1983f..4384626999d0a 100644 --- a/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py +++ b/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py @@ -25,12 +25,14 @@ import os from datetime import datetime -from vertexai.generative_models import HarmBlockThreshold, HarmCategory, Tool, grounding +from vertexai.generative_models import HarmBlockThreshold, HarmCategory, Part, Tool, grounding from vertexai.preview.evaluation import MetricPromptTemplateExamples from airflow.models.dag import DAG from airflow.providers.google.cloud.operators.vertex_ai.generative_model import ( CountTokensOperator, + CreateCachedContentOperator, + GenerateFromCachedContentOperator, GenerativeModelGenerateContentOperator, RunEvaluationOperator, TextEmbeddingModelGetEmbeddingsOperator, @@ -90,6 +92,23 @@ EXPERIMENT_RUN_NAME = "eval-experiment-airflow-operator-run" PROMPT_TEMPLATE = "{instruction}. Article: {context}. Summary:" +CACHED_MODEL = "gemini-1.5-pro-002" +CACHED_SYSTEM_INSTRUCTION = """ +You are an expert researcher. You always stick to the facts in the sources provided, and never make up new facts. +Now look at these research papers, and answer the following questions. +""" + +CACHED_CONTENTS = [ + Part.from_uri( + "gs://cloud-samples-data/generative-ai/pdf/2312.11805v3.pdf", + mime_type="application/pdf", + ), + Part.from_uri( + "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf", + mime_type="application/pdf", + ), +] + with DAG( dag_id=DAG_ID, description="Sample DAG with generative models.", @@ -155,6 +174,33 @@ ) # [END how_to_cloud_vertex_ai_run_evaluation_operator] + # [START how_to_cloud_vertex_ai_create_cached_content_operator] + create_cached_content_task = CreateCachedContentOperator( + task_id="create_cached_content_task", + project_id=PROJECT_ID, + location=REGION, + model_name=CACHED_MODEL, + system_instruction=CACHED_SYSTEM_INSTRUCTION, + contents=CACHED_CONTENTS, + ttl_hours=1, + display_name="example-cache", + ) + # [END how_to_cloud_vertex_ai_create_cached_content_operator] + + # [START how_to_cloud_vertex_ai_generate_from_cached_content_operator] + generate_from_cached_content_task = GenerateFromCachedContentOperator( + task_id="generate_from_cached_content_task", + project_id=PROJECT_ID, + location=REGION, + cached_content_name="{{ task_instance.xcom_pull(task_ids='create_cached_content_task', key='return_value') }}", + contents=["What are the papers about?"], + generation_config=GENERATION_CONFIG, + safety_settings=SAFETY_SETTINGS, + ) + # [END how_to_cloud_vertex_ai_generate_from_cached_content_operator] + + create_cached_content_task >> generate_from_cached_content_task + from tests_common.test_utils.watcher import watcher # This test needs watcher in order to properly mark success/failure