diff --git a/docs/core_docs/docs/integrations/chat/cohere.mdx b/docs/core_docs/docs/integrations/chat/cohere.mdx
index 40e44f32f2c6..c455bed66b90 100644
--- a/docs/core_docs/docs/integrations/chat/cohere.mdx
+++ b/docs/core_docs/docs/integrations/chat/cohere.mdx
@@ -62,6 +62,14 @@ import StatefulChatExample from "@examples/models/chat/cohere/stateful_conversat
You can see the LangSmith traces from this example [here](https://smith.langchain.com/public/8e67b05a-4e63-414e-ac91-a91acf21b262/r) and [here](https://smith.langchain.com/public/50fabc25-46fe-4727-a59c-7e4eb0de8e70/r)
:::
+### Tools
+
+The Cohere API supports tool calling, along with multi-hop-tool calling. The following example demonstrates how to call tools:
+
+import ToolCallingExample from "@examples/models/chat/cohere/tool_calling.ts";
+
+{ToolCallingExample}
+
### RAG
Cohere also comes out of the box with RAG support.
diff --git a/examples/src/models/chat/cohere/chat_cohere.ts b/examples/src/models/chat/cohere/chat_cohere.ts
index 04ffed68aa5e..778eb9ab63f2 100644
--- a/examples/src/models/chat/cohere/chat_cohere.ts
+++ b/examples/src/models/chat/cohere/chat_cohere.ts
@@ -3,7 +3,6 @@ import { ChatPromptTemplate } from "@langchain/core/prompts";
const model = new ChatCohere({
apiKey: process.env.COHERE_API_KEY, // Default
- model: "command", // Default
});
const prompt = ChatPromptTemplate.fromMessages([
["ai", "You are a helpful assistant"],
diff --git a/examples/src/models/chat/cohere/chat_stream_cohere.ts b/examples/src/models/chat/cohere/chat_stream_cohere.ts
index 559fd9f4415f..a7ddd822608e 100644
--- a/examples/src/models/chat/cohere/chat_stream_cohere.ts
+++ b/examples/src/models/chat/cohere/chat_stream_cohere.ts
@@ -4,7 +4,6 @@ import { StringOutputParser } from "@langchain/core/output_parsers";
const model = new ChatCohere({
apiKey: process.env.COHERE_API_KEY, // Default
- model: "command", // Default
});
const prompt = ChatPromptTemplate.fromMessages([
["ai", "You are a helpful assistant"],
diff --git a/examples/src/models/chat/cohere/connectors.ts b/examples/src/models/chat/cohere/connectors.ts
index fd252dc7c76f..a16c2ed677c3 100644
--- a/examples/src/models/chat/cohere/connectors.ts
+++ b/examples/src/models/chat/cohere/connectors.ts
@@ -3,7 +3,6 @@ import { HumanMessage } from "@langchain/core/messages";
const model = new ChatCohere({
apiKey: process.env.COHERE_API_KEY, // Default
- model: "command", // Default
});
const response = await model.invoke(
diff --git a/examples/src/models/chat/cohere/rag.ts b/examples/src/models/chat/cohere/rag.ts
index 240225a33a46..b572dc8a1efe 100644
--- a/examples/src/models/chat/cohere/rag.ts
+++ b/examples/src/models/chat/cohere/rag.ts
@@ -3,7 +3,6 @@ import { HumanMessage } from "@langchain/core/messages";
const model = new ChatCohere({
apiKey: process.env.COHERE_API_KEY, // Default
- model: "command", // Default
});
const documents = [
diff --git a/examples/src/models/chat/cohere/stateful_conversation.ts b/examples/src/models/chat/cohere/stateful_conversation.ts
index 1edc61a47ab2..e126c4bf6bce 100644
--- a/examples/src/models/chat/cohere/stateful_conversation.ts
+++ b/examples/src/models/chat/cohere/stateful_conversation.ts
@@ -3,7 +3,6 @@ import { HumanMessage } from "@langchain/core/messages";
const model = new ChatCohere({
apiKey: process.env.COHERE_API_KEY, // Default
- model: "command", // Default
});
const conversationId = `demo_test_id-${Math.random()}`;
diff --git a/examples/src/models/chat/cohere/tool_calling.ts b/examples/src/models/chat/cohere/tool_calling.ts
new file mode 100644
index 000000000000..f08a5e6cb343
--- /dev/null
+++ b/examples/src/models/chat/cohere/tool_calling.ts
@@ -0,0 +1,57 @@
+import { ChatCohere } from "@langchain/cohere";
+import { HumanMessage } from "@langchain/core/messages";
+import { z } from "zod";
+import { DynamicStructuredTool } from "@langchain/core/tools";
+
+const model = new ChatCohere({
+ apiKey: process.env.COHERE_API_KEY, // Default
+});
+
+const magicFunctionTool = new DynamicStructuredTool({
+ name: "magic_function",
+ description: "Apply a magic function to the input number",
+ schema: z.object({
+ num: z.number().describe("The number to apply the magic function for"),
+ }),
+ func: async ({ num }) => {
+ return `The magic function of ${num} is ${num + 5}`;
+ },
+});
+
+const tools = [magicFunctionTool];
+const modelWithTools = model.bindTools(tools);
+
+const messages = [new HumanMessage("What is the magic function of number 5?")];
+const response = await modelWithTools.invoke(messages);
+/*
+ AIMessage {
+ content: 'I will use the magic_function tool to answer this question.',
+ name: undefined,
+ additional_kwargs: {
+ response_id: 'd0b189e5-3dbf-493c-93f8-99ed4b01d96d',
+ generationId: '8982a68f-c64c-48f8-bf12-0b4bea0018b6',
+ chatHistory: [ [Object], [Object] ],
+ finishReason: 'COMPLETE',
+ meta: { apiVersion: [Object], billedUnits: [Object], tokens: [Object] },
+ toolCalls: [ [Object] ]
+ },
+ response_metadata: {
+ estimatedTokenUsage: { completionTokens: 54, promptTokens: 920, totalTokens: 974 },
+ response_id: 'd0b189e5-3dbf-493c-93f8-99ed4b01d96d',
+ generationId: '8982a68f-c64c-48f8-bf12-0b4bea0018b6',
+ chatHistory: [ [Object], [Object] ],
+ finishReason: 'COMPLETE',
+ meta: { apiVersion: [Object], billedUnits: [Object], tokens: [Object] },
+ toolCalls: [ [Object] ]
+ },
+ tool_calls: [
+ {
+ name: 'magic_function',
+ args: [Object],
+ id: '4ec98550-ba9a-4043-adfe-566230e5'
+ }
+ ],
+ invalid_tool_calls: [],
+ usage_metadata: { input_tokens: 920, output_tokens: 54, total_tokens: 974 }
+ }
+*/
diff --git a/libs/langchain-cohere/.eslintrc.cjs b/libs/langchain-cohere/.eslintrc.cjs
index d533e6deffb6..59171b108443 100644
--- a/libs/langchain-cohere/.eslintrc.cjs
+++ b/libs/langchain-cohere/.eslintrc.cjs
@@ -33,6 +33,7 @@ module.exports = {
"@typescript-eslint/no-unused-vars": ["warn", { args: "none" }],
"@typescript-eslint/no-floating-promises": "error",
"@typescript-eslint/no-misused-promises": "error",
+ "arrow-body-style": 0,
camelcase: 0,
"class-methods-use-this": 0,
"import/extensions": [2, "ignorePackages"],
diff --git a/libs/langchain-cohere/package.json b/libs/langchain-cohere/package.json
index 4227d83388e7..c8dfc8bf1477 100644
--- a/libs/langchain-cohere/package.json
+++ b/libs/langchain-cohere/package.json
@@ -35,8 +35,11 @@
"author": "LangChain",
"license": "MIT",
"dependencies": {
- "@langchain/core": ">=0.2.5 <0.3.0",
- "cohere-ai": "^7.10.5"
+ "@langchain/core": ">=0.2.14 <0.3.0",
+ "cohere-ai": "^7.10.5",
+ "uuid": "^10.0.0",
+ "zod": "^3.23.8",
+ "zod-to-json-schema": "^3.23.1"
},
"devDependencies": {
"@jest/globals": "^29.5.0",
diff --git a/libs/langchain-cohere/src/chat_models.ts b/libs/langchain-cohere/src/chat_models.ts
index cc8184a5bb05..bd9d7001979b 100644
--- a/libs/langchain-cohere/src/chat_models.ts
+++ b/libs/langchain-cohere/src/chat_models.ts
@@ -1,12 +1,22 @@
+/* eslint-disable @typescript-eslint/no-explicit-any */
import { CohereClient, Cohere } from "cohere-ai";
+import { ToolResult } from "cohere-ai/api/index.js";
+import { zodToJsonSchema } from "zod-to-json-schema";
import {
MessageType,
type BaseMessage,
MessageContent,
AIMessage,
+ isAIMessage,
} from "@langchain/core/messages";
-import { type BaseLanguageModelCallOptions } from "@langchain/core/language_models/base";
+import {
+ BaseLanguageModelInput,
+ ToolDefinition,
+ isOpenAITool,
+ type BaseLanguageModelCallOptions,
+} from "@langchain/core/language_models/base";
+import { isStructuredTool } from "@langchain/core/utils/function_calling";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
type BaseChatModelParams,
@@ -21,6 +31,14 @@ import {
import { AIMessageChunk } from "@langchain/core/messages";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { NewTokenIndices } from "@langchain/core/callbacks/base";
+import {
+ ToolMessage,
+ ToolCall,
+ ToolCallChunk,
+} from "@langchain/core/messages/tool";
+import * as uuid from "uuid";
+import { StructuredToolInterface } from "@langchain/core/tools";
+import { Runnable } from "@langchain/core/runnables";
/**
* Input interface for ChatCohere
@@ -65,15 +83,62 @@ interface TokenUsage {
totalTokens?: number;
}
-export interface CohereChatCallOptions
+export interface ChatCohereCallOptions
extends BaseLanguageModelCallOptions,
- Partial>,
- Partial>,
- Pick {}
+ Partial>,
+ Partial>,
+ Pick {
+ tools?: (
+ | StructuredToolInterface
+ | Cohere.Tool
+ | Record
+ | ToolDefinition
+ )[];
+}
+
+/** @deprecated Import as ChatCohereCallOptions instead. */
+export interface CohereChatCallOptions extends ChatCohereCallOptions {}
+
+function convertToDocuments(
+ observations: MessageContent
+): Array> {
+ /** Converts observations into a 'document' dict */
+ const documents: Array> = [];
+ let observationsList: Array> = [];
+
+ if (typeof observations === "string") {
+ // strings are turned into a key/value pair and a key of 'output' is added.
+ observationsList = [{ output: observations }];
+ } else if (
+ // eslint-disable-next-line no-instanceof/no-instanceof
+ observations instanceof Map ||
+ (typeof observations === "object" &&
+ observations !== null &&
+ !Array.isArray(observations))
+ ) {
+ // single mappings are transformed into a list to simplify the rest of the code.
+ observationsList = [observations];
+ } else if (!Array.isArray(observations)) {
+ // all other types are turned into a key/value pair within a list
+ observationsList = [{ output: observations }];
+ }
+
+ for (let doc of observationsList) {
+ // eslint-disable-next-line no-instanceof/no-instanceof
+ if (!(doc instanceof Map) && (typeof doc !== "object" || doc === null)) {
+ // types that aren't Mapping are turned into a key/value pair.
+ doc = { output: doc };
+ }
+ documents.push(doc);
+ }
-function convertMessagesToCohereMessages(
- messages: Array
-): Array {
+ return documents;
+}
+
+function convertMessageToCohereMessage(
+ message: BaseMessage,
+ toolResults: ToolResult[]
+): Cohere.Message {
const getRole = (role: MessageType) => {
switch (role) {
case "system":
@@ -82,9 +147,11 @@ function convertMessagesToCohereMessages(
return "USER";
case "ai":
return "CHATBOT";
+ case "tool":
+ return "TOOL";
default:
throw new Error(
- `Unknown message type: '${role}'. Accepted types: 'human', 'ai', 'system'`
+ `Unknown message type: '${role}'. Accepted types: 'human', 'ai', 'system', 'tool'`
);
}
};
@@ -102,10 +169,108 @@ function convertMessagesToCohereMessages(
);
};
- return messages.map((message) => ({
- role: getRole(message._getType()),
- message: getContent(message.content),
- }));
+ const getToolCall = (message: BaseMessage): Cohere.ToolCall[] => {
+ if (isAIMessage(message) && message.tool_calls) {
+ return message.tool_calls.map((toolCall) => ({
+ name: toolCall.name,
+ parameters: toolCall.args,
+ }));
+ }
+ return [];
+ };
+ if (message._getType().toLowerCase() === "ai") {
+ return {
+ role: getRole(message._getType()),
+ message: getContent(message.content),
+ toolCalls: getToolCall(message),
+ };
+ } else if (message._getType().toLowerCase() === "tool") {
+ return {
+ role: getRole(message._getType()),
+ message: getContent(message.content),
+ toolResults,
+ };
+ } else if (
+ message._getType().toLowerCase() === "human" ||
+ message._getType().toLowerCase() === "system"
+ ) {
+ return {
+ role: getRole(message._getType()),
+ message: getContent(message.content),
+ };
+ } else {
+ throw new Error(
+ "Got unknown message type. Supported types are AIMessage, ToolMessage, HumanMessage, and SystemMessage"
+ );
+ }
+}
+
+function isCohereTool(tool: any): tool is Cohere.Tool {
+ return (
+ "name" in tool && "description" in tool && "parameterDefinitions" in tool
+ );
+}
+
+function isToolMessage(message: BaseMessage): message is ToolMessage {
+ return message._getType() === "tool";
+}
+
+function _convertJsonSchemaToCohereTool(jsonSchema: Record) {
+ const parameterDefinitionsProperties =
+ "properties" in jsonSchema ? jsonSchema.properties : {};
+ let parameterDefinitionsRequired =
+ "required" in jsonSchema ? jsonSchema.required : [];
+
+ const parameterDefinitionsFinal: Record = {};
+
+ // Iterate through all properties
+ Object.keys(parameterDefinitionsProperties).forEach((propertyName) => {
+ // Create the property in the new object
+ parameterDefinitionsFinal[propertyName] =
+ parameterDefinitionsProperties[propertyName];
+ // Set the required property based on the 'required' array
+ if (parameterDefinitionsRequired === undefined) {
+ parameterDefinitionsRequired = [];
+ }
+ parameterDefinitionsFinal[propertyName].required =
+ parameterDefinitionsRequired.includes(propertyName);
+ });
+ return parameterDefinitionsFinal;
+}
+
+function _formatToolsToCohere(
+ tools: ChatCohereCallOptions["tools"]
+): Cohere.Tool[] | undefined {
+ if (!tools) {
+ return undefined;
+ } else if (tools.every(isCohereTool)) {
+ return tools;
+ } else if (tools.every(isOpenAITool)) {
+ return tools.map((tool) => {
+ return {
+ name: tool.function.name,
+ description: tool.function.description ?? "",
+ parameterDefinitions: _convertJsonSchemaToCohereTool(
+ tool.function.parameters
+ ),
+ };
+ });
+ } else if (tools.every(isStructuredTool)) {
+ return tools.map((tool) => {
+ const parameterDefinitionsFromZod = zodToJsonSchema(tool.schema);
+ return {
+ name: tool.name,
+ description: tool.description,
+ parameterDefinitions: _convertJsonSchemaToCohereTool(
+ parameterDefinitionsFromZod
+ ),
+ };
+ });
+ } else {
+ throw new Error(
+ `Can not pass in a mix of tool schema types to ChatCohere.`
+ );
+ }
}
/**
@@ -114,7 +279,7 @@ function convertMessagesToCohereMessages(
* ```typescript
* const model = new ChatCohere({
* apiKey: process.env.COHERE_API_KEY, // Default
- * model: "command" // Default
+ * model: "command-r-plus" // Default
* });
* const response = await model.invoke([
* new HumanMessage("How tall are the largest pengiuns?")
@@ -122,7 +287,7 @@ function convertMessagesToCohereMessages(
* ```
*/
export class ChatCohere<
- CallOptions extends CohereChatCallOptions = CohereChatCallOptions
+ CallOptions extends ChatCohereCallOptions = ChatCohereCallOptions
>
extends BaseChatModel
implements ChatCohereInput
@@ -135,7 +300,7 @@ export class ChatCohere<
client: CohereClient;
- model = "command";
+ model = "command-r-plus";
temperature = 0.3;
@@ -189,6 +354,8 @@ export class ChatCohere<
searchQueriesOnly: options.searchQueriesOnly,
documents: options.documents,
temperature: options.temperature ?? this.temperature,
+ forceSingleStep: options.forceSingleStep,
+ tools: options.tools,
};
// Filter undefined entries
return Object.fromEntries(
@@ -196,6 +363,243 @@ export class ChatCohere<
);
}
+ override bindTools(
+ tools: (
+ | Cohere.Tool
+ | Record
+ | StructuredToolInterface
+ | ToolDefinition
+ )[],
+ kwargs?: Partial
+ ): Runnable {
+ return this.bind({
+ tools: _formatToolsToCohere(tools),
+ ...kwargs,
+ } as Partial);
+ }
+
+ /** @ignore */
+ private _getChatRequest(
+ messages: BaseMessage[],
+ options: this["ParsedCallOptions"]
+ ): Cohere.ChatRequest {
+ const params = this.invocationParams(options);
+
+ const toolResults = this._messagesToCohereToolResultsCurrChatTurn(messages);
+ const chatHistory = [];
+ let messageStr: string = "";
+ let tempToolResults: {
+ call: Cohere.ToolCall;
+ outputs: any;
+ }[] = [];
+
+ if (!params.forceSingleStep) {
+ for (let i = 0; i < messages.length - 1; i += 1) {
+ const message = messages[i];
+ // If there are multiple tool messages, then we need to aggregate them into one single tool message to pass into chat history
+ if (message._getType().toLowerCase() === "tool") {
+ tempToolResults = tempToolResults.concat(
+ this._messageToCohereToolResults(messages, i)
+ );
+
+ if (
+ i === messages.length - 1 ||
+ !(messages[i + 1]._getType().toLowerCase() === "tool")
+ ) {
+ const cohere_message = convertMessageToCohereMessage(
+ message,
+ tempToolResults
+ );
+ chatHistory.push(cohere_message);
+ tempToolResults = [];
+ }
+ } else {
+ chatHistory.push(convertMessageToCohereMessage(message, []));
+ }
+ }
+
+ messageStr =
+ toolResults.length > 0
+ ? ""
+ : messages[messages.length - 1].content.toString();
+ } else {
+ messageStr = "";
+
+ // if force_single_step is set to True, then message is the last human message in the conversation
+ for (let i = 0; i < messages.length - 1; i += 1) {
+ const message = messages[i];
+ if (isAIMessage(message) && message.tool_calls) {
+ continue;
+ }
+
+ // If there are multiple tool messages, then we need to aggregate them into one single tool message to pass into chat history
+ if (message._getType().toLowerCase() === "tool") {
+ tempToolResults = tempToolResults.concat(
+ this._messageToCohereToolResults(messages, i)
+ );
+
+ if (
+ i === messages.length - 1 ||
+ !(messages[i + 1]._getType().toLowerCase() === "tool")
+ ) {
+ const cohereMessage = convertMessageToCohereMessage(
+ message,
+ tempToolResults
+ );
+ chatHistory.push(cohereMessage);
+ tempToolResults = [];
+ }
+ } else {
+ chatHistory.push(convertMessageToCohereMessage(message, []));
+ }
+ }
+
+ // Add the last human message in the conversation to the message string
+ for (let i = messages.length - 1; i >= 0; i -= 1) {
+ const message = messages[i];
+ if (message._getType().toLowerCase() === "human" && message.content) {
+ messageStr = message.content.toString();
+ break;
+ }
+ }
+ }
+ const req: Cohere.ChatRequest = {
+ message: messageStr,
+ chatHistory,
+ toolResults: toolResults.length > 0 ? toolResults : undefined,
+ ...params,
+ };
+
+ return req;
+ }
+
+ private _getCurrChatTurnMessages(messages: BaseMessage[]): BaseMessage[] {
+ // Get the messages for the current chat turn.
+ const currentChatTurnMessages: BaseMessage[] = [];
+ for (let i = messages.length - 1; i >= 0; i -= 1) {
+ const message = messages[i];
+ currentChatTurnMessages.push(message);
+ if (message._getType().toLowerCase() === "human") {
+ break;
+ }
+ }
+ return currentChatTurnMessages.reverse();
+ }
+
+ private _messagesToCohereToolResultsCurrChatTurn(
+ messages: BaseMessage[]
+ ): Array<{
+ call: Cohere.ToolCall;
+ outputs: ReturnType;
+ }> {
+ /** Get tool_results from messages. */
+ const toolResults: Array<{
+ call: Cohere.ToolCall;
+ outputs: ReturnType;
+ }> = [];
+ const currChatTurnMessages = this._getCurrChatTurnMessages(messages);
+
+ for (const message of currChatTurnMessages) {
+ if (isToolMessage(message)) {
+ const toolMessage = message;
+ const previousAiMsgs = currChatTurnMessages.filter(
+ (msg) => isAIMessage(msg) && msg.tool_calls !== undefined
+ ) as AIMessage[];
+ if (previousAiMsgs.length > 0) {
+ const previousAiMsg = previousAiMsgs[previousAiMsgs.length - 1];
+ if (previousAiMsg.tool_calls) {
+ toolResults.push(
+ ...previousAiMsg.tool_calls
+ .filter(
+ (lcToolCall) => lcToolCall.id === toolMessage.tool_call_id
+ )
+ .map((lcToolCall) => ({
+ call: {
+ name: lcToolCall.name,
+ parameters: lcToolCall.args,
+ },
+ outputs: convertToDocuments(toolMessage.content),
+ }))
+ );
+ }
+ }
+ }
+ }
+ return toolResults;
+ }
+
+ private _messageToCohereToolResults(
+ messages: BaseMessage[],
+ toolMessageIndex: number
+ ): Array<{ call: Cohere.ToolCall; outputs: any }> {
+ /** Get tool_results from messages. */
+ const toolResults: Array<{ call: Cohere.ToolCall; outputs: any }> = [];
+ const toolMessage = messages[toolMessageIndex];
+
+ if (!isToolMessage(toolMessage)) {
+ throw new Error(
+ "The message index does not correspond to an instance of ToolMessage"
+ );
+ }
+
+ const messagesUntilTool = messages.slice(0, toolMessageIndex);
+ const previousAiMessage = messagesUntilTool
+ .filter((message) => isAIMessage(message) && message.tool_calls)
+ .slice(-1)[0] as AIMessage;
+
+ if (previousAiMessage.tool_calls) {
+ toolResults.push(
+ ...previousAiMessage.tool_calls
+ .filter((lcToolCall) => lcToolCall.id === toolMessage.tool_call_id)
+ .map((lcToolCall) => ({
+ call: {
+ name: lcToolCall.name,
+ parameters: lcToolCall.args,
+ },
+ outputs: convertToDocuments(toolMessage.content),
+ }))
+ );
+ }
+
+ return toolResults;
+ }
+
+ private _formatCohereToolCalls(toolCalls: Cohere.ToolCall[] | null = null): {
+ id: string;
+ function: {
+ name: string;
+ arguments: Record;
+ };
+ type: string;
+ }[] {
+ if (!toolCalls) {
+ return [];
+ }
+
+ const formattedToolCalls = [];
+ for (const toolCall of toolCalls) {
+ formattedToolCalls.push({
+ id: uuid.v4().substring(0, 32),
+ function: {
+ name: toolCall.name,
+ arguments: toolCall.parameters, // Convert arguments to string
+ },
+ type: "function",
+ });
+ }
+ return formattedToolCalls;
+ }
+
+ private _convertCohereToolCallToLangchain(
+ toolCalls: Record[]
+ ): ToolCall[] {
+ return toolCalls.map((toolCall) => ({
+ name: toolCall.function.name,
+ args: toolCall.function.arguments,
+ id: toolCall.id,
+ }));
+ }
+
/** @ignore */
async _generate(
messages: BaseMessage[],
@@ -203,26 +607,9 @@ export class ChatCohere<
runManager?: CallbackManagerForLLMRun
): Promise {
const tokenUsage: TokenUsage = {};
- const params = this.invocationParams(options);
- const cohereMessages = convertMessagesToCohereMessages(messages);
// The last message in the array is the most recent, all other messages
// are apart of the chat history.
- const lastMessage = cohereMessages[cohereMessages.length - 1];
- if (lastMessage.role === "TOOL") {
- throw new Error(
- "Cohere does not support tool messages as the most recent message in chat history."
- );
- }
- const { message } = lastMessage;
- const chatHistory: Cohere.Message[] = [];
- if (cohereMessages.length > 1) {
- chatHistory.push(...cohereMessages.slice(0, -1));
- }
- const input = {
- ...params,
- message,
- chatHistory,
- };
+ const request = this._getChatRequest(messages, options);
// Handle streaming
if (this.streaming) {
@@ -251,8 +638,7 @@ export class ChatCohere<
async () => {
let response;
try {
- response = await this.client.chat(input);
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ response = await this.client.chat(request);
} catch (e: any) {
e.status = e.status ?? e.statusCode;
throw e;
@@ -281,6 +667,19 @@ export class ChatCohere<
const generationInfo: Record = { ...response };
delete generationInfo.text;
+ if (response.toolCalls && response.toolCalls.length > 0) {
+ // Only populate tool_calls when 1) present on the response and
+ // 2) has one or more calls.
+ generationInfo.toolCalls = this._formatCohereToolCalls(
+ response.toolCalls
+ );
+ }
+ let toolCalls: ToolCall[] = [];
+ if ("toolCalls" in generationInfo) {
+ toolCalls = this._convertCohereToolCallToLangchain(
+ generationInfo.toolCalls as Record[]
+ );
+ }
const generations: ChatGeneration[] = [
{
@@ -288,6 +687,7 @@ export class ChatCohere<
message: new AIMessage({
content: response.text,
additional_kwargs: generationInfo,
+ tool_calls: toolCalls,
usage_metadata: {
input_tokens: tokenUsage.promptTokens ?? 0,
output_tokens: tokenUsage.completionTokens ?? 0,
@@ -308,33 +708,13 @@ export class ChatCohere<
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator {
- const params = this.invocationParams(options);
- const cohereMessages = convertMessagesToCohereMessages(messages);
- // The last message in the array is the most recent, all other messages
- // are apart of the chat history.
- const lastMessage = cohereMessages[cohereMessages.length - 1];
- if (lastMessage.role === "TOOL") {
- throw new Error(
- "Cohere does not support tool messages as the most recent message in chat history."
- );
- }
- const { message } = lastMessage;
- const chatHistory: Cohere.Message[] = [];
- if (cohereMessages.length > 1) {
- chatHistory.push(...cohereMessages.slice(0, -1));
- }
- const input = {
- ...params,
- message,
- chatHistory,
- };
+ const request = this._getChatRequest(messages, options);
// All models have a built in `this.caller` property for retries
const stream = await this.caller.call(async () => {
let stream;
try {
- stream = await this.client.chatStream(input);
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ stream = await this.client.chatStream(request);
} catch (e: any) {
e.status = e.status ?? e.statusCode;
throw e;
@@ -372,6 +752,30 @@ export class ChatCohere<
// stream-end events contain the final token count
const input_tokens = chunk.response.meta?.tokens?.inputTokens ?? 0;
const output_tokens = chunk.response.meta?.tokens?.outputTokens ?? 0;
+ const chunkGenerationInfo: Record = {
+ ...chunk.response,
+ };
+
+ if (chunk.response.toolCalls && chunk.response.toolCalls.length > 0) {
+ // Only populate tool_calls when 1) present on the response and
+ // 2) has one or more calls.
+ chunkGenerationInfo.toolCalls = this._formatCohereToolCalls(
+ chunk.response.toolCalls
+ );
+ }
+
+ let toolCallChunks: ToolCallChunk[] = [];
+ const toolCalls = chunkGenerationInfo.toolCalls ?? [];
+
+ if (toolCalls.length > 0) {
+ toolCallChunks = toolCalls.map((toolCall: any) => ({
+ name: toolCall.function.name,
+ args: toolCall.function.arguments,
+ id: toolCall.id,
+ index: toolCall.index,
+ }));
+ }
+
yield new ChatGenerationChunk({
text: "",
message: new AIMessageChunk({
@@ -379,6 +783,7 @@ export class ChatCohere<
additional_kwargs: {
eventType: "stream-end",
},
+ tool_call_chunks: toolCallChunks,
usage_metadata: {
input_tokens,
output_tokens,
@@ -387,13 +792,13 @@ export class ChatCohere<
}),
generationInfo: {
eventType: "stream-end",
+ ...chunkGenerationInfo,
},
});
}
}
}
- /** @ignore */
_combineLLMOutput(...llmOutputs: CohereLLMOutput[]): CohereLLMOutput {
return llmOutputs.reduce<{
[key in keyof CohereLLMOutput]: Required;
diff --git a/libs/langchain-cohere/src/tests/chat_models.int.test.ts b/libs/langchain-cohere/src/tests/chat_models.int.test.ts
index 5da7660249d2..857283937850 100644
--- a/libs/langchain-cohere/src/tests/chat_models.int.test.ts
+++ b/libs/langchain-cohere/src/tests/chat_models.int.test.ts
@@ -1,6 +1,12 @@
/* eslint-disable no-promise-executor-return */
import { test, expect } from "@jest/globals";
-import { AIMessageChunk, HumanMessage } from "@langchain/core/messages";
+import {
+ AIMessageChunk,
+ HumanMessage,
+ ToolMessage,
+} from "@langchain/core/messages";
+import { z } from "zod";
+import { DynamicStructuredTool } from "@langchain/core/tools";
import { ChatCohere } from "../chat_models.js";
test("ChatCohere can invoke", async () => {
@@ -140,3 +146,56 @@ test("Invoke token count usage_metadata", async () => {
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
);
});
+
+test("Test model tool calling", async () => {
+ const model = new ChatCohere({
+ model: "command-r-plus",
+ temperature: 0,
+ });
+ const webSearchTool = new DynamicStructuredTool({
+ name: "web_search",
+ description: "Search the web and return the answer",
+ schema: z.object({
+ search_query: z
+ .string()
+ .describe("The search query to surf the internet for"),
+ }) as any /* eslint-disable-line @typescript-eslint/no-explicit-any */,
+ func: async ({ search_query }) => `${search_query}`,
+ });
+
+ const tools = [webSearchTool];
+ const modelWithTools = model.bindTools(tools);
+
+ const messages = [
+ new HumanMessage(
+ "Who is the president of Singapore?? USE TOOLS TO SEARCH INTERNET!!!!"
+ ),
+ ];
+ const res = await modelWithTools.invoke(messages);
+ console.log(res);
+ expect(res?.usage_metadata).toBeDefined();
+ if (!res?.usage_metadata) {
+ return;
+ }
+ expect(res.usage_metadata.total_tokens).toBe(
+ res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
+ );
+ expect(res.tool_calls).toBeDefined();
+ expect(res.tool_calls?.length).toBe(1);
+ const tool_id = res.response_metadata.toolCalls[0].id;
+ messages.push(res);
+ messages.push(
+ new ToolMessage(
+ "Aidan Gomez is the president of Singapore",
+ tool_id,
+ "web_search"
+ )
+ );
+ const resWithToolResults = await modelWithTools.invoke(messages);
+ console.log(resWithToolResults);
+ expect(resWithToolResults?.usage_metadata).toBeDefined();
+ if (!resWithToolResults?.usage_metadata) {
+ return;
+ }
+ expect(resWithToolResults.content).toContain("Aidan Gomez");
+});
diff --git a/libs/langchain-cohere/src/tests/chat_models.standard.int.test.ts b/libs/langchain-cohere/src/tests/chat_models.standard.int.test.ts
index ec47bbf4bf8b..5b6e0d025f1b 100644
--- a/libs/langchain-cohere/src/tests/chat_models.standard.int.test.ts
+++ b/libs/langchain-cohere/src/tests/chat_models.standard.int.test.ts
@@ -2,10 +2,10 @@
import { test, expect } from "@jest/globals";
import { ChatModelIntegrationTests } from "@langchain/standard-tests";
import { AIMessageChunk } from "@langchain/core/messages";
-import { ChatCohere, CohereChatCallOptions } from "../chat_models.js";
+import { ChatCohere, ChatCohereCallOptions } from "../chat_models.js";
class ChatCohereStandardIntegrationTests extends ChatModelIntegrationTests<
- CohereChatCallOptions,
+ ChatCohereCallOptions,
AIMessageChunk
> {
constructor() {
@@ -16,11 +16,19 @@ class ChatCohereStandardIntegrationTests extends ChatModelIntegrationTests<
}
super({
Cls: ChatCohere,
- chatModelHasToolCalling: false,
- chatModelHasStructuredOutput: false,
+ chatModelHasToolCalling: true,
+ chatModelHasStructuredOutput: true,
constructorArgs: {},
});
}
+
+ async testToolMessageHistoriesListContent() {
+ this.skipTestMessage(
+ "testToolMessageHistoriesListContent",
+ "ChatCohere",
+ "Anthropic-style tool calling is not supported."
+ );
+ }
}
const testClass = new ChatCohereStandardIntegrationTests();
diff --git a/libs/langchain-cohere/src/tests/chat_models.standard.test.ts b/libs/langchain-cohere/src/tests/chat_models.standard.test.ts
index dbfc2813ae83..6c01666c9168 100644
--- a/libs/langchain-cohere/src/tests/chat_models.standard.test.ts
+++ b/libs/langchain-cohere/src/tests/chat_models.standard.test.ts
@@ -2,17 +2,17 @@
import { test, expect } from "@jest/globals";
import { ChatModelUnitTests } from "@langchain/standard-tests";
import { AIMessageChunk } from "@langchain/core/messages";
-import { ChatCohere, CohereChatCallOptions } from "../chat_models.js";
+import { ChatCohere, ChatCohereCallOptions } from "../chat_models.js";
class ChatCohereStandardUnitTests extends ChatModelUnitTests<
- CohereChatCallOptions,
+ ChatCohereCallOptions,
AIMessageChunk
> {
constructor() {
super({
Cls: ChatCohere,
- chatModelHasToolCalling: false,
- chatModelHasStructuredOutput: false,
+ chatModelHasToolCalling: true,
+ chatModelHasStructuredOutput: true,
constructorArgs: {},
});
// This must be set so method like `.bindTools` or `.withStructuredOutput`
diff --git a/libs/langchain-scripts/tsconfig.json b/libs/langchain-scripts/tsconfig.json
index e0bc2f01fad7..75ced1455bd1 100644
--- a/libs/langchain-scripts/tsconfig.json
+++ b/libs/langchain-scripts/tsconfig.json
@@ -32,4 +32,4 @@
"docs",
"bin/"
]
-}
\ No newline at end of file
+}
diff --git a/package.json b/package.json
index 7df23f31dabc..e1cb30cdcbf0 100644
--- a/package.json
+++ b/package.json
@@ -61,7 +61,8 @@
"typedoc-plugin-markdown@next": "patch:typedoc-plugin-markdown@npm%3A4.0.0-next.6#./.yarn/patches/typedoc-plugin-markdown-npm-4.0.0-next.6-96b4b47746.patch",
"voy-search@0.6.2": "patch:voy-search@npm%3A0.6.2#./.yarn/patches/voy-search-npm-0.6.2-d4aca30a0e.patch",
"@langchain/core": "workspace:*",
- "better-sqlite3": "9.4.0"
+ "better-sqlite3": "9.4.0",
+ "zod": "3.23.8"
},
"lint-staged": {
"**/*.{ts,tsx}": [
diff --git a/yarn.lock b/yarn.lock
index b82e3b5d8ff1..97ea6b5018e3 100644
--- a/yarn.lock
+++ b/yarn.lock
@@ -10422,7 +10422,7 @@ __metadata:
resolution: "@langchain/cohere@workspace:libs/langchain-cohere"
dependencies:
"@jest/globals": ^29.5.0
- "@langchain/core": ">=0.2.5 <0.3.0"
+ "@langchain/core": ">=0.2.14 <0.3.0"
"@langchain/scripts": ~0.0.14
"@langchain/standard-tests": 0.0.0
"@swc/core": ^1.3.90
@@ -10447,6 +10447,9 @@ __metadata:
rollup: ^4.5.2
ts-jest: ^29.1.0
typescript: <5.2.0
+ uuid: ^10.0.0
+ zod: ^3.23.8
+ zod-to-json-schema: ^3.23.1
languageName: unknown
linkType: soft
@@ -41090,7 +41093,7 @@ __metadata:
languageName: node
linkType: hard
-"zod-to-json-schema@npm:^3.23.0":
+"zod-to-json-schema@npm:^3.23.0, zod-to-json-schema@npm:^3.23.1":
version: 3.23.1
resolution: "zod-to-json-schema@npm:3.23.1"
peerDependencies:
@@ -41099,21 +41102,7 @@ __metadata:
languageName: node
linkType: hard
-"zod@npm:^3.22.3, zod@npm:^3.22.4":
- version: 3.22.4
- resolution: "zod@npm:3.22.4"
- checksum: 80bfd7f8039b24fddeb0718a2ec7c02aa9856e4838d6aa4864335a047b6b37a3273b191ef335bf0b2002e5c514ef261ffcda5a589fb084a48c336ffc4cdbab7f
- languageName: node
- linkType: hard
-
-"zod@npm:^3.22.5":
- version: 3.23.4
- resolution: "zod@npm:3.23.4"
- checksum: 58f6e298c51d9ae01a1b1a1692ac7f00774b466d9a287a1ff8d61ff1fbe0ae9b0f050ae1cf1a8f71e4c6ccd0333a3cc340f339360fab5f5046cc954d10525a54
- languageName: node
- linkType: hard
-
-"zod@npm:^3.23.8":
+"zod@npm:3.23.8":
version: 3.23.8
resolution: "zod@npm:3.23.8"
checksum: 15949ff82118f59c893dacd9d3c766d02b6fa2e71cf474d5aa888570c469dbf5446ac5ad562bb035bf7ac9650da94f290655c194f4a6de3e766f43febd432c5c