diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py index f31e131b0a205..3ed03b66b63a2 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py @@ -21,15 +21,15 @@ import time from datetime import timedelta -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Literal import vertexai from google.cloud import aiplatform from vertexai.generative_models import GenerativeModel from vertexai.language_models import TextEmbeddingModel +from vertexai.preview import generative_models as preview_generative_model from vertexai.preview.caching import CachedContent from vertexai.preview.evaluation import EvalResult, EvalTask -from vertexai.preview.generative_models import GenerativeModel as preview_generative_model from vertexai.preview.tuning import sft from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook @@ -86,7 +86,7 @@ def get_cached_context_model( """Return a Generative Model with Cached Context.""" cached_content = CachedContent(cached_content_name=cached_content_name) - cached_context_model = preview_generative_model.from_cached_content(cached_content) + cached_context_model = preview_generative_model.GenerativeModel.from_cached_content(cached_content) return cached_context_model @GoogleBaseHook.fallback_to_default_project_id @@ -164,7 +164,7 @@ def supervised_fine_tuning_train( tuned_model_display_name: str | None = None, validation_dataset: str | None = None, epochs: int | None = None, - adapter_size: int | None = None, + adapter_size: Literal[1, 4, 8, 16] | None = None, learning_rate_multiplier: float | None = None, project_id: str = PROVIDE_PROJECT_ID, ) -> types_v1.TuningJob: @@ -301,7 +301,7 @@ def create_cached_content( location: str, ttl_hours: float = 1, system_instruction: str | None = None, - contents: list | None = None, + contents: list[Any] | None = None, display_name: str | None = None, project_id: str = PROVIDE_PROJECT_ID, ) -> str: diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py index 365efed44ff84..686e1674078ac 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py @@ -20,7 +20,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Literal from google.api_core import exceptions @@ -222,7 +222,7 @@ def __init__( tuned_model_display_name: str | None = None, validation_dataset: str | None = None, epochs: int | None = None, - adapter_size: int | None = None, + adapter_size: Literal[1, 4, 8, 16] | None = None, learning_rate_multiplier: float | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, @@ -474,7 +474,7 @@ def __init__( location: str, model_name: str, system_instruction: str | None = None, - contents: list | None = None, + contents: list[Any] | None = None, ttl_hours: float = 1, display_name: str | None = None, gcp_conn_id: str = "google_cloud_default",