Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from autogen.token_count_utils import count_token

try:
from openai import OpenAI, APIError
from openai import OpenAI, APIError, AzureOpenAI
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion import Completion
Expand Down Expand Up @@ -137,7 +137,19 @@ def _client(self, config, openai_config):
"""
openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}}
self._process_for_azure(openai_config, config)
client = OpenAI(**openai_config)

# TODO: move this to _process_for_azure and rewrite it
if config.get("api_type", "").startswith("azure"):
azure_config = config.copy()
if azure_config.get("model"):
azure_config["azure_deployment"] = config.get("model").replace("gpt-3.5", "gpt-35")
azure_config.pop("model")
azure_config["azure_endpoint"] = config.get("base_url")
azure_config.pop("base_url")
config.pop("api_type")
client = AzureOpenAI(**config)
else:
client = OpenAI(**openai_config)
return client

@classmethod
Expand Down