Skip to content

Commit

Permalink
feat: accept timeout as parameter in empiricalrc.json (#119)
Browse files Browse the repository at this point in the history
Co-authored-by: Saikat Mitra <saikatmitra91@gmail.com>
  • Loading branch information
KaustubhKumar05 and saikatmitra91 authored Apr 15, 2024
1 parent d087119 commit 9822db6
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 16 deletions.
6 changes: 6 additions & 0 deletions .changeset/good-spoons-own.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@empiricalrun/ai": minor
"@empiricalrun/types": patch
---

feat: accept timeout as parameter in empiricalrc.json
18 changes: 18 additions & 0 deletions docs/models/basics.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,24 @@ For example, Mistral models support a `safePrompt` parameter for [guardrailing](
]
```

#### Configuring request timeout

You can set the timeout duration in milliseconds under model parameters in the `empiricalrc.json` file. This might be required for prompt completions that are expected to take more time, for example while running models like Claude Opus. If no specific value is assigned, the default timeout duration of 30 seconds will be applied.

```json empiricalrc.json
"runs": [
{
"type": "model",
"provider": "anthropic",
"model": "claude-3-opus",
"prompt": "Hey I'm {{user_name}}",
"parameters": {
"timeout": 10000
}
}
]
```

#### Limitations

- These parameters are not supported today: `logit_bias`, `tools`, `tool_choice`, `user`, `stream`
Expand Down
1 change: 1 addition & 0 deletions packages/ai/src/constants/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export const DEFAULT_TIMEOUT = 30000;
10 changes: 8 additions & 2 deletions packages/ai/src/providers/anthropic/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { ChatCompletionMessageParam } from "openai/resources/chat/completions.mj
import promiseRetry from "promise-retry";
import { BatchTaskManager, getPassthroughParams } from "../../utils";
import { AIError, AIErrorEnum } from "../../error";
import { DEFAULT_TIMEOUT } from "../../constants";

const batchTaskManager = new BatchTaskManager(5);

Expand Down Expand Up @@ -54,10 +55,15 @@ const createChatCompletion: ICreateChatCompletion = async (body) => {
"process.env.ANTHROPIC_API_KEY is not set",
);
}
const { model, messages, ...config } = body;
const timeout = config.timeout || DEFAULT_TIMEOUT;
if (config.timeout) {
delete config.timeout;
}
const anthropic = new Anthropic({
apiKey: process.env.ANTHROPIC_API_KEY,
timeout: timeout,
});
const { model, messages, ...config } = body;
const { contents, systemPrompt } = convertOpenAIToAnthropicAI(messages);
const { executionDone } = await batchTaskManager.waitForTurn();
try {
Expand Down Expand Up @@ -130,7 +136,7 @@ const createChatCompletion: ICreateChatCompletion = async (body) => {
executionDone();
throw new AIError(
AIErrorEnum.FAILED_CHAT_COMPLETION,
`failed chat completion for model ${body.model} with message ${(e as Error).message} `,
`Failed to fetch output from model ${body.model}: ${(e as Error).message}`,
);
}
};
Expand Down
9 changes: 6 additions & 3 deletions packages/ai/src/providers/fireworks/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ const batchTaskManager = new BatchTaskManager(10);

const createChatCompletion: ICreateChatCompletion = async (body) => {
const { model, messages, ...config } = body;
if (config.timeout) {
delete config.timeout;
}
const payload = JSON.stringify({
model: `accounts/fireworks/models/${model}`,
messages,
Expand All @@ -35,6 +38,7 @@ const createChatCompletion: ICreateChatCompletion = async (body) => {
const { executionDone } = await batchTaskManager.waitForTurn();

try {
const startedAt = Date.now();
const completion = await promiseRetry<IChatCompletion>(
(retry) => {
return fetch("https://api.fireworks.ai/inference/v1/chat/completions", {
Expand Down Expand Up @@ -67,12 +71,11 @@ const createChatCompletion: ICreateChatCompletion = async (body) => {
},
{
randomize: true,
minTimeout: 1000,
},
);

const latency = Date.now() - startedAt;
executionDone();
return completion;
return { ...completion, latency };
} catch (err) {
throw new AIError(AIErrorEnum.FAILED_CHAT_COMPLETION, "Unknown error");
}
Expand Down
7 changes: 4 additions & 3 deletions packages/ai/src/providers/google/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { BatchTaskManager } from "../../utils";
import crypto from "crypto";
import promiseRetry from "promise-retry";
import { AIError, AIErrorEnum } from "../../error";
import { DEFAULT_TIMEOUT } from "../../constants";

const batch = new BatchTaskManager(5);

Expand Down Expand Up @@ -62,7 +63,8 @@ const createChatCompletion: ICreateChatCompletion = async (body) => {
}
const { model, messages } = body;
const googleAI = new GoogleGenerativeAI(process.env.GOOGLE_API_KEY!);
const modelInstance = googleAI.getGenerativeModel({ model });
const timeout = body.timeout || DEFAULT_TIMEOUT;
const modelInstance = googleAI.getGenerativeModel({ model }, { timeout });
const contents = massageOpenAIMessagesToGoogleAI(messages);
const { executionDone } = await batch.waitForTurn();
try {
Expand All @@ -79,7 +81,6 @@ const createChatCompletion: ICreateChatCompletion = async (body) => {
},
{
randomize: true,
minTimeout: 2000,
},
);
executionDone();
Expand Down Expand Up @@ -131,7 +132,7 @@ const createChatCompletion: ICreateChatCompletion = async (body) => {
executionDone();
throw new AIError(
AIErrorEnum.FAILED_CHAT_COMPLETION,
`failed chat completion for model ${body.model} with message ${(e as Error).message}`,
`Failed to fetch output from model ${body.model} with message ${(e as Error).message}`,
);
}
};
Expand Down
17 changes: 14 additions & 3 deletions packages/ai/src/providers/mistral/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
import { BatchTaskManager, getPassthroughParams } from "../../utils";
import { ToolCalls, ResponseFormat } from "@mistralai/mistralai";
import { AIError, AIErrorEnum } from "../../error";
import { DEFAULT_TIMEOUT } from "../../constants";

type MistralChatMessage = {
role: string;
Expand All @@ -29,9 +30,19 @@ const createChatCompletion: ICreateChatCompletion = async function (body) {
);
}
const MistralClient = await importMistral();
const mistralai = new MistralClient(process.env.MISTRAL_API_KEY);
const { executionDone } = await batch.waitForTurn();
const { model, messages, ...config } = body;
const mistralai = new MistralClient(
process.env.MISTRAL_API_KEY,
undefined,
// type issue in https://github.com/mistralai/client-js/blob/e33a2f3e5f6fb88fd083e8e7d9c3c081d1c7c0e4/src/client.js#L51, will submit a PR later
// @ts-ignore default value for retries
5,
(config.timeout || DEFAULT_TIMEOUT) / 1000, // Mistral expects values in seconds
);
if (config.timeout) {
delete config.timeout;
}
const { executionDone } = await batch.waitForTurn();
try {
// typecasting as there is a minor difference in role being openai enum vs string
const mistralMessages = messages as MistralChatMessage[];
Expand All @@ -56,7 +67,7 @@ const createChatCompletion: ICreateChatCompletion = async function (body) {
executionDone();
throw new AIError(
AIErrorEnum.FAILED_CHAT_COMPLETION,
`failed chat completion for model ${body.model} with message ${(err as Error).message}`,
`Failed to fetch output from model ${body.model} with message ${(err as Error).message}`,
);
}
};
Expand Down
9 changes: 8 additions & 1 deletion packages/ai/src/providers/openai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
import OpenAI from "openai";
import promiseRetry from "promise-retry";
import { AIError, AIErrorEnum } from "../../error";
import { DEFAULT_TIMEOUT } from "../../constants";

const createChatCompletion: ICreateChatCompletion = async (body) => {
const apiKey = process.env.OPENAI_API_KEY;
Expand All @@ -15,9 +16,15 @@ const createChatCompletion: ICreateChatCompletion = async (body) => {
"process.env.OPENAI_API_KEY is not set",
);
}
const timeout = body.timeout || DEFAULT_TIMEOUT;
if (body.timeout) {
delete body.timeout;
}
const openai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY,
timeout,
});

try {
const startedAt = Date.now();
const completions = await promiseRetry<IChatCompletion>(
Expand Down Expand Up @@ -50,7 +57,7 @@ const createChatCompletion: ICreateChatCompletion = async (body) => {
} catch (err) {
throw new AIError(
AIErrorEnum.FAILED_CHAT_COMPLETION,
`Failed completion for OpenAI ${body.model}: ${(err as any)?.error?.message}`,
`Failed to fetch output from model ${body.model}: ${(err as any)?.error?.message}`,
);
}
};
Expand Down
1 change: 1 addition & 0 deletions packages/ai/src/utils/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ function isReservedParameter(paramName: string) {
"seed",
"stop",
"top_logprobs",
"timeout",
];
return reservedParameters.indexOf(paramName) >= 0;
}
Expand Down
1 change: 1 addition & 0 deletions packages/types/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ interface ModelParameters {
seed?: number;
stop?: string | Array<string>;
top_logprobs?: number;
timeout?: number;

// For other models, we coerce the above known parameters to appropriate slots
// If users require other parameters, we support passthrough for other key names
Expand Down
8 changes: 4 additions & 4 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 9822db6

Please sign in to comment.