Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚀 feat: o1 #4019

Merged
merged 7 commits into from
Sep 12, 2024
16 changes: 13 additions & 3 deletions api/app/clients/AnthropicClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ class AnthropicClient extends BaseClient {
/** Whether or not the model supports Prompt Caching
* @type {boolean} */
this.supportsCacheControl;
/** The key for the usage object's input tokens
* @type {string} */
this.inputTokensKey = 'input_tokens';
/** The key for the usage object's output tokens
* @type {string} */
this.outputTokensKey = 'output_tokens';
}

setOptions(options) {
Expand Down Expand Up @@ -200,15 +206,15 @@ class AnthropicClient extends BaseClient {
}

/**
* Calculates the correct token count for the current message based on the token count map and API usage.
* Calculates the correct token count for the current user message based on the token count map and API usage.
* Edge case: If the calculation results in a negative value, it returns the original estimate.
* If revisiting a conversation with a chat history entirely composed of token estimates,
* the cumulative token count going forward should become more accurate as the conversation progresses.
* @param {Object} params - The parameters for the calculation.
* @param {Record<string, number>} params.tokenCountMap - A map of message IDs to their token counts.
* @param {string} params.currentMessageId - The ID of the current message to calculate.
* @param {AnthropicStreamUsage} params.usage - The usage object returned by the API.
* @returns {number} The correct token count for the current message.
* @returns {number} The correct token count for the current user message.
*/
calculateCurrentTokenCount({ tokenCountMap, currentMessageId, usage }) {
const originalEstimate = tokenCountMap[currentMessageId] || 0;
Expand Down Expand Up @@ -680,7 +686,11 @@ class AnthropicClient extends BaseClient {
*/
checkPromptCacheSupport(modelName) {
const modelMatch = matchModelName(modelName, EModelEndpoint.anthropic);
if (modelMatch === 'claude-3-5-sonnet' || modelMatch === 'claude-3-haiku') {
if (
modelMatch === 'claude-3-5-sonnet' ||
modelMatch === 'claude-3-haiku' ||
modelMatch === 'claude-3-opus'
) {
return true;
}
return false;
Expand Down
12 changes: 9 additions & 3 deletions api/app/clients/BaseClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ class BaseClient {
this.conversationId;
/** @type {string} */
this.responseMessageId;
/** The key for the usage object's input tokens
* @type {string} */
this.inputTokensKey = 'prompt_tokens';
/** The key for the usage object's output tokens
* @type {string} */
this.outputTokensKey = 'completion_tokens';
}

setOptions() {
Expand Down Expand Up @@ -604,8 +610,8 @@ class BaseClient {
* @type {StreamUsage | null} */
const usage = this.getStreamUsage != null ? this.getStreamUsage() : null;

if (usage != null && Number(usage.output_tokens) > 0) {
responseMessage.tokenCount = usage.output_tokens;
if (usage != null && Number(usage[this.outputTokensKey]) > 0) {
responseMessage.tokenCount = usage[this.outputTokensKey];
completionTokens = responseMessage.tokenCount;
await this.updateUserMessageTokenCount({ usage, tokenCountMap, userMessage, opts });
} else {
Expand Down Expand Up @@ -655,7 +661,7 @@ class BaseClient {
/** @type {boolean} */
const shouldUpdateCount =
this.calculateCurrentTokenCount != null &&
Number(usage.input_tokens) > 0 &&
Number(usage[this.inputTokensKey]) > 0 &&
(this.options.resendFiles ||
(!this.options.resendFiles && !this.options.attachments?.length)) &&
!this.options.promptPrefix;
Expand Down
109 changes: 105 additions & 4 deletions api/app/clients/OpenAIClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const {
constructAzureURL,
getModelMaxTokens,
genAzureChatCompletion,
getModelMaxOutputTokens,
} = require('~/utils');
const {
truncateText,
Expand Down Expand Up @@ -64,6 +65,9 @@ class OpenAIClient extends BaseClient {

/** @type {string | undefined} - The API Completions URL */
this.completionsUrl;

/** @type {OpenAIUsageMetadata | undefined} */
this.usage;
}

// TODO: PluginsClient calls this 3x, unneeded
Expand Down Expand Up @@ -138,7 +142,8 @@ class OpenAIClient extends BaseClient {

const { model } = this.modelOptions;

this.isChatCompletion = this.useOpenRouter || !!reverseProxy || model.includes('gpt');
this.isChatCompletion =
/\bo1\b/i.test(model) || model.includes('gpt') || this.useOpenRouter || !!reverseProxy;
this.isChatGptModel = this.isChatCompletion;
if (
model.includes('text-davinci') ||
Expand Down Expand Up @@ -169,7 +174,14 @@ class OpenAIClient extends BaseClient {
logger.debug('[OpenAIClient] maxContextTokens', this.maxContextTokens);
}

this.maxResponseTokens = this.modelOptions.max_tokens || 1024;
this.maxResponseTokens =
this.modelOptions.max_tokens ??
getModelMaxOutputTokens(
model,
this.options.endpointType ?? this.options.endpoint,
this.options.endpointTokenConfig,
) ??
1024;
this.maxPromptTokens =
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;

Expand Down Expand Up @@ -533,7 +545,8 @@ class OpenAIClient extends BaseClient {
promptPrefix = this.augmentedPrompt + promptPrefix;
}

if (promptPrefix) {
const isO1Model = /\bo1\b/i.test(this.modelOptions.model);
if (promptPrefix && !isO1Model) {
promptPrefix = `Instructions:\n${promptPrefix.trim()}`;
instructions = {
role: 'system',
Expand Down Expand Up @@ -561,6 +574,16 @@ class OpenAIClient extends BaseClient {
messages,
};

/** EXPERIMENTAL */
if (promptPrefix && isO1Model) {
const lastUserMessageIndex = payload.findLastIndex((message) => message.role === 'user');
if (lastUserMessageIndex !== -1) {
payload[
lastUserMessageIndex
].content = `${promptPrefix}\n${payload[lastUserMessageIndex].content}`;
}
}

if (tokenCountMap) {
tokenCountMap.instructions = instructions?.tokenCount;
result.tokenCountMap = tokenCountMap;
Expand Down Expand Up @@ -885,6 +908,56 @@ ${convo}
return title;
}

/**
* Get stream usage as returned by this client's API response.
* @returns {OpenAIUsageMetadata} The stream usage object.
*/
getStreamUsage() {
if (
typeof this.usage === 'object' &&
typeof this.usage.completion_tokens_details === 'object'
) {
const outputTokens = Math.abs(
this.usage.completion_tokens_details.reasoning_tokens - this.usage[this.outputTokensKey],
);
return {
...this.usage.completion_tokens_details,
[this.inputTokensKey]: this.usage[this.inputTokensKey],
[this.outputTokensKey]: outputTokens,
};
}
return this.usage;
}

/**
* Calculates the correct token count for the current user message based on the token count map and API usage.
* Edge case: If the calculation results in a negative value, it returns the original estimate.
* If revisiting a conversation with a chat history entirely composed of token estimates,
* the cumulative token count going forward should become more accurate as the conversation progresses.
* @param {Object} params - The parameters for the calculation.
* @param {Record<string, number>} params.tokenCountMap - A map of message IDs to their token counts.
* @param {string} params.currentMessageId - The ID of the current message to calculate.
* @param {OpenAIUsageMetadata} params.usage - The usage object returned by the API.
* @returns {number} The correct token count for the current user message.
*/
calculateCurrentTokenCount({ tokenCountMap, currentMessageId, usage }) {
const originalEstimate = tokenCountMap[currentMessageId] || 0;

if (!usage || typeof usage[this.inputTokensKey] !== 'number') {
return originalEstimate;
}

tokenCountMap[currentMessageId] = 0;
const totalTokensFromMap = Object.values(tokenCountMap).reduce((sum, count) => {
const numCount = Number(count);
return sum + (isNaN(numCount) ? 0 : numCount);
}, 0);
const totalInputTokens = usage[this.inputTokensKey] ?? 0;

const currentMessageTokens = totalInputTokens - totalTokensFromMap;
return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate;
}

async summarizeMessages({ messagesToRefine, remainingContextTokens }) {
logger.debug('[OpenAIClient] Summarizing messages...');
let context = messagesToRefine;
Expand Down Expand Up @@ -1000,7 +1073,16 @@ ${convo}
}
}

async recordTokenUsage({ promptTokens, completionTokens, context = 'message' }) {
/**
* @param {object} params
* @param {number} params.promptTokens
* @param {number} params.completionTokens
* @param {OpenAIUsageMetadata} [params.usage]
* @param {string} [params.model]
* @param {string} [params.context='message']
* @returns {Promise<void>}
*/
async recordTokenUsage({ promptTokens, completionTokens, usage, context = 'message' }) {
await spendTokens(
{
context,
Expand All @@ -1011,6 +1093,19 @@ ${convo}
},
{ promptTokens, completionTokens },
);

if (typeof usage === 'object' && typeof usage.reasoning_tokens === 'number') {
await spendTokens(
{
context: 'reasoning',
model: this.modelOptions.model,
conversationId: this.conversationId,
user: this.user ?? this.options.req.user?.id,
endpointTokenConfig: this.options.endpointTokenConfig,
},
{ completionTokens: usage.reasoning_tokens },
);
}
}

getTokenCountForResponse(response) {
Expand Down Expand Up @@ -1191,6 +1286,10 @@ ${convo}
/** @type {(value: void | PromiseLike<void>) => void} */
let streamResolve;

if (modelOptions.stream && /\bo1\b/i.test(modelOptions.model)) {
delete modelOptions.stream;
}

if (modelOptions.stream) {
streamPromise = new Promise((resolve) => {
streamResolve = resolve;
Expand Down Expand Up @@ -1269,6 +1368,8 @@ ${convo}
}

const { choices } = chatCompletion;
this.usage = chatCompletion.usage;

if (!Array.isArray(choices) || choices.length === 0) {
logger.warn('[OpenAIClient] Chat completion response has no choices');
return intermediateReply.join('');
Expand Down
9 changes: 9 additions & 0 deletions api/models/tx.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ const tokenValues = Object.assign(
'4k': { prompt: 1.5, completion: 2 },
'16k': { prompt: 3, completion: 4 },
'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 },
'o1-preview': { prompt: 15, completion: 60 },
'o1-mini': { prompt: 3, completion: 12 },
o1: { prompt: 15, completion: 60 },
'gpt-4o-2024-08-06': { prompt: 2.5, completion: 10 },
'gpt-4o-mini': { prompt: 0.15, completion: 0.6 },
'gpt-4o': { prompt: 5, completion: 15 },
Expand Down Expand Up @@ -95,6 +98,12 @@ const getValueKey = (model, endpoint) => {
return 'gpt-3.5-turbo-1106';
} else if (modelName.includes('gpt-3.5')) {
return '4k';
} else if (modelName.includes('o1-preview')) {
return 'o1-preview';
} else if (modelName.includes('o1-mini')) {
return 'o1-mini';
} else if (modelName.includes('o1')) {
return 'o1';
} else if (modelName.includes('gpt-4o-2024-08-06')) {
return 'gpt-4o-2024-08-06';
} else if (modelName.includes('gpt-4o-mini')) {
Expand Down
4 changes: 4 additions & 0 deletions api/server/middleware/abortMiddleware.js
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ const handleAbortError = async (res, req, error, data) => {
errorText = `{"type":"${ErrorTypes.INVALID_REQUEST}"}`;
}

if (error?.message?.includes('does not support \'system\'')) {
errorText = `{"type":"${ErrorTypes.NO_SYSTEM_MESSAGES}"}`;
}

const respondWithError = async (partialText) => {
let options = {
sender,
Expand Down
14 changes: 13 additions & 1 deletion api/typedefs.js
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,19 @@
*/

/**
* @typedef {AnthropicStreamUsage} StreamUsage - Stream usage for all providers (currently only Anthropic)
* @exports OpenAIUsageMetadata
* @typedef {Object} OpenAIUsageMetadata - Usage statistics related to the run. This value will be `null` if the run is not in a terminal state (i.e. `in_progress`, `queued`, etc.).
* @property {number} [usage.completion_tokens] - Number of completion tokens used over the course of the run.
* @property {number} [usage.prompt_tokens] - Number of prompt tokens used over the course of the run.
* @property {number} [usage.total_tokens] - Total number of tokens used (prompt + completion).
* @property {number} [usage.reasoning_tokens] - Total number of tokens used for reasoning (OpenAI o1 models).
* @property {Object} [usage.completion_tokens_details] - Further details on the completion tokens used (OpenAI o1 models).
* @property {number} [usage.completion_tokens_details.reasoning_tokens] - Total number of tokens used for reasoning (OpenAI o1 models).
* @memberof typedefs
*/

/**
* @typedef {AnthropicStreamUsage | OpenAIUsageMetadata | UsageMetadata} StreamUsage - Stream usage for all providers (currently only Anthropic, OpenAI, LangChain)
*/

/* Native app/client methods */
Expand Down
Loading
Loading