-
Notifications
You must be signed in to change notification settings - Fork 148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support multi agent for ts #300
Changes from 1 commit
413593b
622b84b
f464b40
0ebcb9f
f43f00a
6c05872
2c7a538
5daf519
b875618
b030a3d
33ce593
de5ba29
aff4f0c
c4041e2
f659721
54d74f8
d69cd42
054ee5b
7297edf
3ebc3ec
8cfabc5
305296b
ea3bbcf
325c7ca
45f7529
234b15e
32c3d89
7079b68
6ecd5f8
0679c37
fa45102
2fb502e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ import { | |
Settings, | ||
ToolCall, | ||
ToolCallLLM, | ||
ToolCallLLMMessageOptions, | ||
callTool, | ||
} from "llamaindex"; | ||
import { AgentInput, AgentRunEvent } from "./type"; | ||
|
||
|
@@ -33,7 +35,6 @@ export class FunctionCallingAgent extends Workflow { | |
systemPrompt?: string; | ||
writeEvents: boolean; | ||
role?: string; | ||
toolCalled: boolean = false; | ||
|
||
constructor(options: { | ||
name: string; | ||
|
@@ -75,26 +76,18 @@ export class FunctionCallingAgent extends Workflow { | |
} | ||
|
||
private get chatHistory() { | ||
return this.memory.getAllMessages(); | ||
} | ||
|
||
private get toolsByName() { | ||
return this.tools.reduce((acc: Record<string, BaseToolWithCall>, tool) => { | ||
acc[tool.metadata.name] = tool; | ||
return acc; | ||
}, {}); | ||
return this.memory.getMessages(); | ||
} | ||
|
||
private async prepareChatHistory( | ||
ctx: Context, | ||
ev: StartEvent<AgentInput>, | ||
): Promise<InputEvent> { | ||
this.toolCalled = false; | ||
const { message, streaming } = ev.data.input; | ||
ctx.set("streaming", streaming); | ||
this.writeEvent(`Start to work on: ${message}`, ctx); | ||
if (this.systemPrompt) { | ||
this.memory.put({ role: "assistant", content: this.systemPrompt }); | ||
this.memory.put({ role: "system", content: this.systemPrompt }); | ||
} | ||
this.memory.put({ role: "user", content: message }); | ||
return new InputEvent({ input: this.chatHistory }); | ||
|
@@ -112,8 +105,10 @@ export class FunctionCallingAgent extends Workflow { | |
messages: this.chatHistory, | ||
tools: this.tools, | ||
}); | ||
this.memory.put(result.message); | ||
|
||
const toolCalls = this.getToolCallsFromResponse(result); | ||
if (toolCalls.length && !this.toolCalled) { | ||
if (toolCalls.length) { | ||
return new ToolCallEvent({ toolCalls }); | ||
} | ||
this.writeEvent("Finished task", ctx); | ||
|
@@ -151,7 +146,8 @@ export class FunctionCallingAgent extends Workflow { | |
if (fullResponse) { | ||
memory.put({ | ||
role: "assistant", | ||
content: fullResponse.delta, | ||
content: "", | ||
options: fullResponse.options, | ||
}); | ||
yield fullResponse; | ||
} | ||
|
@@ -162,7 +158,7 @@ export class FunctionCallingAgent extends Workflow { | |
if (isToolCall.value) { | ||
const fullResponse = await generator.next(); | ||
const toolCalls = this.getToolCallsFromResponse( | ||
fullResponse.value as ChatResponseChunk<object>, | ||
fullResponse.value as ChatResponseChunk<ToolCallLLMMessageOptions>, | ||
); | ||
return new ToolCallEvent({ toolCalls }); | ||
} | ||
|
@@ -175,48 +171,38 @@ export class FunctionCallingAgent extends Workflow { | |
ctx: Context, | ||
ev: ToolCallEvent, | ||
): Promise<InputEvent> { | ||
this.toolCalled = true; | ||
const { toolCalls } = ev.data; | ||
|
||
const toolMsgs: ChatMessage[] = []; | ||
for (const toolCall of toolCalls) { | ||
const tool = this.toolsByName[toolCall.name]; | ||
const options = { | ||
tool_call_id: toolCall.id, | ||
name: tool.metadata.name, | ||
}; | ||
if (!tool) { | ||
toolMsgs.push({ | ||
role: "assistant", | ||
content: `Tool ${toolCall.name} does not exist`, | ||
options, | ||
}); | ||
continue; | ||
} | ||
|
||
try { | ||
const toolInput = JSON.parse(toolCall.input.toString()); | ||
const toolOutput = await tool.call(toolInput); | ||
toolMsgs.push({ | ||
role: "assistant", | ||
content: toolOutput.toString(), | ||
options, | ||
}); | ||
} catch (e) { | ||
console.error(e); | ||
toolMsgs.push({ | ||
role: "assistant", | ||
content: `Encountered error in tool call: ${toolCall.name}`, | ||
options, | ||
}); | ||
} | ||
for (const call of toolCalls) { | ||
const targetTool = this.tools.find( | ||
(tool) => tool.metadata.name === call.name, | ||
); | ||
// TODO: make logger optional in callTool in framework | ||
const toolOutput = await callTool(targetTool, call, { | ||
log: () => {}, | ||
error: console.error.bind(console), | ||
warn: () => {}, | ||
}); | ||
toolMsgs.push({ | ||
content: JSON.stringify(toolOutput.output), | ||
role: "user", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, it will be transformed by LITS There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. was also confusing me |
||
options: { | ||
toolResult: { | ||
result: toolOutput.output, | ||
isError: toolOutput.isError, | ||
id: call.id, | ||
}, | ||
}, | ||
}); | ||
} | ||
|
||
for (const msg of toolMsgs) { | ||
this.memory.put(msg); | ||
} | ||
|
||
return new InputEvent({ input: this.memory.getAllMessages() }); | ||
return new InputEvent({ input: this.memory.getMessages() }); | ||
} | ||
|
||
private writeEvent(msg: string, context: Context) { | ||
|
@@ -231,10 +217,10 @@ export class FunctionCallingAgent extends Workflow { | |
if (!supportToolCall) throw new Error("LLM does not support tool calls"); | ||
} | ||
|
||
// TODO: in LITS, llm should have a method to get tool calls from response | ||
// then we don't need to use toolCalled flag | ||
private getToolCallsFromResponse( | ||
response: ChatResponse<object> | ChatResponseChunk<object>, | ||
response: | ||
| ChatResponse<ToolCallLLMMessageOptions> | ||
| ChatResponseChunk<ToolCallLLMMessageOptions>, | ||
): ToolCall[] { | ||
let options; | ||
if ("message" in response) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@himself65 using
callTool
from the framework is very helpful, but would be nice if the logger object would be optional