Skip to content

Commit 9972cce

Browse files
authored
Hongyi LlamaIndex NeMo demo (#40)
1 parent 34696ff commit 9972cce

File tree

2 files changed

+95
-1
lines changed

2 files changed

+95
-1
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from nemoguardrails import LLMRails, RailsConfig
2+
from langchain.llms.base import BaseLLM
3+
4+
from typing import Callable, Any, Coroutine
5+
6+
COLANG_CONFIG = """
7+
define user express greeting
8+
"hi"
9+
10+
define user express ill intent
11+
"I hate you"
12+
"I want to destroy the world"
13+
14+
define bot express cannot respond
15+
"I'm sorry I cannot help you with that."
16+
17+
define user express question
18+
"What is the current unemployment rate?"
19+
20+
# Basic guardrail example
21+
define flow
22+
user express ill intent
23+
bot express cannot respond
24+
25+
# Question answering flow
26+
define flow
27+
user ...
28+
$answer = execute llama_index_query(query=$last_user_message)
29+
bot $answer
30+
31+
"""
32+
33+
YAML_CONFIG = """
34+
models:
35+
- type: main
36+
engine: openai
37+
model: text-davinci-003
38+
"""
39+
40+
41+
def demo():
42+
try:
43+
import llama_index
44+
from llama_index.indices.query.base import BaseQueryEngine
45+
from llama_index.response.schema import StreamingResponse
46+
47+
except ImportError:
48+
raise ImportError(
49+
"Could not import llama_index, please install it with "
50+
"`pip install llama_index`."
51+
)
52+
53+
config = RailsConfig.from_content(COLANG_CONFIG, YAML_CONFIG)
54+
app = LLMRails(config)
55+
56+
def _get_llama_index_query_engine(llm: BaseLLM):
57+
docs = llama_index.SimpleDirectoryReader(
58+
input_files=["../examples/grounding_rail/kb/report.md"]
59+
).load_data()
60+
llm_predictor = llama_index.LLMPredictor(llm=llm)
61+
index = llama_index.GPTVectorStoreIndex.from_documents(
62+
docs, llm_predictor=llm_predictor
63+
)
64+
default_query_engine = index.as_query_engine()
65+
return default_query_engine
66+
67+
def _get_callable_query_engine(
68+
query_engine: BaseQueryEngine,
69+
) -> Callable[[str], Coroutine[Any, Any, str]]:
70+
async def get_query_response(query: str) -> str:
71+
response = query_engine.query(query)
72+
if isinstance(response, StreamingResponse):
73+
typed_response = response.get_response()
74+
else:
75+
typed_response = response
76+
response_str = typed_response.response
77+
if response_str is None:
78+
return ""
79+
return response_str
80+
81+
return get_query_response
82+
83+
query_engine = _get_llama_index_query_engine(app.llm)
84+
app.register_action(
85+
_get_callable_query_engine(query_engine), name="llama_index_query"
86+
)
87+
88+
history = [{"role": "user", "content": "What is the current unemployment rate?"}]
89+
result = app.generate(messages=history)
90+
print(result)
91+
92+
93+
if __name__ == "__main__":
94+
demo()

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ starlette==0.26.1
1212
uvicorn==0.21.1
1313
httpx==0.23.3
1414
simpleeval==0.9.13
15-
typing-extensions==4.5.0
15+
typing-extensions==4.5.0

0 commit comments

Comments
 (0)