Skip to content

Commit

Permalink
Updated the azure client to support AAD auth. (#2879)
Browse files Browse the repository at this point in the history
  • Loading branch information
afourney authored and victordibia committed Jul 30, 2024
1 parent 906c4c7 commit 34a0e99
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 4 deletions.
35 changes: 32 additions & 3 deletions autogen/logger/sqlite_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,16 @@ def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any]) -> N

args = to_dict(
init_args,
exclude=("self", "__class__", "api_key", "organization", "base_url", "azure_endpoint"),
exclude=(
"self",
"__class__",
"api_key",
"organization",
"base_url",
"azure_endpoint",
"azure_ad_token",
"azure_ad_token_provider",
),
no_recursive=(Agent,),
)

Expand Down Expand Up @@ -301,7 +310,17 @@ def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLM
return

args = to_dict(
init_args, exclude=("self", "__class__", "api_key", "organization", "base_url", "azure_endpoint")
init_args,
exclude=(
"self",
"__class__",
"api_key",
"organization",
"base_url",
"azure_endpoint",
"azure_ad_token",
"azure_ad_token_provider",
),
)

query = """
Expand All @@ -323,7 +342,17 @@ def log_new_client(
return

args = to_dict(
init_args, exclude=("self", "__class__", "api_key", "organization", "base_url", "azure_endpoint")
init_args,
exclude=(
"self",
"__class__",
"api_key",
"organization",
"base_url",
"azure_endpoint",
"azure_ad_token",
"azure_ad_token_provider",
),
)

query = """
Expand Down
8 changes: 8 additions & 0 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,14 @@ def _configure_azure_openai(self, config: Dict[str, Any], openai_config: Dict[st
openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "")
openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None))

# Create a default Azure token provider if requested
if openai_config.get("azure_ad_token_provider") == "DEFAULT":
import azure.identity

openai_config["azure_ad_token_provider"] = azure.identity.get_bearer_token_provider(
azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)

def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None:
"""Create a client with the given config to override openai_config,
after removing extra kwargs.
Expand Down
2 changes: 1 addition & 1 deletion autogen/oai/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from openai.types.beta.assistant import Assistant
from packaging.version import parse

NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version"]
NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version", "azure_ad_token", "azure_ad_token_provider"]
DEFAULT_AZURE_API_VERSION = "2024-02-15-preview"
OAI_PRICE1K = {
# https://openai.com/api/pricing/
Expand Down

0 comments on commit 34a0e99

Please sign in to comment.