Skip to content

Commit 0182667

Browse files
committed
feat: add mock embedding model for testing embedding providers
1 parent 51491e3 commit 0182667

File tree

1 file changed

+161
-0
lines changed

1 file changed

+161
-0
lines changed

tests/test_embeddign_providers.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import asyncio
17+
from typing import List
18+
19+
import pytest
20+
21+
from nemoguardrails.embeddings.providers import (
22+
init_embedding_model,
23+
register_embedding_provider,
24+
)
25+
from nemoguardrails.embeddings.providers.base import EmbeddingModel
26+
27+
SUPPORTED_PARAMS = {"param1", "param2"}
28+
29+
30+
class MockEmbeddingModel(EmbeddingModel):
31+
"""Mock embedding model for testing purposes.
32+
33+
Supported embedding models:
34+
- mock-embedding-small: Embedding size of 128.
35+
- mock-embedding-large: Embedding size of 256.
36+
Supported parameters:
37+
- param1
38+
- param2
39+
40+
Args:
41+
embedding_model (str): The name of the embedding model.
42+
43+
Attributes:
44+
model (str): The name of the embedding model.
45+
embedding_size (int): The size of the embeddings.
46+
47+
Methods:
48+
encode: Encode a list of documents into embeddings.
49+
"""
50+
51+
engine_name = "mock_engine"
52+
53+
def __init__(self, embedding_model: str, **kwargs):
54+
self.model = embedding_model
55+
self.embedding_size_dict = {
56+
"mock-embedding-small": 128,
57+
"mock-embedding-large": 256,
58+
}
59+
60+
self.embedding_params = kwargs
61+
62+
if self.model not in self.embedding_size_dict:
63+
raise ValueError(f"Invalid embedding model: {self.model}")
64+
65+
supported_params = SUPPORTED_PARAMS
66+
67+
for param in self.embedding_params:
68+
if param not in supported_params:
69+
raise ValueError(f"Unsupported parameter: {param}")
70+
71+
self.embedding_size = self.embedding_size_dict[self.model]
72+
73+
async def encode_async(self, documents: List[str]) -> List[List[float]]:
74+
"""Encode a list of documents into embeddings asynchronously.
75+
76+
Args:
77+
documents (List[str]): The list of documents to be encoded.
78+
79+
Returns:
80+
List[List[float]]: The encoded embeddings.
81+
"""
82+
return await asyncio.get_running_loop().run_in_executor(
83+
None, self.encode, documents
84+
)
85+
86+
def encode(self, documents: List[str]) -> List[List[float]]:
87+
"""Encode a list of documents into embeddings.
88+
89+
Args:
90+
documents (List[str]): The list of documents to be encoded.
91+
92+
Returns:
93+
List[List[float]]: The encoded embeddings.
94+
"""
95+
return [[float(i) for i in range(self.embedding_size)] for _ in documents]
96+
97+
98+
register_embedding_provider(MockEmbeddingModel)
99+
100+
101+
def test_init_embedding_model_with_params():
102+
embedding_model = "mock-embedding-small"
103+
embedding_engine = "mock_engine"
104+
supported_param = next(iter(SUPPORTED_PARAMS))
105+
embedding_params = {supported_param: "value1"}
106+
model = init_embedding_model(embedding_model, embedding_engine, embedding_params)
107+
assert isinstance(model, MockEmbeddingModel)
108+
assert model.model == embedding_model
109+
assert model.embedding_size == 128
110+
assert model.engine_name == embedding_engine
111+
assert model.embedding_params == embedding_params
112+
113+
114+
def test_init_embedding_model_without_params():
115+
embedding_model = "mock-embedding-large"
116+
embedding_engine = "mock_engine"
117+
model = init_embedding_model(embedding_model, embedding_engine)
118+
assert isinstance(model, MockEmbeddingModel)
119+
assert model.model == embedding_model
120+
assert model.embedding_size == 256
121+
assert model.engine_name == embedding_engine
122+
assert model.embedding_params == {}
123+
124+
125+
def test_init_embedding_model_with_unsupported_params():
126+
embedding_model = "mock-embedding-small"
127+
embedding_engine = "mock_engine"
128+
embedding_params = {"unsupported_param": "value"}
129+
with pytest.raises(ValueError, match="Unsupported parameter: unsupported_param"):
130+
init_embedding_model(embedding_model, embedding_engine, embedding_params)
131+
132+
133+
def test_init_embedding_model_with_invalid_model():
134+
embedding_model = "invalid_model"
135+
embedding_engine = "mock_engine"
136+
embedding_params = {"param1": "value1"}
137+
with pytest.raises(ValueError, match="Invalid embedding model: invalid_model"):
138+
init_embedding_model(embedding_model, embedding_engine, embedding_params)
139+
140+
141+
def test_encode_method():
142+
embedding_model = "mock-embedding-small"
143+
embedding_engine = "mock_engine"
144+
model = init_embedding_model(embedding_model, embedding_engine)
145+
assert isinstance(model, MockEmbeddingModel)
146+
documents = ["doc1", "doc2", "doc3"]
147+
embeddings = model.encode(documents)
148+
assert len(embeddings) == len(documents)
149+
assert len(embeddings[0]) == model.embedding_size
150+
151+
152+
@pytest.mark.asyncio
153+
async def test_encode_async_method():
154+
embedding_model = "mock-embedding-large"
155+
embedding_engine = "mock_engine"
156+
model = init_embedding_model(embedding_model, embedding_engine)
157+
assert isinstance(model, MockEmbeddingModel)
158+
documents = ["doc1", "doc2", "doc3"]
159+
embeddings = await model.encode_async(documents)
160+
assert len(embeddings) == len(documents)
161+
assert len(embeddings[0]) == model.embedding_size

0 commit comments

Comments
 (0)