Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openai[minor],core[minor]: Add support for passing strict in openai tools #7

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions langchain-core/src/language_models/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,15 @@ export interface FunctionDefinition {
* how to call the function.
*/
description?: string;

/**
* Whether to enable strict schema adherence when generating the function call. If
* set to true, the model will follow the exact schema defined in the `parameters`
* field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn
* more about Structured Outputs in the
* [function calling guide](https://platform.openai.com/docs/guides/function-calling).
*/
strict?: boolean;
}

export interface ToolDefinition {
Expand Down
19 changes: 17 additions & 2 deletions langchain-core/src/utils/function_calling.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,29 @@ export function convertToOpenAIFunction(
*/
export function convertToOpenAITool(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
tool: StructuredToolInterface | Record<string, any> | RunnableToolLike
tool: StructuredToolInterface | Record<string, any> | RunnableToolLike,
fields?: {
/**
* If `true`, model output is guaranteed to exactly match the JSON Schema
* provided in the function definition.
*/
strict?: boolean;
}
): ToolDefinition {
let toolDef: ToolDefinition | undefined;
if (isStructuredTool(tool) || isRunnableToolLike(tool)) {
return {
toolDef = {
type: "function",
function: convertToOpenAIFunction(tool),
};
} else {
toolDef = tool as ToolDefinition;
}

if (fields?.strict !== undefined) {
toolDef.function.strict = fields.strict;
}

Comment on lines +56 to +59
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the Return Statement in convertToOpenAITool

The return statement currently returns the original tool parameter. It should return the toolDef variable, which contains the updated tool definition with the strict field.

Fix the return statement as follows:

-  return tool as ToolDefinition;
+  return toolDef;

Committable suggestion was skipped due to low confidence.

return tool as ToolDefinition;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

category Functionality severity potentially major

There's an issue with the return statement in the convertToOpenAITool function. Currently, it's returning the original tool parameter, which hasn't been modified with the strict field. Instead, it should return the toolDef variable, which contains the updated tool definition. Please update the return statement to return toolDef; to ensure the function returns the correctly modified tool definition.

Chat with Korbit by mentioning @korbit-ai, and give a 👍 or 👎 to help Korbit improve your reviews.

}

Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-openai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"dependencies": {
"@langchain/core": ">=0.2.16 <0.3.0",
"js-tiktoken": "^1.0.12",
"openai": "^4.49.1",
"openai": "^4.55.0",
"zod": "^3.22.4",
"zod-to-json-schema": "^3.22.3"
},
Expand Down
37 changes: 29 additions & 8 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { type ClientOptions, OpenAI as OpenAIClient } from "openai";
import { type ClientOptions, OpenAI as OpenAIClient, } from "openai";

import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
Expand Down Expand Up @@ -299,6 +299,16 @@ export interface ChatOpenAICallOptions
* call multiple tools in one response.
*/
parallel_tool_calls?: boolean;
/**
* If `true`, model output is guaranteed to exactly match the JSON Schema
* provided in the tool definition.
* Enabled by default for `"gpt-"` models.
*/
strict?: boolean;
}

export interface ChatOpenAIFields extends Partial<OpenAIChatInput>, Partial<AzureOpenAIInput>, BaseChatModelParams {
configuration?: ClientOptions & LegacyOpenAIInput;
}

/**
Expand Down Expand Up @@ -441,12 +451,15 @@ export class ChatOpenAI<

protected clientConfig: ClientOptions;

/**
* Whether the model supports the 'strict' argument when passing in tools.
* Defaults to `true` if `modelName`/`model` starts with 'gpt-' otherwise
* defaults to `false`.
*/
supportsStrictToolCalling?: boolean;

constructor(
fields?: Partial<OpenAIChatInput> &
Partial<AzureOpenAIInput> &
BaseChatModelParams & {
configuration?: ClientOptions & LegacyOpenAIInput;
},
fields?: ChatOpenAIFields,
/** @deprecated */
configuration?: ClientOptions & LegacyOpenAIInput
) {
Expand Down Expand Up @@ -541,6 +554,12 @@ export class ChatOpenAI<
...configuration,
...fields?.configuration,
};

// Assume only "gpt-..." models support strict tool calling as of 08/06/24.
this.supportsStrictToolCalling =
fields?.supportsStrictToolCalling !== undefined
? fields.supportsStrictToolCalling
: this.modelName.startsWith("gpt-");
}

getLsParams(options: this["ParsedCallOptions"]): LangSmithParams {
Expand All @@ -563,8 +582,9 @@ export class ChatOpenAI<
)[],
kwargs?: Partial<CallOptions>
): Runnable<BaseLanguageModelInput, AIMessageChunk, CallOptions> {
const strict = kwargs?.strict !== undefined ? kwargs.strict : this.supportsStrictToolCalling;
return this.bind({
tools: tools.map(convertToOpenAITool),
tools: tools.map((tool) => convertToOpenAITool(tool, { strict })),
...kwargs,
} as Partial<CallOptions>);
}
Expand All @@ -578,6 +598,7 @@ export class ChatOpenAI<
streaming?: boolean;
}
): Omit<OpenAIClient.Chat.ChatCompletionCreateParams, "messages"> {
const strict = options?.strict !== undefined ? options.strict : this.supportsStrictToolCalling;
function isStructuredToolArray(
tools?: unknown[]
): tools is StructuredToolInterface[] {
Expand Down Expand Up @@ -615,7 +636,7 @@ export class ChatOpenAI<
functions: options?.functions,
function_call: options?.function_call,
tools: isStructuredToolArray(options?.tools)
? options?.tools.map(convertToOpenAITool)
? options?.tools.map((tool) => convertToOpenAITool(tool, { strict }))
: options?.tools,
tool_choice: formatToOpenAIToolChoice(options?.tool_choice),
response_format: options?.response_format,
Expand Down
212 changes: 212 additions & 0 deletions libs/langchain-openai/src/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import { z } from "zod";
import { zodToJsonSchema } from "zod-to-json-schema";
import { it, expect, describe, beforeAll, afterAll, jest } from "@jest/globals";
import { ChatOpenAI } from "../chat_models.js";


describe("strict tool calling", () => {
const weatherTool = {
type: "function" as const,
function: {
name: "get_current_weather",
description: "Get the current weather in a location",
parameters: zodToJsonSchema(z.object({
location: z.string().describe("The location to get the weather for"),
}))
}
}

// Store the original value of LANGCHAIN_TRACING_V2
let oldLangChainTracingValue: string | undefined;
// Before all tests, save the current LANGCHAIN_TRACING_V2 value
beforeAll(() => {
oldLangChainTracingValue = process.env.LANGCHAIN_TRACING_V2;
})
// After all tests, restore the original LANGCHAIN_TRACING_V2 value
afterAll(() => {
if (oldLangChainTracingValue !== undefined) {
process.env.LANGCHAIN_TRACING_V2 = oldLangChainTracingValue;
} else {
// If it was undefined, remove the environment variable
delete process.env.LANGCHAIN_TRACING_V2;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid Using the delete Operator

Using the delete operator can impact performance. Consider setting the variable to undefined instead.

-      delete process.env.LANGCHAIN_TRACING_V2;
+      process.env.LANGCHAIN_TRACING_V2 = undefined;
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
delete process.env.LANGCHAIN_TRACING_V2;
process.env.LANGCHAIN_TRACING_V2 = undefined;
Tools
Biome

[error] 31-31: Avoid the delete operator which can impact performance.

Unsafe fix: Use an undefined assignment instead.

(lint/performance/noDelete)

}
})

it("Can accept strict as a call arg via .bindTools", async () => {
const mockFetch = jest.fn<(url: any, init?: any) => Promise<any>>();
mockFetch.mockImplementation((url, options): Promise<any> => {
// Store the request details for later inspection
mockFetch.mock.calls.push({ url, options } as any);

// Return a mock response
return Promise.resolve({
ok: true,
json: () => Promise.resolve({}),
}) as Promise<any>;
});

const model = new ChatOpenAI({
model: "gpt-4",
configuration: {
fetch: mockFetch,
},
maxRetries: 0,
});

const modelWithTools = model.bindTools([weatherTool], { strict: true });

// This will fail since we're not returning a valid response in our mocked fetch function.
await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow();

expect(mockFetch).toHaveBeenCalled();
const [_url, options] = mockFetch.mock.calls[0];

if (options && options.body) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Optional Chaining for Safety

Consider using optional chaining to safely access properties and avoid potential runtime errors.

-    if (options && options.body) {
+    if (options?.body) {

Also applies to: 110-110, 154-154, 199-199

Tools
Biome

[error] 64-64: Change to an optional chain.

Unsafe fix: Change to an optional chain.

(lint/complexity/useOptionalChain)

expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({
type: "function",
function: {
...weatherTool.function,
// This should be added to the function call because `strict` was passed to `bindTools`
strict: true,
}
})]);
} else {
throw new Error("Body not found in request.")
}
});

it("Can accept strict as a call arg via .bind", async () => {
const mockFetch = jest.fn<(url: any, init?: any) => Promise<any>>();
mockFetch.mockImplementation((url, options): Promise<any> => {
// Store the request details for later inspection
mockFetch.mock.calls.push({ url, options } as any);

// Return a mock response
return Promise.resolve({
ok: true,
json: () => Promise.resolve({}),
}) as Promise<any>;
});

const model = new ChatOpenAI({
model: "gpt-4",
configuration: {
fetch: mockFetch,
},
maxRetries: 0,
});

const modelWithTools = model.bind({
tools: [weatherTool],
strict: true
});

// This will fail since we're not returning a valid response in our mocked fetch function.
await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow();

expect(mockFetch).toHaveBeenCalled();
const [_url, options] = mockFetch.mock.calls[0];

if (options && options.body) {
expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({
type: "function",
function: {
...weatherTool.function,
// This should be added to the function call because `strict` was passed to `bind`
strict: true,
}
})]);
} else {
throw new Error("Body not found in request.")
}
});

it("Sets strict to true if the model name starts with 'gpt-'", async () => {
const mockFetch = jest.fn<(url: any, init?: any) => Promise<any>>();
mockFetch.mockImplementation((url, options): Promise<any> => {
// Store the request details for later inspection
mockFetch.mock.calls.push({ url, options } as any);

// Return a mock response
return Promise.resolve({
ok: true,
json: () => Promise.resolve({}),
}) as Promise<any>;
});

const model = new ChatOpenAI({
model: "gpt-4",
configuration: {
fetch: mockFetch,
},
maxRetries: 0,
});

// Do NOT pass `strict` here since we're checking that it's set to true by default
const modelWithTools = model.bindTools([weatherTool]);

// This will fail since we're not returning a valid response in our mocked fetch function.
await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow();

expect(mockFetch).toHaveBeenCalled();
const [_url, options] = mockFetch.mock.calls[0];

if (options && options.body) {
expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({
type: "function",
function: {
...weatherTool.function,
// This should be added to the function call because `strict` was passed to `bind`
strict: true,
}
})]);
} else {
throw new Error("Body not found in request.")
}
});

it("Strict is false if supportsStrictToolCalling is false", async () => {
const mockFetch = jest.fn<(url: any, init?: any) => Promise<any>>();
mockFetch.mockImplementation((url, options): Promise<any> => {
// Store the request details for later inspection
mockFetch.mock.calls.push({ url, options } as any);

// Return a mock response
return Promise.resolve({
ok: true,
json: () => Promise.resolve({}),
}) as Promise<any>;
});

const model = new ChatOpenAI({
model: "gpt-4",
configuration: {
fetch: mockFetch,
},
maxRetries: 0,
supportsStrictToolCalling: false,
});

// Do NOT pass `strict` here since we're checking that it's set to true by default
const modelWithTools = model.bindTools([weatherTool]);

// This will fail since we're not returning a valid response in our mocked fetch function.
await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow();

expect(mockFetch).toHaveBeenCalled();
const [_url, options] = mockFetch.mock.calls[0];

if (options && options.body) {
expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({
type: "function",
function: {
...weatherTool.function,
// This should be added to the function call because `strict` was passed to `bind`
strict: false,
}
})]);
} else {
throw new Error("Body not found in request.")
}
});
})
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { zodToJsonSchema } from "zod-to-json-schema";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { AIMessage } from "@langchain/core/messages";
import { ChatOpenAI } from "../chat_models.js";
import { test, expect } from "@jest/globals";
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

category Functionality

It appears that you've added an import for 'test' and 'expect' from '@jest/globals'. However, these are typically globally available in Jest test files without explicit imports. To improve code cleanliness and avoid potential conflicts, you can remove this redundant import line.

Chat with Korbit by mentioning @korbit-ai, and give a 👍 or 👎 to help Korbit improve your reviews.


test("withStructuredOutput zod schema function calling", async () => {
const model = new ChatOpenAI({
Expand Down
7 changes: 7 additions & 0 deletions libs/langchain-openai/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ export interface OpenAIChatInput extends OpenAIBaseInput {
* Currently in experimental beta.
*/
__includeRawResponse?: boolean;

/**
* Whether the model supports the 'strict' argument when passing in tools.
* Defaults to `true` if `modelName`/`model` starts with 'gpt-' otherwise
* defaults to `false`.
*/
supportsStrictToolCalling?: boolean;
Comment on lines +159 to +164
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

category Functionality

The new 'supportsStrictToolCalling' field has been added to the OpenAIChatInput interface, which is a good addition. However, it would be helpful to add a comment explaining how this field is used in the actual implementation. For example, does setting this to true enforce strict adherence to the tool's schema? How does it affect the model's behavior when calling tools? Adding this information would make the code more self-documenting and easier for other developers to understand and use correctly.

Chat with Korbit by mentioning @korbit-ai, and give a 👍 or 👎 to help Korbit improve your reviews.

}

export declare interface AzureOpenAIInput {
Expand Down
Loading