diff --git a/python/packages/autogen-core/tests/test_model_context.py b/python/packages/autogen-core/tests/test_model_context.py index 7a37c30a220b..bcd2a9c87d9e 100644 --- a/python/packages/autogen-core/tests/test_model_context.py +++ b/python/packages/autogen-core/tests/test_model_context.py @@ -137,7 +137,10 @@ async def test_token_limited_model_context_with_token_limit( await model_context.add_message(msg) retrieved = await model_context.get_messages() - assert len(retrieved) == 1 # Token limit set very low, will remove 2 of the messages + # Token limit set low, will remove some messages + # OpenAI: keeps 2 messages (29 tokens with limit 30) + # Ollama: keeps 1 message (20 tokens with limit 20) + assert len(retrieved) < len(messages) # Some messages removed due to token limit assert retrieved != messages # Will not be equal to the original messages await model_context.clear() @@ -151,7 +154,7 @@ async def test_token_limited_model_context_with_token_limit( await model_context.clear() await model_context.load_state(state) retrieved = await model_context.get_messages() - assert len(retrieved) == 1 + assert len(retrieved) < len(messages) # Some messages removed due to token limit assert retrieved != messages diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index 2cbb235cefa3..a80e912534ab 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -393,6 +393,17 @@ def count_tokens_openai( elif field == "description": tool_tokens += 2 tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore + elif field == "anyOf": + tool_tokens -= 3 + for o in v["anyOf"]: # type: ignore + tool_tokens += 3 + tool_tokens += len(encoding.encode(str(o["type"]))) # pyright: ignore + elif field == "default": + tool_tokens += 2 + tool_tokens += len(encoding.encode(json.dumps(v["default"]))) + elif field == "title": + tool_tokens += 2 + tool_tokens += len(encoding.encode(str(v["title"]))) # pyright: ignore elif field == "enum": tool_tokens -= 3 for o in v["enum"]: # pyright: ignore @@ -404,7 +415,9 @@ def count_tokens_openai( if len(parameters["properties"]) == 0: # pyright: ignore tool_tokens -= 2 num_tokens += tool_tokens - num_tokens += 12 + + if oai_tools: + num_tokens += 12 return num_tokens diff --git a/python/packages/autogen-ext/tests/models/test_openai_model_client.py b/python/packages/autogen-ext/tests/models/test_openai_model_client.py index 58558cceb5f4..2c1e19521ac4 100644 --- a/python/packages/autogen-ext/tests/models/test_openai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_openai_model_client.py @@ -2,7 +2,7 @@ import json import logging import os -from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Tuple, TypeVar +from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Optional, Tuple, TypeVar from unittest.mock import AsyncMock, MagicMock import httpx @@ -450,11 +450,27 @@ def tool1(test: str, test2: str) -> str: def tool2(test1: int, test2: List[int]) -> str: return str(test1) + str(test2) - tools = [FunctionTool(tool1, description="example tool 1"), FunctionTool(tool2, description="example tool 2")] + def tool3(test1: Annotated[Optional[str], "example"] = None, test2: Literal["1", "2"] = "2") -> str: + return str(test1) + str(test2) + + tools = [ + FunctionTool(tool1, description="example tool 1"), + FunctionTool(tool2, description="example tool 2"), + FunctionTool(tool3, description="example tool 3"), + ] mockcalculate_vision_tokens = MagicMock() monkeypatch.setattr("autogen_ext.models.openai._openai_client.calculate_vision_tokens", mockcalculate_vision_tokens) + # Test count_tokens without tools + num_tokens = client.count_tokens(messages) + assert num_tokens + + # Check that calculate_vision_tokens was called + mockcalculate_vision_tokens.assert_called_once() + mockcalculate_vision_tokens.reset_mock() + + # Test count_tokens with tools num_tokens = client.count_tokens(messages, tools=tools) assert num_tokens