Skip to content

Commit d9bba83

Browse files
committed
Add streaming support
1 parent 87c9e63 commit d9bba83

File tree

3 files changed

+222
-1
lines changed

3 files changed

+222
-1
lines changed

llms_wrapper/llms.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,7 @@ def query(
593593
return_response: bool = False,
594594
debug=False,
595595
litellm_debug=None,
596+
stream=True,
596597
recursive_call_info: Optional[Dict[str, any]] = None,
597598
**kwargs,
598599
) -> Dict[str, any]:
@@ -609,6 +610,9 @@ def query(
609610
return_response: whether or not the complete reponse should get returned
610611
debug: if True, emits debug messages to aid development and debugging
611612
litellm_debug: if True, litellm debug logging is enabled, if False, disabled, if None, use debug setting
613+
stream: if True, the returned object containst the stream that can be iterated over. Streaming
614+
may not work for all models.
615+
recursive_call_info: internal use only
612616
kwargs: any additional keyword arguments to pass on to the LLM
613617
614618
Returns:
@@ -657,6 +661,11 @@ def query(
657661
fmap = toolnames2funcs(tools)
658662
else:
659663
fmap = {}
664+
if stream:
665+
# TODO: check if model supports streaming
666+
# if streaming is enabled, we always return the original response
667+
return_response = True
668+
completion_kwargs["stream"] = True
660669
ret = {}
661670
# before adding the kwargs, save the recursive_call_info and remove it from kwargs
662671
if debug:
@@ -687,6 +696,13 @@ def query(
687696
model=llm["llm"],
688697
messages=messages,
689698
**completion_kwargs)
699+
if stream:
700+
# TODO: for now we take a shortcut here and simply return the original response
701+
# as "response".
702+
ret["response"] = response
703+
ret["ok"] = True
704+
ret["error"] = ""
705+
return ret
690706
elapsed = time.time() - start
691707
logger.debug(f"Full Response: {response}")
692708
llm["_elapsed_time"] += elapsed
@@ -743,10 +759,14 @@ def query(
743759
if debug:
744760
print(f"DEBUG: checking for tool_calls: {response_message}, have tools: {tools is not None}")
745761
if tools is not None:
762+
# TODO: if streaming is enabled we need to gather the complete response before
763+
# we can process the tool calls
746764
if hasattr(response_message, "tool_calls") and response_message.tool_calls is not None:
747765
tool_calls = response_message.tool_calls
748766
else:
749767
tool_calls = []
768+
if stream:
769+
raise ValueError("Error: streaming is not supported for tool calls yet")
750770
if debug:
751771
print(f"DEBUG: got {len(tool_calls)} tool calls:")
752772
for tool_call in tool_calls:

llms_wrapper/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
import importlib.metadata
2-
__version__ = "0.2.0"
2+
__version__ = "0.3.0"
33

notebooks/test-streaming.ipynb

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "744579c6-ec0c-4c73-bb38-5c99a566056f",
6+
"metadata": {},
7+
"source": [
8+
"# test-tooling.ipynb\n",
9+
"\n",
10+
"Test the API implementation of tooling"
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": 1,
16+
"id": "a6da6e33-e3dd-45d3-abad-ad48a617b1db",
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"import os, sys\n",
21+
"from typing import Optional, List, Dict\n",
22+
"sys.path.append(os.path.join(\"..\"))\n",
23+
"from llms_wrapper.llms import LLMS, toolnames2funcs, get_func_by_name\n",
24+
"from llms_wrapper.config import update_llm_config"
25+
]
26+
},
27+
{
28+
"cell_type": "code",
29+
"execution_count": 2,
30+
"id": "7401b8b9-af81-4a1e-9bfa-bef00ec3ea68",
31+
"metadata": {},
32+
"outputs": [
33+
{
34+
"data": {
35+
"text/plain": [
36+
"['openai/gpt-4o',\n",
37+
" 'openai/gpt-4o-mini',\n",
38+
" 'gemini/gemini-2.0-flash-exp',\n",
39+
" 'gemini/gemini-1.5-flash',\n",
40+
" 'gemini/gemini-1.5-pro',\n",
41+
" 'anthropic/claude-3-5-sonnet-20240620',\n",
42+
" 'anthropic/claude-3-opus-20240229',\n",
43+
" 'mistral/mistral-large-latest',\n",
44+
" 'xai/grok-beta',\n",
45+
" 'groq/llama3-70b-8192',\n",
46+
" 'groq/llama-3.3-70b-versatile',\n",
47+
" 'deepseek/deepseek-chat']"
48+
]
49+
},
50+
"execution_count": 2,
51+
"metadata": {},
52+
"output_type": "execute_result"
53+
}
54+
],
55+
"source": [
56+
"config = dict(\n",
57+
" llms=[\n",
58+
" # OpenAI\n",
59+
" # https://platform.openai.com/docs/models\n",
60+
" dict(llm=\"openai/gpt-4o\"),\n",
61+
" dict(llm=\"openai/gpt-4o-mini\"),\n",
62+
" # dict(llm=\"openai/o1\"), # restricted\n",
63+
" # dict(llm=\"openai/o1-mini\"), # restricted\n",
64+
" # Google Gemini\n",
65+
" # https://ai.google.dev/gemini-api/docs/models/gemini\n",
66+
" dict(llm=\"gemini/gemini-2.0-flash-exp\"),\n",
67+
" dict(llm=\"gemini/gemini-1.5-flash\"),\n",
68+
" dict(llm=\"gemini/gemini-1.5-pro\"),\n",
69+
" # Anthropic\n",
70+
" # https://docs.anthropic.com/en/docs/about-claude/models\n",
71+
" dict(llm=\"anthropic/claude-3-5-sonnet-20240620\"),\n",
72+
" dict(llm=\"anthropic/claude-3-opus-20240229\"),\n",
73+
" # Mistral\n",
74+
" # https://docs.mistral.ai/getting-started/models/models_overview/\n",
75+
" dict(llm=\"mistral/mistral-large-latest\"),\n",
76+
" # XAI\n",
77+
" # dict(llm=\"xai/grok-2\"), # not mapped by litellm yet?\n",
78+
" dict(llm=\"xai/grok-beta\"),\n",
79+
" # Groq\n",
80+
" # https://console.groq.com/docs/models\n",
81+
" dict(llm=\"groq/llama3-70b-8192\"),\n",
82+
" dict(llm=\"groq/llama-3.3-70b-versatile\"),\n",
83+
" # Deepseek\n",
84+
" # https://api-docs.deepseek.com/quick_start/pricing\n",
85+
" dict(llm=\"deepseek/deepseek-chat\"),\n",
86+
" ],\n",
87+
" providers = dict(\n",
88+
" openai = dict(api_key_env=\"MY_OPENAI_API_KEY\"),\n",
89+
" gemini = dict(api_key_env=\"MY_GEMINI_API_KEY\"),\n",
90+
" anthropic = dict(api_key_env=\"MY_ANTHROPIC_API_KEY\"),\n",
91+
" mistral = dict(api_key_env=\"MY_MISTRAL_API_KEY\"),\n",
92+
" xai = dict(api_key_env=\"MY_XAI_API_KEY\"), \n",
93+
" groq = dict(api_key_env=\"MY_GROQ_API_KEY\"),\n",
94+
" deepseek = dict(api_key_env=\"MY_DEEPSEEK_API_KEY\"),\n",
95+
" )\n",
96+
")\n",
97+
"config = update_llm_config(config)\n",
98+
"llms = LLMS(config)\n",
99+
"llms.list_aliases()"
100+
]
101+
},
102+
{
103+
"cell_type": "markdown",
104+
"id": "09dfe4bc",
105+
"metadata": {},
106+
"source": [
107+
"## Test streaming\n",
108+
"\n"
109+
]
110+
},
111+
{
112+
"cell_type": "code",
113+
"execution_count": 14,
114+
"id": "a89258b6",
115+
"metadata": {},
116+
"outputs": [
117+
{
118+
"data": {
119+
"text/plain": [
120+
"[{'content': 'What is a monoid? Give me a simple example. Provide your answer as plain text, do not use Markdown formatting.',\n",
121+
" 'role': 'user'}]"
122+
]
123+
},
124+
"execution_count": 14,
125+
"metadata": {},
126+
"output_type": "execute_result"
127+
}
128+
],
129+
"source": [
130+
"msgs = LLMS.make_messages(\"What is a monoid? Give me a simple example. Provide your answer as plain text, do not use Markdown formatting.\")\n",
131+
"msgs"
132+
]
133+
},
134+
{
135+
"cell_type": "code",
136+
"execution_count": 15,
137+
"id": "491b0ddb",
138+
"metadata": {},
139+
"outputs": [
140+
{
141+
"name": "stdout",
142+
"output_type": "stream",
143+
"text": [
144+
"A monoid is an algebraic structure with a single associative binary operation and an identity element. Specifically, a set M is a monoid if it is equipped with a binary operation (let's call it *) that satisfies the following properties:\n",
145+
"\n",
146+
"1. **Associativity**: For all elements a, b, and c in M, the equation (a * b) * c = a * (b * c) holds.\n",
147+
"2. **Identity Element**: There exists an element e in M such that for every element a in M, the equation e * a = a * e = a holds.\n",
148+
"\n",
149+
"A simple example of a monoid is the set of natural numbers (including zero) with the operation of addition. \n",
150+
"\n",
151+
"- The set is {0, 1, 2, 3, ...}.\n",
152+
"- The binary operation is addition (+).\n",
153+
"- The identity element is 0 because adding 0 to any natural number does not change the number (0 + a = a + 0 = a).\n",
154+
"- Addition is associative because for any natural numbers a, b, and c, the equation (a + b) + c = a + (b + c) is always true."
155+
]
156+
}
157+
],
158+
"source": [
159+
"ret = llms.query(\"openai/gpt-4o\", msgs, temperature=0.5, max_tokens=1000, stream=True)\n",
160+
"if ret[\"ok\"]:\n",
161+
" for chunk in ret[\"response\"]:\n",
162+
" choice0 = chunk.choices[0]\n",
163+
" if choice0.finish_reason == \"stop\":\n",
164+
" break \n",
165+
" content = choice0.delta.content \n",
166+
" print(content, end=\"\", flush=True)\n",
167+
"else:\n",
168+
" print(\"Error:\", ret[\"error\"])"
169+
]
170+
},
171+
{
172+
"cell_type": "code",
173+
"execution_count": null,
174+
"id": "0eeea159",
175+
"metadata": {},
176+
"outputs": [],
177+
"source": []
178+
}
179+
],
180+
"metadata": {
181+
"kernelspec": {
182+
"display_name": "llms_wrapper",
183+
"language": "python",
184+
"name": "python3"
185+
},
186+
"language_info": {
187+
"codemirror_mode": {
188+
"name": "ipython",
189+
"version": 3
190+
},
191+
"file_extension": ".py",
192+
"mimetype": "text/x-python",
193+
"name": "python",
194+
"nbconvert_exporter": "python",
195+
"pygments_lexer": "ipython3",
196+
"version": "3.11.11"
197+
}
198+
},
199+
"nbformat": 4,
200+
"nbformat_minor": 5
201+
}

0 commit comments

Comments
 (0)