Skip to content

Commit

Permalink
feat(community): Introduce callbacks to IBM Watsonx SDK (#7329)
Browse files Browse the repository at this point in the history
Co-authored-by: jacoblee93 <jacoblee93@gmail.com>
  • Loading branch information
FilipZmijewski and jacoblee93 authored Dec 10, 2024
1 parent 99eb5d2 commit ee0e8a2
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 179 deletions.
2 changes: 1 addition & 1 deletion libs/langchain-community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
"@google-cloud/storage": "^7.7.0",
"@gradientai/nodejs-sdk": "^1.2.0",
"@huggingface/inference": "^2.6.4",
"@ibm-cloud/watsonx-ai": "^1.1.0",
"@ibm-cloud/watsonx-ai": "^1.3.0",
"@jest/globals": "^29.5.0",
"@lancedb/lancedb": "^0.13.0",
"@langchain/core": "workspace:*",
Expand Down
37 changes: 27 additions & 10 deletions libs/langchain-community/src/chat_models/ibm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import {
} from "@langchain/core/outputs";
import { AsyncCaller } from "@langchain/core/utils/async_caller";
import {
RequestCallbacks,
TextChatMessagesTextChatMessageAssistant,
TextChatParameterTools,
TextChatParams,
Expand Down Expand Up @@ -81,12 +82,14 @@ export interface WatsonxDeltaStream {
export interface WatsonxCallParams
extends Partial<Omit<TextChatParams, "modelId" | "toolChoice">> {
maxRetries?: number;
watsonxCallbacks?: RequestCallbacks;
}
export interface WatsonxCallOptionsChat
extends Omit<BaseChatModelCallOptions, "stop">,
WatsonxCallParams {
promptIndex?: number;
tool_choice?: TextChatParameterTools | string | "auto" | "any";
watsonxCallbacks?: RequestCallbacks;
}

type ChatWatsonxToolType = BindToolsInput | TextChatParameterTools;
Expand Down Expand Up @@ -420,6 +423,8 @@ export class ChatWatsonx<

streaming: boolean;

watsonxCallbacks?: RequestCallbacks;

constructor(fields: ChatWatsonxInput & WatsonxAuth) {
super(fields);
if (
Expand Down Expand Up @@ -450,7 +455,7 @@ export class ChatWatsonx<
this.n = fields?.n ?? this.n;
this.model = fields?.model ?? this.model;
this.version = fields?.version ?? this.version;

this.watsonxCallbacks = fields?.watsonxCallbacks ?? this.watsonxCallbacks;
const {
watsonxAIApikey,
watsonxAIAuthType,
Expand Down Expand Up @@ -502,6 +507,10 @@ export class ChatWatsonx<
return { ...params, ...toolChoiceResult };
}

invocationCallbacks(options: this["ParsedCallOptions"]) {
return options.watsonxCallbacks ?? this.watsonxCallbacks;
}

override bindTools(
tools: ChatWatsonxToolType[],
kwargs?: Partial<CallOptions>
Expand Down Expand Up @@ -590,15 +599,19 @@ export class ChatWatsonx<
...this.invocationParams(options),
...this.scopeId(),
};
const watsonxCallbacks = this.invocationCallbacks(options);
const watsonxMessages = _convertMessagesToWatsonxMessages(
messages,
this.model
);
const callback = () =>
this.service.textChat({
...params,
messages: watsonxMessages,
});
this.service.textChat(
{
...params,
messages: watsonxMessages,
},
watsonxCallbacks
);
const { result } = await this.completionWithRetry(callback, options);
const generations: ChatGeneration[] = [];
for (const part of result.choices) {
Expand Down Expand Up @@ -638,12 +651,16 @@ export class ChatWatsonx<
messages,
this.model
);
const watsonxCallbacks = this.invocationCallbacks(options);
const callback = () =>
this.service.textChatStream({
...params,
messages: watsonxMessages,
returnObject: true,
});
this.service.textChatStream(
{
...params,
messages: watsonxMessages,
returnObject: true,
},
watsonxCallbacks
);
const stream = await this.completionWithRetry(callback, options);
let defaultRole;
let usage: TextChatUsage | undefined;
Expand Down
118 changes: 114 additions & 4 deletions libs/langchain-community/src/chat_models/tests/ibm.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,9 @@ describe("Tests for chat", () => {
}
);
const llmWithTools = service.bindTools([calculatorTool]);
const res = await llmWithTools.invoke("What is 3 * 12");
const res = await llmWithTools.invoke(
"You are bad at calculations and need to use calculator at all times. What is 3 * 12"
);

expect(res).toBeInstanceOf(AIMessage);
expect(res.tool_calls?.[0].name).toBe("calculator");
Expand Down Expand Up @@ -572,7 +574,7 @@ describe("Tests for chat", () => {
);
const llmWithTools = service.bindTools([calculatorTool]);
const res = await llmWithTools.invoke(
"What is 3 * 12? Also, what is 11 + 49?"
"You are bad at calculations and need to use calculator at all times. What is 3 * 12? Also, what is 11 + 49?"
);

expect(res).toBeInstanceOf(AIMessage);
Expand Down Expand Up @@ -619,7 +621,9 @@ describe("Tests for chat", () => {
},
],
});
const res = await modelWithTools.invoke("What is 32 * 122");
const res = await modelWithTools.invoke(
"You are bad at calculations and need to use calculator at all times. What is 32 * 122"
);

expect(res).toBeInstanceOf(AIMessage);
expect(res.tool_calls?.[0].name).toBe("calculator");
Expand Down Expand Up @@ -666,7 +670,7 @@ describe("Tests for chat", () => {

const modelWithTools = service.bindTools(tools);
const res = await modelWithTools.invoke(
"What is 3 * 12? Also, what is 11 + 49?"
"You are bad at calculations and need to use calculator at all times. What is 3 * 12? Also, what is 11 + 49?"
);

expect(res).toBeInstanceOf(AIMessage);
Expand Down Expand Up @@ -831,4 +835,110 @@ describe("Tests for chat", () => {
expect(typeof result.number2).toBe("number");
});
});

describe("Test watsonx callbacks", () => {
test("Single request callback", async () => {
let callbackFlag = false;
const service = new ChatWatsonx({
model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
maxTokens: 10,
watsonxCallbacks: {
requestCallback(req) {
callbackFlag = !!req;
},
},
});
const hello = await service.stream("Print hello world");
const chunks = [];
for await (const chunk of hello) {
chunks.push(chunk);
}
expect(callbackFlag).toBe(true);
});
test("Single response callback", async () => {
let callbackFlag = false;
const service = new ChatWatsonx({
model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
maxTokens: 10,
watsonxCallbacks: {
responseCallback(res) {
callbackFlag = !!res;
},
},
});
const hello = await service.stream("Print hello world");
const chunks = [];
for await (const chunk of hello) {
chunks.push(chunk);
}
expect(callbackFlag).toBe(true);
});
test("Both callbacks", async () => {
let callbackFlagReq = false;
let callbackFlagRes = false;
const service = new ChatWatsonx({
model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
maxTokens: 10,
watsonxCallbacks: {
requestCallback(req) {
callbackFlagReq = !!req;
},
responseCallback(res) {
callbackFlagRes = !!res;
},
},
});
const hello = await service.stream("Print hello world");
const chunks = [];
for await (const chunk of hello) {
chunks.push(chunk);
}
expect(callbackFlagReq).toBe(true);
expect(callbackFlagRes).toBe(true);
});
test("Multiple callbacks", async () => {
let callbackFlagReq = false;
let callbackFlagRes = false;
let langchainCallback = false;

const service = new ChatWatsonx({
model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
maxTokens: 10,
callbacks: CallbackManager.fromHandlers({
async handleLLMEnd(output) {
expect(output.generations).toBeDefined();
langchainCallback = !!output;
},
}),
watsonxCallbacks: {
requestCallback(req) {
callbackFlagReq = !!req;
},
responseCallback(res) {
callbackFlagRes = !!res;
},
},
});
const hello = await service.stream("Print hello world");
const chunks = [];
for await (const chunk of hello) {
chunks.push(chunk);
}
expect(callbackFlagReq).toBe(true);
expect(callbackFlagRes).toBe(true);
expect(langchainCallback).toBe(true);
});
});
});
2 changes: 1 addition & 1 deletion libs/langchain-community/src/document_compressors/ibm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ export class WatsonxRerank
? {
index: document.index,
relevanceScore: document.score,
input: document?.input,
input: document?.input.text,
}
: {
index: document.index,
Expand Down
Loading

0 comments on commit ee0e8a2

Please sign in to comment.