From f67d48989a3428caa8022c99b259c460b75ce22e Mon Sep 17 00:00:00 2001 From: Christian Yarros Date: Mon, 22 Apr 2024 16:28:17 +0000 Subject: [PATCH] add templated fields for google llm operators --- .../operators/vertex_ai/generative_model.py | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py b/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py index 8ac08c3b5f275..da1436a6ab3cc 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py @@ -33,11 +33,11 @@ class PromptLanguageModelOperator(GoogleCloudBaseOperator): Uses the Vertex AI PaLM API to generate natural language text. :param project_id: Required. The ID of the Google Cloud project that the - service belongs to. + service belongs to (templated). :param location: Required. The ID of the Google Cloud location that the - service belongs to. + service belongs to (templated). :param prompt: Required. Inputs or queries that a user or a program gives - to the Vertex AI PaLM API, in order to elicit a specific response. + to the Vertex AI PaLM API, in order to elicit a specific response (templated). :param pretrained_model: By default uses the pre-trained model `text-bison`, optimized for performing natural language tasks such as classification, summarization, extraction, content creation, and ideation. @@ -60,6 +60,8 @@ class PromptLanguageModelOperator(GoogleCloudBaseOperator): account from the list granting this role to the originating account (templated). """ + template_fields = ("location", "project_id", "impersonation_chain", "prompt") + def __init__( self, *, @@ -116,11 +118,11 @@ class GenerateTextEmbeddingsOperator(GoogleCloudBaseOperator): Uses the Vertex AI PaLM API to generate natural language text. :param project_id: Required. The ID of the Google Cloud project that the - service belongs to. + service belongs to (templated). :param location: Required. The ID of the Google Cloud location that the - service belongs to. + service belongs to (templated). :param prompt: Required. Inputs or queries that a user or a program gives - to the Vertex AI PaLM API, in order to elicit a specific response. + to the Vertex AI PaLM API, in order to elicit a specific response (templated). :param pretrained_model: By default uses the pre-trained model `textembedding-gecko`, optimized for performing text embeddings. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. @@ -134,6 +136,8 @@ class GenerateTextEmbeddingsOperator(GoogleCloudBaseOperator): account from the list granting this role to the originating account (templated). """ + template_fields = ("location", "project_id", "impersonation_chain", "prompt") + def __init__( self, *, @@ -178,11 +182,11 @@ class PromptMultimodalModelOperator(GoogleCloudBaseOperator): Use the Vertex AI Gemini Pro foundation model to generate natural language text. :param project_id: Required. The ID of the Google Cloud project that the - service belongs to. + service belongs to (templated). :param location: Required. The ID of the Google Cloud location that the - service belongs to. + service belongs to (templated). :param prompt: Required. Inputs or queries that a user or a program gives - to the Multi-modal model, in order to elicit a specific response. + to the Multi-modal model, in order to elicit a specific response (templated). :param pretrained_model: By default uses the pre-trained model `gemini-pro`, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can @@ -198,6 +202,8 @@ class PromptMultimodalModelOperator(GoogleCloudBaseOperator): account from the list granting this role to the originating account (templated). """ + template_fields = ("location", "project_id", "impersonation_chain", "prompt") + def __init__( self, *, @@ -240,11 +246,11 @@ class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator): Use the Vertex AI Gemini Pro foundation model to generate natural language text. :param project_id: Required. The ID of the Google Cloud project that the - service belongs to. + service belongs to (templated). :param location: Required. The ID of the Google Cloud location that the - service belongs to. + service belongs to (templated). :param prompt: Required. Inputs or queries that a user or a program gives - to the Multi-modal model, in order to elicit a specific response. + to the Multi-modal model, in order to elicit a specific response (templated). :param pretrained_model: By default uses the pre-trained model `gemini-pro-vision`, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can @@ -263,6 +269,8 @@ class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator): account from the list granting this role to the originating account (templated). """ + template_fields = ("location", "project_id", "impersonation_chain", "prompt") + def __init__( self, *,