Skip to content

Commit

Permalink
feat: LLM - TextEmbeddingModel - Added support for structural inputs …
Browse files Browse the repository at this point in the history
…(`TextEmbeddingInput`), `auto_truncate` parameter and result `statistics`

PiperOrigin-RevId: 558465128
  • Loading branch information
Ark-kun authored and copybara-github committed Aug 20, 2023
1 parent 76b95b9 commit cbf9b6e
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 24 deletions.
16 changes: 11 additions & 5 deletions tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,17 @@ def test_text_embedding(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

model = TextEmbeddingModel.from_pretrained("google/textembedding-gecko@001")
embeddings = model.get_embeddings(["What is life?"])
assert embeddings
for embedding in embeddings:
vector = embedding.values
assert len(vector) == 768
# One short text, one llong text (to check truncation)
texts = ["What is life?", "What is life?" * 1000]
embeddings = model.get_embeddings(texts)
assert len(embeddings) == 2
assert len(embeddings[0].values) == 768
assert embeddings[0].statistics.token_count > 0
assert not embeddings[0].statistics.truncated

assert len(embeddings[1].values) == 768
assert embeddings[1].statistics.token_count > 1000
assert embeddings[1].statistics.truncated

def test_tuning(self, shared_state):
"""Test tuning, listing and loading models."""
Expand Down
46 changes: 43 additions & 3 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def reverse_string_2(s):""",
_TEST_TEXT_EMBEDDING_PREDICTION = {
"embeddings": {
"values": list([1.0] * _TEXT_EMBEDDING_VECTOR_LENGTH),
"statistics": {"truncated": False, "token_count": 4.0},
}
}

Expand Down Expand Up @@ -2170,18 +2171,57 @@ def test_text_embedding(self):

gca_predict_response = gca_prediction_service.PredictResponse()
gca_predict_response.predictions.append(_TEST_TEXT_EMBEDDING_PREDICTION)
gca_predict_response.predictions.append(_TEST_TEXT_EMBEDDING_PREDICTION)

expected_embedding = _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]
with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response,
):
embeddings = model.get_embeddings(["What is life?"])
) as mock_predict:
embeddings = model.get_embeddings(
[
"What is life?",
language_models.TextEmbeddingInput(
text="Foo",
task_type="RETRIEVAL_DOCUMENT",
title="Bar",
),
language_models.TextEmbeddingInput(
text="Baz",
task_type="CLASSIFICATION",
),
],
auto_truncate=False,
)
prediction_instances = mock_predict.call_args[1]["instances"]
assert prediction_instances == [
{"content": "What is life?"},
{
"content": "Foo",
"taskType": "RETRIEVAL_DOCUMENT",
"title": "Bar",
},
{
"content": "Baz",
"taskType": "CLASSIFICATION",
},
]
prediction_parameters = mock_predict.call_args[1]["parameters"]
assert not prediction_parameters["autoTruncate"]
assert embeddings
for embedding in embeddings:
vector = embedding.values
assert len(vector) == _TEXT_EMBEDDING_VECTOR_LENGTH
assert vector == _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]["values"]
assert vector == expected_embedding["values"]
assert (
embedding.statistics.token_count
== expected_embedding["statistics"]["token_count"]
)
assert (
embedding.statistics.truncated
== expected_embedding["statistics"]["truncated"]
)

def test_text_embedding_ga(self):
"""Tests the text embedding model."""
Expand Down
2 changes: 2 additions & 0 deletions vertexai/language_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
CodeGenerationModel,
InputOutputTextPair,
TextEmbedding,
TextEmbeddingInput,
TextEmbeddingModel,
TextGenerationModel,
TextGenerationResponse,
Expand All @@ -37,6 +38,7 @@
"CodeGenerationModel",
"InputOutputTextPair",
"TextEmbedding",
"TextEmbeddingInput",
"TextEmbeddingModel",
"TextGenerationModel",
"TextGenerationResponse",
Expand Down
97 changes: 81 additions & 16 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,33 @@ def send_message(
return response_obj


@dataclasses.dataclass
class TextEmbeddingInput:
"""Structural text embedding input.
Attributes:
text: The main text content to embed.
task_type: The name of the downstream task the embeddings will be used for.
Valid values:
RETRIEVAL_QUERY
Specifies the given text is a query in a search/retrieval setting.
RETRIEVAL_DOCUMENT
Specifies the given text is a document from the corpus being searched.
SEMANTIC_SIMILARITY
Specifies the given text will be used for STS.
CLASSIFICATION
Specifies that the given text will be classified.
CLUSTERING
Specifies that the embeddings will be used for clustering.
title: Optional identifier of the text content.
"""
text: str
task_type: Optional[str] = None
title: Optional[str] = None


class TextEmbeddingModel(_LanguageModel):
"""TextEmbeddingModel converts text into a vector of floating-point numbers.
"""TextEmbeddingModel class calculates embeddings for the given texts.
Examples::
Expand All @@ -711,36 +736,76 @@ class TextEmbeddingModel(_LanguageModel):
"gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml"
)

def get_embeddings(self, texts: List[str]) -> List["TextEmbedding"]:
instances = [{"content": str(text)} for text in texts]
def get_embeddings(self,
texts: List[Union[str, TextEmbeddingInput]],
*,
auto_truncate: bool = True,
) -> List["TextEmbedding"]:
"""Calculates embeddings for the given texts.
Args:
texts(str): A list of texts or `TextEmbeddingInput` objects to embed.
auto_truncate(bool): Whether to automatically truncate long texts. Default: True.
Returns:
A list of `TextEmbedding` objects.
"""
instances = []
for text in texts:
if isinstance(text, TextEmbeddingInput):
instance = {"content": text.text}
if text.task_type:
instance["taskType"] = text.task_type
if text.title:
instance["title"] = text.title
elif isinstance(text, str):
instance = {"content": text}
else:
raise TypeError(f"Unsupported text embedding input type: {text}.")
instances.append(instance)
parameters = {"autoTruncate": auto_truncate}

prediction_response = self._endpoint.predict(
instances=instances,
parameters=parameters,
)

return [
TextEmbedding(
values=prediction["embeddings"]["values"],
results = []
for prediction in prediction_response.predictions:
embeddings = prediction["embeddings"]
statistics = embeddings["statistics"]
result = TextEmbedding(
values=embeddings["values"],
statistics=TextEmbeddingStatistics(
token_count=statistics["token_count"],
truncated=statistics["truncated"],
),
_prediction_response=prediction_response,
)
for prediction in prediction_response.predictions
]
results.append(result)

return results


class _PreviewTextEmbeddingModel(TextEmbeddingModel, _ModelWithBatchPredict):
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE


@dataclasses.dataclass
class TextEmbeddingStatistics:
"""Text embedding statistics."""

token_count: int
truncated: bool


@dataclasses.dataclass
class TextEmbedding:
"""Contains text embedding vector."""
"""Text embedding vector and statistics."""

def __init__(
self,
values: List[float],
_prediction_response: Any = None,
):
self.values = values
self._prediction_response = _prediction_response
values: List[float]
statistics: TextEmbeddingStatistics
_prediction_response: aiplatform.models.Prediction = None


@dataclasses.dataclass
Expand Down
2 changes: 2 additions & 0 deletions vertexai/preview/language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
CodeChatSession,
InputOutputTextPair,
TextEmbedding,
TextEmbeddingInput,
TextGenerationResponse,
)

Expand Down Expand Up @@ -60,6 +61,7 @@
"EvaluationTextClassificationSpec",
"InputOutputTextPair",
"TextEmbedding",
"TextEmbeddingInput",
"TextEmbeddingModel",
"TextGenerationModel",
"TextGenerationResponse",
Expand Down

0 comments on commit cbf9b6e

Please sign in to comment.