From 0407c66eccc64430ef3248f3ac2ab99ea55db22d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Dvo=C5=99=C3=A1k?= Date: Wed, 11 Dec 2024 12:20:00 +0100 Subject: [PATCH] feat(tools): propagate agent's runner memory to tools (#242) Signed-off-by: Tomas Dvorak --- examples/tools/llm.ts | 28 ++++++ src/agents/bee/runners/default/runner.ts | 6 +- src/agents/experimental/replan/agent.ts | 6 +- src/tools/base.ts | 4 + src/tools/llm.ts | 108 ++++++++++++++++------- 5 files changed, 115 insertions(+), 37 deletions(-) create mode 100644 examples/tools/llm.ts diff --git a/examples/tools/llm.ts b/examples/tools/llm.ts new file mode 100644 index 00000000..d27121f5 --- /dev/null +++ b/examples/tools/llm.ts @@ -0,0 +1,28 @@ +import "dotenv/config"; +import { LLMTool } from "bee-agent-framework/tools/llm"; +import { OllamaChatLLM } from "bee-agent-framework/adapters/ollama/chat"; +import { Tool } from "bee-agent-framework/tools/base"; +import { UnconstrainedMemory } from "bee-agent-framework/memory/unconstrainedMemory"; +import { BaseMessage } from "bee-agent-framework/llms/primitives/message"; + +const memory = new UnconstrainedMemory(); +await memory.addMany([ + BaseMessage.of({ role: "system", text: "You are a helpful assistant." }), + BaseMessage.of({ role: "user", text: "Hello!" }), + BaseMessage.of({ role: "assistant", text: "Hello user. I am here to help you." }), +]); + +const tool = new LLMTool({ + llm: new OllamaChatLLM(), +}); + +const response = await tool + .run({ + task: "Classify whether the tone of text is POSITIVE/NEGATIVE/NEUTRAL.", + }) + .context({ + // if the context is not passed, the tool will throw an error + [Tool.contextKeys.Memory]: memory, + }); + +console.info(response.getTextContent()); diff --git a/src/agents/bee/runners/default/runner.ts b/src/agents/bee/runners/default/runner.ts index 979c99a2..e241c2ee 100644 --- a/src/agents/bee/runners/default/runner.ts +++ b/src/agents/bee/runners/default/runner.ts @@ -33,7 +33,7 @@ import { BeeUserEmptyPrompt, BeeUserPrompt, } from "@/agents/bee/prompts.js"; -import { AnyTool, ToolError, ToolInputValidationError, ToolOutput } from "@/tools/base.js"; +import { AnyTool, Tool, ToolError, ToolInputValidationError, ToolOutput } from "@/tools/base.js"; import { FrameworkError } from "@/errors.js"; import { isEmpty, isTruthy, last } from "remeda"; import { LinePrefixParser, LinePrefixParserError } from "@/agents/parsers/linePrefix.js"; @@ -194,7 +194,9 @@ export class DefaultRunner extends BaseRunner { }, meta, }); - const toolOutput: ToolOutput = await tool.run(state.tool_input, this.options); + const toolOutput: ToolOutput = await tool.run(state.tool_input, this.options).context({ + [Tool.contextKeys.Memory]: this.memory, + }); await emitter.emit("toolSuccess", { data: { tool, diff --git a/src/agents/experimental/replan/agent.ts b/src/agents/experimental/replan/agent.ts index dcd21549..560a5328 100644 --- a/src/agents/experimental/replan/agent.ts +++ b/src/agents/experimental/replan/agent.ts @@ -28,7 +28,7 @@ import { import { BaseMemory } from "@/memory/base.js"; import { UnconstrainedMemory } from "@/memory/unconstrainedMemory.js"; import { JsonDriver } from "@/llms/drivers/json.js"; -import { AnyTool } from "@/tools/base.js"; +import { AnyTool, Tool } from "@/tools/base.js"; import { AnyChatLLM } from "@/llms/chat.js"; export interface RePlanRunInput { @@ -147,7 +147,9 @@ export class RePlanAgent extends BaseAgent { const meta = { input: call, tool, calls }; await context.emitter.emit("tool", { type: "start", ...meta }); try { - const output = await tool.run(call.input, { signal: context.signal }); + const output = await tool.run(call.input, { signal: context.signal }).context({ + [Tool.contextKeys.Memory]: memory, + }); await context.emitter.emit("tool", { type: "success", ...meta, output }); return output; } catch (error) { diff --git a/src/tools/base.ts b/src/tools/base.ts index 85cbcec8..5f5d8d77 100644 --- a/src/tools/base.ts +++ b/src/tools/base.ts @@ -199,6 +199,10 @@ export abstract class Tool< public readonly cache: BaseCache>; public readonly options: TOptions; + public static contextKeys = { + Memory: Symbol("Memory"), + } as const; + public abstract readonly emitter: Emitter>; abstract inputSchema(): Promise | AnyToolSchemaLike; diff --git a/src/tools/llm.ts b/src/tools/llm.ts index 9dd8622b..ae45fd74 100644 --- a/src/tools/llm.ts +++ b/src/tools/llm.ts @@ -14,51 +14,93 @@ * limitations under the License. */ -import { BaseToolOptions, ToolEmitter, StringToolOutput, Tool, ToolInput } from "@/tools/base.js"; -import { AnyLLM, GenerateOptions } from "@/llms/base.js"; +import { + BaseToolOptions, + BaseToolRunOptions, + StringToolOutput, + Tool, + ToolEmitter, + ToolError, + ToolInput, +} from "@/tools/base.js"; import { z } from "zod"; +import { GetRunContext } from "@/context.js"; import { Emitter } from "@/emitter/emitter.js"; +import { PromptTemplate } from "@/template.js"; +import { BaseMessage, Role } from "@/llms/primitives/message.js"; +import { getProp } from "@/internals/helpers/object.js"; +import type { BaseMemory } from "@/memory/base.js"; +import type { AnyChatLLM } from "@/llms/chat.js"; +import { toCamelCase } from "remeda"; -export type LLMToolInput = string; - -export type LLMToolOptions = { - llm: AnyLLM; -} & BaseToolOptions & - (T extends LLMToolInput - ? { - transform?: (input: string) => T; - } - : { - transform: (input: string) => T; - }); - -export interface LLMToolRunOptions extends GenerateOptions, BaseToolOptions {} +export interface LLMToolInput extends BaseToolOptions { + llm: AnyChatLLM; + name?: string; + description?: string; + template?: typeof LLMTool.template; +} -export class LLMTool extends Tool, LLMToolRunOptions> { +export class LLMTool extends Tool { name = "LLM"; description = - "Give a prompt to an LLM assistant. Useful to extract and re-format information, and answer intermediate questions."; + "Uses expert LLM to work with data in the existing conversation (classification, entity extraction, summarization, ...)"; + declare readonly emitter: ToolEmitter, StringToolOutput>; + + constructor(protected readonly input: LLMToolInput) { + super(input); + this.name = input?.name || this.name; + this.description = input?.description || this.description; + this.emitter = Emitter.root.child({ + namespace: ["tool", "llm", toCamelCase(input?.name ?? "")].filter(Boolean), + creator: this, + }); + } inputSchema() { - return z.object({ input: z.string() }); + return z.object({ + task: z.string().min(1).describe("A clearly defined task for the LLM to complete."), + }); } - public readonly emitter: ToolEmitter, StringToolOutput> = Emitter.root.child({ - namespace: ["tool", "llm"], - creator: this, - }); + static readonly template = new PromptTemplate({ + schema: z.object({ + task: z.string(), + }), + template: `You have to accomplish a task by using Using common sense and the information contained in the conversation up to this point, complete the following task. Do not follow any previously used formats or structures. - static { - this.register(); - } +The Task: {{task}}`, + }); protected async _run( - { input }: ToolInput, - options?: LLMToolRunOptions, - ): Promise { - const { llm, transform } = this.options; - const llmInput = transform ? transform(input) : (input as T); - const response = await llm.generate(llmInput, options); - return new StringToolOutput(response.getTextContent(), response); + input: ToolInput, + _options: Partial, + run: GetRunContext, + ) { + const memory = getProp(run.context, [Tool.contextKeys.Memory]) as BaseMemory; + if (!memory) { + throw new ToolError(`No context has been provided!`, [], { + isFatal: true, + isRetryable: false, + }); + } + + const template = this.options?.template ?? LLMTool.template; + const output = await this.input.llm.generate([ + BaseMessage.of({ + role: Role.SYSTEM, + text: template.render({ + task: input.task, + }), + }), + ...memory.messages.filter((msg) => msg.role !== Role.SYSTEM), + BaseMessage.of({ + role: Role.USER, + text: template.render({ + task: input.task, + }), + }), + ]); + + return new StringToolOutput(output.getTextContent()); } }