Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,59 +51,31 @@ protected internal override async ValueTask OnCheckpointRestoredAsync(IWorkflowC

protected override async ValueTask TakeTurnAsync(List<ChatMessage> messages, IWorkflowContext context, bool? emitEvents, CancellationToken cancellationToken = default)
{
emitEvents ??= this._emitEvents;
IAsyncEnumerable<AgentRunResponseUpdate> agentStream = this._agent.RunStreamingAsync(messages, this.EnsureThread(context), cancellationToken: cancellationToken);

List<AIContent> updates = [];
ChatMessage? currentStreamingMessage = null;

await foreach (AgentRunResponseUpdate update in agentStream.ConfigureAwait(false))
if (emitEvents ?? this._emitEvents)
{
if (string.IsNullOrEmpty(update.MessageId))
{
// Ignore updates that don't have a message ID.
continue;
}
// Run the agent in streaming mode only when agent run update events are to be emitted.
IAsyncEnumerable<AgentRunResponseUpdate> agentStream = this._agent.RunStreamingAsync(messages, this.EnsureThread(context), cancellationToken: cancellationToken);

if (emitEvents ?? this._emitEvents)
List<AgentRunResponseUpdate> updates = [];

await foreach (AgentRunResponseUpdate update in agentStream.ConfigureAwait(false))
{
await context.AddEventAsync(new AgentRunUpdateEvent(this.Id, update), cancellationToken).ConfigureAwait(false);
}

// TODO: FunctionCall request handling, and user info request handling.
// In some sense: We should just let it be handled as a ChatMessage, though we should consider
// providing some mechanisms to help the user complete the request, or route it out of the
// workflow.

if (currentStreamingMessage is null || currentStreamingMessage.MessageId != update.MessageId)
{
await PublishCurrentMessageAsync().ConfigureAwait(false);
currentStreamingMessage = new(update.Role ?? ChatRole.Assistant, update.Contents)
{
AuthorName = update.AuthorName,
CreatedAt = update.CreatedAt,
MessageId = update.MessageId,
RawRepresentation = update.RawRepresentation,
AdditionalProperties = update.AdditionalProperties
};
// TODO: FunctionCall request handling, and user info request handling.
// In some sense: We should just let it be handled as a ChatMessage, though we should consider
// providing some mechanisms to help the user complete the request, or route it out of the
// workflow.
updates.Add(update);
}

updates.AddRange(update.Contents);
await context.SendMessageAsync(updates.ToAgentRunResponse().Messages, cancellationToken: cancellationToken).ConfigureAwait(false);
}

await PublishCurrentMessageAsync().ConfigureAwait(false);

async ValueTask PublishCurrentMessageAsync()
else
{
if (currentStreamingMessage is not null && updates.Count > 0)
{
currentStreamingMessage.Contents = updates;
updates = [];

await context.SendMessageAsync(currentStreamingMessage, cancellationToken: cancellationToken).ConfigureAwait(false);
}

currentStreamingMessage = null;
// Otherwise, run the agent in non-streaming mode.
AgentRunResponse response = await this._agent.RunAsync(messages, this.EnsureThread(context), cancellationToken: cancellationToken).ConfigureAwait(false);
await context.SendMessageAsync(response.Messages, cancellationToken: cancellationToken).ConfigureAwait(false);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ internal sealed class TestWorkflowContext(string executorId, bool concurrentRuns
{
private readonly StateManager _stateManager = new();

public List<List<ChatMessage>> Updates { get; } = [];
public List<ChatMessage> Updates { get; } = [];

public ValueTask AddEventAsync(WorkflowEvent workflowEvent, CancellationToken cancellationToken = default) =>
default;
Expand Down Expand Up @@ -145,11 +145,11 @@ public ValueTask SendMessageAsync(object message, string? targetId = null, Cance
{
if (message is List<ChatMessage> messages)
{
this.Updates.Add(messages);
this.Updates.AddRange(messages);
}
else if (message is ChatMessage chatMessage)
{
this.Updates.Add([chatMessage]);
this.Updates.Add(chatMessage);
}

return default;
Expand All @@ -176,52 +176,24 @@ public async Task Test_AIAgentStreamingMessage_AggregationAsync()
"Quisque dignissim ante odio, at facilisis orci porta a. Duis mi augue, fringilla eu egestas a, pellentesque sed lacus."
];

string[][] splits = MessageStrings.Select(t => t.Split()).ToArray();
foreach (string[] messageSplits in splits)
{
for (int i = 0; i < messageSplits.Length - 1; i++)
{
messageSplits[i] += ' ';
}
}

List<ChatMessage> expected = TestAIAgent.ToChatMessages(MessageStrings);

TestAIAgent agent = new(expected);
AIAgentHostExecutor host = new(agent);

TestWorkflowContext collectingContext = new(host.Id);

await host.TakeTurnAsync(new TurnToken(emitEvents: false), collectingContext);
await host.TakeTurnAsync(new TurnToken(emitEvents: true), collectingContext);

// The first empty message is skipped.
collectingContext.Updates.Should().HaveCount(MessageStrings.Length - 1);

for (int i = 1; i < MessageStrings.Length; i++)
{
string expectedText = MessageStrings[i];
string[] expectedSplits = splits[i];

ChatMessage equivalent = expected[i];
List<ChatMessage> collected = collectingContext.Updates[i - 1];

collected.Should().HaveCount(1);
collected[0].Text.Should().Be(expectedText);
collected[0].Contents.Should().HaveCount(splits[i].Length);
ChatMessage collected = collectingContext.Updates[i - 1];

Action<AIContent>[] splitCheckActions = splits[i].Select(MakeSplitCheckAction).ToArray();
Assert.Collection(collected[0].Contents, splitCheckActions);
}

Action<AIContent> MakeSplitCheckAction(string splitString)
{
return Check;

void Check(AIContent content)
{
TextContent? text = content as TextContent;
text!.Text.Should().Be(splitString);
}
collected.Text.Should().Be(expectedText);
}
}
}
Loading