diff --git a/python/packages/autogen-ext/src/autogen_ext/models/ollama/_ollama_client.py b/python/packages/autogen-ext/src/autogen_ext/models/ollama/_ollama_client.py index 1f4641fa25b0..e8abf67614a9 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/ollama/_ollama_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/ollama/_ollama_client.py @@ -646,14 +646,6 @@ async def create( content: Union[str, List[FunctionCall]] thought: Optional[str] = None if result.message.tool_calls is not None: - # TODO: What are possible values for done_reason? - if result.done_reason != "tool_calls": - warnings.warn( - f"Finish reason mismatch: {result.done_reason} != tool_calls " - "when tool_calls are present. Finish reason may not be accurate. " - "This may be due to the API used that is not returning the correct finish reason.", - stacklevel=2, - ) if result.message.content is not None and result.message.content != "": thought = result.message.content # NOTE: If OAI response type changes, this will need to be updated @@ -760,9 +752,8 @@ async def create_stream( content_chunks.append(chunk.message.content) if len(chunk.message.content) > 0: yield chunk.message.content - continue - # Otherwise, get tool calls + # Get tool calls if chunk.message.tool_calls is not None: full_tool_calls.extend( [ @@ -796,9 +787,6 @@ async def create_stream( else: prompt_tokens = 0 - if stop_reason == "function_call": - raise ValueError("Function calls are not supported in this context") - content: Union[str, List[FunctionCall]] thought: Optional[str] = None diff --git a/python/packages/autogen-ext/tests/models/test_ollama_chat_completion_client.py b/python/packages/autogen-ext/tests/models/test_ollama_chat_completion_client.py index 6e7389bdd3fc..7499e31beec3 100644 --- a/python/packages/autogen-ext/tests/models/test_ollama_chat_completion_client.py +++ b/python/packages/autogen-ext/tests/models/test_ollama_chat_completion_client.py @@ -206,6 +206,77 @@ async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse: assert create_result.usage.completion_tokens == 12 +@pytest.mark.asyncio +async def test_create_stream_tools(monkeypatch: pytest.MonkeyPatch) -> None: + def add(x: int, y: int) -> str: + return str(x + y) + + add_tool = FunctionTool(add, description="Add two numbers") + model = "llama3.2" + content_raw = "Hello world! This is a test response. Test response." + + async def _mock_chat(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatResponse, None]: + assert "stream" in kwargs + assert kwargs["stream"] is True + + async def _mock_stream() -> AsyncGenerator[ChatResponse, None]: + chunks = [content_raw[i : i + 5] for i in range(0, len(content_raw), 5)] + # Simulate streaming by yielding chunks of the response + for chunk in chunks[:-1]: + yield ChatResponse( + model=model, + done=False, + message=Message( + role="assistant", + content=chunk, + ), + ) + yield ChatResponse( + model=model, + done=True, + done_reason="stop", + message=Message( + content=chunks[-1], + role="assistant", + tool_calls=[ + Message.ToolCall( + function=Message.ToolCall.Function( + name=add_tool.name, + arguments={"x": 2, "y": 2}, + ), + ), + ], + ), + prompt_eval_count=10, + eval_count=12, + ) + + return _mock_stream() + + monkeypatch.setattr(AsyncClient, "chat", _mock_chat) + client = OllamaChatCompletionClient(model=model) + stream = client.create_stream( + messages=[ + UserMessage(content="hi", source="user"), + ], + tools=[add_tool], + ) + chunks: List[str | CreateResult] = [] + async for chunk in stream: + chunks.append(chunk) + assert len(chunks) > 0 + assert isinstance(chunks[-1], CreateResult) + assert isinstance(chunks[-1].content, list) + assert len(chunks[-1].content) > 0 + assert isinstance(chunks[-1].content[0], FunctionCall) + assert chunks[-1].content[0].name == add_tool.name + assert chunks[-1].content[0].arguments == json.dumps({"x": 2, "y": 2}) + assert chunks[-1].finish_reason == "stop" + assert chunks[-1].usage is not None + assert chunks[-1].usage.prompt_tokens == 10 + assert chunks[-1].usage.completion_tokens == 12 + + @pytest.mark.asyncio async def test_create_structured_output(monkeypatch: pytest.MonkeyPatch) -> None: class ResponseType(BaseModel): @@ -541,7 +612,6 @@ def add(x: int, y: int) -> str: assert ResponseType.model_validate_json(create_result.thought) -@pytest.mark.skip("TODO: Fix streaming with tools") @pytest.mark.asyncio @pytest.mark.parametrize("model", ["qwen2.5:0.5b", "llama3.2:1b"]) async def test_ollama_create_stream_tools(model: str, ollama_client: OllamaChatCompletionClient) -> None: @@ -569,7 +639,7 @@ def add(x: int, y: int) -> str: assert len(create_result.content) > 0 assert isinstance(create_result.content[0], FunctionCall) assert create_result.content[0].name == add_tool.name - assert create_result.finish_reason == "function_calls" + assert create_result.finish_reason == "stop" execution_result = FunctionExecutionResult( content="4",