Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, get_args

from autogen_core.models import LLMMessage, ModelFamily

Expand Down Expand Up @@ -87,10 +87,13 @@ def _find_model_family(api: str, model: str) -> str:
Finds the best matching model family for the given model.
Search via prefix matching (e.g. "gpt-4o" → "gpt-4o-1.0").
"""
len_family = 0
family = ModelFamily.UNKNOWN
for _family in MESSAGE_TRANSFORMERS[api].keys():
if model.startswith(_family):
family = _family
if len(_family) > len_family:
family = _family
len_family = len(_family)
return family


Expand All @@ -108,13 +111,14 @@ def get_transformer(api: str, model: str, model_family: str) -> TransformerMap:
Keeping this as a function (instead of direct dict access) improves long-term flexibility.
"""

if model_family == ModelFamily.UNKNOWN:
if model_family not in set(get_args(ModelFamily.ANY)) or model_family == ModelFamily.UNKNOWN:
# fallback to finding the best matching model family
model_family = _find_model_family(api, model)

transformer = MESSAGE_TRANSFORMERS.get(api, {}).get(model_family, {})

if not transformer:
# Just in case, we should never reach here
raise ValueError(f"No transformer found for model family '{model_family}'")

return transformer
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
to_oai_type,
)
from autogen_ext.models.openai._transformation import TransformerMap, get_transformer
from autogen_ext.models.openai._transformation.registry import _find_model_family # pyright: ignore[reportPrivateUsage]
from openai.resources.beta.chat.completions import ( # type: ignore
AsyncChatCompletionStreamManager as BetaAsyncChatCompletionStreamManager, # type: ignore
)
Expand Down Expand Up @@ -2394,11 +2395,6 @@ def get_regitered_transformer(client: OpenAIChatCompletionClient) -> Transformer
assert get_regitered_transformer(client1) == get_regitered_transformer(client2)


def test_openai_model_registry_find_wrong() -> None:
with pytest.raises(ValueError, match="No transformer found for model family"):
get_transformer("openai", "gpt-7", "foobar")


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model",
Expand Down Expand Up @@ -2451,4 +2447,13 @@ def test_rstrip_railing_whitespace_at_last_assistant_content() -> None:
assert result[-1].content == "foobar"


def test_find_model_family() -> None:
assert _find_model_family("openai", "gpt-4") == ModelFamily.GPT_4
assert _find_model_family("openai", "gpt-4-latest") == ModelFamily.GPT_4
assert _find_model_family("openai", "gpt-4o") == ModelFamily.GPT_4O
assert _find_model_family("openai", "gemini-2.0-flash") == ModelFamily.GEMINI_2_0_FLASH
assert _find_model_family("openai", "claude-3-5-haiku-20241022") == ModelFamily.CLAUDE_3_5_HAIKU
assert _find_model_family("openai", "error") == ModelFamily.UNKNOWN


# TODO: add integration tests for Azure OpenAI using AAD token.
Loading