From 8789e1188d6f378f0ccd80c1b610382b4262dbab Mon Sep 17 00:00:00 2001 From: Stainless Bot Date: Mon, 6 May 2024 14:35:59 +0000 Subject: [PATCH] fix(azure): update build script --- scripts/utils/fix-index-exports.cjs | 2 +- src/index.ts | 28 +++++++++++++++++----------- tests/lib/azure.test.ts | 19 ++++++++++++------- 3 files changed, 30 insertions(+), 19 deletions(-) diff --git a/scripts/utils/fix-index-exports.cjs b/scripts/utils/fix-index-exports.cjs index 72b0b8fd0..ee5cebb85 100644 --- a/scripts/utils/fix-index-exports.cjs +++ b/scripts/utils/fix-index-exports.cjs @@ -9,6 +9,6 @@ const indexJs = let before = fs.readFileSync(indexJs, 'utf8'); let after = before.replace( /^\s*exports\.default\s*=\s*(\w+)/m, - 'exports = module.exports = $1;\nexports.default = $1', + 'exports = module.exports = $1;\nmodule.exports.AzureOpenAI = AzureOpenAI;\nexports.default = $1', ); fs.writeFileSync(indexJs, after, 'utf8'); diff --git a/src/index.ts b/src/index.ts index dbade2f86..438a46779 100644 --- a/src/index.ts +++ b/src/index.ts @@ -339,12 +339,12 @@ export interface AzureClientOptions extends ClientOptions { * A function that returns an access token for Microsoft Entra (formerly known as Azure Active Directory), * which will be invoked on every request. */ - azureADTokenProvider?: (() => string) | undefined; + azureADTokenProvider?: (() => Promise) | undefined; } /** API Client for interfacing with the Azure OpenAI API. */ export class AzureOpenAI extends OpenAI { - private _azureADTokenProvider: (() => string) | undefined; + private _azureADTokenProvider: (() => Promise) | undefined; apiVersion: string = ''; /** * API Client for interfacing with the Azure OpenAI API. @@ -451,9 +451,9 @@ export class AzureOpenAI extends OpenAI { return super.buildRequest(options); } - private _getAzureADToken(): string | undefined { + private async _getAzureADToken(): Promise { if (typeof this._azureADTokenProvider === 'function') { - const token = this._azureADTokenProvider(); + const token = await this._azureADTokenProvider(); if (!token || typeof token !== 'string') { throw new Errors.OpenAIError( `Expected 'azureADTokenProvider' argument to return a string but it returned ${token}`, @@ -465,17 +465,23 @@ export class AzureOpenAI extends OpenAI { } protected override authHeaders(opts: Core.FinalRequestOptions): Core.Headers { + return {}; + } + + protected override async prepareOptions(opts: Core.FinalRequestOptions): Promise { if (opts.headers?.['Authorization'] || opts.headers?.['api-key']) { - return {}; + return super.prepareOptions(opts); } - const token = this._getAzureADToken(); + const token = await this._getAzureADToken(); + opts.headers ??= {}; if (token) { - return { Authorization: `Bearer ${token}` }; - } - if (this.apiKey !== API_KEY_SENTINEL) { - return { 'api-key': this.apiKey }; + opts.headers['Authorization'] = `Bearer ${token}`; + } else if (this.apiKey !== API_KEY_SENTINEL) { + opts.headers['api-key'] = this.apiKey; + } else { + throw new Errors.OpenAIError('Unable to handle auth'); } - throw new Errors.OpenAIError('Unable to handle auth'); + return super.prepareOptions(opts); } } diff --git a/tests/lib/azure.test.ts b/tests/lib/azure.test.ts index e2b967903..4895273be 100644 --- a/tests/lib/azure.test.ts +++ b/tests/lib/azure.test.ts @@ -222,16 +222,21 @@ describe('instantiate azure client', () => { }); describe('Azure Active Directory (AD)', () => { - test('with azureADTokenProvider', () => { + test('with azureADTokenProvider', async () => { + const testFetch = async (url: RequestInfo, { headers }: RequestInit = {}): Promise => { + return new Response(JSON.stringify({ a: 1 }), { headers }); + }; const client = new AzureOpenAI({ baseURL: 'http://localhost:5000/', - azureADTokenProvider: () => 'my token', + azureADTokenProvider: async () => 'my token', apiVersion, + fetch: testFetch, }); - expect(client.buildRequest({ method: 'post', path: 'https://example.com' }).req.headers).toHaveProperty( - 'authorization', - 'Bearer my token', - ); + expect( + (await client.request({ method: 'post', path: 'https://example.com' }).asResponse()).headers.get( + 'authorization', + ), + ).toEqual('Bearer my token'); }); test('apiKey and azureADTokenProvider cant be combined', () => { @@ -239,7 +244,7 @@ describe('instantiate azure client', () => { () => new AzureOpenAI({ baseURL: 'http://localhost:5000/', - azureADTokenProvider: () => 'my token', + azureADTokenProvider: async () => 'my token', apiKey: 'My API Key', apiVersion, }),