Skip to content

Commit

Permalink
fix(vertex_ai.py): fix streaming logic
Browse files Browse the repository at this point in the history
  • Loading branch information
krrishdholakia committed Apr 23, 2024
1 parent 0bb8a44 commit ec2c70e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
11 changes: 5 additions & 6 deletions litellm/llms/vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def completion(
stream=True,
tools=tools,
)
optional_params["stream"] = True

return model_response

request_str += f"response = llm_model.generate_content({content})\n"
Expand Down Expand Up @@ -632,7 +632,7 @@ def completion(
},
)
model_response = chat.send_message_streaming(prompt, **optional_params)
optional_params["stream"] = True

return model_response

request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
Expand Down Expand Up @@ -664,7 +664,7 @@ def completion(
},
)
model_response = llm_model.predict_streaming(prompt, **optional_params)
optional_params["stream"] = True

return model_response

request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
Expand Down Expand Up @@ -1045,8 +1045,7 @@ async def async_streaming(
generation_config=optional_params,
tools=tools,
)
optional_params["stream"] = True
optional_params["tools"] = tools

elif mode == "chat":
chat = llm_model.start_chat()
optional_params.pop(
Expand All @@ -1065,7 +1064,7 @@ async def async_streaming(
},
)
response = chat.send_message_streaming_async(prompt, **optional_params)
optional_params["stream"] = True

elif mode == "text":
optional_params.pop(
"stream", None
Expand Down
5 changes: 3 additions & 2 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1683,13 +1683,14 @@ def completion(
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
new_params = deepcopy(optional_params)
if "claude-3" in model:
model_response = vertex_ai_anthropic.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
optional_params=new_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
Expand All @@ -1705,7 +1706,7 @@ def completion(
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
optional_params=new_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
Expand Down

0 comments on commit ec2c70e

Please sign in to comment.