Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

25 changes: 25 additions & 0 deletions src/Infrastructure/BotSharp.Abstraction/Hooks/HookProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using BotSharp.Abstraction.Conversations;
using Microsoft.Extensions.DependencyInjection;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace BotSharp.Abstraction.Hooks
{
public static class HookProvider
{
public static List<T> GetHooks<T>(this IServiceProvider services, string agentId) where T : IHookBase
{
var hooks = services.GetServices<T>().Where(p => p.IsMatch(agentId));
return hooks.ToList();
}

public static List<T> GetHooksOrderByPriority<T>(this IServiceProvider services, string agentId) where T: IConversationHook
{
var hooks = services.GetServices<T>().Where(p => p.IsMatch(agentId));
return hooks.OrderBy(p => p.Priority).ToList();
}
}
}
2 changes: 1 addition & 1 deletion src/Infrastructure/BotSharp.Abstraction/Hooks/IHookBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ public interface IHookBase
/// Agent Id
/// </summary>
string SelfId => string.Empty;
bool IsMatch(string id) => string.IsNullOrEmpty(SelfId) || SelfId == id;
bool IsMatch(string agentId) => string.IsNullOrEmpty(SelfId) || SelfId == agentId;
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using BotSharp.Abstraction.Functions.Models;
using BotSharp.Abstraction.Hooks;
using BotSharp.Abstraction.Options;
using BotSharp.Core.Infrastructures;

Expand All @@ -23,7 +24,6 @@ public RealtimeHub(IServiceProvider services, ILogger<RealtimeHub> logger)

public async Task ConnectToModel(Func<string, Task>? responseToUser = null, Func<string, Task>? init = null)
{
var hookProvider = _services.GetService<ConversationHookProvider>();
var convService = _services.GetRequiredService<IConversationService>();
convService.SetConversationId(_conn.ConversationId, []);
var conversation = await convService.GetConversation(_conn.ConversationId);
Expand Down Expand Up @@ -105,7 +105,8 @@ await HookEmitter.Emit<IRoutingHook>(_services, async hook => await hook.OnRouti
dialogs.Add(message);
storage.Append(_conn.ConversationId, message);

foreach (var hook in hookProvider?.HooksOrderByPriority ?? [])
var hooks = _services.GetHooksOrderByPriority<IConversationHook>(_conn.CurrentAgentId);
foreach (var hook in hooks)
{
hook.SetAgent(agent)
.SetConversation(conversation);
Expand All @@ -126,7 +127,8 @@ await HookEmitter.Emit<IRoutingHook>(_services, async hook => await hook.OnRouti
storage.Append(_conn.ConversationId, message);
routing.Context.SetMessageId(_conn.ConversationId, message.MessageId);

foreach (var hook in hookProvider?.HooksOrderByPriority ?? [])
var hooks = _services.GetHooksOrderByPriority<IConversationHook>(_conn.CurrentAgentId);
foreach (var hook in hooks)
{
hook.SetAgent(agent)
.SetConversation(conversation);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using BotSharp.Abstraction.Hooks;
using BotSharp.Abstraction.Infrastructures.Enums;
using BotSharp.Abstraction.Messaging;
using BotSharp.Abstraction.Messaging.Models.RichContent;
Expand Down Expand Up @@ -29,7 +30,6 @@ public async Task<bool> SendMessage(string agentId,
var dialogs = conv.GetDialogHistory();

var statistics = _services.GetRequiredService<ITokenStatistics>();
var hookProvider = _services.GetRequiredService<ConversationHookProvider>();

RoleDialogModel response = message;
bool stopCompletion = false;
Expand All @@ -44,7 +44,8 @@ public async Task<bool> SendMessage(string agentId,
message.Payload = replyMessage.Payload;
}

foreach (var hook in hookProvider.HooksOrderByPriority)
var hooks = _services.GetHooksOrderByPriority<IConversationHook>(message.CurrentAgentId);
foreach (var hook in hooks)
{
hook.SetAgent(agent)
.SetConversation(conversation);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using BotSharp.Abstraction.Hooks;
using BotSharp.Abstraction.Infrastructures.Enums;

namespace BotSharp.Core.Conversations.Services;
Expand Down Expand Up @@ -31,9 +32,7 @@ public async Task UpdateBreakpoint(bool resetStates = false, string? reason = nu
states.CleanStates(excludedStates);
}

var hooks = _services
.GetRequiredService<ConversationHookProvider>()
.HooksOrderByPriority;
var hooks = _services.GetHooksOrderByPriority<IConversationHook>(routingCtx.GetCurrentAgentId());

// Before executing functions
foreach (var hook in hooks)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using BotSharp.Abstraction.Conversations.Enums;
using BotSharp.Abstraction.Hooks;
using BotSharp.Abstraction.Models;

namespace BotSharp.Core.Conversations.Services;
Expand Down Expand Up @@ -116,7 +117,7 @@ public async Task<Conversation> NewConversation(Conversation sess)

db.CreateNewConversation(record);

var hooks = _services.GetServices<IConversationHook>();
var hooks = _services.GetHooks<IConversationHook>(record.AgentId);

foreach (var hook in hooks)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
******************************************************************************/

using BotSharp.Abstraction.Conversations.Enums;
using BotSharp.Abstraction.Hooks;
using BotSharp.Abstraction.Options;
using BotSharp.Abstraction.SideCar;

Expand All @@ -28,6 +29,7 @@ public class ConversationStateService : IConversationStateService
private readonly ILogger _logger;
private readonly IServiceProvider _services;
private readonly IBotSharpRepository _db;
private readonly IRoutingContext _routingContext;
private readonly IConversationSideCar? _sidecar;
private string _conversationId;
/// <summary>
Expand All @@ -42,10 +44,12 @@ public class ConversationStateService : IConversationStateService
public ConversationStateService(
IServiceProvider services,
IBotSharpRepository db,
IRoutingContext routingContext,
ILogger<ConversationStateService> logger)
{
_services = services;
_db = db;
_routingContext = routingContext;
_logger = logger;
_curStates = new ConversationState();
_historyStates = new ConversationState();
Expand Down Expand Up @@ -87,7 +91,6 @@ public IConversationStateService SetState<T>(string name, T value, bool isNeedVe
}

_logger.LogDebug($"[STATE] {name} = {value}");
var routingCtx = _services.GetRequiredService<IRoutingContext>();

var isNoChange = ContainsState(name)
&& preValue == currentValue
Expand All @@ -98,15 +101,15 @@ public IConversationStateService SetState<T>(string name, T value, bool isNeedVe
&& prevLeafNode?.Active == curActive
&& pair?.Readonly == readOnly;

var hooks = _services.GetServices<IConversationHook>();
var hooks = _services.GetHooks<IConversationHook>(_routingContext.GetCurrentAgentId());
if (!ContainsState(name) || preValue != currentValue || prevLeafNode?.ActiveRounds != curActiveRounds)
{
foreach (var hook in hooks)
{
hook.OnStateChanged(new StateChangeModel
{
ConversationId = _conversationId,
MessageId = routingCtx.MessageId,
MessageId = _routingContext.MessageId,
Name = name,
BeforeValue = preValue,
BeforeActiveRounds = prevLeafNode?.ActiveRounds,
Expand All @@ -129,7 +132,7 @@ public IConversationStateService SetState<T>(string name, T value, bool isNeedVe
var newValue = new StateValue
{
Data = currentValue,
MessageId = routingCtx.MessageId,
MessageId = _routingContext.MessageId,
Active = curActive,
ActiveRounds = curActiveRounds,
DataType = valueType,
Expand Down Expand Up @@ -171,8 +174,7 @@ public Dictionary<string, string> Load(string conversationId, bool isReadOnly =
return endNodes;
}

var routingCtx = _services.GetRequiredService<IRoutingContext>();
var curMsgId = routingCtx.MessageId;
var curMsgId = _routingContext.MessageId;
var dialogs = _db.GetConversationDialogs(conversationId);
var userDialogs = dialogs.Where(x => x.MetaData?.Role == AgentRole.User)
.GroupBy(x => x.MetaData?.MessageId)
Expand Down Expand Up @@ -225,7 +227,7 @@ public Dictionary<string, string> Load(string conversationId, bool isReadOnly =
}

_logger.LogInformation($"Loaded conversation states: {conversationId}");
var hooks = _services.GetServices<IConversationHook>();
var hooks = _services.GetHooks<IConversationHook>(_routingContext.GetCurrentAgentId());
foreach (var hook in hooks)
{
hook.OnStateLoaded(_curStates).Wait();
Expand Down Expand Up @@ -277,29 +279,28 @@ public bool RemoveState(string name)
{
if (!ContainsState(name)) return false;

var routingCtx = _services.GetRequiredService<IRoutingContext>();
var value = _curStates[name];
var leafNode = value?.Values?.LastOrDefault();
if (value == null || !value.Versioning || leafNode == null) return false;

_curStates[name].Values.Add(new StateValue
{
Data = leafNode.Data,
MessageId = routingCtx.MessageId,
MessageId = _routingContext.MessageId,
Active = false,
ActiveRounds = leafNode.ActiveRounds,
DataType = leafNode.DataType,
Source = leafNode.Source,
UpdateTime = DateTime.UtcNow
});

var hooks = _services.GetServices<IConversationHook>();
var hooks = _services.GetHooks<IConversationHook>(_routingContext.GetCurrentAgentId());
foreach (var hook in hooks)
{
hook.OnStateChanged(new StateChangeModel
{
ConversationId = _conversationId,
MessageId = routingCtx.MessageId,
MessageId = _routingContext.MessageId,
Name = name,
BeforeValue = leafNode.Data,
BeforeActiveRounds = leafNode.ActiveRounds,
Expand All @@ -316,8 +317,7 @@ public bool RemoveState(string name)

public void CleanStates(params string[] excludedStates)
{
var routingCtx = _services.GetRequiredService<IRoutingContext>();
var curMsgId = routingCtx.MessageId;
var curMsgId = _routingContext.MessageId;
var utcNow = DateTime.UtcNow;

foreach (var key in _curStates.Keys)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public static HookEmittedResult Emit<T>(IServiceProvider services, Action<T> act
{
var logger = services.GetRequiredService<ILogger<T>>();
var result = new HookEmittedResult();
var hooks = services.GetServices<T>().Where(p => p.IsMatch(agentId));
var hooks = services.GetHooks<T>(agentId);
option = option ?? new();

foreach (var hook in hooks)
Expand Down Expand Up @@ -40,7 +40,7 @@ public static async Task<HookEmittedResult> Emit<T>(IServiceProvider services, F
{
var logger = services.GetRequiredService<ILogger<T>>();
var result = new HookEmittedResult();
var hooks = services.GetServices<T>().Where(p => p.IsMatch(agentId));
var hooks = services.GetHooks<T>(agentId);
option = option ?? new();

foreach (var hook in hooks)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using BotSharp.Abstraction.Hooks;
using BotSharp.Abstraction.Instructs;
using BotSharp.Abstraction.Instructs.Models;
using BotSharp.Abstraction.MLTasks;
Expand All @@ -23,7 +24,7 @@ public async Task<InstructResult> Execute(string agentId, RoleDialogModel messag
}

// Trigger before completion hooks
var hooks = _services.GetServices<IInstructHook>().Where(p => p.IsMatch(agentId));
var hooks = _services.GetHooks<IInstructHook>(agentId);
foreach (var hook in hooks)
{
await hook.BeforeCompletion(agent, message);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using BotSharp.Abstraction.Functions;
using BotSharp.Abstraction.Hooks;

namespace BotSharp.Core.Routing.Functions;

Expand All @@ -15,9 +16,7 @@ public HumanInterventionNeededFn(IServiceProvider services)

public async Task<bool> Execute(RoleDialogModel message)
{
var hooks = _services
.GetRequiredService<ConversationHookProvider>()
.HooksOrderByPriority;
var hooks = _services.GetHooksOrderByPriority<IConversationHook>(message.CurrentAgentId);

foreach (var hook in hooks)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using BotSharp.Abstraction.Functions;
using BotSharp.Abstraction.Hooks;
using BotSharp.Abstraction.Templating;
using BotSharp.Core.Routing.Executor;

Expand Down Expand Up @@ -29,10 +30,6 @@ public async Task<bool> InvokeFunction(string name, RoleDialogModel message)
var clonedMessage = RoleDialogModel.From(message);
clonedMessage.FunctionName = name;

var hooks = _services
.GetRequiredService<ConversationHookProvider>()
.HooksOrderByPriority;

var progressService = _services.GetService<IConversationProgressService>();

clonedMessage.Indication = await funcExecutor.GetIndicatorAsync(message);
Expand All @@ -41,7 +38,8 @@ public async Task<bool> InvokeFunction(string name, RoleDialogModel message)
{
await progressService.OnFunctionExecuting(clonedMessage);
}


var hooks = _services.GetHooksOrderByPriority<IConversationHook>(clonedMessage.CurrentAgentId);
foreach (var hook in hooks)
{
hook.SetAgent(agent);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Anthropic.SDK.Common;
using BotSharp.Abstraction.Conversations;
using BotSharp.Abstraction.Hooks;
using BotSharp.Abstraction.MLTasks.Settings;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
Expand Down Expand Up @@ -29,7 +30,7 @@ public ChatCompletionProvider(AnthropicSettings settings,

public async Task<RoleDialogModel> GetChatCompletions(Agent agent, List<RoleDialogModel> conversations)
{
var contentHooks = _services.GetServices<IContentGeneratingHook>().ToList();
var contentHooks = _services.GetHooks<IContentGeneratingHook>(agent.Id);

// Before chat completion hook
foreach (var hook in contentHooks)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Azure;
using BotSharp.Abstraction.Files.Utilities;
using BotSharp.Abstraction.Hooks;
using OpenAI.Chat;
using System.ClientModel;

Expand Down Expand Up @@ -29,7 +30,7 @@ public ChatCompletionProvider(

public async Task<RoleDialogModel> GetChatCompletions(Agent agent, List<RoleDialogModel> conversations)
{
var contentHooks = _services.GetServices<IContentGeneratingHook>().ToList();
var contentHooks = _services.GetHooks<IContentGeneratingHook>(agent.Id);

// Before chat completion hook
foreach (var hook in contentHooks)
Expand Down Expand Up @@ -128,7 +129,7 @@ public async Task<bool> GetChatCompletionsAsync(Agent agent,
Func<RoleDialogModel, Task> onMessageReceived,
Func<RoleDialogModel, Task> onFunctionExecuting)
{
var hooks = _services.GetServices<IContentGeneratingHook>().ToList();
var hooks = _services.GetHooks<IContentGeneratingHook>(agent.Id);

// Before chat completion hook
foreach (var hook in hooks)
Expand Down
Loading
Loading