Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ public interface IRealtimeHub

IRealTimeCompletion Completer { get; }

Task ConnectToModel(Func<string, Task>? responseToUser = null, Func<string, Task>? init = null);
Task ConnectToModel(Func<string, Task>? responseToUser = null, Func<string, Task>? init = null, List<MessageState>? initStates = null);
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using BotSharp.Abstraction.Functions.Models;
using BotSharp.Abstraction.Hooks;
using BotSharp.Abstraction.Models;
using BotSharp.Abstraction.Options;
using BotSharp.Core.Infrastructures;

Expand All @@ -22,10 +23,10 @@ public RealtimeHub(IServiceProvider services, ILogger<RealtimeHub> logger)
_logger = logger;
}

public async Task ConnectToModel(Func<string, Task>? responseToUser = null, Func<string, Task>? init = null)
public async Task ConnectToModel(Func<string, Task>? responseToUser = null, Func<string, Task>? init = null, List<MessageState>? initStates = null)
{
var convService = _services.GetRequiredService<IConversationService>();
convService.SetConversationId(_conn.ConversationId, []);
convService.SetConversationId(_conn.ConversationId, initStates ?? []);
var conversation = await convService.GetConversation(_conn.ConversationId);

var routing = _services.GetRequiredService<IRoutingService>();
Expand Down
29 changes: 22 additions & 7 deletions src/Plugins/BotSharp.Plugin.ChatHub/ChatStreamMiddleware.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using BotSharp.Abstraction.Models;
using BotSharp.Abstraction.Realtime.Models.Session;
using BotSharp.Core.Session;
using Microsoft.AspNetCore.Http;
Expand Down Expand Up @@ -66,6 +67,7 @@ private async Task HandleWebSocket(IServiceProvider services, string agentId, st

// load conversation and state
var convService = services.GetRequiredService<IConversationService>();
var state = services.GetRequiredService<IConversationStateService>();
convService.SetConversationId(conversationId, []);
await convService.GetConversationRecordOrCreateNew(agentId);

Expand All @@ -80,7 +82,8 @@ private async Task HandleWebSocket(IServiceProvider services, string agentId, st
var (eventType, data) = MapEvents(conn, receivedText);
if (eventType == "start")
{
await ConnectToModel(hub, webSocket);
var states = InitStates(data);
await ConnectToModel(hub, webSocket, states);
}
else if (eventType == "media")
{
Expand All @@ -96,34 +99,33 @@ private async Task HandleWebSocket(IServiceProvider services, string agentId, st
}
}

convService.SaveStates();
await _session.DisconnectAsync();
_session.Dispose();
}

private async Task ConnectToModel(IRealtimeHub hub, WebSocket webSocket)
private async Task ConnectToModel(IRealtimeHub hub, WebSocket webSocket, List<MessageState>? states = null)
{
await hub.ConnectToModel(async data =>
await hub.ConnectToModel(responseToUser: async data =>
{
if (_session != null)
{
await _session.SendEventAsync(data);
}
});
}, initStates: states);
}

private (string, string) MapEvents(RealtimeHubConnection conn, string receivedText)
{
var response = JsonSerializer.Deserialize<ChatStreamEventResponse>(receivedText);
string data = string.Empty;
var data = response?.Body?.Payload ?? string.Empty;

switch (response.Event)
{
case "start":
conn.ResetStreamState();
break;
case "media":
var mediaResponse = JsonSerializer.Deserialize<ChatStreamMediaEventResponse>(receivedText);
data = mediaResponse?.Body?.Payload ?? string.Empty;
break;
case "disconnect":
break;
Expand Down Expand Up @@ -154,4 +156,17 @@ private void InitEvents(RealtimeHubConnection conn)
@event = "clear"
});
}

private List<MessageState> InitStates(string data)
{
try
{
var states = JsonSerializer.Deserialize<List<MessageState>>(data, BotSharpOptions.defaultJsonOptions);
return states ?? [];
}
catch
{
return [];
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@ internal class ChatStreamEventResponse
{
[JsonPropertyName("event")]
public string Event { get; set; }
}

internal class ChatStreamMediaEventResponse : ChatStreamEventResponse
{
[JsonPropertyName("body")]
public MediaEventResponseBody Body { get; set; }
public ChatStreamEventResponseBody Body { get; set; }
}

internal class MediaEventResponseBody
internal class ChatStreamEventResponseBody
{
[JsonPropertyName("payload")]
public string Payload { get; set; }
Expand Down
Loading