Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[patch]: Allow dynamic tools to be initialized with JSON schema #6306

Merged
merged 3 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 54 additions & 20 deletions langchain-core/src/tools/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import { ZodObjectAny } from "../types/zod.js";
import { MessageContent } from "../messages/base.js";
import { AsyncLocalStorageProviderSingleton } from "../singletons/index.js";
import { _isToolCall, ToolInputParsingException } from "./utils.js";
import { isZodSchema } from "../utils/types/is_zod_schema.js";

export { ToolInputParsingException };

Expand Down Expand Up @@ -319,16 +320,19 @@ export interface DynamicToolInput extends BaseDynamicToolInput {
* Interface for the input parameters of the DynamicStructuredTool class.
*/
export interface DynamicStructuredToolInput<
T extends ZodObjectAny = ZodObjectAny
// eslint-disable-next-line @typescript-eslint/no-explicit-any
T extends ZodObjectAny | Record<string, any> = ZodObjectAny
> extends BaseDynamicToolInput {
func: (
input: BaseDynamicToolInput["responseFormat"] extends "content_and_artifact"
? ToolCall
: z.infer<T>,
: T extends ZodObjectAny
? z.infer<T>
: T,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
) => Promise<ToolReturnType>;
schema: T;
schema: T extends ZodObjectAny ? T : T;
}

/**
Expand Down Expand Up @@ -382,10 +386,14 @@ export class DynamicTool extends Tool {
* description, designed to work with structured data. It extends the
* StructuredTool class and overrides the _call method to execute the
* provided function when the tool is called.
*
* Schema can be passed as Zod or JSON schema. The tool will not validate
* input if JSON schema is passed.
*/
export class DynamicStructuredTool<
T extends ZodObjectAny = ZodObjectAny
> extends StructuredTool<T> {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
T extends ZodObjectAny | Record<string, any> = ZodObjectAny
> extends StructuredTool<T extends ZodObjectAny ? T : ZodObjectAny> {
static lc_name() {
return "DynamicStructuredTool";
}
Expand All @@ -396,22 +404,24 @@ export class DynamicStructuredTool<

func: DynamicStructuredToolInput<T>["func"];

schema: T;
schema: T extends ZodObjectAny ? T : ZodObjectAny;

constructor(fields: DynamicStructuredToolInput<T>) {
super(fields);
this.name = fields.name;
this.description = fields.description;
this.func = fields.func;
this.returnDirect = fields.returnDirect ?? this.returnDirect;
this.schema = fields.schema;
this.schema = (
isZodSchema(fields.schema) ? fields.schema : z.object({})
) as T extends ZodObjectAny ? T : ZodObjectAny;
}

/**
* @deprecated Use .invoke() instead. Will be removed in 0.3.0.
*/
async call(
arg: z.output<T> | ToolCall,
arg: (T extends ZodObjectAny ? z.output<T> : T) | ToolCall,
configArg?: RunnableConfig | Callbacks,
/** @deprecated */
tags?: string[]
Expand All @@ -424,11 +434,12 @@ export class DynamicStructuredTool<
}

protected _call(
arg: z.output<T> | ToolCall,
arg: (T extends ZodObjectAny ? z.output<T> : T) | ToolCall,
runManager?: CallbackManagerForToolRun,
parentConfig?: RunnableConfig
): Promise<ToolReturnType> {
return this.func(arg, runManager, parentConfig);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return this.func(arg as any, runManager, parentConfig);
}
}

Expand All @@ -447,10 +458,16 @@ export abstract class BaseToolkit {

/**
* Parameters for the tool function.
* @template {ZodObjectAny | z.ZodString = ZodObjectAny} RunInput The input schema for the tool. Either any Zod object, or a Zod string.
* Schema can be provided as Zod or JSON schema.
* If you pass JSON schema, tool inputs will not be validated.
* @template {ZodObjectAny | z.ZodString | Record<string, any> = ZodObjectAny} RunInput The input schema for the tool. Either any Zod object, a Zod string, or JSON schema.
*/
interface ToolWrapperParams<
RunInput extends ZodObjectAny | z.ZodString = ZodObjectAny
RunInput extends
| ZodObjectAny
| z.ZodString
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any> = ZodObjectAny
> extends ToolParams {
/**
* The name of the tool. If using with an LLM, this
Expand Down Expand Up @@ -483,8 +500,11 @@ interface ToolWrapperParams<
/**
* Creates a new StructuredTool instance with the provided function, name, description, and schema.
*
* Schema can be provided as Zod or JSON schema.
* If you pass JSON schema, tool inputs will not be validated.
*
* @function
* @template {ZodObjectAny | z.ZodString = ZodObjectAny} T The input schema for the tool. Either any Zod object, or a Zod string.
* @template {ZodObjectAny | z.ZodString | Record<string, any> = ZodObjectAny} T The input schema for the tool. Either any Zod object, a Zod string, or JSON schema instance.
*
* @param {RunnableFunc<z.output<T>, ToolReturnType>} func - The function to invoke when the tool is called.
* @param {ToolWrapperParams<T>} fields - An object containing the following properties:
Expand All @@ -494,18 +514,27 @@ interface ToolWrapperParams<
*
* @returns {DynamicStructuredTool<T>} A new StructuredTool instance.
*/
export function tool<T extends z.ZodString = z.ZodString>(
export function tool<T extends z.ZodString>(
func: RunnableFunc<z.output<T>, ToolReturnType>,
fields: ToolWrapperParams<T>
): DynamicTool;

export function tool<T extends ZodObjectAny = ZodObjectAny>(
export function tool<T extends ZodObjectAny>(
func: RunnableFunc<z.output<T>, ToolReturnType>,
fields: ToolWrapperParams<T>
): DynamicStructuredTool<T>;

export function tool<T extends ZodObjectAny | z.ZodString = ZodObjectAny>(
func: RunnableFunc<z.output<T>, ToolReturnType>,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function tool<T extends Record<string, any>>(
func: RunnableFunc<T, ToolReturnType>,
fields: ToolWrapperParams<T>
): DynamicStructuredTool<T>;

export function tool<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
T extends ZodObjectAny | z.ZodString | Record<string, any> = ZodObjectAny
>(
func: RunnableFunc<T extends ZodObjectAny ? z.output<T> : T, ToolReturnType>,
fields: ToolWrapperParams<T>
):
| DynamicStructuredTool<T extends ZodObjectAny ? T : ZodObjectAny>
Expand All @@ -518,7 +547,9 @@ export function tool<T extends ZodObjectAny | z.ZodString = ZodObjectAny>(
fields.description ??
fields.schema?.description ??
`${fields.name} tool`,
func,
// TS doesn't restrict the type here based on the guard above
// eslint-disable-next-line @typescript-eslint/no-explicit-any
func: func as any,
});
}

Expand All @@ -528,7 +559,8 @@ export function tool<T extends ZodObjectAny | z.ZodString = ZodObjectAny>(
return new DynamicStructuredTool<T extends ZodObjectAny ? T : ZodObjectAny>({
...fields,
description,
schema: fields.schema as T extends ZodObjectAny ? T : ZodObjectAny,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
schema: fields.schema as any,
// TODO: Consider moving into DynamicStructuredTool constructor
func: async (input, runManager, config) => {
return new Promise((resolve, reject) => {
Expand All @@ -539,7 +571,9 @@ export function tool<T extends ZodObjectAny | z.ZodString = ZodObjectAny>(
childConfig,
async () => {
try {
resolve(func(input, childConfig));
// TS doesn't restrict the type here based on the guard above
// eslint-disable-next-line @typescript-eslint/no-explicit-any
resolve(func(input as any, childConfig));
} catch (e) {
reject(e);
}
Expand Down
99 changes: 98 additions & 1 deletion langchain-core/src/tools/tests/tools.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { test, expect } from "@jest/globals";
import { z } from "zod";
import { tool } from "../index.js";
import { DynamicStructuredTool, tool } from "../index.js";
import { ToolMessage } from "../../messages/tool.js";

test("Tool should error if responseFormat is content_and_artifact but the function doesn't return a tuple", async () => {
Expand Down Expand Up @@ -115,3 +115,100 @@ test("Tool can accept single string input", async () => {
const result = await stringTool.invoke("b");
expect(result).toBe("ba");
});

test("Tool declared with JSON schema", async () => {
const weatherSchema = {
type: "object",
properties: {
location: {
type: "string",
description: "A place",
},
},
required: ["location"],
};
const weatherTool = tool(
(_) => {
return "Sunny";
},
{
name: "weather",
schema: weatherSchema,
}
);

const weatherTool2 = new DynamicStructuredTool({
name: "weather",
description: "get the weather",
func: async (_) => {
return "Sunny";
},
schema: weatherSchema,
});
// No validation on JSON schema tools
await weatherTool.invoke({
somethingSilly: true,
});
await weatherTool2.invoke({
somethingSilly: true,
});
});

test("Tool input typing is enforced", async () => {
const weatherSchema = z.object({
location: z.string(),
});

const weatherTool = tool(
(_) => {
return "Sunny";
},
{
name: "weather",
schema: weatherSchema,
}
);

const weatherTool2 = new DynamicStructuredTool({
name: "weather",
description: "get the weather",
func: async (_) => {
return "Sunny";
},
schema: weatherSchema,
});

const weatherTool3 = tool(
async (_) => {
return "Sunny";
},
{
name: "weather",
description: "get the weather",
schema: z.string(),
}
);

await expect(async () => {
await weatherTool.invoke({
// @ts-expect-error Invalid argument
badval: "someval",
});
}).rejects.toThrow();
const res = await weatherTool.invoke({
location: "somewhere",
});
expect(res).toEqual("Sunny");
await expect(async () => {
await weatherTool2.invoke({
// @ts-expect-error Invalid argument
badval: "someval",
});
}).rejects.toThrow();
const res2 = await weatherTool2.invoke({
location: "someval",
});
expect(res2).toEqual("Sunny");
const res3 = await weatherTool3.invoke("blah");
expect(res3).toEqual("Sunny");
});
Loading