Skip to content

Commit

Permalink
core[patch]: Allow runnable tools to take single string/ToolCall in…
Browse files Browse the repository at this point in the history
…puts (#6096)

* core[patch]: Allow runnable tools to take single string inputs

* add test for tool func

* chore: lint files

* cr

* cr

* cr

* fix types

* rename ZodAny to ZodObjectAny

* docstring nits

* fiox

* cr

* cr
  • Loading branch information
bracesproul authored Jul 17, 2024
1 parent 2c000dd commit 1c4b2d8
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 56 deletions.
1 change: 1 addition & 0 deletions examples/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"@langchain/google-vertexai": "workspace:*",
"@langchain/google-vertexai-web": "workspace:*",
"@langchain/groq": "workspace:*",
"@langchain/langgraph": "^0.0.28",
"@langchain/mistralai": "workspace:*",
"@langchain/mongodb": "workspace:*",
"@langchain/nomic": "workspace:*",
Expand Down
46 changes: 41 additions & 5 deletions langchain-core/src/runnables/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ import {
isIterableIterator,
isIterator,
} from "./iter.js";
import { _isToolCall, ToolInputParsingException } from "../tools/utils.js";
import { ToolCall } from "../messages/tool.js";

export { type RunnableInterface, RunnableBatchOptions };

Expand Down Expand Up @@ -1095,7 +1097,7 @@ export abstract class Runnable<
name?: string;
description?: string;
schema: z.ZodType<T>;
}): RunnableToolLike<z.ZodType<T>, RunOutput> {
}): RunnableToolLike<z.ZodType<T | ToolCall>, RunOutput> {
return convertRunnableToTool<T, RunOutput>(this, fields);
}
}
Expand Down Expand Up @@ -2828,8 +2830,29 @@ export class RunnableToolLike<
schema: RunInput;

constructor(fields: RunnableToolLikeArgs<RunInput, RunOutput>) {
const sequence = RunnableSequence.from([
RunnableLambda.from(async (input) => {
let toolInput: z.TypeOf<RunInput>;

if (_isToolCall(input)) {
try {
toolInput = await this.schema.parseAsync(input.args);
} catch (e) {
throw new ToolInputParsingException(
`Received tool input did not match expected schema`,
JSON.stringify(input.args)
);
}
} else {
toolInput = input;
}
return toolInput;
}).withConfig({ runName: `${fields.name}:parse_input` }),
fields.bound,
]).withConfig({ runName: fields.name });

super({
bound: fields.bound,
bound: sequence,
config: fields.config ?? {},
});

Expand Down Expand Up @@ -2863,11 +2886,24 @@ export function convertRunnableToTool<RunInput, RunOutput>(
description?: string;
schema: z.ZodType<RunInput>;
}
): RunnableToolLike<z.ZodType<RunInput>, RunOutput> {
): RunnableToolLike<z.ZodType<RunInput | ToolCall>, RunOutput> {
const name = fields.name ?? runnable.getName();
const description = fields.description ?? fields.schema.description;
const description = fields.description ?? fields.schema?.description;

if (fields.schema.constructor === z.ZodString) {
return new RunnableToolLike<z.ZodType<RunInput | ToolCall>, RunOutput>({
name,
description,
schema: z
.object({
input: z.string(),
})
.transform((input) => input.input) as z.ZodType,
bound: runnable,
});
}

return new RunnableToolLike<z.ZodType<RunInput>, RunOutput>({
return new RunnableToolLike<z.ZodType<RunInput | ToolCall>, RunOutput>({
name,
description,
schema: fields.schema,
Expand Down
42 changes: 42 additions & 0 deletions langchain-core/src/runnables/tests/runnable_tools.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { z } from "zod";
import { RunnableLambda, RunnableToolLike } from "../base.js";
import { FakeRetriever } from "../../utils/testing/index.js";
import { Document } from "../../documents/document.js";

test("Runnable asTool works", async () => {
const schema = z.object({
Expand Down Expand Up @@ -137,3 +139,43 @@ test("Runnable asTool uses Zod schema description if not provided", async () =>

expect(tool.description).toBe(description);
});

test("Runnable asTool can accept a string zod schema", async () => {
const lambda = RunnableLambda.from<string, string>((input) => {
return `${input}a`;
}).asTool({
name: "string_tool",
description: "A tool that appends 'a' to the input string",
schema: z.string(),
});

const result = await lambda.invoke("b");
expect(result).toBe("ba");
});

test("Runnables which dont accept ToolCalls as inputs can accept ToolCalls", async () => {
const pageContent = "Dogs are pretty cool, man!";
const retriever = new FakeRetriever({
output: [
new Document({
pageContent,
}),
],
});
const tool = retriever.asTool({
name: "pet_info_retriever",
description: "Get information about pets.",
schema: z.string(),
});

const result = await tool.invoke({
type: "tool_call",
name: "pet_info_retriever",
args: {
input: "dogs",
},
id: "string",
});
expect(result).toHaveLength(1);
expect(result[0].pageContent).toBe(pageContent);
});
102 changes: 52 additions & 50 deletions langchain-core/src/tools/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ import {
} from "../runnables/config.js";
import type { RunnableFunc, RunnableInterface } from "../runnables/base.js";
import { ToolCall, ToolMessage } from "../messages/tool.js";
import { ZodAny } from "../types/zod.js";
import { ZodObjectAny } from "../types/zod.js";
import { MessageContent } from "../messages/base.js";
import { AsyncLocalStorageProviderSingleton } from "../singletons/index.js";
import { _isToolCall, ToolInputParsingException } from "./utils.js";

export { ToolInputParsingException };

export type ResponseFormat = "content" | "content_and_artifact" | string;

Expand All @@ -44,21 +47,7 @@ export interface ToolParams extends BaseLangChainParams {
responseFormat?: ResponseFormat;
}

/**
* Custom error class used to handle exceptions related to tool input parsing.
* It extends the built-in `Error` class and adds an optional `output`
* property that can hold the output that caused the exception.
*/
export class ToolInputParsingException extends Error {
output?: string;

constructor(message: string, output?: string) {
super(message);
this.output = output;
}
}

export interface StructuredToolInterface<T extends ZodAny = ZodAny>
export interface StructuredToolInterface<T extends ZodObjectAny = ZodObjectAny>
extends RunnableInterface<
(z.output<T> extends string ? string : never) | z.input<T> | ToolCall,
ToolReturnType
Expand Down Expand Up @@ -96,7 +85,7 @@ export interface StructuredToolInterface<T extends ZodAny = ZodAny>
* Base class for Tools that accept input of any shape defined by a Zod schema.
*/
export abstract class StructuredTool<
T extends ZodAny = ZodAny
T extends ZodObjectAny = ZodObjectAny
> extends BaseLangChain<
(z.output<T> extends string ? string : never) | z.input<T> | ToolCall,
ToolReturnType
Expand Down Expand Up @@ -259,7 +248,7 @@ export abstract class StructuredTool<
}
}

export interface ToolInterface<T extends ZodAny = ZodAny>
export interface ToolInterface<T extends ZodObjectAny = ZodObjectAny>
extends StructuredToolInterface<T> {
/**
* @deprecated Use .invoke() instead. Will be removed in 0.3.0.
Expand All @@ -279,7 +268,7 @@ export interface ToolInterface<T extends ZodAny = ZodAny>
/**
* Base class for Tools that accept input as a string.
*/
export abstract class Tool extends StructuredTool<ZodAny> {
export abstract class Tool extends StructuredTool<ZodObjectAny> {
schema = z
.object({ input: z.string().optional() })
.transform((obj) => obj.input);
Expand Down Expand Up @@ -328,8 +317,9 @@ export interface DynamicToolInput extends BaseDynamicToolInput {
/**
* Interface for the input parameters of the DynamicStructuredTool class.
*/
export interface DynamicStructuredToolInput<T extends ZodAny = ZodAny>
extends BaseDynamicToolInput {
export interface DynamicStructuredToolInput<
T extends ZodObjectAny = ZodObjectAny
> extends BaseDynamicToolInput {
func: (
input: BaseDynamicToolInput["responseFormat"] extends "content_and_artifact"
? ToolCall
Expand Down Expand Up @@ -393,7 +383,7 @@ export class DynamicTool extends Tool {
* provided function when the tool is called.
*/
export class DynamicStructuredTool<
T extends ZodAny = ZodAny
T extends ZodObjectAny = ZodObjectAny
> extends StructuredTool<T> {
static lc_name() {
return "DynamicStructuredTool";
Expand Down Expand Up @@ -456,11 +446,11 @@ export abstract class BaseToolkit {

/**
* Parameters for the tool function.
* @template {ZodAny} RunInput The input schema for the tool.
* @template {any} RunOutput The output type for the tool.
* @template {ZodObjectAny | z.ZodString = ZodObjectAny} RunInput The input schema for the tool. Either any Zod object, or a Zod string.
*/
interface ToolWrapperParams<RunInput extends ZodAny = ZodAny>
extends ToolParams {
interface ToolWrapperParams<
RunInput extends ZodObjectAny | z.ZodString = ZodObjectAny
> extends ToolParams {
/**
* The name of the tool. If using with an LLM, this
* will be passed as the tool name.
Expand Down Expand Up @@ -491,33 +481,54 @@ interface ToolWrapperParams<RunInput extends ZodAny = ZodAny>

/**
* Creates a new StructuredTool instance with the provided function, name, description, and schema.
*
* @function
* @template {RunInput extends ZodAny = ZodAny} RunInput The input schema for the tool. This corresponds to the input type when the tool is invoked.
* @template {RunOutput = any} RunOutput The output type for the tool. This corresponds to the output type when the tool is invoked.
* @template {FuncInput extends z.infer<RunInput> | ToolCall = z.infer<RunInput>} FuncInput The input type for the function.
* @template {ZodObjectAny | z.ZodString = ZodObjectAny} T The input schema for the tool. Either any Zod object, or a Zod string.
*
* @param {RunnableFunc<z.infer<RunInput> | ToolCall, RunOutput>} func - The function to invoke when the tool is called.
* @param fields - An object containing the following properties:
* @param {RunnableFunc<z.output<T>, ToolReturnType>} func - The function to invoke when the tool is called.
* @param {ToolWrapperParams<T>} fields - An object containing the following properties:
* @param {string} fields.name The name of the tool.
* @param {string | undefined} fields.description The description of the tool. Defaults to either the description on the Zod schema, or `${fields.name} tool`.
* @param {z.ZodObject<any, any, any, any>} fields.schema The Zod schema defining the input for the tool.
* @param {ZodObjectAny | z.ZodString | undefined} fields.schema The Zod schema defining the input for the tool. If undefined, it will default to a Zod string schema.
*
* @returns {DynamicStructuredTool<RunInput, RunOutput>} A new StructuredTool instance.
* @returns {DynamicStructuredTool<T>} A new StructuredTool instance.
*/
export function tool<T extends ZodAny = ZodAny>(
export function tool<T extends z.ZodString = z.ZodString>(
func: RunnableFunc<z.output<T>, ToolReturnType>,
fields: ToolWrapperParams<T>
): DynamicStructuredTool<T> {
const schema =
fields.schema ??
z.object({ input: z.string().optional() }).transform((obj) => obj.input);
): DynamicTool;

export function tool<T extends ZodObjectAny = ZodObjectAny>(
func: RunnableFunc<z.output<T>, ToolReturnType>,
fields: ToolWrapperParams<T>
): DynamicStructuredTool<T>;

export function tool<T extends ZodObjectAny | z.ZodString = ZodObjectAny>(
func: RunnableFunc<z.output<T>, ToolReturnType>,
fields: ToolWrapperParams<T>
):
| DynamicStructuredTool<T extends ZodObjectAny ? T : ZodObjectAny>
| DynamicTool {
// If the schema is not provided, or it's a string schema, create a DynamicTool
if (!fields.schema || !("shape" in fields.schema) || !fields.schema.shape) {
return new DynamicTool({
name: fields.name,
description:
fields.description ??
fields.schema?.description ??
`${fields.name} tool`,
responseFormat: fields.responseFormat,
func,
});
}

const description =
fields.description ?? schema.description ?? `${fields.name} tool`;
return new DynamicStructuredTool({
fields.description ?? fields.schema.description ?? `${fields.name} tool`;

return new DynamicStructuredTool<T extends ZodObjectAny ? T : ZodObjectAny>({
name: fields.name,
description,
schema: schema as T,
schema: fields.schema as T extends ZodObjectAny ? T : ZodObjectAny,
// TODO: Consider moving into DynamicStructuredTool constructor
func: async (input, runManager, config) => {
return new Promise((resolve, reject) => {
Expand All @@ -540,15 +551,6 @@ export function tool<T extends ZodAny = ZodAny>(
});
}

function _isToolCall(toolCall?: unknown): toolCall is ToolCall {
return !!(
toolCall &&
typeof toolCall === "object" &&
"type" in toolCall &&
toolCall.type === "tool_call"
);
}

function _formatToolOutput(params: {
content: unknown;
name: string;
Expand Down
16 changes: 16 additions & 0 deletions langchain-core/src/tools/tests/tools.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,19 @@ test("Returns tool message if responseFormat is content_and_artifact and returns
expect(toolResult.artifact).toEqual({ location: "San Francisco" });
expect(toolResult.name).toBe("weather");
});

test("Tool can accept single string input", async () => {
const stringTool = tool<z.ZodString>(
(input: string): string => {
return `${input}a`;
},
{
name: "string_tool",
description: "A tool that appends 'a' to the input string",
schema: z.string(),
}
);

const result = await stringTool.invoke("b");
expect(result).toBe("ba");
});
24 changes: 24 additions & 0 deletions langchain-core/src/tools/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { ToolCall } from "../messages/tool.js";

export function _isToolCall(toolCall?: unknown): toolCall is ToolCall {
return !!(
toolCall &&
typeof toolCall === "object" &&
"type" in toolCall &&
toolCall.type === "tool_call"
);
}

/**
* Custom error class used to handle exceptions related to tool input parsing.
* It extends the built-in `Error` class and adds an optional `output`
* property that can hold the output that caused the exception.
*/
export class ToolInputParsingException extends Error {
output?: string;

constructor(message: string, output?: string) {
super(message);
this.output = output;
}
}
2 changes: 1 addition & 1 deletion langchain-core/src/types/zod.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { z } from "zod";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type ZodAny = z.ZodObject<any, any, any, any>;
export type ZodObjectAny = z.ZodObject<any, any, any, any>;
Loading

0 comments on commit 1c4b2d8

Please sign in to comment.