Skip to content

Commit 51491e3

Browse files
committed
feat: add embedding_params to BasicEmbeddingsIndex
- Added `embedding_params` attribute to `BasicEmbeddingsIndex` class. - Updated the constructor to accept `embedding_params`. - Modified `_init_model` method to pass `embedding_params` to `init_embedding_model`. - Updated `init_embedding_model` function to handle `embedding_params`. - Adjusted `NIMEmbeddingModel` and `OpenAIEmbeddingModel` to accept additional parameters. - Updated `LLMRails` to handle default embedding parameters. improve style
1 parent 61ea72e commit 51491e3

File tree

5 files changed

+27
-7
lines changed

5 files changed

+27
-7
lines changed

nemoguardrails/embeddings/basic.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):
4747

4848
embedding_model: str
4949
embedding_engine: str
50+
embedding_params: Dict[str, Any]
5051
index: AnnoyIndex
5152
embedding_size: int
5253
cache_config: EmbeddingsCacheConfig
@@ -60,6 +61,7 @@ def __init__(
6061
self,
6162
embedding_model=None,
6263
embedding_engine=None,
64+
embedding_params=None,
6365
index=None,
6466
cache_config: Union[EmbeddingsCacheConfig, Dict[str, Any]] = None,
6567
search_threshold: float = None,
@@ -83,6 +85,7 @@ def __init__(
8385
self._embeddings = []
8486
self.embedding_model = embedding_model
8587
self.embedding_engine = embedding_engine
88+
self.embedding_params = embedding_params or {}
8689
self._embedding_size = 0
8790
self.search_threshold = search_threshold or float("inf")
8891
if isinstance(cache_config, Dict):
@@ -132,7 +135,9 @@ def embeddings_index(self, index):
132135
def _init_model(self):
133136
"""Initialize the model used for computing the embeddings."""
134137
self._model = init_embedding_model(
135-
embedding_model=self.embedding_model, embedding_engine=self.embedding_engine
138+
embedding_model=self.embedding_model,
139+
embedding_engine=self.embedding_engine,
140+
embedding_params=self.embedding_params,
136141
)
137142

138143
@cache_embeddings

nemoguardrails/embeddings/providers/__init__.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,15 @@ def register_embedding_provider(
7070
register_embedding_provider(nim.NVIDIAAIEndpointsEmbeddingModel)
7171

7272

73-
def init_embedding_model(embedding_model: str, embedding_engine: str) -> EmbeddingModel:
73+
def init_embedding_model(
74+
embedding_model: str, embedding_engine: str, embedding_params: dict = {}
75+
) -> EmbeddingModel:
7476
"""Initialize the embedding model.
7577
7678
Args:
7779
embedding_model (str): The path or name of the embedding model.
7880
embedding_engine (str): The name of the embedding engine.
81+
embedding_params (dict): Additional parameters for the embedding model.
7982
8083
Returns:
8184
EmbeddingModel: An instance of the initialized embedding model.
@@ -84,10 +87,16 @@ def init_embedding_model(embedding_model: str, embedding_engine: str) -> Embeddi
8487
ValueError: If the embedding engine is invalid.
8588
"""
8689

87-
model_key = f"{embedding_engine}-{embedding_model}"
90+
embedding_params_str = (
91+
"_".join([f"{key}={value}" for key, value in embedding_params.items()])
92+
or "default"
93+
)
94+
95+
model_key = f"{embedding_engine}-{embedding_model}-{embedding_params_str}"
8896

8997
if model_key not in _embedding_model_cache:
90-
model = EmbeddingProviderRegistry().get(embedding_engine)(embedding_model)
98+
provider_class = EmbeddingProviderRegistry().get(embedding_engine)
99+
model = provider_class(embedding_model=embedding_model, **embedding_params)
91100
_embedding_model_cache[model_key] = model
92101

93102
return _embedding_model_cache[model_key]

nemoguardrails/embeddings/providers/nim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ class NIMEmbeddingModel(EmbeddingModel):
3333

3434
engine_name = "nim"
3535

36-
def __init__(self, embedding_model: str):
36+
def __init__(self, embedding_model: str, **kwargs):
3737
try:
3838
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
3939

4040
self.model = embedding_model
41-
self.document_embedder = NVIDIAEmbeddings(model=embedding_model)
41+
self.document_embedder = NVIDIAEmbeddings(model=embedding_model, **kwargs)
4242

4343
except ImportError:
4444
raise ImportError(

nemoguardrails/embeddings/providers/openai.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class OpenAIEmbeddingModel(EmbeddingModel):
4343
def __init__(
4444
self,
4545
embedding_model: str,
46+
**kwargs,
4647
):
4748
try:
4849
import openai
@@ -59,7 +60,7 @@ def __init__(
5960
)
6061

6162
self.model = embedding_model
62-
self.client = OpenAI()
63+
self.client = OpenAI(**kwargs)
6364

6465
self.embedding_size_dict = {
6566
"text-embedding-ada-002": 1536,

nemoguardrails/rails/llm/llmrails.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
# The default embeddings model is using FastEmbed
104104
self.default_embedding_model = "all-MiniLM-L6-v2"
105105
self.default_embedding_engine = "FastEmbed"
106+
self.default_embedding_params = {}
106107

107108
# We keep a cache of the events history associated with a sequence of user messages.
108109
# TODO: when we update the interface to allow to return a "state object", this
@@ -212,6 +213,7 @@ def __init__(
212213
if model.type == "embeddings":
213214
self.default_embedding_model = model.model
214215
self.default_embedding_engine = model.engine
216+
self.default_embedding_params = model.parameters or {}
215217
break
216218

217219
# InteractionLogAdapters used for tracing
@@ -429,6 +431,9 @@ def _get_embeddings_search_provider_instance(
429431
embedding_engine=esp_config.parameters.get(
430432
"embedding_engine", self.default_embedding_engine
431433
),
434+
embedding_params=esp_config.parameters.get(
435+
"embedding_parameters", self.default_embedding_params
436+
),
432437
cache_config=esp_config.cache,
433438
# We make sure we also pass additional relevant params.
434439
**{

0 commit comments

Comments
 (0)