Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@

import time
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Literal

import vertexai
from google.cloud import aiplatform
from vertexai.generative_models import GenerativeModel
from vertexai.language_models import TextEmbeddingModel
from vertexai.preview import generative_models as preview_generative_model
from vertexai.preview.caching import CachedContent
from vertexai.preview.evaluation import EvalResult, EvalTask
from vertexai.preview.generative_models import GenerativeModel as preview_generative_model
from vertexai.preview.tuning import sft

from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
Expand Down Expand Up @@ -86,7 +86,7 @@ def get_cached_context_model(
"""Return a Generative Model with Cached Context."""
cached_content = CachedContent(cached_content_name=cached_content_name)

cached_context_model = preview_generative_model.from_cached_content(cached_content)
cached_context_model = preview_generative_model.GenerativeModel.from_cached_content(cached_content)
return cached_context_model

@GoogleBaseHook.fallback_to_default_project_id
Expand Down Expand Up @@ -164,7 +164,7 @@ def supervised_fine_tuning_train(
tuned_model_display_name: str | None = None,
validation_dataset: str | None = None,
epochs: int | None = None,
adapter_size: int | None = None,
adapter_size: Literal[1, 4, 8, 16] | None = None,
learning_rate_multiplier: float | None = None,
project_id: str = PROVIDE_PROJECT_ID,
) -> types_v1.TuningJob:
Expand Down Expand Up @@ -301,7 +301,7 @@ def create_cached_content(
location: str,
ttl_hours: float = 1,
system_instruction: str | None = None,
contents: list | None = None,
contents: list[Any] | None = None,
display_name: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Literal

from google.api_core import exceptions

Expand Down Expand Up @@ -222,7 +222,7 @@ def __init__(
tuned_model_display_name: str | None = None,
validation_dataset: str | None = None,
epochs: int | None = None,
adapter_size: int | None = None,
adapter_size: Literal[1, 4, 8, 16] | None = None,
learning_rate_multiplier: float | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
Expand Down Expand Up @@ -474,7 +474,7 @@ def __init__(
location: str,
model_name: str,
system_instruction: str | None = None,
contents: list | None = None,
contents: list[Any] | None = None,
ttl_hours: float = 1,
display_name: str | None = None,
gcp_conn_id: str = "google_cloud_default",
Expand Down