Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize state #391

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public interface IConversationStateService
IConversationStateService SetState<T>(string name, T value, bool isNeedVersion = true,
int activeRounds = -1, string valueType = StateDataType.String, string source = StateSource.User, bool readOnly = false);
void SaveStateByArgs(JsonDocument args);
bool RemoveState(string name);
void CleanStates();
void Save();
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,16 @@ public class ConversationStateService : IConversationStateService, IDisposable
{
private readonly ILogger _logger;
private readonly IServiceProvider _services;
private ConversationState _states;
private string _conversationId;
private readonly IBotSharpRepository _db;
private string _conversationId;
/// <summary>
/// States in the current round of conversation
/// </summary>
private ConversationState _curStates;
/// <summary>
/// States in the previous rounds of conversation
/// </summary>
private ConversationState _historyStates;

public ConversationStateService(ILogger<ConversationStateService> logger,
IServiceProvider services,
Expand All @@ -21,7 +28,8 @@ public ConversationStateService(ILogger<ConversationStateService> logger,
_logger = logger;
_services = services;
_db = db;
_states = new ConversationState();
_curStates = new ConversationState();
_historyStates = new ConversationState();
}

public string GetConversationId() => _conversationId;
Expand All @@ -48,7 +56,7 @@ public IConversationStateService SetState<T>(string name, T value, bool isNeedVe
var curActiveRounds = activeRounds > 0 ? activeRounds : -1;
int? preActiveRounds = null;

if (ContainsState(name) && _states.TryGetValue(name, out var pair))
if (ContainsState(name) && _curStates.TryGetValue(name, out var pair))
{
var lastNode = pair?.Values?.LastOrDefault();
preActiveRounds = lastNode?.ActiveRounds;
Expand Down Expand Up @@ -96,14 +104,14 @@ public IConversationStateService SetState<T>(string name, T value, bool isNeedVe
UpdateTime = DateTime.UtcNow,
};

if (!isNeedVersion || !_states.ContainsKey(name))
if (!isNeedVersion || !_curStates.ContainsKey(name))
{
newPair.Values = new List<StateValue> { newValue };
_states[name] = newPair;
_curStates[name] = newPair;
}
else
{
_states[name].Values.Add(newValue);
_curStates[name].Values.Add(newValue);
}

return this;
Expand All @@ -115,56 +123,68 @@ public Dictionary<string, string> Load(string conversationId)

var routingCtx = _services.GetRequiredService<IRoutingContext>();
var curMsgId = routingCtx.MessageId;
_states = _db.GetConversationStates(_conversationId);

_historyStates = _db.GetConversationStates(_conversationId);
var dialogs = _db.GetConversationDialogs(_conversationId);
var userDialogs = dialogs.Where(x => x.MetaData?.Role == AgentRole.User || x.MetaData?.Role == UserRole.Client)
.OrderBy(x => x.MetaData?.CreateTime)
.ToList();

var curMsgIndex = userDialogs.FindIndex(x => !string.IsNullOrEmpty(curMsgId) && x.MetaData?.MessageId == curMsgId);
curMsgIndex = curMsgIndex < 0 ? userDialogs.Count() : curMsgIndex;
var curStates = new Dictionary<string, string>();

if (!_states.IsNullOrEmpty())
var endNodes = new Dictionary<string, string>();
if (_historyStates.IsNullOrEmpty()) return endNodes;

foreach (var state in _historyStates)
{
foreach (var state in _states)
var key = state.Key;
var value = state.Value;
var leafNode = value.Values.LastOrDefault();
if (leafNode == null) continue;

_curStates[key] = new StateKeyValue
{
var value = state.Value?.Values?.LastOrDefault();
if (value == null || !value.Active) continue;
Key = key,
Versioning = value.Versioning,
Readonly = value.Readonly,
Values = new List<StateValue> { leafNode }
};

if (!leafNode.Active) continue;

if (value.ActiveRounds > 0)
// Handle state active rounds
if (leafNode.ActiveRounds > 0)
{
var stateMsgIndex = userDialogs.FindIndex(x => !string.IsNullOrEmpty(x.MetaData?.MessageId) && x.MetaData.MessageId == leafNode.MessageId);
if (stateMsgIndex >= 0 && curMsgIndex - stateMsgIndex >= leafNode.ActiveRounds)
{
var stateMsgIndex = userDialogs.FindIndex(x => !string.IsNullOrEmpty(x.MetaData?.MessageId) && x.MetaData.MessageId == value.MessageId);
if (stateMsgIndex >= 0 && curMsgIndex - stateMsgIndex >= value.ActiveRounds)
_curStates[key].Values.Add(new StateValue
{
state.Value.Values.Add(new StateValue
{
Data = value.Data,
MessageId = curMsgId,
Active = false,
ActiveRounds = value.ActiveRounds,
DataType = value.DataType,
Source = value.Source,
UpdateTime = DateTime.UtcNow
});
continue;
}
Data = leafNode.Data,
MessageId = curMsgId,
Active = false,
ActiveRounds = leafNode.ActiveRounds,
DataType = leafNode.DataType,
Source = leafNode.Source,
UpdateTime = DateTime.UtcNow
});
continue;
}

var data = value.Data ?? string.Empty;
curStates[state.Key] = data;
_logger.LogInformation($"[STATE] {state.Key} : {data}");
}

var data = leafNode.Data ?? string.Empty;
endNodes[state.Key] = data;
_logger.LogInformation($"[STATE] {key} : {data}");
}

_logger.LogInformation($"Loaded conversation states: {_conversationId}");
var hooks = _services.GetServices<IConversationHook>();
foreach (var hook in hooks)
{
hook.OnStateLoaded(_states).Wait();
hook.OnStateLoaded(_curStates).Wait();
}

return curStates;
return endNodes;
}

public void Save()
Expand All @@ -176,63 +196,107 @@ public void Save()

var states = new List<StateKeyValue>();

foreach (var dic in _states)
foreach (var pair in _curStates)
{
states.Add(dic.Value);
var key = pair.Key;
var curValue = pair.Value;

if (!_historyStates.TryGetValue(key, out var historyValue)
|| historyValue == null
|| historyValue.Values.IsNullOrEmpty()
|| !curValue.Versioning)
{
states.Add(curValue);
}
else
{
var historyValues = historyValue.Values.Take(historyValue.Values.Count - 1).ToList();
var newValues = historyValues.Concat(curValue.Values).ToList();
var updatedNode = new StateKeyValue
{
Key = pair.Key,
Versioning = curValue.Versioning,
Readonly = curValue.Readonly,
Values = newValues
};
states.Add(updatedNode);
}
}

_db.UpdateConversationStates(_conversationId, states);
_logger.LogInformation($"Saved states of conversation {_conversationId}");
}

public bool RemoveState(string name)
{
if (!ContainsState(name)) return false;

var routingCtx = _services.GetRequiredService<IRoutingContext>();
var leafNode = _curStates[name].Values?.LastOrDefault();
if (leafNode == null) return false;

_curStates[name].Values.Add(new StateValue
{
Data = leafNode.Data,
MessageId = routingCtx.MessageId,
Active = false,
ActiveRounds = leafNode.ActiveRounds,
DataType = leafNode.DataType,
Source = leafNode.Source,
UpdateTime = DateTime.UtcNow
});

return true;
}

public void CleanStates()
{
var routingCtx = _services.GetRequiredService<IRoutingContext>();
var curMsgId = routingCtx.MessageId;
var utcNow = DateTime.UtcNow;

foreach (var key in _states.Keys)
foreach (var key in _curStates.Keys)
{
var value = _states[key];
var value = _curStates[key];
if (value == null || !value.Versioning || value.Values.IsNullOrEmpty()) continue;

var lastValue = value.Values.LastOrDefault();
if (lastValue == null || !lastValue.Active) continue;
var leafNode = value.Values.LastOrDefault();
if (leafNode == null || !leafNode.Active) continue;

value.Values.Add(new StateValue
{
Data = lastValue.Data,
Data = leafNode.Data,
MessageId = curMsgId,
Active = false,
ActiveRounds = lastValue.ActiveRounds,
DataType = lastValue.DataType,
Source = lastValue.Source,
ActiveRounds = leafNode.ActiveRounds,
DataType = leafNode.DataType,
Source = leafNode.Source,
UpdateTime = utcNow
});
}
}

public Dictionary<string, string> GetStates()
{
var curStates = new Dictionary<string, string>();
foreach (var state in _states)
var endNodes = new Dictionary<string, string>();
foreach (var state in _curStates)
{
var value = state.Value?.Values?.LastOrDefault();
if (value == null || !value.Active) continue;

curStates[state.Key] = value.Data ?? string.Empty;
endNodes[state.Key] = value.Data ?? string.Empty;
}
return curStates;
return endNodes;
}

public string GetState(string name, string defaultValue = "")
{
if (!_states.ContainsKey(name) || _states[name].Values.IsNullOrEmpty() || !_states[name].Values.Last().Active)
if (!_curStates.ContainsKey(name) || _curStates[name].Values.IsNullOrEmpty() || !_curStates[name].Values.Last().Active)
{
return defaultValue;
}

return _states[name].Values.Last().Data;
return _curStates[name].Values.Last().Data;
}

public void Dispose()
Expand All @@ -242,10 +306,10 @@ public void Dispose()

public bool ContainsState(string name)
{
return _states.ContainsKey(name)
&& !_states[name].Values.IsNullOrEmpty()
&& _states[name].Values.LastOrDefault()?.Active == true
&& !string.IsNullOrEmpty(_states[name].Values.Last().Data);
return _curStates.ContainsKey(name)
&& !_curStates[name].Values.IsNullOrEmpty()
&& _curStates[name].Values.LastOrDefault()?.Active == true
&& !string.IsNullOrEmpty(_curStates[name].Values.Last().Data);
}

public void SaveStateByArgs(JsonDocument args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public async Task<bool> Execute(RoleDialogModel message)
Message = new ButtonTemplateMessage
{
Text = "Please select a pizza type",
Buttons = pizzaTypes.Select(x => new ButtonElement
Buttons = pizzaTypes.Select(x => new ElementButton
{
Type = "text",
Title = x,
Expand Down