From b1fb358fa8aeeaefb6f29574d09830ecac261b60 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Tue, 16 Jul 2024 10:38:04 -0700 Subject: [PATCH 1/3] Adds experimental raw response field to OpenAI chat models --- libs/langchain-openai/src/chat_models.ts | 36 ++++++++++++++----- .../src/tests/azure/chat_models.int.test.ts | 18 ++++++++++ .../azure/chat_models.standard.int.test.ts | 18 ++++++++++ .../src/tests/azure/embeddings.int.test.ts | 18 ++++++++++ .../src/tests/azure/llms.int.test.ts | 18 ++++++++++ .../tests/chat_models-extended.int.test.ts | 26 ++++++++++++++ libs/langchain-openai/src/types.ts | 8 ++++- 7 files changed, 132 insertions(+), 10 deletions(-) diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index c3f1f03ffc1b..033ff794621d 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -129,7 +129,8 @@ export function messageToOpenAIRole(message: BaseMessage): OpenAIRoleEnum { function openAIResponseToChatMessage( message: OpenAIClient.Chat.Completions.ChatCompletionMessage, - messageId: string + messageId: string, + rawResponse?: OpenAIClient.Chat.Completions.ChatCompletion ): BaseMessage { const rawToolCalls: OpenAIToolCall[] | undefined = message.tool_calls as | OpenAIToolCall[] @@ -146,14 +147,18 @@ function openAIResponseToChatMessage( invalidToolCalls.push(makeInvalidToolCall(rawToolCall, e.message)); } } + let additional_kwargs: Record = { + function_call: message.function_call, + tool_calls: rawToolCalls, + }; + if (rawResponse !== undefined) { + additional_kwargs.__raw_response = rawResponse; + } return new AIMessage({ content: message.content || "", tool_calls: toolCalls, invalid_tool_calls: invalidToolCalls, - additional_kwargs: { - function_call: message.function_call, - tool_calls: rawToolCalls, - }, + additional_kwargs, id: messageId, }); } @@ -166,11 +171,12 @@ function _convertDeltaToMessageChunk( // eslint-disable-next-line @typescript-eslint/no-explicit-any delta: Record, messageId: string, - defaultRole?: OpenAIRoleEnum + defaultRole?: OpenAIRoleEnum, + rawResponse?: OpenAIClient.Chat.Completions.ChatCompletionChunk ) { const role = delta.role ?? defaultRole; const content = delta.content ?? ""; - let additional_kwargs; + let additional_kwargs: Record; if (delta.function_call) { additional_kwargs = { function_call: delta.function_call, @@ -182,6 +188,9 @@ function _convertDeltaToMessageChunk( } else { additional_kwargs = {}; } + if (rawResponse !== undefined) { + additional_kwargs.__raw_response = rawResponse; + } if (role === "user") { return new HumanMessageChunk({ content }); } else if (role === "assistant") { @@ -415,6 +424,8 @@ export class ChatOpenAI< organization?: string; + __includeRawResponse?: boolean; + protected client: OpenAIClient; protected clientConfig: ClientOptions; @@ -485,6 +496,7 @@ export class ChatOpenAI< this.stop = fields?.stopSequences ?? fields?.stop; this.stopSequences = this?.stop; this.user = fields?.user; + this.__includeRawResponse = fields?.__includeRawResponse; if (this.azureOpenAIApiKey || this.azureADTokenProvider) { if (!this.azureOpenAIApiInstanceName && !this.azureOpenAIBasePath) { @@ -648,7 +660,12 @@ export class ChatOpenAI< if (!delta) { continue; } - const chunk = _convertDeltaToMessageChunk(delta, data.id, defaultRole); + const chunk = _convertDeltaToMessageChunk( + delta, + data.id, + defaultRole, + this.__includeRawResponse ? data : undefined + ); defaultRole = delta.role ?? defaultRole; const newTokenIndices = { prompt: options.promptIndex ?? 0, @@ -797,7 +814,8 @@ export class ChatOpenAI< text, message: openAIResponseToChatMessage( part.message ?? { role: "assistant" }, - data.id + data.id, + this.__includeRawResponse ? data : undefined ), }; generation.generationInfo = { diff --git a/libs/langchain-openai/src/tests/azure/chat_models.int.test.ts b/libs/langchain-openai/src/tests/azure/chat_models.int.test.ts index 65fcedce93bf..31de416cbfd4 100644 --- a/libs/langchain-openai/src/tests/azure/chat_models.int.test.ts +++ b/libs/langchain-openai/src/tests/azure/chat_models.int.test.ts @@ -28,6 +28,24 @@ import { AzureChatOpenAI } from "../../azure/chat_models.js"; // Save the original value of the 'LANGCHAIN_CALLBACKS_BACKGROUND' environment variable const originalBackground = process.env.LANGCHAIN_CALLBACKS_BACKGROUND; +beforeAll(() => { + if (!process.env.AZURE_OPENAI_API_KEY) { + process.env.AZURE_OPENAI_API_KEY = process.env.TEST_AZURE_OPENAI_API_KEY; + } + if (!process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME) { + process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = + process.env.TEST_AZURE_OPENAI_API_DEPLOYMENT_NAME; + } + if (!process.env.AZURE_OPENAI_BASE_PATH) { + process.env.AZURE_OPENAI_BASE_PATH = + process.env.TEST_AZURE_OPENAI_BASE_PATH; + } + if (!process.env.AZURE_OPENAI_API_VERSION) { + process.env.AZURE_OPENAI_API_VERSION = + process.env.TEST_AZURE_OPENAI_API_VERSION; + } +}); + test("Test Azure ChatOpenAI call method", async () => { const chat = new AzureChatOpenAI({ modelName: "gpt-3.5-turbo", diff --git a/libs/langchain-openai/src/tests/azure/chat_models.standard.int.test.ts b/libs/langchain-openai/src/tests/azure/chat_models.standard.int.test.ts index cfec5cfe8eb8..5f003828e331 100644 --- a/libs/langchain-openai/src/tests/azure/chat_models.standard.int.test.ts +++ b/libs/langchain-openai/src/tests/azure/chat_models.standard.int.test.ts @@ -5,6 +5,24 @@ import { AIMessageChunk } from "@langchain/core/messages"; import { AzureChatOpenAI } from "../../azure/chat_models.js"; import { ChatOpenAICallOptions } from "../../chat_models.js"; +beforeAll(() => { + if (!process.env.AZURE_OPENAI_API_KEY) { + process.env.AZURE_OPENAI_API_KEY = process.env.TEST_AZURE_OPENAI_API_KEY; + } + if (!process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME) { + process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = + process.env.TEST_AZURE_OPENAI_API_DEPLOYMENT_NAME; + } + if (!process.env.AZURE_OPENAI_BASE_PATH) { + process.env.AZURE_OPENAI_BASE_PATH = + process.env.TEST_AZURE_OPENAI_BASE_PATH; + } + if (!process.env.AZURE_OPENAI_API_VERSION) { + process.env.AZURE_OPENAI_API_VERSION = + process.env.TEST_AZURE_OPENAI_API_VERSION; + } +}); + class AzureChatOpenAIStandardIntegrationTests extends ChatModelIntegrationTests< ChatOpenAICallOptions, AIMessageChunk diff --git a/libs/langchain-openai/src/tests/azure/embeddings.int.test.ts b/libs/langchain-openai/src/tests/azure/embeddings.int.test.ts index 7362fe1a73e8..8046e3458641 100644 --- a/libs/langchain-openai/src/tests/azure/embeddings.int.test.ts +++ b/libs/langchain-openai/src/tests/azure/embeddings.int.test.ts @@ -1,6 +1,24 @@ import { test, expect } from "@jest/globals"; import { AzureOpenAIEmbeddings as OpenAIEmbeddings } from "../../azure/embeddings.js"; +beforeAll(() => { + if (!process.env.AZURE_OPENAI_API_KEY) { + process.env.AZURE_OPENAI_API_KEY = process.env.TEST_AZURE_OPENAI_API_KEY; + } + if (!process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME) { + process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = + process.env.TEST_AZURE_OPENAI_API_DEPLOYMENT_NAME; + } + if (!process.env.AZURE_OPENAI_BASE_PATH) { + process.env.AZURE_OPENAI_BASE_PATH = + process.env.TEST_AZURE_OPENAI_BASE_PATH; + } + if (!process.env.AZURE_OPENAI_API_VERSION) { + process.env.AZURE_OPENAI_API_VERSION = + process.env.TEST_AZURE_OPENAI_API_VERSION; + } +}); + test("Test AzureOpenAIEmbeddings.embedQuery", async () => { const embeddings = new OpenAIEmbeddings(); const res = await embeddings.embedQuery("Hello world"); diff --git a/libs/langchain-openai/src/tests/azure/llms.int.test.ts b/libs/langchain-openai/src/tests/azure/llms.int.test.ts index fa91c27e5dc4..42f84c163429 100644 --- a/libs/langchain-openai/src/tests/azure/llms.int.test.ts +++ b/libs/langchain-openai/src/tests/azure/llms.int.test.ts @@ -15,6 +15,24 @@ import { AzureOpenAI } from "../../azure/llms.js"; // Save the original value of the 'LANGCHAIN_CALLBACKS_BACKGROUND' environment variable const originalBackground = process.env.LANGCHAIN_CALLBACKS_BACKGROUND; +beforeAll(() => { + if (!process.env.AZURE_OPENAI_API_KEY) { + process.env.AZURE_OPENAI_API_KEY = process.env.TEST_AZURE_OPENAI_API_KEY; + } + if (!process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME) { + process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = + process.env.TEST_AZURE_OPENAI_API_DEPLOYMENT_NAME; + } + if (!process.env.AZURE_OPENAI_BASE_PATH) { + process.env.AZURE_OPENAI_BASE_PATH = + process.env.TEST_AZURE_OPENAI_BASE_PATH; + } + if (!process.env.AZURE_OPENAI_API_VERSION) { + process.env.AZURE_OPENAI_API_VERSION = + process.env.TEST_AZURE_OPENAI_API_VERSION; + } +}); + test("Test Azure OpenAI invoke", async () => { const model = new AzureOpenAI({ maxTokens: 5, diff --git a/libs/langchain-openai/src/tests/chat_models-extended.int.test.ts b/libs/langchain-openai/src/tests/chat_models-extended.int.test.ts index 3eca7ca0ff4f..cec415151922 100644 --- a/libs/langchain-openai/src/tests/chat_models-extended.int.test.ts +++ b/libs/langchain-openai/src/tests/chat_models-extended.int.test.ts @@ -291,3 +291,29 @@ test("Few shotting with tool calls", async () => { console.log(res); expect(res.content).toContain("24"); }); + +test("Test ChatOpenAI with raw response", async () => { + const chat = new ChatOpenAI({ + modelName: "gpt-3.5-turbo-1106", + maxTokens: 128, + __includeRawResponse: true, + }); + const message = new HumanMessage("Hello!"); + const res = await chat.invoke([message]); + expect(res.additional_kwargs.__raw_response).toBeDefined(); +}); + +test("Test ChatOpenAI with raw response", async () => { + const chat = new ChatOpenAI({ + modelName: "gpt-3.5-turbo-1106", + maxTokens: 128, + __includeRawResponse: true, + }); + const message = new HumanMessage("Hello!"); + const stream = await chat.stream([message]); + for await (const chunk of stream) { + expect( + chunk.additional_kwargs.__raw_response || chunk.usage_metadata + ).toBeDefined(); + } +}); diff --git a/libs/langchain-openai/src/types.ts b/libs/langchain-openai/src/types.ts index 7bcf4e8cf16e..19e6af483d7d 100644 --- a/libs/langchain-openai/src/types.ts +++ b/libs/langchain-openai/src/types.ts @@ -148,7 +148,13 @@ export interface OpenAIChatInput extends OpenAIBaseInput { topLogprobs?: number; /** ChatGPT messages to pass as a prefix to the prompt */ - prefixMessages?: OpenAIClient.Chat.CreateChatCompletionRequestMessage[]; + prefixMessages?: OpenAIClient.Chat.ChatCompletionMessageParam[]; + + /** + * Whether to include the raw OpenAI response in the output message's "additional_kwargs" field. + * Currently in experimental beta. + */ + __includeRawResponse?: boolean; } export declare interface AzureOpenAIInput { From e9b8f19c24b1666d1a21902be6ef4308ff97d660 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Tue, 16 Jul 2024 10:47:13 -0700 Subject: [PATCH 2/3] Fix lint --- libs/langchain-openai/src/chat_models.ts | 2 +- libs/langchain-openai/src/tests/azure/embeddings.int.test.ts | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index 033ff794621d..15fdb3b1c21f 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -147,7 +147,7 @@ function openAIResponseToChatMessage( invalidToolCalls.push(makeInvalidToolCall(rawToolCall, e.message)); } } - let additional_kwargs: Record = { + const additional_kwargs: Record = { function_call: message.function_call, tool_calls: rawToolCalls, }; diff --git a/libs/langchain-openai/src/tests/azure/embeddings.int.test.ts b/libs/langchain-openai/src/tests/azure/embeddings.int.test.ts index 8046e3458641..634cca967d74 100644 --- a/libs/langchain-openai/src/tests/azure/embeddings.int.test.ts +++ b/libs/langchain-openai/src/tests/azure/embeddings.int.test.ts @@ -1,3 +1,4 @@ +/* eslint-disable no-process-env */ import { test, expect } from "@jest/globals"; import { AzureOpenAIEmbeddings as OpenAIEmbeddings } from "../../azure/embeddings.js"; From 766e10531e1b5baa41adc260d80ebe1196897c4c Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Tue, 16 Jul 2024 19:20:22 -0700 Subject: [PATCH 3/3] Refactor --- libs/langchain-openai/src/chat_models.ts | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index 15fdb3b1c21f..31712178e896 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -129,8 +129,8 @@ export function messageToOpenAIRole(message: BaseMessage): OpenAIRoleEnum { function openAIResponseToChatMessage( message: OpenAIClient.Chat.Completions.ChatCompletionMessage, - messageId: string, - rawResponse?: OpenAIClient.Chat.Completions.ChatCompletion + rawResponse: OpenAIClient.Chat.Completions.ChatCompletion, + includeRawResponse?: boolean ): BaseMessage { const rawToolCalls: OpenAIToolCall[] | undefined = message.tool_calls as | OpenAIToolCall[] @@ -151,7 +151,7 @@ function openAIResponseToChatMessage( function_call: message.function_call, tool_calls: rawToolCalls, }; - if (rawResponse !== undefined) { + if (includeRawResponse !== undefined) { additional_kwargs.__raw_response = rawResponse; } return new AIMessage({ @@ -159,7 +159,7 @@ function openAIResponseToChatMessage( tool_calls: toolCalls, invalid_tool_calls: invalidToolCalls, additional_kwargs, - id: messageId, + id: rawResponse.id, }); } default: @@ -170,9 +170,9 @@ function openAIResponseToChatMessage( function _convertDeltaToMessageChunk( // eslint-disable-next-line @typescript-eslint/no-explicit-any delta: Record, - messageId: string, + rawResponse: OpenAIClient.Chat.Completions.ChatCompletionChunk, defaultRole?: OpenAIRoleEnum, - rawResponse?: OpenAIClient.Chat.Completions.ChatCompletionChunk + includeRawResponse?: boolean ) { const role = delta.role ?? defaultRole; const content = delta.content ?? ""; @@ -188,7 +188,7 @@ function _convertDeltaToMessageChunk( } else { additional_kwargs = {}; } - if (rawResponse !== undefined) { + if (includeRawResponse) { additional_kwargs.__raw_response = rawResponse; } if (role === "user") { @@ -210,7 +210,7 @@ function _convertDeltaToMessageChunk( content, tool_call_chunks: toolCallChunks, additional_kwargs, - id: messageId, + id: rawResponse.id, }); } else if (role === "system") { return new SystemMessageChunk({ content }); @@ -662,9 +662,9 @@ export class ChatOpenAI< } const chunk = _convertDeltaToMessageChunk( delta, - data.id, + data, defaultRole, - this.__includeRawResponse ? data : undefined + this.__includeRawResponse ); defaultRole = delta.role ?? defaultRole; const newTokenIndices = { @@ -814,8 +814,8 @@ export class ChatOpenAI< text, message: openAIResponseToChatMessage( part.message ?? { role: "assistant" }, - data.id, - this.__includeRawResponse ? data : undefined + data, + this.__includeRawResponse ), }; generation.generationInfo = {