Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -646,14 +646,6 @@ async def create(
content: Union[str, List[FunctionCall]]
thought: Optional[str] = None
if result.message.tool_calls is not None:
# TODO: What are possible values for done_reason?
if result.done_reason != "tool_calls":
warnings.warn(
f"Finish reason mismatch: {result.done_reason} != tool_calls "
"when tool_calls are present. Finish reason may not be accurate. "
"This may be due to the API used that is not returning the correct finish reason.",
stacklevel=2,
)
if result.message.content is not None and result.message.content != "":
thought = result.message.content
# NOTE: If OAI response type changes, this will need to be updated
Expand Down Expand Up @@ -760,9 +752,8 @@ async def create_stream(
content_chunks.append(chunk.message.content)
if len(chunk.message.content) > 0:
yield chunk.message.content
continue

# Otherwise, get tool calls
# Get tool calls
if chunk.message.tool_calls is not None:
full_tool_calls.extend(
[
Expand Down Expand Up @@ -796,9 +787,6 @@ async def create_stream(
else:
prompt_tokens = 0

if stop_reason == "function_call":
raise ValueError("Function calls are not supported in this context")

content: Union[str, List[FunctionCall]]
thought: Optional[str] = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,77 @@ async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
assert create_result.usage.completion_tokens == 12


@pytest.mark.asyncio
async def test_create_stream_tools(monkeypatch: pytest.MonkeyPatch) -> None:
def add(x: int, y: int) -> str:
return str(x + y)

add_tool = FunctionTool(add, description="Add two numbers")
model = "llama3.2"
content_raw = "Hello world! This is a test response. Test response."

async def _mock_chat(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatResponse, None]:
assert "stream" in kwargs
assert kwargs["stream"] is True

async def _mock_stream() -> AsyncGenerator[ChatResponse, None]:
chunks = [content_raw[i : i + 5] for i in range(0, len(content_raw), 5)]
# Simulate streaming by yielding chunks of the response
for chunk in chunks[:-1]:
yield ChatResponse(
model=model,
done=False,
message=Message(
role="assistant",
content=chunk,
),
)
yield ChatResponse(
model=model,
done=True,
done_reason="stop",
message=Message(
content=chunks[-1],
role="assistant",
tool_calls=[
Message.ToolCall(
function=Message.ToolCall.Function(
name=add_tool.name,
arguments={"x": 2, "y": 2},
),
),
],
),
prompt_eval_count=10,
eval_count=12,
)

return _mock_stream()

monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client = OllamaChatCompletionClient(model=model)
stream = client.create_stream(
messages=[
UserMessage(content="hi", source="user"),
],
tools=[add_tool],
)
chunks: List[str | CreateResult] = []
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
assert isinstance(chunks[-1].content, list)
assert len(chunks[-1].content) > 0
assert isinstance(chunks[-1].content[0], FunctionCall)
assert chunks[-1].content[0].name == add_tool.name
assert chunks[-1].content[0].arguments == json.dumps({"x": 2, "y": 2})
assert chunks[-1].finish_reason == "stop"
assert chunks[-1].usage is not None
assert chunks[-1].usage.prompt_tokens == 10
assert chunks[-1].usage.completion_tokens == 12


@pytest.mark.asyncio
async def test_create_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
class ResponseType(BaseModel):
Expand Down Expand Up @@ -541,7 +612,6 @@ def add(x: int, y: int) -> str:
assert ResponseType.model_validate_json(create_result.thought)


@pytest.mark.skip("TODO: Fix streaming with tools")
@pytest.mark.asyncio
@pytest.mark.parametrize("model", ["qwen2.5:0.5b", "llama3.2:1b"])
async def test_ollama_create_stream_tools(model: str, ollama_client: OllamaChatCompletionClient) -> None:
Expand Down Expand Up @@ -569,7 +639,7 @@ def add(x: int, y: int) -> str:
assert len(create_result.content) > 0
assert isinstance(create_result.content[0], FunctionCall)
assert create_result.content[0].name == add_tool.name
assert create_result.finish_reason == "function_calls"
assert create_result.finish_reason == "stop"

execution_result = FunctionExecutionResult(
content="4",
Expand Down
Loading