Skip to content

Commit ec07145

Browse files
committed
Refactor imports to allow using without Annoy/SentenceTransformers when using custom embedding search.
1 parent 57f2b6c commit ec07145

File tree

4 files changed

+9
-3
lines changed

4 files changed

+9
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1919

2020
- Moved to using `nest_asyncio` for [implementing the blocking API](./docs/user_guide/advanced/nested-async-loop.md). Fixes [#3](https://github.com/NVIDIA/NeMo-Guardrails/issues/3) and [#32](https://github.com/NVIDIA/NeMo-Guardrails/issues/32).
2121
- Improved event property validation in `new_event_dict`.
22+
- Refactored imports to allow installing from source without Annoy/SentenceTransformers (would need a custom embedding search provider to work).
2223

2324
### Fixed
2425

nemoguardrails/embeddings/basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from typing import List
1717

1818
from annoy import AnnoyIndex
19-
from sentence_transformers import SentenceTransformer
2019
from torch import cuda
2120

2221
from nemoguardrails.embeddings.index import EmbeddingModel, EmbeddingsIndex, IndexItem
@@ -115,6 +114,8 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
115114
"""Embedding model using sentence-transformers."""
116115

117116
def __init__(self, embedding_model: str):
117+
from sentence_transformers import SentenceTransformer
118+
118119
device = "cuda" if cuda.is_available() else "cpu"
119120
self.model = SentenceTransformer(embedding_model, device=device)
120121
# Get the embedding dimension of the model

nemoguardrails/kb/kb.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from annoy import AnnoyIndex
2323

24-
from nemoguardrails.embeddings.basic import BasicEmbeddingsIndex
2524
from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem
2625
from nemoguardrails.kb.utils import split_markdown_in_topic_chunks
2726
from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, KnowledgeBaseConfig
@@ -89,6 +88,8 @@ async def build(self):
8988
and os.path.exists(cache_file)
9089
and os.path.exists(embedding_size_file)
9190
):
91+
from nemoguardrails.embeddings.basic import BasicEmbeddingsIndex
92+
9293
log.info(cache_file)
9394
self.index = cast(
9495
BasicEmbeddingsIndex,
@@ -116,6 +117,8 @@ async def build(self):
116117
# For the default Embedding Search provider, which uses annoy, we also
117118
# persist the index after it's computed.
118119
if self.config.embedding_search_provider.name == "default":
120+
from nemoguardrails.embeddings.basic import BasicEmbeddingsIndex
121+
119122
# We also save the file for future use
120123
os.makedirs(CACHE_FOLDER, exist_ok=True)
121124
basic_index = cast(BasicEmbeddingsIndex, self.index)

nemoguardrails/rails/llm/llmrails.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from nemoguardrails.actions.math import wolfram_alpha_request
3232
from nemoguardrails.actions.output_moderation import output_moderation
3333
from nemoguardrails.actions.retrieve_relevant_chunks import retrieve_relevant_chunks
34-
from nemoguardrails.embeddings.basic import BasicEmbeddingsIndex
3534
from nemoguardrails.embeddings.index import EmbeddingsIndex
3635
from nemoguardrails.flows.runtime import Runtime
3736
from nemoguardrails.kb.kb import KnowledgeBase
@@ -227,6 +226,8 @@ def _get_embeddings_search_provider_instance(
227226
esp_config = EmbeddingSearchProvider()
228227

229228
if esp_config.name == "default":
229+
from nemoguardrails.embeddings.basic import BasicEmbeddingsIndex
230+
230231
return BasicEmbeddingsIndex(
231232
embedding_model=esp_config.parameters.get(
232233
"embedding_model", self.default_embedding_model

0 commit comments

Comments
 (0)