Skip to content

Commit 13bdf7c

Browse files
committed
Fixes
1 parent cebc4a6 commit 13bdf7c

File tree

2 files changed

+42
-25
lines changed

2 files changed

+42
-25
lines changed

examples/demo_llama_index_guardrails.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from nemoguardrails import LLMRails, RailsConfig
2-
from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, LLMPredictor
3-
from llama_index.indices.query.base import BaseQueryEngine
42
from langchain.llms.base import BaseLLM
53

6-
from typing import Callable
4+
from typing import Callable, Any, Coroutine
75

86
COLANG_CONFIG = """
97
define user express greeting
@@ -26,7 +24,7 @@
2624
2725
# Question answering flow
2826
define flow
29-
user express question
27+
user ...
3028
$answer = execute llama_index_query(query=$last_user_message)
3129
bot $answer
3230
@@ -40,29 +38,49 @@
4038
"""
4139

4240

43-
def _get_llama_index_query_engine(llm: BaseLLM):
44-
docs = SimpleDirectoryReader(
45-
input_files=["../examples/grounding_rail/kb/report.md"]
46-
).load_data()
47-
llm_predictor = LLMPredictor(llm=llm)
48-
index = GPTVectorStoreIndex.from_documents(docs, llm_predictor=llm_predictor)
49-
default_query_engine = index.as_query_engine()
50-
return default_query_engine
51-
52-
53-
def _get_callable_query_engine(
54-
query_engine: BaseQueryEngine
55-
) -> Callable[[str], str]:
56-
async def get_query_response(query: str) -> str:
57-
return query_engine.query(query).response
58-
59-
return get_query_response
41+
def demo():
42+
try:
43+
import llama_index
44+
from llama_index.indices.query.base import BaseQueryEngine
45+
from llama_index.response.schema import StreamingResponse
6046

47+
except ImportError:
48+
raise ImportError(
49+
"Could not import llama_index, please install it with "
50+
"`pip install llama_index`."
51+
)
6152

62-
def demo():
6353
config = RailsConfig.from_content(COLANG_CONFIG, YAML_CONFIG)
6454
app = LLMRails(config)
65-
query_engine: BaseQueryEngine = _get_llama_index_query_engine(app.llm)
55+
56+
def _get_llama_index_query_engine(llm: BaseLLM):
57+
docs = llama_index.SimpleDirectoryReader(
58+
input_files=["../examples/grounding_rail/kb/report.md"]
59+
).load_data()
60+
llm_predictor = llama_index.LLMPredictor(llm=llm)
61+
index = llama_index.GPTVectorStoreIndex.from_documents(
62+
docs, llm_predictor=llm_predictor
63+
)
64+
default_query_engine = index.as_query_engine()
65+
return default_query_engine
66+
67+
def _get_callable_query_engine(
68+
query_engine: BaseQueryEngine,
69+
) -> Callable[[str], Coroutine[Any, Any, str]]:
70+
async def get_query_response(query: str) -> str:
71+
response = query_engine.query(query)
72+
if isinstance(response, StreamingResponse):
73+
typed_response = response.get_response()
74+
else:
75+
typed_response = response
76+
response_str = typed_response.response
77+
if response_str is None:
78+
return ""
79+
return response_str
80+
81+
return get_query_response
82+
83+
query_engine = _get_llama_index_query_engine(app.llm)
6684
app.register_action(
6785
_get_callable_query_engine(query_engine), name="llama_index_query"
6886
)

requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,4 @@ starlette==0.26.1
1212
uvicorn==0.21.1
1313
httpx==0.23.3
1414
simpleeval==0.9.13
15-
typing-extensions==4.5.0
16-
llama_index==0.6.14
15+
typing-extensions==4.5.0

0 commit comments

Comments
 (0)