Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JS] feat: streaming support for Tools Augmentation #2195

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions js/packages/teams-ai/src/models/OpenAIModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -373,20 +373,44 @@ export class OpenAIModel implements PromptCompletionModel {
if (delta.content) {
message.content += delta.content;
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these changes need tests?

// Handle tool calls
// - We don't know how many tool calls there will be so we need to add them one-by-one.
if (delta.tool_calls) {
message.action_calls = delta.tool_calls.map(
(toolCall: OpenAI.Chat.Completions.ChatCompletionChunk.Choice.Delta.ToolCall) => {
return {
id: toolCall.id,
function: {
name: toolCall.function!.name,
arguments: toolCall.function!.arguments
},
type: toolCall.type
} as ActionCall;
// Create action calls array if it doesn't exist
if (!Array.isArray(message.action_calls)) {
message.action_calls = [];
}

// Add tool calls to action calls
for (const toolCall of delta.tool_calls) {
// Add empty tool call to message if new index
// - Note that a single tool call can span multiple chunks.
const index = toolCall.index;
if (index >= message.action_calls.length) {
message.action_calls.push({ id: '', function: { name: '', arguments: '' }, type: '' } as any);
}

// Set ID if provided
if (toolCall.id) {
message.action_calls[index].id = toolCall.id;
}

// Set type if provided
if (toolCall.type) {
message.action_calls[index].type = toolCall.type;
}

// Append function name if provided
if (toolCall.function?.name) {
message.action_calls[index].function.name += toolCall.function.name;
}

// Append function arguments if provided
if (toolCall.function?.arguments) {
message.action_calls[index].function.arguments += toolCall.function.arguments;
}
);
}
}

// Signal chunk received
Expand Down
65 changes: 43 additions & 22 deletions js/packages/teams-ai/src/planners/LLMClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ export class LLMClient<TContent = any> {
functions: PromptFunctions
): Promise<PromptResponse<TContent>> {
// Define event handlers
let isStreaming = false;
let streamer: StreamingResponse | undefined;
const beforeCompletion: PromptCompletionModelBeforeCompletionEvent = (
ctx,
Expand All @@ -301,20 +300,23 @@ export class LLMClient<TContent = any> {

// Check for a streaming response
if (streaming) {
isStreaming = true;

// Create streamer and send initial message
streamer = new StreamingResponse(context);
memory.setValue('temp.streamer', streamer);

if (this._enableFeedbackLoop != null) {
streamer.setFeedbackLoop(this._enableFeedbackLoop);
}

streamer.setGeneratedByAILabel(true);

if (this._startStreamingMessage) {
streamer.queueInformativeUpdate(this._startStreamingMessage);
// Attach to any existing streamer
// - see tool call note below to understand.
streamer = memory.getValue('temp.streamer');
if (!streamer) {
// Create streamer and send initial message
streamer = new StreamingResponse(context);
memory.setValue('temp.streamer', streamer);

if (this._enableFeedbackLoop != null) {
streamer.setFeedbackLoop(this._enableFeedbackLoop);
}

streamer.setGeneratedByAILabel(true);

if (this._startStreamingMessage) {
streamer.queueInformativeUpdate(this._startStreamingMessage);
}
}
}
};
Expand All @@ -325,6 +327,12 @@ export class LLMClient<TContent = any> {
return;
}

// Ignore tool calls
// - see the tool call note below to understand why we're ignoring them.
if ((chunk.delta as any)?.tool_calls || chunk.delta?.action_calls) {
return;
}

// Send chunk to client
const text = chunk.delta?.content ?? '';
const citations = chunk.delta?.context?.citations ?? undefined;
Expand All @@ -347,15 +355,28 @@ export class LLMClient<TContent = any> {
try {
// Complete the prompt
const response = await this.callCompletePrompt(context, memory, functions);
if (response.status == 'success' && isStreaming) {
// Delete message from response to avoid sending it twice
delete response.message;
}

// End the stream if streaming
// - We're not listening for the response received event because we can't await the completion of events.
// Handle streaming responses
if (streamer) {
await streamer.endStream();
// Tool call handling
// - We need to keep the streamer around during tool calls so we're just letting them return as normal
// messages minus the message content. The text content is being streamed to the client in chunks.
// - When the tool call completes we'll call back into ActionPlanner and end up re-attaching to the
// streamer. This will result in us continuing to stream the response to the client.
if (Array.isArray(response.message?.action_calls)) {
// Ensure content is empty for tool calls
response.message!.content = '' as TContent;
} else {
if (response.status == 'success') {
// Delete message from response to avoid sending it twice
delete response.message;
}

// End the stream and remove pointer from memory
// - We're not listening for the response received event because we can't await the completion of events.
await streamer.endStream();
memory.deleteValue('temp.streamer');
}
}

return response;
Expand Down
Loading