diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py index 3ed03b66b63a2..871ab81d10d1e 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py @@ -35,7 +35,6 @@ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook if TYPE_CHECKING: - from google.cloud.aiplatform_v1 import types as types_v1 from google.cloud.aiplatform_v1beta1 import types as types_v1beta1 @@ -50,7 +49,7 @@ def get_text_embedding_model(self, pretrained_model: str): def get_generative_model( self, pretrained_model: str, - system_instruction: str | None = None, + system_instruction: Any | None = None, generation_config: dict | None = None, safety_settings: dict | None = None, tools: list | None = None, @@ -82,7 +81,7 @@ def get_eval_task( def get_cached_context_model( self, cached_content_name: str, - ) -> preview_generative_model: + ) -> Any: """Return a Generative Model with Cached Context.""" cached_content = CachedContent(cached_content_name=cached_content_name) @@ -167,7 +166,7 @@ def supervised_fine_tuning_train( adapter_size: Literal[1, 4, 8, 16] | None = None, learning_rate_multiplier: float | None = None, project_id: str = PROVIDE_PROJECT_ID, - ) -> types_v1.TuningJob: + ) -> Any: """ Use the Supervised Fine Tuning API to create a tuning job. @@ -300,7 +299,7 @@ def create_cached_content( model_name: str, location: str, ttl_hours: float = 1, - system_instruction: str | None = None, + system_instruction: Any | None = None, contents: list[Any] | None = None, display_name: str | None = None, project_id: str = PROVIDE_PROJECT_ID, diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py index 686e1674078ac..20257dad196c8 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py @@ -473,7 +473,7 @@ def __init__( project_id: str, location: str, model_name: str, - system_instruction: str | None = None, + system_instruction: Any | None = None, contents: list[Any] | None = None, ttl_hours: float = 1, display_name: str | None = None,