Skip to content

Commit 606ca6b

Browse files
committed
fix utf-8 encoding issue, add embedding_model configuration
Signed-off-by: Yuekai Zhang <zhangyuekai@foxmail.com>
1 parent 9972cce commit 606ca6b

File tree

4 files changed

+21
-11
lines changed

4 files changed

+21
-11
lines changed

nemoguardrails/actions/llm/generation.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ def __init__(self, config: RailsConfig, llm: BaseLLM, verbose: bool = False):
7171
self.kb = None
7272
self._init_kb()
7373

74+
# If we have a customized embedding model, we'll use it.
75+
self.embedding_model = "all-MiniLM-L6-v2"
76+
for model in self.config.models:
77+
if 'embedding' in model.type:
78+
self.embedding_model = model.model
79+
assert model.engine == "SentenceTransformer"
80+
break
81+
7482
def _init_user_message_index(self):
7583
"""Initializes the index of user messages."""
7684

@@ -86,7 +94,7 @@ def _init_user_message_index(self):
8694
if len(items) == 0:
8795
return
8896

89-
self.user_message_index = BasicEmbeddingsIndex()
97+
self.user_message_index = BasicEmbeddingsIndex(self.embedding_model)
9098
self.user_message_index.add_items(items)
9199

92100
# NOTE: this should be very fast, otherwise needs to be moved to separate thread.
@@ -107,7 +115,7 @@ def _init_bot_message_index(self):
107115
if len(items) == 0:
108116
return
109117

110-
self.bot_message_index = BasicEmbeddingsIndex()
118+
self.bot_message_index = BasicEmbeddingsIndex(self.embedding_model)
111119
self.bot_message_index.add_items(items)
112120

113121
# NOTE: this should be very fast, otherwise needs to be moved to separate thread.
@@ -141,7 +149,7 @@ def _init_flows_index(self):
141149
if len(items) == 0:
142150
return
143151

144-
self.flows_index = BasicEmbeddingsIndex()
152+
self.flows_index = BasicEmbeddingsIndex(self.embedding_model)
145153
self.flows_index.add_items(items)
146154

147155
# NOTE: this should be very fast, otherwise needs to be moved to separate thread.
@@ -154,7 +162,7 @@ def _init_kb(self):
154162
return
155163

156164
documents = [doc.content for doc in self.config.docs]
157-
self.kb = KnowledgeBase(documents=documents)
165+
self.kb = KnowledgeBase(documents=documents, embedding_model=self.embedding_model)
158166
self.kb.init()
159167
self.kb.build()
160168

nemoguardrails/kb/basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):
2828
It uses Annoy to perform the search.
2929
"""
3030

31-
def __init__(self, index=None):
31+
def __init__(self, embedding_model=None, index=None):
3232
self._model = None
3333
self._items = []
3434
self._embeddings = []
35+
self.embedding_model = embedding_model
3536

3637
# When the index is provided, it means it's from the cache.
3738
self._index = index
@@ -42,7 +43,7 @@ def embeddings_index(self):
4243

4344
def _init_model(self):
4445
"""Initialize the model used for computing the embeddings."""
45-
self._model = SentenceTransformer("all-MiniLM-L6-v2")
46+
self._model = SentenceTransformer(self.embedding_model)
4647

4748
def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
4849
"""Compute embeddings for a list of texts."""

nemoguardrails/kb/kb.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@
3333
class KnowledgeBase:
3434
"""Basic implementation of a knowledge base."""
3535

36-
def __init__(self, documents: List[str]):
36+
def __init__(self, documents: List[str], embedding_model: str):
3737
self.documents = documents
3838
self.chunks = []
3939
self.index = None
40+
self.embedding_model = embedding_model
4041

4142
def init(self):
4243
"""Initialize the knowledge base.
@@ -79,10 +80,10 @@ def build(self):
7980
ann_index = AnnoyIndex(embedding_size, "angular")
8081
ann_index.load(cache_file)
8182

82-
self.index = BasicEmbeddingsIndex(index=ann_index)
83+
self.index = BasicEmbeddingsIndex(embedding_model=self.embedding_model, index=ann_index)
8384
self.index.add_items(index_items)
8485
else:
85-
self.index = BasicEmbeddingsIndex()
86+
self.index = BasicEmbeddingsIndex(self.embedding_model)
8687
self.index.add_items(index_items)
8788
self.index.build()
8889

nemoguardrails/rails/llm/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,11 @@ def from_path(config_path: str):
210210
)
211211

212212
elif file.endswith(".yml") or file.endswith(".yaml"):
213-
with open(full_path) as f:
213+
with open(full_path, 'r', encoding='utf-8') as f:
214214
_raw_config = yaml.safe_load(f.read())
215215

216216
elif file.endswith(".co"):
217-
with open(full_path) as f:
217+
with open(full_path, 'r', encoding='utf-8') as f:
218218
_raw_config = parse_colang_file(file, content=f.read())
219219

220220
_join_config(raw_config, _raw_config)

0 commit comments

Comments
 (0)