diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_transformation/registry.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_transformation/registry.py index 5c1187fb5224..c3148110d057 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_transformation/registry.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_transformation/registry.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, get_args from autogen_core.models import LLMMessage, ModelFamily @@ -87,10 +87,13 @@ def _find_model_family(api: str, model: str) -> str: Finds the best matching model family for the given model. Search via prefix matching (e.g. "gpt-4o" → "gpt-4o-1.0"). """ + len_family = 0 family = ModelFamily.UNKNOWN for _family in MESSAGE_TRANSFORMERS[api].keys(): if model.startswith(_family): - family = _family + if len(_family) > len_family: + family = _family + len_family = len(_family) return family @@ -108,13 +111,14 @@ def get_transformer(api: str, model: str, model_family: str) -> TransformerMap: Keeping this as a function (instead of direct dict access) improves long-term flexibility. """ - if model_family == ModelFamily.UNKNOWN: + if model_family not in set(get_args(ModelFamily.ANY)) or model_family == ModelFamily.UNKNOWN: # fallback to finding the best matching model family model_family = _find_model_family(api, model) transformer = MESSAGE_TRANSFORMERS.get(api, {}).get(model_family, {}) if not transformer: + # Just in case, we should never reach here raise ValueError(f"No transformer found for model family '{model_family}'") return transformer diff --git a/python/packages/autogen-ext/tests/models/test_openai_model_client.py b/python/packages/autogen-ext/tests/models/test_openai_model_client.py index 97edb4dab054..60d5547d0302 100644 --- a/python/packages/autogen-ext/tests/models/test_openai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_openai_model_client.py @@ -30,6 +30,7 @@ to_oai_type, ) from autogen_ext.models.openai._transformation import TransformerMap, get_transformer +from autogen_ext.models.openai._transformation.registry import _find_model_family # pyright: ignore[reportPrivateUsage] from openai.resources.beta.chat.completions import ( # type: ignore AsyncChatCompletionStreamManager as BetaAsyncChatCompletionStreamManager, # type: ignore ) @@ -2394,11 +2395,6 @@ def get_regitered_transformer(client: OpenAIChatCompletionClient) -> Transformer assert get_regitered_transformer(client1) == get_regitered_transformer(client2) -def test_openai_model_registry_find_wrong() -> None: - with pytest.raises(ValueError, match="No transformer found for model family"): - get_transformer("openai", "gpt-7", "foobar") - - @pytest.mark.asyncio @pytest.mark.parametrize( "model", @@ -2451,4 +2447,13 @@ def test_rstrip_railing_whitespace_at_last_assistant_content() -> None: assert result[-1].content == "foobar" +def test_find_model_family() -> None: + assert _find_model_family("openai", "gpt-4") == ModelFamily.GPT_4 + assert _find_model_family("openai", "gpt-4-latest") == ModelFamily.GPT_4 + assert _find_model_family("openai", "gpt-4o") == ModelFamily.GPT_4O + assert _find_model_family("openai", "gemini-2.0-flash") == ModelFamily.GEMINI_2_0_FLASH + assert _find_model_family("openai", "claude-3-5-haiku-20241022") == ModelFamily.CLAUDE_3_5_HAIKU + assert _find_model_family("openai", "error") == ModelFamily.UNKNOWN + + # TODO: add integration tests for Azure OpenAI using AAD token.