Skip to content

Commit

Permalink
wip: render context files with the tool that called them
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcMcIntosh committed Jun 18, 2024
1 parent e719719 commit deccbc9
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 79 deletions.
40 changes: 34 additions & 6 deletions src/components/ChatContent/AssistantInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@ import React from "react";
import { Markdown, MarkdownProps } from "../Markdown";

import { Container, Box } from "@radix-ui/themes";
import { ToolCall, ToolResult } from "../../events";
import {
ChatContextFile,
ChatMessages,
ToolCall,
ToolResult,
isChatContextFileMessage,
isToolMessage,
} from "../../events";
import { ToolContent } from "./ToolsContent";

type ChatInputProps = Pick<
Expand All @@ -11,7 +18,7 @@ type ChatInputProps = Pick<
> & {
message: string | null;
toolCalls?: ToolCall[] | null;
toolResults: Record<string, ToolResult>;
auxMessages?: ChatMessages;
};

function fallbackCopying(text: string) {
Expand All @@ -31,12 +38,28 @@ function fallbackCopying(text: string) {
}

export const AssistantInput: React.FC<ChatInputProps> = (props) => {
const messages = props.auxMessages ?? [];

const results = messages.reduce<Record<string, ToolResult>>(
(acc, message) => {
if (isToolMessage(message)) {
const result = message[1];
return {
...acc,
[result.tool_call_id]: result,
};
}
return acc;
},
{},
);

const files = messages.reduce<ChatContextFile[]>((acc, message) => {
if (!isChatContextFileMessage(message)) return acc;
return [...acc, ...message[1]];
}, []);
return (
<Container position="relative">
{props.toolCalls && (
<ToolContent toolCalls={props.toolCalls} results={props.toolResults} />
)}

{props.message && (
<Box py="4">
<Markdown
Expand All @@ -59,6 +82,11 @@ export const AssistantInput: React.FC<ChatInputProps> = (props) => {
</Markdown>
</Box>
)}
<ToolContent
toolCalls={props.toolCalls ?? []}
results={results}
files={files}
/>
</Container>
);
};
193 changes: 135 additions & 58 deletions src/components/ChatContent/ChatContent.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import React from "react";
import React, { useCallback, useMemo } from "react";
import {
ChatMessages,
ToolResult,
Expand All @@ -15,6 +15,7 @@ import { ContextFiles } from "./ContextFiles";
import { AssistantInput } from "./AssistantInput";
import { MemoryContent } from "./MemoryContent";
import { useAutoScroll } from "./useAutoScroll";
import { takeWhile } from "../../utils";

const PlaceHolderText: React.FC = () => (
<Text>Welcome to Refact chat! How can I assist you today?</Text>
Expand All @@ -40,18 +41,41 @@ export const ChatContent = React.forwardRef<HTMLDivElement, ChatContentProps>(
isStreaming,
} = props;

const { innerRef, handleScroll } = useAutoScroll({ ref, messages });
const { innerRef, handleScroll } = useAutoScroll({
ref,
messages,
isStreaming,
});

const handleRetry = useCallback(
(count: number, question: string) => {
const toSend = messages.slice(0, count).concat([["user", question]]);
onRetry(toSend);
},
[messages, onRetry],
);

const toolResultsMap = React.useMemo(() => {
return messages.reduce<Record<string, ToolResult>>((acc, message) => {
if (!isToolMessage(message)) return acc;
const result = message[1];
return {
...acc,
[result.tool_call_id]: result,
};
}, {});
}, [messages]);
const elements = useMemo(
() =>
groupMessages(
messages,
handleRetry,
isStreaming,
isWaiting,
onNewFileClick,
onPasteClick,
canPaste,
),
[
canPaste,
handleRetry,
isStreaming,
isWaiting,
messages,
onNewFileClick,
onPasteClick,
],
);

return (
<ScrollArea
Expand All @@ -60,52 +84,8 @@ export const ChatContent = React.forwardRef<HTMLDivElement, ChatContentProps>(
onScroll={handleScroll}
>
<Flex direction="column" className={styles.content} p="2" gap="2">
{messages.length === 0 && <PlaceHolderText />}
{messages.map((message, index) => {
if (isChatContextFileMessage(message)) {
const [, files] = message;
return <ContextFiles key={index} files={files} />;
}

const [role, text] = message;

if (role === "user") {
const handleRetry = (question: string) => {
const toSend = messages
.slice(0, index)
.concat([["user", question]]);
onRetry(toSend);
};
return (
<UserInput
onRetry={handleRetry}
key={index}
disableRetry={isStreaming || isWaiting}
>
{text}
</UserInput>
);
} else if (role === "assistant") {
return (
<AssistantInput
onNewFileClick={onNewFileClick}
onPasteClick={onPasteClick}
canPaste={canPaste}
key={index}
message={text}
toolCalls={message[2]}
toolResults={toolResultsMap}
/>
);
} else if (role === "tool") {
return null;
} else if (role === "context_memory") {
return <MemoryContent key={index} items={text} />;
} else {
return null;
// return <Markdown key={index}>{text}</Markdown>;
}
})}
{elements.length === 0 && <PlaceHolderText />}
{elements}
{isWaiting && <Spinner />}
<div ref={innerRef} />
</Flex>
Expand All @@ -115,3 +95,100 @@ export const ChatContent = React.forwardRef<HTMLDivElement, ChatContentProps>(
);

ChatContent.displayName = "ChatContent";

function groupMessages(
messages: ChatMessages,
onRetry: (messageIndex: number, question: string) => void,
isStreaming: boolean,
isWaiting: boolean,
onNewFileClick: MarkdownProps["onNewFileClick"],
onPasteClick: MarkdownProps["onNewFileClick"],
canPaste: boolean,
memo: JSX.Element[] = [],
toolCallsMap: Record<string, ToolResult> = {},
count = 0,
): JSX.Element[] {
if (messages.length === 0) return memo;
const [head, ...tail] = messages;

const key = `message-${head[0]}-${memo.length}`;

const nextCall = (
m: ChatMessages,
p: JSX.Element[],
toolCalls = toolCallsMap,
nextCount = count + 1,
): JSX.Element[] =>
groupMessages(
m,
onRetry,
isStreaming,
isWaiting,
onNewFileClick,
onPasteClick,
canPaste,
p,
toolCalls,
nextCount,
);

if (isToolMessage(head)) {
const result = head[1];
const toolCalls = { ...toolCallsMap, [result.tool_call_id]: result };

return nextCall(tail, memo, toolCalls);
}

if (isChatContextFileMessage(head)) {
const [, files] = head;
const processed = memo.concat(<ContextFiles key={key} files={files} />);
return nextCall(tail, processed);
}

if (head[0] === "context_memory") {
const proccesed = memo.concat(<MemoryContent key={key} items={head[1]} />);
return nextCall(tail, proccesed);
}

if (head[0] === "user") {
const text = head[1];
const proccesed = memo.concat(
<UserInput
onRetry={(question) => onRetry(count, question)}
key={key}
disableRetry={isStreaming || isWaiting}
>
{text}
</UserInput>,
);

return nextCall(tail, proccesed);
}

if (head[0] === "assistant") {
const text = head[1];
const tools = head[2];

const nextContextOrToolMessages = takeWhile(
tail,
(message) => message[0] === "context_file" || message[0] === "tool",
);
const nextTail = tail.slice(nextContextOrToolMessages.length);
const nextCount = count + nextContextOrToolMessages.length;
const proccesed = memo.concat(
<AssistantInput
onNewFileClick={onNewFileClick}
onPasteClick={onPasteClick}
canPaste={canPaste}
key={key}
message={text}
toolCalls={tools}
auxMessages={nextContextOrToolMessages}
/>,
);

return nextCall(nextTail, proccesed, toolCallsMap, nextCount);
}

return nextCall(tail, memo);
}
16 changes: 12 additions & 4 deletions src/components/ChatContent/ToolsContent.tsx
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import React from "react";
import * as Collapsible from "@radix-ui/react-collapsible";
import { Container, Flex, Text, Box, Button } from "@radix-ui/themes";
import { ToolCall, ToolResult } from "../../events";
import { ChatContextFile, ToolCall, ToolResult } from "../../events";
import { ChevronDownIcon } from "@radix-ui/react-icons";
import classNames from "classnames";
import styles from "./ChatContent.module.css";
import { Markdown } from "../CommandLine/Markdown";
import { ContextFiles } from "./ContextFiles";

const Chevron: React.FC<{ open: boolean }> = ({ open }) => {
return (
Expand Down Expand Up @@ -48,7 +49,8 @@ const Result: React.FC<{ children: string }> = ({ children }) => {
const ToolMessage: React.FC<{
toolCall: ToolCall;
result?: ToolResult;
}> = ({ toolCall, result }) => {
files: ChatContextFile[];
}> = ({ toolCall, result, files }) => {
const results = result?.content ?? "";
const name = toolCall.function.name ?? "";

Expand Down Expand Up @@ -79,14 +81,16 @@ const ToolMessage: React.FC<{
return (
<Flex gap="2" direction="column">
<Result>{content}</Result>
<ContextFiles files={files} />
</Flex>
);
};

export const ToolContent: React.FC<{
toolCalls: ToolCall[];
results: Record<string, ToolResult>;
}> = ({ toolCalls, results }) => {
files: ChatContextFile[];
}> = ({ toolCalls, results, files }) => {
const [open, setOpen] = React.useState(false);

if (toolCalls.length === 0) return null;
Expand Down Expand Up @@ -115,7 +119,11 @@ export const ToolContent: React.FC<{
const key = `${toolCall.id}-${toolCall.index}`;
return (
<Box key={key} py="2">
<ToolMessage toolCall={toolCall} result={result} />
<ToolMessage
toolCall={toolCall}
result={result}
files={files}
/>
</Box>
);
})}
Expand Down
Loading

0 comments on commit deccbc9

Please sign in to comment.