Skip to content

Commit

Permalink
[Azure] Refresh AAD token on retry (#1003)
Browse files Browse the repository at this point in the history
* [Azure] Refresh AAD token on retry

* add context
  • Loading branch information
deyaaeldeen authored Aug 20, 2024
1 parent 39731a6 commit fe8bbaa
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,13 @@ export class AzureOpenAI extends OpenAI {
}

protected override async prepareOptions(opts: Core.FinalRequestOptions<unknown>): Promise<void> {
if (opts.headers?.['Authorization'] || opts.headers?.['api-key']) {
/**
* The user should provide a bearer token provider if they want
* to use Azure AD authentication. The user shouldn't set the
* Authorization header manually because the header is overwritten
* with the Azure AD token if a bearer token provider is provided.
*/
if (opts.headers?.['api-key']) {
return super.prepareOptions(opts);
}
const token = await this._getAzureADToken();
Expand Down
37 changes: 37 additions & 0 deletions tests/lib/azure.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,43 @@ describe('instantiate azure client', () => {
/The `apiKey` and `azureADTokenProvider` arguments are mutually exclusive; only one can be passed at a time./,
);
});

test('AAD token is refreshed', async () => {
let fail = true;
const testFetch = async (url: RequestInfo, req: RequestInit | undefined): Promise<Response> => {
if (fail) {
fail = false;
return new Response(undefined, {
status: 429,
headers: {
'Retry-After': '0.1',
},
});
}
return new Response(
JSON.stringify({ auth: (req?.headers as Record<string, string>)['authorization'] }),
{ headers: { 'content-type': 'application/json' } },
);
};
let counter = 0;
async function azureADTokenProvider() {
return `token-${counter++}`;
}
const client = new AzureOpenAI({
baseURL: 'http://localhost:5000/',
azureADTokenProvider,
apiVersion,
fetch: testFetch,
});
expect(
await client.chat.completions.create({
model,
messages: [{ role: 'system', content: 'Hello' }],
}),
).toStrictEqual({
auth: 'Bearer token-1',
});
});
});

test('with endpoint', () => {
Expand Down

0 comments on commit fe8bbaa

Please sign in to comment.