Skip to content

Commit

Permalink
Bugfix/Pass state to tool node for agents (#3139)
Browse files Browse the repository at this point in the history
pass state to tool node for agents
  • Loading branch information
HenryHengZJ authored Sep 3, 2024
1 parent 2a21f18 commit e691838
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions packages/components/nodes/sequentialagents/Agent/Agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ import {
ISeqAgentNode,
IDatabaseEntity,
IUsedTool,
IDocument
IDocument,
IStateWithMessages
} from '../../../src/Interface'
import { ToolCallingAgentOutputParser, AgentExecutor, SOURCE_DOCUMENTS_PREFIX } from '../../../src/agents'
import { getInputVariables, getVars, handleEscapeCharacters, prepareSandboxVars } from '../../../src/utils'
Expand All @@ -34,6 +35,7 @@ import {
} from '../commonUtils'
import { END, StateGraph } from '@langchain/langgraph'
import { StructuredTool } from '@langchain/core/tools'
import { DynamicStructuredTool } from '../../tools/CustomTool/core'

const defaultApprovalPrompt = `You are about to execute tool: {tools}. Ask if user want to proceed`
const examplePrompt = 'You are a research assistant who can search for up-to-date info using search engine.'
Expand Down Expand Up @@ -904,18 +906,44 @@ class ToolNode<T extends BaseMessage[] | MessagesState> extends RunnableCallable
}

private async run(input: BaseMessage[] | MessagesState, config: RunnableConfig): Promise<BaseMessage[] | MessagesState> {
const message = Array.isArray(input) ? input[input.length - 1] : input.messages[input.messages.length - 1]
let messages: BaseMessage[]

// Check if input is an array of BaseMessage[]
if (Array.isArray(input)) {
messages = input
}
// Check if input is IStateWithMessages
else if ((input as IStateWithMessages).messages) {
messages = (input as IStateWithMessages).messages
}
// Handle MessagesState type
else {
messages = (input as MessagesState).messages
}

// Get the last message
const message = messages[messages.length - 1]

if (message._getType() !== 'ai') {
throw new Error('ToolNode only accepts AIMessages as input.')
}

// Extract all properties except messages for IStateWithMessages
const { messages: _, ...inputWithoutMessages } = Array.isArray(input) ? { messages: input } : input
const ChannelsWithoutMessages = {
state: inputWithoutMessages
}

const outputs = await Promise.all(
(message as AIMessage).tool_calls?.map(async (call) => {
const tool = this.tools.find((tool) => tool.name === call.name)
if (tool === undefined) {
throw new Error(`Tool ${call.name} not found.`)
}
if (tool && tool instanceof DynamicStructuredTool) {
// @ts-ignore
tool.setFlowObject(ChannelsWithoutMessages)
}
let output = await tool.invoke(call.args, config)
let sourceDocuments: Document[] = []
if (output?.includes(SOURCE_DOCUMENTS_PREFIX)) {
Expand Down

0 comments on commit e691838

Please sign in to comment.