-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
openai[minor],core[minor]: Add support for passing strict in openai tools #7
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,14 +34,29 @@ export function convertToOpenAIFunction( | |
*/ | ||
export function convertToOpenAITool( | ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
tool: StructuredToolInterface | Record<string, any> | RunnableToolLike | ||
tool: StructuredToolInterface | Record<string, any> | RunnableToolLike, | ||
fields?: { | ||
/** | ||
* If `true`, model output is guaranteed to exactly match the JSON Schema | ||
* provided in the function definition. | ||
*/ | ||
strict?: boolean; | ||
} | ||
): ToolDefinition { | ||
let toolDef: ToolDefinition | undefined; | ||
if (isStructuredTool(tool) || isRunnableToolLike(tool)) { | ||
return { | ||
toolDef = { | ||
type: "function", | ||
function: convertToOpenAIFunction(tool), | ||
}; | ||
} else { | ||
toolDef = tool as ToolDefinition; | ||
} | ||
|
||
if (fields?.strict !== undefined) { | ||
toolDef.function.strict = fields.strict; | ||
} | ||
|
||
return tool as ToolDefinition; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's an issue with the return statement in the
|
||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,212 @@ | ||||||
import { z } from "zod"; | ||||||
import { zodToJsonSchema } from "zod-to-json-schema"; | ||||||
import { it, expect, describe, beforeAll, afterAll, jest } from "@jest/globals"; | ||||||
import { ChatOpenAI } from "../chat_models.js"; | ||||||
|
||||||
|
||||||
describe("strict tool calling", () => { | ||||||
const weatherTool = { | ||||||
type: "function" as const, | ||||||
function: { | ||||||
name: "get_current_weather", | ||||||
description: "Get the current weather in a location", | ||||||
parameters: zodToJsonSchema(z.object({ | ||||||
location: z.string().describe("The location to get the weather for"), | ||||||
})) | ||||||
} | ||||||
} | ||||||
|
||||||
// Store the original value of LANGCHAIN_TRACING_V2 | ||||||
let oldLangChainTracingValue: string | undefined; | ||||||
// Before all tests, save the current LANGCHAIN_TRACING_V2 value | ||||||
beforeAll(() => { | ||||||
oldLangChainTracingValue = process.env.LANGCHAIN_TRACING_V2; | ||||||
}) | ||||||
// After all tests, restore the original LANGCHAIN_TRACING_V2 value | ||||||
afterAll(() => { | ||||||
if (oldLangChainTracingValue !== undefined) { | ||||||
process.env.LANGCHAIN_TRACING_V2 = oldLangChainTracingValue; | ||||||
} else { | ||||||
// If it was undefined, remove the environment variable | ||||||
delete process.env.LANGCHAIN_TRACING_V2; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid Using the Using the - delete process.env.LANGCHAIN_TRACING_V2;
+ process.env.LANGCHAIN_TRACING_V2 = undefined; Committable suggestion
Suggested change
ToolsBiome
|
||||||
} | ||||||
}) | ||||||
|
||||||
it("Can accept strict as a call arg via .bindTools", async () => { | ||||||
const mockFetch = jest.fn<(url: any, init?: any) => Promise<any>>(); | ||||||
mockFetch.mockImplementation((url, options): Promise<any> => { | ||||||
// Store the request details for later inspection | ||||||
mockFetch.mock.calls.push({ url, options } as any); | ||||||
|
||||||
// Return a mock response | ||||||
return Promise.resolve({ | ||||||
ok: true, | ||||||
json: () => Promise.resolve({}), | ||||||
}) as Promise<any>; | ||||||
}); | ||||||
|
||||||
const model = new ChatOpenAI({ | ||||||
model: "gpt-4", | ||||||
configuration: { | ||||||
fetch: mockFetch, | ||||||
}, | ||||||
maxRetries: 0, | ||||||
}); | ||||||
|
||||||
const modelWithTools = model.bindTools([weatherTool], { strict: true }); | ||||||
|
||||||
// This will fail since we're not returning a valid response in our mocked fetch function. | ||||||
await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); | ||||||
|
||||||
expect(mockFetch).toHaveBeenCalled(); | ||||||
const [_url, options] = mockFetch.mock.calls[0]; | ||||||
|
||||||
if (options && options.body) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use Optional Chaining for Safety Consider using optional chaining to safely access properties and avoid potential runtime errors. - if (options && options.body) {
+ if (options?.body) { Also applies to: 110-110, 154-154, 199-199 ToolsBiome
|
||||||
expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ | ||||||
type: "function", | ||||||
function: { | ||||||
...weatherTool.function, | ||||||
// This should be added to the function call because `strict` was passed to `bindTools` | ||||||
strict: true, | ||||||
} | ||||||
})]); | ||||||
} else { | ||||||
throw new Error("Body not found in request.") | ||||||
} | ||||||
}); | ||||||
|
||||||
it("Can accept strict as a call arg via .bind", async () => { | ||||||
const mockFetch = jest.fn<(url: any, init?: any) => Promise<any>>(); | ||||||
mockFetch.mockImplementation((url, options): Promise<any> => { | ||||||
// Store the request details for later inspection | ||||||
mockFetch.mock.calls.push({ url, options } as any); | ||||||
|
||||||
// Return a mock response | ||||||
return Promise.resolve({ | ||||||
ok: true, | ||||||
json: () => Promise.resolve({}), | ||||||
}) as Promise<any>; | ||||||
}); | ||||||
|
||||||
const model = new ChatOpenAI({ | ||||||
model: "gpt-4", | ||||||
configuration: { | ||||||
fetch: mockFetch, | ||||||
}, | ||||||
maxRetries: 0, | ||||||
}); | ||||||
|
||||||
const modelWithTools = model.bind({ | ||||||
tools: [weatherTool], | ||||||
strict: true | ||||||
}); | ||||||
|
||||||
// This will fail since we're not returning a valid response in our mocked fetch function. | ||||||
await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); | ||||||
|
||||||
expect(mockFetch).toHaveBeenCalled(); | ||||||
const [_url, options] = mockFetch.mock.calls[0]; | ||||||
|
||||||
if (options && options.body) { | ||||||
expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ | ||||||
type: "function", | ||||||
function: { | ||||||
...weatherTool.function, | ||||||
// This should be added to the function call because `strict` was passed to `bind` | ||||||
strict: true, | ||||||
} | ||||||
})]); | ||||||
} else { | ||||||
throw new Error("Body not found in request.") | ||||||
} | ||||||
}); | ||||||
|
||||||
it("Sets strict to true if the model name starts with 'gpt-'", async () => { | ||||||
const mockFetch = jest.fn<(url: any, init?: any) => Promise<any>>(); | ||||||
mockFetch.mockImplementation((url, options): Promise<any> => { | ||||||
// Store the request details for later inspection | ||||||
mockFetch.mock.calls.push({ url, options } as any); | ||||||
|
||||||
// Return a mock response | ||||||
return Promise.resolve({ | ||||||
ok: true, | ||||||
json: () => Promise.resolve({}), | ||||||
}) as Promise<any>; | ||||||
}); | ||||||
|
||||||
const model = new ChatOpenAI({ | ||||||
model: "gpt-4", | ||||||
configuration: { | ||||||
fetch: mockFetch, | ||||||
}, | ||||||
maxRetries: 0, | ||||||
}); | ||||||
|
||||||
// Do NOT pass `strict` here since we're checking that it's set to true by default | ||||||
const modelWithTools = model.bindTools([weatherTool]); | ||||||
|
||||||
// This will fail since we're not returning a valid response in our mocked fetch function. | ||||||
await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); | ||||||
|
||||||
expect(mockFetch).toHaveBeenCalled(); | ||||||
const [_url, options] = mockFetch.mock.calls[0]; | ||||||
|
||||||
if (options && options.body) { | ||||||
expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ | ||||||
type: "function", | ||||||
function: { | ||||||
...weatherTool.function, | ||||||
// This should be added to the function call because `strict` was passed to `bind` | ||||||
strict: true, | ||||||
} | ||||||
})]); | ||||||
} else { | ||||||
throw new Error("Body not found in request.") | ||||||
} | ||||||
}); | ||||||
|
||||||
it("Strict is false if supportsStrictToolCalling is false", async () => { | ||||||
const mockFetch = jest.fn<(url: any, init?: any) => Promise<any>>(); | ||||||
mockFetch.mockImplementation((url, options): Promise<any> => { | ||||||
// Store the request details for later inspection | ||||||
mockFetch.mock.calls.push({ url, options } as any); | ||||||
|
||||||
// Return a mock response | ||||||
return Promise.resolve({ | ||||||
ok: true, | ||||||
json: () => Promise.resolve({}), | ||||||
}) as Promise<any>; | ||||||
}); | ||||||
|
||||||
const model = new ChatOpenAI({ | ||||||
model: "gpt-4", | ||||||
configuration: { | ||||||
fetch: mockFetch, | ||||||
}, | ||||||
maxRetries: 0, | ||||||
supportsStrictToolCalling: false, | ||||||
}); | ||||||
|
||||||
// Do NOT pass `strict` here since we're checking that it's set to true by default | ||||||
const modelWithTools = model.bindTools([weatherTool]); | ||||||
|
||||||
// This will fail since we're not returning a valid response in our mocked fetch function. | ||||||
await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); | ||||||
|
||||||
expect(mockFetch).toHaveBeenCalled(); | ||||||
const [_url, options] = mockFetch.mock.calls[0]; | ||||||
|
||||||
if (options && options.body) { | ||||||
expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ | ||||||
type: "function", | ||||||
function: { | ||||||
...weatherTool.function, | ||||||
// This should be added to the function call because `strict` was passed to `bind` | ||||||
strict: false, | ||||||
} | ||||||
})]); | ||||||
} else { | ||||||
throw new Error("Body not found in request.") | ||||||
} | ||||||
}); | ||||||
}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ import { zodToJsonSchema } from "zod-to-json-schema"; | |
import { ChatPromptTemplate } from "@langchain/core/prompts"; | ||
import { AIMessage } from "@langchain/core/messages"; | ||
import { ChatOpenAI } from "../chat_models.js"; | ||
import { test, expect } from "@jest/globals"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It appears that you've added an import for 'test' and 'expect' from '@jest/globals'. However, these are typically globally available in Jest test files without explicit imports. To improve code cleanliness and avoid potential conflicts, you can remove this redundant import line.
|
||
|
||
test("withStructuredOutput zod schema function calling", async () => { | ||
const model = new ChatOpenAI({ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -155,6 +155,13 @@ export interface OpenAIChatInput extends OpenAIBaseInput { | |
* Currently in experimental beta. | ||
*/ | ||
__includeRawResponse?: boolean; | ||
|
||
/** | ||
* Whether the model supports the 'strict' argument when passing in tools. | ||
* Defaults to `true` if `modelName`/`model` starts with 'gpt-' otherwise | ||
* defaults to `false`. | ||
*/ | ||
supportsStrictToolCalling?: boolean; | ||
Comment on lines
+159
to
+164
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new 'supportsStrictToolCalling' field has been added to the OpenAIChatInput interface, which is a good addition. However, it would be helpful to add a comment explaining how this field is used in the actual implementation. For example, does setting this to true enforce strict adherence to the tool's schema? How does it affect the model's behavior when calling tools? Adding this information would make the code more self-documenting and easier for other developers to understand and use correctly.
|
||
} | ||
|
||
export declare interface AzureOpenAIInput { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix the Return Statement in
convertToOpenAITool
The return statement currently returns the original
tool
parameter. It should return thetoolDef
variable, which contains the updated tool definition with thestrict
field.Fix the return statement as follows: