diff --git a/libs/vertexai/tests/integration_tests/test_embeddings.py b/libs/vertexai/tests/integration_tests/test_embeddings.py index af109ed3..ffbcaa2c 100644 --- a/libs/vertexai/tests/integration_tests/test_embeddings.py +++ b/libs/vertexai/tests/integration_tests/test_embeddings.py @@ -9,7 +9,6 @@ from vertexai.vision_models import MultiModalEmbeddingModel # type: ignore from langchain_google_vertexai.embeddings import ( - EmbeddingTaskTypes, GoogleEmbeddingModelType, VertexAIEmbeddings, ) @@ -19,11 +18,6 @@ ("multimodalembedding@001", 1408), ] -_EMBEDDING_TASK_TYPES = [ - "RETRIEVAL_QUERY", - "RETRIEVAL_DOCUMENT", -] - @pytest.mark.release def test_initialization() -> None: @@ -60,22 +54,17 @@ def test_langchain_google_vertexai_embedding_documents( @pytest.mark.release -@pytest.mark.parametrize( - "embeddings_task_type", - _EMBEDDING_TASK_TYPES, -) @pytest.mark.parametrize( "model_name, embeddings_dim", _EMBEDDING_MODELS, ) def test_langchain_google_vertexai_embedding_documents_with_task_type( - embeddings_task_type: EmbeddingTaskTypes, model_name: str, embeddings_dim: int, ) -> None: documents = ["foo bar"] * 8 model = VertexAIEmbeddings(model_name) - output = model.embed_documents(documents, embeddings_task_type=embeddings_task_type) + output = model.embed_documents(documents) assert len(output) == 8 for embedding in output: assert len(embedding) == embeddings_dim @@ -96,22 +85,17 @@ def test_langchain_google_vertexai_embedding_query(model_name, embeddings_dim) - @pytest.mark.release -@pytest.mark.parametrize( - "embeddings_task_type", - _EMBEDDING_TASK_TYPES, -) @pytest.mark.parametrize( "model_name, embeddings_dim", _EMBEDDING_MODELS, ) def test_langchain_google_vertexai_embedding_query_with_task_type( - embeddings_task_type: EmbeddingTaskTypes, model_name: str, embeddings_dim: int, ) -> None: document = "foo bar" model = VertexAIEmbeddings(model_name) - output = model.embed_query(document, embeddings_task_type=embeddings_task_type) + output = model.embed_query(document) assert len(output) == embeddings_dim diff --git a/libs/vertexai/tests/unit_tests/test_embeddings.py b/libs/vertexai/tests/unit_tests/test_embeddings.py index d41cc147..c5ab0fa2 100644 --- a/libs/vertexai/tests/unit_tests/test_embeddings.py +++ b/libs/vertexai/tests/unit_tests/test_embeddings.py @@ -1,5 +1,5 @@ from typing import Any, Dict -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from pydantic import model_validator @@ -29,6 +29,49 @@ def test_langchain_google_vertexai_no_dups_dynamic_batch_size() -> None: assert len(batches) == 2 +@patch.object(VertexAIEmbeddings, "embed") +def test_embed_documents_with_question_answering_task(mock_embed) -> None: + mock_embeddings = MockVertexAIEmbeddings("text-embedding-005") + texts = [f"text {i}" for i in range(5)] + + embedding_dimension = 768 + embeddings_task_type = "QUESTION_ANSWERING" + + mock_embed.return_value = [[0.001] * embedding_dimension for _ in texts] + + embeddings = mock_embeddings.embed_documents( + texts=texts, embeddings_task_type=embeddings_task_type + ) + + assert isinstance(embeddings, list) + assert len(embeddings) == len(texts) + assert len(embeddings[0]) == embedding_dimension + + # Verify embed() was called correctly + mock_embed.assert_called_once_with(texts, 0, embeddings_task_type) + + +@patch.object(VertexAIEmbeddings, "embed") +def test_embed_query_with_question_answering_task(mock_embed) -> None: + mock_embeddings = MockVertexAIEmbeddings("text-embedding-005") + text = "text 0" + + embedding_dimension = 768 + embeddings_task_type = "QUESTION_ANSWERING" + + mock_embed.return_value = [[0.001] * embedding_dimension] + + embedding = mock_embeddings.embed_query( + text=text, embeddings_task_type=embeddings_task_type + ) + + assert isinstance(embedding, list) + assert len(embedding) == embedding_dimension + + # Verify embed() was called correctly + mock_embed.assert_called_once_with([text], 1, embeddings_task_type) + + class MockVertexAIEmbeddings(VertexAIEmbeddings): """ A mock class for avoiding instantiating VertexAI and the EmbeddingModel client