Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🎉 feat: Optimizations and Anthropic Title Generation #2184

Merged
merged 12 commits into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 96 additions & 7 deletions api/app/clients/AnthropicClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@ const {
validateVisionModel,
} = require('librechat-data-provider');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { formatMessage, createContextHandlers } = require('./prompts');
const {
titleFunctionPrompt,
parseTitleFromPrompt,
truncateText,
formatMessage,
createContextHandlers,
} = require('./prompts');
const spendTokens = require('~/models/spendTokens');
const { getModelMaxTokens } = require('~/utils');
const BaseClient = require('./BaseClient');
Expand Down Expand Up @@ -108,7 +114,12 @@ class AnthropicClient extends BaseClient {
return this;
}

/**
* Get the initialized Anthropic client.
* @returns {Anthropic} The Anthropic client instance.
*/
getClient() {
/** @type {Anthropic.default.RequestOptions} */
const options = {
apiKey: this.apiKey,
};
Expand Down Expand Up @@ -176,14 +187,13 @@ class AnthropicClient extends BaseClient {
return files;
}

async recordTokenUsage({ promptTokens, completionTokens }) {
logger.debug('[AnthropicClient] recordTokenUsage:', { promptTokens, completionTokens });
async recordTokenUsage({ promptTokens, completionTokens, model, context = 'message' }) {
await spendTokens(
{
context,
user: this.user,
model: this.modelOptions.model,
context: 'message',
conversationId: this.conversationId,
model: model ?? this.modelOptions.model,
endpointTokenConfig: this.options.endpointTokenConfig,
},
{ promptTokens, completionTokens },
Expand Down Expand Up @@ -512,8 +522,15 @@ class AnthropicClient extends BaseClient {
logger.debug('AnthropicClient doesn\'t use getCompletion (all handled in sendCompletion)');
}

async createResponse(client, options) {
return this.useMessages
/**
* Creates a message or completion response using the Anthropic client.
* @param {Anthropic} client - The Anthropic client instance.
* @param {Anthropic.default.MessageCreateParams | Anthropic.default.CompletionCreateParams} options - The options for the message or completion.
* @param {boolean} useMessages - Whether to use messages or completions. Defaults to `this.useMessages`.
* @returns {Promise<Anthropic.default.Message | Anthropic.default.Completion>} The response from the Anthropic client.
*/
async createResponse(client, options, useMessages) {
return useMessages ?? this.useMessages
? await client.messages.create(options)
: await client.completions.create(options);
}
Expand Down Expand Up @@ -663,6 +680,78 @@ class AnthropicClient extends BaseClient {
getTokenCount(text) {
return this.gptEncoder.encode(text, 'all').length;
}

/**
* Generates a concise title for a conversation based on the user's input text and response.
* Involves sending a chat completion request with specific instructions for title generation.
*
* This function capitlizes on [Anthropic's function calling training](https://docs.anthropic.com/claude/docs/functions-external-tools).
*
* @param {Object} params - The parameters for the conversation title generation.
* @param {string} params.text - The user's input.
* @param {string} [params.responseText=''] - The AI's immediate response to the user.
*
* @returns {Promise<string | 'New Chat'>} A promise that resolves to the generated conversation title.
* In case of failure, it will return the default title, "New Chat".
*/
async titleConvo({ text, responseText = '' }) {
let title = 'New Chat';
const convo = `<initial_message>
${truncateText(text)}
</initial_message>
<response>
${JSON.stringify(truncateText(responseText))}
</response>`;

const { ANTHROPIC_TITLE_MODEL } = process.env ?? {};
const model = this.options.titleModel ?? ANTHROPIC_TITLE_MODEL ?? 'claude-3-haiku-20240307';
const system = titleFunctionPrompt;

const titleChatCompletion = async () => {
const content = `<conversation_context>
${convo}
</conversation_context>

Please generate a title for this conversation.`;

const titleMessage = { role: 'user', content };
const requestOptions = {
model,
temperature: 0.3,
max_tokens: 1024,
system,
stop_sequences: ['\n\nHuman:', '\n\nAssistant', '</function_calls>'],
messages: [titleMessage],
};

try {
const response = await this.createResponse(this.getClient(), requestOptions, true);
let promptTokens = response?.usage?.input_tokens;
let completionTokens = response?.usage?.output_tokens;
if (!promptTokens) {
promptTokens = this.getTokenCountForMessage(titleMessage);
promptTokens += this.getTokenCountForMessage({ role: 'system', content: system });
}
if (!completionTokens) {
completionTokens = this.getTokenCountForMessage(response.content[0]);
}
await this.recordTokenUsage({
model,
promptTokens,
completionTokens,
context: 'title',
});
const text = response.content[0].text;
title = parseTitleFromPrompt(text);
} catch (e) {
logger.error('[AnthropicClient] There was an issue generating the title', e);
}
};

await titleChatCompletion();
logger.debug('[AnthropicClient] Convo Title: ' + title);
return title;
}
}

module.exports = AnthropicClient;
1 change: 1 addition & 0 deletions api/app/clients/BaseClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ class BaseClient {
sender: this.sender,
text: addSpaceIfNeeded(generation) + completion,
promptTokens,
...(this.metadata ?? {}),
};

if (
Expand Down
22 changes: 6 additions & 16 deletions api/app/clients/OpenAIClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class OpenAIClient extends BaseClient {
/** @type {AzureOptions} */
this.azure = options.azure || false;
this.setOptions(options);
this.metadata = {};
}

// TODO: PluginsClient calls this 3x, unneeded
Expand Down Expand Up @@ -574,7 +575,6 @@ class OpenAIClient extends BaseClient {
} else if (typeof opts.onProgress === 'function' || this.options.useChatCompletion) {
reply = await this.chatCompletion({
payload,
clientOptions: opts,
onProgress: opts.onProgress,
abortController: opts.abortController,
});
Expand All @@ -594,9 +594,9 @@ class OpenAIClient extends BaseClient {
}
}

if (streamResult && typeof opts.addMetadata === 'function') {
if (streamResult) {
const { finish_reason } = streamResult.choices[0];
opts.addMetadata({ finish_reason });
this.metadata = { finish_reason };
}
return (reply ?? '').trim();
}
Expand Down Expand Up @@ -921,7 +921,6 @@ ${convo}
}

async recordTokenUsage({ promptTokens, completionTokens }) {
logger.debug('[OpenAIClient] recordTokenUsage:', { promptTokens, completionTokens });
await spendTokens(
{
user: this.user,
Expand All @@ -941,7 +940,7 @@ ${convo}
});
}

async chatCompletion({ payload, onProgress, clientOptions, abortController = null }) {
async chatCompletion({ payload, onProgress, abortController = null }) {
let error = null;
const errorCallback = (err) => (error = err);
let intermediateReply = '';
Expand All @@ -962,15 +961,6 @@ ${convo}
}

const baseURL = extractBaseURL(this.completionsUrl);
// let { messages: _msgsToLog, ...modelOptionsToLog } = modelOptions;
// if (modelOptionsToLog.messages) {
// _msgsToLog = modelOptionsToLog.messages.map((msg) => {
// let { content, ...rest } = msg;

// if (content)
// return { ...rest, content: truncateText(content) };
// });
// }
logger.debug('[OpenAIClient] chatCompletion', { baseURL, modelOptions });
const opts = {
baseURL,
Expand Down Expand Up @@ -1163,8 +1153,8 @@ ${convo}
}

const { message, finish_reason } = chatCompletion.choices[0];
if (chatCompletion && typeof clientOptions.addMetadata === 'function') {
clientOptions.addMetadata({ finish_reason });
if (chatCompletion) {
this.metadata = { finish_reason };
}

logger.debug('[OpenAIClient] chatCompletion response', chatCompletion);
Expand Down
53 changes: 53 additions & 0 deletions api/app/clients/prompts/titlePrompts.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,60 @@ ${convo}`,
return titlePrompt;
};

const titleFunctionPrompt = `In this environment you have access to a set of tools you can use to generate the conversation title.

You may call them like this:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>

Here are the tools available:
<tools>
<tool_description>
<tool_name>submit_title</tool_name>
<description>
Submit a brief title in the conversation's language, following the parameter description closely.
</description>
<parameters>
<parameter>
<name>title</name>
<type>string</type>
<description>A concise, 5-word-or-less title for the conversation, using its same language, with no punctuation. Apply title case conventions appropriate for the language. For English, use AP Stylebook Title Case. Never directly mention the language name or the word "title"</description>
</parameter>
</parameters>
</tool_description>
</tools>`;

/**
* Parses titles from title functions based on the provided prompt.
* @param {string} prompt - The prompt containing the title function.
* @returns {string} The parsed title. "New Chat" if no title is found.
*/
function parseTitleFromPrompt(prompt) {
const titleRegex = /<title>(.+?)<\/title>/;
const titleMatch = prompt.match(titleRegex);

if (titleMatch && titleMatch[1]) {
const title = titleMatch[1].trim();

// // Capitalize the first letter of each word; Note: unnecessary due to title case prompting
// const capitalizedTitle = title.replace(/\b\w/g, (char) => char.toUpperCase());

return title;
}

return 'New Chat';
}

module.exports = {
langPrompt,
createTitlePrompt,
titleFunctionPrompt,
parseTitleFromPrompt,
};
2 changes: 1 addition & 1 deletion api/cache/getLogStores.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ const modelQueries = isEnabled(process.env.USE_REDIS)

const abortKeys = isEnabled(USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.ABORT_KEYS });
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: 600000 });

const namespaces = {
[CacheKeys.CONFIG_STORE]: config,
Expand Down
13 changes: 9 additions & 4 deletions api/models/spendTokens.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,15 @@ const { logger } = require('~/config');
*/
const spendTokens = async (txData, tokenUsage) => {
const { promptTokens, completionTokens } = tokenUsage;
logger.debug(`[spendTokens] conversationId: ${txData.conversationId} | Token usage: `, {
promptTokens,
completionTokens,
});
logger.debug(
`[spendTokens] conversationId: ${txData.conversationId}${
txData?.context ? ` | Context: ${txData?.context}` : ''
} | Token usage: `,
{
promptTokens,
completionTokens,
},
);
let prompt, completion;
try {
if (promptTokens >= 0) {
Expand Down
28 changes: 7 additions & 21 deletions api/server/controllers/AskController.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
const throttle = require('lodash/throttle');
const { getResponseSender, Constants } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils');
Expand All @@ -16,13 +17,10 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {

logger.debug('[AskController]', { text, conversationId, ...endpointOption });

let metadata;
let userMessage;
let promptTokens;
let userMessageId;
let responseMessageId;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const sender = getResponseSender({
...endpointOption,
model: endpointOption.modelOptions.model,
Expand All @@ -31,8 +29,6 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
const newConvo = !conversationId;
const user = req.user.id;

const addMetadata = (data) => (metadata = data);

const getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
Expand All @@ -54,11 +50,8 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
const { client } = await initializeClient({ req, res, endpointOption });

const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();

if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
onProgress: throttle(
({ text: partialText }) => {
saveMessage({
messageId: responseMessageId,
sender,
Expand All @@ -70,12 +63,10 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
error: false,
user,
});
}

if (saveDelay < 500) {
saveDelay = 500;
}
},
},
3000,
{ trailing: false },
),
});

getText = getPartialText;
Expand Down Expand Up @@ -113,7 +104,6 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
overrideParentMessageId,
getReqData,
onStart,
addMetadata,
abortController,
onProgress: progressCallback.call(null, {
res,
Expand All @@ -128,10 +118,6 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
response.parentMessageId = overrideParentMessageId;
}

if (metadata) {
response = { ...response, ...metadata };
}

response.endpoint = endpointOption.endpoint;

const conversation = await getConvo(user, conversationId);
Expand Down
Loading
Loading