diff --git a/docs/core_docs/docs/integrations/chat/cohere.mdx b/docs/core_docs/docs/integrations/chat/cohere.mdx index 40e44f32f2c6..c455bed66b90 100644 --- a/docs/core_docs/docs/integrations/chat/cohere.mdx +++ b/docs/core_docs/docs/integrations/chat/cohere.mdx @@ -62,6 +62,14 @@ import StatefulChatExample from "@examples/models/chat/cohere/stateful_conversat You can see the LangSmith traces from this example [here](https://smith.langchain.com/public/8e67b05a-4e63-414e-ac91-a91acf21b262/r) and [here](https://smith.langchain.com/public/50fabc25-46fe-4727-a59c-7e4eb0de8e70/r) ::: +### Tools + +The Cohere API supports tool calling, along with multi-hop-tool calling. The following example demonstrates how to call tools: + +import ToolCallingExample from "@examples/models/chat/cohere/tool_calling.ts"; + +{ToolCallingExample} + ### RAG Cohere also comes out of the box with RAG support. diff --git a/examples/src/models/chat/cohere/chat_cohere.ts b/examples/src/models/chat/cohere/chat_cohere.ts index 04ffed68aa5e..778eb9ab63f2 100644 --- a/examples/src/models/chat/cohere/chat_cohere.ts +++ b/examples/src/models/chat/cohere/chat_cohere.ts @@ -3,7 +3,6 @@ import { ChatPromptTemplate } from "@langchain/core/prompts"; const model = new ChatCohere({ apiKey: process.env.COHERE_API_KEY, // Default - model: "command", // Default }); const prompt = ChatPromptTemplate.fromMessages([ ["ai", "You are a helpful assistant"], diff --git a/examples/src/models/chat/cohere/chat_stream_cohere.ts b/examples/src/models/chat/cohere/chat_stream_cohere.ts index 559fd9f4415f..a7ddd822608e 100644 --- a/examples/src/models/chat/cohere/chat_stream_cohere.ts +++ b/examples/src/models/chat/cohere/chat_stream_cohere.ts @@ -4,7 +4,6 @@ import { StringOutputParser } from "@langchain/core/output_parsers"; const model = new ChatCohere({ apiKey: process.env.COHERE_API_KEY, // Default - model: "command", // Default }); const prompt = ChatPromptTemplate.fromMessages([ ["ai", "You are a helpful assistant"], diff --git a/examples/src/models/chat/cohere/connectors.ts b/examples/src/models/chat/cohere/connectors.ts index fd252dc7c76f..a16c2ed677c3 100644 --- a/examples/src/models/chat/cohere/connectors.ts +++ b/examples/src/models/chat/cohere/connectors.ts @@ -3,7 +3,6 @@ import { HumanMessage } from "@langchain/core/messages"; const model = new ChatCohere({ apiKey: process.env.COHERE_API_KEY, // Default - model: "command", // Default }); const response = await model.invoke( diff --git a/examples/src/models/chat/cohere/rag.ts b/examples/src/models/chat/cohere/rag.ts index 240225a33a46..b572dc8a1efe 100644 --- a/examples/src/models/chat/cohere/rag.ts +++ b/examples/src/models/chat/cohere/rag.ts @@ -3,7 +3,6 @@ import { HumanMessage } from "@langchain/core/messages"; const model = new ChatCohere({ apiKey: process.env.COHERE_API_KEY, // Default - model: "command", // Default }); const documents = [ diff --git a/examples/src/models/chat/cohere/stateful_conversation.ts b/examples/src/models/chat/cohere/stateful_conversation.ts index 1edc61a47ab2..e126c4bf6bce 100644 --- a/examples/src/models/chat/cohere/stateful_conversation.ts +++ b/examples/src/models/chat/cohere/stateful_conversation.ts @@ -3,7 +3,6 @@ import { HumanMessage } from "@langchain/core/messages"; const model = new ChatCohere({ apiKey: process.env.COHERE_API_KEY, // Default - model: "command", // Default }); const conversationId = `demo_test_id-${Math.random()}`; diff --git a/examples/src/models/chat/cohere/tool_calling.ts b/examples/src/models/chat/cohere/tool_calling.ts new file mode 100644 index 000000000000..f08a5e6cb343 --- /dev/null +++ b/examples/src/models/chat/cohere/tool_calling.ts @@ -0,0 +1,57 @@ +import { ChatCohere } from "@langchain/cohere"; +import { HumanMessage } from "@langchain/core/messages"; +import { z } from "zod"; +import { DynamicStructuredTool } from "@langchain/core/tools"; + +const model = new ChatCohere({ + apiKey: process.env.COHERE_API_KEY, // Default +}); + +const magicFunctionTool = new DynamicStructuredTool({ + name: "magic_function", + description: "Apply a magic function to the input number", + schema: z.object({ + num: z.number().describe("The number to apply the magic function for"), + }), + func: async ({ num }) => { + return `The magic function of ${num} is ${num + 5}`; + }, +}); + +const tools = [magicFunctionTool]; +const modelWithTools = model.bindTools(tools); + +const messages = [new HumanMessage("What is the magic function of number 5?")]; +const response = await modelWithTools.invoke(messages); +/* + AIMessage { + content: 'I will use the magic_function tool to answer this question.', + name: undefined, + additional_kwargs: { + response_id: 'd0b189e5-3dbf-493c-93f8-99ed4b01d96d', + generationId: '8982a68f-c64c-48f8-bf12-0b4bea0018b6', + chatHistory: [ [Object], [Object] ], + finishReason: 'COMPLETE', + meta: { apiVersion: [Object], billedUnits: [Object], tokens: [Object] }, + toolCalls: [ [Object] ] + }, + response_metadata: { + estimatedTokenUsage: { completionTokens: 54, promptTokens: 920, totalTokens: 974 }, + response_id: 'd0b189e5-3dbf-493c-93f8-99ed4b01d96d', + generationId: '8982a68f-c64c-48f8-bf12-0b4bea0018b6', + chatHistory: [ [Object], [Object] ], + finishReason: 'COMPLETE', + meta: { apiVersion: [Object], billedUnits: [Object], tokens: [Object] }, + toolCalls: [ [Object] ] + }, + tool_calls: [ + { + name: 'magic_function', + args: [Object], + id: '4ec98550-ba9a-4043-adfe-566230e5' + } + ], + invalid_tool_calls: [], + usage_metadata: { input_tokens: 920, output_tokens: 54, total_tokens: 974 } + } +*/ diff --git a/libs/langchain-cohere/.eslintrc.cjs b/libs/langchain-cohere/.eslintrc.cjs index d533e6deffb6..59171b108443 100644 --- a/libs/langchain-cohere/.eslintrc.cjs +++ b/libs/langchain-cohere/.eslintrc.cjs @@ -33,6 +33,7 @@ module.exports = { "@typescript-eslint/no-unused-vars": ["warn", { args: "none" }], "@typescript-eslint/no-floating-promises": "error", "@typescript-eslint/no-misused-promises": "error", + "arrow-body-style": 0, camelcase: 0, "class-methods-use-this": 0, "import/extensions": [2, "ignorePackages"], diff --git a/libs/langchain-cohere/package.json b/libs/langchain-cohere/package.json index 4227d83388e7..c8dfc8bf1477 100644 --- a/libs/langchain-cohere/package.json +++ b/libs/langchain-cohere/package.json @@ -35,8 +35,11 @@ "author": "LangChain", "license": "MIT", "dependencies": { - "@langchain/core": ">=0.2.5 <0.3.0", - "cohere-ai": "^7.10.5" + "@langchain/core": ">=0.2.14 <0.3.0", + "cohere-ai": "^7.10.5", + "uuid": "^10.0.0", + "zod": "^3.23.8", + "zod-to-json-schema": "^3.23.1" }, "devDependencies": { "@jest/globals": "^29.5.0", diff --git a/libs/langchain-cohere/src/chat_models.ts b/libs/langchain-cohere/src/chat_models.ts index cc8184a5bb05..bd9d7001979b 100644 --- a/libs/langchain-cohere/src/chat_models.ts +++ b/libs/langchain-cohere/src/chat_models.ts @@ -1,12 +1,22 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ import { CohereClient, Cohere } from "cohere-ai"; +import { ToolResult } from "cohere-ai/api/index.js"; +import { zodToJsonSchema } from "zod-to-json-schema"; import { MessageType, type BaseMessage, MessageContent, AIMessage, + isAIMessage, } from "@langchain/core/messages"; -import { type BaseLanguageModelCallOptions } from "@langchain/core/language_models/base"; +import { + BaseLanguageModelInput, + ToolDefinition, + isOpenAITool, + type BaseLanguageModelCallOptions, +} from "@langchain/core/language_models/base"; +import { isStructuredTool } from "@langchain/core/utils/function_calling"; import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; import { type BaseChatModelParams, @@ -21,6 +31,14 @@ import { import { AIMessageChunk } from "@langchain/core/messages"; import { getEnvironmentVariable } from "@langchain/core/utils/env"; import { NewTokenIndices } from "@langchain/core/callbacks/base"; +import { + ToolMessage, + ToolCall, + ToolCallChunk, +} from "@langchain/core/messages/tool"; +import * as uuid from "uuid"; +import { StructuredToolInterface } from "@langchain/core/tools"; +import { Runnable } from "@langchain/core/runnables"; /** * Input interface for ChatCohere @@ -65,15 +83,62 @@ interface TokenUsage { totalTokens?: number; } -export interface CohereChatCallOptions +export interface ChatCohereCallOptions extends BaseLanguageModelCallOptions, - Partial>, - Partial>, - Pick {} + Partial>, + Partial>, + Pick { + tools?: ( + | StructuredToolInterface + | Cohere.Tool + | Record + | ToolDefinition + )[]; +} + +/** @deprecated Import as ChatCohereCallOptions instead. */ +export interface CohereChatCallOptions extends ChatCohereCallOptions {} + +function convertToDocuments( + observations: MessageContent +): Array> { + /** Converts observations into a 'document' dict */ + const documents: Array> = []; + let observationsList: Array> = []; + + if (typeof observations === "string") { + // strings are turned into a key/value pair and a key of 'output' is added. + observationsList = [{ output: observations }]; + } else if ( + // eslint-disable-next-line no-instanceof/no-instanceof + observations instanceof Map || + (typeof observations === "object" && + observations !== null && + !Array.isArray(observations)) + ) { + // single mappings are transformed into a list to simplify the rest of the code. + observationsList = [observations]; + } else if (!Array.isArray(observations)) { + // all other types are turned into a key/value pair within a list + observationsList = [{ output: observations }]; + } + + for (let doc of observationsList) { + // eslint-disable-next-line no-instanceof/no-instanceof + if (!(doc instanceof Map) && (typeof doc !== "object" || doc === null)) { + // types that aren't Mapping are turned into a key/value pair. + doc = { output: doc }; + } + documents.push(doc); + } -function convertMessagesToCohereMessages( - messages: Array -): Array { + return documents; +} + +function convertMessageToCohereMessage( + message: BaseMessage, + toolResults: ToolResult[] +): Cohere.Message { const getRole = (role: MessageType) => { switch (role) { case "system": @@ -82,9 +147,11 @@ function convertMessagesToCohereMessages( return "USER"; case "ai": return "CHATBOT"; + case "tool": + return "TOOL"; default: throw new Error( - `Unknown message type: '${role}'. Accepted types: 'human', 'ai', 'system'` + `Unknown message type: '${role}'. Accepted types: 'human', 'ai', 'system', 'tool'` ); } }; @@ -102,10 +169,108 @@ function convertMessagesToCohereMessages( ); }; - return messages.map((message) => ({ - role: getRole(message._getType()), - message: getContent(message.content), - })); + const getToolCall = (message: BaseMessage): Cohere.ToolCall[] => { + if (isAIMessage(message) && message.tool_calls) { + return message.tool_calls.map((toolCall) => ({ + name: toolCall.name, + parameters: toolCall.args, + })); + } + return []; + }; + if (message._getType().toLowerCase() === "ai") { + return { + role: getRole(message._getType()), + message: getContent(message.content), + toolCalls: getToolCall(message), + }; + } else if (message._getType().toLowerCase() === "tool") { + return { + role: getRole(message._getType()), + message: getContent(message.content), + toolResults, + }; + } else if ( + message._getType().toLowerCase() === "human" || + message._getType().toLowerCase() === "system" + ) { + return { + role: getRole(message._getType()), + message: getContent(message.content), + }; + } else { + throw new Error( + "Got unknown message type. Supported types are AIMessage, ToolMessage, HumanMessage, and SystemMessage" + ); + } +} + +function isCohereTool(tool: any): tool is Cohere.Tool { + return ( + "name" in tool && "description" in tool && "parameterDefinitions" in tool + ); +} + +function isToolMessage(message: BaseMessage): message is ToolMessage { + return message._getType() === "tool"; +} + +function _convertJsonSchemaToCohereTool(jsonSchema: Record) { + const parameterDefinitionsProperties = + "properties" in jsonSchema ? jsonSchema.properties : {}; + let parameterDefinitionsRequired = + "required" in jsonSchema ? jsonSchema.required : []; + + const parameterDefinitionsFinal: Record = {}; + + // Iterate through all properties + Object.keys(parameterDefinitionsProperties).forEach((propertyName) => { + // Create the property in the new object + parameterDefinitionsFinal[propertyName] = + parameterDefinitionsProperties[propertyName]; + // Set the required property based on the 'required' array + if (parameterDefinitionsRequired === undefined) { + parameterDefinitionsRequired = []; + } + parameterDefinitionsFinal[propertyName].required = + parameterDefinitionsRequired.includes(propertyName); + }); + return parameterDefinitionsFinal; +} + +function _formatToolsToCohere( + tools: ChatCohereCallOptions["tools"] +): Cohere.Tool[] | undefined { + if (!tools) { + return undefined; + } else if (tools.every(isCohereTool)) { + return tools; + } else if (tools.every(isOpenAITool)) { + return tools.map((tool) => { + return { + name: tool.function.name, + description: tool.function.description ?? "", + parameterDefinitions: _convertJsonSchemaToCohereTool( + tool.function.parameters + ), + }; + }); + } else if (tools.every(isStructuredTool)) { + return tools.map((tool) => { + const parameterDefinitionsFromZod = zodToJsonSchema(tool.schema); + return { + name: tool.name, + description: tool.description, + parameterDefinitions: _convertJsonSchemaToCohereTool( + parameterDefinitionsFromZod + ), + }; + }); + } else { + throw new Error( + `Can not pass in a mix of tool schema types to ChatCohere.` + ); + } } /** @@ -114,7 +279,7 @@ function convertMessagesToCohereMessages( * ```typescript * const model = new ChatCohere({ * apiKey: process.env.COHERE_API_KEY, // Default - * model: "command" // Default + * model: "command-r-plus" // Default * }); * const response = await model.invoke([ * new HumanMessage("How tall are the largest pengiuns?") @@ -122,7 +287,7 @@ function convertMessagesToCohereMessages( * ``` */ export class ChatCohere< - CallOptions extends CohereChatCallOptions = CohereChatCallOptions + CallOptions extends ChatCohereCallOptions = ChatCohereCallOptions > extends BaseChatModel implements ChatCohereInput @@ -135,7 +300,7 @@ export class ChatCohere< client: CohereClient; - model = "command"; + model = "command-r-plus"; temperature = 0.3; @@ -189,6 +354,8 @@ export class ChatCohere< searchQueriesOnly: options.searchQueriesOnly, documents: options.documents, temperature: options.temperature ?? this.temperature, + forceSingleStep: options.forceSingleStep, + tools: options.tools, }; // Filter undefined entries return Object.fromEntries( @@ -196,6 +363,243 @@ export class ChatCohere< ); } + override bindTools( + tools: ( + | Cohere.Tool + | Record + | StructuredToolInterface + | ToolDefinition + )[], + kwargs?: Partial + ): Runnable { + return this.bind({ + tools: _formatToolsToCohere(tools), + ...kwargs, + } as Partial); + } + + /** @ignore */ + private _getChatRequest( + messages: BaseMessage[], + options: this["ParsedCallOptions"] + ): Cohere.ChatRequest { + const params = this.invocationParams(options); + + const toolResults = this._messagesToCohereToolResultsCurrChatTurn(messages); + const chatHistory = []; + let messageStr: string = ""; + let tempToolResults: { + call: Cohere.ToolCall; + outputs: any; + }[] = []; + + if (!params.forceSingleStep) { + for (let i = 0; i < messages.length - 1; i += 1) { + const message = messages[i]; + // If there are multiple tool messages, then we need to aggregate them into one single tool message to pass into chat history + if (message._getType().toLowerCase() === "tool") { + tempToolResults = tempToolResults.concat( + this._messageToCohereToolResults(messages, i) + ); + + if ( + i === messages.length - 1 || + !(messages[i + 1]._getType().toLowerCase() === "tool") + ) { + const cohere_message = convertMessageToCohereMessage( + message, + tempToolResults + ); + chatHistory.push(cohere_message); + tempToolResults = []; + } + } else { + chatHistory.push(convertMessageToCohereMessage(message, [])); + } + } + + messageStr = + toolResults.length > 0 + ? "" + : messages[messages.length - 1].content.toString(); + } else { + messageStr = ""; + + // if force_single_step is set to True, then message is the last human message in the conversation + for (let i = 0; i < messages.length - 1; i += 1) { + const message = messages[i]; + if (isAIMessage(message) && message.tool_calls) { + continue; + } + + // If there are multiple tool messages, then we need to aggregate them into one single tool message to pass into chat history + if (message._getType().toLowerCase() === "tool") { + tempToolResults = tempToolResults.concat( + this._messageToCohereToolResults(messages, i) + ); + + if ( + i === messages.length - 1 || + !(messages[i + 1]._getType().toLowerCase() === "tool") + ) { + const cohereMessage = convertMessageToCohereMessage( + message, + tempToolResults + ); + chatHistory.push(cohereMessage); + tempToolResults = []; + } + } else { + chatHistory.push(convertMessageToCohereMessage(message, [])); + } + } + + // Add the last human message in the conversation to the message string + for (let i = messages.length - 1; i >= 0; i -= 1) { + const message = messages[i]; + if (message._getType().toLowerCase() === "human" && message.content) { + messageStr = message.content.toString(); + break; + } + } + } + const req: Cohere.ChatRequest = { + message: messageStr, + chatHistory, + toolResults: toolResults.length > 0 ? toolResults : undefined, + ...params, + }; + + return req; + } + + private _getCurrChatTurnMessages(messages: BaseMessage[]): BaseMessage[] { + // Get the messages for the current chat turn. + const currentChatTurnMessages: BaseMessage[] = []; + for (let i = messages.length - 1; i >= 0; i -= 1) { + const message = messages[i]; + currentChatTurnMessages.push(message); + if (message._getType().toLowerCase() === "human") { + break; + } + } + return currentChatTurnMessages.reverse(); + } + + private _messagesToCohereToolResultsCurrChatTurn( + messages: BaseMessage[] + ): Array<{ + call: Cohere.ToolCall; + outputs: ReturnType; + }> { + /** Get tool_results from messages. */ + const toolResults: Array<{ + call: Cohere.ToolCall; + outputs: ReturnType; + }> = []; + const currChatTurnMessages = this._getCurrChatTurnMessages(messages); + + for (const message of currChatTurnMessages) { + if (isToolMessage(message)) { + const toolMessage = message; + const previousAiMsgs = currChatTurnMessages.filter( + (msg) => isAIMessage(msg) && msg.tool_calls !== undefined + ) as AIMessage[]; + if (previousAiMsgs.length > 0) { + const previousAiMsg = previousAiMsgs[previousAiMsgs.length - 1]; + if (previousAiMsg.tool_calls) { + toolResults.push( + ...previousAiMsg.tool_calls + .filter( + (lcToolCall) => lcToolCall.id === toolMessage.tool_call_id + ) + .map((lcToolCall) => ({ + call: { + name: lcToolCall.name, + parameters: lcToolCall.args, + }, + outputs: convertToDocuments(toolMessage.content), + })) + ); + } + } + } + } + return toolResults; + } + + private _messageToCohereToolResults( + messages: BaseMessage[], + toolMessageIndex: number + ): Array<{ call: Cohere.ToolCall; outputs: any }> { + /** Get tool_results from messages. */ + const toolResults: Array<{ call: Cohere.ToolCall; outputs: any }> = []; + const toolMessage = messages[toolMessageIndex]; + + if (!isToolMessage(toolMessage)) { + throw new Error( + "The message index does not correspond to an instance of ToolMessage" + ); + } + + const messagesUntilTool = messages.slice(0, toolMessageIndex); + const previousAiMessage = messagesUntilTool + .filter((message) => isAIMessage(message) && message.tool_calls) + .slice(-1)[0] as AIMessage; + + if (previousAiMessage.tool_calls) { + toolResults.push( + ...previousAiMessage.tool_calls + .filter((lcToolCall) => lcToolCall.id === toolMessage.tool_call_id) + .map((lcToolCall) => ({ + call: { + name: lcToolCall.name, + parameters: lcToolCall.args, + }, + outputs: convertToDocuments(toolMessage.content), + })) + ); + } + + return toolResults; + } + + private _formatCohereToolCalls(toolCalls: Cohere.ToolCall[] | null = null): { + id: string; + function: { + name: string; + arguments: Record; + }; + type: string; + }[] { + if (!toolCalls) { + return []; + } + + const formattedToolCalls = []; + for (const toolCall of toolCalls) { + formattedToolCalls.push({ + id: uuid.v4().substring(0, 32), + function: { + name: toolCall.name, + arguments: toolCall.parameters, // Convert arguments to string + }, + type: "function", + }); + } + return formattedToolCalls; + } + + private _convertCohereToolCallToLangchain( + toolCalls: Record[] + ): ToolCall[] { + return toolCalls.map((toolCall) => ({ + name: toolCall.function.name, + args: toolCall.function.arguments, + id: toolCall.id, + })); + } + /** @ignore */ async _generate( messages: BaseMessage[], @@ -203,26 +607,9 @@ export class ChatCohere< runManager?: CallbackManagerForLLMRun ): Promise { const tokenUsage: TokenUsage = {}; - const params = this.invocationParams(options); - const cohereMessages = convertMessagesToCohereMessages(messages); // The last message in the array is the most recent, all other messages // are apart of the chat history. - const lastMessage = cohereMessages[cohereMessages.length - 1]; - if (lastMessage.role === "TOOL") { - throw new Error( - "Cohere does not support tool messages as the most recent message in chat history." - ); - } - const { message } = lastMessage; - const chatHistory: Cohere.Message[] = []; - if (cohereMessages.length > 1) { - chatHistory.push(...cohereMessages.slice(0, -1)); - } - const input = { - ...params, - message, - chatHistory, - }; + const request = this._getChatRequest(messages, options); // Handle streaming if (this.streaming) { @@ -251,8 +638,7 @@ export class ChatCohere< async () => { let response; try { - response = await this.client.chat(input); - // eslint-disable-next-line @typescript-eslint/no-explicit-any + response = await this.client.chat(request); } catch (e: any) { e.status = e.status ?? e.statusCode; throw e; @@ -281,6 +667,19 @@ export class ChatCohere< const generationInfo: Record = { ...response }; delete generationInfo.text; + if (response.toolCalls && response.toolCalls.length > 0) { + // Only populate tool_calls when 1) present on the response and + // 2) has one or more calls. + generationInfo.toolCalls = this._formatCohereToolCalls( + response.toolCalls + ); + } + let toolCalls: ToolCall[] = []; + if ("toolCalls" in generationInfo) { + toolCalls = this._convertCohereToolCallToLangchain( + generationInfo.toolCalls as Record[] + ); + } const generations: ChatGeneration[] = [ { @@ -288,6 +687,7 @@ export class ChatCohere< message: new AIMessage({ content: response.text, additional_kwargs: generationInfo, + tool_calls: toolCalls, usage_metadata: { input_tokens: tokenUsage.promptTokens ?? 0, output_tokens: tokenUsage.completionTokens ?? 0, @@ -308,33 +708,13 @@ export class ChatCohere< options: this["ParsedCallOptions"], runManager?: CallbackManagerForLLMRun ): AsyncGenerator { - const params = this.invocationParams(options); - const cohereMessages = convertMessagesToCohereMessages(messages); - // The last message in the array is the most recent, all other messages - // are apart of the chat history. - const lastMessage = cohereMessages[cohereMessages.length - 1]; - if (lastMessage.role === "TOOL") { - throw new Error( - "Cohere does not support tool messages as the most recent message in chat history." - ); - } - const { message } = lastMessage; - const chatHistory: Cohere.Message[] = []; - if (cohereMessages.length > 1) { - chatHistory.push(...cohereMessages.slice(0, -1)); - } - const input = { - ...params, - message, - chatHistory, - }; + const request = this._getChatRequest(messages, options); // All models have a built in `this.caller` property for retries const stream = await this.caller.call(async () => { let stream; try { - stream = await this.client.chatStream(input); - // eslint-disable-next-line @typescript-eslint/no-explicit-any + stream = await this.client.chatStream(request); } catch (e: any) { e.status = e.status ?? e.statusCode; throw e; @@ -372,6 +752,30 @@ export class ChatCohere< // stream-end events contain the final token count const input_tokens = chunk.response.meta?.tokens?.inputTokens ?? 0; const output_tokens = chunk.response.meta?.tokens?.outputTokens ?? 0; + const chunkGenerationInfo: Record = { + ...chunk.response, + }; + + if (chunk.response.toolCalls && chunk.response.toolCalls.length > 0) { + // Only populate tool_calls when 1) present on the response and + // 2) has one or more calls. + chunkGenerationInfo.toolCalls = this._formatCohereToolCalls( + chunk.response.toolCalls + ); + } + + let toolCallChunks: ToolCallChunk[] = []; + const toolCalls = chunkGenerationInfo.toolCalls ?? []; + + if (toolCalls.length > 0) { + toolCallChunks = toolCalls.map((toolCall: any) => ({ + name: toolCall.function.name, + args: toolCall.function.arguments, + id: toolCall.id, + index: toolCall.index, + })); + } + yield new ChatGenerationChunk({ text: "", message: new AIMessageChunk({ @@ -379,6 +783,7 @@ export class ChatCohere< additional_kwargs: { eventType: "stream-end", }, + tool_call_chunks: toolCallChunks, usage_metadata: { input_tokens, output_tokens, @@ -387,13 +792,13 @@ export class ChatCohere< }), generationInfo: { eventType: "stream-end", + ...chunkGenerationInfo, }, }); } } } - /** @ignore */ _combineLLMOutput(...llmOutputs: CohereLLMOutput[]): CohereLLMOutput { return llmOutputs.reduce<{ [key in keyof CohereLLMOutput]: Required; diff --git a/libs/langchain-cohere/src/tests/chat_models.int.test.ts b/libs/langchain-cohere/src/tests/chat_models.int.test.ts index 5da7660249d2..857283937850 100644 --- a/libs/langchain-cohere/src/tests/chat_models.int.test.ts +++ b/libs/langchain-cohere/src/tests/chat_models.int.test.ts @@ -1,6 +1,12 @@ /* eslint-disable no-promise-executor-return */ import { test, expect } from "@jest/globals"; -import { AIMessageChunk, HumanMessage } from "@langchain/core/messages"; +import { + AIMessageChunk, + HumanMessage, + ToolMessage, +} from "@langchain/core/messages"; +import { z } from "zod"; +import { DynamicStructuredTool } from "@langchain/core/tools"; import { ChatCohere } from "../chat_models.js"; test("ChatCohere can invoke", async () => { @@ -140,3 +146,56 @@ test("Invoke token count usage_metadata", async () => { res.usage_metadata.input_tokens + res.usage_metadata.output_tokens ); }); + +test("Test model tool calling", async () => { + const model = new ChatCohere({ + model: "command-r-plus", + temperature: 0, + }); + const webSearchTool = new DynamicStructuredTool({ + name: "web_search", + description: "Search the web and return the answer", + schema: z.object({ + search_query: z + .string() + .describe("The search query to surf the internet for"), + }) as any /* eslint-disable-line @typescript-eslint/no-explicit-any */, + func: async ({ search_query }) => `${search_query}`, + }); + + const tools = [webSearchTool]; + const modelWithTools = model.bindTools(tools); + + const messages = [ + new HumanMessage( + "Who is the president of Singapore?? USE TOOLS TO SEARCH INTERNET!!!!" + ), + ]; + const res = await modelWithTools.invoke(messages); + console.log(res); + expect(res?.usage_metadata).toBeDefined(); + if (!res?.usage_metadata) { + return; + } + expect(res.usage_metadata.total_tokens).toBe( + res.usage_metadata.input_tokens + res.usage_metadata.output_tokens + ); + expect(res.tool_calls).toBeDefined(); + expect(res.tool_calls?.length).toBe(1); + const tool_id = res.response_metadata.toolCalls[0].id; + messages.push(res); + messages.push( + new ToolMessage( + "Aidan Gomez is the president of Singapore", + tool_id, + "web_search" + ) + ); + const resWithToolResults = await modelWithTools.invoke(messages); + console.log(resWithToolResults); + expect(resWithToolResults?.usage_metadata).toBeDefined(); + if (!resWithToolResults?.usage_metadata) { + return; + } + expect(resWithToolResults.content).toContain("Aidan Gomez"); +}); diff --git a/libs/langchain-cohere/src/tests/chat_models.standard.int.test.ts b/libs/langchain-cohere/src/tests/chat_models.standard.int.test.ts index ec47bbf4bf8b..5b6e0d025f1b 100644 --- a/libs/langchain-cohere/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-cohere/src/tests/chat_models.standard.int.test.ts @@ -2,10 +2,10 @@ import { test, expect } from "@jest/globals"; import { ChatModelIntegrationTests } from "@langchain/standard-tests"; import { AIMessageChunk } from "@langchain/core/messages"; -import { ChatCohere, CohereChatCallOptions } from "../chat_models.js"; +import { ChatCohere, ChatCohereCallOptions } from "../chat_models.js"; class ChatCohereStandardIntegrationTests extends ChatModelIntegrationTests< - CohereChatCallOptions, + ChatCohereCallOptions, AIMessageChunk > { constructor() { @@ -16,11 +16,19 @@ class ChatCohereStandardIntegrationTests extends ChatModelIntegrationTests< } super({ Cls: ChatCohere, - chatModelHasToolCalling: false, - chatModelHasStructuredOutput: false, + chatModelHasToolCalling: true, + chatModelHasStructuredOutput: true, constructorArgs: {}, }); } + + async testToolMessageHistoriesListContent() { + this.skipTestMessage( + "testToolMessageHistoriesListContent", + "ChatCohere", + "Anthropic-style tool calling is not supported." + ); + } } const testClass = new ChatCohereStandardIntegrationTests(); diff --git a/libs/langchain-cohere/src/tests/chat_models.standard.test.ts b/libs/langchain-cohere/src/tests/chat_models.standard.test.ts index dbfc2813ae83..6c01666c9168 100644 --- a/libs/langchain-cohere/src/tests/chat_models.standard.test.ts +++ b/libs/langchain-cohere/src/tests/chat_models.standard.test.ts @@ -2,17 +2,17 @@ import { test, expect } from "@jest/globals"; import { ChatModelUnitTests } from "@langchain/standard-tests"; import { AIMessageChunk } from "@langchain/core/messages"; -import { ChatCohere, CohereChatCallOptions } from "../chat_models.js"; +import { ChatCohere, ChatCohereCallOptions } from "../chat_models.js"; class ChatCohereStandardUnitTests extends ChatModelUnitTests< - CohereChatCallOptions, + ChatCohereCallOptions, AIMessageChunk > { constructor() { super({ Cls: ChatCohere, - chatModelHasToolCalling: false, - chatModelHasStructuredOutput: false, + chatModelHasToolCalling: true, + chatModelHasStructuredOutput: true, constructorArgs: {}, }); // This must be set so method like `.bindTools` or `.withStructuredOutput` diff --git a/libs/langchain-scripts/tsconfig.json b/libs/langchain-scripts/tsconfig.json index e0bc2f01fad7..75ced1455bd1 100644 --- a/libs/langchain-scripts/tsconfig.json +++ b/libs/langchain-scripts/tsconfig.json @@ -32,4 +32,4 @@ "docs", "bin/" ] -} \ No newline at end of file +} diff --git a/package.json b/package.json index 7df23f31dabc..e1cb30cdcbf0 100644 --- a/package.json +++ b/package.json @@ -61,7 +61,8 @@ "typedoc-plugin-markdown@next": "patch:typedoc-plugin-markdown@npm%3A4.0.0-next.6#./.yarn/patches/typedoc-plugin-markdown-npm-4.0.0-next.6-96b4b47746.patch", "voy-search@0.6.2": "patch:voy-search@npm%3A0.6.2#./.yarn/patches/voy-search-npm-0.6.2-d4aca30a0e.patch", "@langchain/core": "workspace:*", - "better-sqlite3": "9.4.0" + "better-sqlite3": "9.4.0", + "zod": "3.23.8" }, "lint-staged": { "**/*.{ts,tsx}": [ diff --git a/yarn.lock b/yarn.lock index b82e3b5d8ff1..97ea6b5018e3 100644 --- a/yarn.lock +++ b/yarn.lock @@ -10422,7 +10422,7 @@ __metadata: resolution: "@langchain/cohere@workspace:libs/langchain-cohere" dependencies: "@jest/globals": ^29.5.0 - "@langchain/core": ">=0.2.5 <0.3.0" + "@langchain/core": ">=0.2.14 <0.3.0" "@langchain/scripts": ~0.0.14 "@langchain/standard-tests": 0.0.0 "@swc/core": ^1.3.90 @@ -10447,6 +10447,9 @@ __metadata: rollup: ^4.5.2 ts-jest: ^29.1.0 typescript: <5.2.0 + uuid: ^10.0.0 + zod: ^3.23.8 + zod-to-json-schema: ^3.23.1 languageName: unknown linkType: soft @@ -41090,7 +41093,7 @@ __metadata: languageName: node linkType: hard -"zod-to-json-schema@npm:^3.23.0": +"zod-to-json-schema@npm:^3.23.0, zod-to-json-schema@npm:^3.23.1": version: 3.23.1 resolution: "zod-to-json-schema@npm:3.23.1" peerDependencies: @@ -41099,21 +41102,7 @@ __metadata: languageName: node linkType: hard -"zod@npm:^3.22.3, zod@npm:^3.22.4": - version: 3.22.4 - resolution: "zod@npm:3.22.4" - checksum: 80bfd7f8039b24fddeb0718a2ec7c02aa9856e4838d6aa4864335a047b6b37a3273b191ef335bf0b2002e5c514ef261ffcda5a589fb084a48c336ffc4cdbab7f - languageName: node - linkType: hard - -"zod@npm:^3.22.5": - version: 3.23.4 - resolution: "zod@npm:3.23.4" - checksum: 58f6e298c51d9ae01a1b1a1692ac7f00774b466d9a287a1ff8d61ff1fbe0ae9b0f050ae1cf1a8f71e4c6ccd0333a3cc340f339360fab5f5046cc954d10525a54 - languageName: node - linkType: hard - -"zod@npm:^3.23.8": +"zod@npm:3.23.8": version: 3.23.8 resolution: "zod@npm:3.23.8" checksum: 15949ff82118f59c893dacd9d3c766d02b6fa2e71cf474d5aa888570c469dbf5446ac5ad562bb035bf7ac9650da94f290655c194f4a6de3e766f43febd432c5c