Skip to content

Commit f76f92d

Browse files
authored
Fix not supported field warnings in count_tokens_openai (#6987)
1 parent fb03c1c commit f76f92d

File tree

3 files changed

+37
-5
lines changed

3 files changed

+37
-5
lines changed

python/packages/autogen-core/tests/test_model_context.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,10 @@ async def test_token_limited_model_context_with_token_limit(
137137
await model_context.add_message(msg)
138138

139139
retrieved = await model_context.get_messages()
140-
assert len(retrieved) == 1 # Token limit set very low, will remove 2 of the messages
140+
# Token limit set low, will remove some messages
141+
# OpenAI: keeps 2 messages (29 tokens with limit 30)
142+
# Ollama: keeps 1 message (20 tokens with limit 20)
143+
assert len(retrieved) < len(messages) # Some messages removed due to token limit
141144
assert retrieved != messages # Will not be equal to the original messages
142145

143146
await model_context.clear()
@@ -151,7 +154,7 @@ async def test_token_limited_model_context_with_token_limit(
151154
await model_context.clear()
152155
await model_context.load_state(state)
153156
retrieved = await model_context.get_messages()
154-
assert len(retrieved) == 1
157+
assert len(retrieved) < len(messages) # Some messages removed due to token limit
155158
assert retrieved != messages
156159

157160

python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,17 @@ def count_tokens_openai(
393393
elif field == "description":
394394
tool_tokens += 2
395395
tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore
396+
elif field == "anyOf":
397+
tool_tokens -= 3
398+
for o in v["anyOf"]: # type: ignore
399+
tool_tokens += 3
400+
tool_tokens += len(encoding.encode(str(o["type"]))) # pyright: ignore
401+
elif field == "default":
402+
tool_tokens += 2
403+
tool_tokens += len(encoding.encode(json.dumps(v["default"])))
404+
elif field == "title":
405+
tool_tokens += 2
406+
tool_tokens += len(encoding.encode(str(v["title"]))) # pyright: ignore
396407
elif field == "enum":
397408
tool_tokens -= 3
398409
for o in v["enum"]: # pyright: ignore
@@ -404,7 +415,9 @@ def count_tokens_openai(
404415
if len(parameters["properties"]) == 0: # pyright: ignore
405416
tool_tokens -= 2
406417
num_tokens += tool_tokens
407-
num_tokens += 12
418+
419+
if oai_tools:
420+
num_tokens += 12
408421
return num_tokens
409422

410423

python/packages/autogen-ext/tests/models/test_openai_model_client.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import logging
44
import os
5-
from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Tuple, TypeVar
5+
from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Optional, Tuple, TypeVar
66
from unittest.mock import AsyncMock, MagicMock
77

88
import httpx
@@ -450,11 +450,27 @@ def tool1(test: str, test2: str) -> str:
450450
def tool2(test1: int, test2: List[int]) -> str:
451451
return str(test1) + str(test2)
452452

453-
tools = [FunctionTool(tool1, description="example tool 1"), FunctionTool(tool2, description="example tool 2")]
453+
def tool3(test1: Annotated[Optional[str], "example"] = None, test2: Literal["1", "2"] = "2") -> str:
454+
return str(test1) + str(test2)
455+
456+
tools = [
457+
FunctionTool(tool1, description="example tool 1"),
458+
FunctionTool(tool2, description="example tool 2"),
459+
FunctionTool(tool3, description="example tool 3"),
460+
]
454461

455462
mockcalculate_vision_tokens = MagicMock()
456463
monkeypatch.setattr("autogen_ext.models.openai._openai_client.calculate_vision_tokens", mockcalculate_vision_tokens)
457464

465+
# Test count_tokens without tools
466+
num_tokens = client.count_tokens(messages)
467+
assert num_tokens
468+
469+
# Check that calculate_vision_tokens was called
470+
mockcalculate_vision_tokens.assert_called_once()
471+
mockcalculate_vision_tokens.reset_mock()
472+
473+
# Test count_tokens with tools
458474
num_tokens = client.count_tokens(messages, tools=tools)
459475
assert num_tokens
460476

0 commit comments

Comments
 (0)