diff --git a/.changeset/yellow-jokes-protect.md b/.changeset/yellow-jokes-protect.md new file mode 100644 index 000000000..2e9d7f199 --- /dev/null +++ b/.changeset/yellow-jokes-protect.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Add multi agents template for Typescript diff --git a/e2e/multiagent_template.spec.ts b/e2e/multiagent_template.spec.ts index 619b8cd15..061c8a472 100644 --- a/e2e/multiagent_template.spec.ts +++ b/e2e/multiagent_template.spec.ts @@ -10,19 +10,19 @@ import type { } from "../helpers"; import { createTestDir, runCreateLlama, type AppType } from "./utils"; -const templateFramework: TemplateFramework = "fastapi"; +const templateFramework: TemplateFramework = process.env.FRAMEWORK + ? (process.env.FRAMEWORK as TemplateFramework) + : "fastapi"; const dataSource: string = "--example-file"; const templateUI: TemplateUI = "shadcn"; const templatePostInstallAction: TemplatePostInstallAction = "runApp"; -const appType: AppType = "--frontend"; +const appType: AppType = templateFramework === "nextjs" ? "" : "--frontend"; const userMessage = "Write a blog post about physical standards for letters"; test.describe(`Test multiagent template ${templateFramework} ${dataSource} ${templateUI} ${appType} ${templatePostInstallAction}`, async () => { test.skip( - process.platform !== "linux" || - process.env.FRAMEWORK !== "fastapi" || - process.env.DATASOURCE === "--no-files", - "The multiagent template currently only works with FastAPI and files. We also only run on Linux to speed up tests.", + process.platform !== "linux" || process.env.DATASOURCE === "--no-files", + "The multiagent template currently only works with files. We also only run on Linux to speed up tests.", ); let port: number; let externalPort: number; diff --git a/helpers/typescript.ts b/helpers/typescript.ts index 6dc3d6b78..ffebae4a3 100644 --- a/helpers/typescript.ts +++ b/helpers/typescript.ts @@ -33,8 +33,7 @@ export const installTSTemplate = async ({ * Copy the template files to the target directory. */ console.log("\nInitializing project with template:", template, "\n"); - const type = template === "multiagent" ? "streaming" : template; // use nextjs streaming template for multiagent - const templatePath = path.join(templatesDir, "types", type, framework); + const templatePath = path.join(templatesDir, "types", "streaming", framework); const copySource = ["**"]; await copy(copySource, root, { @@ -124,6 +123,30 @@ export const installTSTemplate = async ({ cwd: path.join(compPath, "vectordbs", "typescript", vectorDb ?? "none"), }); + if (template === "multiagent") { + const multiagentPath = path.join(compPath, "multiagent", "typescript"); + + // copy workflow code for multiagent template + await copy("**", path.join(root, relativeEngineDestPath, "workflow"), { + parents: true, + cwd: path.join(multiagentPath, "workflow"), + }); + + if (framework === "nextjs") { + // patch route.ts file + await copy("**", path.join(root, relativeEngineDestPath), { + parents: true, + cwd: path.join(multiagentPath, "nextjs"), + }); + } else if (framework === "express") { + // patch chat.controller.ts file + await copy("**", path.join(root, relativeEngineDestPath), { + parents: true, + cwd: path.join(multiagentPath, "express"), + }); + } + } + // copy loader component (TS only supports llama_parse and file for now) const loaderFolder = useLlamaParse ? "llama_parse" : "file"; await copy("**", enginePath, { @@ -145,6 +168,11 @@ export const installTSTemplate = async ({ cwd: path.join(compPath, "engines", "typescript", engine), }); + // copy settings to engine folder + await copy("**", enginePath, { + cwd: path.join(compPath, "settings", "typescript"), + }); + /** * Copy the selected UI files to the target directory and reference it. */ diff --git a/questions.ts b/questions.ts index 973aeba5a..3619447cd 100644 --- a/questions.ts +++ b/questions.ts @@ -410,10 +410,7 @@ export const askQuestions = async ( return; // early return - no further questions needed for llamapack projects } - if (program.template === "multiagent") { - // TODO: multi-agents currently only supports FastAPI - program.framework = preferences.framework = "fastapi"; - } else if (program.template === "extractor") { + if (program.template === "extractor") { // Extractor template only supports FastAPI, empty data sources, and llamacloud // So we just use example file for extractor template, this allows user to choose vector database later program.dataSources = [EXAMPLE_FILE]; diff --git a/templates/components/multiagent/typescript/express/chat.controller.ts b/templates/components/multiagent/typescript/express/chat.controller.ts new file mode 100644 index 000000000..46be6d781 --- /dev/null +++ b/templates/components/multiagent/typescript/express/chat.controller.ts @@ -0,0 +1,41 @@ +import { StopEvent } from "@llamaindex/core/workflow"; +import { Message, streamToResponse } from "ai"; +import { Request, Response } from "express"; +import { ChatMessage, ChatResponseChunk } from "llamaindex"; +import { createWorkflow } from "./workflow/factory"; +import { toDataStream, workflowEventsToStreamData } from "./workflow/stream"; + +export const chat = async (req: Request, res: Response) => { + try { + const { messages }: { messages: Message[] } = req.body; + const userMessage = messages.pop(); + if (!messages || !userMessage || userMessage.role !== "user") { + return res.status(400).json({ + error: + "messages are required in the request body and the last message must be from the user", + }); + } + + const chatHistory = messages as ChatMessage[]; + const agent = createWorkflow(chatHistory); + const result = agent.run>( + userMessage.content, + ) as unknown as Promise>>; + + // convert the workflow events to a vercel AI stream data object + const agentStreamData = await workflowEventsToStreamData( + agent.streamEvents(), + ); + // convert the workflow result to a vercel AI content stream + const stream = toDataStream(result, { + onFinal: () => agentStreamData.close(), + }); + + return streamToResponse(stream, res, {}, agentStreamData); + } catch (error) { + console.error("[LlamaIndex]", error); + return res.status(500).json({ + detail: (error as Error).message, + }); + } +}; diff --git a/templates/components/multiagent/typescript/nextjs/route.ts b/templates/components/multiagent/typescript/nextjs/route.ts new file mode 100644 index 000000000..04b403396 --- /dev/null +++ b/templates/components/multiagent/typescript/nextjs/route.ts @@ -0,0 +1,57 @@ +import { initObservability } from "@/app/observability"; +import { StopEvent } from "@llamaindex/core/workflow"; +import { Message, StreamingTextResponse } from "ai"; +import { ChatMessage, ChatResponseChunk } from "llamaindex"; +import { NextRequest, NextResponse } from "next/server"; +import { initSettings } from "./engine/settings"; +import { createWorkflow } from "./workflow/factory"; +import { toDataStream, workflowEventsToStreamData } from "./workflow/stream"; + +initObservability(); +initSettings(); + +export const runtime = "nodejs"; +export const dynamic = "force-dynamic"; + +export async function POST(request: NextRequest) { + try { + const body = await request.json(); + const { messages }: { messages: Message[] } = body; + const userMessage = messages.pop(); + if (!messages || !userMessage || userMessage.role !== "user") { + return NextResponse.json( + { + error: + "messages are required in the request body and the last message must be from the user", + }, + { status: 400 }, + ); + } + + const chatHistory = messages as ChatMessage[]; + const agent = createWorkflow(chatHistory); + // TODO: fix type in agent.run in LITS + const result = agent.run>( + userMessage.content, + ) as unknown as Promise>>; + // convert the workflow events to a vercel AI stream data object + const agentStreamData = await workflowEventsToStreamData( + agent.streamEvents(), + ); + // convert the workflow result to a vercel AI content stream + const stream = toDataStream(result, { + onFinal: () => agentStreamData.close(), + }); + return new StreamingTextResponse(stream, {}, agentStreamData); + } catch (error) { + console.error("[LlamaIndex]", error); + return NextResponse.json( + { + detail: (error as Error).message, + }, + { + status: 500, + }, + ); + } +} diff --git a/templates/components/multiagent/typescript/workflow/agents.ts b/templates/components/multiagent/typescript/workflow/agents.ts new file mode 100644 index 000000000..7abaa6d0a --- /dev/null +++ b/templates/components/multiagent/typescript/workflow/agents.ts @@ -0,0 +1,51 @@ +import { ChatMessage, QueryEngineTool } from "llamaindex"; +import { getDataSource } from "../engine"; +import { FunctionCallingAgent } from "./single-agent"; + +const getQueryEngineTool = async () => { + const index = await getDataSource(); + if (!index) { + throw new Error( + "StorageContext is empty - call 'npm run generate' to generate the storage first.", + ); + } + + const topK = process.env.TOP_K ? parseInt(process.env.TOP_K) : undefined; + return new QueryEngineTool({ + queryEngine: index.asQueryEngine({ + similarityTopK: topK, + }), + metadata: { + name: "query_index", + description: `Use this tool to retrieve information about the text corpus from the index.`, + }, + }); +}; + +export const createResearcher = async (chatHistory: ChatMessage[]) => { + return new FunctionCallingAgent({ + name: "researcher", + tools: [await getQueryEngineTool()], + systemPrompt: + "You are a researcher agent. You are given a researching task. You must use your tools to complete the research.", + chatHistory, + }); +}; + +export const createWriter = (chatHistory: ChatMessage[]) => { + return new FunctionCallingAgent({ + name: "writer", + systemPrompt: + "You are an expert in writing blog posts. You are given a task to write a blog post. Don't make up any information yourself.", + chatHistory, + }); +}; + +export const createReviewer = (chatHistory: ChatMessage[]) => { + return new FunctionCallingAgent({ + name: "reviewer", + systemPrompt: + "You are an expert in reviewing blog posts. You are given a task to review a blog post. Review the post for logical inconsistencies, ask critical questions, and provide suggestions for improvement. Furthermore, proofread the post for grammar and spelling errors. Only if the post is good enough for publishing, then you MUST return 'The post is good.'. In all other cases return your review.", + chatHistory, + }); +}; diff --git a/templates/components/multiagent/typescript/workflow/factory.ts b/templates/components/multiagent/typescript/workflow/factory.ts new file mode 100644 index 000000000..8853b08b2 --- /dev/null +++ b/templates/components/multiagent/typescript/workflow/factory.ts @@ -0,0 +1,133 @@ +import { + Context, + StartEvent, + StopEvent, + Workflow, + WorkflowEvent, +} from "@llamaindex/core/workflow"; +import { ChatMessage, ChatResponseChunk } from "llamaindex"; +import { createResearcher, createReviewer, createWriter } from "./agents"; +import { AgentInput, AgentRunEvent } from "./type"; + +const TIMEOUT = 360 * 1000; +const MAX_ATTEMPTS = 2; + +class ResearchEvent extends WorkflowEvent<{ input: string }> {} +class WriteEvent extends WorkflowEvent<{ + input: string; + isGood: boolean; +}> {} +class ReviewEvent extends WorkflowEvent<{ input: string }> {} + +export const createWorkflow = (chatHistory: ChatMessage[]) => { + const runAgent = async ( + context: Context, + agent: Workflow, + input: AgentInput, + ) => { + const run = agent.run(new StartEvent({ input })); + for await (const event of agent.streamEvents()) { + if (event.data instanceof AgentRunEvent) { + context.writeEventToStream(event.data); + } + } + return await run; + }; + + const start = async (context: Context, ev: StartEvent) => { + context.set("task", ev.data.input); + return new ResearchEvent({ + input: `Research for this task: ${ev.data.input}`, + }); + }; + + const research = async (context: Context, ev: ResearchEvent) => { + const researcher = await createResearcher(chatHistory); + const researchRes = await runAgent(context, researcher, { + message: ev.data.input, + }); + const researchResult = researchRes.data.result; + return new WriteEvent({ + input: `Write a blog post given this task: ${context.get("task")} using this research content: ${researchResult}`, + isGood: false, + }); + }; + + const write = async (context: Context, ev: WriteEvent) => { + context.set("attempts", context.get("attempts", 0) + 1); + const tooManyAttempts = context.get("attempts") > MAX_ATTEMPTS; + if (tooManyAttempts) { + context.writeEventToStream( + new AgentRunEvent({ + name: "writer", + msg: `Too many attempts (${MAX_ATTEMPTS}) to write the blog post. Proceeding with the current version.`, + }), + ); + } + + if (ev.data.isGood || tooManyAttempts) { + // The text is ready for publication, we just use the writer to stream the output + const writer = createWriter(chatHistory); + const content = context.get("result"); + + return (await runAgent(context, writer, { + message: `You're blog post is ready for publication. Please respond with just the blog post. Blog post: \`\`\`${content}\`\`\``, + streaming: true, + })) as unknown as StopEvent>; + } + + const writer = createWriter(chatHistory); + const writeRes = await runAgent(context, writer, { + message: ev.data.input, + }); + const writeResult = writeRes.data.result; + context.set("result", writeResult); // store the last result + return new ReviewEvent({ input: writeResult }); + }; + + const review = async (context: Context, ev: ReviewEvent) => { + const reviewer = createReviewer(chatHistory); + const reviewRes = await reviewer.run( + new StartEvent({ input: { message: ev.data.input } }), + ); + const reviewResult = reviewRes.data.result; + const oldContent = context.get("result"); + const postIsGood = reviewResult.toLowerCase().includes("post is good"); + context.writeEventToStream( + new AgentRunEvent({ + name: "reviewer", + msg: `The post is ${postIsGood ? "" : "not "}good enough for publishing. Sending back to the writer${ + postIsGood ? " for publication." : "." + }`, + }), + ); + if (postIsGood) { + return new WriteEvent({ + input: "", + isGood: true, + }); + } + + return new WriteEvent({ + input: `Improve the writing of a given blog post by using a given review. + Blog post: + \`\`\` + ${oldContent} + \`\`\` + + Review: + \`\`\` + ${reviewResult} + \`\`\``, + isGood: false, + }); + }; + + const workflow = new Workflow({ timeout: TIMEOUT, validate: true }); + workflow.addStep(StartEvent, start, { outputs: ResearchEvent }); + workflow.addStep(ResearchEvent, research, { outputs: WriteEvent }); + workflow.addStep(WriteEvent, write, { outputs: [ReviewEvent, StopEvent] }); + workflow.addStep(ReviewEvent, review, { outputs: WriteEvent }); + + return workflow; +}; diff --git a/templates/components/multiagent/typescript/workflow/single-agent.ts b/templates/components/multiagent/typescript/workflow/single-agent.ts new file mode 100644 index 000000000..568697dfb --- /dev/null +++ b/templates/components/multiagent/typescript/workflow/single-agent.ts @@ -0,0 +1,236 @@ +import { + Context, + StartEvent, + StopEvent, + Workflow, + WorkflowEvent, +} from "@llamaindex/core/workflow"; +import { + BaseToolWithCall, + ChatMemoryBuffer, + ChatMessage, + ChatResponse, + ChatResponseChunk, + Settings, + ToolCall, + ToolCallLLM, + ToolCallLLMMessageOptions, + callTool, +} from "llamaindex"; +import { AgentInput, AgentRunEvent } from "./type"; + +class InputEvent extends WorkflowEvent<{ + input: ChatMessage[]; +}> {} + +class ToolCallEvent extends WorkflowEvent<{ + toolCalls: ToolCall[]; +}> {} + +export class FunctionCallingAgent extends Workflow { + name: string; + llm: ToolCallLLM; + memory: ChatMemoryBuffer; + tools: BaseToolWithCall[]; + systemPrompt?: string; + writeEvents: boolean; + role?: string; + + constructor(options: { + name: string; + llm?: ToolCallLLM; + chatHistory?: ChatMessage[]; + tools?: BaseToolWithCall[]; + systemPrompt?: string; + writeEvents?: boolean; + role?: string; + verbose?: boolean; + timeout?: number; + }) { + super({ + verbose: options?.verbose ?? false, + timeout: options?.timeout ?? 360, + }); + this.name = options?.name; + this.llm = options.llm ?? (Settings.llm as ToolCallLLM); + this.checkToolCallSupport(); + this.memory = new ChatMemoryBuffer({ + llm: this.llm, + chatHistory: options.chatHistory, + }); + this.tools = options?.tools ?? []; + this.systemPrompt = options.systemPrompt; + this.writeEvents = options?.writeEvents ?? true; + this.role = options?.role; + + // add steps + this.addStep(StartEvent, this.prepareChatHistory, { + outputs: InputEvent, + }); + this.addStep(InputEvent, this.handleLLMInput, { + outputs: [ToolCallEvent, StopEvent], + }); + this.addStep(ToolCallEvent, this.handleToolCalls, { + outputs: InputEvent, + }); + } + + private get chatHistory() { + return this.memory.getMessages(); + } + + private async prepareChatHistory( + ctx: Context, + ev: StartEvent, + ): Promise { + const { message, streaming } = ev.data.input; + ctx.set("streaming", streaming); + this.writeEvent(`Start to work on: ${message}`, ctx); + if (this.systemPrompt) { + this.memory.put({ role: "system", content: this.systemPrompt }); + } + this.memory.put({ role: "user", content: message }); + return new InputEvent({ input: this.chatHistory }); + } + + private async handleLLMInput( + ctx: Context, + ev: InputEvent, + ): Promise | ToolCallEvent> { + if (ctx.get("streaming")) { + return await this.handleLLMInputStream(ctx, ev); + } + + const result = await this.llm.chat({ + messages: this.chatHistory, + tools: this.tools, + }); + this.memory.put(result.message); + + const toolCalls = this.getToolCallsFromResponse(result); + if (toolCalls.length) { + return new ToolCallEvent({ toolCalls }); + } + this.writeEvent("Finished task", ctx); + return new StopEvent({ result: result.message.content.toString() }); + } + + private async handleLLMInputStream( + context: Context, + ev: InputEvent, + ): Promise | ToolCallEvent> { + const { llm, tools, memory } = this; + const llmArgs = { messages: this.chatHistory, tools }; + + const responseGenerator = async function* () { + const responseStream = await llm.chat({ ...llmArgs, stream: true }); + + let fullResponse = null; + let yieldedIndicator = false; + for await (const chunk of responseStream) { + const hasToolCalls = chunk.options && "toolCall" in chunk.options; + if (!hasToolCalls) { + if (!yieldedIndicator) { + yield false; + yieldedIndicator = true; + } + yield chunk; + } else if (!yieldedIndicator) { + yield true; + yieldedIndicator = true; + } + + fullResponse = chunk; + } + + if (fullResponse) { + memory.put({ + role: "assistant", + content: "", + options: fullResponse.options, + }); + yield fullResponse; + } + }; + + const generator = responseGenerator(); + const isToolCall = await generator.next(); + if (isToolCall.value) { + const fullResponse = await generator.next(); + const toolCalls = this.getToolCallsFromResponse( + fullResponse.value as ChatResponseChunk, + ); + return new ToolCallEvent({ toolCalls }); + } + + this.writeEvent("Finished task", context); + return new StopEvent({ result: generator }); + } + + private async handleToolCalls( + ctx: Context, + ev: ToolCallEvent, + ): Promise { + const { toolCalls } = ev.data; + + const toolMsgs: ChatMessage[] = []; + + for (const call of toolCalls) { + const targetTool = this.tools.find( + (tool) => tool.metadata.name === call.name, + ); + // TODO: make logger optional in callTool in framework + const toolOutput = await callTool(targetTool, call, { + log: () => {}, + error: console.error.bind(console), + warn: () => {}, + }); + toolMsgs.push({ + content: JSON.stringify(toolOutput.output), + role: "user", + options: { + toolResult: { + result: toolOutput.output, + isError: toolOutput.isError, + id: call.id, + }, + }, + }); + } + + for (const msg of toolMsgs) { + this.memory.put(msg); + } + + return new InputEvent({ input: this.memory.getMessages() }); + } + + private writeEvent(msg: string, context: Context) { + if (!this.writeEvents) return; + context.writeEventToStream({ + data: new AgentRunEvent({ name: this.name, msg }), + }); + } + + private checkToolCallSupport() { + const { supportToolCall } = this.llm as ToolCallLLM; + if (!supportToolCall) throw new Error("LLM does not support tool calls"); + } + + private getToolCallsFromResponse( + response: + | ChatResponse + | ChatResponseChunk, + ): ToolCall[] { + let options; + if ("message" in response) { + options = response.message.options; + } else { + options = response.options; + } + if (options && "toolCall" in options) { + return options.toolCall as ToolCall[]; + } + return []; + } +} diff --git a/templates/components/multiagent/typescript/workflow/stream.ts b/templates/components/multiagent/typescript/workflow/stream.ts new file mode 100644 index 000000000..6502be55c --- /dev/null +++ b/templates/components/multiagent/typescript/workflow/stream.ts @@ -0,0 +1,65 @@ +import { StopEvent } from "@llamaindex/core/workflow"; +import { + createCallbacksTransformer, + createStreamDataTransformer, + StreamData, + trimStartOfStreamHelper, + type AIStreamCallbacksAndOptions, +} from "ai"; +import { ChatResponseChunk } from "llamaindex"; +import { AgentRunEvent } from "./type"; + +export function toDataStream( + result: Promise>>, + callbacks?: AIStreamCallbacksAndOptions, +) { + return toReadableStream(result) + .pipeThrough(createCallbacksTransformer(callbacks)) + .pipeThrough(createStreamDataTransformer()); +} + +function toReadableStream( + result: Promise>>, +) { + const trimStartOfStream = trimStartOfStreamHelper(); + return new ReadableStream({ + start(controller) { + controller.enqueue(""); // Kickstart the stream + }, + async pull(controller): Promise { + const stopEvent = await result; + const generator = stopEvent.data.result; + const { value, done } = await generator.next(); + if (done) { + controller.close(); + return; + } + + const text = trimStartOfStream(value.delta ?? ""); + if (text) controller.enqueue(text); + }, + }); +} + +export async function workflowEventsToStreamData( + events: AsyncIterable, +): Promise { + const streamData = new StreamData(); + + (async () => { + for await (const event of events) { + if (event instanceof AgentRunEvent) { + const { name, msg } = event.data; + if ((streamData as any).isClosed) { + break; + } + streamData.appendMessageAnnotation({ + type: "agent", + data: { agent: name, text: msg }, + }); + } + } + })(); + + return streamData; +} diff --git a/templates/components/multiagent/typescript/workflow/type.ts b/templates/components/multiagent/typescript/workflow/type.ts new file mode 100644 index 000000000..2e2fdedbd --- /dev/null +++ b/templates/components/multiagent/typescript/workflow/type.ts @@ -0,0 +1,11 @@ +import { WorkflowEvent } from "@llamaindex/core/workflow"; + +export type AgentInput = { + message: string; + streaming?: boolean; +}; + +export class AgentRunEvent extends WorkflowEvent<{ + name: string; + msg: string; +}> {} diff --git a/templates/types/streaming/express/src/controllers/engine/settings.ts b/templates/components/settings/typescript/settings.ts similarity index 100% rename from templates/types/streaming/express/src/controllers/engine/settings.ts rename to templates/components/settings/typescript/settings.ts diff --git a/templates/types/streaming/express/package.json b/templates/types/streaming/express/package.json index ff54cf261..084f39c86 100644 --- a/templates/types/streaming/express/package.json +++ b/templates/types/streaming/express/package.json @@ -15,6 +15,7 @@ "dev": "concurrently \"tsup index.ts --format esm --dts --watch\" \"nodemon --watch dist/index.js\"" }, "dependencies": { + "@llamaindex/core": "^0.2.6", "ai": "3.3.42", "cors": "^2.8.5", "dotenv": "^16.3.1", diff --git a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts deleted file mode 100644 index 5691fe773..000000000 --- a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts +++ /dev/null @@ -1,179 +0,0 @@ -import { - ALL_AVAILABLE_MISTRAL_MODELS, - Anthropic, - GEMINI_EMBEDDING_MODEL, - GEMINI_MODEL, - Gemini, - GeminiEmbedding, - Groq, - MistralAI, - MistralAIEmbedding, - MistralAIEmbeddingModelType, - OpenAI, - OpenAIEmbedding, - Settings, -} from "llamaindex"; -import { HuggingFaceEmbedding } from "llamaindex/embeddings/HuggingFaceEmbedding"; -import { OllamaEmbedding } from "llamaindex/embeddings/OllamaEmbedding"; -import { ALL_AVAILABLE_ANTHROPIC_MODELS } from "llamaindex/llm/anthropic"; -import { Ollama } from "llamaindex/llm/ollama"; - -const CHUNK_SIZE = 512; -const CHUNK_OVERLAP = 20; - -export const initSettings = async () => { - // HINT: you can delete the initialization code for unused model providers - console.log(`Using '${process.env.MODEL_PROVIDER}' model provider`); - - if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) { - throw new Error("'MODEL' and 'EMBEDDING_MODEL' env variables must be set."); - } - - switch (process.env.MODEL_PROVIDER) { - case "ollama": - initOllama(); - break; - case "groq": - initGroq(); - break; - case "anthropic": - initAnthropic(); - break; - case "gemini": - initGemini(); - break; - case "mistral": - initMistralAI(); - break; - case "azure-openai": - initAzureOpenAI(); - break; - default: - initOpenAI(); - break; - } - Settings.chunkSize = CHUNK_SIZE; - Settings.chunkOverlap = CHUNK_OVERLAP; -}; - -function initOpenAI() { - Settings.llm = new OpenAI({ - model: process.env.MODEL ?? "gpt-4o-mini", - maxTokens: process.env.LLM_MAX_TOKENS - ? Number(process.env.LLM_MAX_TOKENS) - : undefined, - }); - Settings.embedModel = new OpenAIEmbedding({ - model: process.env.EMBEDDING_MODEL, - dimensions: process.env.EMBEDDING_DIM - ? parseInt(process.env.EMBEDDING_DIM) - : undefined, - }); -} - -function initAzureOpenAI() { - // Map Azure OpenAI model names to OpenAI model names (only for TS) - const AZURE_OPENAI_MODEL_MAP: Record = { - "gpt-35-turbo": "gpt-3.5-turbo", - "gpt-35-turbo-16k": "gpt-3.5-turbo-16k", - "gpt-4o": "gpt-4o", - "gpt-4": "gpt-4", - "gpt-4-32k": "gpt-4-32k", - "gpt-4-turbo": "gpt-4-turbo", - "gpt-4-turbo-2024-04-09": "gpt-4-turbo", - "gpt-4-vision-preview": "gpt-4-vision-preview", - "gpt-4-1106-preview": "gpt-4-1106-preview", - "gpt-4o-2024-05-13": "gpt-4o-2024-05-13", - }; - - const azureConfig = { - apiKey: process.env.AZURE_OPENAI_KEY, - endpoint: process.env.AZURE_OPENAI_ENDPOINT, - apiVersion: - process.env.AZURE_OPENAI_API_VERSION || process.env.OPENAI_API_VERSION, - }; - - Settings.llm = new OpenAI({ - model: - AZURE_OPENAI_MODEL_MAP[process.env.MODEL ?? "gpt-35-turbo"] ?? - "gpt-3.5-turbo", - maxTokens: process.env.LLM_MAX_TOKENS - ? Number(process.env.LLM_MAX_TOKENS) - : undefined, - azure: { - ...azureConfig, - deployment: process.env.AZURE_OPENAI_LLM_DEPLOYMENT, - }, - }); - - Settings.embedModel = new OpenAIEmbedding({ - model: process.env.EMBEDDING_MODEL, - dimensions: process.env.EMBEDDING_DIM - ? parseInt(process.env.EMBEDDING_DIM) - : undefined, - azure: { - ...azureConfig, - deployment: process.env.AZURE_OPENAI_EMBEDDING_DEPLOYMENT, - }, - }); -} - -function initOllama() { - const config = { - host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434", - }; - Settings.llm = new Ollama({ - model: process.env.MODEL ?? "", - config, - }); - Settings.embedModel = new OllamaEmbedding({ - model: process.env.EMBEDDING_MODEL ?? "", - config, - }); -} - -function initGroq() { - const embedModelMap: Record = { - "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2", - "all-mpnet-base-v2": "Xenova/all-mpnet-base-v2", - }; - - Settings.llm = new Groq({ - model: process.env.MODEL!, - }); - - Settings.embedModel = new HuggingFaceEmbedding({ - modelType: embedModelMap[process.env.EMBEDDING_MODEL!], - }); -} - -function initAnthropic() { - const embedModelMap: Record = { - "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2", - "all-mpnet-base-v2": "Xenova/all-mpnet-base-v2", - }; - Settings.llm = new Anthropic({ - model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS, - }); - Settings.embedModel = new HuggingFaceEmbedding({ - modelType: embedModelMap[process.env.EMBEDDING_MODEL!], - }); -} - -function initGemini() { - Settings.llm = new Gemini({ - model: process.env.MODEL as GEMINI_MODEL, - }); - Settings.embedModel = new GeminiEmbedding({ - model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL, - }); -} - -function initMistralAI() { - Settings.llm = new MistralAI({ - model: process.env.MODEL as keyof typeof ALL_AVAILABLE_MISTRAL_MODELS, - }); - Settings.embedModel = new MistralAIEmbedding({ - model: process.env.EMBEDDING_MODEL as MistralAIEmbeddingModelType, - }); -} diff --git a/templates/types/streaming/nextjs/package.json b/templates/types/streaming/nextjs/package.json index 2839737d2..775bee983 100644 --- a/templates/types/streaming/nextjs/package.json +++ b/templates/types/streaming/nextjs/package.json @@ -12,6 +12,7 @@ "dependencies": { "@apidevtools/swagger-parser": "^10.1.0", "@e2b/code-interpreter": "^0.0.5", + "@llamaindex/core": "^0.2.6", "@llamaindex/pdf-viewer": "^1.1.3", "@radix-ui/react-collapsible": "^1.0.3", "@radix-ui/react-hover-card": "^1.0.7",