forked from FlowiseAI/Flowise
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature/introducting-conversational-retrieval-tool-agent (FlowiseAI#2430
) * introducting openai-conversational-retriever-agent * fix lint * fix build * rename + update description * changing agent base from openai to tool agent * adding author for community agent
- Loading branch information
1 parent
9e8de50
commit a857119
Showing
3 changed files
with
287 additions
and
0 deletions.
There are no files selected for viewing
286 changes: 286 additions & 0 deletions
286
...ponents/nodes/agents/ConversationalRetrievalToolAgent/ConversationalRetrievalToolAgent.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,286 @@ | ||
import { flatten } from 'lodash' | ||
import { BaseMessage } from '@langchain/core/messages' | ||
import { ChainValues } from '@langchain/core/utils/types' | ||
import { RunnableSequence } from '@langchain/core/runnables' | ||
import { BaseChatModel } from '@langchain/core/language_models/chat_models' | ||
import { ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate, PromptTemplate } from '@langchain/core/prompts' | ||
import { formatToOpenAIToolMessages } from 'langchain/agents/format_scratchpad/openai_tools' | ||
import { getBaseClasses } from '../../../src/utils' | ||
import { type ToolsAgentStep } from 'langchain/agents/openai/output_parser' | ||
import { FlowiseMemory, ICommonObject, INode, INodeData, INodeParams, IUsedTool, IVisionChatModal } from '../../../src/Interface' | ||
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' | ||
import { AgentExecutor, ToolCallingAgentOutputParser } from '../../../src/agents' | ||
import { Moderation, checkInputs, streamResponse } from '../../moderation/Moderation' | ||
import { formatResponse } from '../../outputparsers/OutputParserHelpers' | ||
import type { Document } from '@langchain/core/documents' | ||
import { BaseRetriever } from '@langchain/core/retrievers' | ||
import { RESPONSE_TEMPLATE } from '../../chains/ConversationalRetrievalQAChain/prompts' | ||
import { addImagesToMessages, llmSupportsVision } from '../../../src/multiModalUtils' | ||
|
||
class ConversationalRetrievalToolAgent_Agents implements INode { | ||
label: string | ||
name: string | ||
author: string | ||
version: number | ||
description: string | ||
type: string | ||
icon: string | ||
category: string | ||
baseClasses: string[] | ||
inputs: INodeParams[] | ||
sessionId?: string | ||
badge?: string | ||
|
||
constructor(fields?: { sessionId?: string }) { | ||
this.label = 'Conversational Retrieval Tool Agent' | ||
this.name = 'conversationalRetrievalToolAgent' | ||
this.author = 'niztal(falkor)' | ||
this.version = 1.0 | ||
this.type = 'AgentExecutor' | ||
this.category = 'Agents' | ||
this.icon = 'toolAgent.png' | ||
this.description = `Agent that calls a vector store retrieval and uses Function Calling to pick the tools and args to call` | ||
this.baseClasses = [this.type, ...getBaseClasses(AgentExecutor)] | ||
this.badge = 'NEW' | ||
this.inputs = [ | ||
{ | ||
label: 'Tools', | ||
name: 'tools', | ||
type: 'Tool', | ||
list: true | ||
}, | ||
{ | ||
label: 'Memory', | ||
name: 'memory', | ||
type: 'BaseChatMemory' | ||
}, | ||
{ | ||
label: 'Tool Calling Chat Model', | ||
name: 'model', | ||
type: 'BaseChatModel', | ||
description: | ||
'Only compatible with models that are capable of function calling. ChatOpenAI, ChatMistral, ChatAnthropic, ChatVertexAI' | ||
}, | ||
{ | ||
label: 'System Message', | ||
name: 'systemMessage', | ||
type: 'string', | ||
description: 'Taking the rephrased question, search for answer from the provided context', | ||
warning: 'Prompt must include input variable: {context}', | ||
rows: 4, | ||
additionalParams: true, | ||
optional: true, | ||
default: RESPONSE_TEMPLATE | ||
}, | ||
{ | ||
label: 'Input Moderation', | ||
description: 'Detect text that could generate harmful output and prevent it from being sent to the language model', | ||
name: 'inputModeration', | ||
type: 'Moderation', | ||
optional: true, | ||
list: true | ||
}, | ||
{ | ||
label: 'Max Iterations', | ||
name: 'maxIterations', | ||
type: 'number', | ||
optional: true, | ||
additionalParams: true | ||
}, | ||
{ | ||
label: 'Vector Store Retriever', | ||
name: 'vectorStoreRetriever', | ||
type: 'BaseRetriever' | ||
} | ||
] | ||
this.sessionId = fields?.sessionId | ||
} | ||
|
||
async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> { | ||
return prepareAgent(nodeData, options, { sessionId: this.sessionId, chatId: options.chatId, input }) | ||
} | ||
|
||
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> { | ||
const memory = nodeData.inputs?.memory as FlowiseMemory | ||
const moderations = nodeData.inputs?.inputModeration as Moderation[] | ||
|
||
const isStreamable = options.socketIO && options.socketIOClientId | ||
|
||
if (moderations && moderations.length > 0) { | ||
try { | ||
// Use the output of the moderation chain as input for the OpenAI Function Agent | ||
input = await checkInputs(moderations, input) | ||
} catch (e) { | ||
await new Promise((resolve) => setTimeout(resolve, 500)) | ||
if (isStreamable) | ||
streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) | ||
return formatResponse(e.message) | ||
} | ||
} | ||
|
||
const executor = await prepareAgent(nodeData, options, { sessionId: this.sessionId, chatId: options.chatId, input }) | ||
|
||
const loggerHandler = new ConsoleCallbackHandler(options.logger) | ||
const callbacks = await additionalCallbacks(nodeData, options) | ||
|
||
let res: ChainValues = {} | ||
let sourceDocuments: ICommonObject[] = [] | ||
let usedTools: IUsedTool[] = [] | ||
|
||
if (isStreamable) { | ||
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) | ||
res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] }) | ||
if (res.sourceDocuments) { | ||
options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments)) | ||
sourceDocuments = res.sourceDocuments | ||
} | ||
if (res.usedTools) { | ||
options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools) | ||
usedTools = res.usedTools | ||
} | ||
} else { | ||
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] }) | ||
if (res.sourceDocuments) { | ||
sourceDocuments = res.sourceDocuments | ||
} | ||
if (res.usedTools) { | ||
usedTools = res.usedTools | ||
} | ||
} | ||
|
||
let output = res?.output as string | ||
|
||
// Claude 3 Opus tends to spit out <thinking>..</thinking> as well, discard that in final output | ||
const regexPattern: RegExp = /<thinking>[\s\S]*?<\/thinking>/ | ||
const matches: RegExpMatchArray | null = output.match(regexPattern) | ||
if (matches) { | ||
for (const match of matches) { | ||
output = output.replace(match, '') | ||
} | ||
} | ||
|
||
await memory.addChatMessages( | ||
[ | ||
{ | ||
text: input, | ||
type: 'userMessage' | ||
}, | ||
{ | ||
text: output, | ||
type: 'apiMessage' | ||
} | ||
], | ||
this.sessionId | ||
) | ||
|
||
let finalRes = res?.output | ||
|
||
if (sourceDocuments.length || usedTools.length) { | ||
const finalRes: ICommonObject = { text: output } | ||
if (sourceDocuments.length) { | ||
finalRes.sourceDocuments = flatten(sourceDocuments) | ||
} | ||
if (usedTools.length) { | ||
finalRes.usedTools = usedTools | ||
} | ||
return finalRes | ||
} | ||
|
||
return finalRes | ||
} | ||
} | ||
|
||
const formatDocs = (docs: Document[]) => { | ||
return docs.map((doc, i) => `<doc id='${i}'>${doc.pageContent}</doc>`).join('\n') | ||
} | ||
|
||
const prepareAgent = async ( | ||
nodeData: INodeData, | ||
options: ICommonObject, | ||
flowObj: { sessionId?: string; chatId?: string; input?: string } | ||
) => { | ||
const model = nodeData.inputs?.model as BaseChatModel | ||
const maxIterations = nodeData.inputs?.maxIterations as string | ||
const memory = nodeData.inputs?.memory as FlowiseMemory | ||
const systemMessage = nodeData.inputs?.systemMessage as string | ||
let tools = nodeData.inputs?.tools | ||
tools = flatten(tools) | ||
const memoryKey = memory.memoryKey ? memory.memoryKey : 'chat_history' | ||
const inputKey = memory.inputKey ? memory.inputKey : 'input' | ||
const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever as BaseRetriever | ||
|
||
const prompt = ChatPromptTemplate.fromMessages([ | ||
['system', systemMessage ? systemMessage : `You are a helpful AI assistant.`], | ||
new MessagesPlaceholder(memoryKey), | ||
['human', `{${inputKey}}`], | ||
new MessagesPlaceholder('agent_scratchpad') | ||
]) | ||
|
||
if (llmSupportsVision(model)) { | ||
const visionChatModel = model as IVisionChatModal | ||
const messageContent = await addImagesToMessages(nodeData, options, model.multiModalOption) | ||
|
||
if (messageContent?.length) { | ||
visionChatModel.setVisionModel() | ||
|
||
// Pop the `agent_scratchpad` MessagePlaceHolder | ||
let messagePlaceholder = prompt.promptMessages.pop() as MessagesPlaceholder | ||
if (prompt.promptMessages.at(-1) instanceof HumanMessagePromptTemplate) { | ||
const lastMessage = prompt.promptMessages.pop() as HumanMessagePromptTemplate | ||
const template = (lastMessage.prompt as PromptTemplate).template as string | ||
const msg = HumanMessagePromptTemplate.fromTemplate([ | ||
...messageContent, | ||
{ | ||
text: template | ||
} | ||
]) | ||
msg.inputVariables = lastMessage.inputVariables | ||
prompt.promptMessages.push(msg) | ||
} | ||
|
||
// Add the `agent_scratchpad` MessagePlaceHolder back | ||
prompt.promptMessages.push(messagePlaceholder) | ||
} else { | ||
visionChatModel.revertToOriginalModel() | ||
} | ||
} | ||
|
||
if (model.bindTools === undefined) { | ||
throw new Error(`This agent requires that the "bindTools()" method be implemented on the input model.`) | ||
} | ||
|
||
const modelWithTools = model.bindTools(tools) | ||
|
||
const runnableAgent = RunnableSequence.from([ | ||
{ | ||
[inputKey]: (i: { input: string; steps: ToolsAgentStep[] }) => i.input, | ||
agent_scratchpad: (i: { input: string; steps: ToolsAgentStep[] }) => formatToOpenAIToolMessages(i.steps), | ||
[memoryKey]: async (_: { input: string; steps: ToolsAgentStep[] }) => { | ||
const messages = (await memory.getChatMessages(flowObj?.sessionId, true)) as BaseMessage[] | ||
return messages ?? [] | ||
}, | ||
context: async (i: { input: string; chatHistory?: string }) => { | ||
const relevantDocs = await vectorStoreRetriever.invoke(i.input) | ||
const formattedDocs = formatDocs(relevantDocs) | ||
return formattedDocs | ||
} | ||
}, | ||
prompt, | ||
modelWithTools, | ||
new ToolCallingAgentOutputParser() | ||
]) | ||
|
||
const executor = AgentExecutor.fromAgentAndTools({ | ||
agent: runnableAgent, | ||
tools, | ||
sessionId: flowObj?.sessionId, | ||
chatId: flowObj?.chatId, | ||
input: flowObj?.input, | ||
verbose: process.env.DEBUG === 'true' ? true : false, | ||
maxIterations: maxIterations ? parseFloat(maxIterations) : undefined | ||
}) | ||
|
||
return executor | ||
} | ||
|
||
module.exports = { nodeClass: ConversationalRetrievalToolAgent_Agents } |
Binary file added
BIN
+17.1 KB
packages/components/nodes/agents/ConversationalRetrievalToolAgent/toolAgent.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters