diff --git a/examples/lightrag_openai_compatible_stream_demo.py b/examples/lightrag_openai_compatible_stream_demo.py new file mode 100644 index 00000000..9345ada5 --- /dev/null +++ b/examples/lightrag_openai_compatible_stream_demo.py @@ -0,0 +1,55 @@ +import os +import inspect +from lightrag import LightRAG +from lightrag.llm import openai_complete, openai_embedding +from lightrag.utils import EmbeddingFunc +from lightrag.lightrag import always_get_an_event_loop +from lightrag import QueryParam + +# WorkingDir +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +WORKING_DIR = os.path.join(ROOT_DIR, "dickens") +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) +print(f"WorkingDir: {WORKING_DIR}") + +api_key = "empty" +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=openai_complete, + llm_model_name="qwen2.5-14b-instruct@4bit", + llm_model_max_async=4, + llm_model_max_token_size=32768, + llm_model_kwargs={"base_url": "http://127.0.0.1:1234/v1", "api_key": api_key}, + embedding_func=EmbeddingFunc( + embedding_dim=1024, + max_token_size=8192, + func=lambda texts: openai_embedding( + texts=texts, + model="text-embedding-bge-m3", + base_url="http://127.0.0.1:1234/v1", + api_key=api_key, + ), + ), +) + +with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + +resp = rag.query( + "What are the top themes in this story?", + param=QueryParam(mode="hybrid", stream=True), +) + + +async def print_stream(stream): + async for chunk in stream: + if chunk: + print(chunk, end="", flush=True) + + +loop = always_get_an_event_loop() +if inspect.isasyncgen(resp): + loop.run_until_complete(print_stream(resp)) +else: + print(resp) diff --git a/lightrag/llm.py b/lightrag/llm.py index 72af880e..568fce04 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -76,11 +76,24 @@ async def openai_complete_if_cache( response = await openai_async_client.chat.completions.create( model=model, messages=messages, **kwargs ) - content = response.choices[0].message.content - if r"\u" in content: - content = content.encode("utf-8").decode("unicode_escape") - return content + if hasattr(response, "__aiter__"): + + async def inner(): + async for chunk in response: + content = chunk.choices[0].delta.content + if content is None: + continue + if r"\u" in content: + content = content.encode("utf-8").decode("unicode_escape") + yield content + + return inner() + else: + content = response.choices[0].message.content + if r"\u" in content: + content = content.encode("utf-8").decode("unicode_escape") + return content @retry( @@ -306,7 +319,7 @@ async def ollama_model_if_cache( response = await ollama_client.chat(model=model, messages=messages, **kwargs) if stream: - """ cannot cache stream response """ + """cannot cache stream response""" async def inner(): async for chunk in response: @@ -447,6 +460,22 @@ class GPTKeywordExtractionFormat(BaseModel): low_level_keywords: List[str] +async def openai_complete( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> Union[str, AsyncIterator[str]]: + keyword_extraction = kwargs.pop("keyword_extraction", None) + if keyword_extraction: + kwargs["response_format"] = "json" + model_name = kwargs["hashing_kv"].global_config["llm_model_name"] + return await openai_complete_if_cache( + model_name, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + async def gpt_4o_complete( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: diff --git a/lightrag/utils.py b/lightrag/utils.py index 32d5c87f..d79cc1a2 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -488,7 +488,7 @@ class CacheData: async def save_to_cache(hashing_kv, cache_data: CacheData): - if hashing_kv is None: + if hashing_kv is None or hasattr(cache_data.content, "__aiter__"): return mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}