From afd43afb60b230a00ce5a5effab900130252f5cb Mon Sep 17 00:00:00 2001 From: Danny Avila <110412045+danny-avila@users.noreply.github.com> Date: Thu, 17 Aug 2023 12:50:05 -0400 Subject: [PATCH] feat(GPT/Anthropic): Continue Regenerating & Generation Buttons (#808) * feat(useMessageHandler.js/ts): Refactor and add features to handle user messages, support multiple endpoints/models, generate placeholder responses, regeneration, and stopGeneration function fix(conversation.ts, buildTree.ts): Import TMessage type, handle null parentMessageId feat(schemas.ts): Update and add schemas for various AI services, add default values, optional fields, and endpoint-to-schema mapping, create parseConvo function chore(useMessageHandler.js, schemas.ts): Remove unused imports, variables, and chatGPT enum * wip: add generation buttons * refactor(cleanupPreset.ts): simplify cleanupPreset function refactor(getDefaultConversation.js): remove unused code and simplify getDefaultConversation function feat(utils): add getDefaultConversation function This commit adds a new utility function called `getDefaultConversation` to the `client/src/utils/getDefaultConversation.ts` file. This function is responsible for generating a default conversation object based on the provided parameters. The `getDefaultConversation` function takes in an object with the following properties: - `conversation`: The conversation object to be used as a base. - `endpointsConfig`: The configuration object containing information about the available endpoints. - `preset`: An optional preset object that can be used to override the default behavior. The function first tries to determine the target endpoint based on the preset object. If a valid endpoint is found, it is used as the target endpoint. If not, the function tries to retrieve the last conversation setup from the local storage and uses its endpoint if it is valid. If neither the preset nor the local storage contains a valid endpoint, the function falls back to a default endpoint. Once the target endpoint is determined, * fix(utils): remove console.error statement in buildDefaultConversation function fix(schemas): add default values for catch blocks in openAISchema, googleSchema, bingAISchema, anthropicSchema, chatGPTBrowserSchema, and gptPluginsSchema * fix: endpoint not changing on change of preset from other endpoint, wip: refactor * refactor: preset items to TSX * refactor: convert resetConvo to TS * refactor(getDefaultConversation.ts): move defaultEndpoints array to the top of the file for better readability refactor(getDefaultConversation.ts): extract getDefaultEndpoint function for better code organization and reusability * feat(svg): add ContinueIcon component feat(svg): add RegenerateIcon component feat(svg): add ContinueIcon and RegenerateIcon components to index.ts * feat(Button.tsx): add onClick and className props to Button component feat(GenerationButtons.tsx): add logic to display Regenerate or StopGenerating button based on isSubmitting and messages feat(Regenerate.tsx): create Regenerate component with RegenerateIcon and handleRegenerate function feat(StopGenerating.tsx): create StopGenerating component with StopGeneratingIcon and handleStopGenerating function * fix(TextChat.jsx): reorder imports and variables for better readability fix(TextChat.jsx): fix typo in condition for isNotAppendable variable fix(TextChat.jsx): remove unused handleStopGenerating function fix(ContinueIcon.tsx): remove unnecessary closing tags for polygon elements fix(useMessageHandler.ts): add missing type annotations for handleStopGenerating and handleRegenerate functions fix(useMessageHandler.ts): remove unused variables in return statement * fix(getDefaultConversation.ts): refactor code to use getLocalStorageItems function feat(getLocalStorageItems.ts): add utility function to retrieve items from local storage * fix(OpenAIClient.js): add support for streaming result in sendCompletion method feat(OpenAIClient.js): add finish_reason metadata to opts in sendCompletion method feat(Message.js): add finish_reason field to Message model feat(messageSchema.js): add finish_reason field to messageSchema feat(openAI.js): parse chatGptLabel and promptPrefix from req.body and pass rest of the modelOptions to endpointOption feat(openAI.js): add addMetadata function to store metadata in ask function feat(openAI.js): add metadata to response if available feat(schemas.ts): add finish_reason field to tMessageSchema * feat(types.ts): add TOnClick and TGenButtonProps types for button components feat(Continue.tsx): create Continue component for generating button feat(GenerationButtons.tsx): update GenerationButtons component to use Continue component feat(Regenerate.tsx): create Regenerate component for regenerating button feat(Stop.tsx): create Stop component for stop generating button * feat(MessageHandler.jsx): add MessageHandler component to handle messages and conversations fix(Root.jsx): fix import paths for Nav and MessageHandler components * feat(useMessageHandler.ts): add support for generation parameter in ask function feat(useMessageHandler.ts): add support for isEdited parameter in ask function feat(useMessageHandler.ts): add support for continueGeneration function fix(createPayload.ts): replace endpoint URL when isEdited parameter is true * chore(client): set skipLibCheck to true in tsconfig.json * fix(useMessageHandler.ts): remove unused clientId variable fix(schemas.ts): make clientId field in tMessageSchema nullable and optional * wip: edit route for continue generation * refactor(api): move handlers to root of routes dir * fix(useMessageHandler.ts): initialize currentMessages to an empty array if messages is null fix(useMessageHandler.ts): update initialResponse text to use responseText variable fix(useMessageHandler.ts): update setMessages logic for isRegenerate case fix(MessageHandler.jsx): update setMessages logic for cancelHandler, createdHandler, and finalHandler * fix(schemas.ts): make createdAt and updatedAt fields optional and set default values using new Date().toISOString() fix(schemas.ts): change type annotation of TMessage from infer to input * refactor(useMessageHandler.ts): rename AskProps type to TAskProps refactor(useMessageHandler.ts): remove generation property from ask function arguments refactor(useMessageHandler.ts): use nullish coalescing operator (??) instead of logical OR (||) refactor(useMessageHandler.ts): pass the responseMessageId to message prop of submission * fix(BaseClient.js): use nullish coalescing operator (??) instead of logical OR (||) for default values * fix(BaseClient.js): fix responseMessageId assignment in handleStartMethods method feat(BaseClient.js): add support for isEdited flag in sendMessage method feat(BaseClient.js): add generation to responseMessage text in sendMessage method * fix(openAI.js): remove unused imports and commented out code feat(openAI.js): add support for generation parameter in request body fix(openAI.js): remove console.log statement fix(openAI.js): remove unused variables and parameters fix(openAI.js): update response text in case of error fix(openAI.js): handle error and abort message in case of error fix(handlers.js): add generation parameter to createOnProgress function fix(useMessageHandler.ts): update responseText variable to use generation parameter * refactor(api/middleware): move inside server dir * refactor: add endpoint specific, modular functions to build options and initialize clients, create server/utils, move middleware, separate utils into api general utils and server specific utils * fix(abortMiddleware.js): import getConvo and getConvoTitle functions from models feat(abortMiddleware.js): add abortAsk function to abortController to handle aborting of requests fix(openAI.js): import buildOptions and initializeClient functions from endpoints/openAI refactor(openAI.js): use getAbortData function to get data for abortAsk function * refactor: move endpoint specific logic to an endpoints dir * refactor(PluginService.js): fix import path for encrypt and decrypt functions in PluginService.js * feat(openAI): add new endpoint for adding a title to a conversation - Added a new file `addTitle.js` in the `api/server/routes/endpoints/openAI` directory. - The `addTitle.js` file exports a function `addTitle` that takes in request parameters and performs the following actions: - If the `parentMessageId` is `'00000000-0000-0000-0000-000000000000'` and `newConvo` is true, it proceeds with the following steps: - Calls the `titleConvo` function from the `titleConvo` module, passing in the necessary parameters. - Calls the `saveConvo` function from the `saveConvo` module, passing in the user ID and conversation details. - Updated the `index.js` file in the `api/server/routes/endpoints/openAI` directory to export the `addTitle` function. - This change adds * fix(abortMiddleware.js): remove console.log statement refactor(gptPlugins.js): update imports and function parameters feat(gptPlugins.js): add support for abortController and getAbortData refactor(openAI.js): update imports and function parameters feat(openAI.js): add support for abortController and getAbortData fix(openAI.js): refactor code to use modularized functions and middleware fix(buildOptions.js): refactor code to use destructuring and update variable names * refactor(askChatGPTBrowser.js, bingAI.js, google.js): remove duplicate code for setting response headers feat(askChatGPTBrowser.js, bingAI.js, google.js): add setHeaders middleware to set response headers * feat(middleware): validateEndpoint, refactor buildOption to only be concerned of endpointOption * fix(abortMiddleware.js): add 'finish_reason' property with value 'incomplete' to responseMessage object fix(abortMessage.js): remove console.log statement for aborted message fix(handlers.js): modify tokens assignment to handle empty generation string and trailing space * fix(BaseClient.js): import addSpaceIfNeeded function from server/utils fix(BaseClient.js): add space before generation in text property fix(index.js): remove getCitations and citeText exports feat(buildEndpointOption.js): add buildEndpointOption middleware fix(index.js): import buildEndpointOption middleware fix(anthropic.js): remove buildOptions function and use endpointOption from req.body fix(gptPlugins.js): remove buildOptions function and use endpointOption from req.body fix(openAI.js): remove buildOptions function and use endpointOption from req.body feat(utils): add citations.js and handleText.js modules fix(utils): fix import statements in index.js module * refactor(gptPlugins.js): use getResponseSender function from librechat-data-provider * feat(gptPlugins): complete 'continue generating' * wip: anthropic continue regen * feat(middleware): add validateRegistration middleware A new middleware function called `validateRegistration` has been added to the list of exported middleware functions in `index.js`. This middleware is responsible for validating registration data before allowing the registration process to proceed. * feat(Anthropic): complete continue regen * chore: add librechat-data-provider to api/package.json * fix(ci): backend-review will mock meilisearch, also installs data-provider as now needed * chore(ci): remove unneeded SEARCH env var * style(GenerationButtons): make text shorter for sake of space economy, even though this diverges from chat.openai.com * style(GenerationButtons/ScrollToBottom): adjust visibility/position based on screen size * chore(client): 'Editting' typo * feat(GenerationButtons.tsx): add support for endpoint prop in GenerationButtons component feat(OptionsBar.tsx): pass endpoint prop to GenerationButtons component feat(useGenerations.ts): create useGenerations hook to handle generation logic fix(schemas.ts): add searchResult field to tMessageSchema * refactor(HoverButtons): convert to TSX and utilize new useGenerations hook * fix(abortMiddleware): handle error with res headers set, or abortController not found, to ensure proper API error is sent to the client, chore(BaseClient): remove console log for onStart message meant for debugging * refactor(api): remove librechat-data-provider dep for now as it complicates deployed docker build stage, re-use code in CJS, located in server/endpoints/schemas * chore: remove console.logs from test files * ci: add backend tests for AnthropicClient, focusing on new buildMessages logic * refactor(FakeClient): use actual BaseClient sendMessage method for testing * test(BaseClient.test.js): add test for loading chat history test(BaseClient.test.js): add test for sendMessage logic with isEdited flag * fix(buildEndpointOption.js): add support for azureOpenAI in buildFunction object wip(endpoints.js): fetch Azure models from Azure OpenAI API if opts.azure is true * fix(Button.tsx): add data-testid attribute to button component fix(SelectDropDown.tsx): add data-testid attribute to Listbox.Button component fix(messages.spec.ts): add waitForServerStream function to consolidate logic for awaiting the server response feat(messages.spec.ts): add test for stopping and continuing message and improve browser/page context order and closing * refactor(onProgress): speed up time to save initial message for editable routes * chore: disable AI message editing (for now), was accidentally allowed * refactor: ensure continue is only supported for latest message style: improve styling in dark mode and across all hover buttons/icons, including making edit icon for AI invisible (for now) * fix: add test id to generation buttons so they never resolve to 2+ items * chore(package.json): add 'packages/' to the list of ignored directories chore(data-provider/package.json): bump version to 0.1.5 --- .github/workflows/backend-review.yml | 10 +- api/app/clients/AnthropicClient.js | 46 ++- api/app/clients/BaseClient.js | 71 ++-- api/app/clients/OpenAIClient.js | 9 + api/app/clients/PluginsClient.js | 3 +- api/app/clients/specs/AnthropicClient.test.js | 139 +++++++ api/app/clients/specs/BaseClient.test.js | 63 ++- api/app/clients/specs/FakeClient.js | 81 ---- api/app/clients/specs/OpenAIClient.test.js | 5 + api/app/clients/specs/PluginsClient.test.js | 1 - api/app/clients/tools/util/addOpenAPISpecs.js | 1 - .../clients/tools/util/handleTools.test.js | 1 - api/app/index.js | 4 - api/lib/parse/getCitations.js | 18 - api/models/Message.js | 2 + api/models/schema/messageSchema.js | 3 + api/server/index.js | 1 + api/server/middleware/abortControllers.js | 2 + api/server/middleware/abortMiddleware.js | 106 +++++ api/server/middleware/buildEndpointOption.js | 20 + api/server/middleware/index.js | 17 + api/{ => server}/middleware/requireJwtAuth.js | 0 .../middleware/requireLocalAuth.js | 2 +- api/server/middleware/setHeaders.js | 12 + api/server/middleware/validateEndpoint.js | 19 + .../middleware/validateRegistration.js | 0 api/server/routes/ask/anthropic.js | 252 +++++------- api/server/routes/ask/askChatGPTBrowser.js | 16 +- api/server/routes/ask/bingAI.js | 14 +- api/server/routes/ask/google.js | 14 +- api/server/routes/ask/gptPlugins.js | 338 ++++++---------- api/server/routes/ask/index.js | 3 - api/server/routes/ask/openAI.js | 343 +++++++--------- api/server/routes/auth.js | 4 +- api/server/routes/convos.js | 2 +- api/server/routes/edit/anthropic.js | 139 +++++++ api/server/routes/edit/gptPlugins.js | 185 +++++++++ api/server/routes/edit/index.js | 13 + api/server/routes/edit/openAI.js | 141 +++++++ api/server/routes/endpoints.js | 22 +- .../endpoints/anthropic/buildOptions.js | 15 + .../routes/endpoints/anthropic/index.js | 8 + .../endpoints/anthropic/initializeClient.js | 12 + .../endpoints/gptPlugins/buildOptions.js | 31 ++ .../routes/endpoints/gptPlugins/index.js | 7 + .../endpoints/gptPlugins/initializeClient.js | 30 ++ .../routes/endpoints/openAI/addTitle.js | 22 ++ .../routes/endpoints/openAI/buildOptions.js | 15 + api/server/routes/endpoints/openAI/index.js | 9 + .../endpoints/openAI/initializeClient.js | 27 ++ api/server/routes/endpoints/schemas.js | 369 ++++++++++++++++++ api/server/routes/index.js | 2 + api/server/routes/messages.js | 2 +- api/server/routes/plugins.js | 2 +- api/server/routes/presets.js | 2 +- api/server/routes/search.js | 2 +- api/server/routes/tokenizer.js | 2 +- api/server/routes/user.js | 2 +- api/server/services/AuthService.js | 6 +- api/server/services/PluginService.js | 2 +- .../citeText.js => server/utils/citations.js} | 17 +- api/{ => server}/utils/crypto.js | 0 .../utils/emails/passwordReset.handlebars | 0 .../emails/requestPasswordReset.handlebars | 0 .../ask/handlers.js => utils/handleText.js} | 9 +- api/server/utils/index.js | 11 + api/{ => server}/utils/sendEmail.js | 0 api/utils/abortMessage.js | 18 - api/utils/index.js | 6 - client/src/common/types.ts | 14 + .../Input/EndpointMenu/EndpointMenu.jsx | 14 +- .../{PresetItem.jsx => PresetItem.tsx} | 16 +- .../{PresetItems.jsx => PresetItems.tsx} | 3 +- .../components/Input/Generations/Button.tsx | 27 ++ .../components/Input/Generations/Continue.tsx | 12 + .../Input/Generations/GenerationButtons.tsx | 61 +++ .../Input/Generations/Regenerate.tsx | 12 + .../src/components/Input/Generations/Stop.tsx | 12 + .../src/components/Input/Generations/index.ts | 1 + client/src/components/Input/OptionsBar.tsx | 8 +- client/src/components/Input/RowButton.tsx | 16 - client/src/components/Input/TextChat.jsx | 21 +- .../src/components/Messages/HoverButtons.jsx | 87 ----- .../src/components/Messages/HoverButtons.tsx | 86 ++++ client/src/components/Messages/Message.jsx | 8 +- ...{ScrollToBottom.jsx => ScrollToBottom.tsx} | 8 +- client/src/components/svg/Clipboard.tsx | 2 +- client/src/components/svg/ContinueIcon.tsx | 21 + client/src/components/svg/RegenerateIcon.tsx | 6 +- .../src/components/svg/StopGeneratingIcon.tsx | 6 +- client/src/components/svg/index.ts | 2 + client/src/components/ui/SelectDropDown.tsx | 1 + client/src/hooks/index.ts | 1 + client/src/hooks/useGenerations.ts | 55 +++ client/src/hooks/useMessageHandler.js | 219 ----------- client/src/hooks/useMessageHandler.ts | 201 ++++++++++ .../index.jsx => routes/MessageHandler.jsx} | 33 +- client/src/routes/Root.jsx | 8 +- client/src/store/conversation.ts | 10 +- client/src/utils/buildTree.ts | 2 +- client/src/utils/cleanupPreset.ts | 113 +----- client/src/utils/getDefaultConversation.js | 199 ---------- client/src/utils/getDefaultConversation.ts | 96 +++++ client/src/utils/getLocalStorageItems.ts | 22 ++ client/src/utils/index.ts | 1 + .../utils/{resetConvo.js => resetConvo.ts} | 8 +- client/tsconfig.json | 2 +- e2e/specs/messages.spec.ts | 60 ++- package.json | 3 +- packages/data-provider/package.json | 2 +- packages/data-provider/src/createPayload.ts | 8 +- packages/data-provider/src/schemas.ts | 294 +++++++++++++- packages/data-provider/src/types.ts | 47 +-- 113 files changed, 3029 insertions(+), 1549 deletions(-) create mode 100644 api/app/clients/specs/AnthropicClient.test.js delete mode 100644 api/lib/parse/getCitations.js create mode 100644 api/server/middleware/abortControllers.js create mode 100644 api/server/middleware/abortMiddleware.js create mode 100644 api/server/middleware/buildEndpointOption.js create mode 100644 api/server/middleware/index.js rename api/{ => server}/middleware/requireJwtAuth.js (100%) rename api/{ => server}/middleware/requireLocalAuth.js (92%) create mode 100644 api/server/middleware/setHeaders.js create mode 100644 api/server/middleware/validateEndpoint.js rename api/{ => server}/middleware/validateRegistration.js (100%) create mode 100644 api/server/routes/edit/anthropic.js create mode 100644 api/server/routes/edit/gptPlugins.js create mode 100644 api/server/routes/edit/index.js create mode 100644 api/server/routes/edit/openAI.js create mode 100644 api/server/routes/endpoints/anthropic/buildOptions.js create mode 100644 api/server/routes/endpoints/anthropic/index.js create mode 100644 api/server/routes/endpoints/anthropic/initializeClient.js create mode 100644 api/server/routes/endpoints/gptPlugins/buildOptions.js create mode 100644 api/server/routes/endpoints/gptPlugins/index.js create mode 100644 api/server/routes/endpoints/gptPlugins/initializeClient.js create mode 100644 api/server/routes/endpoints/openAI/addTitle.js create mode 100644 api/server/routes/endpoints/openAI/buildOptions.js create mode 100644 api/server/routes/endpoints/openAI/index.js create mode 100644 api/server/routes/endpoints/openAI/initializeClient.js create mode 100644 api/server/routes/endpoints/schemas.js rename api/{lib/parse/citeText.js => server/utils/citations.js} (67%) rename api/{ => server}/utils/crypto.js (100%) rename api/{ => server}/utils/emails/passwordReset.handlebars (100%) rename api/{ => server}/utils/emails/requestPasswordReset.handlebars (100%) rename api/server/{routes/ask/handlers.js => utils/handleText.js} (93%) create mode 100644 api/server/utils/index.js rename api/{ => server}/utils/sendEmail.js (100%) delete mode 100644 api/utils/abortMessage.js rename client/src/components/Input/EndpointMenu/{PresetItem.jsx => PresetItem.tsx} (87%) rename client/src/components/Input/EndpointMenu/{PresetItems.jsx => PresetItems.tsx} (82%) create mode 100644 client/src/components/Input/Generations/Button.tsx create mode 100644 client/src/components/Input/Generations/Continue.tsx create mode 100644 client/src/components/Input/Generations/GenerationButtons.tsx create mode 100644 client/src/components/Input/Generations/Regenerate.tsx create mode 100644 client/src/components/Input/Generations/Stop.tsx create mode 100644 client/src/components/Input/Generations/index.ts delete mode 100644 client/src/components/Input/RowButton.tsx delete mode 100644 client/src/components/Messages/HoverButtons.jsx create mode 100644 client/src/components/Messages/HoverButtons.tsx rename client/src/components/Messages/{ScrollToBottom.jsx => ScrollToBottom.tsx} (79%) create mode 100644 client/src/components/svg/ContinueIcon.tsx create mode 100644 client/src/hooks/useGenerations.ts delete mode 100644 client/src/hooks/useMessageHandler.js create mode 100644 client/src/hooks/useMessageHandler.ts rename client/src/{components/MessageHandler/index.jsx => routes/MessageHandler.jsx} (89%) delete mode 100644 client/src/utils/getDefaultConversation.js create mode 100644 client/src/utils/getDefaultConversation.ts create mode 100644 client/src/utils/getLocalStorageItems.ts rename client/src/utils/{resetConvo.js => resetConvo.ts} (54%) diff --git a/.github/workflows/backend-review.yml b/.github/workflows/backend-review.yml index 7c79cadce84..e5c86caa4c5 100644 --- a/.github/workflows/backend-review.yml +++ b/.github/workflows/backend-review.yml @@ -1,10 +1,5 @@ name: Backend Unit Tests on: - # push: - # branches: - # - main - # - dev - # - release/* pull_request: branches: - main @@ -23,6 +18,7 @@ jobs: JWT_SECRET: ${{ secrets.JWT_SECRET }} CREDS_KEY: ${{ secrets.CREDS_KEY }} CREDS_IV: ${{ secrets.CREDS_IV }} + NODE_ENV: ci steps: - uses: actions/checkout@v2 - name: Use Node.js 20.x @@ -34,8 +30,8 @@ jobs: - name: Install dependencies run: npm ci - # - name: Install Linux X64 Sharp - # run: npm install --platform=linux --arch=x64 --verbose sharp + - name: Install Data Provider + run: npm run build:data-provider - name: Run unit tests run: cd api && npm run test:ci diff --git a/api/app/clients/AnthropicClient.js b/api/app/clients/AnthropicClient.js index cf9571c69b8..ebec514040c 100644 --- a/api/app/clients/AnthropicClient.js +++ b/api/app/clients/AnthropicClient.js @@ -1,4 +1,3 @@ -const Keyv = require('keyv'); // const { Agent, ProxyAgent } = require('undici'); const BaseClient = require('./BaseClient'); const { @@ -15,8 +14,6 @@ const tokenizersCache = {}; class AnthropicClient extends BaseClient { constructor(apiKey, options = {}, cacheOptions = {}) { super(apiKey, options, cacheOptions); - cacheOptions.namespace = cacheOptions.namespace || 'anthropic'; - this.conversationsCache = new Keyv(cacheOptions); this.apiKey = apiKey || process.env.ANTHROPIC_API_KEY; this.sender = 'Anthropic'; this.userLabel = HUMAN_PROMPT; @@ -107,6 +104,23 @@ class AnthropicClient extends BaseClient { content: message?.content ?? message.text, })); + let lastAuthor = ''; + let groupedMessages = []; + + for (let message of formattedMessages) { + // If last author is not same as current author, add to new group + if (lastAuthor !== message.author) { + groupedMessages.push({ + author: message.author, + content: [message.content], + }); + lastAuthor = message.author; + // If same author, append content to the last group + } else { + groupedMessages[groupedMessages.length - 1].content.push(message.content); + } + } + let identityPrefix = ''; if (this.options.userLabel) { identityPrefix = `\nHuman's name: ${this.options.userLabel}`; @@ -129,8 +143,12 @@ class AnthropicClient extends BaseClient { promptPrefix = `${identityPrefix}${promptPrefix}`; } - const promptSuffix = `${promptPrefix}${this.assistantLabel}\n`; // Prompt AI to respond. - let currentTokenCount = this.getTokenCount(promptSuffix); + // Prompt AI to respond, empty if last message was from AI + let isEdited = lastAuthor === this.assistantLabel; + const promptSuffix = isEdited ? '' : `${promptPrefix}${this.assistantLabel}\n`; + let currentTokenCount = isEdited + ? this.getTokenCount(promptPrefix) + : this.getTokenCount(promptSuffix); let promptBody = ''; const maxTokenCount = this.maxPromptTokens; @@ -148,10 +166,13 @@ class AnthropicClient extends BaseClient { }; const buildPromptBody = async () => { - if (currentTokenCount < maxTokenCount && formattedMessages.length > 0) { - const message = formattedMessages.pop(); + if (currentTokenCount < maxTokenCount && groupedMessages.length > 0) { + const message = groupedMessages.pop(); const isCreatedByUser = message.author === this.userLabel; - const messageString = `${message.author}\n${message.content}${this.endToken}\n`; + // Use promptPrefix if message is edited assistant' + const messagePrefix = + isCreatedByUser || !isEdited ? message.author : `${promptPrefix}${message.author}`; + const messageString = `${messagePrefix}\n${message.content}${this.endToken}\n`; let newPromptBody = `${messageString}${promptBody}`; context.unshift(message); @@ -182,6 +203,12 @@ class AnthropicClient extends BaseClient { } promptBody = newPromptBody; currentTokenCount = newTokenCount; + + // Switch off isEdited after using it for the first time + if (isEdited) { + isEdited = false; + } + // wait for next tick to avoid blocking the event loop await new Promise((resolve) => setImmediate(resolve)); return buildPromptBody(); @@ -197,7 +224,8 @@ class AnthropicClient extends BaseClient { context.shift(); } - const prompt = `${promptBody}${promptSuffix}`; + let prompt = `${promptBody}${promptSuffix}`; + // Add 2 tokens for metadata after all messages have been counted. currentTokenCount += 2; diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index baaa0990d36..5d55d33fdcc 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -5,11 +5,12 @@ const { ChatOpenAI } = require('langchain/chat_models/openai'); const { loadSummarizationChain } = require('langchain/chains'); const { refinePrompt } = require('./prompts/refinePrompt'); const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('../../models'); +const { addSpaceIfNeeded } = require('../../server/utils'); class BaseClient { constructor(apiKey, options = {}) { this.apiKey = apiKey; - this.sender = options.sender || 'AI'; + this.sender = options.sender ?? 'AI'; this.contextStrategy = null; this.currentDateString = new Date().toLocaleDateString('en-us', { year: 'numeric', @@ -51,18 +52,20 @@ class BaseClient { if (opts && typeof opts === 'object') { this.setOptions(opts); } - const user = opts.user || null; - const conversationId = opts.conversationId || crypto.randomUUID(); - const parentMessageId = opts.parentMessageId || '00000000-0000-0000-0000-000000000000'; - const userMessageId = opts.overrideParentMessageId || crypto.randomUUID(); - const responseMessageId = crypto.randomUUID(); + const user = opts.user ?? null; + const conversationId = opts.conversationId ?? crypto.randomUUID(); + const parentMessageId = opts.parentMessageId ?? '00000000-0000-0000-0000-000000000000'; + const userMessageId = opts.overrideParentMessageId ?? crypto.randomUUID(); + const responseMessageId = opts.responseMessageId ?? crypto.randomUUID(); const saveOptions = this.getSaveOptions(); - this.abortController = opts.abortController || new AbortController(); - this.currentMessages = (await this.loadHistory(conversationId, parentMessageId)) ?? []; + const head = opts.isEdited ? responseMessageId : parentMessageId; + this.currentMessages = (await this.loadHistory(conversationId, head)) ?? []; + this.abortController = opts.abortController ?? new AbortController(); return { ...opts, user, + head, conversationId, parentMessageId, userMessageId, @@ -72,7 +75,7 @@ class BaseClient { } createUserMessage({ messageId, parentMessageId, conversationId, text }) { - const userMessage = { + return { messageId, parentMessageId, conversationId, @@ -80,19 +83,27 @@ class BaseClient { text, isCreatedByUser: true, }; - return userMessage; } async handleStartMethods(message, opts) { - const { user, conversationId, parentMessageId, userMessageId, responseMessageId, saveOptions } = - await this.setMessageOptions(opts); - - const userMessage = this.createUserMessage({ - messageId: userMessageId, - parentMessageId, + const { + user, + head, conversationId, - text: message, - }); + parentMessageId, + userMessageId, + responseMessageId, + saveOptions, + } = await this.setMessageOptions(opts); + + const userMessage = opts.isEdited + ? this.currentMessages[this.currentMessages.length - 2] + : this.createUserMessage({ + messageId: userMessageId, + parentMessageId, + conversationId, + text: message, + }); if (typeof opts?.getIds === 'function') { opts.getIds({ @@ -109,6 +120,7 @@ class BaseClient { return { ...opts, user, + head, conversationId, responseMessageId, saveOptions, @@ -373,7 +385,7 @@ class BaseClient { if (this.options.debug) { console.debug('<-------------------------PAYLOAD/TOKEN COUNT MAP------------------------->'); - console.debug('Payload:', payload); + // console.debug('Payload:', payload); console.debug('Token Count Map:', tokenCountMap); console.debug('Prompt Tokens', promptTokens, remainingContextTokens, this.maxContextTokens); } @@ -382,13 +394,16 @@ class BaseClient { } async sendMessage(message, opts = {}) { - const { user, conversationId, responseMessageId, saveOptions, userMessage } = + const { user, head, isEdited, conversationId, responseMessageId, saveOptions, userMessage } = await this.handleStartMethods(message, opts); this.user = user; // It's not necessary to push to currentMessages // depending on subclass implementation of handling messages - this.currentMessages.push(userMessage); + // When this is an edit, all messages are already in currentMessages, both user and response + if (!isEdited) { + this.currentMessages.push(userMessage); + } let { prompt: payload, @@ -398,13 +413,13 @@ class BaseClient { this.currentMessages, // When the userMessage is pushed to currentMessages, the parentMessage is the userMessageId. // this only matters when buildMessages is utilizing the parentMessageId, and may vary on implementation - userMessage.messageId, + isEdited ? head : userMessage.messageId, this.getBuildMessagesOptions(opts), ); if (this.options.debug) { console.debug('payload'); - console.debug(payload); + // console.debug(payload); } if (tokenCountMap) { @@ -423,7 +438,11 @@ class BaseClient { this.handleTokenCountMap(tokenCountMap); } - await this.saveMessageToDatabase(userMessage, saveOptions, user); + if (!isEdited) { + await this.saveMessageToDatabase(userMessage, saveOptions, user); + } + + const generation = isEdited ? this.currentMessages[this.currentMessages.length - 1].text : ''; const responseMessage = { messageId: responseMessageId, conversationId, @@ -431,7 +450,7 @@ class BaseClient { isCreatedByUser: false, model: this.modelOptions.model, sender: this.sender, - text: await this.sendCompletion(payload, opts), + text: addSpaceIfNeeded(generation) + (await this.sendCompletion(payload, opts)), promptTokens, }; @@ -453,7 +472,7 @@ class BaseClient { console.debug('Loading history for conversation', conversationId, parentMessageId); } - const messages = (await getMessages({ conversationId })) || []; + const messages = (await getMessages({ conversationId })) ?? []; if (messages.length === 0) { return []; diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 53f4815d740..87b25b1226e 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -314,6 +314,7 @@ class OpenAIClient extends BaseClient { async sendCompletion(payload, opts = {}) { let reply = ''; let result = null; + let streamResult = null; if (typeof opts.onProgress === 'function') { await this.getCompletion( payload, @@ -321,6 +322,10 @@ class OpenAIClient extends BaseClient { if (progressMessage === '[DONE]') { return; } + + if (progressMessage.choices) { + streamResult = progressMessage; + } const token = this.isChatCompletion ? progressMessage.choices?.[0]?.delta?.content : progressMessage.choices?.[0]?.text; @@ -355,6 +360,10 @@ class OpenAIClient extends BaseClient { } } + if (streamResult && typeof opts.addMetadata === 'function') { + const { finish_reason } = streamResult.choices[0]; + opts.addMetadata({ finish_reason }); + } return reply.trim(); } diff --git a/api/app/clients/PluginsClient.js b/api/app/clients/PluginsClient.js index d29f03517ae..dc427b4ae41 100644 --- a/api/app/clients/PluginsClient.js +++ b/api/app/clients/PluginsClient.js @@ -345,7 +345,8 @@ Only respond with your conversational reply to the following User Message: } async sendMessage(message, opts = {}) { - const completionMode = this.options.tools.length === 0; + // If a message is edited, no tools can be used. + const completionMode = this.options.tools.length === 0 || opts.isEdited; if (completionMode) { this.setOptions(opts); return super.sendMessage(message, opts); diff --git a/api/app/clients/specs/AnthropicClient.test.js b/api/app/clients/specs/AnthropicClient.test.js new file mode 100644 index 00000000000..52324914b9d --- /dev/null +++ b/api/app/clients/specs/AnthropicClient.test.js @@ -0,0 +1,139 @@ +const AnthropicClient = require('../AnthropicClient'); +const HUMAN_PROMPT = '\n\nHuman:'; +const AI_PROMPT = '\n\nAssistant:'; + +describe('AnthropicClient', () => { + let client; + const model = 'claude-2'; + const parentMessageId = '1'; + const messages = [ + { role: 'user', isCreatedByUser: true, text: 'Hello', messageId: parentMessageId }, + { role: 'assistant', isCreatedByUser: false, text: 'Hi', messageId: '2', parentMessageId }, + { + role: 'user', + isCreatedByUser: true, + text: 'What\'s up', + messageId: '3', + parentMessageId: '2', + }, + ]; + + beforeEach(() => { + const options = { + modelOptions: { + model, + temperature: 0.7, + }, + }; + client = new AnthropicClient('test-api-key'); + client.setOptions(options); + }); + + describe('setOptions', () => { + it('should set the options correctly', () => { + expect(client.apiKey).toBe('test-api-key'); + expect(client.modelOptions.model).toBe(model); + expect(client.modelOptions.temperature).toBe(0.7); + }); + }); + + describe('getSaveOptions', () => { + it('should return the correct save options', () => { + const options = client.getSaveOptions(); + expect(options).toHaveProperty('modelLabel'); + expect(options).toHaveProperty('promptPrefix'); + }); + }); + + describe('buildMessages', () => { + it('should handle promptPrefix from options when promptPrefix argument is not provided', async () => { + client.options.promptPrefix = 'Test Prefix from options'; + const result = await client.buildMessages(messages, parentMessageId); + const { prompt } = result; + expect(prompt).toContain('Test Prefix from options'); + }); + + it('should build messages correctly for chat completion', async () => { + const result = await client.buildMessages(messages, '2'); + expect(result).toHaveProperty('prompt'); + expect(result.prompt).toContain(HUMAN_PROMPT); + expect(result.prompt).toContain('Hello'); + expect(result.prompt).toContain(AI_PROMPT); + expect(result.prompt).toContain('Hi'); + }); + + it('should group messages by the same author', async () => { + const groupedMessages = messages.map((m) => ({ ...m, isCreatedByUser: true, role: 'user' })); + const result = await client.buildMessages(groupedMessages, '3'); + expect(result.context).toHaveLength(1); + + // Check that HUMAN_PROMPT appears only once in the prompt + const matches = result.prompt.match(new RegExp(HUMAN_PROMPT, 'g')); + expect(matches).toHaveLength(1); + + groupedMessages.push({ + role: 'assistant', + isCreatedByUser: false, + text: 'I heard you the first time', + messageId: '4', + parentMessageId: '3', + }); + + const result2 = await client.buildMessages(groupedMessages, '4'); + expect(result2.context).toHaveLength(2); + + // Check that HUMAN_PROMPT appears only once in the prompt + const human_matches = result2.prompt.match(new RegExp(HUMAN_PROMPT, 'g')); + const ai_matches = result2.prompt.match(new RegExp(AI_PROMPT, 'g')); + expect(human_matches).toHaveLength(1); + expect(ai_matches).toHaveLength(1); + }); + + it('should handle isEdited condition', async () => { + const editedMessages = [ + { role: 'user', isCreatedByUser: true, text: 'Hello', messageId: '1' }, + { role: 'assistant', isCreatedByUser: false, text: 'Hi', messageId: '2', parentMessageId }, + ]; + + const trimmedLabel = AI_PROMPT.trim(); + const result = await client.buildMessages(editedMessages, '2'); + expect(result.prompt.trim().endsWith(trimmedLabel)).toBeFalsy(); + + // Add a human message at the end to test the opposite + editedMessages.push({ + role: 'user', + isCreatedByUser: true, + text: 'Hi again', + messageId: '3', + parentMessageId: '2', + }); + const result2 = await client.buildMessages(editedMessages, '3'); + expect(result2.prompt.trim().endsWith(trimmedLabel)).toBeTruthy(); + }); + + it('should build messages correctly with a promptPrefix', async () => { + const promptPrefix = 'Test Prefix'; + client.options.promptPrefix = promptPrefix; + const result = await client.buildMessages(messages, parentMessageId); + const { prompt } = result; + expect(prompt).toBeDefined(); + expect(prompt).toContain(promptPrefix); + const textAfterPrefix = prompt.split(promptPrefix)[1]; + expect(textAfterPrefix).toContain(AI_PROMPT); + + const editedMessages = messages.slice(0, -1); + const result2 = await client.buildMessages(editedMessages, parentMessageId); + const textAfterPrefix2 = result2.prompt.split(promptPrefix)[1]; + expect(textAfterPrefix2).toContain(AI_PROMPT); + }); + + it('should handle identityPrefix from options', async () => { + client.options.userLabel = 'John'; + client.options.modelLabel = 'Claude-2'; + const result = await client.buildMessages(messages, parentMessageId); + const { prompt } = result; + expect(prompt).toContain('Human\'s name: John'); + expect(prompt).toContain('You are Claude-2'); + }); + }); +}); diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index d81bfe6274e..10d5868cb0f 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -45,6 +45,18 @@ const fakeMessages = []; const userMessage = 'Hello, ChatGPT!'; const apiKey = 'fake-api-key'; +const messageHistory = [ + { role: 'user', isCreatedByUser: true, text: 'Hello', messageId: '1' }, + { role: 'assistant', isCreatedByUser: false, text: 'Hi', messageId: '2', parentMessageId: '1' }, + { + role: 'user', + isCreatedByUser: true, + text: 'What\'s up', + messageId: '3', + parentMessageId: '2', + }, +]; + describe('BaseClient', () => { let TestClient; const options = { @@ -277,9 +289,54 @@ describe('BaseClient', () => { }); test('should return chat history', async () => { - const chatMessages = await TestClient.loadHistory(conversationId, parentMessageId); - expect(TestClient.currentMessages).toHaveLength(4); - expect(chatMessages[0].text).toEqual(userMessage); + TestClient = initializeFakeClient(apiKey, options, messageHistory); + const chatMessages = await TestClient.loadHistory(conversationId, '2'); + expect(TestClient.currentMessages).toHaveLength(2); + expect(chatMessages[0].text).toEqual('Hello'); + + const chatMessages2 = await TestClient.loadHistory(conversationId, '3'); + expect(TestClient.currentMessages).toHaveLength(3); + expect(chatMessages2[chatMessages2.length - 1].text).toEqual('What\'s up'); + }); + + /* Most of the new sendMessage logic revolving around edited/continued AI messages + * can be summarized by the following test. The condition will load the entire history up to + * the message that is being edited, which will trigger the AI API to 'continue' the response. + * The 'userMessage' is only passed by convention and is not necessary for the generation. + */ + it('should not push userMessage to currentMessages when isEdited is true and vice versa', async () => { + const overrideParentMessageId = 'user-message-id'; + const responseMessageId = 'response-message-id'; + const newHistory = messageHistory.slice(); + newHistory.push({ + role: 'assistant', + isCreatedByUser: false, + text: 'test message', + messageId: responseMessageId, + parentMessageId: '3', + }); + + TestClient = initializeFakeClient(apiKey, options, newHistory); + const sendMessageOptions = { + isEdited: true, + overrideParentMessageId, + parentMessageId: '3', + responseMessageId, + }; + + await TestClient.sendMessage('test message', sendMessageOptions); + const currentMessages = TestClient.currentMessages; + expect(currentMessages[currentMessages.length - 1].messageId).not.toEqual( + overrideParentMessageId, + ); + + // Test the opposite case + sendMessageOptions.isEdited = false; + await TestClient.sendMessage('test message', sendMessageOptions); + const currentMessages2 = TestClient.currentMessages; + expect(currentMessages2[currentMessages2.length - 1].messageId).toEqual( + overrideParentMessageId, + ); }); test('setOptions is called with the correct arguments', async () => { diff --git a/api/app/clients/specs/FakeClient.js b/api/app/clients/specs/FakeClient.js index 5cd7556bcf5..e46cee687e1 100644 --- a/api/app/clients/specs/FakeClient.js +++ b/api/app/clients/specs/FakeClient.js @@ -1,4 +1,3 @@ -const crypto = require('crypto'); const BaseClient = require('../BaseClient'); const { maxTokensMap } = require('../../../utils'); @@ -87,86 +86,6 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => { return 'Mock response text'; }); - TestClient.sendMessage = jest.fn().mockImplementation(async (message, opts = {}) => { - if (opts && typeof opts === 'object') { - TestClient.setOptions(opts); - } - - const user = opts.user || null; - const conversationId = opts.conversationId || crypto.randomUUID(); - const parentMessageId = opts.parentMessageId || '00000000-0000-0000-0000-000000000000'; - const userMessageId = opts.overrideParentMessageId || crypto.randomUUID(); - const saveOptions = TestClient.getSaveOptions(); - - this.pastMessages = await TestClient.loadHistory( - conversationId, - TestClient.options?.parentMessageId, - ); - - const userMessage = { - text: message, - sender: TestClient.sender, - isCreatedByUser: true, - messageId: userMessageId, - parentMessageId, - conversationId, - }; - - const response = { - sender: TestClient.sender, - text: 'Hello, User!', - isCreatedByUser: false, - messageId: crypto.randomUUID(), - parentMessageId: userMessage.messageId, - conversationId, - }; - - fakeMessages.push(userMessage); - fakeMessages.push(response); - - if (typeof opts.getIds === 'function') { - opts.getIds({ - userMessage, - conversationId, - responseMessageId: response.messageId, - }); - } - - if (typeof opts.onStart === 'function') { - opts.onStart(userMessage); - } - - let { prompt: payload, tokenCountMap } = await TestClient.buildMessages( - this.currentMessages, - userMessage.messageId, - TestClient.getBuildMessagesOptions(opts), - ); - - if (tokenCountMap) { - payload = payload.map((message, i) => { - const { tokenCount, ...messageWithoutTokenCount } = message; - // userMessage is always the last one in the payload - if (i === payload.length - 1) { - userMessage.tokenCount = message.tokenCount; - console.debug( - `Token count for user message: ${tokenCount}`, - `Instruction Tokens: ${tokenCountMap.instructions || 'N/A'}`, - ); - } - return messageWithoutTokenCount; - }); - TestClient.handleTokenCountMap(tokenCountMap); - } - - await TestClient.saveMessageToDatabase(userMessage, saveOptions, user); - response.text = await TestClient.sendCompletion(payload, opts); - if (tokenCountMap && TestClient.getTokenCountForResponse) { - response.tokenCount = TestClient.getTokenCountForResponse(response); - } - await TestClient.saveMessageToDatabase(response, saveOptions, user); - return response; - }); - TestClient.buildMessages = jest.fn(async (messages, parentMessageId) => { const orderedMessages = TestClient.constructor.getMessagesForConversation( messages, diff --git a/api/app/clients/specs/OpenAIClient.test.js b/api/app/clients/specs/OpenAIClient.test.js index 41aeb4f3b49..dd4de5cc7c5 100644 --- a/api/app/clients/specs/OpenAIClient.test.js +++ b/api/app/clients/specs/OpenAIClient.test.js @@ -1,5 +1,7 @@ const OpenAIClient = require('../OpenAIClient'); +jest.mock('meilisearch'); + describe('OpenAIClient', () => { let client, client2; const model = 'gpt-4'; @@ -25,6 +27,9 @@ describe('OpenAIClient', () => { content: 'Refined answer', tokenCount: 30, }); + client.buildPrompt = jest + .fn() + .mockResolvedValue({ prompt: messages.map((m) => m.text).join('\n') }); client.constructor.freeAndResetAllEncoders(); }); diff --git a/api/app/clients/specs/PluginsClient.test.js b/api/app/clients/specs/PluginsClient.test.js index 59218c6206e..a11d490b221 100644 --- a/api/app/clients/specs/PluginsClient.test.js +++ b/api/app/clients/specs/PluginsClient.test.js @@ -111,7 +111,6 @@ describe('PluginsClient', () => { }); const response = await TestAgent.sendMessage(userMessage); - console.log(response); parentMessageId = response.messageId; conversationId = response.conversationId; expect(response).toEqual(expectedResult); diff --git a/api/app/clients/tools/util/addOpenAPISpecs.js b/api/app/clients/tools/util/addOpenAPISpecs.js index 2d5756f1948..8b87be9941d 100644 --- a/api/app/clients/tools/util/addOpenAPISpecs.js +++ b/api/app/clients/tools/util/addOpenAPISpecs.js @@ -20,7 +20,6 @@ async function addOpenAPISpecs(availableTools) { } return availableTools; } catch (error) { - console.log('addOpenAPISpecs error', error); return availableTools; } } diff --git a/api/app/clients/tools/util/handleTools.test.js b/api/app/clients/tools/util/handleTools.test.js index 674543ba293..f0c6ee6601a 100644 --- a/api/app/clients/tools/util/handleTools.test.js +++ b/api/app/clients/tools/util/handleTools.test.js @@ -83,7 +83,6 @@ describe('Tool Handlers', () => { it('returns valid tools given input tools and user authentication', async () => { const validTools = await validateTools(fakeUser._id, initialTools); expect(validTools).toBeDefined(); - console.log('validateTools: validTools', validTools); expect(validTools.some((tool) => tool === pluginKey)).toBeTruthy(); expect(validTools.length).toBeGreaterThan(0); }); diff --git a/api/app/index.js b/api/app/index.js index 95624829a93..fe1462331e0 100644 --- a/api/app/index.js +++ b/api/app/index.js @@ -3,15 +3,11 @@ const { askBing } = require('./bingai'); const clients = require('./clients'); const titleConvo = require('./titleConvo'); const titleConvoBing = require('./titleConvoBing'); -const getCitations = require('../lib/parse/getCitations'); -const citeText = require('../lib/parse/citeText'); module.exports = { browserClient, askBing, titleConvo, titleConvoBing, - getCitations, - citeText, ...clients, }; diff --git a/api/lib/parse/getCitations.js b/api/lib/parse/getCitations.js deleted file mode 100644 index f99363d1453..00000000000 --- a/api/lib/parse/getCitations.js +++ /dev/null @@ -1,18 +0,0 @@ -// const regex = / \[\d+\..*?\]\(.*?\)/g; -const regex = / \[.*?]\(.*?\)/g; - -const getCitations = (res) => { - const adaptiveCards = res.details.adaptiveCards; - const textBlocks = adaptiveCards && adaptiveCards[0].body; - if (!textBlocks) { - return ''; - } - let links = textBlocks[textBlocks.length - 1]?.text.match(regex); - if (links?.length === 0 || !links) { - return ''; - } - links = links.map((link) => link.trim()); - return links.join('\n - '); -}; - -module.exports = getCitations; diff --git a/api/models/Message.js b/api/models/Message.js index 20d06486b8d..0fa2265e88e 100644 --- a/api/models/Message.js +++ b/api/models/Message.js @@ -14,6 +14,7 @@ module.exports = { error, unfinished, cancelled, + finish_reason = null, tokenCount = null, plugin = null, model = null, @@ -29,6 +30,7 @@ module.exports = { sender, text, isCreatedByUser, + finish_reason, error, unfinished, cancelled, diff --git a/api/models/schema/messageSchema.js b/api/models/schema/messageSchema.js index 792b2d545af..98d6cef239d 100644 --- a/api/models/schema/messageSchema.js +++ b/api/models/schema/messageSchema.js @@ -67,6 +67,9 @@ const messageSchema = mongoose.Schema( type: Boolean, default: false, }, + finish_reason: { + type: String, + }, _meiliIndex: { type: Boolean, required: false, diff --git a/api/server/index.js b/api/server/index.js index 2480dc25f56..6ff02656a38 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -84,6 +84,7 @@ config.validate(); // Validate the config app.use('/api/user', routes.user); app.use('/api/search', routes.search); app.use('/api/ask', routes.ask); + app.use('/api/edit', routes.edit); app.use('/api/messages', routes.messages); app.use('/api/convos', routes.convos); app.use('/api/presets', routes.presets); diff --git a/api/server/middleware/abortControllers.js b/api/server/middleware/abortControllers.js new file mode 100644 index 00000000000..31acbfe3891 --- /dev/null +++ b/api/server/middleware/abortControllers.js @@ -0,0 +1,2 @@ +// abortControllers.js +module.exports = new Map(); diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js new file mode 100644 index 00000000000..f678e448682 --- /dev/null +++ b/api/server/middleware/abortMiddleware.js @@ -0,0 +1,106 @@ +const { saveMessage, getConvo, getConvoTitle } = require('../../models'); +const { sendMessage, handleError } = require('../utils'); +const abortControllers = require('./abortControllers'); + +async function abortMessage(req, res) { + const { abortKey } = req.body; + + if (!abortControllers.has(abortKey) && !res.headersSent) { + return res.status(404).send('Request not found'); + } + + const { abortController } = abortControllers.get(abortKey); + const ret = await abortController.abortCompletion(); + console.log('Aborted request', abortKey); + abortControllers.delete(abortKey); + res.send(JSON.stringify(ret)); +} + +const handleAbort = () => { + return async (req, res) => { + try { + return await abortMessage(req, res); + } catch (err) { + console.error(err); + } + }; +}; + +const createAbortController = (res, req, endpointOption, getAbortData) => { + const abortController = new AbortController(); + const onStart = (userMessage) => { + sendMessage(res, { message: userMessage, created: true }); + abortControllers.set(userMessage.conversationId, { abortController, ...endpointOption }); + + res.on('finish', function () { + abortControllers.delete(userMessage.conversationId); + }); + }; + + abortController.abortCompletion = async function () { + abortController.abort(); + const { conversationId, userMessage, ...responseData } = getAbortData(); + + const responseMessage = { + ...responseData, + finish_reason: 'incomplete', + model: endpointOption.modelOptions.model, + unfinished: false, + cancelled: true, + error: false, + }; + + saveMessage(responseMessage); + + return { + title: await getConvoTitle(req.user.id, conversationId), + final: true, + conversation: await getConvo(req.user.id, conversationId), + requestMessage: userMessage, + responseMessage: responseMessage, + }; + }; + + return { abortController, onStart }; +}; + +const handleAbortError = async (res, req, error, data) => { + console.error(error); + const { sender, conversationId, messageId, parentMessageId, partialText } = data; + + const respondWithError = async () => { + const errorMessage = { + sender, + messageId, + conversationId, + parentMessageId, + unfinished: false, + cancelled: false, + error: true, + text: error.message, + }; + if (abortControllers.has(conversationId)) { + const { abortController } = abortControllers.get(conversationId); + abortController.abort(); + abortControllers.delete(conversationId); + } + await saveMessage(errorMessage); + handleError(res, errorMessage); + }; + + if (partialText?.length > 2) { + try { + return await abortMessage(req, res); + } catch (err) { + return respondWithError(); + } + } else { + return respondWithError(); + } +}; + +module.exports = { + handleAbort, + createAbortController, + handleAbortError, +}; diff --git a/api/server/middleware/buildEndpointOption.js b/api/server/middleware/buildEndpointOption.js new file mode 100644 index 00000000000..ea6ad637e84 --- /dev/null +++ b/api/server/middleware/buildEndpointOption.js @@ -0,0 +1,20 @@ +const openAI = require('../routes/endpoints/openAI'); +const gptPlugins = require('../routes/endpoints/gptPlugins'); +const anthropic = require('../routes/endpoints/anthropic'); +const { parseConvo } = require('../routes/endpoints/schemas'); + +const buildFunction = { + openAI: openAI.buildOptions, + azureOpenAI: openAI.buildOptions, + gptPlugins: gptPlugins.buildOptions, + anthropic: anthropic.buildOptions, +}; + +function buildEndpointOption(req, res, next) { + const { endpoint } = req.body; + const parsedBody = parseConvo(endpoint, req.body); + req.body.endpointOption = buildFunction[endpoint](endpoint, parsedBody); + next(); +} + +module.exports = buildEndpointOption; diff --git a/api/server/middleware/index.js b/api/server/middleware/index.js new file mode 100644 index 00000000000..14262600863 --- /dev/null +++ b/api/server/middleware/index.js @@ -0,0 +1,17 @@ +const abortMiddleware = require('./abortMiddleware'); +const setHeaders = require('./setHeaders'); +const requireJwtAuth = require('./requireJwtAuth'); +const requireLocalAuth = require('./requireLocalAuth'); +const validateEndpoint = require('./validateEndpoint'); +const buildEndpointOption = require('./buildEndpointOption'); +const validateRegistration = require('./validateRegistration'); + +module.exports = { + ...abortMiddleware, + setHeaders, + requireJwtAuth, + requireLocalAuth, + validateEndpoint, + buildEndpointOption, + validateRegistration, +}; diff --git a/api/middleware/requireJwtAuth.js b/api/server/middleware/requireJwtAuth.js similarity index 100% rename from api/middleware/requireJwtAuth.js rename to api/server/middleware/requireJwtAuth.js diff --git a/api/middleware/requireLocalAuth.js b/api/server/middleware/requireLocalAuth.js similarity index 92% rename from api/middleware/requireLocalAuth.js rename to api/server/middleware/requireLocalAuth.js index b8700412bd3..107d370e855 100644 --- a/api/middleware/requireLocalAuth.js +++ b/api/server/middleware/requireLocalAuth.js @@ -1,5 +1,5 @@ const passport = require('passport'); -const DebugControl = require('../utils/debug.js'); +const DebugControl = require('../../utils/debug.js'); function log({ title, parameters }) { DebugControl.log.functionName(title); diff --git a/api/server/middleware/setHeaders.js b/api/server/middleware/setHeaders.js new file mode 100644 index 00000000000..c1b58e2a5ab --- /dev/null +++ b/api/server/middleware/setHeaders.js @@ -0,0 +1,12 @@ +function setHeaders(req, res, next) { + res.writeHead(200, { + Connection: 'keep-alive', + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache, no-transform', + 'Access-Control-Allow-Origin': '*', + 'X-Accel-Buffering': 'no', + }); + next(); +} + +module.exports = setHeaders; diff --git a/api/server/middleware/validateEndpoint.js b/api/server/middleware/validateEndpoint.js new file mode 100644 index 00000000000..6e9c914c8eb --- /dev/null +++ b/api/server/middleware/validateEndpoint.js @@ -0,0 +1,19 @@ +const { handleError } = require('../utils'); + +function validateEndpoint(req, res, next) { + const { endpoint } = req.body; + + if (!req.body.text || req.body.text.length === 0) { + return handleError(res, { text: 'Prompt empty or too short' }); + } + + const pathEndpoint = req.baseUrl.split('/')[3]; + + if (endpoint !== pathEndpoint) { + return handleError(res, { text: 'Illegal request: Endpoint mismatch' }); + } + + next(); +} + +module.exports = validateEndpoint; diff --git a/api/middleware/validateRegistration.js b/api/server/middleware/validateRegistration.js similarity index 100% rename from api/middleware/validateRegistration.js rename to api/server/middleware/validateRegistration.js diff --git a/api/server/routes/ask/anthropic.js b/api/server/routes/ask/anthropic.js index 6ede82c4d27..beca07b85ab 100644 --- a/api/server/routes/ask/anthropic.js +++ b/api/server/routes/ask/anthropic.js @@ -1,72 +1,43 @@ const express = require('express'); const router = express.Router(); -const crypto = require('crypto'); -const { titleConvo, AnthropicClient } = require('../../../app'); -const requireJwtAuth = require('../../../middleware/requireJwtAuth'); -const { abortMessage } = require('../../../utils'); +const { getResponseSender } = require('../endpoints/schemas'); +const { initializeClient } = require('../endpoints/anthropic'); +const { + handleAbort, + createAbortController, + handleAbortError, + setHeaders, + requireJwtAuth, + validateEndpoint, + buildEndpointOption, +} = require('../../middleware'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); -const { handleError, sendMessage, createOnProgress } = require('./handlers'); - -const abortControllers = new Map(); - -router.post('/abort', requireJwtAuth, async (req, res) => { - try { - return await abortMessage(req, res, abortControllers); - } catch (err) { - console.error(err); - } -}); - -router.post('/', requireJwtAuth, async (req, res) => { - const { endpoint, text, parentMessageId, conversationId: oldConversationId } = req.body; - if (text.length === 0) { - return handleError(res, { text: 'Prompt empty or too short' }); - } - if (endpoint !== 'anthropic') { - return handleError(res, { text: 'Illegal request' }); - } - - const endpointOption = { - promptPrefix: req.body?.promptPrefix ?? null, - modelLabel: req.body?.modelLabel ?? null, - token: req.body?.token ?? null, - modelOptions: { - model: req.body?.model ?? 'claude-1', - temperature: req.body?.temperature ?? 1, - maxOutputTokens: req.body?.maxOutputTokens ?? 1024, - topP: req.body?.topP ?? 0.7, - topK: req.body?.topK ?? 5, - }, - }; - - const conversationId = oldConversationId || crypto.randomUUID(); - - return await ask({ - text, - endpointOption, - conversationId, - parentMessageId, - req, - res, - }); -}); - -const ask = async ({ text, endpointOption, parentMessageId = null, conversationId, req, res }) => { - res.writeHead(200, { - Connection: 'keep-alive', - 'Content-Type': 'text/event-stream', - 'Cache-Control': 'no-cache, no-transform', - 'Access-Control-Allow-Origin': '*', - 'X-Accel-Buffering': 'no', - }); - - let userMessage; - let userMessageId; - let responseMessageId; - let lastSavedTimestamp = 0; - const { overrideParentMessageId = null } = req.body; +const { sendMessage, createOnProgress } = require('../../utils'); + +router.post('/abort', requireJwtAuth, handleAbort()); + +router.post( + '/', + requireJwtAuth, + validateEndpoint, + buildEndpointOption, + setHeaders, + async (req, res) => { + let { + text, + endpointOption, + conversationId, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('ask log'); + console.dir({ text, conversationId, endpointOption }, { depth: null }); + let userMessage; + let userMessageId; + let responseMessageId; + let lastSavedTimestamp = 0; + let saveDelay = 100; - try { const getIds = (data) => { userMessage = data.userMessage; userMessageId = data.userMessage.messageId; @@ -79,116 +50,95 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI const { onProgress: progressCallback, getPartialText } = createOnProgress({ onProgress: ({ text: partialText }) => { const currentTimestamp = Date.now(); - if (currentTimestamp - lastSavedTimestamp > 500) { + + if (currentTimestamp - lastSavedTimestamp > saveDelay) { lastSavedTimestamp = currentTimestamp; saveMessage({ messageId: responseMessageId, - sender: 'Anthropic', + sender: getResponseSender(endpointOption), conversationId, - parentMessageId: overrideParentMessageId || userMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, text: partialText, unfinished: true, cancelled: false, error: false, }); } + + if (saveDelay < 500) { + saveDelay = 500; + } }, }); + try { + const getAbortData = () => ({ + conversationId, + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + userMessage, + }); - const abortController = new AbortController(); - abortController.abortAsk = async function () { - this.abort(); + const { abortController, onStart } = createAbortController( + res, + req, + endpointOption, + getAbortData, + ); - const responseMessage = { - messageId: responseMessageId, - sender: 'Anthropic', + const { client } = initializeClient(req, endpointOption); + + let response = await client.sendMessage(text, { + getIds, + debug: false, + user: req.user.id, conversationId, - parentMessageId: overrideParentMessageId || userMessageId, - text: getPartialText(), - model: endpointOption.modelOptions.model, - unfinished: false, - cancelled: true, - error: false, - }; + parentMessageId, + overrideParentMessageId, + ...endpointOption, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId ?? userMessageId, + }), + onStart, + abortController, + }); + + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } - saveMessage(responseMessage); + await saveConvo(req.user.id, { + ...endpointOption, + ...endpointOption.modelOptions, + conversationId, + endpoint: 'anthropic', + }); - return { + await saveMessage(response); + sendMessage(res, { title: await getConvoTitle(req.user.id, conversationId), final: true, conversation: await getConvo(req.user.id, conversationId), requestMessage: userMessage, - responseMessage: responseMessage, - }; - }; - - const onStart = (userMessage) => { - sendMessage(res, { message: userMessage, created: true }); - abortControllers.set(userMessage.conversationId, { abortController, ...endpointOption }); - }; - - const client = new AnthropicClient(endpointOption.token); - - let response = await client.sendMessage(text, { - getIds, - debug: false, - user: req.user.id, - conversationId, - parentMessageId, - overrideParentMessageId, - ...endpointOption, - onProgress: progressCallback.call(null, { - res, - text, - parentMessageId: overrideParentMessageId || userMessageId, - }), - onStart, - abortController, - }); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - - await saveConvo(req.user.id, { - ...endpointOption, - ...endpointOption.modelOptions, - conversationId, - endpoint: 'anthropic', - }); - - await saveMessage(response); - sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), - final: true, - conversation: await getConvo(req.user.id, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); + responseMessage: response, + }); + res.end(); - if (parentMessageId == '00000000-0000-0000-0000-000000000000') { - const title = await titleConvo({ text, response }); - await saveConvo(req.user.id, { + // TODO: add anthropic titling + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, conversationId, - title, + sender: getResponseSender(endpointOption), + messageId: responseMessageId, + parentMessageId: userMessageId, }); } - } catch (error) { - console.error(error); - const errorMessage = { - messageId: responseMessageId, - sender: 'Anthropic', - conversationId, - parentMessageId, - unfinished: false, - cancelled: false, - error: true, - text: error.message, - }; - await saveMessage(errorMessage); - handleError(res, errorMessage); - } -}; + }, +); module.exports = router; diff --git a/api/server/routes/ask/askChatGPTBrowser.js b/api/server/routes/ask/askChatGPTBrowser.js index 576f5810810..2b2472b89e6 100644 --- a/api/server/routes/ask/askChatGPTBrowser.js +++ b/api/server/routes/ask/askChatGPTBrowser.js @@ -1,13 +1,12 @@ const express = require('express'); const crypto = require('crypto'); const router = express.Router(); -// const { getChatGPTBrowserModels } = require('../endpoints'); const { browserClient } = require('../../../app/'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); -const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers'); -const requireJwtAuth = require('../../../middleware/requireJwtAuth'); +const { handleError, sendMessage, createOnProgress, handleText } = require('../../utils'); +const { requireJwtAuth, setHeaders } = require('../../middleware'); -router.post('/', requireJwtAuth, async (req, res) => { +router.post('/', requireJwtAuth, setHeaders, async (req, res) => { const { endpoint, text, @@ -86,15 +85,6 @@ const ask = async ({ }) => { let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage; const userId = req.user.id; - - res.writeHead(200, { - Connection: 'keep-alive', - 'Content-Type': 'text/event-stream', - 'Cache-Control': 'no-cache, no-transform', - 'Access-Control-Allow-Origin': '*', - 'X-Accel-Buffering': 'no', - }); - let responseMessageId = crypto.randomUUID(); let getPartialMessage = null; try { diff --git a/api/server/routes/ask/bingAI.js b/api/server/routes/ask/bingAI.js index ced293105a9..f8e834fb97b 100644 --- a/api/server/routes/ask/bingAI.js +++ b/api/server/routes/ask/bingAI.js @@ -3,10 +3,10 @@ const crypto = require('crypto'); const router = express.Router(); const { titleConvoBing, askBing } = require('../../../app'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); -const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers'); -const requireJwtAuth = require('../../../middleware/requireJwtAuth'); +const { handleError, sendMessage, createOnProgress, handleText } = require('../../utils'); +const { requireJwtAuth, setHeaders } = require('../../middleware'); -router.post('/', requireJwtAuth, async (req, res) => { +router.post('/', requireJwtAuth, setHeaders, async (req, res) => { const { endpoint, text, @@ -103,14 +103,6 @@ const ask = async ({ let responseMessageId = crypto.randomUUID(); - res.writeHead(200, { - Connection: 'keep-alive', - 'Content-Type': 'text/event-stream', - 'Cache-Control': 'no-cache, no-transform', - 'Access-Control-Allow-Origin': '*', - 'X-Accel-Buffering': 'no', - }); - if (preSendRequest) { sendMessage(res, { message: userMessage, created: true }); } diff --git a/api/server/routes/ask/google.js b/api/server/routes/ask/google.js index f3d25cbcd4a..17775e2f38c 100644 --- a/api/server/routes/ask/google.js +++ b/api/server/routes/ask/google.js @@ -2,12 +2,11 @@ const express = require('express'); const router = express.Router(); const crypto = require('crypto'); const { titleConvo, GoogleClient } = require('../../../app'); -// const GoogleClient = require('../../../app/google/GoogleClient'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); -const { handleError, sendMessage, createOnProgress } = require('./handlers'); -const requireJwtAuth = require('../../../middleware/requireJwtAuth'); +const { handleError, sendMessage, createOnProgress } = require('../../utils'); +const { requireJwtAuth, setHeaders } = require('../../middleware'); -router.post('/', requireJwtAuth, async (req, res) => { +router.post('/', requireJwtAuth, setHeaders, async (req, res) => { const { endpoint, text, parentMessageId, conversationId: oldConversationId } = req.body; if (text.length === 0) { return handleError(res, { text: 'Prompt empty or too short' }); @@ -50,13 +49,6 @@ router.post('/', requireJwtAuth, async (req, res) => { }); const ask = async ({ text, endpointOption, parentMessageId = null, conversationId, req, res }) => { - res.writeHead(200, { - Connection: 'keep-alive', - 'Content-Type': 'text/event-stream', - 'Cache-Control': 'no-cache, no-transform', - 'Access-Control-Allow-Origin': '*', - 'X-Accel-Buffering': 'no', - }); let userMessage; let userMessageId; let responseMessageId; diff --git a/api/server/routes/ask/gptPlugins.js b/api/server/routes/ask/gptPlugins.js index 7a336fe97de..2ce3d21d8ef 100644 --- a/api/server/routes/ask/gptPlugins.js +++ b/api/server/routes/ask/gptPlugins.js @@ -1,112 +1,56 @@ const express = require('express'); const router = express.Router(); -const { titleConvo, validateTools, PluginsClient } = require('../../../app'); -const { abortMessage, getAzureCredentials } = require('../../../utils'); -const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); +const { getResponseSender } = require('../endpoints/schemas'); +const { validateTools } = require('../../../app'); +const { addTitle } = require('../endpoints/openAI'); +const { initializeClient } = require('../endpoints/gptPlugins'); +const { saveMessage, getConvoTitle, getConvo } = require('../../../models'); +const { sendMessage, createOnProgress, formatSteps, formatAction } = require('../../utils'); const { - handleError, - sendMessage, - createOnProgress, - formatSteps, - formatAction, -} = require('./handlers'); -const requireJwtAuth = require('../../../middleware/requireJwtAuth'); - -const abortControllers = new Map(); - -router.post('/abort', requireJwtAuth, async (req, res) => { - try { - return await abortMessage(req, res, abortControllers); - } catch (err) { - console.error(err); - } -}); - -router.post('/', requireJwtAuth, async (req, res) => { - const { endpoint, text, parentMessageId, conversationId } = req.body; - if (text.length === 0) { - return handleError(res, { text: 'Prompt empty or too short' }); - } - if (endpoint !== 'gptPlugins') { - return handleError(res, { text: 'Illegal request' }); - } - - const agentOptions = req.body?.agentOptions ?? { - agent: 'functions', - skipCompletion: true, - model: 'gpt-3.5-turbo', - temperature: 0, - // top_p: 1, - // presence_penalty: 0, - // frequency_penalty: 0 - }; - - const tools = req.body?.tools.map((tool) => tool.pluginKey) ?? []; - // build endpoint option - const endpointOption = { - chatGptLabel: tools.length === 0 ? req.body?.chatGptLabel ?? null : null, - promptPrefix: tools.length === 0 ? req.body?.promptPrefix ?? null : null, - tools, - modelOptions: { - model: req.body?.model ?? 'gpt-4', - temperature: req.body?.temperature ?? 0, - top_p: req.body?.top_p ?? 1, - presence_penalty: req.body?.presence_penalty ?? 0, - frequency_penalty: req.body?.frequency_penalty ?? 0, - }, - agentOptions: { - ...agentOptions, - // agent: 'functions' - }, - }; - - console.log('ask log'); - console.dir({ text, conversationId, endpointOption }, { depth: null }); - - // eslint-disable-next-line no-use-before-define - return await ask({ - text, - endpoint, - endpointOption, - conversationId, - parentMessageId, - req, - res, - }); -}); - -const ask = async ({ - text, - endpoint, - endpointOption, - parentMessageId = null, - conversationId, - req, - res, -}) => { - res.writeHead(200, { - Connection: 'keep-alive', - 'Content-Type': 'text/event-stream', - 'Cache-Control': 'no-cache, no-transform', - 'Access-Control-Allow-Origin': '*', - 'X-Accel-Buffering': 'no', - }); - let userMessage; - let userMessageId; - let responseMessageId; - let lastSavedTimestamp = 0; - const newConvo = !conversationId; - const { overrideParentMessageId = null } = req.body; - const user = req.user.id; - - const plugin = { - loading: true, - inputs: [], - latest: null, - outputs: null, - }; + handleAbort, + createAbortController, + handleAbortError, + setHeaders, + requireJwtAuth, + validateEndpoint, + buildEndpointOption, +} = require('../../middleware'); + +router.post('/abort', requireJwtAuth, handleAbort()); + +router.post( + '/', + requireJwtAuth, + validateEndpoint, + buildEndpointOption, + setHeaders, + async (req, res) => { + let { + text, + endpointOption, + conversationId, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('ask log'); + console.dir({ text, conversationId, endpointOption }, { depth: null }); + let metadata; + let userMessage; + let userMessageId; + let responseMessageId; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const newConvo = !conversationId; + const user = req.user.id; + + const plugin = { + loading: true, + inputs: [], + latest: null, + outputs: null, + }; - try { + const addMetadata = (data) => (metadata = data); const getIds = (data) => { userMessage = data.userMessage; userMessageId = userMessage.messageId; @@ -128,11 +72,11 @@ const ask = async ({ plugin.loading = false; } - if (currentTimestamp - lastSavedTimestamp > 500) { + if (currentTimestamp - lastSavedTimestamp > saveDelay) { lastSavedTimestamp = currentTimestamp; saveMessage({ messageId: responseMessageId, - sender: 'ChatGPT', + sender: getResponseSender(endpointOption), conversationId, parentMessageId: overrideParentMessageId || userMessageId, text: partialText, @@ -142,63 +86,13 @@ const ask = async ({ error: false, }); } + + if (saveDelay < 500) { + saveDelay = 500; + } }, }); - const abortController = new AbortController(); - abortController.abortAsk = async function () { - this.abort(); - - const responseMessage = { - messageId: responseMessageId, - sender: endpointOption?.chatGptLabel || 'ChatGPT', - conversationId, - parentMessageId: overrideParentMessageId || userMessageId, - text: getPartialText(), - plugin: { ...plugin, loading: false }, - model: endpointOption.modelOptions.model, - unfinished: false, - cancelled: true, - error: false, - }; - - saveMessage(responseMessage); - - return { - title: await getConvoTitle(req.user.id, conversationId), - final: true, - conversation: await getConvo(req.user.id, conversationId), - requestMessage: userMessage, - responseMessage: responseMessage, - }; - }; - - const onStart = (userMessage) => { - sendMessage(res, { message: userMessage, created: true }); - abortControllers.set(userMessage.conversationId, { abortController, ...endpointOption }); - }; - - endpointOption.tools = await validateTools(user, endpointOption.tools); - const clientOptions = { - debug: true, - endpoint, - reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null, - proxy: process.env.PROXY || null, - ...endpointOption, - }; - - let openAIApiKey = req.body?.token ?? process.env.OPENAI_API_KEY; - if (process.env.PLUGINS_USE_AZURE) { - clientOptions.azure = getAzureCredentials(); - openAIApiKey = clientOptions.azure.azureOpenAIApiKey; - } - - if (openAIApiKey && openAIApiKey.includes('azure') && !clientOptions.azure) { - clientOptions.azure = JSON.parse(req.body?.token) ?? getAzureCredentials(); - openAIApiKey = clientOptions.azure.azureOpenAIApiKey; - } - const chatAgent = new PluginsClient(openAIApiKey, clientOptions); - const onAgentAction = (action, start = false) => { const formattedAction = formatAction(action); plugin.inputs.push(formattedAction); @@ -219,70 +113,86 @@ const ask = async ({ // console.log('CHAIN END', plugin.outputs); }; - let response = await chatAgent.sendMessage(text, { - getIds, - user, - parentMessageId, + const getAbortData = () => ({ + sender: getResponseSender(endpointOption), conversationId, - overrideParentMessageId, - onAgentAction, - onChainEnd, - onStart, - ...endpointOption, - onProgress: progressCallback.call(null, { - res, - text, - plugin, - parentMessageId: overrideParentMessageId || userMessageId, - }), - abortController, + messageId: responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + plugin: { ...plugin, loading: false }, + userMessage, }); + const { abortController, onStart } = createAbortController( + res, + req, + endpointOption, + getAbortData, + ); + + try { + endpointOption.tools = await validateTools(user, endpointOption.tools); + const { client, azure, openAIApiKey } = initializeClient(req, endpointOption); + + let response = await client.sendMessage(text, { + user, + conversationId, + parentMessageId, + overrideParentMessageId, + getIds, + onAgentAction, + onChainEnd, + onStart, + addMetadata, + ...endpointOption, + onProgress: progressCallback.call(null, { + res, + text, + plugin, + parentMessageId: overrideParentMessageId || userMessageId, + }), + abortController, + }); - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } - console.log('CLIENT RESPONSE'); - console.dir(response, { depth: null }); - response.plugin = { ...plugin, loading: false }; - await saveMessage(response); + if (metadata) { + response = { ...response, ...metadata }; + } - sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), - final: true, - conversation: await getConvo(req.user.id, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); + console.log('CLIENT RESPONSE'); + console.dir(response, { depth: null }); + response.plugin = { ...plugin, loading: false }; + await saveMessage(response); - if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) { - const title = await titleConvo({ + sendMessage(res, { + title: await getConvoTitle(req.user.id, conversationId), + final: true, + conversation: await getConvo(req.user.id, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + addTitle(req, { text, + newConvo, response, openAIApiKey, - azure: !!clientOptions.azure, + parentMessageId, + azure: !!azure, }); - await saveConvo(req.user.id, { - conversationId: conversationId, - title, + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender: getResponseSender(endpointOption), + messageId: responseMessageId, + parentMessageId: userMessageId, }); } - } catch (error) { - console.error(error); - const errorMessage = { - messageId: responseMessageId, - sender: 'ChatGPT', - conversationId, - parentMessageId: userMessageId, - unfinished: false, - cancelled: false, - error: true, - text: error.message, - }; - await saveMessage(errorMessage); - handleError(res, errorMessage); - } -}; + }, +); module.exports = router; diff --git a/api/server/routes/ask/index.js b/api/server/routes/ask/index.js index d088d97b174..77da50f68a0 100644 --- a/api/server/routes/ask/index.js +++ b/api/server/routes/ask/index.js @@ -1,7 +1,5 @@ const express = require('express'); const router = express.Router(); -// const askAzureOpenAI = require('./askAzureOpenAI';) -// const askOpenAI = require('./askOpenAI'); const openAI = require('./openAI'); const google = require('./google'); const bingAI = require('./bingAI'); @@ -9,7 +7,6 @@ const gptPlugins = require('./gptPlugins'); const askChatGPTBrowser = require('./askChatGPTBrowser'); const anthropic = require('./anthropic'); -// router.use('/azureOpenAI', askAzureOpenAI); router.use(['/azureOpenAI', '/openAI'], openAI); router.use('/google', google); router.use('/bingAI', bingAI); diff --git a/api/server/routes/ask/openAI.js b/api/server/routes/ask/openAI.js index db8e3a3cbd0..be236956af5 100644 --- a/api/server/routes/ask/openAI.js +++ b/api/server/routes/ask/openAI.js @@ -1,231 +1,160 @@ const express = require('express'); const router = express.Router(); -const { titleConvo, OpenAIClient } = require('../../../app'); -const { getAzureCredentials, abortMessage } = require('../../../utils'); -const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); -const { handleError, sendMessage, createOnProgress } = require('./handlers'); -const requireJwtAuth = require('../../../middleware/requireJwtAuth'); - -const abortControllers = new Map(); - -router.post('/abort', requireJwtAuth, async (req, res) => { - try { - return await abortMessage(req, res, abortControllers); - } catch (err) { - console.error(err); - } -}); - -router.post('/', requireJwtAuth, async (req, res) => { - const { endpoint, text, parentMessageId, conversationId } = req.body; - if (text.length === 0) { - return handleError(res, { text: 'Prompt empty or too short' }); - } - const isOpenAI = endpoint === 'openAI' || endpoint === 'azureOpenAI'; - if (!isOpenAI) { - return handleError(res, { text: 'Illegal request' }); - } - - // build endpoint option - const endpointOption = { - chatGptLabel: req.body?.chatGptLabel ?? null, - promptPrefix: req.body?.promptPrefix ?? null, - modelOptions: { - model: req.body?.model ?? 'gpt-3.5-turbo', - temperature: req.body?.temperature ?? 1, - top_p: req.body?.top_p ?? 1, - presence_penalty: req.body?.presence_penalty ?? 0, - frequency_penalty: req.body?.frequency_penalty ?? 0, - }, - }; - - console.log('ask log'); - console.dir({ text, conversationId, endpointOption }, { depth: null }); - - // eslint-disable-next-line no-use-before-define - return await ask({ - text, - endpointOption, - conversationId, - parentMessageId, - endpoint, - req, - res, - }); -}); - -const ask = async ({ - text, - endpointOption, - parentMessageId = null, - endpoint, - conversationId, - req, - res, -}) => { - res.writeHead(200, { - Connection: 'keep-alive', - 'Content-Type': 'text/event-stream', - 'Cache-Control': 'no-cache, no-transform', - 'Access-Control-Allow-Origin': '*', - 'X-Accel-Buffering': 'no', - }); - let userMessage; - let userMessageId; - let responseMessageId; - let lastSavedTimestamp = 0; - const newConvo = !conversationId; - const { overrideParentMessageId = null } = req.body; - const user = req.user.id; - - const getIds = (data) => { - userMessage = data.userMessage; - userMessageId = userMessage.messageId; - responseMessageId = data.responseMessageId; - if (!conversationId) { - conversationId = data.conversationId; - } - }; - - const { onProgress: progressCallback, getPartialText } = createOnProgress({ - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); - - if (currentTimestamp - lastSavedTimestamp > 500) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender: 'ChatGPT', - conversationId, - parentMessageId: overrideParentMessageId || userMessageId, - text: partialText, - model: endpointOption.modelOptions.model, - unfinished: true, - cancelled: false, - error: false, - }); +const { getResponseSender } = require('../endpoints/schemas'); +const { sendMessage, createOnProgress } = require('../../utils'); +const { addTitle, initializeClient } = require('../endpoints/openAI'); +const { saveMessage, getConvoTitle, getConvo } = require('../../../models'); +const { + handleAbort, + createAbortController, + handleAbortError, + setHeaders, + requireJwtAuth, + validateEndpoint, + buildEndpointOption, +} = require('../../middleware'); + +router.post('/abort', requireJwtAuth, handleAbort()); + +router.post( + '/', + requireJwtAuth, + validateEndpoint, + buildEndpointOption, + setHeaders, + async (req, res) => { + let { + text, + endpointOption, + conversationId, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('ask log'); + console.dir({ text, conversationId, endpointOption }, { depth: null }); + let metadata; + let userMessage; + let userMessageId; + let responseMessageId; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const newConvo = !conversationId; + const user = req.user.id; + + const addMetadata = (data) => (metadata = data); + + const getIds = (data) => { + userMessage = data.userMessage; + userMessageId = userMessage.messageId; + responseMessageId = data.responseMessageId; + if (!conversationId) { + conversationId = data.conversationId; } - }, - }); + }; - const abortController = new AbortController(); - abortController.abortAsk = async function () { - this.abort(); + const { onProgress: progressCallback, getPartialText } = createOnProgress({ + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); + + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + conversationId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: partialText, + model: endpointOption.modelOptions.model, + unfinished: true, + cancelled: false, + error: false, + }); + } + + if (saveDelay < 500) { + saveDelay = 500; + } + }, + }); - const responseMessage = { - messageId: responseMessageId, - sender: endpointOption?.chatGptLabel || 'ChatGPT', + const getAbortData = () => ({ + sender: getResponseSender(endpointOption), conversationId, - parentMessageId: overrideParentMessageId || userMessageId, + messageId: responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, text: getPartialText(), - model: endpointOption.modelOptions.model, - unfinished: false, - cancelled: true, - error: false, - }; - - saveMessage(responseMessage); - - return { - title: await getConvoTitle(req.user.id, conversationId), - final: true, - conversation: await getConvo(req.user.id, conversationId), - requestMessage: userMessage, - responseMessage: responseMessage, - }; - }; - - const onStart = (userMessage) => { - sendMessage(res, { message: userMessage, created: true }); - abortControllers.set(userMessage.conversationId, { abortController, ...endpointOption }); - }; - - try { - const clientOptions = { - // debug: true, - // contextStrategy: 'refine', - reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null, - proxy: process.env.PROXY || null, - endpoint, - ...endpointOption, - }; + userMessage, + }); - let openAIApiKey = req.body?.token ?? process.env.OPENAI_API_KEY; + const { abortController, onStart } = createAbortController( + res, + req, + endpointOption, + getAbortData, + ); - if (process.env.AZURE_API_KEY && endpoint === 'azureOpenAI') { - clientOptions.azure = JSON.parse(req.body?.token) ?? getAzureCredentials(); - openAIApiKey = clientOptions.azure.azureOpenAIApiKey; - } + try { + const { client, openAIApiKey } = initializeClient(req, endpointOption); - const client = new OpenAIClient(openAIApiKey, clientOptions); + let response = await client.sendMessage(text, { + user, + parentMessageId, + conversationId, + overrideParentMessageId, + getIds, + onStart, + addMetadata, + abortController, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId || userMessageId, + }), + }); - let response = await client.sendMessage(text, { - user, - parentMessageId, - conversationId, - overrideParentMessageId, - getIds, - onStart, - onProgress: progressCallback.call(null, { - res, - text, - parentMessageId: overrideParentMessageId || userMessageId, - }), - abortController, - }); + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } + if (metadata) { + response = { ...response, ...metadata }; + } - console.log( - 'promptTokens, completionTokens:', - response.promptTokens, - response.completionTokens, - ); - await saveMessage(response); - - sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), - final: true, - conversation: await getConvo(req.user.id, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); + console.log( + 'promptTokens, completionTokens:', + response.promptTokens, + response.completionTokens, + ); + await saveMessage(response); + + sendMessage(res, { + title: await getConvoTitle(req.user.id, conversationId), + final: true, + conversation: await getConvo(req.user.id, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); - if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) { - const title = await titleConvo({ + addTitle(req, { text, + newConvo, response, openAIApiKey, - azure: endpoint === 'azureOpenAI', + parentMessageId, + azure: endpointOption.endpoint === 'azureOpenAI', }); - await saveConvo(req.user.id, { + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, conversationId, - title, - }); - } - } catch (error) { - console.error(error); - const partialText = getPartialText(); - if (partialText?.length > 2) { - return await abortMessage(req, res, abortControllers); - } else { - const errorMessage = { + sender: getResponseSender(endpointOption), messageId: responseMessageId, - sender: 'ChatGPT', - conversationId, parentMessageId: userMessageId, - unfinished: false, - cancelled: false, - error: true, - text: error.message, - }; - await saveMessage(errorMessage); - handleError(res, errorMessage); + }); } - } -}; + }, +); module.exports = router; diff --git a/api/server/routes/auth.js b/api/server/routes/auth.js index 1f0c660c670..ff2e7cab0e3 100644 --- a/api/server/routes/auth.js +++ b/api/server/routes/auth.js @@ -7,9 +7,7 @@ const { } = require('../controllers/AuthController'); const { loginController } = require('../controllers/auth/LoginController'); const { logoutController } = require('../controllers/auth/LogoutController'); -const requireJwtAuth = require('../../middleware/requireJwtAuth'); -const requireLocalAuth = require('../../middleware/requireLocalAuth'); -const validateRegistration = require('../../middleware/validateRegistration'); +const { requireJwtAuth, requireLocalAuth, validateRegistration } = require('../middleware'); const router = express.Router(); diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index 28e29bea70f..66b3ffc0ac2 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -2,7 +2,7 @@ const express = require('express'); const router = express.Router(); const { getConvo, saveConvo } = require('../../models'); const { getConvosByPage, deleteConvos } = require('../../models/Conversation'); -const requireJwtAuth = require('../../middleware/requireJwtAuth'); +const requireJwtAuth = require('../middleware/requireJwtAuth'); router.get('/', requireJwtAuth, async (req, res) => { const pageNumber = req.query.pageNumber || 1; diff --git a/api/server/routes/edit/anthropic.js b/api/server/routes/edit/anthropic.js new file mode 100644 index 00000000000..7e141cbdfad --- /dev/null +++ b/api/server/routes/edit/anthropic.js @@ -0,0 +1,139 @@ +const express = require('express'); +const router = express.Router(); +const { getResponseSender } = require('../endpoints/schemas'); +const { initializeClient } = require('../endpoints/anthropic'); +const { + handleAbort, + createAbortController, + handleAbortError, + setHeaders, + requireJwtAuth, + validateEndpoint, + buildEndpointOption, +} = require('../../middleware'); +const { saveMessage, getConvoTitle, getConvo } = require('../../../models'); +const { sendMessage, createOnProgress } = require('../../utils'); + +router.post('/abort', requireJwtAuth, handleAbort()); + +router.post( + '/', + requireJwtAuth, + validateEndpoint, + buildEndpointOption, + setHeaders, + async (req, res) => { + let { + text, + generation, + endpointOption, + conversationId, + responseMessageId, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('edit log'); + console.dir({ text, conversationId, endpointOption }, { depth: null }); + let metadata; + let userMessage; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const userMessageId = parentMessageId; + + const addMetadata = (data) => (metadata = data); + const getIds = (data) => (userMessage = data.userMessage); + + const { onProgress: progressCallback, getPartialText } = createOnProgress({ + generation, + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + conversationId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: partialText, + unfinished: true, + cancelled: false, + error: false, + }); + } + + if (saveDelay < 500) { + saveDelay = 500; + } + }, + }); + try { + const getAbortData = () => ({ + conversationId, + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + userMessage, + }); + + const { abortController, onStart } = createAbortController( + res, + req, + endpointOption, + getAbortData, + ); + + const { client } = initializeClient(req, endpointOption); + + let response = await client.sendMessage(text, { + user: req.user.id, + isEdited: true, + conversationId, + parentMessageId, + responseMessageId, + overrideParentMessageId, + ...endpointOption, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId ?? userMessageId, + }), + getIds, + onStart, + addMetadata, + abortController, + }); + + if (metadata) { + response = { ...response, ...metadata }; + } + + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } + + await saveMessage(response); + sendMessage(res, { + title: await getConvoTitle(req.user.id, conversationId), + final: true, + conversation: await getConvo(req.user.id, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + + // TODO: add anthropic titling + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender: getResponseSender(endpointOption), + messageId: responseMessageId, + parentMessageId: userMessageId, + }); + } + }, +); + +module.exports = router; diff --git a/api/server/routes/edit/gptPlugins.js b/api/server/routes/edit/gptPlugins.js new file mode 100644 index 00000000000..a0c81d46ca6 --- /dev/null +++ b/api/server/routes/edit/gptPlugins.js @@ -0,0 +1,185 @@ +const express = require('express'); +const router = express.Router(); +const { getResponseSender } = require('../endpoints/schemas'); +const { validateTools } = require('../../../app'); +const { initializeClient } = require('../endpoints/gptPlugins'); +const { saveMessage, getConvoTitle, getConvo } = require('../../../models'); +const { sendMessage, createOnProgress, formatSteps, formatAction } = require('../../utils'); +const { + handleAbort, + createAbortController, + handleAbortError, + setHeaders, + requireJwtAuth, + validateEndpoint, + buildEndpointOption, +} = require('../../middleware'); + +router.post('/abort', requireJwtAuth, handleAbort()); + +router.post( + '/', + requireJwtAuth, + validateEndpoint, + buildEndpointOption, + setHeaders, + async (req, res) => { + let { + text, + generation, + endpointOption, + conversationId, + responseMessageId, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('edit log'); + console.dir({ text, conversationId, endpointOption }, { depth: null }); + let metadata; + let userMessage; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const userMessageId = parentMessageId; + const user = req.user.id; + + const plugin = { + loading: true, + inputs: [], + latest: null, + outputs: null, + }; + + const addMetadata = (data) => (metadata = data); + const getIds = (data) => (userMessage = data.userMessage); + + const { + onProgress: progressCallback, + sendIntermediateMessage, + getPartialText, + } = createOnProgress({ + generation, + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); + + if (plugin.loading === true) { + plugin.loading = false; + } + + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + conversationId, + parentMessageId: overrideParentMessageId || userMessageId, + text: partialText, + model: endpointOption.modelOptions.model, + unfinished: true, + cancelled: false, + error: false, + }); + } + + if (saveDelay < 500) { + saveDelay = 500; + } + }, + }); + + const onAgentAction = (action, start = false) => { + const formattedAction = formatAction(action); + plugin.inputs.push(formattedAction); + plugin.latest = formattedAction.plugin; + if (!start) { + saveMessage(userMessage); + } + sendIntermediateMessage(res, { plugin }); + // console.log('PLUGIN ACTION', formattedAction); + }; + + const onChainEnd = (data) => { + let { intermediateSteps: steps } = data; + plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.'; + plugin.loading = false; + saveMessage(userMessage); + sendIntermediateMessage(res, { plugin }); + // console.log('CHAIN END', plugin.outputs); + }; + + const getAbortData = () => ({ + sender: getResponseSender(endpointOption), + conversationId, + messageId: responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + plugin: { ...plugin, loading: false }, + userMessage, + }); + const { abortController, onStart } = createAbortController( + res, + req, + endpointOption, + getAbortData, + ); + + try { + endpointOption.tools = await validateTools(user, endpointOption.tools); + const { client } = initializeClient(req, endpointOption); + + let response = await client.sendMessage(text, { + user, + isEdited: true, + conversationId, + parentMessageId, + responseMessageId, + overrideParentMessageId, + getIds, + onAgentAction, + onChainEnd, + onStart, + addMetadata, + ...endpointOption, + onProgress: progressCallback.call(null, { + res, + text, + plugin, + parentMessageId: overrideParentMessageId || userMessageId, + }), + abortController, + }); + + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } + + if (metadata) { + response = { ...response, ...metadata }; + } + + console.log('CLIENT RESPONSE'); + console.dir(response, { depth: null }); + response.plugin = { ...plugin, loading: false }; + await saveMessage(response); + + sendMessage(res, { + title: await getConvoTitle(req.user.id, conversationId), + final: true, + conversation: await getConvo(req.user.id, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender: getResponseSender(endpointOption), + messageId: responseMessageId, + parentMessageId: userMessageId, + }); + } + }, +); + +module.exports = router; diff --git a/api/server/routes/edit/index.js b/api/server/routes/edit/index.js new file mode 100644 index 00000000000..7eda18b8a1d --- /dev/null +++ b/api/server/routes/edit/index.js @@ -0,0 +1,13 @@ +const express = require('express'); +const router = express.Router(); +const openAI = require('./openAI'); +const gptPlugins = require('./gptPlugins'); +const anthropic = require('./anthropic'); +// const google = require('./google'); + +router.use(['/azureOpenAI', '/openAI'], openAI); +router.use('/gptPlugins', gptPlugins); +router.use('/anthropic', anthropic); +// router.use('/google', google); + +module.exports = router; diff --git a/api/server/routes/edit/openAI.js b/api/server/routes/edit/openAI.js new file mode 100644 index 00000000000..1ad0bcbebf7 --- /dev/null +++ b/api/server/routes/edit/openAI.js @@ -0,0 +1,141 @@ +const express = require('express'); +const router = express.Router(); +const { getResponseSender } = require('../endpoints/schemas'); +const { initializeClient } = require('../endpoints/openAI'); +const { saveMessage, getConvoTitle, getConvo } = require('../../../models'); +const { sendMessage, createOnProgress } = require('../../utils'); +const { + handleAbort, + createAbortController, + handleAbortError, + setHeaders, + requireJwtAuth, + validateEndpoint, + buildEndpointOption, +} = require('../../middleware'); + +router.post('/abort', requireJwtAuth, handleAbort()); + +router.post( + '/', + requireJwtAuth, + validateEndpoint, + buildEndpointOption, + setHeaders, + async (req, res) => { + let { + text, + generation, + endpointOption, + conversationId, + responseMessageId, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('edit log'); + console.dir({ text, conversationId, endpointOption }, { depth: null }); + let metadata; + let userMessage; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const userMessageId = parentMessageId; + + const addMetadata = (data) => (metadata = data); + const getIds = (data) => (userMessage = data.userMessage); + + const { onProgress: progressCallback, getPartialText } = createOnProgress({ + generation, + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); + + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + conversationId, + parentMessageId: overrideParentMessageId || userMessageId, + text: partialText, + model: endpointOption.modelOptions.model, + unfinished: true, + cancelled: false, + error: false, + }); + } + + if (saveDelay < 500) { + saveDelay = 500; + } + }, + }); + + const getAbortData = () => ({ + sender: getResponseSender(endpointOption), + conversationId, + messageId: responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + userMessage, + }); + + const { abortController, onStart } = createAbortController( + res, + req, + endpointOption, + getAbortData, + ); + + try { + const { client } = initializeClient(req, endpointOption); + + let response = await client.sendMessage(text, { + user: req.user.id, + isEdited: true, + conversationId, + parentMessageId, + responseMessageId, + overrideParentMessageId, + getIds, + onStart, + addMetadata, + abortController, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId || userMessageId, + }), + }); + + if (metadata) { + response = { ...response, ...metadata }; + } + + console.log( + 'promptTokens, completionTokens:', + response.promptTokens, + response.completionTokens, + ); + await saveMessage(response); + + sendMessage(res, { + title: await getConvoTitle(req.user.id, conversationId), + final: true, + conversation: await getConvo(req.user.id, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender: getResponseSender(endpointOption), + messageId: responseMessageId, + parentMessageId: userMessageId, + }); + } + }, +); + +module.exports = router; diff --git a/api/server/routes/endpoints.js b/api/server/routes/endpoints.js index b6016244c8b..bee04998e18 100644 --- a/api/server/routes/endpoints.js +++ b/api/server/routes/endpoints.js @@ -3,37 +3,45 @@ const express = require('express'); const router = express.Router(); const { availableTools } = require('../../app/clients/tools'); const { addOpenAPISpecs } = require('../../app/clients/tools/util/addOpenAPISpecs'); +// const { getAzureCredentials, genAzureChatCompletion } = require('../../utils/'); const openAIApiKey = process.env.OPENAI_API_KEY; const azureOpenAIApiKey = process.env.AZURE_API_KEY; +const useAzurePlugins = !!process.env.PLUGINS_USE_AZURE; const userProvidedOpenAI = openAIApiKey ? openAIApiKey === 'user_provided' : azureOpenAIApiKey === 'user_provided'; const fetchOpenAIModels = async (opts = { azure: false, plugins: false }, _models = []) => { let models = _models.slice() ?? []; + let apiKey = openAIApiKey; + let basePath = 'https://api.openai.com/v1'; if (opts.azure) { - /* TODO: Add Azure models from api/models */ return models; + // const azure = getAzureCredentials(); + // basePath = (genAzureChatCompletion(azure)) + // .split('/deployments')[0] + // .concat(`/models?api-version=${azure.azureOpenAIApiVersion}`); + // apiKey = azureOpenAIApiKey; } - let basePath = 'https://api.openai.com/v1/'; const reverseProxyUrl = process.env.OPENAI_REVERSE_PROXY; if (reverseProxyUrl) { basePath = reverseProxyUrl.match(/.*v1/)[0]; } - if (basePath.includes('v1')) { + if (basePath.includes('v1') || opts.azure) { try { - const res = await axios.get(`${basePath}/models`, { + const res = await axios.get(`${basePath}${opts.azure ? '' : '/models'}`, { headers: { - Authorization: `Bearer ${openAIApiKey}`, + Authorization: `Bearer ${apiKey}`, }, }); models = res.data.data.map((item) => item.id); + // console.log(`Fetched ${models.length} models from ${opts.azure ? 'Azure ' : ''}OpenAI API`); } catch (err) { - console.error(err); + console.log(`Failed to fetch models from ${opts.azure ? 'Azure ' : ''}OpenAI API`); } } @@ -149,7 +157,7 @@ router.get('/', async function (req, res) { const gptPlugins = openAIApiKey || azureOpenAIApiKey ? { - availableModels: await getOpenAIModels({ plugins: true }), + availableModels: await getOpenAIModels({ azure: useAzurePlugins, plugins: true }), plugins, availableAgents: ['classic', 'functions'], userProvide: userProvidedOpenAI, diff --git a/api/server/routes/endpoints/anthropic/buildOptions.js b/api/server/routes/endpoints/anthropic/buildOptions.js new file mode 100644 index 00000000000..2b0143d2b07 --- /dev/null +++ b/api/server/routes/endpoints/anthropic/buildOptions.js @@ -0,0 +1,15 @@ +const buildOptions = (endpoint, parsedBody) => { + const { modelLabel, promptPrefix, ...rest } = parsedBody; + const endpointOption = { + endpoint, + modelLabel, + promptPrefix, + modelOptions: { + ...rest, + }, + }; + + return endpointOption; +}; + +module.exports = buildOptions; diff --git a/api/server/routes/endpoints/anthropic/index.js b/api/server/routes/endpoints/anthropic/index.js new file mode 100644 index 00000000000..84e4bd5973a --- /dev/null +++ b/api/server/routes/endpoints/anthropic/index.js @@ -0,0 +1,8 @@ +const buildOptions = require('./buildOptions'); +const initializeClient = require('./initializeClient'); + +module.exports = { + // addTitle, // todo + buildOptions, + initializeClient, +}; diff --git a/api/server/routes/endpoints/anthropic/initializeClient.js b/api/server/routes/endpoints/anthropic/initializeClient.js new file mode 100644 index 00000000000..eab0487f95c --- /dev/null +++ b/api/server/routes/endpoints/anthropic/initializeClient.js @@ -0,0 +1,12 @@ +const { AnthropicClient } = require('../../../../app'); + +const initializeClient = (req) => { + let anthropicApiKey = req.body?.token ?? process.env.ANTHROPIC_API_KEY; + const client = new AnthropicClient(anthropicApiKey); + return { + client, + anthropicApiKey, + }; +}; + +module.exports = initializeClient; diff --git a/api/server/routes/endpoints/gptPlugins/buildOptions.js b/api/server/routes/endpoints/gptPlugins/buildOptions.js new file mode 100644 index 00000000000..ebf4116ec3a --- /dev/null +++ b/api/server/routes/endpoints/gptPlugins/buildOptions.js @@ -0,0 +1,31 @@ +const buildOptions = (endpoint, parsedBody) => { + const { + chatGptLabel, + promptPrefix, + agentOptions, + tools, + model, + temperature, + top_p, + presence_penalty, + frequency_penalty, + } = parsedBody; + const endpointOption = { + endpoint, + tools: tools.map((tool) => tool.pluginKey) ?? [], + chatGptLabel, + promptPrefix, + agentOptions, + modelOptions: { + model, + temperature, + top_p, + presence_penalty, + frequency_penalty, + }, + }; + + return endpointOption; +}; + +module.exports = buildOptions; diff --git a/api/server/routes/endpoints/gptPlugins/index.js b/api/server/routes/endpoints/gptPlugins/index.js new file mode 100644 index 00000000000..39944683067 --- /dev/null +++ b/api/server/routes/endpoints/gptPlugins/index.js @@ -0,0 +1,7 @@ +const buildOptions = require('./buildOptions'); +const initializeClient = require('./initializeClient'); + +module.exports = { + buildOptions, + initializeClient, +}; diff --git a/api/server/routes/endpoints/gptPlugins/initializeClient.js b/api/server/routes/endpoints/gptPlugins/initializeClient.js new file mode 100644 index 00000000000..6035a6f5a43 --- /dev/null +++ b/api/server/routes/endpoints/gptPlugins/initializeClient.js @@ -0,0 +1,30 @@ +const { PluginsClient } = require('../../../../app'); +const { getAzureCredentials } = require('../../../../utils'); + +const initializeClient = (req, endpointOption) => { + const clientOptions = { + debug: true, + reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null, + proxy: process.env.PROXY || null, + ...endpointOption, + }; + + let openAIApiKey = req.body?.token ?? process.env.OPENAI_API_KEY; + if (process.env.PLUGINS_USE_AZURE) { + clientOptions.azure = getAzureCredentials(); + openAIApiKey = clientOptions.azure.azureOpenAIApiKey; + } + + if (openAIApiKey && openAIApiKey.includes('azure') && !clientOptions.azure) { + clientOptions.azure = JSON.parse(req.body?.token) ?? getAzureCredentials(); + openAIApiKey = clientOptions.azure.azureOpenAIApiKey; + } + const client = new PluginsClient(openAIApiKey, clientOptions); + return { + client, + azure: clientOptions.azure, + openAIApiKey, + }; +}; + +module.exports = initializeClient; diff --git a/api/server/routes/endpoints/openAI/addTitle.js b/api/server/routes/endpoints/openAI/addTitle.js new file mode 100644 index 00000000000..c3f6c2ad2af --- /dev/null +++ b/api/server/routes/endpoints/openAI/addTitle.js @@ -0,0 +1,22 @@ +const { titleConvo } = require('../../../../app'); +const { saveConvo } = require('../../../../models'); + +const addTitle = async ( + req, + { text, azure, response, newConvo, parentMessageId, openAIApiKey }, +) => { + if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) { + const title = await titleConvo({ + text, + azure, + response, + openAIApiKey, + }); + await saveConvo(req.user.id, { + conversationId: response.conversationId, + title, + }); + } +}; + +module.exports = addTitle; diff --git a/api/server/routes/endpoints/openAI/buildOptions.js b/api/server/routes/endpoints/openAI/buildOptions.js new file mode 100644 index 00000000000..a1ad232bb73 --- /dev/null +++ b/api/server/routes/endpoints/openAI/buildOptions.js @@ -0,0 +1,15 @@ +const buildOptions = (endpoint, parsedBody) => { + const { chatGptLabel, promptPrefix, ...rest } = parsedBody; + const endpointOption = { + endpoint, + chatGptLabel, + promptPrefix, + modelOptions: { + ...rest, + }, + }; + + return endpointOption; +}; + +module.exports = buildOptions; diff --git a/api/server/routes/endpoints/openAI/index.js b/api/server/routes/endpoints/openAI/index.js new file mode 100644 index 00000000000..772b1efb118 --- /dev/null +++ b/api/server/routes/endpoints/openAI/index.js @@ -0,0 +1,9 @@ +const addTitle = require('./addTitle'); +const buildOptions = require('./buildOptions'); +const initializeClient = require('./initializeClient'); + +module.exports = { + addTitle, + buildOptions, + initializeClient, +}; diff --git a/api/server/routes/endpoints/openAI/initializeClient.js b/api/server/routes/endpoints/openAI/initializeClient.js new file mode 100644 index 00000000000..4e4910c5b04 --- /dev/null +++ b/api/server/routes/endpoints/openAI/initializeClient.js @@ -0,0 +1,27 @@ +const { OpenAIClient } = require('../../../../app'); +const { getAzureCredentials } = require('../../../../utils'); + +const initializeClient = (req, endpointOption) => { + const clientOptions = { + // debug: true, + // contextStrategy: 'refine', + reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null, + proxy: process.env.PROXY || null, + ...endpointOption, + }; + + let openAIApiKey = req.body?.token ?? process.env.OPENAI_API_KEY; + + if (process.env.AZURE_API_KEY && endpointOption.endpoint === 'azureOpenAI') { + clientOptions.azure = JSON.parse(req.body?.token) ?? getAzureCredentials(); + openAIApiKey = clientOptions.azure.azureOpenAIApiKey; + } + + const client = new OpenAIClient(openAIApiKey, clientOptions); + return { + client, + openAIApiKey, + }; +}; + +module.exports = initializeClient; diff --git a/api/server/routes/endpoints/schemas.js b/api/server/routes/endpoints/schemas.js new file mode 100644 index 00000000000..e7e9f30e8d9 --- /dev/null +++ b/api/server/routes/endpoints/schemas.js @@ -0,0 +1,369 @@ +const { z } = require('zod'); + +const EModelEndpoint = { + azureOpenAI: 'azureOpenAI', + openAI: 'openAI', + bingAI: 'bingAI', + chatGPTBrowser: 'chatGPTBrowser', + google: 'google', + gptPlugins: 'gptPlugins', + anthropic: 'anthropic', +}; + +const eModelEndpointSchema = z.nativeEnum(EModelEndpoint); + +/* +const tMessageSchema = z.object({ + messageId: z.string(), + clientId: z.string().nullable().optional(), + conversationId: z.string().nullable(), + parentMessageId: z.string().nullable(), + sender: z.string(), + text: z.string(), + isCreatedByUser: z.boolean(), + error: z.boolean(), + createdAt: z + .string() + .optional() + .default(() => new Date().toISOString()), + updatedAt: z + .string() + .optional() + .default(() => new Date().toISOString()), + current: z.boolean().optional(), + unfinished: z.boolean().optional(), + submitting: z.boolean().optional(), + searchResult: z.boolean().optional(), + finish_reason: z.string().optional(), +}); + +const tPresetSchema = tConversationSchema + .omit({ + conversationId: true, + createdAt: true, + updatedAt: true, + title: true, + }) + .merge( + z.object({ + conversationId: z.string().optional(), + presetId: z.string().nullable().optional(), + title: z.string().nullable().optional(), + }), + ); +*/ + +const tPluginAuthConfigSchema = z.object({ + authField: z.string(), + label: z.string(), + description: z.string(), +}); + +const tPluginSchema = z.object({ + name: z.string(), + pluginKey: z.string(), + description: z.string(), + icon: z.string(), + authConfig: z.array(tPluginAuthConfigSchema), + authenticated: z.boolean().optional(), + isButton: z.boolean().optional(), +}); + +const tExampleSchema = z.object({ + input: z.object({ + content: z.string(), + }), + output: z.object({ + content: z.string(), + }), +}); + +const tAgentOptionsSchema = z.object({ + agent: z.string(), + skipCompletion: z.boolean(), + model: z.string(), + temperature: z.number(), +}); + +const tConversationSchema = z.object({ + conversationId: z.string().nullable(), + title: z.string(), + user: z.string().optional(), + endpoint: eModelEndpointSchema.nullable(), + suggestions: z.array(z.string()).optional(), + messages: z.array(z.string()).optional(), + tools: z.array(tPluginSchema).optional(), + createdAt: z.string(), + updatedAt: z.string(), + systemMessage: z.string().nullable().optional(), + modelLabel: z.string().nullable().optional(), + examples: z.array(tExampleSchema).optional(), + chatGptLabel: z.string().nullable().optional(), + userLabel: z.string().optional(), + model: z.string().nullable().optional(), + promptPrefix: z.string().nullable().optional(), + temperature: z.number().optional(), + topP: z.number().optional(), + topK: z.number().optional(), + context: z.string().nullable().optional(), + top_p: z.number().optional(), + frequency_penalty: z.number().optional(), + presence_penalty: z.number().optional(), + jailbreak: z.boolean().optional(), + jailbreakConversationId: z.string().nullable().optional(), + conversationSignature: z.string().nullable().optional(), + parentMessageId: z.string().optional(), + clientId: z.string().nullable().optional(), + invocationId: z.number().nullable().optional(), + toneStyle: z.string().nullable().optional(), + maxOutputTokens: z.number().optional(), + agentOptions: tAgentOptionsSchema.nullable().optional(), +}); + +const openAISchema = tConversationSchema + .pick({ + model: true, + chatGptLabel: true, + promptPrefix: true, + temperature: true, + top_p: true, + presence_penalty: true, + frequency_penalty: true, + }) + .transform((obj) => ({ + ...obj, + model: obj.model ?? 'gpt-3.5-turbo', + chatGptLabel: obj.chatGptLabel ?? null, + promptPrefix: obj.promptPrefix ?? null, + temperature: obj.temperature ?? 1, + top_p: obj.top_p ?? 1, + presence_penalty: obj.presence_penalty ?? 0, + frequency_penalty: obj.frequency_penalty ?? 0, + })) + .catch(() => ({ + model: 'gpt-3.5-turbo', + chatGptLabel: null, + promptPrefix: null, + temperature: 1, + top_p: 1, + presence_penalty: 0, + frequency_penalty: 0, + })); + +const googleSchema = tConversationSchema + .pick({ + model: true, + modelLabel: true, + promptPrefix: true, + examples: true, + temperature: true, + maxOutputTokens: true, + topP: true, + topK: true, + }) + .transform((obj) => ({ + ...obj, + model: obj.model ?? 'chat-bison', + modelLabel: obj.modelLabel ?? null, + promptPrefix: obj.promptPrefix ?? null, + temperature: obj.temperature ?? 0.2, + maxOutputTokens: obj.maxOutputTokens ?? 1024, + topP: obj.topP ?? 0.95, + topK: obj.topK ?? 40, + })) + .catch(() => ({ + model: 'chat-bison', + modelLabel: null, + promptPrefix: null, + temperature: 0.2, + maxOutputTokens: 1024, + topP: 0.95, + topK: 40, + })); + +const bingAISchema = tConversationSchema + .pick({ + jailbreak: true, + systemMessage: true, + context: true, + toneStyle: true, + jailbreakConversationId: true, + conversationSignature: true, + clientId: true, + invocationId: true, + }) + .transform((obj) => ({ + ...obj, + model: '', + jailbreak: obj.jailbreak ?? false, + systemMessage: obj.systemMessage ?? null, + context: obj.context ?? null, + toneStyle: obj.toneStyle ?? 'creative', + jailbreakConversationId: obj.jailbreakConversationId ?? null, + conversationSignature: obj.conversationSignature ?? null, + clientId: obj.clientId ?? null, + invocationId: obj.invocationId ?? 1, + })) + .catch(() => ({ + model: '', + jailbreak: false, + systemMessage: null, + context: null, + toneStyle: 'creative', + jailbreakConversationId: null, + conversationSignature: null, + clientId: null, + invocationId: 1, + })); + +const anthropicSchema = tConversationSchema + .pick({ + model: true, + modelLabel: true, + promptPrefix: true, + temperature: true, + maxOutputTokens: true, + topP: true, + topK: true, + }) + .transform((obj) => ({ + ...obj, + model: obj.model ?? 'claude-1', + modelLabel: obj.modelLabel ?? null, + promptPrefix: obj.promptPrefix ?? null, + temperature: obj.temperature ?? 1, + maxOutputTokens: obj.maxOutputTokens ?? 1024, + topP: obj.topP ?? 0.7, + topK: obj.topK ?? 5, + })) + .catch(() => ({ + model: 'claude-1', + modelLabel: null, + promptPrefix: null, + temperature: 1, + maxOutputTokens: 1024, + topP: 0.7, + topK: 5, + })); + +const chatGPTBrowserSchema = tConversationSchema + .pick({ + model: true, + }) + .transform((obj) => ({ + ...obj, + model: obj.model ?? 'text-davinci-002-render-sha', + })) + .catch(() => ({ + model: 'text-davinci-002-render-sha', + })); + +const gptPluginsSchema = tConversationSchema + .pick({ + model: true, + chatGptLabel: true, + promptPrefix: true, + temperature: true, + top_p: true, + presence_penalty: true, + frequency_penalty: true, + tools: true, + agentOptions: true, + }) + .transform((obj) => ({ + ...obj, + model: obj.model ?? 'gpt-3.5-turbo', + chatGptLabel: obj.chatGptLabel ?? null, + promptPrefix: obj.promptPrefix ?? null, + temperature: obj.temperature ?? 0.8, + top_p: obj.top_p ?? 1, + presence_penalty: obj.presence_penalty ?? 0, + frequency_penalty: obj.frequency_penalty ?? 0, + tools: obj.tools ?? [], + agentOptions: obj.agentOptions ?? { + agent: 'functions', + skipCompletion: true, + model: 'gpt-3.5-turbo', + temperature: 0, + }, + })) + .catch(() => ({ + model: 'gpt-3.5-turbo', + chatGptLabel: null, + promptPrefix: null, + temperature: 0.8, + top_p: 1, + presence_penalty: 0, + frequency_penalty: 0, + tools: [], + agentOptions: { + agent: 'functions', + skipCompletion: true, + model: 'gpt-3.5-turbo', + temperature: 0, + }, + })); + +const endpointSchemas = { + openAI: openAISchema, + azureOpenAI: openAISchema, + google: googleSchema, + bingAI: bingAISchema, + anthropic: anthropicSchema, + chatGPTBrowser: chatGPTBrowserSchema, + gptPlugins: gptPluginsSchema, +}; + +function getFirstDefinedValue(possibleValues) { + let returnValue; + for (const value of possibleValues) { + if (value) { + returnValue = value; + break; + } + } + return returnValue; +} + +const parseConvo = (endpoint, conversation, possibleValues) => { + const schema = endpointSchemas[endpoint]; + + if (!schema) { + throw new Error(`Unknown endpoint: ${endpoint}`); + } + + const convo = schema.parse(conversation); + + if (possibleValues && convo) { + convo.model = getFirstDefinedValue(possibleValues.model) ?? convo.model; + } + + return convo; +}; + +const getResponseSender = (endpointOption) => { + const { endpoint, chatGptLabel, modelLabel, jailbreak } = endpointOption; + + if (['openAI', 'azureOpenAI', 'gptPlugins', 'chatGPTBrowser'].includes(endpoint)) { + return chatGptLabel ?? 'ChatGPT'; + } + + if (endpoint === 'bingAI') { + return jailbreak ? 'Sydney' : 'BingAI'; + } + + if (endpoint === 'anthropic') { + return modelLabel ?? 'Anthropic'; + } + + if (endpoint === 'google') { + return modelLabel ?? 'PaLM2'; + } + + return ''; +}; + +module.exports = { + parseConvo, + getResponseSender, +}; diff --git a/api/server/routes/index.js b/api/server/routes/index.js index 18d2a44fc49..c78af3fcff5 100644 --- a/api/server/routes/index.js +++ b/api/server/routes/index.js @@ -1,4 +1,5 @@ const ask = require('./ask'); +const edit = require('./edit'); const messages = require('./messages'); const convos = require('./convos'); const presets = require('./presets'); @@ -15,6 +16,7 @@ const config = require('./config'); module.exports = { search, ask, + edit, messages, convos, presets, diff --git a/api/server/routes/messages.js b/api/server/routes/messages.js index a13b4272bca..0530ebc2639 100644 --- a/api/server/routes/messages.js +++ b/api/server/routes/messages.js @@ -1,7 +1,7 @@ const express = require('express'); const router = express.Router(); const { getMessages } = require('../../models/Message'); -const requireJwtAuth = require('../../middleware/requireJwtAuth'); +const requireJwtAuth = require('../middleware/requireJwtAuth'); router.get('/:conversationId', requireJwtAuth, async (req, res) => { const { conversationId } = req.params; diff --git a/api/server/routes/plugins.js b/api/server/routes/plugins.js index cb931632423..4a7715a6186 100644 --- a/api/server/routes/plugins.js +++ b/api/server/routes/plugins.js @@ -1,6 +1,6 @@ const express = require('express'); const { getAvailablePluginsController } = require('../controllers/PluginController'); -const requireJwtAuth = require('../../middleware/requireJwtAuth'); +const requireJwtAuth = require('../middleware/requireJwtAuth'); const router = express.Router(); diff --git a/api/server/routes/presets.js b/api/server/routes/presets.js index 8a08f0b509e..127a8e5b6be 100644 --- a/api/server/routes/presets.js +++ b/api/server/routes/presets.js @@ -2,7 +2,7 @@ const express = require('express'); const router = express.Router(); const { getPresets, savePreset, deletePresets } = require('../../models'); const crypto = require('crypto'); -const requireJwtAuth = require('../../middleware/requireJwtAuth'); +const requireJwtAuth = require('../middleware/requireJwtAuth'); router.get('/', requireJwtAuth, async (req, res) => { const presets = (await getPresets(req.user.id)).map((preset) => { diff --git a/api/server/routes/search.js b/api/server/routes/search.js index 9f495bb386a..955e58da97a 100644 --- a/api/server/routes/search.js +++ b/api/server/routes/search.js @@ -5,7 +5,7 @@ const { Message } = require('../../models/Message'); const { Conversation, getConvosQueried } = require('../../models/Conversation'); const { reduceHits } = require('../../lib/utils/reduceHits'); const { cleanUpPrimaryKeyValue } = require('../../lib/utils/misc'); -const requireJwtAuth = require('../../middleware/requireJwtAuth'); +const requireJwtAuth = require('../middleware/requireJwtAuth'); const cache = new Map(); diff --git a/api/server/routes/tokenizer.js b/api/server/routes/tokenizer.js index 995263b0d01..9fd79d0acfa 100644 --- a/api/server/routes/tokenizer.js +++ b/api/server/routes/tokenizer.js @@ -4,7 +4,7 @@ const { Tiktoken } = require('@dqbd/tiktoken/lite'); const { load } = require('@dqbd/tiktoken/load'); const registry = require('@dqbd/tiktoken/registry.json'); const models = require('@dqbd/tiktoken/model_to_encoding.json'); -const requireJwtAuth = require('../../middleware/requireJwtAuth'); +const requireJwtAuth = require('../middleware/requireJwtAuth'); router.post('/', requireJwtAuth, async (req, res) => { try { diff --git a/api/server/routes/user.js b/api/server/routes/user.js index 293ce4cf630..b90e3d965b6 100644 --- a/api/server/routes/user.js +++ b/api/server/routes/user.js @@ -1,5 +1,5 @@ const express = require('express'); -const requireJwtAuth = require('../../middleware/requireJwtAuth'); +const requireJwtAuth = require('../middleware/requireJwtAuth'); const { getUserController, updateUserPluginsController } = require('../controllers/UserController'); const router = express.Router(); diff --git a/api/server/services/AuthService.js b/api/server/services/AuthService.js index 309a15a2ba2..52380bc8a19 100644 --- a/api/server/services/AuthService.js +++ b/api/server/services/AuthService.js @@ -1,10 +1,10 @@ -const User = require('../../models/User'); -const Token = require('../../models/schema/tokenSchema'); const crypto = require('crypto'); const bcrypt = require('bcryptjs'); +const User = require('../../models/User'); +const Token = require('../../models/schema/tokenSchema'); const { registerSchema } = require('../../strategies/validators'); -const { sendEmail } = require('../../utils'); const config = require('../../../config/loader'); +const { sendEmail } = require('../utils'); const domains = config.domains; /** diff --git a/api/server/services/PluginService.js b/api/server/services/PluginService.js index 970f16f6d92..b70005dffa3 100644 --- a/api/server/services/PluginService.js +++ b/api/server/services/PluginService.js @@ -1,5 +1,5 @@ const PluginAuth = require('../../models/schema/pluginAuthSchema'); -const { encrypt, decrypt } = require('../../utils/'); +const { encrypt, decrypt } = require('../utils/'); const getUserPluginAuthValue = async (user, authField) => { try { diff --git a/api/lib/parse/citeText.js b/api/server/utils/citations.js similarity index 67% rename from api/lib/parse/citeText.js rename to api/server/utils/citations.js index 8fc1cea8b4f..33136c18b8d 100644 --- a/api/lib/parse/citeText.js +++ b/api/server/utils/citations.js @@ -1,4 +1,19 @@ const citationRegex = /\[\^\d+?\^\]/g; +const regex = / \[.*?]\(.*?\)/g; + +const getCitations = (res) => { + const adaptiveCards = res.details.adaptiveCards; + const textBlocks = adaptiveCards && adaptiveCards[0].body; + if (!textBlocks) { + return ''; + } + let links = textBlocks[textBlocks.length - 1]?.text.match(regex); + if (links?.length === 0 || !links) { + return ''; + } + links = links.map((link) => link.trim()); + return links.join('\n - '); +}; const citeText = (res, noLinks = false) => { let result = res.text || res; @@ -32,4 +47,4 @@ const citeText = (res, noLinks = false) => { return result; }; -module.exports = citeText; +module.exports = { getCitations, citeText }; diff --git a/api/utils/crypto.js b/api/server/utils/crypto.js similarity index 100% rename from api/utils/crypto.js rename to api/server/utils/crypto.js diff --git a/api/utils/emails/passwordReset.handlebars b/api/server/utils/emails/passwordReset.handlebars similarity index 100% rename from api/utils/emails/passwordReset.handlebars rename to api/server/utils/emails/passwordReset.handlebars diff --git a/api/utils/emails/requestPasswordReset.handlebars b/api/server/utils/emails/requestPasswordReset.handlebars similarity index 100% rename from api/utils/emails/requestPasswordReset.handlebars rename to api/server/utils/emails/requestPasswordReset.handlebars diff --git a/api/server/routes/ask/handlers.js b/api/server/utils/handleText.js similarity index 93% rename from api/server/routes/ask/handlers.js rename to api/server/utils/handleText.js index d917c65ca4a..b5efa08d87b 100644 --- a/api/server/routes/ask/handlers.js +++ b/api/server/utils/handleText.js @@ -1,8 +1,10 @@ const _ = require('lodash'); const citationRegex = /\[\^\d+?\^]/g; -const { getCitations, citeText } = require('../../../app'); +const { getCitations, citeText } = require('./citations'); const cursor = ''; +const addSpaceIfNeeded = (text) => (text.length > 0 && !text.endsWith(' ') ? text + ' ' : text); + const handleError = (res, message) => { res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`); res.end(); @@ -15,12 +17,12 @@ const sendMessage = (res, message, event = 'message') => { res.write(`event: ${event}\ndata: ${JSON.stringify(message)}\n\n`); }; -const createOnProgress = ({ onProgress: _onProgress }) => { +const createOnProgress = ({ generation = '', onProgress: _onProgress }) => { let i = 0; let code = ''; - let tokens = ''; let precode = ''; let codeBlock = false; + let tokens = addSpaceIfNeeded(generation); const progressCallback = async (partial, { res, text, plugin, bing = false, ...rest }) => { let chunk = partial === text ? '' : partial; @@ -155,4 +157,5 @@ module.exports = { handleText, formatSteps, formatAction, + addSpaceIfNeeded, }; diff --git a/api/server/utils/index.js b/api/server/utils/index.js new file mode 100644 index 00000000000..e76d5b43650 --- /dev/null +++ b/api/server/utils/index.js @@ -0,0 +1,11 @@ +const cryptoUtils = require('./crypto'); +const handleText = require('./handleText'); +const citations = require('./citations'); +const sendEmail = require('./sendEmail'); + +module.exports = { + ...cryptoUtils, + ...handleText, + ...citations, + sendEmail, +}; diff --git a/api/utils/sendEmail.js b/api/server/utils/sendEmail.js similarity index 100% rename from api/utils/sendEmail.js rename to api/server/utils/sendEmail.js diff --git a/api/utils/abortMessage.js b/api/utils/abortMessage.js deleted file mode 100644 index fea33eb4c79..00000000000 --- a/api/utils/abortMessage.js +++ /dev/null @@ -1,18 +0,0 @@ -async function abortMessage(req, res, abortControllers) { - const { abortKey } = req.body; - console.log('req.body', req.body); - if (!abortControllers.has(abortKey)) { - return res.status(404).send('Request not found'); - } - - const { abortController } = abortControllers.get(abortKey); - - abortControllers.delete(abortKey); - const ret = await abortController.abortAsk(); - console.log('Aborted request', abortKey); - console.log('Aborted message:', ret); - - res.send(JSON.stringify(ret)); -} - -module.exports = abortMessage; diff --git a/api/utils/index.js b/api/utils/index.js index 0a4dd75bf5f..7de983c0f03 100644 --- a/api/utils/index.js +++ b/api/utils/index.js @@ -1,16 +1,10 @@ const azureUtils = require('./azureUtils'); -const cryptoUtils = require('./crypto'); const { tiktokenModels, maxTokensMap } = require('./tokens'); -const sendEmail = require('./sendEmail'); -const abortMessage = require('./abortMessage'); const findMessageContent = require('./findMessageContent'); module.exports = { - ...cryptoUtils, ...azureUtils, maxTokensMap, tiktokenModels, - sendEmail, - abortMessage, findMessageContent, }; diff --git a/client/src/common/types.ts b/client/src/common/types.ts index dfcf5d5446a..9507782358d 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -48,3 +48,17 @@ export type TSetOptionsPayload = { checkPluginSelection: (value: string) => boolean; setTools: (newValue: string) => void; }; + +export type TPresetItemProps = { + preset: TPreset; + value: TPreset; + onSelect: (preset: TPreset) => void; + onChangePreset: (preset: TPreset) => void; + onDeletePreset: (preset: TPreset) => void; +}; + +export type TOnClick = (e: React.MouseEvent) => void; + +export type TGenButtonProps = { + onClick: TOnClick; +}; diff --git a/client/src/components/Input/EndpointMenu/EndpointMenu.jsx b/client/src/components/Input/EndpointMenu/EndpointMenu.jsx index d7bbf439e9b..cfc2c25bc4a 100644 --- a/client/src/components/Input/EndpointMenu/EndpointMenu.jsx +++ b/client/src/components/Input/EndpointMenu/EndpointMenu.jsx @@ -103,10 +103,18 @@ export default function NewConversationMenu() { }; // set the current model + const isModular = modularEndpoints.has(endpoint); const onSelectPreset = (newPreset) => { setMenuOpen(false); + if (!newPreset) { + return; + } - if (modularEndpoints.has(endpoint) && modularEndpoints.has(newPreset?.endpoint)) { + if ( + isModular && + modularEndpoints.has(newPreset?.endpoint) && + endpoint === newPreset?.endpoint + ) { const currentConvo = getDefaultConversation({ conversation, endpointsConfig, @@ -118,10 +126,6 @@ export default function NewConversationMenu() { return; } - if (!newPreset) { - return; - } - newConversation({}, newPreset); }; diff --git a/client/src/components/Input/EndpointMenu/PresetItem.jsx b/client/src/components/Input/EndpointMenu/PresetItem.tsx similarity index 87% rename from client/src/components/Input/EndpointMenu/PresetItem.jsx rename to client/src/components/Input/EndpointMenu/PresetItem.tsx index ca5d6b86a71..47ac0fb9f97 100644 --- a/client/src/components/Input/EndpointMenu/PresetItem.jsx +++ b/client/src/components/Input/EndpointMenu/PresetItem.tsx @@ -1,7 +1,14 @@ +import type { TPresetItemProps } from '~/common'; +import type { TPreset } from 'librechat-data-provider'; import { DropdownMenuRadioItem, EditIcon, TrashIcon } from '~/components'; import { getIcon } from '~/components/Endpoints'; -export default function PresetItem({ preset = {}, value, onChangePreset, onDeletePreset }) { +export default function PresetItem({ + preset = {} as TPreset, + value, + onChangePreset, + onDeletePreset, +}: TPresetItemProps) { const { endpoint } = preset; const icon = getIcon({ @@ -14,9 +21,9 @@ export default function PresetItem({ preset = {}, value, onChangePreset, onDelet const getPresetTitle = () => { let _title = `${endpoint}`; + const { chatGptLabel, modelLabel, model, jailbreak, toneStyle } = preset; if (endpoint === 'azureOpenAI' || endpoint === 'openAI') { - const { chatGptLabel, model } = preset; if (model) { _title += `: ${model}`; } @@ -24,7 +31,6 @@ export default function PresetItem({ preset = {}, value, onChangePreset, onDelet _title += ` as ${chatGptLabel}`; } } else if (endpoint === 'google') { - const { modelLabel, model } = preset; if (model) { _title += `: ${model}`; } @@ -32,7 +38,6 @@ export default function PresetItem({ preset = {}, value, onChangePreset, onDelet _title += ` as ${modelLabel}`; } } else if (endpoint === 'bingAI') { - const { jailbreak, toneStyle } = preset; if (toneStyle) { _title += `: ${toneStyle}`; } @@ -40,12 +45,10 @@ export default function PresetItem({ preset = {}, value, onChangePreset, onDelet _title += ' as Sydney'; } } else if (endpoint === 'chatGPTBrowser') { - const { model } = preset; if (model) { _title += `: ${model}`; } } else if (endpoint === 'gptPlugins') { - const { model } = preset; if (model) { _title += `: ${model}`; } @@ -60,6 +63,7 @@ export default function PresetItem({ preset = {}, value, onChangePreset, onDelet // regular model return ( diff --git a/client/src/components/Input/EndpointMenu/PresetItems.jsx b/client/src/components/Input/EndpointMenu/PresetItems.tsx similarity index 82% rename from client/src/components/Input/EndpointMenu/PresetItems.jsx rename to client/src/components/Input/EndpointMenu/PresetItems.tsx index a7048bc6855..5e6e47b5097 100644 --- a/client/src/components/Input/EndpointMenu/PresetItems.jsx +++ b/client/src/components/Input/EndpointMenu/PresetItems.tsx @@ -1,10 +1,11 @@ import React from 'react'; import PresetItem from './PresetItem'; +import type { TPreset } from 'librechat-data-provider'; export default function PresetItems({ presets, onSelect, onChangePreset, onDeletePreset }) { return ( <> - {presets.map((preset) => ( + {presets.map((preset: TPreset) => ( ) => void; + className?: string; +}) { + return ( + + ); +} diff --git a/client/src/components/Input/Generations/Continue.tsx b/client/src/components/Input/Generations/Continue.tsx new file mode 100644 index 00000000000..7d472b30f62 --- /dev/null +++ b/client/src/components/Input/Generations/Continue.tsx @@ -0,0 +1,12 @@ +import type { TGenButtonProps } from '~/common'; +import { ContinueIcon } from '~/components/svg'; +import Button from './Button'; + +export default function Continue({ onClick }: TGenButtonProps) { + return ( + + ); +} diff --git a/client/src/components/Input/Generations/GenerationButtons.tsx b/client/src/components/Input/Generations/GenerationButtons.tsx new file mode 100644 index 00000000000..7c51f59e095 --- /dev/null +++ b/client/src/components/Input/Generations/GenerationButtons.tsx @@ -0,0 +1,61 @@ +import type { TMessage } from 'librechat-data-provider'; +import { useMessageHandler, useMediaQuery, useGenerations } from '~/hooks'; +import { cn } from '~/utils'; +import Regenerate from './Regenerate'; +import Continue from './Continue'; +import Stop from './Stop'; + +type GenerationButtonsProps = { + endpoint: string; + showPopover: boolean; + opacityClass: string; +}; + +export default function GenerationButtons({ + endpoint, + showPopover, + opacityClass, +}: GenerationButtonsProps) { + const { + messages, + isSubmitting, + latestMessage, + handleContinue, + handleRegenerate, + handleStopGenerating, + } = useMessageHandler(); + const isSmallScreen = useMediaQuery('(max-width: 768px)'); + const { continueSupported, regenerateEnabled } = useGenerations({ + endpoint, + message: latestMessage as TMessage, + isSubmitting, + }); + + if (isSmallScreen) { + return null; + } + + let button: React.ReactNode = null; + + if (isSubmitting) { + button = ; + } else if (continueSupported) { + button = ; + } else if (messages && messages.length > 0 && regenerateEnabled) { + button = ; + } + + return ( +
+
+
+
+ {button} +
+
+
+ ); +} diff --git a/client/src/components/Input/Generations/Regenerate.tsx b/client/src/components/Input/Generations/Regenerate.tsx new file mode 100644 index 00000000000..8187ca32608 --- /dev/null +++ b/client/src/components/Input/Generations/Regenerate.tsx @@ -0,0 +1,12 @@ +import type { TGenButtonProps } from '~/common'; +import { RegenerateIcon } from '~/components/svg'; +import Button from './Button'; + +export default function Regenerate({ onClick }: TGenButtonProps) { + return ( + + ); +} diff --git a/client/src/components/Input/Generations/Stop.tsx b/client/src/components/Input/Generations/Stop.tsx new file mode 100644 index 00000000000..41579d6c5f9 --- /dev/null +++ b/client/src/components/Input/Generations/Stop.tsx @@ -0,0 +1,12 @@ +import type { TGenButtonProps } from '~/common'; +import { StopGeneratingIcon } from '~/components/svg'; +import Button from './Button'; + +export default function Stop({ onClick }: TGenButtonProps) { + return ( + + ); +} diff --git a/client/src/components/Input/Generations/index.ts b/client/src/components/Input/Generations/index.ts new file mode 100644 index 00000000000..bbf5aeb41b0 --- /dev/null +++ b/client/src/components/Input/Generations/index.ts @@ -0,0 +1 @@ +export { default as GenerationButtons } from './GenerationButtons'; diff --git a/client/src/components/Input/OptionsBar.tsx b/client/src/components/Input/OptionsBar.tsx index 12dfc63360f..988271c46b0 100644 --- a/client/src/components/Input/OptionsBar.tsx +++ b/client/src/components/Input/OptionsBar.tsx @@ -12,7 +12,7 @@ import { Button } from '~/components/ui'; import { cn, cardStyle } from '~/utils/'; import { useSetOptions } from '~/hooks'; import { ModelSelect } from './ModelSelect'; -import GenerationButtons from './GenerationButtons'; +import { GenerationButtons } from './Generations'; import store from '~/store'; export default function OptionsBar() { @@ -76,7 +76,11 @@ export default function OptionsBar() { : () => setShowPopover((prev) => !prev); return (
- +
- {children} - {text} - {/* -Regenerate response */} - - ); -} diff --git a/client/src/components/Input/TextChat.jsx b/client/src/components/Input/TextChat.jsx index e35caca9cbe..ec2a4f22650 100644 --- a/client/src/components/Input/TextChat.jsx +++ b/client/src/components/Input/TextChat.jsx @@ -10,22 +10,18 @@ import { cn } from '~/utils'; import store from '~/store'; export default function TextChat({ isSearchView = false }) { - const inputRef = useRef(null); - const isComposing = useRef(false); - - const [text, setText] = useRecoilState(store.text); - const { theme } = useContext(ThemeContext); + const { ask, isSubmitting, handleStopGenerating, latestMessage, endpointsConfig } = + useMessageHandler(); const conversation = useRecoilValue(store.conversation); - const latestMessage = useRecoilValue(store.latestMessage); - - const endpointsConfig = useRecoilValue(store.endpointsConfig); - const isSubmitting = useRecoilValue(store.isSubmitting); const setShowBingToneSetting = useSetRecoilState(store.showBingToneSetting); + const [text, setText] = useRecoilState(store.text); + const { theme } = useContext(ThemeContext); + const isComposing = useRef(false); + const inputRef = useRef(null); // TODO: do we need this? const disabled = false; - const { ask, stopGenerating } = useMessageHandler(); const isNotAppendable = latestMessage?.unfinished & !isSubmitting || latestMessage?.error; const { conversationId, jailbreak } = conversation || {}; @@ -60,11 +56,6 @@ export default function TextChat({ isSearchView = false }) { setText(''); }; - const handleStopGenerating = (e) => { - e.preventDefault(); - stopGenerating(); - }; - const handleKeyDown = (e) => { if (e.key === 'Enter' && isSubmitting) { return; diff --git a/client/src/components/Messages/HoverButtons.jsx b/client/src/components/Messages/HoverButtons.jsx deleted file mode 100644 index 4d481c7a3e3..00000000000 --- a/client/src/components/Messages/HoverButtons.jsx +++ /dev/null @@ -1,87 +0,0 @@ -import React from 'react'; -import { cn } from '~/utils/'; -import Clipboard from '../svg/Clipboard'; -import CheckMark from '../svg/CheckMark'; -import EditIcon from '../svg/EditIcon'; -import RegenerateIcon from '../svg/RegenerateIcon'; - -export default function HoverButtons({ - isEditting, - enterEdit, - copyToClipboard, - conversation, - isSubmitting, - message, - regenerate, -}) { - const { endpoint } = conversation; - const [isCopied, setIsCopied] = React.useState(false); - - const branchingSupported = - // azureOpenAI, openAI, chatGPTBrowser support branching, so edit enabled // 5/21/23: Bing is allowing editing and Message regenerating - !![ - 'azureOpenAI', - 'openAI', - 'chatGPTBrowser', - 'google', - 'bingAI', - 'gptPlugins', - 'anthropic', - ].find((e) => e === endpoint); - // Sydney in bingAI supports branching, so edit enabled - - const editEnabled = - !message?.error && - message?.isCreatedByUser && - !message?.searchResult && - !isEditting && - branchingSupported; - - // for now, once branching is supported, regerate will be enabled - let regenerateEnabled = - // !message?.error && - !message?.isCreatedByUser && - !message?.searchResult && - !isEditting && - !isSubmitting && - branchingSupported; - - return ( -
- {editEnabled ? ( - - ) : null} - {regenerateEnabled ? ( - - ) : null} - - -
- ); -} diff --git a/client/src/components/Messages/HoverButtons.tsx b/client/src/components/Messages/HoverButtons.tsx new file mode 100644 index 00000000000..5eec313ddba --- /dev/null +++ b/client/src/components/Messages/HoverButtons.tsx @@ -0,0 +1,86 @@ +import { useState } from 'react'; +import type { TConversation, TMessage } from 'librechat-data-provider'; +import { Clipboard, CheckMark, EditIcon, RegenerateIcon, ContinueIcon } from '~/components/svg'; +import { useGenerations } from '~/hooks'; +import { cn } from '~/utils'; + +type THoverButtons = { + isEditing: boolean; + enterEdit: () => void; + copyToClipboard: (setIsCopied: (isCopied: boolean) => void) => void; + conversation: TConversation; + isSubmitting: boolean; + message: TMessage; + regenerate: () => void; + handleContinue: (e: React.MouseEvent) => void; +}; + +export default function HoverButtons({ + isEditing, + enterEdit, + copyToClipboard, + conversation, + isSubmitting, + message, + regenerate, + handleContinue, +}: THoverButtons) { + const { endpoint } = conversation; + const [isCopied, setIsCopied] = useState(false); + const { editEnabled, regenerateEnabled, continueSupported } = useGenerations({ + isEditing, + isSubmitting, + message, + endpoint: endpoint ?? '', + }); + + return ( +
+ + + {regenerateEnabled ? ( + + ) : null} + {continueSupported ? ( + + ) : null} +
+ ); +} diff --git a/client/src/components/Messages/Message.jsx b/client/src/components/Messages/Message.jsx index 87ce0c40342..a27e6bbc43b 100644 --- a/client/src/components/Messages/Message.jsx +++ b/client/src/components/Messages/Message.jsx @@ -1,6 +1,6 @@ /* eslint-disable react-hooks/exhaustive-deps */ import { useState, useEffect, useRef } from 'react'; -import { useRecoilValue, useSetRecoilState } from 'recoil'; +import { useSetRecoilState } from 'recoil'; import copy from 'copy-to-clipboard'; import Plugin from './Plugin'; import SubRow from './Content/SubRow'; @@ -25,13 +25,12 @@ export default function Message({ setSiblingIdx, }) { const { text, searchResult, isCreatedByUser, error, submitting, unfinished } = message; - const isSubmitting = useRecoilValue(store.isSubmitting); const setLatestMessage = useSetRecoilState(store.latestMessage); const [abortScroll, setAbort] = useState(false); const textEditor = useRef(null); const last = !message?.children?.length; const edit = message.messageId == currentEditId; - const { ask, regenerate } = useMessageHandler(); + const { isSubmitting, ask, regenerate, handleContinue } = useMessageHandler(); const { switchToConversation } = store.useConversation(); const blinker = submitting && isSubmitting; const getConversationQuery = useGetConversationByIdQuery(message.conversationId, { @@ -223,12 +222,13 @@ export default function Message({ )}
enterEdit()} regenerate={() => regenerateMessage()} + handleContinue={handleContinue} copyToClipboard={copyToClipboard} /> diff --git a/client/src/components/Messages/ScrollToBottom.jsx b/client/src/components/Messages/ScrollToBottom.tsx similarity index 79% rename from client/src/components/Messages/ScrollToBottom.jsx rename to client/src/components/Messages/ScrollToBottom.tsx index 281f0d420f9..3c76555ff2e 100644 --- a/client/src/components/Messages/ScrollToBottom.jsx +++ b/client/src/components/Messages/ScrollToBottom.tsx @@ -1,10 +1,12 @@ -import React from 'react'; +type Props = { + scrollHandler: React.MouseEventHandler; +}; -export default function ScrollToBottom({ scrollHandler }) { +export default function ScrollToBottom({ scrollHandler }: Props) { return (