Skip to content

Commit

Permalink
feat: batch_predict method generally-available at TextEmbeddingModel.
Browse files Browse the repository at this point in the history
FUTURE_COPYBARA_INTEGRATE_REVIEW=#4425 from googleapis:release-please--branches--main 292499f
PiperOrigin-RevId: 676292399
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Sep 19, 2024
1 parent c0626fe commit 2a7cf3a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
34 changes: 34 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4739,6 +4739,40 @@ def test_batch_prediction_for_code_generation(self):
)

def test_batch_prediction_for_text_embedding(self):
"""Tests batch prediction."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
),
):
model = language_models.TextEmbeddingModel.from_pretrained(
"textembedding-gecko@001"
)

with mock.patch.object(
target=aiplatform.BatchPredictionJob,
attribute="create",
) as mock_create:
model.batch_predict(
dataset="gs://test-bucket/test_table.jsonl",
destination_uri_prefix="gs://test-bucket/results/",
model_parameters={},
)
mock_create.assert_called_once_with(
model_name=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/textembedding-gecko@001",
job_display_name=None,
gcs_source="gs://test-bucket/test_table.jsonl",
gcs_destination_prefix="gs://test-bucket/results/",
model_parameters={},
)

def test_batch_prediction_for_text_embedding_preview(self):
"""Tests batch prediction."""
aiplatform.init(
project=_TEST_PROJECT,
Expand Down
3 changes: 2 additions & 1 deletion vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2421,6 +2421,7 @@ class _TunableTextEmbeddingModelMixin(_PreviewTunableTextEmbeddingModelMixin):

class TextEmbeddingModel(
_TextEmbeddingModel,
_ModelWithBatchPredict,
_TunableTextEmbeddingModelMixin,
_CountTokensMixin,
):
Expand All @@ -2430,8 +2431,8 @@ class TextEmbeddingModel(
class _PreviewTextEmbeddingModel(
_TextEmbeddingModel,
_ModelWithBatchPredict,
_CountTokensMixin,
_PreviewTunableTextEmbeddingModelMixin,
_CountTokensMixin,
):
__name__ = "TextEmbeddingModel"
__module__ = "vertexai.preview.language_models"
Expand Down

0 comments on commit 2a7cf3a

Please sign in to comment.