From 4e39b9e67278e689db737a2a0b4375d020adf459 Mon Sep 17 00:00:00 2001 From: Steven Ickman Date: Wed, 20 Nov 2024 12:02:22 -0800 Subject: [PATCH 1/2] Added tool support to streaming --- .../teams-ai/src/models/OpenAIModel.ts | 30 +++++---- .../teams-ai/src/planners/LLMClient.ts | 65 ++++++++++++------- 2 files changed, 61 insertions(+), 34 deletions(-) diff --git a/js/packages/teams-ai/src/models/OpenAIModel.ts b/js/packages/teams-ai/src/models/OpenAIModel.ts index 11734fe5f..8d3c728b4 100644 --- a/js/packages/teams-ai/src/models/OpenAIModel.ts +++ b/js/packages/teams-ai/src/models/OpenAIModel.ts @@ -373,20 +373,26 @@ export class OpenAIModel implements PromptCompletionModel { if (delta.content) { message.content += delta.content; } + // Handle tool calls + // - We don't know how many tool calls there will be so we need to add them one-by-one. if (delta.tool_calls) { - message.action_calls = delta.tool_calls.map( - (toolCall: OpenAI.Chat.Completions.ChatCompletionChunk.Choice.Delta.ToolCall) => { - return { - id: toolCall.id, - function: { - name: toolCall.function!.name, - arguments: toolCall.function!.arguments - }, - type: toolCall.type - } as ActionCall; - } - ); + // Create action calls array if it doesn't exist + if (!Array.isArray(message.action_calls)) { + message.action_calls = []; + } + + // Add tool calls to action calls + for (const toolCall of delta.tool_calls) { + message.action_calls.push({ + id: toolCall.id, + function: { + name: toolCall.function!.name, + arguments: toolCall.function!.arguments + }, + type: toolCall.type + } as ActionCall); + } } // Signal chunk received diff --git a/js/packages/teams-ai/src/planners/LLMClient.ts b/js/packages/teams-ai/src/planners/LLMClient.ts index 50652c0da..726fcb69c 100644 --- a/js/packages/teams-ai/src/planners/LLMClient.ts +++ b/js/packages/teams-ai/src/planners/LLMClient.ts @@ -284,7 +284,6 @@ export class LLMClient { functions: PromptFunctions ): Promise> { // Define event handlers - let isStreaming = false; let streamer: StreamingResponse | undefined; const beforeCompletion: PromptCompletionModelBeforeCompletionEvent = ( ctx, @@ -301,20 +300,23 @@ export class LLMClient { // Check for a streaming response if (streaming) { - isStreaming = true; - - // Create streamer and send initial message - streamer = new StreamingResponse(context); - memory.setValue('temp.streamer', streamer); - - if (this._enableFeedbackLoop != null) { - streamer.setFeedbackLoop(this._enableFeedbackLoop); - } - - streamer.setGeneratedByAILabel(true); - - if (this._startStreamingMessage) { - streamer.queueInformativeUpdate(this._startStreamingMessage); + // Attach to any existing streamer + // - see tool call note below to understand. + streamer = memory.getValue('temp.streamer'); + if (!streamer) { + // Create streamer and send initial message + streamer = new StreamingResponse(context); + memory.setValue('temp.streamer', streamer); + + if (this._enableFeedbackLoop != null) { + streamer.setFeedbackLoop(this._enableFeedbackLoop); + } + + streamer.setGeneratedByAILabel(true); + + if (this._startStreamingMessage) { + streamer.queueInformativeUpdate(this._startStreamingMessage); + } } } }; @@ -325,6 +327,12 @@ export class LLMClient { return; } + // Ignore tool calls + // - see the tool call note below to understand why we're ignoring them. + if ((chunk.delta as any)?.tool_calls || chunk.delta?.action_calls) { + return; + } + // Send chunk to client const text = chunk.delta?.content ?? ''; const citations = chunk.delta?.context?.citations ?? undefined; @@ -347,15 +355,28 @@ export class LLMClient { try { // Complete the prompt const response = await this.callCompletePrompt(context, memory, functions); - if (response.status == 'success' && isStreaming) { - // Delete message from response to avoid sending it twice - delete response.message; - } - // End the stream if streaming - // - We're not listening for the response received event because we can't await the completion of events. + // Handle streaming responses if (streamer) { - await streamer.endStream(); + // Tool call handling + // - We need to keep the streamer around during tool calls so we're just letting them return as normal + // messages minus the message content. The text content is being streamed to the client in chunks. + // - When the tool call completes we'll call back into ActionPlanner and end up re-attaching to the + // streamer. This will result in us continuing to stream the response to the client. + if (Array.isArray(response.message?.action_calls)) { + // Ensure content is empty for tool calls + response.message!.content = '' as TContent; + } else { + if (response.status == 'success') { + // Delete message from response to avoid sending it twice + delete response.message; + } + + // End the stream and remove pointer from memory + // - We're not listening for the response received event because we can't await the completion of events. + await streamer.endStream(); + memory.deleteValue('temp.streamer'); + } } return response; From e71b392675939d84c3d4ea0bb164aded339f6eab Mon Sep 17 00:00:00 2001 From: Steven Ickman Date: Wed, 20 Nov 2024 17:11:59 -0800 Subject: [PATCH 2/2] Fixed bug with tool calls being split across multiple chunks. --- .../teams-ai/src/models/OpenAIModel.ts | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/js/packages/teams-ai/src/models/OpenAIModel.ts b/js/packages/teams-ai/src/models/OpenAIModel.ts index 8d3c728b4..402062a76 100644 --- a/js/packages/teams-ai/src/models/OpenAIModel.ts +++ b/js/packages/teams-ai/src/models/OpenAIModel.ts @@ -384,14 +384,32 @@ export class OpenAIModel implements PromptCompletionModel { // Add tool calls to action calls for (const toolCall of delta.tool_calls) { - message.action_calls.push({ - id: toolCall.id, - function: { - name: toolCall.function!.name, - arguments: toolCall.function!.arguments - }, - type: toolCall.type - } as ActionCall); + // Add empty tool call to message if new index + // - Note that a single tool call can span multiple chunks. + const index = toolCall.index; + if (index >= message.action_calls.length) { + message.action_calls.push({ id: '', function: { name: '', arguments: '' }, type: '' } as any); + } + + // Set ID if provided + if (toolCall.id) { + message.action_calls[index].id = toolCall.id; + } + + // Set type if provided + if (toolCall.type) { + message.action_calls[index].type = toolCall.type; + } + + // Append function name if provided + if (toolCall.function?.name) { + message.action_calls[index].function.name += toolCall.function.name; + } + + // Append function arguments if provided + if (toolCall.function?.arguments) { + message.action_calls[index].function.arguments += toolCall.function.arguments; + } } }