diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 70a28e15ed60..5d9c61300422 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -11,7 +11,7 @@ from autogen.token_count_utils import count_token try: - from openai import OpenAI, APIError + from openai import OpenAI, APIError, AzureOpenAI from openai.types.chat import ChatCompletion from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion import Completion @@ -137,7 +137,19 @@ def _client(self, config, openai_config): """ openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}} self._process_for_azure(openai_config, config) - client = OpenAI(**openai_config) + + # TODO: move this to _process_for_azure and rewrite it + if config.get("api_type", "").startswith("azure"): + azure_config = config.copy() + if azure_config.get("model"): + azure_config["azure_deployment"] = config.get("model").replace("gpt-3.5", "gpt-35") + azure_config.pop("model") + azure_config["azure_endpoint"] = config.get("base_url") + azure_config.pop("base_url") + config.pop("api_type") + client = AzureOpenAI(**config) + else: + client = OpenAI(**openai_config) return client @classmethod