Skip to content

Commit

Permalink
chore (test): test url search params explicitly (#2138)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored Jun 28, 2024
1 parent 3282fd1 commit 2a387d6
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 57 deletions.
101 changes: 45 additions & 56 deletions packages/azure/src/azure-openai-provider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@ const TEST_PROMPT: LanguageModelV1Prompt = [
{ role: 'user', content: [{ type: 'text', text: 'Hello' }] },
];

const provider = createAzure({
resourceName: 'test-resource',
apiKey: 'test-api-key',
});

describe('chat', () => {
describe('doGenerate', () => {
const server = new JsonTestServer(
'https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-05-01-preview',
'https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions',
);

server.setupTestEnvironment();
Expand Down Expand Up @@ -42,6 +47,21 @@ describe('chat', () => {
};
}

it('should set the correct api version', async () => {
prepareJsonResponse();

await provider('test-deployment').doGenerate({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});

const searchParams = await server.getRequestUrlSearchParams();
expect(searchParams.get('api-version')).toStrictEqual(
'2024-05-01-preview',
);
});

it('should pass headers', async () => {
prepareJsonResponse();

Expand Down Expand Up @@ -77,7 +97,7 @@ describe('chat', () => {
describe('completion', () => {
describe('doGenerate', () => {
const server = new JsonTestServer(
'https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo-instruct/completions?api-version=2024-05-01-preview',
'https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo-instruct/completions',
);

server.setupTestEnvironment();
Expand Down Expand Up @@ -122,6 +142,21 @@ describe('completion', () => {
};
}

it('should set the correct api version', async () => {
prepareJsonCompletionResponse({ content: 'Hello World!' });

await provider.completion('gpt-35-turbo-instruct').doGenerate({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});

const searchParams = await server.getRequestUrlSearchParams();
expect(searchParams.get('api-version')).toStrictEqual(
'2024-05-01-preview',
);
});

it('should pass headers', async () => {
prepareJsonCompletionResponse({ content: 'Hello World!' });

Expand Down Expand Up @@ -161,14 +196,9 @@ describe('embedding', () => {
];
const testValues = ['sunny day at the beach', 'rainy day in the city'];

const provider = createAzure({
resourceName: 'test-resource',
apiKey: 'test-api-key',
});

describe('doEmbed', () => {
const server = new JsonTestServer(
'https://test-resource.openai.azure.com/openai/deployments/my-embedding/embeddings?api-version=2024-05-01-preview',
'https://test-resource.openai.azure.com/openai/deployments/my-embedding/embeddings',
);

const model = provider.embedding('my-embedding');
Expand All @@ -192,58 +222,17 @@ describe('embedding', () => {
};
}

it('should extract embedding', async () => {
it('should set the correct api version', async () => {
prepareJsonResponse();

const { embeddings } = await model.doEmbed({ values: testValues });

expect(embeddings).toStrictEqual(dummyEmbeddings);
});

it('should expose the raw response headers', async () => {
prepareJsonResponse();

server.responseHeaders = {
'test-header': 'test-value',
};

const { rawResponse } = await model.doEmbed({ values: testValues });

expect(rawResponse?.headers).toStrictEqual({
// default headers:
'content-length': '226',
'content-type': 'application/json',

// custom header
'test-header': 'test-value',
});
});

it('should pass the model and the values', async () => {
prepareJsonResponse();

await model.doEmbed({ values: testValues });

expect(await server.getRequestBodyJson()).toStrictEqual({
model: 'my-embedding',
input: testValues,
encoding_format: 'float',
await model.doEmbed({
values: testValues,
});
});

it('should pass the dimensions setting', async () => {
prepareJsonResponse();

await provider
.embedding('my-embedding', { dimensions: 64 })
.doEmbed({ values: testValues });

expect(await server.getRequestBodyJson()).toStrictEqual({
model: 'my-embedding',
input: testValues,
encoding_format: 'float',
dimensions: 64,
});
const searchParams = await server.getRequestUrlSearchParams();
expect(searchParams.get('api-version')).toStrictEqual(
'2024-05-01-preview',
);
});

it('should pass headers', async () => {
Expand Down
15 changes: 14 additions & 1 deletion packages/google/src/google-generative-ai-language-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ describe('doGenerate', () => {

describe('doStream', () => {
const server = new StreamingTestServer(
'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent?alt=sse',
'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent',
);

server.setupTestEnvironment();
Expand Down Expand Up @@ -397,6 +397,19 @@ describe('doStream', () => {
});
});

it('should set streaming mode search param', async () => {
prepareStreamResponse({ content: [''] });

await model.doStream({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});

const searchParams = await server.getRequestUrlSearchParams();
expect(searchParams.get('alt')).toStrictEqual('sse');
});

it('should pass headers', async () => {
prepareStreamResponse({ content: [] });

Expand Down
5 changes: 5 additions & 0 deletions packages/provider-utils/src/test/json-test-server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ export class JsonTestServer {
return headersObject;
}

async getRequestUrlSearchParams() {
expect(this.request).toBeDefined();
return new URL(this.request!.url).searchParams;
}

setupTestEnvironment() {
beforeAll(() => this.server.listen());
beforeEach(() => {
Expand Down
5 changes: 5 additions & 0 deletions packages/provider-utils/src/test/streaming-test-server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ export class StreamingTestServer {
return headersObject;
}

async getRequestUrlSearchParams() {
expect(this.request).toBeDefined();
return new URL(this.request!.url).searchParams;
}

setupTestEnvironment() {
beforeAll(() => this.server.listen());
beforeEach(() => {
Expand Down

0 comments on commit 2a387d6

Please sign in to comment.