diff --git a/src/index.ts b/src/index.ts index 422e26c92..0c6c7badb 100644 --- a/src/index.ts +++ b/src/index.ts @@ -485,7 +485,13 @@ export class AzureOpenAI extends OpenAI { } protected override async prepareOptions(opts: Core.FinalRequestOptions): Promise { - if (opts.headers?.['Authorization'] || opts.headers?.['api-key']) { + /** + * The user should provide a bearer token provider if they want + * to use Azure AD authentication. The user shouldn't set the + * Authorization header manually because the header is overwritten + * with the Azure AD token if a bearer token provider is provided. + */ + if (opts.headers?.['api-key']) { return super.prepareOptions(opts); } const token = await this._getAzureADToken(); diff --git a/tests/lib/azure.test.ts b/tests/lib/azure.test.ts index 6bb6e0d1e..064a0098c 100644 --- a/tests/lib/azure.test.ts +++ b/tests/lib/azure.test.ts @@ -254,6 +254,43 @@ describe('instantiate azure client', () => { /The `apiKey` and `azureADTokenProvider` arguments are mutually exclusive; only one can be passed at a time./, ); }); + + test('AAD token is refreshed', async () => { + let fail = true; + const testFetch = async (url: RequestInfo, req: RequestInit | undefined): Promise => { + if (fail) { + fail = false; + return new Response(undefined, { + status: 429, + headers: { + 'Retry-After': '0.1', + }, + }); + } + return new Response( + JSON.stringify({ auth: (req?.headers as Record)['authorization'] }), + { headers: { 'content-type': 'application/json' } }, + ); + }; + let counter = 0; + async function azureADTokenProvider() { + return `token-${counter++}`; + } + const client = new AzureOpenAI({ + baseURL: 'http://localhost:5000/', + azureADTokenProvider, + apiVersion, + fetch: testFetch, + }); + expect( + await client.chat.completions.create({ + model, + messages: [{ role: 'system', content: 'Hello' }], + }), + ).toStrictEqual({ + auth: 'Bearer token-1', + }); + }); }); test('with endpoint', () => {