Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use selected llamacloud for multiagent #359

Merged
5 changes: 5 additions & 0 deletions .changeset/stupid-paws-push.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"create-llama": patch
---

feat: use selected llamacloud for multiagent
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,26 @@ function getValidAnnotation(annotation: JSONValue): Annotation {
}
return { type: annotation.type, data: annotation.data };
}

// validate and get all annotations of a specific type or role from the frontend messages
export function getAnnotations<
T extends Annotation["data"] = Annotation["data"],
>(
messages: Message[],
options?: {
role?: Message["role"]; // message role
type?: Annotation["type"]; // annotation type
},
): {
type: string;
data: T;
}[] {
const messagesByRole = options?.role
? messages.filter((msg) => msg.role === options?.role)
: messages;
const annotations = getAllAnnotations(messagesByRole);
const annotationsByType = options?.type
? annotations.filter((a) => a.type === options.type)
: annotations;
return annotationsByType as { type: string; data: T }[];
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import { StopEvent } from "@llamaindex/core/workflow";
import { Message, streamToResponse } from "ai";
import { Request, Response } from "express";
import { ChatMessage, ChatResponseChunk } from "llamaindex";
import { ChatResponseChunk } from "llamaindex";
import { createWorkflow } from "./workflow/factory";
import { toDataStream, workflowEventsToStreamData } from "./workflow/stream";

export const chat = async (req: Request, res: Response) => {
try {
const { messages }: { messages: Message[] } = req.body;
const { messages, data }: { messages: Message[]; data?: any } = req.body;
const userMessage = messages.pop();
if (!messages || !userMessage || userMessage.role !== "user") {
return res.status(400).json({
Expand All @@ -16,8 +16,7 @@ export const chat = async (req: Request, res: Response) => {
});
}

const chatHistory = messages as ChatMessage[];
const agent = createWorkflow(chatHistory);
const agent = createWorkflow(messages, data);
const result = agent.run<AsyncGenerator<ChatResponseChunk>>(
userMessage.content,
) as unknown as Promise<StopEvent<AsyncGenerator<ChatResponseChunk>>>;
Expand Down
7 changes: 3 additions & 4 deletions templates/components/multiagent/typescript/nextjs/route.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { initObservability } from "@/app/observability";
import { StopEvent } from "@llamaindex/core/workflow";
import { Message, StreamingTextResponse } from "ai";
import { ChatMessage, ChatResponseChunk } from "llamaindex";
import { ChatResponseChunk } from "llamaindex";
import { NextRequest, NextResponse } from "next/server";
import { initSettings } from "./engine/settings";
import { createWorkflow } from "./workflow/factory";
Expand All @@ -16,7 +16,7 @@ export const dynamic = "force-dynamic";
export async function POST(request: NextRequest) {
try {
const body = await request.json();
const { messages }: { messages: Message[] } = body;
const { messages, data }: { messages: Message[]; data?: any } = body;
thucpn marked this conversation as resolved.
Show resolved Hide resolved
const userMessage = messages.pop();
if (!messages || !userMessage || userMessage.role !== "user") {
return NextResponse.json(
Expand All @@ -28,8 +28,7 @@ export async function POST(request: NextRequest) {
);
}

const chatHistory = messages as ChatMessage[];
const agent = createWorkflow(chatHistory);
const agent = createWorkflow(messages, data);
// TODO: fix type in agent.run in LITS
const result = agent.run<AsyncGenerator<ChatResponseChunk>>(
userMessage.content,
Expand Down
21 changes: 13 additions & 8 deletions templates/components/multiagent/typescript/workflow/agents.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import { ChatMessage } from "llamaindex";
import { FunctionCallingAgent } from "./single-agent";
import { lookupTools } from "./tools";
import { getQueryEngineTool, lookupTools } from "./tools";

export const createResearcher = async (chatHistory: ChatMessage[]) => {
const tools = await lookupTools([
"query_index",
"wikipedia_tool",
"duckduckgo_search",
"image_generator",
]);
export const createResearcher = async (
chatHistory: ChatMessage[],
params?: any,
) => {
const queryEngineTool = await getQueryEngineTool(params);
const tools = (
await lookupTools([
"wikipedia_tool",
"duckduckgo_search",
"image_generator",
])
).concat(queryEngineTool ? [queryEngineTool] : []);

return new FunctionCallingAgent({
name: "researcher",
Expand Down
29 changes: 15 additions & 14 deletions templates/components/multiagent/typescript/workflow/factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import {
Workflow,
WorkflowEvent,
} from "@llamaindex/core/workflow";
import { Message } from "ai";
import { ChatMessage, ChatResponseChunk, Settings } from "llamaindex";
import { getAnnotations } from "../llamaindex/streaming/annotations";
import {
createPublisher,
createResearcher,
Expand All @@ -25,19 +27,15 @@ class WriteEvent extends WorkflowEvent<{
class ReviewEvent extends WorkflowEvent<{ input: string }> {}
class PublishEvent extends WorkflowEvent<{ input: string }> {}

const prepareChatHistory = (chatHistory: ChatMessage[]) => {
const prepareChatHistory = (chatHistory: Message[]): ChatMessage[] => {
// By default, the chat history only contains the assistant and user messages
// all the agents messages are stored in annotation data which is not visible to the LLM

const MAX_AGENT_MESSAGES = 10;

// Construct a new agent message from agent messages
// Get annotations from assistant messages
const agentAnnotations = chatHistory
.filter((msg) => msg.role === "assistant")
.flatMap((msg) => msg.annotations || [])
.filter((annotation) => annotation.type === "agent")
.slice(-MAX_AGENT_MESSAGES);
const agentAnnotations = getAnnotations<{ agent: string; text: string }>(
chatHistory,
{ role: "assistant", type: "agent" },
).slice(-MAX_AGENT_MESSAGES);

const agentMessages = agentAnnotations
.map(
Expand All @@ -59,13 +57,13 @@ const prepareChatHistory = (chatHistory: ChatMessage[]) => {
...chatHistory.slice(0, -1),
agentMessage,
chatHistory.slice(-1)[0],
];
] as ChatMessage[];
}
return chatHistory;
return chatHistory as ChatMessage[];
};

export const createWorkflow = (chatHistory: ChatMessage[]) => {
const chatHistoryWithAgentMessages = prepareChatHistory(chatHistory);
export const createWorkflow = (messages: Message[], params?: any) => {
const chatHistoryWithAgentMessages = prepareChatHistory(messages);
const runAgent = async (
context: Context,
agent: Workflow,
Expand Down Expand Up @@ -123,7 +121,10 @@ Decision (respond with either 'not_publish' or 'publish'):`;
};

const research = async (context: Context, ev: ResearchEvent) => {
const researcher = await createResearcher(chatHistoryWithAgentMessages);
const researcher = await createResearcher(
chatHistoryWithAgentMessages,
params,
);
const researchRes = await runAgent(context, researcher, {
message: ev.data.input,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ export class FunctionCallingAgent extends Workflow {
fullResponse = chunk;
}

if (fullResponse) {
if (fullResponse?.options && Object.keys(fullResponse.options).length) {
memory.put({
role: "assistant",
content: "",
Expand Down
6 changes: 4 additions & 2 deletions templates/components/multiagent/typescript/workflow/tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import path from "path";
import { getDataSource } from "../engine";
import { createTools } from "../engine/tools/index";

const getQueryEngineTool = async (): Promise<QueryEngineTool | null> => {
const index = await getDataSource();
export const getQueryEngineTool = async (
params?: any,
): Promise<QueryEngineTool | null> => {
const index = await getDataSource(params);
thucpn marked this conversation as resolved.
Show resolved Hide resolved
if (!index) {
return null;
}
Expand Down
Loading