Skip to content

Commit

Permalink
feat(tools): propagate agent's runner memory to tools (#242)
Browse files Browse the repository at this point in the history
Signed-off-by: Tomas Dvorak <toomas2d@gmail.com>
  • Loading branch information
Tomas2D authored Dec 11, 2024
1 parent ca2e09b commit 0407c66
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 37 deletions.
28 changes: 28 additions & 0 deletions examples/tools/llm.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import "dotenv/config";
import { LLMTool } from "bee-agent-framework/tools/llm";
import { OllamaChatLLM } from "bee-agent-framework/adapters/ollama/chat";
import { Tool } from "bee-agent-framework/tools/base";
import { UnconstrainedMemory } from "bee-agent-framework/memory/unconstrainedMemory";
import { BaseMessage } from "bee-agent-framework/llms/primitives/message";

const memory = new UnconstrainedMemory();
await memory.addMany([
BaseMessage.of({ role: "system", text: "You are a helpful assistant." }),
BaseMessage.of({ role: "user", text: "Hello!" }),
BaseMessage.of({ role: "assistant", text: "Hello user. I am here to help you." }),
]);

const tool = new LLMTool({
llm: new OllamaChatLLM(),
});

const response = await tool
.run({
task: "Classify whether the tone of text is POSITIVE/NEGATIVE/NEUTRAL.",
})
.context({
// if the context is not passed, the tool will throw an error
[Tool.contextKeys.Memory]: memory,
});

console.info(response.getTextContent());
6 changes: 4 additions & 2 deletions src/agents/bee/runners/default/runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import {
BeeUserEmptyPrompt,
BeeUserPrompt,
} from "@/agents/bee/prompts.js";
import { AnyTool, ToolError, ToolInputValidationError, ToolOutput } from "@/tools/base.js";
import { AnyTool, Tool, ToolError, ToolInputValidationError, ToolOutput } from "@/tools/base.js";
import { FrameworkError } from "@/errors.js";
import { isEmpty, isTruthy, last } from "remeda";
import { LinePrefixParser, LinePrefixParserError } from "@/agents/parsers/linePrefix.js";
Expand Down Expand Up @@ -194,7 +194,9 @@ export class DefaultRunner extends BaseRunner {
},
meta,
});
const toolOutput: ToolOutput = await tool.run(state.tool_input, this.options);
const toolOutput: ToolOutput = await tool.run(state.tool_input, this.options).context({
[Tool.contextKeys.Memory]: this.memory,
});
await emitter.emit("toolSuccess", {
data: {
tool,
Expand Down
6 changes: 4 additions & 2 deletions src/agents/experimental/replan/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import {
import { BaseMemory } from "@/memory/base.js";
import { UnconstrainedMemory } from "@/memory/unconstrainedMemory.js";
import { JsonDriver } from "@/llms/drivers/json.js";
import { AnyTool } from "@/tools/base.js";
import { AnyTool, Tool } from "@/tools/base.js";
import { AnyChatLLM } from "@/llms/chat.js";

export interface RePlanRunInput {
Expand Down Expand Up @@ -147,7 +147,9 @@ export class RePlanAgent extends BaseAgent<RePlanRunInput, RePlanRunOutput> {
const meta = { input: call, tool, calls };
await context.emitter.emit("tool", { type: "start", ...meta });
try {
const output = await tool.run(call.input, { signal: context.signal });
const output = await tool.run(call.input, { signal: context.signal }).context({
[Tool.contextKeys.Memory]: memory,
});
await context.emitter.emit("tool", { type: "success", ...meta, output });
return output;
} catch (error) {
Expand Down
4 changes: 4 additions & 0 deletions src/tools/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ export abstract class Tool<
public readonly cache: BaseCache<Task<TOutput>>;
public readonly options: TOptions;

public static contextKeys = {
Memory: Symbol("Memory"),
} as const;

public abstract readonly emitter: Emitter<ToolEvents<any, TOutput>>;

abstract inputSchema(): Promise<AnyToolSchemaLike> | AnyToolSchemaLike;
Expand Down
108 changes: 75 additions & 33 deletions src/tools/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,51 +14,93 @@
* limitations under the License.
*/

import { BaseToolOptions, ToolEmitter, StringToolOutput, Tool, ToolInput } from "@/tools/base.js";
import { AnyLLM, GenerateOptions } from "@/llms/base.js";
import {
BaseToolOptions,
BaseToolRunOptions,
StringToolOutput,
Tool,
ToolEmitter,
ToolError,
ToolInput,
} from "@/tools/base.js";
import { z } from "zod";
import { GetRunContext } from "@/context.js";
import { Emitter } from "@/emitter/emitter.js";
import { PromptTemplate } from "@/template.js";
import { BaseMessage, Role } from "@/llms/primitives/message.js";
import { getProp } from "@/internals/helpers/object.js";
import type { BaseMemory } from "@/memory/base.js";
import type { AnyChatLLM } from "@/llms/chat.js";
import { toCamelCase } from "remeda";

export type LLMToolInput = string;

export type LLMToolOptions<T> = {
llm: AnyLLM<T>;
} & BaseToolOptions &
(T extends LLMToolInput
? {
transform?: (input: string) => T;
}
: {
transform: (input: string) => T;
});

export interface LLMToolRunOptions extends GenerateOptions, BaseToolOptions {}
export interface LLMToolInput extends BaseToolOptions {
llm: AnyChatLLM;
name?: string;
description?: string;
template?: typeof LLMTool.template;
}

export class LLMTool<T> extends Tool<StringToolOutput, LLMToolOptions<T>, LLMToolRunOptions> {
export class LLMTool extends Tool<StringToolOutput, LLMToolInput> {
name = "LLM";
description =
"Give a prompt to an LLM assistant. Useful to extract and re-format information, and answer intermediate questions.";
"Uses expert LLM to work with data in the existing conversation (classification, entity extraction, summarization, ...)";
declare readonly emitter: ToolEmitter<ToolInput<this>, StringToolOutput>;

constructor(protected readonly input: LLMToolInput) {
super(input);
this.name = input?.name || this.name;
this.description = input?.description || this.description;
this.emitter = Emitter.root.child({
namespace: ["tool", "llm", toCamelCase(input?.name ?? "")].filter(Boolean),
creator: this,
});
}

inputSchema() {
return z.object({ input: z.string() });
return z.object({
task: z.string().min(1).describe("A clearly defined task for the LLM to complete."),
});
}

public readonly emitter: ToolEmitter<ToolInput<this>, StringToolOutput> = Emitter.root.child({
namespace: ["tool", "llm"],
creator: this,
});
static readonly template = new PromptTemplate({
schema: z.object({
task: z.string(),
}),
template: `You have to accomplish a task by using Using common sense and the information contained in the conversation up to this point, complete the following task. Do not follow any previously used formats or structures.
static {
this.register();
}
The Task: {{task}}`,
});

protected async _run(
{ input }: ToolInput<this>,
options?: LLMToolRunOptions,
): Promise<StringToolOutput> {
const { llm, transform } = this.options;
const llmInput = transform ? transform(input) : (input as T);
const response = await llm.generate(llmInput, options);
return new StringToolOutput(response.getTextContent(), response);
input: ToolInput<this>,
_options: Partial<BaseToolRunOptions>,
run: GetRunContext<this>,
) {
const memory = getProp(run.context, [Tool.contextKeys.Memory]) as BaseMemory;
if (!memory) {
throw new ToolError(`No context has been provided!`, [], {
isFatal: true,
isRetryable: false,
});
}

const template = this.options?.template ?? LLMTool.template;
const output = await this.input.llm.generate([
BaseMessage.of({
role: Role.SYSTEM,
text: template.render({
task: input.task,
}),
}),
...memory.messages.filter((msg) => msg.role !== Role.SYSTEM),
BaseMessage.of({
role: Role.USER,
text: template.render({
task: input.task,
}),
}),
]);

return new StringToolOutput(output.getTextContent());
}
}

0 comments on commit 0407c66

Please sign in to comment.