Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

groq[minor]: Implement streaming tool calls #6203

Merged
merged 7 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions libs/langchain-groq/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
"author": "LangChain",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! 👋 I noticed that the dependency changes in the package.json file might impact the project's peer/dev/hard dependencies. I've flagged this for your review as it's important to ensure compatibility and stability. Keep up the great work! 🚀

"license": "MIT",
"dependencies": {
"@langchain/core": ">=0.2.16 <0.3.0",
"@langchain/core": ">=0.2.18 <0.3.0",
"@langchain/openai": "~0.2.4",
"groq-sdk": "^0.3.2",
"groq-sdk": "^0.5.0",
"zod": "^3.22.4",
"zod-to-json-schema": "^3.22.5"
},
Expand Down
217 changes: 136 additions & 81 deletions libs/langchain-groq/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import {
LangSmithParams,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import * as ChatCompletionsAPI from "groq-sdk/resources/chat/completions";
import * as CompletionsAPI from "groq-sdk/resources/completions";
import {
AIMessage,
AIMessageChunk,
Expand All @@ -19,6 +21,7 @@ import {
ToolMessage,
OpenAIToolCall,
isAIMessage,
BaseMessageChunk,
} from "@langchain/core/messages";
import {
ChatGeneration,
Expand All @@ -32,7 +35,6 @@ import {
} from "@langchain/openai";
import { isZodSchema } from "@langchain/core/utils/types";
import Groq from "groq-sdk";
import { ChatCompletionChunk } from "groq-sdk/lib/chat_completions_ext";
import {
ChatCompletion,
ChatCompletionCreateParams,
Expand Down Expand Up @@ -146,8 +148,8 @@ export function messageToGroqRole(message: BaseMessage): GroqRoleEnum {

function convertMessagesToGroqParams(
messages: BaseMessage[]
): Array<ChatCompletion.Choice.Message> {
return messages.map((message): ChatCompletion.Choice.Message => {
): Array<ChatCompletionsAPI.ChatCompletionMessage> {
return messages.map((message): ChatCompletionsAPI.ChatCompletionMessage => {
if (typeof message.content !== "string") {
throw new Error("Non string message content not supported");
}
Expand All @@ -172,12 +174,12 @@ function convertMessagesToGroqParams(
completionParam.tool_call_id = (message as ToolMessage).tool_call_id;
}
}
return completionParam as ChatCompletion.Choice.Message;
return completionParam as ChatCompletionsAPI.ChatCompletionMessage;
});
}

function groqResponseToChatMessage(
message: ChatCompletion.Choice.Message
message: ChatCompletionsAPI.ChatCompletionMessage
): BaseMessage {
const rawToolCalls: OpenAIToolCall[] | undefined = message.tool_calls as
| OpenAIToolCall[]
Expand Down Expand Up @@ -206,10 +208,34 @@ function groqResponseToChatMessage(
}
}

function _convertDeltaToolCallToToolCallChunk(
toolCalls?: ChatCompletionsAPI.ChatCompletionChunk.Choice.Delta.ToolCall[],
index?: number
): ToolCallChunk[] | undefined {
if (!toolCalls?.length) return undefined;

return toolCalls.map((tc) => ({
id: tc.id,
name: tc.function?.name,
args: tc.function?.arguments,
type: "tool_call_chunk",
index,
}));
}

function _convertDeltaToMessageChunk(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
delta: Record<string, any>
) {
delta: Record<string, any>,
index: number
): {
message: BaseMessageChunk;
toolCallData?: {
id: string;
name: string;
index: number;
type: "tool_call_chunk";
}[];
} {
const { role } = delta;
const content = delta.content ?? "";
let additional_kwargs;
Expand All @@ -225,13 +251,43 @@ function _convertDeltaToMessageChunk(
additional_kwargs = {};
}
if (role === "user") {
return new HumanMessageChunk({ content });
return {
message: new HumanMessageChunk({ content }),
};
} else if (role === "assistant") {
return new AIMessageChunk({ content, additional_kwargs });
const toolCallChunks = _convertDeltaToolCallToToolCallChunk(
delta.tool_calls,
index
);
return {
message: new AIMessageChunk({
content,
additional_kwargs,
tool_call_chunks: toolCallChunks
? toolCallChunks.map((tc) => ({
type: tc.type,
args: tc.args,
index: tc.index,
}))
: undefined,
}),
toolCallData: toolCallChunks
? toolCallChunks.map((tc) => ({
id: tc.id ?? "",
name: tc.name ?? "",
index: tc.index ?? index,
type: "tool_call_chunk",
}))
: undefined,
};
} else if (role === "system") {
return new SystemMessageChunk({ content });
return {
message: new SystemMessageChunk({ content }),
};
} else {
return new ChatMessageChunk({ content, role });
return {
message: new ChatMessageChunk({ content, role }),
};
}
}

Expand Down Expand Up @@ -322,16 +378,16 @@ export class ChatGroq extends BaseChatModel<
ls_provider: "groq",
ls_model_name: this.model,
ls_model_type: "chat",
ls_temperature: params.temperature,
ls_max_tokens: params.max_tokens,
ls_temperature: params.temperature ?? this.temperature,
ls_max_tokens: params.max_tokens ?? this.maxTokens,
ls_stop: options.stop,
};
}

async completionWithRetry(
request: ChatCompletionCreateParamsStreaming,
options?: OpenAICoreRequestOptions
): Promise<AsyncIterable<ChatCompletionChunk>>;
): Promise<AsyncIterable<ChatCompletionsAPI.ChatCompletionChunk>>;

async completionWithRetry(
request: ChatCompletionCreateParamsNonStreaming,
Expand All @@ -341,7 +397,9 @@ export class ChatGroq extends BaseChatModel<
async completionWithRetry(
request: ChatCompletionCreateParams,
options?: OpenAICoreRequestOptions
): Promise<AsyncIterable<ChatCompletionChunk> | ChatCompletion> {
): Promise<
AsyncIterable<ChatCompletionsAPI.ChatCompletionChunk> | ChatCompletion
> {
return this.caller.call(async () =>
this.client.chat.completions.create(request, options)
);
Expand Down Expand Up @@ -391,76 +449,73 @@ export class ChatGroq extends BaseChatModel<
): AsyncGenerator<ChatGenerationChunk> {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey team, I've flagged this PR for your review as it introduces a new external HTTP request using this.completionWithRetry. This comment is to ensure the change is reviewed and aligned with our project's requirements. Let me know if you need further clarification.

const params = this.invocationParams(options);
const messagesMapped = convertMessagesToGroqParams(messages);
if (options.tools !== undefined && options.tools.length > 0) {
const result = await this._generateNonStreaming(
messages,
options,
runManager
);
const generationMessage = result.generations[0].message as AIMessage;
if (
generationMessage === undefined ||
typeof generationMessage.content !== "string"
) {
throw new Error("Could not parse Groq output.");
const response = await this.completionWithRetry(
{
...params,
messages: messagesMapped,
stream: true,
},
{
signal: options?.signal,
headers: options?.headers,
}
const toolCallChunks: ToolCallChunk[] | undefined =
generationMessage.tool_calls?.map((toolCall, i) => ({
name: toolCall.name,
args: JSON.stringify(toolCall.args),
id: toolCall.id,
index: i,
type: "tool_call_chunk",
}));
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: generationMessage.content,
additional_kwargs: generationMessage.additional_kwargs,
tool_call_chunks: toolCallChunks,
}),
text: generationMessage.content,
});
} else {
const response = await this.completionWithRetry(
{
...params,
messages: messagesMapped,
stream: true,
},
);
let role = "";
const toolCall: {
id: string;
name: string;
index: number;
type: "tool_call_chunk";
}[] = [];
for await (const data of response) {
const choice = data?.choices[0];
if (!choice) {
continue;
}
// The `role` field is populated in the first delta of the response
// but is not present in subsequent deltas. Extract it when available.
if (choice.delta?.role) {
role = choice.delta.role;
}

const { message, toolCallData } = _convertDeltaToMessageChunk(
{
signal: options?.signal,
headers: options?.headers,
}
...choice.delta,
role,
} ?? {},
choice.index
);
let role = "";
for await (const data of response) {
const choice = data?.choices[0];
if (!choice) {
continue;
}
// The `role` field is populated in the first delta of the response
// but is not present in subsequent deltas. Extract it when available.
if (choice.delta?.role) {
role = choice.delta.role;
}
const chunk = new ChatGenerationChunk({
message: _convertDeltaToMessageChunk(
{
...choice.delta,
role,
} ?? {}
),
text: choice.delta.content ?? "",
generationInfo: {
finishReason: choice.finish_reason,
},

if (toolCallData) {
// First, ensure the ID is not already present in toolCall
const newToolCallData = toolCallData.filter((tc) =>
toolCall.every((t) => t.id !== tc.id)
);
toolCall.push(...newToolCallData);

// Yield here, ensuring the ID and name fields are only yielded once.
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: "",
tool_call_chunks: newToolCallData,
}),
text: "",
});
yield chunk;
void runManager?.handleLLMNewToken(chunk.text ?? "");
}
if (options.signal?.aborted) {
throw new Error("AbortError");
}

const chunk = new ChatGenerationChunk({
message,
text: choice.delta.content ?? "",
generationInfo: {
finishReason: choice.finish_reason,
},
});
yield chunk;
void runManager?.handleLLMNewToken(chunk.text ?? "");
}

if (options.signal?.aborted) {
throw new Error("AbortError");
}
}

Expand Down Expand Up @@ -518,7 +573,7 @@ export class ChatGroq extends BaseChatModel<
completion_tokens: completionTokens,
prompt_tokens: promptTokens,
total_tokens: totalTokens,
} = data.usage as ChatCompletion.Usage;
} = data.usage as CompletionsAPI.CompletionUsage;

if (completionTokens) {
tokenUsage.completionTokens =
Expand Down
48 changes: 47 additions & 1 deletion libs/langchain-groq/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import { test } from "@jest/globals";
import { AIMessage, HumanMessage, ToolMessage } from "@langchain/core/messages";
import {
AIMessage,
AIMessageChunk,
HumanMessage,
ToolMessage,
} from "@langchain/core/messages";
import { tool } from "@langchain/core/tools";
import { z } from "zod";
import { concat } from "@langchain/core/utils/stream";
import { ChatGroq } from "../chat_models.js";

test("invoke", async () => {
Expand Down Expand Up @@ -195,3 +203,41 @@ test("Few shotting with tool calls", async () => {
console.log(res);
expect(res.content).toContain("24");
});

test("Groq can stream tool calls", async () => {
const model = new ChatGroq({
model: "llama-3.1-70b-versatile",
temperature: 0,
});

const weatherTool = tool((_) => "The temperature is 24 degrees with hail.", {
name: "get_current_weather",
schema: z.object({
location: z
.string()
.describe("The location to get the current weather for."),
}),
description: "Get the current weather in a given location.",
});

const modelWithTools = model.bindTools([weatherTool]);

const stream = await modelWithTools.stream(
"What is the weather in San Francisco?"
);

let finalMessage: AIMessageChunk | undefined;
for await (const chunk of stream) {
finalMessage = !finalMessage ? chunk : concat(finalMessage, chunk);
}

expect(finalMessage).toBeDefined();
if (!finalMessage) return;

expect(finalMessage.tool_calls?.[0]).toBeDefined();
if (!finalMessage.tool_calls?.[0]) return;

expect(finalMessage.tool_calls?.[0].name).toBe("get_current_weather");
expect(finalMessage.tool_calls?.[0].args).toHaveProperty("location");
expect(finalMessage.tool_calls?.[0].id).toBeDefined();
});
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ChatGroqStandardIntegrationTests extends ChatModelIntegrationTests<
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
constructorArgs: {
model: "mixtral-8x7b-32768",
model: "llama-3.1-70b-versatile",
},
});
}
Expand Down
Loading
Loading