Skip to content

Commit

Permalink
feat: GenAI - Batch Prediction - Added support for tuned GenAI models
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 646136098
  • Loading branch information
jaycee-li authored and copybara-github committed Jun 24, 2024
1 parent a31ac4d commit a90ee8d
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 17 deletions.
126 changes: 124 additions & 2 deletions tests/unit/vertexai/test_batch_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,15 @@
import vertexai
from google.cloud.aiplatform import base as aiplatform_base
from google.cloud.aiplatform import initializer as aiplatform_initializer
from google.cloud.aiplatform.compat.services import job_service_client
from google.cloud.aiplatform.compat.services import (
job_service_client,
model_service_client,
)
from google.cloud.aiplatform.compat.types import (
batch_prediction_job as gca_batch_prediction_job_compat,
io as gca_io_compat,
job_state as gca_job_state_compat,
model as gca_model,
)
from vertexai.preview import batch_prediction
from vertexai.generative_models import GenerativeModel
Expand All @@ -43,6 +47,7 @@

_TEST_GEMINI_MODEL_NAME = "gemini-1.0-pro"
_TEST_GEMINI_MODEL_RESOURCE_NAME = f"publishers/google/models/{_TEST_GEMINI_MODEL_NAME}"
_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME = "projects/123/locations/us-central1/models/456"
_TEST_PALM_MODEL_NAME = "text-bison"
_TEST_PALM_MODEL_RESOURCE_NAME = f"publishers/google/models/{_TEST_PALM_MODEL_NAME}"

Expand Down Expand Up @@ -122,6 +127,48 @@ def get_batch_prediction_job_with_gcs_output_mock():
yield get_job_mock


@pytest.fixture
def get_batch_prediction_job_with_tuned_gemini_model_mock():
with mock.patch.object(
job_service_client.JobServiceClient, "get_batch_prediction_job"
) as get_job_mock:
get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob(
name=_TEST_BATCH_PREDICTION_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
model=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME,
state=_TEST_JOB_STATE_SUCCESS,
output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo(
gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX
),
)
yield get_job_mock


@pytest.fixture
def get_gemini_model_mock():
with mock.patch.object(
model_service_client.ModelServiceClient, "get_model"
) as get_model_mock:
get_model_mock.return_value = gca_model.Model(
name=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME,
model_source_info=gca_model.ModelSourceInfo(
source_type=gca_model.ModelSourceInfo.ModelSourceType.GENIE
),
)
yield get_model_mock


@pytest.fixture
def get_non_gemini_model_mock():
with mock.patch.object(
model_service_client.ModelServiceClient, "get_model"
) as get_model_mock:
get_model_mock.return_value = gca_model.Model(
name=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME,
)
yield get_model_mock


@pytest.fixture
def get_batch_prediction_job_invalid_model_mock():
with mock.patch.object(
Expand Down Expand Up @@ -205,6 +252,21 @@ def test_init_batch_prediction_job(
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
)

def test_init_batch_prediction_job_with_tuned_gemini_model(
self,
get_batch_prediction_job_with_tuned_gemini_model_mock,
get_gemini_model_mock,
):
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)

get_batch_prediction_job_with_tuned_gemini_model_mock.assert_called_once_with(
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
)
get_gemini_model_mock.assert_called_once_with(
name=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME,
retry=aiplatform_base._DEFAULT_RETRY,
)

@pytest.mark.usefixtures("get_batch_prediction_job_invalid_model_mock")
def test_init_batch_prediction_job_invalid_model(self):
with pytest.raises(
Expand All @@ -217,6 +279,23 @@ def test_init_batch_prediction_job_invalid_model(self):
):
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)

@pytest.mark.usefixtures(
"get_batch_prediction_job_with_tuned_gemini_model_mock",
"get_non_gemini_model_mock",
)
def test_init_batch_prediction_job_with_invalid_tuned_model(
self,
):
with pytest.raises(
ValueError,
match=(
f"BatchPredictionJob '{_TEST_BATCH_PREDICTION_JOB_ID}' "
f"runs with the model '{_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME}', "
"which is not a GenAI model."
),
):
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)

@pytest.mark.usefixtures("get_batch_prediction_job_with_gcs_output_mock")
def test_submit_batch_prediction_job_with_gcs_input(
self, create_batch_prediction_job_mock
Expand Down Expand Up @@ -368,16 +447,59 @@ def test_submit_batch_prediction_job_with_bq_input_without_output_uri_prefix(
timeout=None,
)

@pytest.mark.usefixtures("create_batch_prediction_job_mock")
def test_submit_batch_prediction_job_with_tuned_model(
self,
get_gemini_model_mock,
):
job = batch_prediction.BatchPredictionJob.submit(
source_model=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME,
input_dataset=_TEST_BQ_INPUT_URI,
)

assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
get_gemini_model_mock.assert_called_once_with(
name=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME,
retry=aiplatform_base._DEFAULT_RETRY,
)

def test_submit_batch_prediction_job_with_invalid_source_model(self):
with pytest.raises(
ValueError,
match=(f"Model '{_TEST_PALM_MODEL_RESOURCE_NAME}' is not a GenAI model."),
match=(
f"Model '{_TEST_PALM_MODEL_RESOURCE_NAME}' is not a Generative AI model."
),
):
batch_prediction.BatchPredictionJob.submit(
source_model=_TEST_PALM_MODEL_NAME,
input_dataset=_TEST_GCS_INPUT_URI,
)

@pytest.mark.usefixtures("get_non_gemini_model_mock")
def test_submit_batch_prediction_job_with_non_gemini_tuned_model(self):
with pytest.raises(
ValueError,
match=(
f"Model '{_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME}' "
"is not a Generative AI model."
),
):
batch_prediction.BatchPredictionJob.submit(
source_model=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME,
input_dataset=_TEST_GCS_INPUT_URI,
)

def test_submit_batch_prediction_job_with_invalid_model_name(self):
invalid_model_name = "invalid/model/name"
with pytest.raises(
ValueError,
match=(f"Invalid format for model name: {invalid_model_name}."),
):
batch_prediction.BatchPredictionJob.submit(
source_model=invalid_model_name,
input_dataset=_TEST_GCS_INPUT_URI,
)

def test_submit_batch_prediction_job_with_invalid_input_dataset(self):
with pytest.raises(
ValueError,
Expand Down
54 changes: 39 additions & 15 deletions vertexai/batch_prediction/_batch_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from google.cloud.aiplatform import base as aiplatform_base
from google.cloud.aiplatform import initializer as aiplatform_initializer
from google.cloud.aiplatform import jobs
from google.cloud.aiplatform import models
from google.cloud.aiplatform import utils as aiplatform_utils
from google.cloud.aiplatform_v1 import types as gca_types
from vertexai import generative_models
Expand All @@ -32,6 +33,7 @@
_LOGGER = aiplatform_base.Logger(__name__)

_GEMINI_MODEL_PATTERN = r"publishers/google/models/gemini"
_GEMINI_TUNED_MODEL_PATTERN = r"^projects/[0-9]+?/locations/[0-9a-z-]+?/models/[0-9]+?$"


class BatchPredictionJob(aiplatform_base._VertexAiResourceNounPlus):
Expand Down Expand Up @@ -64,8 +66,7 @@ def __init__(self, batch_prediction_job_name: str):
self._gca_resource = self._get_gca_resource(
resource_name=batch_prediction_job_name
)
# TODO(b/338452508) Support tuned GenAI models.
if not re.search(_GEMINI_MODEL_PATTERN, self.model_name):
if not self._is_genai_model(self.model_name):
raise ValueError(
f"BatchPredictionJob '{batch_prediction_job_name}' "
f"runs with the model '{self.model_name}', "
Expand Down Expand Up @@ -117,9 +118,12 @@ def submit(
Args:
source_model (Union[str, generative_models.GenerativeModel]):
Model name or a GenerativeModel instance for batch prediction.
Supported formats: "gemini-1.0-pro", "models/gemini-1.0-pro",
and "publishers/google/models/gemini-1.0-pro"
A GenAI model name or a tuned model name or a GenerativeModel instance
for batch prediction.
Supported formats for model name: "gemini-1.0-pro",
"models/gemini-1.0-pro", and "publishers/google/models/gemini-1.0-pro"
Supported formats for tuned model name: "789" and
"projects/123/locations/456/models/789"
input_dataset (Union[str,List[str]]):
GCS URI(-s) or Bigquery URI to your input data to run batch
prediction on. Example: "gs://path/to/input/data.jsonl" or
Expand All @@ -142,12 +146,13 @@ def submit(
set in vertexai.init().
"""
# Handle model name
# TODO(b/338452508) Support tuned GenAI models.
model_name = cls._reconcile_model_name(
source_model._model_name
if isinstance(source_model, generative_models.GenerativeModel)
else source_model
)
if not cls._is_genai_model(model_name):
raise ValueError(f"Model '{model_name}' is not a Generative AI model.")

# Handle input URI
gcs_source = None
Expand Down Expand Up @@ -244,9 +249,7 @@ def delete(self):
def list(cls, filter=None) -> List["BatchPredictionJob"]:
"""Lists all BatchPredictionJob instances that run with GenAI models."""
return cls._list(
cls_filter=lambda gca_resource: re.search(
_GEMINI_MODEL_PATTERN, gca_resource.model
),
cls_filter=lambda gca_resource: cls._is_genai_model(gca_resource.model),
filter=filter,
)

Expand All @@ -263,23 +266,44 @@ def _dashboard_uri(self) -> Optional[str]:

@classmethod
def _reconcile_model_name(cls, model_name: str) -> str:
"""Reconciles model name to a publisher model resource name."""
"""Reconciles model name to a publisher model resource name or a tuned model resource name."""
if not model_name:
raise ValueError("model_name must not be empty")

if "/" not in model_name:
# model name (e.g., gemini-1.0-pro)
model_name = "publishers/google/models/" + model_name
elif model_name.startswith("models/"):
# publisher model name (e.g., models/gemini-1.0-pro)
model_name = "publishers/google/" + model_name
elif not model_name.startswith("publishers/google/models/") and not re.search(
r"^projects/.*?/locations/.*?/publishers/google/models/.*$", model_name
elif (
# publisher model full name
not model_name.startswith("publishers/google/models/")
# tuned model full resource name
and not re.search(_GEMINI_TUNED_MODEL_PATTERN, model_name)
):
raise ValueError(f"Invalid format for model name: {model_name}.")

if not re.search(_GEMINI_MODEL_PATTERN, model_name):
raise ValueError(f"Model '{model_name}' is not a GenAI model.")

return model_name

@classmethod
def _is_genai_model(cls, model_name: str) -> bool:
"""Validates if a given model_name represents a GenAI model."""
if re.search(_GEMINI_MODEL_PATTERN, model_name):
# Model is a Gemini model.
return True

if re.search(_GEMINI_TUNED_MODEL_PATTERN, model_name):
model = models.Model(model_name)
if (
model.gca_resource.model_source_info.source_type
== gca_types.model.ModelSourceInfo.ModelSourceType.GENIE
):
# Model is a tuned Gemini model.
return True

return False

@classmethod
def _complete_bq_uri(cls, uri: Optional[str] = None):
"""Completes a BigQuery uri to a BigQuery table uri."""
Expand Down

0 comments on commit a90ee8d

Please sign in to comment.