Skip to content

Commit 2ede0f4

Browse files
authored
feat: add embedding_params to BasicEmbeddingsIndex (#898)
* feat: add embedding_params to BasicEmbeddingsIndex - Added `embedding_params` attribute to `BasicEmbeddingsIndex` class. - Updated the constructor to accept `embedding_params`. - Modified `_init_model` method to pass `embedding_params` to `init_embedding_model`. - Updated `init_embedding_model` function to handle `embedding_params`. - Adjusted `NIMEmbeddingModel` and `OpenAIEmbeddingModel` to accept additional parameters. - Updated `LLMRails` to handle default embedding parameters. * feat: add mock embedding model for testing embedding providers * feat: add kwargs to FastEmbedEmbeddingModel init * feat: add kwargs to SentenceTransformerEmbeddingModel init * test: add test for additional params in FastEmbed
1 parent 466fd06 commit 2ede0f4

File tree

9 files changed

+200
-12
lines changed

9 files changed

+200
-12
lines changed

nemoguardrails/embeddings/basic.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):
4747

4848
embedding_model: str
4949
embedding_engine: str
50+
embedding_params: Dict[str, Any]
5051
index: AnnoyIndex
5152
embedding_size: int
5253
cache_config: EmbeddingsCacheConfig
@@ -60,6 +61,7 @@ def __init__(
6061
self,
6162
embedding_model=None,
6263
embedding_engine=None,
64+
embedding_params=None,
6365
index=None,
6466
cache_config: Union[EmbeddingsCacheConfig, Dict[str, Any]] = None,
6567
search_threshold: float = None,
@@ -83,6 +85,7 @@ def __init__(
8385
self._embeddings = []
8486
self.embedding_model = embedding_model
8587
self.embedding_engine = embedding_engine
88+
self.embedding_params = embedding_params or {}
8689
self._embedding_size = 0
8790
self.search_threshold = search_threshold or float("inf")
8891
if isinstance(cache_config, Dict):
@@ -132,7 +135,9 @@ def embeddings_index(self, index):
132135
def _init_model(self):
133136
"""Initialize the model used for computing the embeddings."""
134137
self._model = init_embedding_model(
135-
embedding_model=self.embedding_model, embedding_engine=self.embedding_engine
138+
embedding_model=self.embedding_model,
139+
embedding_engine=self.embedding_engine,
140+
embedding_params=self.embedding_params,
136141
)
137142

138143
@cache_embeddings

nemoguardrails/embeddings/providers/__init__.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,15 @@ def register_embedding_provider(
7070
register_embedding_provider(nim.NVIDIAAIEndpointsEmbeddingModel)
7171

7272

73-
def init_embedding_model(embedding_model: str, embedding_engine: str) -> EmbeddingModel:
73+
def init_embedding_model(
74+
embedding_model: str, embedding_engine: str, embedding_params: dict = {}
75+
) -> EmbeddingModel:
7476
"""Initialize the embedding model.
7577
7678
Args:
7779
embedding_model (str): The path or name of the embedding model.
7880
embedding_engine (str): The name of the embedding engine.
81+
embedding_params (dict): Additional parameters for the embedding model.
7982
8083
Returns:
8184
EmbeddingModel: An instance of the initialized embedding model.
@@ -84,10 +87,16 @@ def init_embedding_model(embedding_model: str, embedding_engine: str) -> Embeddi
8487
ValueError: If the embedding engine is invalid.
8588
"""
8689

87-
model_key = f"{embedding_engine}-{embedding_model}"
90+
embedding_params_str = (
91+
"_".join([f"{key}={value}" for key, value in embedding_params.items()])
92+
or "default"
93+
)
94+
95+
model_key = f"{embedding_engine}-{embedding_model}-{embedding_params_str}"
8896

8997
if model_key not in _embedding_model_cache:
90-
model = EmbeddingProviderRegistry().get(embedding_engine)(embedding_model)
98+
provider_class = EmbeddingProviderRegistry().get(embedding_engine)
99+
model = provider_class(embedding_model=embedding_model, **embedding_params)
91100
_embedding_model_cache[model_key] = model
92101

93102
return _embedding_model_cache[model_key]

nemoguardrails/embeddings/providers/fastembed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class FastEmbedEmbeddingModel(EmbeddingModel):
4141

4242
engine_name = "FastEmbed"
4343

44-
def __init__(self, embedding_model: str):
44+
def __init__(self, embedding_model: str, **kwargs):
4545
from fastembed import TextEmbedding as Embedding
4646

4747
# Enabling a short form model name for all-MiniLM-L6-v2.
@@ -50,13 +50,13 @@ def __init__(self, embedding_model: str):
5050
self.embedding_model = embedding_model
5151

5252
try:
53-
self.model = Embedding(embedding_model)
53+
self.model = Embedding(embedding_model, **kwargs)
5454
except ValueError as ex:
5555
# Sometimes the cached model in the temporary folder gets removed,
5656
# but the folder still exists, which causes an error. In this case,
5757
# we fall back to an explicit cache directory.
5858
if "Could not find model.onnx in" in str(ex):
59-
self.model = Embedding(embedding_model, cache_dir=".cache")
59+
self.model = Embedding(embedding_model, cache_dir=".cache", **kwargs)
6060
else:
6161
raise ex
6262

nemoguardrails/embeddings/providers/nim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ class NIMEmbeddingModel(EmbeddingModel):
3333

3434
engine_name = "nim"
3535

36-
def __init__(self, embedding_model: str):
36+
def __init__(self, embedding_model: str, **kwargs):
3737
try:
3838
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
3939

4040
self.model = embedding_model
41-
self.document_embedder = NVIDIAEmbeddings(model=embedding_model)
41+
self.document_embedder = NVIDIAEmbeddings(model=embedding_model, **kwargs)
4242

4343
except ImportError:
4444
raise ImportError(

nemoguardrails/embeddings/providers/openai.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class OpenAIEmbeddingModel(EmbeddingModel):
4343
def __init__(
4444
self,
4545
embedding_model: str,
46+
**kwargs,
4647
):
4748
try:
4849
import openai
@@ -59,7 +60,7 @@ def __init__(
5960
)
6061

6162
self.model = embedding_model
62-
self.client = OpenAI()
63+
self.client = OpenAI(**kwargs)
6364

6465
self.embedding_size_dict = {
6566
"text-embedding-ada-002": 1536,

nemoguardrails/embeddings/providers/sentence_transformers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
4141

4242
engine_name = "SentenceTransformers"
4343

44-
def __init__(self, embedding_model: str):
44+
def __init__(self, embedding_model: str, **kwargs):
4545
try:
4646
from sentence_transformers import SentenceTransformer
4747
except ImportError:
@@ -58,7 +58,7 @@ def __init__(self, embedding_model: str):
5858
)
5959

6060
device = "cuda" if cuda.is_available() else "cpu"
61-
self.model = SentenceTransformer(embedding_model, device=device)
61+
self.model = SentenceTransformer(embedding_model, device=device, **kwargs)
6262
# Get the embedding dimension of the model
6363
self.embedding_size = self.model.get_sentence_embedding_dimension()
6464

nemoguardrails/rails/llm/llmrails.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
# The default embeddings model is using FastEmbed
104104
self.default_embedding_model = "all-MiniLM-L6-v2"
105105
self.default_embedding_engine = "FastEmbed"
106+
self.default_embedding_params = {}
106107

107108
# We keep a cache of the events history associated with a sequence of user messages.
108109
# TODO: when we update the interface to allow to return a "state object", this
@@ -212,6 +213,7 @@ def __init__(
212213
if model.type == "embeddings":
213214
self.default_embedding_model = model.model
214215
self.default_embedding_engine = model.engine
216+
self.default_embedding_params = model.parameters or {}
215217
break
216218

217219
# InteractionLogAdapters used for tracing
@@ -433,6 +435,9 @@ def _get_embeddings_search_provider_instance(
433435
embedding_engine=esp_config.parameters.get(
434436
"embedding_engine", self.default_embedding_engine
435437
),
438+
embedding_params=esp_config.parameters.get(
439+
"embedding_parameters", self.default_embedding_params
440+
),
436441
cache_config=esp_config.cache,
437442
# We make sure we also pass additional relevant params.
438443
**{

tests/test_embedding_providers.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import asyncio
17+
from typing import List
18+
19+
import pytest
20+
21+
from nemoguardrails.embeddings.providers import (
22+
init_embedding_model,
23+
register_embedding_provider,
24+
)
25+
from nemoguardrails.embeddings.providers.base import EmbeddingModel
26+
27+
SUPPORTED_PARAMS = {"param1", "param2"}
28+
29+
30+
class MockEmbeddingModel(EmbeddingModel):
31+
"""Mock embedding model for testing purposes.
32+
33+
Supported embedding models:
34+
- mock-embedding-small: Embedding size of 128.
35+
- mock-embedding-large: Embedding size of 256.
36+
Supported parameters:
37+
- param1
38+
- param2
39+
40+
Args:
41+
embedding_model (str): The name of the embedding model.
42+
43+
Attributes:
44+
model (str): The name of the embedding model.
45+
embedding_size (int): The size of the embeddings.
46+
47+
Methods:
48+
encode: Encode a list of documents into embeddings.
49+
"""
50+
51+
engine_name = "mock_engine"
52+
53+
def __init__(self, embedding_model: str, **kwargs):
54+
self.model = embedding_model
55+
self.embedding_size_dict = {
56+
"mock-embedding-small": 128,
57+
"mock-embedding-large": 256,
58+
}
59+
60+
self.embedding_params = kwargs
61+
62+
if self.model not in self.embedding_size_dict:
63+
raise ValueError(f"Invalid embedding model: {self.model}")
64+
65+
supported_params = SUPPORTED_PARAMS
66+
67+
for param in self.embedding_params:
68+
if param not in supported_params:
69+
raise ValueError(f"Unsupported parameter: {param}")
70+
71+
self.embedding_size = self.embedding_size_dict[self.model]
72+
73+
async def encode_async(self, documents: List[str]) -> List[List[float]]:
74+
"""Encode a list of documents into embeddings asynchronously.
75+
76+
Args:
77+
documents (List[str]): The list of documents to be encoded.
78+
79+
Returns:
80+
List[List[float]]: The encoded embeddings.
81+
"""
82+
return await asyncio.get_running_loop().run_in_executor(
83+
None, self.encode, documents
84+
)
85+
86+
def encode(self, documents: List[str]) -> List[List[float]]:
87+
"""Encode a list of documents into embeddings.
88+
89+
Args:
90+
documents (List[str]): The list of documents to be encoded.
91+
92+
Returns:
93+
List[List[float]]: The encoded embeddings.
94+
"""
95+
return [[float(i) for i in range(self.embedding_size)] for _ in documents]
96+
97+
98+
register_embedding_provider(MockEmbeddingModel)
99+
100+
101+
def test_init_embedding_model_with_params():
102+
embedding_model = "mock-embedding-small"
103+
embedding_engine = "mock_engine"
104+
supported_param = next(iter(SUPPORTED_PARAMS))
105+
embedding_params = {supported_param: "value1"}
106+
model = init_embedding_model(embedding_model, embedding_engine, embedding_params)
107+
assert isinstance(model, MockEmbeddingModel)
108+
assert model.model == embedding_model
109+
assert model.embedding_size == 128
110+
assert model.engine_name == embedding_engine
111+
assert model.embedding_params == embedding_params
112+
113+
114+
def test_init_embedding_model_without_params():
115+
embedding_model = "mock-embedding-large"
116+
embedding_engine = "mock_engine"
117+
model = init_embedding_model(embedding_model, embedding_engine)
118+
assert isinstance(model, MockEmbeddingModel)
119+
assert model.model == embedding_model
120+
assert model.embedding_size == 256
121+
assert model.engine_name == embedding_engine
122+
assert model.embedding_params == {}
123+
124+
125+
def test_init_embedding_model_with_unsupported_params():
126+
embedding_model = "mock-embedding-small"
127+
embedding_engine = "mock_engine"
128+
embedding_params = {"unsupported_param": "value"}
129+
with pytest.raises(ValueError, match="Unsupported parameter: unsupported_param"):
130+
init_embedding_model(embedding_model, embedding_engine, embedding_params)
131+
132+
133+
def test_init_embedding_model_with_invalid_model():
134+
embedding_model = "invalid_model"
135+
embedding_engine = "mock_engine"
136+
embedding_params = {"param1": "value1"}
137+
with pytest.raises(ValueError, match="Invalid embedding model: invalid_model"):
138+
init_embedding_model(embedding_model, embedding_engine, embedding_params)
139+
140+
141+
def test_encode_method():
142+
embedding_model = "mock-embedding-small"
143+
embedding_engine = "mock_engine"
144+
model = init_embedding_model(embedding_model, embedding_engine)
145+
assert isinstance(model, MockEmbeddingModel)
146+
documents = ["doc1", "doc2", "doc3"]
147+
embeddings = model.encode(documents)
148+
assert len(embeddings) == len(documents)
149+
assert len(embeddings[0]) == model.embedding_size
150+
151+
152+
@pytest.mark.asyncio
153+
async def test_encode_async_method():
154+
embedding_model = "mock-embedding-large"
155+
embedding_engine = "mock_engine"
156+
model = init_embedding_model(embedding_model, embedding_engine)
157+
assert isinstance(model, MockEmbeddingModel)
158+
documents = ["doc1", "doc2", "doc3"]
159+
embeddings = await model.encode_async(documents)
160+
assert len(embeddings) == len(documents)
161+
assert len(embeddings[0]) == model.embedding_size

tests/test_embeddings_fastembed.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ def test_sync_embeddings():
2626
assert len(result[0]) == 384
2727

2828

29+
def test_additional_params_with_fastembed():
30+
model = FastEmbedEmbeddingModel("all-MiniLM-L6-v2", max_length=512, lazy_load=True)
31+
result = model.encode(["test"])
32+
33+
assert len(result[0]) == 384
34+
35+
2936
@pytest.mark.asyncio
3037
async def test_async_embeddings():
3138
model = FastEmbedEmbeddingModel("all-MiniLM-L6-v2")

0 commit comments

Comments
 (0)