diff --git a/src/components/ChatContent/AssistantInput.tsx b/src/components/ChatContent/AssistantInput.tsx index 25d22e1b..b17a13e0 100644 --- a/src/components/ChatContent/AssistantInput.tsx +++ b/src/components/ChatContent/AssistantInput.tsx @@ -2,7 +2,14 @@ import React from "react"; import { Markdown, MarkdownProps } from "../Markdown"; import { Container, Box } from "@radix-ui/themes"; -import { ToolCall, ToolResult } from "../../events"; +import { + ChatContextFile, + ChatMessages, + ToolCall, + ToolResult, + isChatContextFileMessage, + isToolMessage, +} from "../../events"; import { ToolContent } from "./ToolsContent"; type ChatInputProps = Pick< @@ -11,7 +18,7 @@ type ChatInputProps = Pick< > & { message: string | null; toolCalls?: ToolCall[] | null; - toolResults: Record; + auxMessages?: ChatMessages; }; function fallbackCopying(text: string) { @@ -31,12 +38,28 @@ function fallbackCopying(text: string) { } export const AssistantInput: React.FC = (props) => { + const messages = props.auxMessages ?? []; + + const results = messages.reduce>( + (acc, message) => { + if (isToolMessage(message)) { + const result = message[1]; + return { + ...acc, + [result.tool_call_id]: result, + }; + } + return acc; + }, + {}, + ); + + const files = messages.reduce((acc, message) => { + if (!isChatContextFileMessage(message)) return acc; + return [...acc, ...message[1]]; + }, []); return ( - {props.toolCalls && ( - - )} - {props.message && ( = (props) => { )} + ); }; diff --git a/src/components/ChatContent/ChatContent.tsx b/src/components/ChatContent/ChatContent.tsx index 5c2a9e3d..c87028c2 100644 --- a/src/components/ChatContent/ChatContent.tsx +++ b/src/components/ChatContent/ChatContent.tsx @@ -1,4 +1,4 @@ -import React from "react"; +import React, { useCallback, useMemo } from "react"; import { ChatMessages, ToolResult, @@ -15,6 +15,7 @@ import { ContextFiles } from "./ContextFiles"; import { AssistantInput } from "./AssistantInput"; import { MemoryContent } from "./MemoryContent"; import { useAutoScroll } from "./useAutoScroll"; +import { takeWhile } from "../../utils"; const PlaceHolderText: React.FC = () => ( Welcome to Refact chat! How can I assist you today? @@ -40,18 +41,41 @@ export const ChatContent = React.forwardRef( isStreaming, } = props; - const { innerRef, handleScroll } = useAutoScroll({ ref, messages }); + const { innerRef, handleScroll } = useAutoScroll({ + ref, + messages, + isStreaming, + }); + + const handleRetry = useCallback( + (count: number, question: string) => { + const toSend = messages.slice(0, count).concat([["user", question]]); + onRetry(toSend); + }, + [messages, onRetry], + ); - const toolResultsMap = React.useMemo(() => { - return messages.reduce>((acc, message) => { - if (!isToolMessage(message)) return acc; - const result = message[1]; - return { - ...acc, - [result.tool_call_id]: result, - }; - }, {}); - }, [messages]); + const elements = useMemo( + () => + groupMessages( + messages, + handleRetry, + isStreaming, + isWaiting, + onNewFileClick, + onPasteClick, + canPaste, + ), + [ + canPaste, + handleRetry, + isStreaming, + isWaiting, + messages, + onNewFileClick, + onPasteClick, + ], + ); return ( ( onScroll={handleScroll} > - {messages.length === 0 && } - {messages.map((message, index) => { - if (isChatContextFileMessage(message)) { - const [, files] = message; - return ; - } - - const [role, text] = message; - - if (role === "user") { - const handleRetry = (question: string) => { - const toSend = messages - .slice(0, index) - .concat([["user", question]]); - onRetry(toSend); - }; - return ( - - {text} - - ); - } else if (role === "assistant") { - return ( - - ); - } else if (role === "tool") { - return null; - } else if (role === "context_memory") { - return ; - } else { - return null; - // return {text}; - } - })} + {elements.length === 0 && } + {elements} {isWaiting && }
@@ -115,3 +95,100 @@ export const ChatContent = React.forwardRef( ); ChatContent.displayName = "ChatContent"; + +function groupMessages( + messages: ChatMessages, + onRetry: (messageIndex: number, question: string) => void, + isStreaming: boolean, + isWaiting: boolean, + onNewFileClick: MarkdownProps["onNewFileClick"], + onPasteClick: MarkdownProps["onNewFileClick"], + canPaste: boolean, + memo: JSX.Element[] = [], + toolCallsMap: Record = {}, + count = 0, +): JSX.Element[] { + if (messages.length === 0) return memo; + const [head, ...tail] = messages; + + const key = `message-${head[0]}-${memo.length}`; + + const nextCall = ( + m: ChatMessages, + p: JSX.Element[], + toolCalls = toolCallsMap, + nextCount = count + 1, + ): JSX.Element[] => + groupMessages( + m, + onRetry, + isStreaming, + isWaiting, + onNewFileClick, + onPasteClick, + canPaste, + p, + toolCalls, + nextCount, + ); + + if (isToolMessage(head)) { + const result = head[1]; + const toolCalls = { ...toolCallsMap, [result.tool_call_id]: result }; + + return nextCall(tail, memo, toolCalls); + } + + if (isChatContextFileMessage(head)) { + const [, files] = head; + const processed = memo.concat(); + return nextCall(tail, processed); + } + + if (head[0] === "context_memory") { + const proccesed = memo.concat(); + return nextCall(tail, proccesed); + } + + if (head[0] === "user") { + const text = head[1]; + const proccesed = memo.concat( + onRetry(count, question)} + key={key} + disableRetry={isStreaming || isWaiting} + > + {text} + , + ); + + return nextCall(tail, proccesed); + } + + if (head[0] === "assistant") { + const text = head[1]; + const tools = head[2]; + + const nextContextOrToolMessages = takeWhile( + tail, + (message) => message[0] === "context_file" || message[0] === "tool", + ); + const nextTail = tail.slice(nextContextOrToolMessages.length); + const nextCount = count + nextContextOrToolMessages.length; + const proccesed = memo.concat( + , + ); + + return nextCall(nextTail, proccesed, toolCallsMap, nextCount); + } + + return nextCall(tail, memo); +} diff --git a/src/components/ChatContent/ToolsContent.tsx b/src/components/ChatContent/ToolsContent.tsx index c50ba01a..ddb8d6a2 100644 --- a/src/components/ChatContent/ToolsContent.tsx +++ b/src/components/ChatContent/ToolsContent.tsx @@ -1,11 +1,12 @@ import React from "react"; import * as Collapsible from "@radix-ui/react-collapsible"; import { Container, Flex, Text, Box, Button } from "@radix-ui/themes"; -import { ToolCall, ToolResult } from "../../events"; +import { ChatContextFile, ToolCall, ToolResult } from "../../events"; import { ChevronDownIcon } from "@radix-ui/react-icons"; import classNames from "classnames"; import styles from "./ChatContent.module.css"; import { Markdown } from "../CommandLine/Markdown"; +import { ContextFiles } from "./ContextFiles"; const Chevron: React.FC<{ open: boolean }> = ({ open }) => { return ( @@ -48,7 +49,8 @@ const Result: React.FC<{ children: string }> = ({ children }) => { const ToolMessage: React.FC<{ toolCall: ToolCall; result?: ToolResult; -}> = ({ toolCall, result }) => { + files: ChatContextFile[]; +}> = ({ toolCall, result, files }) => { const results = result?.content ?? ""; const name = toolCall.function.name ?? ""; @@ -79,6 +81,7 @@ const ToolMessage: React.FC<{ return ( {content} + ); }; @@ -86,7 +89,8 @@ const ToolMessage: React.FC<{ export const ToolContent: React.FC<{ toolCalls: ToolCall[]; results: Record; -}> = ({ toolCalls, results }) => { + files: ChatContextFile[]; +}> = ({ toolCalls, results, files }) => { const [open, setOpen] = React.useState(false); if (toolCalls.length === 0) return null; @@ -115,7 +119,11 @@ export const ToolContent: React.FC<{ const key = `${toolCall.id}-${toolCall.index}`; return ( - + ); })} diff --git a/src/components/ChatContent/useAutoScroll.ts b/src/components/ChatContent/useAutoScroll.ts index a5885ede..8d4aae2d 100644 --- a/src/components/ChatContent/useAutoScroll.ts +++ b/src/components/ChatContent/useAutoScroll.ts @@ -4,20 +4,35 @@ import { type ChatMessages } from "../../events"; type useAutoScrollProps = { ref: React.ForwardedRef; messages: ChatMessages; + isStreaming: boolean; }; -export function useAutoScroll({ ref, messages }: useAutoScrollProps) { +export function useAutoScroll({ + ref, + messages, + isStreaming, +}: useAutoScrollProps) { const innerRef = useRef(null); // eslint-disable-next-line @typescript-eslint/no-non-null-assertion useImperativeHandle(ref, () => innerRef.current!, []); - const [autoScroll, setAutoScroll] = useState(true); + const [autoScroll, setAutoScroll] = useState(false); + const [isVisable, setIsVisable] = useState(true); useEffect(() => { - if (autoScroll && innerRef.current?.scrollIntoView) { + setAutoScroll(isStreaming); + }, [isStreaming]); + + useEffect(() => { + if ( + isStreaming && + !isVisable && + autoScroll && + innerRef.current?.scrollIntoView + ) { innerRef.current.scrollIntoView({ behavior: "instant", block: "end" }); } - }, [messages, autoScroll]); + }, [messages, autoScroll, isVisable, isStreaming]); const handleScroll: React.UIEventHandler = (event) => { if (!innerRef.current) return; @@ -27,12 +42,7 @@ export function useAutoScroll({ ref, messages }: useAutoScrollProps) { top <= parent.top ? parent.top - top <= height : bottom - parent.bottom <= height; - - if (isVisable && !autoScroll) { - setAutoScroll(true); - } else if (autoScroll) { - setAutoScroll(false); - } + setIsVisable(isVisable); }; return { handleScroll, innerRef }; diff --git a/src/utils/index.ts b/src/utils/index.ts index 159b857d..bd1b6de2 100644 --- a/src/utils/index.ts +++ b/src/utils/index.ts @@ -1,3 +1,4 @@ export * from "./ApiKey"; export * from "./createSyntheticEvent"; export * from "./trimIndent"; +export * from "./takeWhile"; diff --git a/src/utils/takeWhile.ts b/src/utils/takeWhile.ts new file mode 100644 index 00000000..9d1cdc8d --- /dev/null +++ b/src/utils/takeWhile.ts @@ -0,0 +1,11 @@ +export function takeWhile( + arr: T[], + fun: (a: T) => boolean, + memo: T[] = [], +): T[] { + if (arr.length === 0) return memo; + const [head, ...tail] = arr; + if (!fun(head)) return memo; + const nextMemo = [...memo, head]; + return takeWhile(tail, fun, nextMemo); +} diff --git a/src/utils/utils.test.tsx b/src/utils/utils.test.tsx index 1a91215a..161d4f0f 100644 --- a/src/utils/utils.test.tsx +++ b/src/utils/utils.test.tsx @@ -1,5 +1,6 @@ import { describe, test, expect } from "vitest"; -import { trimIndentFromMarkdown, trimIndent } from "./trimIndent"; +import { trimIndentFromMarkdown, trimIndent, takeWhile } from "./index"; + const spaces = " "; describe("trim indent from markdown", () => { const tests = [ @@ -36,3 +37,12 @@ describe("trim indent", () => { expect(result).toBe(expected); }); }); + +describe("take while", () => { + test("when given an array and predicate it should take elements from the array until the predicate fails", () => { + const input = [1, 1, 2, 1]; + const expected = [1, 1]; + const predicate = (n: number) => n === 1; + expect(takeWhile(input, predicate)).toEqual(expected); + }); +});