Skip to content

Commit

Permalink
refactor: unify response and agent response (run-llama#930)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusschiesser authored Jun 17, 2024
1 parent 834f492 commit 436bc41
Show file tree
Hide file tree
Showing 29 changed files with 217 additions and 195 deletions.
5 changes: 5 additions & 0 deletions .changeset/five-ants-watch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"llamaindex": minor
---

Unify chat engine response and agent response
2 changes: 1 addition & 1 deletion examples/agent/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async function main() {
message: "How much is 5 + 5? then divide by 2",
});

console.log(response.response.message);
console.log(response.message);
}

void main().then(() => {
Expand Down
2 changes: 1 addition & 1 deletion examples/agent/react_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async function main() {
});

// Chat with the agent
const { response } = await agent.chat({
const response = await agent.chat({
message: "Divide 16 by 2 then add 20",
});

Expand Down
2 changes: 1 addition & 1 deletion examples/agent/step_wise_query_tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async function main() {
tools: [queryEngineTool],
});

const { response } = await agent.chat({
const response = await agent.chat({
message: "What was his salary?",
});

Expand Down
4 changes: 1 addition & 3 deletions examples/agent/stream_openai_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ async function main() {

console.log("Response:");

for await (const {
response: { delta },
} of stream) {
for await (const { delta } of stream) {
process.stdout.write(delta);
}
}
Expand Down
4 changes: 1 addition & 3 deletions examples/agent/wiki.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ async function main() {
stream: true,
});

for await (const {
response: { delta },
} of response) {
for await (const { delta } of response) {
process.stdout.write(delta);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export default {
// @ts-expect-error: see https://github.com/cloudflare/workerd/issues/2067
new TransformStream({
transform: (chunk, controller) => {
controller.enqueue(textEncoder.encode(chunk.response.delta));
controller.enqueue(textEncoder.encode(chunk.delta));
},
}),
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export async function chatWithAgent(
uiStream.update("response:");
},
write: async (message) => {
uiStream.append(message.response.delta);
uiStream.append(message.delta);
},
}),
)
Expand Down
19 changes: 9 additions & 10 deletions packages/core/e2e/node/claude.e2e.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { consola } from "consola";
import { Anthropic, FunctionTool, Settings, type LLM } from "llamaindex";
import { AnthropicAgent } from "llamaindex/agent/anthropic";
import { extractText } from "llamaindex/llm/utils";
import { ok, strictEqual } from "node:assert";
import { ok } from "node:assert";
import { beforeEach, test } from "node:test";
import { getWeatherTool, sumNumbersTool } from "./fixtures/tools.js";
import { mockLLMEvent } from "./utils.js";
Expand Down Expand Up @@ -71,12 +71,11 @@ await test("anthropic agent", async (t) => {
},
],
});
const { response, sources } = await agent.chat({
const response = await agent.chat({
message: "What is the weather in San Francisco?",
});
consola.debug("response:", response.message.content);

strictEqual(sources.length, 1);
ok(extractText(response.message.content).includes("35"));
});

Expand Down Expand Up @@ -110,7 +109,7 @@ await test("anthropic agent", async (t) => {
const agent = new AnthropicAgent({
tools: [showUniqueId],
});
const { response } = await agent.chat({
const response = await agent.chat({
message: "My name is Alex Yang. What is my unique id?",
});
consola.debug("response:", response.message.content);
Expand All @@ -122,7 +121,7 @@ await test("anthropic agent", async (t) => {
tools: [sumNumbersTool],
});

const { response } = await anthropicAgent.chat({
const response = await anthropicAgent.chat({
message: "how much is 1 + 1?",
});

Expand All @@ -137,35 +136,35 @@ await test("anthropic agent with multiple chat", async (t) => {
tools: [getWeatherTool],
});
{
const { response } = await agent.chat({
const response = await agent.chat({
message: 'Hello? Response to me "Yes"',
});
consola.debug("response:", response.message.content);
ok(extractText(response.message.content).includes("Yes"));
}
{
const { response } = await agent.chat({
const response = await agent.chat({
message: 'Hello? Response to me "No"',
});
consola.debug("response:", response.message.content);
ok(extractText(response.message.content).includes("No"));
}
{
const { response } = await agent.chat({
const response = await agent.chat({
message: 'Hello? Response to me "Maybe"',
});
consola.debug("response:", response.message.content);
ok(extractText(response.message.content).includes("Maybe"));
}
{
const { response } = await agent.chat({
const response = await agent.chat({
message: "What is the weather in San Francisco?",
});
consola.debug("response:", response.message.content);
ok(extractText(response.message.content).includes("72"));
}
{
const { response } = await agent.chat({
const response = await agent.chat({
message: "What is the weather in Shanghai?",
});
consola.debug("response:", response.message.content);
Expand Down
24 changes: 8 additions & 16 deletions packages/core/e2e/node/openai.e2e.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import {
SummaryIndex,
VectorStoreIndex,
type LLM,
type ToolOutput,
} from "llamaindex";
import { extractText } from "llamaindex/llm/utils";
import { ok, strictEqual } from "node:assert";
Expand Down Expand Up @@ -93,7 +92,7 @@ await test("gpt-4-turbo", async (t) => {
},
],
});
const { response } = await agent.chat({
const response = await agent.chat({
message: "What is the weather in San Jose?",
});
consola.debug("response:", response.message.content);
Expand All @@ -109,7 +108,7 @@ await test("agent system prompt", async (t) => {
systemPrompt:
"You are a pirate. You MUST speak every words staring with a 'Arhgs'",
});
const { response } = await agent.chat({
const response = await agent.chat({
message: "What is the weather in San Francisco?",
});
consola.debug("response:", response.message.content);
Expand Down Expand Up @@ -187,7 +186,7 @@ For questions about more specific sections, please use the vector_tool.`,
});

strictEqual(mockCall.mock.callCount(), 0);
const { response } = await agent.chat({
const response = await agent.chat({
message:
"What's the summary of Alex? Does he live in Brazil based on the brief information? Return yes or no.",
});
Expand Down Expand Up @@ -224,12 +223,11 @@ await test("agent with object function call", async (t) => {
),
],
});
const { response, sources } = await agent.chat({
const response = await agent.chat({
message: "What is the weather in San Francisco?",
});
consola.debug("response:", response.message.content);

strictEqual(sources.length, 1);
ok(extractText(response.message.content).includes("72"));
});
});
Expand Down Expand Up @@ -257,12 +255,11 @@ await test("agent", async (t) => {
},
],
});
const { response, sources } = await agent.chat({
const response = await agent.chat({
message: "What is the weather in San Francisco?",
});
consola.debug("response:", response.message.content);

strictEqual(sources.length, 1);
ok(extractText(response.message.content).includes("35"));
});

Expand Down Expand Up @@ -296,10 +293,9 @@ await test("agent", async (t) => {
const agent = new OpenAIAgent({
tools: [showUniqueId],
});
const { response, sources } = await agent.chat({
const response = await agent.chat({
message: "My name is Alex Yang. What is my unique id?",
});
strictEqual(sources.length, 1);
ok(extractText(response.message.content).includes(uniqueId));
});

Expand All @@ -308,11 +304,10 @@ await test("agent", async (t) => {
tools: [sumNumbersTool],
});

const { response, sources } = await openaiAgent.chat({
const response = await openaiAgent.chat({
message: "how much is 1 + 1?",
});

strictEqual(sources.length, 1);
ok(extractText(response.message.content).includes("2"));
});
});
Expand All @@ -333,15 +328,12 @@ await test("agent stream", async (t) => {
});

let message = "";
let soruces: ToolOutput[] = [];

for await (const { response, sources: _sources } of stream) {
for await (const response of stream) {
message += response.delta;
soruces = _sources;
}

strictEqual(fn.mock.callCount(), 2);
strictEqual(soruces.length, 2);
ok(message.includes("28"));
Settings.callbackManager.off("llm-tool-call", fn);
});
Expand Down
4 changes: 2 additions & 2 deletions packages/core/e2e/node/react.e2e.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ await test("react agent", async (t) => {
const agent = new ReActAgent({
tools: [getWeatherTool],
});
const { response } = await agent.chat({
const response = await agent.chat({
stream: false,
message: "What is the weather like in San Francisco?",
});
Expand All @@ -41,7 +41,7 @@ await test("react agent stream", async (t) => {
});

let content = "";
for await (const { response } of stream) {
for await (const response of stream) {
content += response.delta;
}
ok(content.includes("72"));
Expand Down
90 changes: 90 additions & 0 deletions packages/core/src/EngineResponse.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import type { NodeWithScore } from "./Node.js";
import type {
ChatMessage,
ChatResponse,
ChatResponseChunk,
} from "./llm/types.js";
import { extractText } from "./llm/utils.js";

export class EngineResponse implements ChatResponse, ChatResponseChunk {
sourceNodes?: NodeWithScore[];

metadata: Record<string, unknown> = {};

message: ChatMessage;
raw: object | null;

#stream: boolean;

private constructor(
chatResponse: ChatResponse,
stream: boolean,
sourceNodes?: NodeWithScore[],
) {
this.message = chatResponse.message;
this.raw = chatResponse.raw;
this.sourceNodes = sourceNodes;
this.#stream = stream;
}

static fromResponse(
response: string,
stream: boolean,
sourceNodes?: NodeWithScore[],
): EngineResponse {
return new EngineResponse(
EngineResponse.toChatResponse(response),
stream,
sourceNodes,
);
}

private static toChatResponse(
response: string,
raw: object | null = null,
): ChatResponse {
return {
message: {
content: response,
role: "assistant",
},
raw,
};
}

static fromChatResponse(
chatResponse: ChatResponse,
sourceNodes?: NodeWithScore[],
): EngineResponse {
return new EngineResponse(chatResponse, false, sourceNodes);
}

static fromChatResponseChunk(
chunk: ChatResponseChunk,
sourceNodes?: NodeWithScore[],
): EngineResponse {
return new EngineResponse(
this.toChatResponse(chunk.delta, chunk.raw),
true,
sourceNodes,
);
}

// @deprecated use 'message' instead
get response(): string {
return extractText(this.message.content);
}

get delta(): string {
if (!this.#stream) {
console.warn(
"delta is only available for streaming responses. Consider using 'message' instead.",
);
}
return extractText(this.message.content);
}

toString() {
return this.response ?? "";
}
}
23 changes: 0 additions & 23 deletions packages/core/src/Response.ts

This file was deleted.

Loading

0 comments on commit 436bc41

Please sign in to comment.