Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refine conversation states #245

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public interface IConversationStateService
string GetState(string name, string defaultValue = "");
bool ContainsState(string name);
ConversationState GetStates();
IConversationStateService SetState<T>(string name, T value);
IConversationStateService SetState<T>(string name, T value, bool isNeedVersion = true);
void SaveStateByArgs(JsonDocument args);
void CleanState();
void Save();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
namespace BotSharp.Abstraction.Conversations.Models;

public class ConversationHistoryState : Dictionary<string, List<HistoryStateValue>>
{
public ConversationHistoryState()
{

}

public ConversationHistoryState(List<HistoryStateKeyValue> pairs)
{
foreach (var pair in pairs)
{
this[pair.Key] = pair.Values;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
namespace BotSharp.Abstraction.Conversations.Models;

public class HistoryStateKeyValue
{
public string Key { get; set; }
public List<HistoryStateValue> Values { get; set; } = new List<HistoryStateValue>();

public HistoryStateKeyValue()
{

}

public HistoryStateKeyValue(string key, List<HistoryStateValue> values)
{
Key = key;
Values = values;
}
}

public class HistoryStateValue
{
public string Data { get; set; }
public DateTime UpdateTime { get; set; }

public HistoryStateValue()
{

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ public interface IBotSharpRepository
List<DialogElement> GetConversationDialogs(string conversationId);
void UpdateConversationDialogElements(string conversationId, List<DialogContentUpdateModel> updateElements);
void AppendConversationDialogs(string conversationId, List<DialogElement> dialogs);
List<StateKeyValue> GetConversationStates(string conversationId);
void UpdateConversationStates(string conversationId, List<StateKeyValue> states);
List<HistoryStateKeyValue> GetConversationStates(string conversationId);
void UpdateConversationStates(string conversationId, List<HistoryStateKeyValue> states);
void UpdateConversationStatus(string conversationId, string status);
Conversation GetConversation(string conversationId);
List<Conversation> GetConversations(ConversationFilter filter);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using BotSharp.Abstraction.Repositories;
using System.IO;

namespace BotSharp.Core.Conversations.Services;

Expand All @@ -11,26 +10,33 @@ public class ConversationStateService : IConversationStateService, IDisposable
private readonly ILogger _logger;
private readonly IServiceProvider _services;
private ConversationState _states;
private BotSharpDatabaseSettings _dbSettings;
private ConversationHistoryState _historyStates;
private string _conversationId;
private readonly IBotSharpRepository _db;
private List<StateKeyValue> _savedStates;

public ConversationStateService(ILogger<ConversationStateService> logger,
IServiceProvider services,
BotSharpDatabaseSettings dbSettings,
IServiceProvider services,
IBotSharpRepository db)
{
_logger = logger;
_services = services;
_dbSettings = dbSettings;
_db = db;
_states = new ConversationState();
_historyStates = new ConversationHistoryState();
}

public string GetConversationId() => _conversationId;

public IConversationStateService SetState<T>(string name, T value)

/// <summary>
/// Set conversation state
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="name"></param>
/// <param name="value"></param>
/// <param name="isNeedVersion">whether the state is related to message or not</param>
/// <returns></returns>
public IConversationStateService SetState<T>(string name, T value, bool isNeedVersion = true)
{
if (value == null)
{
Expand All @@ -40,6 +46,7 @@ public IConversationStateService SetState<T>(string name, T value)
var currentValue = value.ToString();
var hooks = _services.GetServices<IConversationHook>();
string preValue = _states.ContainsKey(name) ? _states[name] : "";

if (!_states.ContainsKey(name) || _states[name] != currentValue)
{
_states[name] = currentValue;
Expand All @@ -48,6 +55,21 @@ public IConversationStateService SetState<T>(string name, T value)
{
hook.OnStateChanged(name, preValue, currentValue).Wait();
}

var historyStateValue = new HistoryStateValue
{
Data = currentValue,
UpdateTime = DateTime.UtcNow
};

if (!_historyStates.ContainsKey(name) || !isNeedVersion)
{
_historyStates[name] = new List<HistoryStateValue> { historyStateValue };
}
else
{
_historyStates[name].Add(historyStateValue);
}
}

return this;
Expand All @@ -57,14 +79,16 @@ public ConversationState Load(string conversationId)
{
_conversationId = conversationId;

_savedStates = _db.GetConversationStates(_conversationId).ToList();
var savedStates = _db.GetConversationStates(_conversationId).ToList();
_historyStates = new ConversationHistoryState(savedStates);

if (!_savedStates.IsNullOrEmpty())
if (!savedStates.IsNullOrEmpty())
{
foreach (var data in _savedStates)
foreach (var state in savedStates)
{
_states[data.Key] = data.Value;
_logger.LogInformation($"[STATE] {data.Key} : {data.Value}");
var value = state.Values.LastOrDefault()?.Data ?? string.Empty;
_states[state.Key] = value;
_logger.LogInformation($"[STATE] {state.Key} : {value}");
}
}

Expand All @@ -85,24 +109,23 @@ public void Save()
return;
}

var states = new List<StateKeyValue>();
var historyStates = new List<HistoryStateKeyValue>();

foreach (var dic in _states)
foreach (var dic in _historyStates)
{
states.Add(new StateKeyValue(dic.Key, dic.Value));
historyStates.Add(new HistoryStateKeyValue(dic.Key, dic.Value));
}

_db.UpdateConversationStates(_conversationId, states);
_logger.LogInformation($"Saved state {_conversationId}");
_db.UpdateConversationStates(_conversationId, historyStates);
_logger.LogInformation($"Saved states of conversation {_conversationId}");
}

public void CleanState()
{
//File.Delete(_file);

}

public ConversationState GetStates()
=> _states;
public ConversationState GetStates() => _states;

public string GetState(string name, string defaultValue = "")
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ public void AddToken(TokenStatsModel stats)
// Accumulated Token
var stat = _services.GetRequiredService<IConversationStateService>();
var inputCount = int.Parse(stat.GetState("prompt_total", "0"));
stat.SetState("prompt_total", stats.PromptCount + inputCount);
stat.SetState("prompt_total", stats.PromptCount + inputCount, false);
var outputCount = int.Parse(stat.GetState("completion_total", "0"));
stat.SetState("completion_total", stats.CompletionCount + outputCount);
stat.SetState("completion_total", stats.CompletionCount + outputCount, false);

// Total cost
var total_cost = float.Parse(stat.GetState("llm_total_cost", "0"));
total_cost += Cost;
stat.SetState("llm_total_cost", total_cost);
stat.SetState("llm_total_cost", total_cost, false);
}

public void PrintStatistics()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ public void UpdateConversationDialogElements(string conversationId, List<DialogC
throw new NotImplementedException();
}

public List<StateKeyValue> GetConversationStates(string conversationId)
public List<HistoryStateKeyValue> GetConversationStates(string conversationId)
{
throw new NotImplementedException();
}
Expand All @@ -165,7 +165,7 @@ public void UpdateConversationTitle(string conversationId, string title)
{
throw new NotImplementedException();
}
public void UpdateConversationStates(string conversationId, List<StateKeyValue> states)
public void UpdateConversationStates(string conversationId, List<HistoryStateKeyValue> states)
{
throw new NotImplementedException();
}
Expand Down
53 changes: 27 additions & 26 deletions src/Infrastructure/BotSharp.Core/Repository/FileRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ public void CreateNewConversation(Conversation conversation)
var stateDir = Path.Combine(dir, "state.dict");
if (!File.Exists(stateDir))
{
File.WriteAllText(stateDir, string.Empty);
File.WriteAllText(stateDir, "[]");
}
}

Expand Down Expand Up @@ -730,33 +730,31 @@ public void UpdateConversationTitle(string conversationId, string title)
}
}
}
public List<StateKeyValue> GetConversationStates(string conversationId)
public List<HistoryStateKeyValue> GetConversationStates(string conversationId)
{
var curStates = new List<StateKeyValue>();
var curStates = new List<HistoryStateKeyValue>();
var convDir = FindConversationDirectory(conversationId);
if (!string.IsNullOrEmpty(convDir))
{
var stateDir = Path.Combine(convDir, "state.dict");
curStates = CollectConversationStates(stateDir);
var stateFile = Path.Combine(convDir, "state.dict");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we rename to state.json?

curStates = CollectConversationStates(stateFile);
}

return curStates;
}

public void UpdateConversationStates(string conversationId, List<StateKeyValue> states)
public void UpdateConversationStates(string conversationId, List<HistoryStateKeyValue> states)
{
var localStates = new List<string>();
if (states.IsNullOrEmpty()) return;

var convDir = FindConversationDirectory(conversationId);
if (!string.IsNullOrEmpty(convDir))
{
var stateDir = Path.Combine(convDir, "state.dict");
if (File.Exists(stateDir))
var stateFile = Path.Combine(convDir, "state.dict");
if (File.Exists(stateFile))
{
foreach (var data in states)
{
localStates.Add($"{data.Key}={data.Value}");
}
File.WriteAllLines(stateDir, localStates);
var stateStr = JsonSerializer.Serialize(states, _options);
File.WriteAllText(stateFile, stateStr);
}
}
}
Expand Down Expand Up @@ -796,8 +794,13 @@ public Conversation GetConversation(string conversationId)
var stateFile = Path.Combine(convDir, "state.dict");
if (record != null)
{
var states = CollectConversationStates(stateFile);
record.States = new ConversationState(states);
var historyStates = CollectConversationStates(stateFile);
var recentStates = historyStates.Select(x => new StateKeyValue
{
Key = x.Key,
Value = x.Values.LastOrDefault()?.Data ?? string.Empty
}).ToList();
record.States = new ConversationState(recentStates);
}

return record;
Expand Down Expand Up @@ -1082,18 +1085,16 @@ private List<string> ParseDialogElements(List<DialogElement> dialogs)
return dialogTexts;
}

private List<StateKeyValue> CollectConversationStates(string stateDir)
private List<HistoryStateKeyValue> CollectConversationStates(string stateFile)
{
var states = new List<StateKeyValue>();
if (!File.Exists(stateDir)) return states;
var states = new List<HistoryStateKeyValue>();
if (!File.Exists(stateFile)) return states;

var dict = File.ReadAllLines(stateDir);
foreach (var line in dict)
{
var data = line.Split('=');
states.Add(new StateKeyValue(data[0], data[1]));
}
return states;
var stateStr = File.ReadAllText(stateFile);
if (string.IsNullOrEmpty(stateStr)) return states;

states = JsonSerializer.Deserialize<List<HistoryStateKeyValue>>(stateStr, _options);
return states ?? new List<HistoryStateKeyValue>();
}

private int GetNextLlmCompletionLogIndex(string logDir, string id)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using BotSharp.Abstraction.Conversations.Models;

namespace BotSharp.Plugin.MongoStorage.Collections;

public class ConversationDocument : MongoBase
Expand All @@ -9,7 +7,6 @@ public class ConversationDocument : MongoBase
public string Title { get; set; }
public string Channel { get; set; }
public string Status { get; set; }
public List<StateKeyValue> States { get; set; }
public DateTime CreatedTime { get; set; }
public DateTime UpdatedTime { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using BotSharp.Plugin.MongoStorage.Models;

namespace BotSharp.Plugin.MongoStorage.Collections;

public class ConversationStateDocument : MongoBase
{
public string ConversationId { get; set; }
public List<StateMongoElement> States { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ namespace BotSharp.Plugin.MongoStorage.Collections;
public class LlmCompletionLogDocument : MongoBase
{
public string ConversationId { get; set; }
public List<PromptLogElement> Logs { get; set; }
public List<PromptLogMongoElement> Logs { get; set; }
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
namespace BotSharp.Plugin.MongoStorage.Models;

public class PromptLogElement
public class PromptLogMongoElement
{
public string MessageId { get; set; }
public string AgentId { get; set; }
Expand Down
Loading