Skip to content

Commit

Permalink
fix: Moved tests into unit_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Anshuman Mitra committed Jan 29, 2025
1 parent e9fa377 commit 9caf9ab
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 19 deletions.
20 changes: 2 additions & 18 deletions libs/vertexai/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from vertexai.vision_models import MultiModalEmbeddingModel # type: ignore

from langchain_google_vertexai.embeddings import (
EmbeddingTaskTypes,
GoogleEmbeddingModelType,
VertexAIEmbeddings,
)
Expand All @@ -19,11 +18,6 @@
("multimodalembedding@001", 1408),
]

_EMBEDDING_TASK_TYPES = [
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
]


@pytest.mark.release
def test_initialization() -> None:
Expand Down Expand Up @@ -60,22 +54,17 @@ def test_langchain_google_vertexai_embedding_documents(


@pytest.mark.release
@pytest.mark.parametrize(
"embeddings_task_type",
_EMBEDDING_TASK_TYPES,
)
@pytest.mark.parametrize(
"model_name, embeddings_dim",
_EMBEDDING_MODELS,
)
def test_langchain_google_vertexai_embedding_documents_with_task_type(
embeddings_task_type: EmbeddingTaskTypes,
model_name: str,
embeddings_dim: int,
) -> None:
documents = ["foo bar"] * 8
model = VertexAIEmbeddings(model_name)
output = model.embed_documents(documents, embeddings_task_type=embeddings_task_type)
output = model.embed_documents(documents)
assert len(output) == 8
for embedding in output:
assert len(embedding) == embeddings_dim
Expand All @@ -96,22 +85,17 @@ def test_langchain_google_vertexai_embedding_query(model_name, embeddings_dim) -


@pytest.mark.release
@pytest.mark.parametrize(
"embeddings_task_type",
_EMBEDDING_TASK_TYPES,
)
@pytest.mark.parametrize(
"model_name, embeddings_dim",
_EMBEDDING_MODELS,
)
def test_langchain_google_vertexai_embedding_query_with_task_type(
embeddings_task_type: EmbeddingTaskTypes,
model_name: str,
embeddings_dim: int,
) -> None:
document = "foo bar"
model = VertexAIEmbeddings(model_name)
output = model.embed_query(document, embeddings_task_type=embeddings_task_type)
output = model.embed_query(document)
assert len(output) == embeddings_dim


Expand Down
45 changes: 44 additions & 1 deletion libs/vertexai/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, Dict
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest
from pydantic import model_validator
Expand Down Expand Up @@ -29,6 +29,49 @@ def test_langchain_google_vertexai_no_dups_dynamic_batch_size() -> None:
assert len(batches) == 2


@patch.object(VertexAIEmbeddings, "embed")
def test_embed_documents_with_question_answering_task(mock_embed) -> None:
mock_embeddings = MockVertexAIEmbeddings("text-embedding-005")
texts = [f"text {i}" for i in range(5)]

embedding_dimension = 768
embeddings_task_type = "QUESTION_ANSWERING"

mock_embed.return_value = [[0.001] * embedding_dimension for _ in texts]

embeddings = mock_embeddings.embed_documents(
texts=texts, embeddings_task_type=embeddings_task_type
)

assert isinstance(embeddings, list)
assert len(embeddings) == len(texts)
assert len(embeddings[0]) == embedding_dimension

# Verify embed() was called correctly
mock_embed.assert_called_once_with(texts, 0, embeddings_task_type)


@patch.object(VertexAIEmbeddings, "embed")
def test_embed_query_with_question_answering_task(mock_embed) -> None:
mock_embeddings = MockVertexAIEmbeddings("text-embedding-005")
text = "text 0"

embedding_dimension = 768
embeddings_task_type = "QUESTION_ANSWERING"

mock_embed.return_value = [[0.001] * embedding_dimension]

embedding = mock_embeddings.embed_query(
text=text, embeddings_task_type=embeddings_task_type
)

assert isinstance(embedding, list)
assert len(embedding) == embedding_dimension

# Verify embed() was called correctly
mock_embed.assert_called_once_with([text], 1, embeddings_task_type)


class MockVertexAIEmbeddings(VertexAIEmbeddings):
"""
A mock class for avoiding instantiating VertexAI and the EmbeddingModel client
Expand Down

0 comments on commit 9caf9ab

Please sign in to comment.