diff --git a/.changeset/five-ants-watch.md b/.changeset/five-ants-watch.md new file mode 100644 index 0000000000..02bdc3d584 --- /dev/null +++ b/.changeset/five-ants-watch.md @@ -0,0 +1,5 @@ +--- +"llamaindex": minor +--- + +Unify chat engine response and agent response diff --git a/examples/agent/openai.ts b/examples/agent/openai.ts index 67d0d4b556..a0780b7c7e 100644 --- a/examples/agent/openai.ts +++ b/examples/agent/openai.ts @@ -53,7 +53,7 @@ async function main() { message: "How much is 5 + 5? then divide by 2", }); - console.log(response.response.message); + console.log(response.message); } void main().then(() => { diff --git a/examples/agent/react_agent.ts b/examples/agent/react_agent.ts index 764f17feae..bb094237f3 100644 --- a/examples/agent/react_agent.ts +++ b/examples/agent/react_agent.ts @@ -68,7 +68,7 @@ async function main() { }); // Chat with the agent - const { response } = await agent.chat({ + const response = await agent.chat({ message: "Divide 16 by 2 then add 20", }); diff --git a/examples/agent/step_wise_query_tool.ts b/examples/agent/step_wise_query_tool.ts index 92a5ebe0f2..829d35f26e 100644 --- a/examples/agent/step_wise_query_tool.ts +++ b/examples/agent/step_wise_query_tool.ts @@ -31,7 +31,7 @@ async function main() { tools: [queryEngineTool], }); - const { response } = await agent.chat({ + const response = await agent.chat({ message: "What was his salary?", }); diff --git a/examples/agent/stream_openai_agent.ts b/examples/agent/stream_openai_agent.ts index 111853fba8..4d8d6e8fcb 100644 --- a/examples/agent/stream_openai_agent.ts +++ b/examples/agent/stream_openai_agent.ts @@ -68,9 +68,7 @@ async function main() { console.log("Response:"); - for await (const { - response: { delta }, - } of stream) { + for await (const { delta } of stream) { process.stdout.write(delta); } } diff --git a/examples/agent/wiki.ts b/examples/agent/wiki.ts index 379c6cdacd..e4100e9909 100644 --- a/examples/agent/wiki.ts +++ b/examples/agent/wiki.ts @@ -16,9 +16,7 @@ async function main() { stream: true, }); - for await (const { - response: { delta }, - } of response) { + for await (const { delta } of response) { process.stdout.write(delta); } } diff --git a/packages/core/e2e/examples/cloudflare-worker-agent/src/index.ts b/packages/core/e2e/examples/cloudflare-worker-agent/src/index.ts index eeb45a3fa5..d96fe9d745 100644 --- a/packages/core/e2e/examples/cloudflare-worker-agent/src/index.ts +++ b/packages/core/e2e/examples/cloudflare-worker-agent/src/index.ts @@ -21,7 +21,7 @@ export default { // @ts-expect-error: see https://github.com/cloudflare/workerd/issues/2067 new TransformStream({ transform: (chunk, controller) => { - controller.enqueue(textEncoder.encode(chunk.response.delta)); + controller.enqueue(textEncoder.encode(chunk.delta)); }, }), ); diff --git a/packages/core/e2e/examples/nextjs-agent/src/actions/index.tsx b/packages/core/e2e/examples/nextjs-agent/src/actions/index.tsx index 8aa6b12f9d..b092fd0b89 100644 --- a/packages/core/e2e/examples/nextjs-agent/src/actions/index.tsx +++ b/packages/core/e2e/examples/nextjs-agent/src/actions/index.tsx @@ -23,7 +23,7 @@ export async function chatWithAgent( uiStream.update("response:"); }, write: async (message) => { - uiStream.append(message.response.delta); + uiStream.append(message.delta); }, }), ) diff --git a/packages/core/e2e/node/claude.e2e.ts b/packages/core/e2e/node/claude.e2e.ts index faa96cd735..c8fde2d908 100644 --- a/packages/core/e2e/node/claude.e2e.ts +++ b/packages/core/e2e/node/claude.e2e.ts @@ -2,7 +2,7 @@ import { consola } from "consola"; import { Anthropic, FunctionTool, Settings, type LLM } from "llamaindex"; import { AnthropicAgent } from "llamaindex/agent/anthropic"; import { extractText } from "llamaindex/llm/utils"; -import { ok, strictEqual } from "node:assert"; +import { ok } from "node:assert"; import { beforeEach, test } from "node:test"; import { getWeatherTool, sumNumbersTool } from "./fixtures/tools.js"; import { mockLLMEvent } from "./utils.js"; @@ -71,12 +71,11 @@ await test("anthropic agent", async (t) => { }, ], }); - const { response, sources } = await agent.chat({ + const response = await agent.chat({ message: "What is the weather in San Francisco?", }); consola.debug("response:", response.message.content); - strictEqual(sources.length, 1); ok(extractText(response.message.content).includes("35")); }); @@ -110,7 +109,7 @@ await test("anthropic agent", async (t) => { const agent = new AnthropicAgent({ tools: [showUniqueId], }); - const { response } = await agent.chat({ + const response = await agent.chat({ message: "My name is Alex Yang. What is my unique id?", }); consola.debug("response:", response.message.content); @@ -122,7 +121,7 @@ await test("anthropic agent", async (t) => { tools: [sumNumbersTool], }); - const { response } = await anthropicAgent.chat({ + const response = await anthropicAgent.chat({ message: "how much is 1 + 1?", }); @@ -137,35 +136,35 @@ await test("anthropic agent with multiple chat", async (t) => { tools: [getWeatherTool], }); { - const { response } = await agent.chat({ + const response = await agent.chat({ message: 'Hello? Response to me "Yes"', }); consola.debug("response:", response.message.content); ok(extractText(response.message.content).includes("Yes")); } { - const { response } = await agent.chat({ + const response = await agent.chat({ message: 'Hello? Response to me "No"', }); consola.debug("response:", response.message.content); ok(extractText(response.message.content).includes("No")); } { - const { response } = await agent.chat({ + const response = await agent.chat({ message: 'Hello? Response to me "Maybe"', }); consola.debug("response:", response.message.content); ok(extractText(response.message.content).includes("Maybe")); } { - const { response } = await agent.chat({ + const response = await agent.chat({ message: "What is the weather in San Francisco?", }); consola.debug("response:", response.message.content); ok(extractText(response.message.content).includes("72")); } { - const { response } = await agent.chat({ + const response = await agent.chat({ message: "What is the weather in Shanghai?", }); consola.debug("response:", response.message.content); diff --git a/packages/core/e2e/node/openai.e2e.ts b/packages/core/e2e/node/openai.e2e.ts index e9bfceb587..6545846740 100644 --- a/packages/core/e2e/node/openai.e2e.ts +++ b/packages/core/e2e/node/openai.e2e.ts @@ -13,7 +13,6 @@ import { SummaryIndex, VectorStoreIndex, type LLM, - type ToolOutput, } from "llamaindex"; import { extractText } from "llamaindex/llm/utils"; import { ok, strictEqual } from "node:assert"; @@ -93,7 +92,7 @@ await test("gpt-4-turbo", async (t) => { }, ], }); - const { response } = await agent.chat({ + const response = await agent.chat({ message: "What is the weather in San Jose?", }); consola.debug("response:", response.message.content); @@ -109,7 +108,7 @@ await test("agent system prompt", async (t) => { systemPrompt: "You are a pirate. You MUST speak every words staring with a 'Arhgs'", }); - const { response } = await agent.chat({ + const response = await agent.chat({ message: "What is the weather in San Francisco?", }); consola.debug("response:", response.message.content); @@ -187,7 +186,7 @@ For questions about more specific sections, please use the vector_tool.`, }); strictEqual(mockCall.mock.callCount(), 0); - const { response } = await agent.chat({ + const response = await agent.chat({ message: "What's the summary of Alex? Does he live in Brazil based on the brief information? Return yes or no.", }); @@ -224,12 +223,11 @@ await test("agent with object function call", async (t) => { ), ], }); - const { response, sources } = await agent.chat({ + const response = await agent.chat({ message: "What is the weather in San Francisco?", }); consola.debug("response:", response.message.content); - strictEqual(sources.length, 1); ok(extractText(response.message.content).includes("72")); }); }); @@ -257,12 +255,11 @@ await test("agent", async (t) => { }, ], }); - const { response, sources } = await agent.chat({ + const response = await agent.chat({ message: "What is the weather in San Francisco?", }); consola.debug("response:", response.message.content); - strictEqual(sources.length, 1); ok(extractText(response.message.content).includes("35")); }); @@ -296,10 +293,9 @@ await test("agent", async (t) => { const agent = new OpenAIAgent({ tools: [showUniqueId], }); - const { response, sources } = await agent.chat({ + const response = await agent.chat({ message: "My name is Alex Yang. What is my unique id?", }); - strictEqual(sources.length, 1); ok(extractText(response.message.content).includes(uniqueId)); }); @@ -308,11 +304,10 @@ await test("agent", async (t) => { tools: [sumNumbersTool], }); - const { response, sources } = await openaiAgent.chat({ + const response = await openaiAgent.chat({ message: "how much is 1 + 1?", }); - strictEqual(sources.length, 1); ok(extractText(response.message.content).includes("2")); }); }); @@ -333,15 +328,12 @@ await test("agent stream", async (t) => { }); let message = ""; - let soruces: ToolOutput[] = []; - for await (const { response, sources: _sources } of stream) { + for await (const response of stream) { message += response.delta; - soruces = _sources; } strictEqual(fn.mock.callCount(), 2); - strictEqual(soruces.length, 2); ok(message.includes("28")); Settings.callbackManager.off("llm-tool-call", fn); }); diff --git a/packages/core/e2e/node/react.e2e.ts b/packages/core/e2e/node/react.e2e.ts index 54626cd205..f500e2041d 100644 --- a/packages/core/e2e/node/react.e2e.ts +++ b/packages/core/e2e/node/react.e2e.ts @@ -19,7 +19,7 @@ await test("react agent", async (t) => { const agent = new ReActAgent({ tools: [getWeatherTool], }); - const { response } = await agent.chat({ + const response = await agent.chat({ stream: false, message: "What is the weather like in San Francisco?", }); @@ -41,7 +41,7 @@ await test("react agent stream", async (t) => { }); let content = ""; - for await (const { response } of stream) { + for await (const response of stream) { content += response.delta; } ok(content.includes("72")); diff --git a/packages/core/src/EngineResponse.ts b/packages/core/src/EngineResponse.ts new file mode 100644 index 0000000000..8cfa9d195a --- /dev/null +++ b/packages/core/src/EngineResponse.ts @@ -0,0 +1,90 @@ +import type { NodeWithScore } from "./Node.js"; +import type { + ChatMessage, + ChatResponse, + ChatResponseChunk, +} from "./llm/types.js"; +import { extractText } from "./llm/utils.js"; + +export class EngineResponse implements ChatResponse, ChatResponseChunk { + sourceNodes?: NodeWithScore[]; + + metadata: Record = {}; + + message: ChatMessage; + raw: object | null; + + #stream: boolean; + + private constructor( + chatResponse: ChatResponse, + stream: boolean, + sourceNodes?: NodeWithScore[], + ) { + this.message = chatResponse.message; + this.raw = chatResponse.raw; + this.sourceNodes = sourceNodes; + this.#stream = stream; + } + + static fromResponse( + response: string, + stream: boolean, + sourceNodes?: NodeWithScore[], + ): EngineResponse { + return new EngineResponse( + EngineResponse.toChatResponse(response), + stream, + sourceNodes, + ); + } + + private static toChatResponse( + response: string, + raw: object | null = null, + ): ChatResponse { + return { + message: { + content: response, + role: "assistant", + }, + raw, + }; + } + + static fromChatResponse( + chatResponse: ChatResponse, + sourceNodes?: NodeWithScore[], + ): EngineResponse { + return new EngineResponse(chatResponse, false, sourceNodes); + } + + static fromChatResponseChunk( + chunk: ChatResponseChunk, + sourceNodes?: NodeWithScore[], + ): EngineResponse { + return new EngineResponse( + this.toChatResponse(chunk.delta, chunk.raw), + true, + sourceNodes, + ); + } + + // @deprecated use 'message' instead + get response(): string { + return extractText(this.message.content); + } + + get delta(): string { + if (!this.#stream) { + console.warn( + "delta is only available for streaming responses. Consider using 'message' instead.", + ); + } + return extractText(this.message.content); + } + + toString() { + return this.response ?? ""; + } +} diff --git a/packages/core/src/Response.ts b/packages/core/src/Response.ts deleted file mode 100644 index c922173d2e..0000000000 --- a/packages/core/src/Response.ts +++ /dev/null @@ -1,23 +0,0 @@ -import type { NodeWithScore } from "./Node.js"; - -/** - * Response is the output of a LLM - */ -export class Response { - response: string; - sourceNodes?: NodeWithScore[]; - metadata: Record = {}; - - constructor(response: string, sourceNodes?: NodeWithScore[]) { - this.response = response; - this.sourceNodes = sourceNodes || []; - } - - protected _getFormattedSources() { - throw new Error("Not implemented yet"); - } - - toString() { - return this.response ?? ""; - } -} diff --git a/packages/core/src/agent/anthropic.ts b/packages/core/src/agent/anthropic.ts index ac415b6e9c..1f917cf3ea 100644 --- a/packages/core/src/agent/anthropic.ts +++ b/packages/core/src/agent/anthropic.ts @@ -1,3 +1,4 @@ +import { EngineResponse } from "../EngineResponse.js"; import { Settings } from "../Settings.js"; import { type ChatEngineParamsNonStreaming, @@ -5,15 +6,9 @@ import { } from "../engines/chat/index.js"; import { stringifyJSONToMessageContent } from "../internal/utils.js"; import { Anthropic } from "../llm/anthropic.js"; -import type { ToolCallLLMMessageOptions } from "../llm/index.js"; import { ObjectRetriever } from "../objects/index.js"; import type { BaseToolWithCall } from "../types.js"; -import { - AgentRunner, - AgentWorker, - type AgentChatResponse, - type AgentParamsBase, -} from "./base.js"; +import { AgentRunner, AgentWorker, type AgentParamsBase } from "./base.js"; import type { TaskHandler } from "./types.js"; import { callTool } from "./utils.js"; @@ -56,9 +51,7 @@ export class AnthropicAgent extends AgentRunner { createStore = AgentRunner.defaultCreateStore; - async chat( - params: ChatEngineParamsNonStreaming, - ): Promise>; + async chat(params: ChatEngineParamsNonStreaming): Promise; async chat(params: ChatEngineParamsStreaming): Promise; override async chat( params: ChatEngineParamsNonStreaming | ChatEngineParamsStreaming, diff --git a/packages/core/src/agent/base.ts b/packages/core/src/agent/base.ts index 1005c15594..a5e18df06c 100644 --- a/packages/core/src/agent/base.ts +++ b/packages/core/src/agent/base.ts @@ -1,4 +1,5 @@ import { ReadableStream, TransformStream, randomUUID } from "@llamaindex/env"; +import { EngineResponse } from "../EngineResponse.js"; import { Settings } from "../Settings.js"; import { type ChatEngine, @@ -9,13 +10,7 @@ import { wrapEventCaller } from "../internal/context/EventCaller.js"; import { consoleLogger, emptyLogger } from "../internal/logger.js"; import { getCallbackManager } from "../internal/settings/CallbackManager.js"; import { isAsyncIterable } from "../internal/utils.js"; -import type { - ChatMessage, - ChatResponse, - ChatResponseChunk, - LLM, - MessageContent, -} from "../llm/index.js"; +import type { ChatMessage, LLM, MessageContent } from "../llm/index.js"; import type { BaseToolWithCall, ToolOutput } from "../types.js"; import type { AgentTaskContext, @@ -101,16 +96,6 @@ export function createTaskOutputStream< }); } -export type AgentStreamChatResponse = { - response: ChatResponseChunk; - sources: ToolOutput[]; -}; - -export type AgentChatResponse = { - response: ChatResponse; - sources: ToolOutput[]; -}; - export type AgentRunnerParams< AI extends LLM, Store extends object = {}, @@ -210,11 +195,7 @@ export abstract class AgentRunner< > ? AdditionalMessageOptions : never, -> implements - ChatEngine< - AgentChatResponse, - ReadableStream> - > +> implements ChatEngine { readonly #llm: AI; readonly #tools: @@ -320,47 +301,30 @@ export abstract class AgentRunner< }); } - async chat( - params: ChatEngineParamsNonStreaming, - ): Promise>; + async chat(params: ChatEngineParamsNonStreaming): Promise; async chat( params: ChatEngineParamsStreaming, - ): Promise>>; + ): Promise>; @wrapEventCaller async chat( params: ChatEngineParamsNonStreaming | ChatEngineParamsStreaming, - ): Promise< - | AgentChatResponse - | ReadableStream> - > { + ): Promise> { const task = this.createTask(params.message, !!params.stream); for await (const stepOutput of task) { // update chat history for each round this.#chatHistory = [...stepOutput.taskStep.context.store.messages]; if (stepOutput.isLast) { - const { output, taskStep } = stepOutput; + const { output } = stepOutput; if (isAsyncIterable(output)) { - return output.pipeThrough< - AgentStreamChatResponse - >( + return output.pipeThrough( new TransformStream({ transform(chunk, controller) { - controller.enqueue({ - response: chunk, - get sources() { - return [...taskStep.context.store.toolOutputs]; - }, - }); + controller.enqueue(EngineResponse.fromChatResponseChunk(chunk)); }, }), ); } else { - return { - response: output, - get sources() { - return [...taskStep.context.store.toolOutputs]; - }, - } satisfies AgentChatResponse; + return EngineResponse.fromChatResponse(output); } } } diff --git a/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts b/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts index 2f885bfb5c..adf099dcd7 100644 --- a/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts +++ b/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts @@ -1,11 +1,11 @@ import type { ChatHistory } from "../../ChatHistory.js"; import { getHistory } from "../../ChatHistory.js"; +import type { EngineResponse } from "../../EngineResponse.js"; import type { CondenseQuestionPrompt } from "../../Prompt.js"; import { defaultCondenseQuestionPrompt, messagesToHistoryStr, } from "../../Prompt.js"; -import type { Response } from "../../Response.js"; import type { ServiceContext } from "../../ServiceContext.js"; import { llmFromSettingsOrContext } from "../../Settings.js"; import { wrapEventCaller } from "../../internal/context/EventCaller.js"; @@ -80,12 +80,14 @@ export class CondenseQuestionChatEngine }); } - chat(params: ChatEngineParamsStreaming): Promise>; - chat(params: ChatEngineParamsNonStreaming): Promise; + chat( + params: ChatEngineParamsStreaming, + ): Promise>; + chat(params: ChatEngineParamsNonStreaming): Promise; @wrapEventCaller async chat( params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, - ): Promise> { + ): Promise> { const { message, stream } = params; const chatHistory = params.chatHistory ? getHistory(params.chatHistory) diff --git a/packages/core/src/engines/chat/ContextChatEngine.ts b/packages/core/src/engines/chat/ContextChatEngine.ts index 5c9625f292..dcd90d2192 100644 --- a/packages/core/src/engines/chat/ContextChatEngine.ts +++ b/packages/core/src/engines/chat/ContextChatEngine.ts @@ -1,11 +1,11 @@ import type { ChatHistory } from "../../ChatHistory.js"; import { getHistory } from "../../ChatHistory.js"; +import { EngineResponse } from "../../EngineResponse.js"; import type { ContextSystemPrompt } from "../../Prompt.js"; -import { Response } from "../../Response.js"; import type { BaseRetriever } from "../../Retriever.js"; import { Settings } from "../../Settings.js"; import { wrapEventCaller } from "../../internal/context/EventCaller.js"; -import type { ChatMessage, ChatResponseChunk, LLM } from "../../llm/index.js"; +import type { ChatMessage, LLM } from "../../llm/index.js"; import type { MessageContent, MessageType } from "../../llm/types.js"; import { extractText, @@ -24,8 +24,7 @@ import type { /** * ContextChatEngine uses the Index to get the appropriate context for each query. - * The context is stored in the system prompt, and the chat history is preserved, - * ideally allowing the appropriate context to be surfaced for each query. + * The context is stored in the system prompt, and the chat history is chunk: ChatResponseChunk, nodes?: NodeWithScore[], nodes?: NodeWithScore[]lowing the appropriate context to be surfaced for each query. */ export class ContextChatEngine extends PromptMixin implements ChatEngine { chatModel: LLM; @@ -60,12 +59,14 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { }; } - chat(params: ChatEngineParamsStreaming): Promise>; - chat(params: ChatEngineParamsNonStreaming): Promise; + chat( + params: ChatEngineParamsStreaming, + ): Promise>; + chat(params: ChatEngineParamsNonStreaming): Promise; @wrapEventCaller async chat( params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, - ): Promise> { + ): Promise> { const { message, stream } = params; const chatHistory = params.chatHistory ? getHistory(params.chatHistory) @@ -88,17 +89,14 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { chatHistory.addMessage({ content: accumulator, role: "assistant" }); }, }), - (r: ChatResponseChunk) => new Response(r.delta, requestMessages.nodes), + (r) => EngineResponse.fromChatResponseChunk(r, requestMessages.nodes), ); } const response = await this.chatModel.chat({ messages: requestMessages.messages, }); chatHistory.addMessage(response.message); - return new Response( - extractText(response.message.content), - requestMessages.nodes, - ); + return EngineResponse.fromChatResponse(response, requestMessages.nodes); } reset() { diff --git a/packages/core/src/engines/chat/SimpleChatEngine.ts b/packages/core/src/engines/chat/SimpleChatEngine.ts index dd6ab0a6d3..248831e0bb 100644 --- a/packages/core/src/engines/chat/SimpleChatEngine.ts +++ b/packages/core/src/engines/chat/SimpleChatEngine.ts @@ -1,14 +1,10 @@ import type { ChatHistory } from "../../ChatHistory.js"; import { getHistory } from "../../ChatHistory.js"; -import { Response } from "../../Response.js"; +import { EngineResponse } from "../../EngineResponse.js"; import { Settings } from "../../Settings.js"; import { wrapEventCaller } from "../../internal/context/EventCaller.js"; -import type { ChatResponseChunk, LLM } from "../../llm/index.js"; -import { - extractText, - streamConverter, - streamReducer, -} from "../../llm/utils.js"; +import type { LLM } from "../../llm/index.js"; +import { streamConverter, streamReducer } from "../../llm/utils.js"; import type { ChatEngine, ChatEngineParamsNonStreaming, @@ -28,12 +24,14 @@ export class SimpleChatEngine implements ChatEngine { this.llm = init?.llm ?? Settings.llm; } - chat(params: ChatEngineParamsStreaming): Promise>; - chat(params: ChatEngineParamsNonStreaming): Promise; + chat( + params: ChatEngineParamsStreaming, + ): Promise>; + chat(params: ChatEngineParamsNonStreaming): Promise; @wrapEventCaller async chat( params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, - ): Promise> { + ): Promise> { const { message, stream } = params; const chatHistory = params.chatHistory @@ -55,7 +53,7 @@ export class SimpleChatEngine implements ChatEngine { chatHistory.addMessage({ content: accumulator, role: "assistant" }); }, }), - (r: ChatResponseChunk) => new Response(r.delta), + EngineResponse.fromChatResponseChunk, ); } @@ -63,7 +61,7 @@ export class SimpleChatEngine implements ChatEngine { messages: await chatHistory.requestMessages(), }); chatHistory.addMessage(response.message); - return new Response(extractText(response.message.content)); + return EngineResponse.fromChatResponse(response); } reset() { diff --git a/packages/core/src/engines/chat/types.ts b/packages/core/src/engines/chat/types.ts index 0b00f1d1db..925187b6fe 100644 --- a/packages/core/src/engines/chat/types.ts +++ b/packages/core/src/engines/chat/types.ts @@ -1,6 +1,6 @@ import type { ChatHistory } from "../../ChatHistory.js"; +import type { EngineResponse } from "../../EngineResponse.js"; import type { NodeWithScore } from "../../Node.js"; -import type { Response } from "../../Response.js"; import type { ChatMessage } from "../../llm/index.js"; import type { MessageContent } from "../../llm/types.js"; @@ -33,7 +33,7 @@ export interface ChatEngineParamsNonStreaming extends ChatEngineParamsBase { */ export interface ChatEngine< // synchronous response - R = Response, + R = EngineResponse, // asynchronous response AR extends AsyncIterable = AsyncIterable, > { diff --git a/packages/core/src/engines/query/RetrieverQueryEngine.ts b/packages/core/src/engines/query/RetrieverQueryEngine.ts index 153cb42657..9b265e54f7 100644 --- a/packages/core/src/engines/query/RetrieverQueryEngine.ts +++ b/packages/core/src/engines/query/RetrieverQueryEngine.ts @@ -1,5 +1,5 @@ +import type { EngineResponse } from "../../EngineResponse.js"; import type { NodeWithScore } from "../../Node.js"; -import type { Response } from "../../Response.js"; import type { BaseRetriever } from "../../Retriever.js"; import { wrapEventCaller } from "../../internal/context/EventCaller.js"; import type { BaseNodePostprocessor } from "../../postprocessors/index.js"; @@ -67,12 +67,14 @@ export class RetrieverQueryEngine extends PromptMixin implements QueryEngine { return await this.applyNodePostprocessors(nodes, query); } - query(params: QueryEngineParamsStreaming): Promise>; - query(params: QueryEngineParamsNonStreaming): Promise; + query( + params: QueryEngineParamsStreaming, + ): Promise>; + query(params: QueryEngineParamsNonStreaming): Promise; @wrapEventCaller async query( params: QueryEngineParamsStreaming | QueryEngineParamsNonStreaming, - ): Promise> { + ): Promise> { const { query, stream } = params; const nodesWithScore = await this.retrieve(query); if (stream) { diff --git a/packages/core/src/engines/query/RouterQueryEngine.ts b/packages/core/src/engines/query/RouterQueryEngine.ts index 197fec0766..6ffa0f9662 100644 --- a/packages/core/src/engines/query/RouterQueryEngine.ts +++ b/packages/core/src/engines/query/RouterQueryEngine.ts @@ -1,5 +1,5 @@ +import { EngineResponse } from "../../EngineResponse.js"; import type { NodeWithScore } from "../../Node.js"; -import { Response } from "../../Response.js"; import type { ServiceContext } from "../../ServiceContext.js"; import { llmFromSettingsOrContext } from "../../Settings.js"; import { PromptMixin } from "../../prompts/index.js"; @@ -24,10 +24,10 @@ type RouterQueryEngineMetadata = { async function combineResponses( summarizer: TreeSummarize, - responses: Response[], + responses: EngineResponse[], queryBundle: QueryBundle, verbose: boolean = false, -): Promise { +): Promise { if (verbose) { console.log("Combining responses from multiple query engines."); } @@ -48,7 +48,7 @@ async function combineResponses( textChunks: responseStrs, }); - return new Response(summary, sourceNodes); + return EngineResponse.fromResponse(summary, false, sourceNodes); } /** @@ -108,11 +108,13 @@ export class RouterQueryEngine extends PromptMixin implements QueryEngine { }); } - query(params: QueryEngineParamsStreaming): Promise>; - query(params: QueryEngineParamsNonStreaming): Promise; + query( + params: QueryEngineParamsStreaming, + ): Promise>; + query(params: QueryEngineParamsNonStreaming): Promise; async query( params: QueryEngineParamsStreaming | QueryEngineParamsNonStreaming, - ): Promise> { + ): Promise> { const { query, stream } = params; const response = await this.queryRoute({ queryStr: query }); @@ -124,11 +126,11 @@ export class RouterQueryEngine extends PromptMixin implements QueryEngine { return response; } - private async queryRoute(queryBundle: QueryBundle): Promise { + private async queryRoute(queryBundle: QueryBundle): Promise { const result = await this.selector.select(this.metadatas, queryBundle); if (result.selections.length > 1) { - const responses: Response[] = []; + const responses: EngineResponse[] = []; for (let i = 0; i < result.selections.length; i++) { const engineInd = result.selections[i]; const logStr = `Selecting query engine ${engineInd}: ${result.selections[i]}.`; diff --git a/packages/core/src/engines/query/SubQuestionQueryEngine.ts b/packages/core/src/engines/query/SubQuestionQueryEngine.ts index 188e0ab9aa..98e17528e7 100644 --- a/packages/core/src/engines/query/SubQuestionQueryEngine.ts +++ b/packages/core/src/engines/query/SubQuestionQueryEngine.ts @@ -1,7 +1,7 @@ +import type { EngineResponse } from "../../EngineResponse.js"; import type { NodeWithScore } from "../../Node.js"; import { TextNode } from "../../Node.js"; import { LLMQuestionGenerator } from "../../QuestionGenerator.js"; -import type { Response } from "../../Response.js"; import type { ServiceContext } from "../../ServiceContext.js"; import { PromptMixin } from "../../prompts/Mixin.js"; import type { BaseSynthesizer } from "../../synthesizers/index.js"; @@ -74,12 +74,14 @@ export class SubQuestionQueryEngine extends PromptMixin implements QueryEngine { }); } - query(params: QueryEngineParamsStreaming): Promise>; - query(params: QueryEngineParamsNonStreaming): Promise; + query( + params: QueryEngineParamsStreaming, + ): Promise>; + query(params: QueryEngineParamsNonStreaming): Promise; @wrapEventCaller async query( params: QueryEngineParamsStreaming | QueryEngineParamsNonStreaming, - ): Promise> { + ): Promise> { const { query, stream } = params; const subQuestions = await this.questionGen.generate(this.metadatas, query); diff --git a/packages/core/src/evaluation/types.ts b/packages/core/src/evaluation/types.ts index 94ced95bd4..a38146b0fd 100644 --- a/packages/core/src/evaluation/types.ts +++ b/packages/core/src/evaluation/types.ts @@ -1,4 +1,4 @@ -import { Response } from "../Response.js"; +import { EngineResponse } from "../EngineResponse.js"; export type EvaluationResult = { query?: string; @@ -22,7 +22,7 @@ export type EvaluatorParams = { export type EvaluatorResponseParams = { query: string | null; - response: Response; + response: EngineResponse; }; export interface BaseEvaluator { evaluate(params: EvaluatorParams): Promise; diff --git a/packages/core/src/index.edge.ts b/packages/core/src/index.edge.ts index 37a9f0993e..6c4907f2a4 100644 --- a/packages/core/src/index.edge.ts +++ b/packages/core/src/index.edge.ts @@ -1,10 +1,10 @@ export * from "./ChatHistory.js"; +export * from "./EngineResponse.js"; export * from "./Node.js"; export * from "./OutputParser.js"; export * from "./Prompt.js"; export * from "./PromptHelper.js"; export * from "./QuestionGenerator.js"; -export * from "./Response.js"; export * from "./Retriever.js"; export * from "./ServiceContext.js"; export { Settings } from "./Settings.js"; diff --git a/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts b/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts index 3eca33ebd7..21aa5efb6f 100644 --- a/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts +++ b/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts @@ -1,5 +1,5 @@ +import { EngineResponse } from "../EngineResponse.js"; import { MetadataMode } from "../Node.js"; -import { Response } from "../Response.js"; import type { ServiceContext } from "../ServiceContext.js"; import { llmFromSettingsOrContext } from "../Settings.js"; import { streamConverter } from "../llm/utils.js"; @@ -49,14 +49,14 @@ export class MultiModalResponseSynthesizer synthesize( params: SynthesizeParamsStreaming, - ): Promise>; - synthesize(params: SynthesizeParamsNonStreaming): Promise; + ): Promise>; + synthesize(params: SynthesizeParamsNonStreaming): Promise; async synthesize({ query, nodesWithScore, stream, }: SynthesizeParamsStreaming | SynthesizeParamsNonStreaming): Promise< - AsyncIterable | Response + AsyncIterable | EngineResponse > { const nodes = nodesWithScore.map(({ node }) => node); const prompt = await createMessageContent( @@ -73,14 +73,13 @@ export class MultiModalResponseSynthesizer prompt, stream, }); - return streamConverter( - response, - ({ text }) => new Response(text, nodesWithScore), + return streamConverter(response, ({ text }) => + EngineResponse.fromResponse(text, true, nodesWithScore), ); } const response = await llm.complete({ prompt, }); - return new Response(response.text, nodesWithScore); + return EngineResponse.fromResponse(response.text, false, nodesWithScore); } } diff --git a/packages/core/src/synthesizers/ResponseSynthesizer.ts b/packages/core/src/synthesizers/ResponseSynthesizer.ts index 22a8903eee..a9faab1d89 100644 --- a/packages/core/src/synthesizers/ResponseSynthesizer.ts +++ b/packages/core/src/synthesizers/ResponseSynthesizer.ts @@ -1,5 +1,5 @@ +import { EngineResponse } from "../EngineResponse.js"; import { MetadataMode } from "../Node.js"; -import { Response } from "../Response.js"; import type { ServiceContext } from "../ServiceContext.js"; import { streamConverter } from "../llm/utils.js"; import { PromptMixin } from "../prompts/Mixin.js"; @@ -57,14 +57,14 @@ export class ResponseSynthesizer synthesize( params: SynthesizeParamsStreaming, - ): Promise>; - synthesize(params: SynthesizeParamsNonStreaming): Promise; + ): Promise>; + synthesize(params: SynthesizeParamsNonStreaming): Promise; async synthesize({ query, nodesWithScore, stream, }: SynthesizeParamsStreaming | SynthesizeParamsNonStreaming): Promise< - AsyncIterable | Response + AsyncIterable | EngineResponse > { const textChunks: string[] = nodesWithScore.map(({ node }) => node.getContent(this.metadataMode), @@ -75,15 +75,14 @@ export class ResponseSynthesizer textChunks, stream, }); - return streamConverter( - response, - (chunk) => new Response(chunk, nodesWithScore), + return streamConverter(response, (chunk) => + EngineResponse.fromResponse(chunk, true, nodesWithScore), ); } const response = await this.responseBuilder.getResponse({ query, textChunks, }); - return new Response(response, nodesWithScore); + return EngineResponse.fromResponse(response, false, nodesWithScore); } } diff --git a/packages/core/src/synthesizers/types.ts b/packages/core/src/synthesizers/types.ts index abfcd9b16b..67e998a6ae 100644 --- a/packages/core/src/synthesizers/types.ts +++ b/packages/core/src/synthesizers/types.ts @@ -1,6 +1,6 @@ +import type { EngineResponse } from "../EngineResponse.js"; import type { NodeWithScore } from "../Node.js"; import type { PromptMixin } from "../prompts/Mixin.js"; -import type { Response } from "../Response.js"; export interface SynthesizeParamsBase { query: string; @@ -21,8 +21,8 @@ export interface SynthesizeParamsNonStreaming extends SynthesizeParamsBase { export interface BaseSynthesizer { synthesize( params: SynthesizeParamsStreaming, - ): Promise>; - synthesize(params: SynthesizeParamsNonStreaming): Promise; + ): Promise>; + synthesize(params: SynthesizeParamsNonStreaming): Promise; } export interface ResponseBuilderParamsBase { diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts index f187649431..c1f1292827 100644 --- a/packages/core/src/types.ts +++ b/packages/core/src/types.ts @@ -2,7 +2,7 @@ * Top level types to avoid circular dependencies */ import { type JSONSchemaType } from "ajv"; -import type { Response } from "./Response.js"; +import type { EngineResponse } from "./EngineResponse.js"; /** * Parameters for sending a query. @@ -27,8 +27,10 @@ export interface QueryEngine { * Query the query engine and get a response. * @param params */ - query(params: QueryEngineParamsStreaming): Promise>; - query(params: QueryEngineParamsNonStreaming): Promise; + query( + params: QueryEngineParamsStreaming, + ): Promise>; + query(params: QueryEngineParamsNonStreaming): Promise; } type Known = diff --git a/packages/experimental/src/engines/query/JSONQueryEngine.ts b/packages/experimental/src/engines/query/JSONQueryEngine.ts index c911407968..4f5e469e51 100644 --- a/packages/experimental/src/engines/query/JSONQueryEngine.ts +++ b/packages/experimental/src/engines/query/JSONQueryEngine.ts @@ -1,6 +1,6 @@ import jsonpath from "jsonpath"; -import { Response } from "llamaindex"; +import { EngineResponse } from "llamaindex"; import { serviceContextFromDefaults, type ServiceContext } from "llamaindex"; @@ -147,11 +147,13 @@ export class JSONQueryEngine implements QueryEngine { return JSON.stringify(this.jsonSchema); } - query(params: QueryEngineParamsStreaming): Promise>; - query(params: QueryEngineParamsNonStreaming): Promise; + query( + params: QueryEngineParamsStreaming, + ): Promise>; + query(params: QueryEngineParamsNonStreaming): Promise; async query( params: QueryEngineParamsStreaming | QueryEngineParamsNonStreaming, - ): Promise> { + ): Promise> { const { query, stream } = params; if (stream) { @@ -200,7 +202,7 @@ export class JSONQueryEngine implements QueryEngine { jsonPathResponseStr, }; - const response = new Response(responseStr, []); + const response = EngineResponse.fromResponse(responseStr, false); response.metadata = responseMetadata;