Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions nemoguardrails/actions/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ def __init__(self, config: RailsConfig, llm: BaseLLM, verbose: bool = False):
self.kb = None
self._init_kb()

# If we have a customized embedding model, we'll use it.
self.embedding_model = "all-MiniLM-L6-v2"
for model in self.config.models:
if 'embedding' in model.type:
self.embedding_model = model.model
assert model.engine == "SentenceTransformer"
break

def _init_user_message_index(self):
"""Initializes the index of user messages."""

Expand All @@ -86,7 +94,7 @@ def _init_user_message_index(self):
if len(items) == 0:
return

self.user_message_index = BasicEmbeddingsIndex()
self.user_message_index = BasicEmbeddingsIndex(self.embedding_model)
self.user_message_index.add_items(items)

# NOTE: this should be very fast, otherwise needs to be moved to separate thread.
Expand All @@ -107,7 +115,7 @@ def _init_bot_message_index(self):
if len(items) == 0:
return

self.bot_message_index = BasicEmbeddingsIndex()
self.bot_message_index = BasicEmbeddingsIndex(self.embedding_model)
self.bot_message_index.add_items(items)

# NOTE: this should be very fast, otherwise needs to be moved to separate thread.
Expand Down Expand Up @@ -141,7 +149,7 @@ def _init_flows_index(self):
if len(items) == 0:
return

self.flows_index = BasicEmbeddingsIndex()
self.flows_index = BasicEmbeddingsIndex(self.embedding_model)
self.flows_index.add_items(items)

# NOTE: this should be very fast, otherwise needs to be moved to separate thread.
Expand All @@ -154,7 +162,7 @@ def _init_kb(self):
return

documents = [doc.content for doc in self.config.docs]
self.kb = KnowledgeBase(documents=documents)
self.kb = KnowledgeBase(documents=documents, embedding_model=self.embedding_model)
self.kb.init()
self.kb.build()

Expand Down
5 changes: 3 additions & 2 deletions nemoguardrails/kb/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):
It uses Annoy to perform the search.
"""

def __init__(self, index=None):
def __init__(self, embedding_model=None, index=None):
self._model = None
self._items = []
self._embeddings = []
self.embedding_model = embedding_model

# When the index is provided, it means it's from the cache.
self._index = index
Expand All @@ -42,7 +43,7 @@ def embeddings_index(self):

def _init_model(self):
"""Initialize the model used for computing the embeddings."""
self._model = SentenceTransformer("all-MiniLM-L6-v2")
self._model = SentenceTransformer(self.embedding_model)

def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Compute embeddings for a list of texts."""
Expand Down
7 changes: 4 additions & 3 deletions nemoguardrails/kb/kb.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@
class KnowledgeBase:
"""Basic implementation of a knowledge base."""

def __init__(self, documents: List[str]):
def __init__(self, documents: List[str], embedding_model: str):
self.documents = documents
self.chunks = []
self.index = None
self.embedding_model = embedding_model

def init(self):
"""Initialize the knowledge base.
Expand Down Expand Up @@ -79,10 +80,10 @@ def build(self):
ann_index = AnnoyIndex(embedding_size, "angular")
ann_index.load(cache_file)

self.index = BasicEmbeddingsIndex(index=ann_index)
self.index = BasicEmbeddingsIndex(embedding_model=self.embedding_model, index=ann_index)
self.index.add_items(index_items)
else:
self.index = BasicEmbeddingsIndex()
self.index = BasicEmbeddingsIndex(self.embedding_model)
self.index.add_items(index_items)
self.index.build()

Expand Down
4 changes: 2 additions & 2 deletions nemoguardrails/rails/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,11 @@ def from_path(config_path: str):
)

elif file.endswith(".yml") or file.endswith(".yaml"):
with open(full_path) as f:
with open(full_path, 'r', encoding='utf-8') as f:
_raw_config = yaml.safe_load(f.read())

elif file.endswith(".co"):
with open(full_path) as f:
with open(full_path, 'r', encoding='utf-8') as f:
_raw_config = parse_colang_file(file, content=f.read())

_join_config(raw_config, _raw_config)
Expand Down