Skip to content

Commit

Permalink
feat: add support for managed identity for Azure OpenAI (run-llama#922)
Browse files Browse the repository at this point in the history
Co-authored-by: Alex Yang <himself65@outlook.com>
  • Loading branch information
manekinekko and himself65 authored Jun 11, 2024
1 parent c8cfc6c commit a51ed8d
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 52 deletions.
5 changes: 5 additions & 0 deletions .changeset/giant-gorillas-explain.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"llamaindex": patch
---

feat: add support for managed identity for Azure OpenAI
19 changes: 6 additions & 13 deletions packages/core/src/embeddings/OpenAIEmbedding.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import type { ClientOptions as OpenAIClientOptions } from "openai";
import type { AzureOpenAIConfig } from "../llm/azure.js";
import {
getAzureBaseUrl,
getAzureConfigFromEnv,
getAzureModel,
shouldUseAzure,
Expand Down Expand Up @@ -67,28 +66,22 @@ export class OpenAIEmbedding extends BaseEmbedding {
this.additionalSessionOptions = init?.additionalSessionOptions;

if (init?.azure || shouldUseAzure()) {
const azureConfig = getAzureConfigFromEnv({
const azureConfig = {
...getAzureConfigFromEnv({
model: getAzureModel(this.model),
}),
...init?.azure,
model: getAzureModel(this.model),
});

if (!azureConfig.apiKey) {
throw new Error(
"Azure API key is required for OpenAI Azure models. Please set the AZURE_OPENAI_KEY environment variable.",
);
}
};

this.apiKey = azureConfig.apiKey;
this.session =
init?.session ??
getOpenAISession({
azure: true,
apiKey: this.apiKey,
baseURL: getAzureBaseUrl(azureConfig),
maxRetries: this.maxRetries,
timeout: this.timeout,
defaultQuery: { "api-version": azureConfig.apiVersion },
...this.additionalSessionOptions,
...azureConfig,
});
} else {
this.apiKey = init?.apiKey ?? undefined;
Expand Down
21 changes: 12 additions & 9 deletions packages/core/src/llm/azure.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { getEnv } from "@llamaindex/env";

export interface AzureOpenAIConfig {
apiKey?: string;
endpoint?: string;
apiVersion?: string;
import type { AzureClientOptions } from "openai";

export interface AzureOpenAIConfig extends AzureClientOptions {
/** @deprecated use "deployment" instead */
deploymentName?: string;
}

Expand Down Expand Up @@ -81,6 +81,12 @@ const DEFAULT_API_VERSION = "2023-05-15";
export function getAzureConfigFromEnv(
init?: Partial<AzureOpenAIConfig> & { model?: string },
): AzureOpenAIConfig {
const deployment =
init?.deploymentName ??
init?.deployment ??
getEnv("AZURE_OPENAI_DEPLOYMENT") ?? // From Azure docs
getEnv("AZURE_OPENAI_API_DEPLOYMENT_NAME") ?? // LCJS compatible
init?.model; // Fall back to model name, Python compatible
return {
apiKey:
init?.apiKey ??
Expand All @@ -98,11 +104,8 @@ export function getAzureConfigFromEnv(
getEnv("OPENAI_API_VERSION") ?? // Python compatible
getEnv("AZURE_OPENAI_API_VERSION") ?? // LCJS compatible
DEFAULT_API_VERSION,
deploymentName:
init?.deploymentName ??
getEnv("AZURE_OPENAI_DEPLOYMENT") ?? // From Azure docs
getEnv("AZURE_OPENAI_API_DEPLOYMENT_NAME") ?? // LCJS compatible
init?.model, // Fall back to model name, Python compatible
deploymentName: deployment, // LCJS compatible
deployment, // For Azure OpenAI
};
}

Expand Down
46 changes: 16 additions & 30 deletions packages/core/src/llm/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import type {
ClientOptions,
ClientOptions as OpenAIClientOptions,
} from "openai";
import { OpenAI as OrigOpenAI } from "openai";
import { AzureOpenAI, OpenAI as OrigOpenAI } from "openai";

import type {
ChatCompletionAssistantMessageParam,
Expand All @@ -23,7 +23,6 @@ import { getCallbackManager } from "../internal/settings/CallbackManager.js";
import type { BaseTool } from "../types.js";
import type { AzureOpenAIConfig } from "./azure.js";
import {
getAzureBaseUrl,
getAzureConfigFromEnv,
getAzureModel,
shouldUseAzure,
Expand All @@ -43,30 +42,23 @@ import type {
} from "./types.js";
import { extractText, wrapLLMEvent } from "./utils.js";

export class AzureOpenAI extends OrigOpenAI {
protected override authHeaders() {
return { "api-key": this.apiKey };
}
}

export class OpenAISession {
openai: OrigOpenAI;

constructor(options: ClientOptions & { azure?: boolean } = {}) {
if (!options.apiKey) {
options.apiKey = getEnv("OPENAI_API_KEY");
}

if (!options.apiKey) {
throw new Error("Set OpenAI Key in OPENAI_API_KEY env variable"); // Overriding OpenAI package's error message
}

if (options.azure) {
this.openai = new AzureOpenAI(options);
this.openai = new AzureOpenAI(options as AzureOpenAIConfig);
} else {
if (!options.apiKey) {
options.apiKey = getEnv("OPENAI_API_KEY");
}

if (!options.apiKey) {
throw new Error("Set OpenAI Key in OPENAI_API_KEY env variable"); // Overriding OpenAI package's error message
}

this.openai = new OrigOpenAI({
...options,
// defaultHeaders: { "OpenAI-Beta": "assistants=v1" },
});
}
}
Expand Down Expand Up @@ -195,28 +187,22 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> {
this.additionalSessionOptions = init?.additionalSessionOptions;

if (init?.azure || shouldUseAzure()) {
const azureConfig = getAzureConfigFromEnv({
const azureConfig = {
...getAzureConfigFromEnv({
model: getAzureModel(this.model),
}),
...init?.azure,
model: getAzureModel(this.model),
});

if (!azureConfig.apiKey) {
throw new Error(
"Azure API key is required for OpenAI Azure models. Please set the AZURE_OPENAI_KEY environment variable.",
);
}
};

this.apiKey = azureConfig.apiKey;
this.session =
init?.session ??
getOpenAISession({
azure: true,
apiKey: this.apiKey,
baseURL: getAzureBaseUrl(azureConfig),
maxRetries: this.maxRetries,
timeout: this.timeout,
defaultQuery: { "api-version": azureConfig.apiVersion },
...this.additionalSessionOptions,
...azureConfig,
});
} else {
this.apiKey = init?.apiKey ?? undefined;
Expand Down

0 comments on commit a51ed8d

Please sign in to comment.