Skip to content

Commit

Permalink
fix(ollama.py): fix sync ollama streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
krrishdholakia committed Dec 17, 2023
1 parent 13d088b commit a3c7a34
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 43 deletions.
48 changes: 15 additions & 33 deletions litellm/llms/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,39 +148,21 @@ def get_ollama_response_stream(
return response

else:
return ollama_completion_stream(url=url, data=data)

def ollama_completion_stream(url, data):
session = requests.Session()

with session.post(url, json=data, stream=True) as resp:
if resp.status_code != 200:
raise OllamaError(status_code=resp.status_code, message=resp.text)
for line in resp.iter_lines():
if line:
try:
json_chunk = line.decode("utf-8")
chunks = json_chunk.split("\n")
for chunk in chunks:
if chunk.strip() != "":
j = json.loads(chunk)
if "error" in j:
completion_obj = {
"role": "assistant",
"content": "",
"error": j
}
yield completion_obj
if "response" in j:
completion_obj = {
"role": "assistant",
"content": "",
}
completion_obj["content"] = j["response"]
yield completion_obj
except Exception as e:
traceback.print_exc()
session.close()
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)

def ollama_completion_stream(url, data, logging_obj):
with httpx.stream(
url=url,
json=data,
method="POST",
timeout=litellm.request_timeout
) as response:
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)

streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
for transformed_chunk in streamwrapper:
yield transformed_chunk


async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
Expand Down
8 changes: 1 addition & 7 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,14 +1320,8 @@ def completion(

## LOGGING
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding)
if acompletion is True:
if acompletion is True or optional_params.get("stream", False) == True:
return generator
if optional_params.get("stream", False) == True:
# assume all ollama responses are streamed
response = CustomStreamWrapper(
generator, model, custom_llm_provider="ollama", logging_obj=logging
)
return response
else:
response_string = ""
for chunk in generator:
Expand Down
8 changes: 5 additions & 3 deletions litellm/tests/test_ollama_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,18 @@
# try:
# response = completion(
# model="ollama/llama2",
# messages=messages,
# messages=[{"role": "user", "content": "Hey, how's it going?"}],
# max_tokens=200,
# request_timeout = 10,

# stream=True
# )
# for chunk in response:
# print(chunk)
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")

# # test_completion_ollama()
# test_completion_ollama()

# def test_completion_ollama_with_api_base():
# try:
Expand Down

0 comments on commit a3c7a34

Please sign in to comment.