From d2d28be378a37cba7f9249b525b085956f99d472 Mon Sep 17 00:00:00 2001 From: Jicheng Lu <103353@smsassist.com> Date: Wed, 3 Apr 2024 12:12:31 -0500 Subject: [PATCH] optimize state --- .../IConversationStateService.cs | 1 + .../Services/ConversationStateService.cs | 174 ++++++++++++------ .../Functions/GetPizzaTypesFn.cs | 2 +- 3 files changed, 121 insertions(+), 56 deletions(-) diff --git a/src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationStateService.cs b/src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationStateService.cs index 4b28c1d3b..655e657a6 100644 --- a/src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationStateService.cs +++ b/src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationStateService.cs @@ -16,6 +16,7 @@ public interface IConversationStateService IConversationStateService SetState(string name, T value, bool isNeedVersion = true, int activeRounds = -1, string valueType = StateDataType.String, string source = StateSource.User, bool readOnly = false); void SaveStateByArgs(JsonDocument args); + bool RemoveState(string name); void CleanStates(); void Save(); } diff --git a/src/Infrastructure/BotSharp.Core/Conversations/Services/ConversationStateService.cs b/src/Infrastructure/BotSharp.Core/Conversations/Services/ConversationStateService.cs index b87a8fd72..5f632bd68 100644 --- a/src/Infrastructure/BotSharp.Core/Conversations/Services/ConversationStateService.cs +++ b/src/Infrastructure/BotSharp.Core/Conversations/Services/ConversationStateService.cs @@ -10,9 +10,16 @@ public class ConversationStateService : IConversationStateService, IDisposable { private readonly ILogger _logger; private readonly IServiceProvider _services; - private ConversationState _states; - private string _conversationId; private readonly IBotSharpRepository _db; + private string _conversationId; + /// + /// States in the current round of conversation + /// + private ConversationState _curStates; + /// + /// States in the previous rounds of conversation + /// + private ConversationState _historyStates; public ConversationStateService(ILogger logger, IServiceProvider services, @@ -21,7 +28,8 @@ public ConversationStateService(ILogger logger, _logger = logger; _services = services; _db = db; - _states = new ConversationState(); + _curStates = new ConversationState(); + _historyStates = new ConversationState(); } public string GetConversationId() => _conversationId; @@ -48,7 +56,7 @@ public IConversationStateService SetState(string name, T value, bool isNeedVe var curActiveRounds = activeRounds > 0 ? activeRounds : -1; int? preActiveRounds = null; - if (ContainsState(name) && _states.TryGetValue(name, out var pair)) + if (ContainsState(name) && _curStates.TryGetValue(name, out var pair)) { var lastNode = pair?.Values?.LastOrDefault(); preActiveRounds = lastNode?.ActiveRounds; @@ -96,14 +104,14 @@ public IConversationStateService SetState(string name, T value, bool isNeedVe UpdateTime = DateTime.UtcNow, }; - if (!isNeedVersion || !_states.ContainsKey(name)) + if (!isNeedVersion || !_curStates.ContainsKey(name)) { newPair.Values = new List { newValue }; - _states[name] = newPair; + _curStates[name] = newPair; } else { - _states[name].Values.Add(newValue); + _curStates[name].Values.Add(newValue); } return this; @@ -115,56 +123,68 @@ public Dictionary Load(string conversationId) var routingCtx = _services.GetRequiredService(); var curMsgId = routingCtx.MessageId; - _states = _db.GetConversationStates(_conversationId); + + _historyStates = _db.GetConversationStates(_conversationId); var dialogs = _db.GetConversationDialogs(_conversationId); var userDialogs = dialogs.Where(x => x.MetaData?.Role == AgentRole.User || x.MetaData?.Role == UserRole.Client) .OrderBy(x => x.MetaData?.CreateTime) .ToList(); - var curMsgIndex = userDialogs.FindIndex(x => !string.IsNullOrEmpty(curMsgId) && x.MetaData?.MessageId == curMsgId); curMsgIndex = curMsgIndex < 0 ? userDialogs.Count() : curMsgIndex; - var curStates = new Dictionary(); - if (!_states.IsNullOrEmpty()) + var endNodes = new Dictionary(); + if (_historyStates.IsNullOrEmpty()) return endNodes; + + foreach (var state in _historyStates) { - foreach (var state in _states) + var key = state.Key; + var value = state.Value; + var leafNode = value.Values.LastOrDefault(); + if (leafNode == null) continue; + + _curStates[key] = new StateKeyValue { - var value = state.Value?.Values?.LastOrDefault(); - if (value == null || !value.Active) continue; + Key = key, + Versioning = value.Versioning, + Readonly = value.Readonly, + Values = new List { leafNode } + }; + + if (!leafNode.Active) continue; - if (value.ActiveRounds > 0) + // Handle state active rounds + if (leafNode.ActiveRounds > 0) + { + var stateMsgIndex = userDialogs.FindIndex(x => !string.IsNullOrEmpty(x.MetaData?.MessageId) && x.MetaData.MessageId == leafNode.MessageId); + if (stateMsgIndex >= 0 && curMsgIndex - stateMsgIndex >= leafNode.ActiveRounds) { - var stateMsgIndex = userDialogs.FindIndex(x => !string.IsNullOrEmpty(x.MetaData?.MessageId) && x.MetaData.MessageId == value.MessageId); - if (stateMsgIndex >= 0 && curMsgIndex - stateMsgIndex >= value.ActiveRounds) + _curStates[key].Values.Add(new StateValue { - state.Value.Values.Add(new StateValue - { - Data = value.Data, - MessageId = curMsgId, - Active = false, - ActiveRounds = value.ActiveRounds, - DataType = value.DataType, - Source = value.Source, - UpdateTime = DateTime.UtcNow - }); - continue; - } + Data = leafNode.Data, + MessageId = curMsgId, + Active = false, + ActiveRounds = leafNode.ActiveRounds, + DataType = leafNode.DataType, + Source = leafNode.Source, + UpdateTime = DateTime.UtcNow + }); + continue; } - - var data = value.Data ?? string.Empty; - curStates[state.Key] = data; - _logger.LogInformation($"[STATE] {state.Key} : {data}"); } + + var data = leafNode.Data ?? string.Empty; + endNodes[state.Key] = data; + _logger.LogInformation($"[STATE] {key} : {data}"); } _logger.LogInformation($"Loaded conversation states: {_conversationId}"); var hooks = _services.GetServices(); foreach (var hook in hooks) { - hook.OnStateLoaded(_states).Wait(); + hook.OnStateLoaded(_curStates).Wait(); } - return curStates; + return endNodes; } public void Save() @@ -176,37 +196,81 @@ public void Save() var states = new List(); - foreach (var dic in _states) + foreach (var pair in _curStates) { - states.Add(dic.Value); + var key = pair.Key; + var curValue = pair.Value; + + if (!_historyStates.TryGetValue(key, out var historyValue) + || historyValue == null + || historyValue.Values.IsNullOrEmpty() + || !curValue.Versioning) + { + states.Add(curValue); + } + else + { + var historyValues = historyValue.Values.Take(historyValue.Values.Count - 1).ToList(); + var newValues = historyValues.Concat(curValue.Values).ToList(); + var updatedNode = new StateKeyValue + { + Key = pair.Key, + Versioning = curValue.Versioning, + Readonly = curValue.Readonly, + Values = newValues + }; + states.Add(updatedNode); + } } _db.UpdateConversationStates(_conversationId, states); _logger.LogInformation($"Saved states of conversation {_conversationId}"); } + public bool RemoveState(string name) + { + if (!ContainsState(name)) return false; + + var routingCtx = _services.GetRequiredService(); + var leafNode = _curStates[name].Values?.LastOrDefault(); + if (leafNode == null) return false; + + _curStates[name].Values.Add(new StateValue + { + Data = leafNode.Data, + MessageId = routingCtx.MessageId, + Active = false, + ActiveRounds = leafNode.ActiveRounds, + DataType = leafNode.DataType, + Source = leafNode.Source, + UpdateTime = DateTime.UtcNow + }); + + return true; + } + public void CleanStates() { var routingCtx = _services.GetRequiredService(); var curMsgId = routingCtx.MessageId; var utcNow = DateTime.UtcNow; - foreach (var key in _states.Keys) + foreach (var key in _curStates.Keys) { - var value = _states[key]; + var value = _curStates[key]; if (value == null || !value.Versioning || value.Values.IsNullOrEmpty()) continue; - var lastValue = value.Values.LastOrDefault(); - if (lastValue == null || !lastValue.Active) continue; + var leafNode = value.Values.LastOrDefault(); + if (leafNode == null || !leafNode.Active) continue; value.Values.Add(new StateValue { - Data = lastValue.Data, + Data = leafNode.Data, MessageId = curMsgId, Active = false, - ActiveRounds = lastValue.ActiveRounds, - DataType = lastValue.DataType, - Source = lastValue.Source, + ActiveRounds = leafNode.ActiveRounds, + DataType = leafNode.DataType, + Source = leafNode.Source, UpdateTime = utcNow }); } @@ -214,25 +278,25 @@ public void CleanStates() public Dictionary GetStates() { - var curStates = new Dictionary(); - foreach (var state in _states) + var endNodes = new Dictionary(); + foreach (var state in _curStates) { var value = state.Value?.Values?.LastOrDefault(); if (value == null || !value.Active) continue; - curStates[state.Key] = value.Data ?? string.Empty; + endNodes[state.Key] = value.Data ?? string.Empty; } - return curStates; + return endNodes; } public string GetState(string name, string defaultValue = "") { - if (!_states.ContainsKey(name) || _states[name].Values.IsNullOrEmpty() || !_states[name].Values.Last().Active) + if (!_curStates.ContainsKey(name) || _curStates[name].Values.IsNullOrEmpty() || !_curStates[name].Values.Last().Active) { return defaultValue; } - return _states[name].Values.Last().Data; + return _curStates[name].Values.Last().Data; } public void Dispose() @@ -242,10 +306,10 @@ public void Dispose() public bool ContainsState(string name) { - return _states.ContainsKey(name) - && !_states[name].Values.IsNullOrEmpty() - && _states[name].Values.LastOrDefault()?.Active == true - && !string.IsNullOrEmpty(_states[name].Values.Last().Data); + return _curStates.ContainsKey(name) + && !_curStates[name].Values.IsNullOrEmpty() + && _curStates[name].Values.LastOrDefault()?.Active == true + && !string.IsNullOrEmpty(_curStates[name].Values.Last().Data); } public void SaveStateByArgs(JsonDocument args) diff --git a/tests/BotSharp.Plugin.PizzaBot/Functions/GetPizzaTypesFn.cs b/tests/BotSharp.Plugin.PizzaBot/Functions/GetPizzaTypesFn.cs index 1842f5573..9a239fd23 100644 --- a/tests/BotSharp.Plugin.PizzaBot/Functions/GetPizzaTypesFn.cs +++ b/tests/BotSharp.Plugin.PizzaBot/Functions/GetPizzaTypesFn.cs @@ -38,7 +38,7 @@ public async Task Execute(RoleDialogModel message) Message = new ButtonTemplateMessage { Text = "Please select a pizza type", - Buttons = pizzaTypes.Select(x => new ButtonElement + Buttons = pizzaTypes.Select(x => new ElementButton { Type = "text", Title = x,