Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for OpenAI Compatible Streaming output and delete unreachable code #417

Merged
merged 5 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions examples/lightrag_openai_compatible_stream_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import inspect
from lightrag import LightRAG
from lightrag.llm import openai_complete, openai_embedding
from lightrag.utils import EmbeddingFunc
from lightrag.lightrag import always_get_an_event_loop
from lightrag import QueryParam

# WorkingDir
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
WORKING_DIR = os.path.join(ROOT_DIR, "dickens")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
print(f"WorkingDir: {WORKING_DIR}")

api_key = "empty"
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=openai_complete,
llm_model_name="qwen2.5-14b-instruct@4bit",
llm_model_max_async=4,
llm_model_max_token_size=32768,
llm_model_kwargs={"base_url": "http://127.0.0.1:1234/v1", "api_key": api_key},
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=lambda texts: openai_embedding(
texts=texts,
model="text-embedding-bge-m3",
base_url="http://127.0.0.1:1234/v1",
api_key=api_key,
),
),
)

with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())

resp = rag.query(
"What are the top themes in this story?",
param=QueryParam(mode="hybrid", stream=True),
)


async def print_stream(stream):
async for chunk in stream:
if chunk:
print(chunk, end="", flush=True)


loop = always_get_an_event_loop()
if inspect.isasyncgen(resp):
loop.run_until_complete(print_stream(resp))
else:
print(resp)
39 changes: 34 additions & 5 deletions lightrag/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,24 @@ async def openai_complete_if_cache(
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
content = response.choices[0].message.content
if r"\u" in content:
content = content.encode("utf-8").decode("unicode_escape")

return content
if hasattr(response, "__aiter__"):

async def inner():
async for chunk in response:
content = chunk.choices[0].delta.content
if content is None:
continue
if r"\u" in content:
content = content.encode("utf-8").decode("unicode_escape")
yield content

return inner()
else:
content = response.choices[0].message.content
if r"\u" in content:
content = content.encode("utf-8").decode("unicode_escape")
return content


@retry(
Expand Down Expand Up @@ -306,7 +319,7 @@ async def ollama_model_if_cache(

response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream:
""" cannot cache stream response """
"""cannot cache stream response"""

async def inner():
async for chunk in response:
Expand Down Expand Up @@ -447,6 +460,22 @@ class GPTKeywordExtractionFormat(BaseModel):
low_level_keywords: List[str]


async def openai_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = "json"
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await openai_complete_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)


async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
Expand Down
2 changes: 1 addition & 1 deletion lightrag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ class CacheData:


async def save_to_cache(hashing_kv, cache_data: CacheData):
if hashing_kv is None:
if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):
return

mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
Expand Down
Loading