From 8e3fb9006abe2cf1a4315db39335a89b5590c9d4 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 24 Jul 2024 17:01:48 -0700 Subject: [PATCH 1/5] implemented and added test --- libs/langchain-groq/package.json | 4 +- libs/langchain-groq/src/chat_models.ts | 143 ++++++++---------- .../src/tests/chat_models.int.test.ts | 40 ++++- .../tests/chat_models.standard.int.test.ts | 2 +- yarn.lock | 60 +------- 5 files changed, 112 insertions(+), 137 deletions(-) diff --git a/libs/langchain-groq/package.json b/libs/langchain-groq/package.json index d4ee60e3c8f4..29ab170beb16 100644 --- a/libs/langchain-groq/package.json +++ b/libs/langchain-groq/package.json @@ -35,9 +35,9 @@ "author": "LangChain", "license": "MIT", "dependencies": { - "@langchain/core": ">=0.2.16 <0.3.0", + "@langchain/core": ">=0.2.18 <0.3.0", "@langchain/openai": "~0.2.4", - "groq-sdk": "^0.3.2", + "groq-sdk": "^0.5.0", "zod": "^3.22.4", "zod-to-json-schema": "^3.22.5" }, diff --git a/libs/langchain-groq/src/chat_models.ts b/libs/langchain-groq/src/chat_models.ts index ca7c302cfd7f..6c8844602b28 100644 --- a/libs/langchain-groq/src/chat_models.ts +++ b/libs/langchain-groq/src/chat_models.ts @@ -8,6 +8,8 @@ import { LangSmithParams, type BaseChatModelParams, } from "@langchain/core/language_models/chat_models"; +import * as ChatCompletionsAPI from "groq-sdk/resources/chat/completions"; +import * as CompletionsAPI from "groq-sdk/resources/completions"; import { AIMessage, AIMessageChunk, @@ -32,7 +34,6 @@ import { } from "@langchain/openai"; import { isZodSchema } from "@langchain/core/utils/types"; import Groq from "groq-sdk"; -import { ChatCompletionChunk } from "groq-sdk/lib/chat_completions_ext"; import { ChatCompletion, ChatCompletionCreateParams, @@ -146,8 +147,8 @@ export function messageToGroqRole(message: BaseMessage): GroqRoleEnum { function convertMessagesToGroqParams( messages: BaseMessage[] -): Array { - return messages.map((message): ChatCompletion.Choice.Message => { +): Array { + return messages.map((message): ChatCompletionsAPI.ChatCompletionMessage => { if (typeof message.content !== "string") { throw new Error("Non string message content not supported"); } @@ -172,12 +173,12 @@ function convertMessagesToGroqParams( completionParam.tool_call_id = (message as ToolMessage).tool_call_id; } } - return completionParam as ChatCompletion.Choice.Message; + return completionParam as ChatCompletionsAPI.ChatCompletionMessage; }); } function groqResponseToChatMessage( - message: ChatCompletion.Choice.Message + message: ChatCompletionsAPI.ChatCompletionMessage ): BaseMessage { const rawToolCalls: OpenAIToolCall[] | undefined = message.tool_calls as | OpenAIToolCall[] @@ -206,6 +207,19 @@ function groqResponseToChatMessage( } } +function _convertDeltaToolCallToToolCallChunk(toolCalls?: ChatCompletionsAPI.ChatCompletionChunk.Choice.Delta.ToolCall[]): ToolCallChunk[] | undefined { + if (!toolCalls?.length) return undefined; + + return toolCalls.map((tc) => { + return { + id: tc.id, + name: tc.function?.name, + args: tc.function?.arguments, + type: "tool_call_chunk" + } + }) +} + function _convertDeltaToMessageChunk( // eslint-disable-next-line @typescript-eslint/no-explicit-any delta: Record @@ -227,7 +241,11 @@ function _convertDeltaToMessageChunk( if (role === "user") { return new HumanMessageChunk({ content }); } else if (role === "assistant") { - return new AIMessageChunk({ content, additional_kwargs }); + return new AIMessageChunk({ + content, + additional_kwargs, + tool_call_chunks: _convertDeltaToolCallToToolCallChunk(delta.tool_calls), + }); } else if (role === "system") { return new SystemMessageChunk({ content }); } else { @@ -322,8 +340,8 @@ export class ChatGroq extends BaseChatModel< ls_provider: "groq", ls_model_name: this.model, ls_model_type: "chat", - ls_temperature: params.temperature, - ls_max_tokens: params.max_tokens, + ls_temperature: params.temperature ?? this.temperature, + ls_max_tokens: params.max_tokens ?? this.maxTokens, ls_stop: options.stop, }; } @@ -331,7 +349,7 @@ export class ChatGroq extends BaseChatModel< async completionWithRetry( request: ChatCompletionCreateParamsStreaming, options?: OpenAICoreRequestOptions - ): Promise>; + ): Promise>; async completionWithRetry( request: ChatCompletionCreateParamsNonStreaming, @@ -341,7 +359,7 @@ export class ChatGroq extends BaseChatModel< async completionWithRetry( request: ChatCompletionCreateParams, options?: OpenAICoreRequestOptions - ): Promise | ChatCompletion> { + ): Promise | ChatCompletion> { return this.caller.call(async () => this.client.chat.completions.create(request, options) ); @@ -391,76 +409,45 @@ export class ChatGroq extends BaseChatModel< ): AsyncGenerator { const params = this.invocationParams(options); const messagesMapped = convertMessagesToGroqParams(messages); - if (options.tools !== undefined && options.tools.length > 0) { - const result = await this._generateNonStreaming( - messages, - options, - runManager - ); - const generationMessage = result.generations[0].message as AIMessage; - if ( - generationMessage === undefined || - typeof generationMessage.content !== "string" - ) { - throw new Error("Could not parse Groq output."); + const response = await this.completionWithRetry( + { + ...params, + messages: messagesMapped, + stream: true, + }, + { + signal: options?.signal, + headers: options?.headers, } - const toolCallChunks: ToolCallChunk[] | undefined = - generationMessage.tool_calls?.map((toolCall, i) => ({ - name: toolCall.name, - args: JSON.stringify(toolCall.args), - id: toolCall.id, - index: i, - type: "tool_call_chunk", - })); - yield new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: generationMessage.content, - additional_kwargs: generationMessage.additional_kwargs, - tool_call_chunks: toolCallChunks, - }), - text: generationMessage.content, - }); - } else { - const response = await this.completionWithRetry( - { - ...params, - messages: messagesMapped, - stream: true, - }, - { - signal: options?.signal, - headers: options?.headers, - } - ); - let role = ""; - for await (const data of response) { - const choice = data?.choices[0]; - if (!choice) { - continue; - } - // The `role` field is populated in the first delta of the response - // but is not present in subsequent deltas. Extract it when available. - if (choice.delta?.role) { - role = choice.delta.role; - } - const chunk = new ChatGenerationChunk({ - message: _convertDeltaToMessageChunk( - { - ...choice.delta, - role, - } ?? {} - ), - text: choice.delta.content ?? "", - generationInfo: { - finishReason: choice.finish_reason, - }, - }); - yield chunk; - void runManager?.handleLLMNewToken(chunk.text ?? ""); + ); + let role = ""; + for await (const data of response) { + const choice = data?.choices[0]; + if (!choice) { + continue; } - if (options.signal?.aborted) { - throw new Error("AbortError"); + // The `role` field is populated in the first delta of the response + // but is not present in subsequent deltas. Extract it when available. + if (choice.delta?.role) { + role = choice.delta.role; } + const chunk = new ChatGenerationChunk({ + message: _convertDeltaToMessageChunk( + { + ...choice.delta, + role, + } ?? {} + ), + text: choice.delta.content ?? "", + generationInfo: { + finishReason: choice.finish_reason, + }, + }); + yield chunk; + void runManager?.handleLLMNewToken(chunk.text ?? ""); + } + if (options.signal?.aborted) { + throw new Error("AbortError"); } } @@ -518,7 +505,7 @@ export class ChatGroq extends BaseChatModel< completion_tokens: completionTokens, prompt_tokens: promptTokens, total_tokens: totalTokens, - } = data.usage as ChatCompletion.Usage; + } = data.usage as CompletionsAPI.CompletionUsage if (completionTokens) { tokenUsage.completionTokens = diff --git a/libs/langchain-groq/src/tests/chat_models.int.test.ts b/libs/langchain-groq/src/tests/chat_models.int.test.ts index c2839786a39b..b3c954b22faa 100644 --- a/libs/langchain-groq/src/tests/chat_models.int.test.ts +++ b/libs/langchain-groq/src/tests/chat_models.int.test.ts @@ -1,6 +1,9 @@ import { test } from "@jest/globals"; -import { AIMessage, HumanMessage, ToolMessage } from "@langchain/core/messages"; +import { AIMessage, AIMessageChunk, HumanMessage, ToolMessage } from "@langchain/core/messages"; import { ChatGroq } from "../chat_models.js"; +import { tool } from "@langchain/core/tools"; +import { z } from "zod"; +import { concat } from "@langchain/core/utils/stream"; test("invoke", async () => { const chat = new ChatGroq({ @@ -195,3 +198,38 @@ test("Few shotting with tool calls", async () => { console.log(res); expect(res.content).toContain("24"); }); + +test("Groq can stream tool calls", async () => { + const model = new ChatGroq({ + model: "llama-3.1-70b-versatile", + temperature: 0, + }); + + const weatherTool = tool((_) => { + return "The temperature is 24 degrees with hail."; + }, { + name: "get_current_weather", + schema: z.object({ + location: z.string().describe("The location to get the current weather for."), + }), + description: "Get the current weather in a given location.", + }) + + const modelWithTools = model.bindTools([weatherTool]); + + const stream = await modelWithTools.stream("What is the weather in San Francisco?"); + + let finalMessage: AIMessageChunk | undefined; + for await (const chunk of stream) { + finalMessage = !finalMessage ? chunk : concat(finalMessage, chunk); + } + + expect(finalMessage).toBeDefined(); + if (!finalMessage) return; + + expect(finalMessage.tool_calls?.[0]).toBeDefined(); + if (!finalMessage.tool_calls?.[0]) return; + + expect(finalMessage.tool_calls?.[0].name).toBe("get_current_weather"); + expect(finalMessage.tool_calls?.[0].args).toHaveProperty("location") +}) \ No newline at end of file diff --git a/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts b/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts index 9e1a2774771f..8065a3c0103b 100644 --- a/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts @@ -19,7 +19,7 @@ class ChatGroqStandardIntegrationTests extends ChatModelIntegrationTests< chatModelHasToolCalling: true, chatModelHasStructuredOutput: true, constructorArgs: { - model: "mixtral-8x7b-32768", + model: "llama3-groq-70b-8192-tool-use-preview", }, }); } diff --git a/yarn.lock b/yarn.lock index c5dcede32a3e..0272bdecbc9c 100644 --- a/yarn.lock +++ b/yarn.lock @@ -11753,7 +11753,7 @@ __metadata: eslint-plugin-import: ^2.27.5 eslint-plugin-no-instanceof: ^1.0.1 eslint-plugin-prettier: ^4.2.1 - groq-sdk: ^0.3.2 + groq-sdk: ^0.5.0 jest: ^29.5.0 jest-environment-node: ^29.6.4 prettier: ^2.8.3 @@ -20745,13 +20745,6 @@ __metadata: languageName: node linkType: hard -"base-64@npm:^0.1.0": - version: 0.1.0 - resolution: "base-64@npm:0.1.0" - checksum: 5a42938f82372ab5392cbacc85a5a78115cbbd9dbef9f7540fa47d78763a3a8bd7d598475f0d92341f66285afd377509851a9bb5c67bbecb89686e9255d5b3eb - languageName: node - linkType: hard - "base-64@npm:^1.0.0": version: 1.0.0 resolution: "base-64@npm:1.0.0" @@ -21561,13 +21554,6 @@ __metadata: languageName: node linkType: hard -"charenc@npm:0.0.2": - version: 0.0.2 - resolution: "charenc@npm:0.0.2" - checksum: 81dcadbe57e861d527faf6dd3855dc857395a1c4d6781f4847288ab23cffb7b3ee80d57c15bba7252ffe3e5e8019db767757ee7975663ad2ca0939bb8fcaf2e5 - languageName: node - linkType: hard - "cheerio-select@npm:^2.1.0": version: 2.1.0 resolution: "cheerio-select@npm:2.1.0" @@ -22792,13 +22778,6 @@ __metadata: languageName: node linkType: hard -"crypt@npm:0.0.2": - version: 0.0.2 - resolution: "crypt@npm:0.0.2" - checksum: baf4c7bbe05df656ec230018af8cf7dbe8c14b36b98726939cef008d473f6fe7a4fad906cfea4062c93af516f1550a3f43ceb4d6615329612c6511378ed9fe34 - languageName: node - linkType: hard - "crypto-js@npm:^4.2.0": version: 4.2.0 resolution: "crypto-js@npm:4.2.0" @@ -23988,16 +23967,6 @@ __metadata: languageName: node linkType: hard -"digest-fetch@npm:^1.3.0": - version: 1.3.0 - resolution: "digest-fetch@npm:1.3.0" - dependencies: - base-64: ^0.1.0 - md5: ^2.3.0 - checksum: 8ebdb4b9ef02b1ac0da532d25c7d08388f2552813dfadabfe7c4630e944bb4a48093b997fc926440a10e1ccf4912f2ce9adcf2d6687b0518dab8480e08f22f9d - languageName: node - linkType: hard - "dingbat-to-unicode@npm:^1.0.1": version: 1.0.1 resolution: "dingbat-to-unicode@npm:1.0.1" @@ -27767,20 +27736,19 @@ __metadata: languageName: node linkType: hard -"groq-sdk@npm:^0.3.2": - version: 0.3.2 - resolution: "groq-sdk@npm:0.3.2" +"groq-sdk@npm:^0.5.0": + version: 0.5.0 + resolution: "groq-sdk@npm:0.5.0" dependencies: "@types/node": ^18.11.18 "@types/node-fetch": ^2.6.4 abort-controller: ^3.0.0 agentkeepalive: ^4.2.1 - digest-fetch: ^1.3.0 form-data-encoder: 1.7.2 formdata-node: ^4.3.2 node-fetch: ^2.6.7 web-streams-polyfill: ^3.2.1 - checksum: 78cdc02ac8e87d5c47c2857def55d14249ee1b698f11d06db01a86227716a3e4e2312224996168f7edee51992862082dd4dfcdfec54b765d698855db9971e525 + checksum: 051ca56e99e4a2440080943c831b109687dd346b24155d3f085113df1ad0639cb95724c14a05611f7314d340db8bf342af425eb11905c97bc6a6948cd7262f04 languageName: node linkType: hard @@ -28850,13 +28818,6 @@ __metadata: languageName: node linkType: hard -"is-buffer@npm:~1.1.6": - version: 1.1.6 - resolution: "is-buffer@npm:1.1.6" - checksum: 4a186d995d8bbf9153b4bd9ff9fd04ae75068fe695d29025d25e592d9488911eeece84eefbd8fa41b8ddcc0711058a71d4c466dcf6f1f6e1d83830052d8ca707 - languageName: node - linkType: hard - "is-callable@npm:^1.1.3, is-callable@npm:^1.1.4, is-callable@npm:^1.2.7": version: 1.2.7 resolution: "is-callable@npm:1.2.7" @@ -32207,17 +32168,6 @@ __metadata: languageName: node linkType: hard -"md5@npm:^2.3.0": - version: 2.3.0 - resolution: "md5@npm:2.3.0" - dependencies: - charenc: 0.0.2 - crypt: 0.0.2 - is-buffer: ~1.1.6 - checksum: a63cacf4018dc9dee08c36e6f924a64ced735b37826116c905717c41cebeb41a522f7a526ba6ad578f9c80f02cb365033ccd67fe186ffbcc1a1faeb75daa9b6e - languageName: node - linkType: hard - "mdast-squeeze-paragraphs@npm:^4.0.0": version: 4.0.0 resolution: "mdast-squeeze-paragraphs@npm:4.0.0" From 6015579f7f5b73eaae29b0530c4623a7d45dbdf8 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 24 Jul 2024 17:03:35 -0700 Subject: [PATCH 2/5] chore: lint files --- libs/langchain-groq/src/chat_models.ts | 18 +++++---- .../src/tests/chat_models.int.test.ts | 38 ++++++++++++------- .../tests/chat_models.standard.int.test.ts | 2 +- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/libs/langchain-groq/src/chat_models.ts b/libs/langchain-groq/src/chat_models.ts index 6c8844602b28..861c2aae6789 100644 --- a/libs/langchain-groq/src/chat_models.ts +++ b/libs/langchain-groq/src/chat_models.ts @@ -207,17 +207,17 @@ function groqResponseToChatMessage( } } -function _convertDeltaToolCallToToolCallChunk(toolCalls?: ChatCompletionsAPI.ChatCompletionChunk.Choice.Delta.ToolCall[]): ToolCallChunk[] | undefined { +function _convertDeltaToolCallToToolCallChunk( + toolCalls?: ChatCompletionsAPI.ChatCompletionChunk.Choice.Delta.ToolCall[] +): ToolCallChunk[] | undefined { if (!toolCalls?.length) return undefined; - return toolCalls.map((tc) => { - return { + return toolCalls.map((tc) => ({ id: tc.id, name: tc.function?.name, args: tc.function?.arguments, - type: "tool_call_chunk" - } - }) + type: "tool_call_chunk", + })); } function _convertDeltaToMessageChunk( @@ -359,7 +359,9 @@ export class ChatGroq extends BaseChatModel< async completionWithRetry( request: ChatCompletionCreateParams, options?: OpenAICoreRequestOptions - ): Promise | ChatCompletion> { + ): Promise< + AsyncIterable | ChatCompletion + > { return this.caller.call(async () => this.client.chat.completions.create(request, options) ); @@ -505,7 +507,7 @@ export class ChatGroq extends BaseChatModel< completion_tokens: completionTokens, prompt_tokens: promptTokens, total_tokens: totalTokens, - } = data.usage as CompletionsAPI.CompletionUsage + } = data.usage as CompletionsAPI.CompletionUsage; if (completionTokens) { tokenUsage.completionTokens = diff --git a/libs/langchain-groq/src/tests/chat_models.int.test.ts b/libs/langchain-groq/src/tests/chat_models.int.test.ts index b3c954b22faa..e28012be269f 100644 --- a/libs/langchain-groq/src/tests/chat_models.int.test.ts +++ b/libs/langchain-groq/src/tests/chat_models.int.test.ts @@ -1,9 +1,14 @@ import { test } from "@jest/globals"; -import { AIMessage, AIMessageChunk, HumanMessage, ToolMessage } from "@langchain/core/messages"; -import { ChatGroq } from "../chat_models.js"; +import { + AIMessage, + AIMessageChunk, + HumanMessage, + ToolMessage, +} from "@langchain/core/messages"; import { tool } from "@langchain/core/tools"; import { z } from "zod"; import { concat } from "@langchain/core/utils/stream"; +import { ChatGroq } from "../chat_models.js"; test("invoke", async () => { const chat = new ChatGroq({ @@ -205,19 +210,24 @@ test("Groq can stream tool calls", async () => { temperature: 0, }); - const weatherTool = tool((_) => { - return "The temperature is 24 degrees with hail."; - }, { - name: "get_current_weather", - schema: z.object({ - location: z.string().describe("The location to get the current weather for."), - }), - description: "Get the current weather in a given location.", - }) + const weatherTool = tool( + (_) => "The temperature is 24 degrees with hail.", + { + name: "get_current_weather", + schema: z.object({ + location: z + .string() + .describe("The location to get the current weather for."), + }), + description: "Get the current weather in a given location.", + } + ); const modelWithTools = model.bindTools([weatherTool]); - const stream = await modelWithTools.stream("What is the weather in San Francisco?"); + const stream = await modelWithTools.stream( + "What is the weather in San Francisco?" + ); let finalMessage: AIMessageChunk | undefined; for await (const chunk of stream) { @@ -231,5 +241,5 @@ test("Groq can stream tool calls", async () => { if (!finalMessage.tool_calls?.[0]) return; expect(finalMessage.tool_calls?.[0].name).toBe("get_current_weather"); - expect(finalMessage.tool_calls?.[0].args).toHaveProperty("location") -}) \ No newline at end of file + expect(finalMessage.tool_calls?.[0].args).toHaveProperty("location"); +}); diff --git a/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts b/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts index 8065a3c0103b..82c4e3c392f8 100644 --- a/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts @@ -19,7 +19,7 @@ class ChatGroqStandardIntegrationTests extends ChatModelIntegrationTests< chatModelHasToolCalling: true, chatModelHasStructuredOutput: true, constructorArgs: { - model: "llama3-groq-70b-8192-tool-use-preview", + model: "llama-3.1-70b-versatile", }, }); } From 40ef156e31a9161b395be31c98f5c82dcb5c31c1 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 24 Jul 2024 17:03:51 -0700 Subject: [PATCH 3/5] ayrn --- yarn.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yarn.lock b/yarn.lock index 0272bdecbc9c..a72575e60f0a 100644 --- a/yarn.lock +++ b/yarn.lock @@ -11735,7 +11735,7 @@ __metadata: resolution: "@langchain/groq@workspace:libs/langchain-groq" dependencies: "@jest/globals": ^29.5.0 - "@langchain/core": ">=0.2.16 <0.3.0" + "@langchain/core": ">=0.2.18 <0.3.0" "@langchain/openai": "workspace:^" "@langchain/scripts": ~0.0.20 "@langchain/standard-tests": 0.0.0 From 270b401fa9a83bf787ee469c5d83703e5eabc811 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 24 Jul 2024 17:15:35 -0700 Subject: [PATCH 4/5] chore: lint files --- libs/langchain-groq/src/chat_models.ts | 10 ++++----- .../src/tests/chat_models.int.test.ts | 21 ++++++++----------- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/libs/langchain-groq/src/chat_models.ts b/libs/langchain-groq/src/chat_models.ts index 861c2aae6789..5967becbd9d1 100644 --- a/libs/langchain-groq/src/chat_models.ts +++ b/libs/langchain-groq/src/chat_models.ts @@ -213,11 +213,11 @@ function _convertDeltaToolCallToToolCallChunk( if (!toolCalls?.length) return undefined; return toolCalls.map((tc) => ({ - id: tc.id, - name: tc.function?.name, - args: tc.function?.arguments, - type: "tool_call_chunk", - })); + id: tc.id, + name: tc.function?.name, + args: tc.function?.arguments, + type: "tool_call_chunk", + })); } function _convertDeltaToMessageChunk( diff --git a/libs/langchain-groq/src/tests/chat_models.int.test.ts b/libs/langchain-groq/src/tests/chat_models.int.test.ts index e28012be269f..902565218366 100644 --- a/libs/langchain-groq/src/tests/chat_models.int.test.ts +++ b/libs/langchain-groq/src/tests/chat_models.int.test.ts @@ -210,18 +210,15 @@ test("Groq can stream tool calls", async () => { temperature: 0, }); - const weatherTool = tool( - (_) => "The temperature is 24 degrees with hail.", - { - name: "get_current_weather", - schema: z.object({ - location: z - .string() - .describe("The location to get the current weather for."), - }), - description: "Get the current weather in a given location.", - } - ); + const weatherTool = tool((_) => "The temperature is 24 degrees with hail.", { + name: "get_current_weather", + schema: z.object({ + location: z + .string() + .describe("The location to get the current weather for."), + }), + description: "Get the current weather in a given location.", + }); const modelWithTools = model.bindTools([weatherTool]); From 203f3669b336acf7072b1bc8522184a1149e88b0 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Thu, 25 Jul 2024 14:30:58 -0700 Subject: [PATCH 5/5] ensure name/id fields are only yielded once for streaming tool calls --- libs/langchain-groq/src/chat_models.ts | 100 +++++++++++++++--- .../src/tests/chat_models.int.test.ts | 1 + 2 files changed, 84 insertions(+), 17 deletions(-) diff --git a/libs/langchain-groq/src/chat_models.ts b/libs/langchain-groq/src/chat_models.ts index 5967becbd9d1..b2291dc552ce 100644 --- a/libs/langchain-groq/src/chat_models.ts +++ b/libs/langchain-groq/src/chat_models.ts @@ -21,6 +21,7 @@ import { ToolMessage, OpenAIToolCall, isAIMessage, + BaseMessageChunk, } from "@langchain/core/messages"; import { ChatGeneration, @@ -208,7 +209,8 @@ function groqResponseToChatMessage( } function _convertDeltaToolCallToToolCallChunk( - toolCalls?: ChatCompletionsAPI.ChatCompletionChunk.Choice.Delta.ToolCall[] + toolCalls?: ChatCompletionsAPI.ChatCompletionChunk.Choice.Delta.ToolCall[], + index?: number ): ToolCallChunk[] | undefined { if (!toolCalls?.length) return undefined; @@ -217,13 +219,23 @@ function _convertDeltaToolCallToToolCallChunk( name: tc.function?.name, args: tc.function?.arguments, type: "tool_call_chunk", + index, })); } function _convertDeltaToMessageChunk( // eslint-disable-next-line @typescript-eslint/no-explicit-any - delta: Record -) { + delta: Record, + index: number +): { + message: BaseMessageChunk; + toolCallData?: { + id: string; + name: string; + index: number; + type: "tool_call_chunk"; + }[]; +} { const { role } = delta; const content = delta.content ?? ""; let additional_kwargs; @@ -239,17 +251,43 @@ function _convertDeltaToMessageChunk( additional_kwargs = {}; } if (role === "user") { - return new HumanMessageChunk({ content }); + return { + message: new HumanMessageChunk({ content }), + }; } else if (role === "assistant") { - return new AIMessageChunk({ - content, - additional_kwargs, - tool_call_chunks: _convertDeltaToolCallToToolCallChunk(delta.tool_calls), - }); + const toolCallChunks = _convertDeltaToolCallToToolCallChunk( + delta.tool_calls, + index + ); + return { + message: new AIMessageChunk({ + content, + additional_kwargs, + tool_call_chunks: toolCallChunks + ? toolCallChunks.map((tc) => ({ + type: tc.type, + args: tc.args, + index: tc.index, + })) + : undefined, + }), + toolCallData: toolCallChunks + ? toolCallChunks.map((tc) => ({ + id: tc.id ?? "", + name: tc.name ?? "", + index: tc.index ?? index, + type: "tool_call_chunk", + })) + : undefined, + }; } else if (role === "system") { - return new SystemMessageChunk({ content }); + return { + message: new SystemMessageChunk({ content }), + }; } else { - return new ChatMessageChunk({ content, role }); + return { + message: new ChatMessageChunk({ content, role }), + }; } } @@ -423,6 +461,12 @@ export class ChatGroq extends BaseChatModel< } ); let role = ""; + const toolCall: { + id: string; + name: string; + index: number; + type: "tool_call_chunk"; + }[] = []; for await (const data of response) { const choice = data?.choices[0]; if (!choice) { @@ -433,13 +477,34 @@ export class ChatGroq extends BaseChatModel< if (choice.delta?.role) { role = choice.delta.role; } + + const { message, toolCallData } = _convertDeltaToMessageChunk( + { + ...choice.delta, + role, + } ?? {}, + choice.index + ); + + if (toolCallData) { + // First, ensure the ID is not already present in toolCall + const newToolCallData = toolCallData.filter((tc) => + toolCall.every((t) => t.id !== tc.id) + ); + toolCall.push(...newToolCallData); + + // Yield here, ensuring the ID and name fields are only yielded once. + yield new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: "", + tool_call_chunks: newToolCallData, + }), + text: "", + }); + } + const chunk = new ChatGenerationChunk({ - message: _convertDeltaToMessageChunk( - { - ...choice.delta, - role, - } ?? {} - ), + message, text: choice.delta.content ?? "", generationInfo: { finishReason: choice.finish_reason, @@ -448,6 +513,7 @@ export class ChatGroq extends BaseChatModel< yield chunk; void runManager?.handleLLMNewToken(chunk.text ?? ""); } + if (options.signal?.aborted) { throw new Error("AbortError"); } diff --git a/libs/langchain-groq/src/tests/chat_models.int.test.ts b/libs/langchain-groq/src/tests/chat_models.int.test.ts index 902565218366..d760aa669335 100644 --- a/libs/langchain-groq/src/tests/chat_models.int.test.ts +++ b/libs/langchain-groq/src/tests/chat_models.int.test.ts @@ -239,4 +239,5 @@ test("Groq can stream tool calls", async () => { expect(finalMessage.tool_calls?.[0].name).toBe("get_current_weather"); expect(finalMessage.tool_calls?.[0].args).toHaveProperty("location"); + expect(finalMessage.tool_calls?.[0].id).toBeDefined(); });