From 9bbea3bece958d7a8d0c00d6c85fbbc3366c57c5 Mon Sep 17 00:00:00 2001 From: Jicheng Lu <103353@smsassist.com> Date: Fri, 6 Dec 2024 00:57:46 -0600 Subject: [PATCH] add common agent hook --- .../Agents/AgentHookBase.cs | 73 ---------------- .../BotSharp.Core/Agents/AgentPlugin.cs | 2 + .../Agents/Hooks/CommonAgentHook.cs | 83 +++++++++++++++++++ 3 files changed, 85 insertions(+), 73 deletions(-) create mode 100644 src/Infrastructure/BotSharp.Core/Agents/Hooks/CommonAgentHook.cs diff --git a/src/Infrastructure/BotSharp.Abstraction/Agents/AgentHookBase.cs b/src/Infrastructure/BotSharp.Abstraction/Agents/AgentHookBase.cs index 867d9d78a..a3746ca8a 100644 --- a/src/Infrastructure/BotSharp.Abstraction/Agents/AgentHookBase.cs +++ b/src/Infrastructure/BotSharp.Abstraction/Agents/AgentHookBase.cs @@ -1,10 +1,5 @@ using BotSharp.Abstraction.Agents.Settings; -using BotSharp.Abstraction.Conversations; using BotSharp.Abstraction.Functions.Models; -using BotSharp.Abstraction.Repositories; -using BotSharp.Abstraction.Routing; -using Microsoft.Extensions.DependencyInjection; -using System.Data; namespace BotSharp.Abstraction.Agents; @@ -60,73 +55,5 @@ public virtual void OnAgentLoaded(Agent agent) public virtual void OnAgentUtilityLoaded(Agent agent) { - if (agent.Type == AgentType.Routing) return; - - var conv = _services.GetRequiredService(); - var isConvMode = conv.IsConversationMode(); - if (!isConvMode) return; - - agent.Functions ??= []; - agent.Utilities ??= []; - - var (functions, templates) = GetUtilityContent(agent); - - foreach (var fn in functions) - { - if (!agent.Functions.Any(x => x.Name.Equals(fn.Name, StringComparison.OrdinalIgnoreCase))) - { - agent.Functions.Add(fn); - } - } - - foreach (var prompt in templates) - { - agent.Instruction += $"\r\n\r\n{prompt}\r\n\r\n"; - } - } - - private (IEnumerable, IEnumerable) GetUtilityContent(Agent agent) - { - var db = _services.GetRequiredService(); - var (functionNames, templateNames) = GetUniqueContent(agent.Utilities); - - if (agent.MergeUtility) - { - var routing = _services.GetRequiredService(); - var entryAgentId = routing.EntryAgentId; - if (!string.IsNullOrEmpty(entryAgentId)) - { - var entryAgent = db.GetAgent(entryAgentId); - var (fns, tps) = GetUniqueContent(entryAgent?.Utilities); - functionNames = functionNames.Concat(fns).Distinct().ToList(); - templateNames = templateNames.Concat(tps).Distinct().ToList(); - } - } - - var ua = db.GetAgent(BuiltInAgentId.UtilityAssistant); - var functions = ua?.Functions?.Where(x => functionNames.Contains(x.Name, StringComparer.OrdinalIgnoreCase))?.ToList() ?? []; - var templates = ua?.Templates?.Where(x => templateNames.Contains(x.Name, StringComparer.OrdinalIgnoreCase))?.Select(x => x.Content)?.ToList() ?? []; - return (functions, templates); - } - - private (IEnumerable, IEnumerable) GetUniqueContent(IEnumerable? utilities) - { - if (utilities.IsNullOrEmpty()) - { - return ([], []); - } - - var prefix = "util-"; - utilities = utilities?.Where(x => !string.IsNullOrEmpty(x.Name) && !x.Disabled)?.ToList() ?? []; - var functionNames = utilities.SelectMany(x => x.Functions) - .Where(x => !string.IsNullOrEmpty(x.Name) && x.Name.StartsWith(prefix)) - .Select(x => x.Name) - .Distinct().ToList(); - var templateNames = utilities.SelectMany(x => x.Templates) - .Where(x => !string.IsNullOrEmpty(x.Name) && x.Name.StartsWith(prefix)) - .Select(x => x.Name) - .Distinct().ToList(); - - return (functionNames, templateNames); } } diff --git a/src/Infrastructure/BotSharp.Core/Agents/AgentPlugin.cs b/src/Infrastructure/BotSharp.Core/Agents/AgentPlugin.cs index cfbc541f2..51b0da05e 100644 --- a/src/Infrastructure/BotSharp.Core/Agents/AgentPlugin.cs +++ b/src/Infrastructure/BotSharp.Core/Agents/AgentPlugin.cs @@ -3,6 +3,7 @@ using BotSharp.Abstraction.Settings; using BotSharp.Abstraction.Templating; using BotSharp.Abstraction.Users.Enums; +using BotSharp.Core.Agents.Hooks; using Microsoft.Extensions.Configuration; namespace BotSharp.Core.Agents; @@ -30,6 +31,7 @@ public void RegisterDI(IServiceCollection services, IConfiguration config) { services.AddScoped(); services.AddScoped(); + services.AddScoped(); services.AddScoped(provider => { diff --git a/src/Infrastructure/BotSharp.Core/Agents/Hooks/CommonAgentHook.cs b/src/Infrastructure/BotSharp.Core/Agents/Hooks/CommonAgentHook.cs new file mode 100644 index 000000000..78420fad7 --- /dev/null +++ b/src/Infrastructure/BotSharp.Core/Agents/Hooks/CommonAgentHook.cs @@ -0,0 +1,83 @@ +namespace BotSharp.Core.Agents.Hooks; + +public class CommonAgentHook : AgentHookBase +{ + public override string SelfId => string.Empty; + + public CommonAgentHook(IServiceProvider services, AgentSettings settings) + : base(services, settings) + { + } + + public override void OnAgentUtilityLoaded(Agent agent) + { + if (agent.Type == AgentType.Routing) return; + + var conv = _services.GetRequiredService(); + var isConvMode = conv.IsConversationMode(); + if (!isConvMode) return; + + agent.Functions ??= []; + agent.Utilities ??= []; + + var (functions, templates) = GetUtilityContent(agent); + + foreach (var fn in functions) + { + if (!agent.Functions.Any(x => x.Name.Equals(fn.Name, StringComparison.OrdinalIgnoreCase))) + { + agent.Functions.Add(fn); + } + } + + foreach (var prompt in templates) + { + agent.Instruction += $"\r\n\r\n{prompt}\r\n\r\n"; + } + } + + private (IEnumerable, IEnumerable) GetUtilityContent(Agent agent) + { + var db = _services.GetRequiredService(); + var (functionNames, templateNames) = GetUniqueContent(agent.Utilities); + + if (agent.MergeUtility) + { + var routing = _services.GetRequiredService(); + var entryAgentId = routing.EntryAgentId; + if (!string.IsNullOrEmpty(entryAgentId)) + { + var entryAgent = db.GetAgent(entryAgentId); + var (fns, tps) = GetUniqueContent(entryAgent?.Utilities); + functionNames = functionNames.Concat(fns).Distinct().ToList(); + templateNames = templateNames.Concat(tps).Distinct().ToList(); + } + } + + var ua = db.GetAgent(BuiltInAgentId.UtilityAssistant); + var functions = ua?.Functions?.Where(x => functionNames.Contains(x.Name, StringComparer.OrdinalIgnoreCase))?.ToList() ?? []; + var templates = ua?.Templates?.Where(x => templateNames.Contains(x.Name, StringComparer.OrdinalIgnoreCase))?.Select(x => x.Content)?.ToList() ?? []; + return (functions, templates); + } + + private (IEnumerable, IEnumerable) GetUniqueContent(IEnumerable? utilities) + { + if (utilities.IsNullOrEmpty()) + { + return ([], []); + } + + var prefix = "util-"; + utilities = utilities?.Where(x => !string.IsNullOrEmpty(x.Name) && !x.Disabled)?.ToList() ?? []; + var functionNames = utilities.SelectMany(x => x.Functions) + .Where(x => !string.IsNullOrEmpty(x.Name) && x.Name.StartsWith(prefix)) + .Select(x => x.Name) + .Distinct().ToList(); + var templateNames = utilities.SelectMany(x => x.Templates) + .Where(x => !string.IsNullOrEmpty(x.Name) && x.Name.StartsWith(prefix)) + .Select(x => x.Name) + .Distinct().ToList(); + + return (functionNames, templateNames); + } +}