Skip to content

Commit

Permalink
Merge pull request #432 from hchen2020/server-sent-events
Browse files Browse the repository at this point in the history
Server-sent Events.
  • Loading branch information
Oceania2018 authored Apr 30, 2024
2 parents 0fa8e5c + 0ae2c4b commit ed4f567
Show file tree
Hide file tree
Showing 15 changed files with 142 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ public class RoleDialogModel : ITrackableMessage

public string? SecondaryContent { get; set; }

/// <summary>
/// Indicator message used to provide UI feedback for function execution
/// </summary>
public string? Indication { get; set; }

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string CurrentAgentId { get; set; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,11 @@ namespace BotSharp.Abstraction.Functions;
public interface IFunctionCallback
{
string Name { get; }

/// <summary>
/// Indicator message used to provide UI feedback for function execution
/// </summary>
string Indication => string.Empty;

Task<bool> Execute(RoleDialogModel message);
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ public interface IRoutingHandler

void SetDialogs(List<RoleDialogModel> dialogs);

Task<bool> Handle(IRoutingService routing, FunctionCallFromLlm inst, RoleDialogModel message);
Task<bool> Handle(IRoutingService routing, FunctionCallFromLlm inst, RoleDialogModel message, Func<RoleDialogModel, Task> onFunctionExecuting);
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using BotSharp.Abstraction.Routing.Models;

namespace BotSharp.Abstraction.Routing;

public interface IRoutingService
Expand Down Expand Up @@ -30,9 +28,9 @@ public interface IRoutingService

List<RoutingHandlerDef> GetHandlers(Agent router);
void ResetRecursiveCounter();
Task<bool> InvokeAgent(string agentId, List<RoleDialogModel> dialogs);
Task<bool> InvokeFunction(string name, RoleDialogModel message);
Task<RoleDialogModel> InstructLoop(RoleDialogModel message, List<RoleDialogModel> dialogs);
Task<bool> InvokeAgent(string agentId, List<RoleDialogModel> dialogs, Func<RoleDialogModel, Task> onFunctionExecuting);
Task<bool> InvokeFunction(string name, RoleDialogModel messages, Func<RoleDialogModel, Task>? onFunctionExecuting = null);
Task<RoleDialogModel> InstructLoop(RoleDialogModel message, List<RoleDialogModel> dialogs, Func<RoleDialogModel, Task> onFunctionExecuting);

/// <summary>
/// Talk to a specific Agent directly, bypassing the Router
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ public interface IExecutor
Task<RoleDialogModel> Execute(IRoutingService routing,
FunctionCallFromLlm inst,
RoleDialogModel message,
List<RoleDialogModel> dialogs);
List<RoleDialogModel> dialogs,
Func<RoleDialogModel, Task> onFunctionExecuting);
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public async Task<bool> SendMessage(string agentId,
var settings = _services.GetRequiredService<RoutingSettings>();

response = agent.Type == AgentType.Routing ?
await routing.InstructLoop(message, dialogs) :
await routing.InstructLoop(message, dialogs, onFunctionExecuting) :
await routing.InstructDirect(agent, message);

routing.ResetRecursiveCounter();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public async Task<bool> Execute(RoleDialogModel message)
routing.Context.Replace(targetAgent.Id);
message.CurrentAgentId = targetAgent.Id;

var response = await routing.InstructLoop(message, dialogs);
var response = await routing.InstructLoop(message, dialogs, null);

message.Content = response.Content;
message.StopCompletion = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public ResponseToUserRoutingHandler(IServiceProvider services, ILogger<ResponseT
{
}

public async Task<bool> Handle(IRoutingService routing, FunctionCallFromLlm inst, RoleDialogModel message)
public async Task<bool> Handle(IRoutingService routing, FunctionCallFromLlm inst, RoleDialogModel message, Func<RoleDialogModel, Task> onFunctionExecuting)
{
var response = new RoleDialogModel(AgentRole.Assistant, inst.Response)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public RetrieveDataFromAgentRoutingHandler(IServiceProvider services, ILogger<Re
{
}

public async Task<bool> Handle(IRoutingService routing, FunctionCallFromLlm inst, RoleDialogModel message)
public async Task<bool> Handle(IRoutingService routing, FunctionCallFromLlm inst, RoleDialogModel message, Func<RoleDialogModel, Task> onFunctionExecuting)
{
var context = _services.GetRequiredService<IRoutingContext>();
var agentId = context.GetCurrentAgentId();
Expand All @@ -47,7 +47,7 @@ public async Task<bool> Handle(IRoutingService routing, FunctionCallFromLlm inst
}
};

var ret = await routing.InvokeAgent(agentId, dialogs);
var ret = await routing.InvokeAgent(agentId, dialogs, onFunctionExecuting);
var response = dialogs.Last();
inst.Response = response.Content;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public RouteToAgentRoutingHandler(IServiceProvider services, ILogger<RouteToAgen
{
}

public async Task<bool> Handle(IRoutingService routing, FunctionCallFromLlm inst, RoleDialogModel message)
public async Task<bool> Handle(IRoutingService routing, FunctionCallFromLlm inst, RoleDialogModel message, Func<RoleDialogModel, Task> onFunctionExecuting)
{
var states = _services.GetRequiredService<IConversationStateService>();
var goalAgent = states.GetState(StateConst.EXPECTED_GOAL_AGENT);
Expand Down Expand Up @@ -85,7 +85,7 @@ await hook.OnNewTaskDetected(message, inst.NextActionReason)
}
else
{
ret = await routing.InvokeAgent(agentId, _dialogs);
ret = await routing.InvokeAgent(agentId, _dialogs, onFunctionExecuting);
}

var response = _dialogs.Last();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using BotSharp.Abstraction.Functions.Models;
using BotSharp.Abstraction.Routing;
using BotSharp.Abstraction.Routing.Planning;

namespace BotSharp.Core.Routing.Planning;
Expand All @@ -18,15 +16,16 @@ public InstructExecutor(IServiceProvider services, ILogger<InstructExecutor> log
public async Task<RoleDialogModel> Execute(IRoutingService routing,
FunctionCallFromLlm inst,
RoleDialogModel message,
List<RoleDialogModel> dialogs)
List<RoleDialogModel> dialogs,
Func<RoleDialogModel, Task> onFunctionExecuting)
{
message.Instruction = inst;

var handlers = _services.GetServices<IRoutingHandler>();
var handler = handlers.FirstOrDefault(x => x.Name == inst.Function);
handler.SetDialogs(dialogs);

var handled = await handler.Handle(routing, inst, message);
var handled = await handler.Handle(routing, inst, message, onFunctionExecuting);

// For client display purpose
var response = dialogs.Last();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace BotSharp.Core.Routing;
public partial class RoutingService
{
private int _currentRecursionDepth = 0;
public async Task<bool> InvokeAgent(string agentId, List<RoleDialogModel> dialogs)
public async Task<bool> InvokeAgent(string agentId, List<RoleDialogModel> dialogs, Func<RoleDialogModel, Task> onFunctionExecuting)
{
var agentService = _services.GetRequiredService<IAgentService>();
var agent = await agentService.LoadAgent(agentId);
Expand Down Expand Up @@ -34,7 +34,8 @@ public async Task<bool> InvokeAgent(string agentId, List<RoleDialogModel> dialog
message.FunctionName = response.FunctionName;
message.FunctionArgs = response.FunctionArgs;
message.CurrentAgentId = agent.Id;
await InvokeFunction(message, dialogs);

await InvokeFunction(message, dialogs, onFunctionExecuting);
}
else
{
Expand All @@ -54,7 +55,7 @@ public async Task<bool> InvokeAgent(string agentId, List<RoleDialogModel> dialog
return true;
}

private async Task<bool> InvokeFunction(RoleDialogModel message, List<RoleDialogModel> dialogs)
private async Task<bool> InvokeFunction(RoleDialogModel message, List<RoleDialogModel> dialogs, Func<RoleDialogModel, Task>? onFunctionExecuting = null)
{
// execute function
// Save states
Expand All @@ -63,7 +64,7 @@ private async Task<bool> InvokeFunction(RoleDialogModel message, List<RoleDialog

var routing = _services.GetRequiredService<IRoutingService>();
// Call functions
await routing.InvokeFunction(message.FunctionName, message);
await routing.InvokeFunction(message.FunctionName, message, onFunctionExecuting);

// Pass execution result to LLM to get response
if (!message.StopCompletion)
Expand All @@ -88,7 +89,7 @@ private async Task<bool> InvokeFunction(RoleDialogModel message, List<RoleDialog

// Send to Next LLM
var agentId = routing.Context.GetCurrentAgentId();
await InvokeAgent(agentId, dialogs);
await InvokeAgent(agentId, dialogs, onFunctionExecuting);
}
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ namespace BotSharp.Core.Routing;

public partial class RoutingService
{
public async Task<bool> InvokeFunction(string name, RoleDialogModel message)
public async Task<bool> InvokeFunction(string name, RoleDialogModel message, Func<RoleDialogModel, Task>? onFunctionExecuting = null)
{
var function = _services.GetServices<IFunctionCallback>().FirstOrDefault(x => x.Name == name);
if (function == null)
Expand All @@ -24,6 +24,12 @@ public async Task<bool> InvokeFunction(string name, RoleDialogModel message)
.ToList();

// Before executing functions
clonedMessage.Indication = function.Indication;
if (onFunctionExecuting != null)
{
await onFunctionExecuting(clonedMessage);
}

foreach (var hook in hooks)
{
await hook.OnFunctionExecuting(clonedMessage);
Expand Down
19 changes: 4 additions & 15 deletions src/Infrastructure/BotSharp.Core/Routing/RoutingService.cs
Original file line number Diff line number Diff line change
@@ -1,19 +1,8 @@
using BotSharp.Abstraction.Agents.Models;
using BotSharp.Abstraction.Infrastructures.Enums;
using BotSharp.Abstraction.Routing.Models;
using BotSharp.Abstraction.Routing.Planning;
using BotSharp.Abstraction.Routing.Settings;
using BotSharp.Abstraction.Templating;
using BotSharp.Core.Routing.Planning;
using Fluid.Ast;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System.Diagnostics.Metrics;
using System.Drawing;
using System.Runtime.InteropServices;
using System.Security.Cryptography.X509Certificates;
using ThirdParty.Json.LitJson;
using static System.Net.Mime.MediaTypeNames;

namespace BotSharp.Core.Routing;

Expand Down Expand Up @@ -67,7 +56,7 @@ public async Task<RoleDialogModel> InstructDirect(Agent agent, RoleDialogModel m
ExecutingDirectly = true
};

var result = await handler.Handle(this, inst, message);
var result = await handler.Handle(this, inst, message, null);

var response = dialogs.Last();
response.MessageId = message.MessageId;
Expand All @@ -76,7 +65,7 @@ public async Task<RoleDialogModel> InstructDirect(Agent agent, RoleDialogModel m
return response;
}

public async Task<RoleDialogModel> InstructLoop(RoleDialogModel message, List<RoleDialogModel> dialogs)
public async Task<RoleDialogModel> InstructLoop(RoleDialogModel message, List<RoleDialogModel> dialogs, Func<RoleDialogModel, Task> onFunctionExecuting)
{
RoleDialogModel response = default;

Expand Down Expand Up @@ -134,12 +123,12 @@ await hook.OnRoutingInstructionReceived(inst, message)
if (inst.HandleDialogsByPlanner)
{
var dialogWithoutContext = planner.BeforeHandleContext(inst, message, dialogs);
response = await executor.Execute(this, inst, message, dialogWithoutContext);
response = await executor.Execute(this, inst, message, dialogWithoutContext, onFunctionExecuting);
planner.AfterHandleContext(dialogs, dialogWithoutContext);
}
else
{
response = await executor.Execute(this, inst, message, dialogs);
response = await executor.Execute(this, inst, message, dialogs, onFunctionExecuting);
}

await planner.AgentExecuted(_router, inst, response, dialogs);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using BotSharp.Abstraction.Routing;
using Newtonsoft.Json.Serialization;
using Newtonsoft.Json;

namespace BotSharp.OpenAPI.Controllers;

Expand Down Expand Up @@ -149,6 +151,15 @@ public async Task<bool> DeleteConversationMessage([FromRoute] string conversatio
return response;
}

private void SetStates(IConversationService conv, NewMessageModel input)
{
conv.States.SetState("channel", input.Channel, source: StateSource.External)
.SetState("provider", input.Provider, source: StateSource.External)
.SetState("model", input.Model, source: StateSource.External)
.SetState("temperature", input.Temperature, source: StateSource.External)
.SetState("sampling_factor", input.SamplingFactor, source: StateSource.External);
}

[HttpPost("/conversation/{agentId}/{conversationId}")]
public async Task<ChatResponseModel> SendMessage([FromRoute] string agentId,
[FromRoute] string conversationId,
Expand All @@ -165,11 +176,7 @@ public async Task<ChatResponseModel> SendMessage([FromRoute] string agentId,
routing.Context.SetMessageId(conversationId, inputMsg.MessageId);

conv.SetConversationId(conversationId, input.States);
conv.States.SetState("channel", input.Channel, source: StateSource.External)
.SetState("provider", input.Provider, source: StateSource.External)
.SetState("model", input.Model, source: StateSource.External)
.SetState("temperature", input.Temperature, source: StateSource.External)
.SetState("sampling_factor", input.SamplingFactor, source: StateSource.External);
SetStates(conv, input);

var response = new ChatResponseModel();

Expand All @@ -194,6 +201,92 @@ await conv.SendMessage(agentId, inputMsg,
return response;
}

[HttpPost("/conversation/{agentId}/{conversationId}/sse")]
public async Task SendMessageSse([FromRoute] string agentId,
[FromRoute] string conversationId,
[FromBody] NewMessageModel input)
{
var conv = _services.GetRequiredService<IConversationService>();
if (!string.IsNullOrEmpty(input.TruncateMessageId))
{
await conv.TruncateConversation(conversationId, input.TruncateMessageId);
}

var inputMsg = new RoleDialogModel(AgentRole.User, input.Text);
var routing = _services.GetRequiredService<IRoutingService>();
routing.Context.SetMessageId(conversationId, inputMsg.MessageId);

conv.SetConversationId(conversationId, input.States);
SetStates(conv, input);

var response = new ChatResponseModel();

Response.StatusCode = 200;
Response.Headers.Append(Microsoft.Net.Http.Headers.HeaderNames.ContentType, "text/event-stream");
Response.Headers.Append(Microsoft.Net.Http.Headers.HeaderNames.CacheControl, "no-cache");
Response.Headers.Append(Microsoft.Net.Http.Headers.HeaderNames.Connection, "keep-alive");

await conv.SendMessage(agentId, inputMsg,
replyMessage: input.Postback,
async msg =>
{
response.Text = !string.IsNullOrEmpty(msg.SecondaryContent) ? msg.SecondaryContent : msg.Content;
response.Function = msg.FunctionName;
response.RichContent = msg.SecondaryRichContent ?? msg.RichContent;
response.Instruction = msg.Instruction;
response.Data = msg.Data;

await OnChunkReceived(Response, msg);
},
async msg =>
{
var message = new RoleDialogModel(AgentRole.Function, msg.Content)
{
FunctionArgs = msg.FunctionArgs,
FunctionName = msg.FunctionName,
Indication = msg.Indication
};
await OnChunkReceived(Response, message);
},
async msg =>
{

});

var state = _services.GetRequiredService<IConversationStateService>();
response.States = state.GetStates();
response.MessageId = inputMsg.MessageId;
response.ConversationId = conversationId;

// await OnEventCompleted(Response);
}

private async Task OnChunkReceived(HttpResponse response, RoleDialogModel message)
{
var json = JsonConvert.SerializeObject(message, new JsonSerializerSettings
{
Formatting = Formatting.None,
ContractResolver = new CamelCasePropertyNamesContractResolver(),
NullValueHandling = NullValueHandling.Ignore,
});

var buffer = Encoding.UTF8.GetBytes($"data:{json}\n");
await response.Body.WriteAsync(buffer, 0, buffer.Length);
await Task.Delay(10);

buffer = Encoding.UTF8.GetBytes("\n");
await response.Body.WriteAsync(buffer, 0, buffer.Length);
}

private async Task OnEventCompleted(HttpResponse response)
{
var buffer = Encoding.UTF8.GetBytes("data:[DONE]\n");
await response.Body.WriteAsync(buffer, 0, buffer.Length);

buffer = Encoding.UTF8.GetBytes("\n");
await response.Body.WriteAsync(buffer, 0, buffer.Length);
}

[HttpPost("/conversation/{conversationId}/attachments")]
public IActionResult UploadAttachments([FromRoute] string conversationId,
IFormFile[] files)
Expand Down

0 comments on commit ed4f567

Please sign in to comment.