From 13b174b2022fb0ac89ade046223943a90f94482f Mon Sep 17 00:00:00 2001 From: Mamadou DICKO <63923024+mamadoudicko@users.noreply.github.com> Date: Thu, 30 Nov 2023 12:49:04 +0100 Subject: [PATCH] feat: add optimistic update on new message (#1764) Demo: https://github.com/StanGirard/quivr/assets/63923024/3aecb83f-3acd-46d4-900d-a042814c6638 Issue: https://github.com/StanGirard/quivr/issues/1753 --- frontend/app/chat/[chatId]/hooks/useChat.ts | 3 +-- .../chat/[chatId]/hooks/useHandleStream.ts | 9 ++++++- .../app/chat/[chatId]/hooks/useQuestion.ts | 16 +++++++++--- .../utils/generatePlaceHolderMessage.ts | 23 +++++++++++++++++ .../lib/context/ChatProvider/ChatProvider.tsx | 25 ++++--------------- .../ChatProvider/mocks/ChatProviderMock.tsx | 3 +-- frontend/lib/context/ChatProvider/types.ts | 3 +-- 7 files changed, 52 insertions(+), 30 deletions(-) create mode 100644 frontend/app/chat/[chatId]/utils/generatePlaceHolderMessage.ts diff --git a/frontend/app/chat/[chatId]/hooks/useChat.ts b/frontend/app/chat/[chatId]/hooks/useChat.ts index d64e41b8de24..bc770d9b664a 100644 --- a/frontend/app/chat/[chatId]/hooks/useChat.ts +++ b/frontend/app/chat/[chatId]/hooks/useChat.ts @@ -91,9 +91,8 @@ export const useChat = () => { prompt_id: currentPromptId ?? undefined, }; - await addStreamQuestion(currentChatId, chatQuestion); - callback?.(); + await addStreamQuestion(currentChatId, chatQuestion); if (shouldUpdateUrl) { router.replace(`/chat/${currentChatId}`); diff --git a/frontend/app/chat/[chatId]/hooks/useHandleStream.ts b/frontend/app/chat/[chatId]/hooks/useHandleStream.ts index 0ad9b479ffc0..0001774e44d2 100644 --- a/frontend/app/chat/[chatId]/hooks/useHandleStream.ts +++ b/frontend/app/chat/[chatId]/hooks/useHandleStream.ts @@ -7,9 +7,11 @@ export const useHandleStream = () => { const { updateStreamingHistory } = useChatContext(); const handleStream = async ( - reader: ReadableStreamDefaultReader + reader: ReadableStreamDefaultReader, + onFirstChunk: () => void ): Promise => { const decoder = new TextDecoder("utf-8"); + let isFirstChunk = true; const handleStreamRecursively = async () => { const { done, value } = await reader.read(); @@ -18,6 +20,11 @@ export const useHandleStream = () => { return; } + if (isFirstChunk) { + isFirstChunk = false; + onFirstChunk(); + } + const dataStrings = decoder .decode(value) .trim() diff --git a/frontend/app/chat/[chatId]/hooks/useQuestion.ts b/frontend/app/chat/[chatId]/hooks/useQuestion.ts index f3031ec0608c..0fdd3b220d3b 100644 --- a/frontend/app/chat/[chatId]/hooks/useQuestion.ts +++ b/frontend/app/chat/[chatId]/hooks/useQuestion.ts @@ -1,10 +1,12 @@ import { useTranslation } from "react-i18next"; +import { useChatContext } from "@/lib/context"; import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; import { useFetch, useToast } from "@/lib/hooks"; import { useHandleStream } from "./useHandleStream"; import { ChatQuestion } from "../types"; +import { generatePlaceHolderMessage } from "../utils/generatePlaceHolderMessage"; interface UseChatService { addStreamQuestion: ( @@ -20,6 +22,7 @@ export const useQuestion = (): UseChatService => { const { t } = useTranslation(["chat"]); const { publish } = useToast(); const { handleStream } = useHandleStream(); + const { removeMessage, updateStreamingHistory } = useChatContext(); const handleFetchError = async (response: Response) => { if (response.status === 429) { @@ -36,8 +39,6 @@ export const useQuestion = (): UseChatService => { variant: "danger", text: errorMessage.detail, }); - - return; }; const addStreamQuestion = async ( @@ -48,6 +49,13 @@ export const useQuestion = (): UseChatService => { "Content-Type": "application/json", Accept: "text/event-stream", }; + + const placeHolderMessage = generatePlaceHolderMessage({ + user_message: chatQuestion.question ?? "", + chat_id: chatId, + }); + updateStreamingHistory(placeHolderMessage); + const body = JSON.stringify(chatQuestion); try { @@ -66,7 +74,9 @@ export const useQuestion = (): UseChatService => { throw new Error(t("resposeBodyNull", { ns: "chat" })); } - await handleStream(response.body.getReader()); + await handleStream(response.body.getReader(), () => + removeMessage(placeHolderMessage.message_id) + ); } catch (error) { publish({ variant: "danger", diff --git a/frontend/app/chat/[chatId]/utils/generatePlaceHolderMessage.ts b/frontend/app/chat/[chatId]/utils/generatePlaceHolderMessage.ts new file mode 100644 index 000000000000..982898b507c6 --- /dev/null +++ b/frontend/app/chat/[chatId]/utils/generatePlaceHolderMessage.ts @@ -0,0 +1,23 @@ +import { ChatMessage } from "../types"; + +type GeneratePlaceHolderMessageProps = { + user_message: string; + chat_id: string; +}; + +export const generatePlaceHolderMessage = ({ + user_message, + chat_id, +}: GeneratePlaceHolderMessageProps): ChatMessage => { + const message_id = new Date().getTime().toString(); + const message_time = new Date().toISOString(); + const assistant = ""; + + return { + message_id, + message_time, + assistant, + chat_id, + user_message, + }; +}; diff --git a/frontend/lib/context/ChatProvider/ChatProvider.tsx b/frontend/lib/context/ChatProvider/ChatProvider.tsx index 726e9b93e336..2f1c2123db9f 100644 --- a/frontend/lib/context/ChatProvider/ChatProvider.tsx +++ b/frontend/lib/context/ChatProvider/ChatProvider.tsx @@ -18,10 +18,6 @@ export const ChatProvider = ({ const [messages, setMessages] = useState([]); const [notifications, setNotifications] = useState([]); - const addToHistory = (message: ChatMessage) => { - setMessages((prevHistory) => [...prevHistory, message]); - }; - const updateStreamingHistory = (streamedChat: ChatMessage): void => { setMessages((prevHistory: ChatMessage[]) => { const updatedHistory = prevHistory.find( @@ -38,20 +34,10 @@ export const ChatProvider = ({ }); }; - const updateHistory = (chat: ChatMessage): void => { - setMessages((prevHistory: ChatMessage[]) => { - const updatedHistory = prevHistory.find( - (item) => item.message_id === chat.message_id - ) - ? prevHistory.map((item: ChatMessage) => - item.message_id === chat.message_id - ? { ...item, assistant: chat.assistant } - : item - ) - : [...prevHistory, chat]; - - return updatedHistory; - }); + const removeMessage = (id: string): void => { + setMessages((prevHistory: ChatMessage[]) => + prevHistory.filter((item) => item.message_id !== id) + ); }; return ( @@ -59,9 +45,8 @@ export const ChatProvider = ({ value={{ messages, setMessages, - addToHistory, - updateHistory, updateStreamingHistory, + removeMessage, notifications, setNotifications, }} diff --git a/frontend/lib/context/ChatProvider/mocks/ChatProviderMock.tsx b/frontend/lib/context/ChatProvider/mocks/ChatProviderMock.tsx index 47e4381f48f9..8c940facb2ba 100644 --- a/frontend/lib/context/ChatProvider/mocks/ChatProviderMock.tsx +++ b/frontend/lib/context/ChatProvider/mocks/ChatProviderMock.tsx @@ -14,11 +14,10 @@ export const ChatProviderMock = ({ value={{ messages: [], setMessages: () => void 0, - addToHistory: () => void 0, - updateHistory: () => void 0, updateStreamingHistory: () => void 0, notifications: [], setNotifications: () => void 0, + removeMessage: () => void 0, }} > {children} diff --git a/frontend/lib/context/ChatProvider/types.ts b/frontend/lib/context/ChatProvider/types.ts index b6af853b33ac..039d62e32e36 100644 --- a/frontend/lib/context/ChatProvider/types.ts +++ b/frontend/lib/context/ChatProvider/types.ts @@ -11,9 +11,8 @@ export type ChatConfig = { export type ChatContextProps = { messages: ChatMessage[]; setMessages: (history: ChatMessage[]) => void; - addToHistory: (message: ChatMessage) => void; - updateHistory: (chat: ChatMessage) => void; updateStreamingHistory: (streamedChat: ChatMessage) => void; notifications: Notification[]; setNotifications: (notifications: Notification[]) => void; + removeMessage: (id: string) => void; };