Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minor change #623

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 10 additions & 25 deletions src/Plugins/BotSharp.Plugin.Planner/Functions/PrimaryStagePlanFn.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
using BotSharp.Abstraction.Conversations.Models;
using BotSharp.Abstraction.Functions;
using BotSharp.Abstraction.Knowledges;
using BotSharp.Abstraction.Knowledges.Models;
using BotSharp.Abstraction.Routing;
using BotSharp.Plugin.Planner.TwoStaging.Models;

namespace BotSharp.Plugin.Planner.Functions;
Expand All @@ -22,7 +18,6 @@ public PrimaryStagePlanFn(IServiceProvider services, ILogger<PrimaryStagePlanFn>

public async Task<bool> Execute(RoleDialogModel message)
{
// Debug
var agentService = _services.GetRequiredService<IAgentService>();
var state = _services.GetRequiredService<IConversationStateService>();
var knowledgeService = _services.GetRequiredService<IKnowledgeService>();
Expand All @@ -48,9 +43,9 @@ public async Task<bool> Execute(RoleDialogModel message)
knowledges.Add(retrievalMessage.Content);
}

// Send knowledge to AI to refine and summarize the primary planning
// Get first stage planning prompt
var currentAgent = await agentService.LoadAgent(message.CurrentAgentId);
var firstPlanningPrompt = await GetFirstStagePlanPrompt(task, knowledges);
var firstPlanningPrompt = await GetFirstStagePlanPrompt(task.Requirements, knowledges);
var plannerAgent = new Agent
{
Id = BuiltInAgentId.Planner,
Expand All @@ -65,21 +60,22 @@ public async Task<bool> Execute(RoleDialogModel message)
return true;
}

private async Task<string> GetFirstStagePlanPrompt(PrimaryRequirementRequest task, List<string> relevantKnowledges)
private async Task<string> GetFirstStagePlanPrompt(string taskDescription, List<string> relevantKnowledges)
{
var agentService = _services.GetRequiredService<IAgentService>();
var render = _services.GetRequiredService<ITemplateRender>();
var knowledgeHooks = _services.GetServices<IKnowledgeHook>();

var aiAssistant = await agentService.GetAgent(BuiltInAgentId.Planner);
var template = aiAssistant.Templates.FirstOrDefault(x => x.Name == "two_stage.1st.plan")?.Content ?? string.Empty;
var responseFormat = JsonSerializer.Serialize(new FirstStagePlan
{
Parameters = [JsonDocument.Parse("{}")],
Results = [""]
Parameters = [ JsonDocument.Parse("{}") ],
Results = [ string.Empty ]
});

// Get global knowledges
var globalKnowledges = new List<string>();
var knowledgeHooks = _services.GetServices<IKnowledgeHook>();
foreach (var hook in knowledgeHooks)
{
var k = await hook.GetGlobalKnowledges();
Expand All @@ -88,7 +84,7 @@ private async Task<string> GetFirstStagePlanPrompt(PrimaryRequirementRequest tas

return render.Render(template, new Dictionary<string, object>
{
{ "task_description", task.Requirements },
{ "task_description", taskDescription },
{ "global_knowledges", globalKnowledges },
{ "relevant_knowledges", relevantKnowledges },
{ "response_format", responseFormat }
Expand All @@ -100,19 +96,8 @@ private async Task<RoleDialogModel> GetAiResponse(Agent plannerAgent)
var conv = _services.GetRequiredService<IConversationService>();
var wholeDialogs = conv.GetDialogHistory();

//add "test" to wholeDialogs' last element
if(plannerAgent.Name == "planner_summary")
{
//add "test" to wholeDialogs' last element in a new paragraph
wholeDialogs.Last().Content += "\n\nIf the table structure didn't mention auto incremental, the data field id needs to insert id manually and you need to use max(id) instead of LAST_INSERT_ID function.\nFor example, you should use SET @id = select max(id) from table;";
wholeDialogs.Last().Content += "\n\nTry if you can generate a single query to fulfill the needs";
}

if (plannerAgent.Name == "planning_1st")
{
//add "test" to wholeDialogs' last element in a new paragraph
wholeDialogs.Last().Content += "\n\nYou must analyze the table description to infer the table relations.";
}
// Append text
wholeDialogs.Last().Content += "\n\nYou must analyze the table description to infer the table relations.";

var completion = CompletionProvider.GetChatCompletion(_services,
provider: plannerAgent.LlmConfig.Provider,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public SecondaryStagePlanFn(IServiceProvider services, ILogger<SecondaryStagePla
public async Task<bool> Execute(RoleDialogModel message)
{
var fn = _services.GetRequiredService<IRoutingService>();
var agentService = _services.GetRequiredService<IAgentService>();
var knowledgeService = _services.GetRequiredService<IKnowledgeService>();
var knowledgeSettings = _services.GetRequiredService<KnowledgeBaseSettings>();
var collectionName = knowledgeSettings.Default.CollectionName ?? KnowledgeCollectionName.BotSharp;
Expand All @@ -33,6 +34,7 @@ public async Task<bool> Execute(RoleDialogModel message)
var taskSecondary = JsonSerializer.Deserialize<SecondaryBreakdownTask>(msgSecondary.FunctionArgs);
var items = msgSecondary.Content.JsonArrayContent<FirstStagePlan>();

// Search knowledgebase
foreach (var item in items)
{
if (!item.NeedAdditionalInformation) continue;
Expand All @@ -44,17 +46,15 @@ public async Task<bool> Execute(RoleDialogModel message)
message.Content += string.Join("\r\n\r\n=====\r\n", knowledges.Select(x => x.ToQuestionAnswer()));
}

// load agent
var agentService = _services.GetRequiredService<IAgentService>();
// Get second stage planning prompt
var currentAgent = await agentService.LoadAgent(message.CurrentAgentId);

var secondPlanningPrompt = await GetSecondStagePlanPrompt(taskSecondary, message);
var secondPlanningPrompt = await GetSecondStagePlanPrompt(taskSecondary.TaskDescription, message);
_logger.LogInformation(secondPlanningPrompt);

var plannerAgent = new Agent
{
Id = string.Empty,
Name = "test",
Name = "planning_2nd",
Instruction = secondPlanningPrompt,
TemplateDict = new Dictionary<string, object>(),
LlmConfig = currentAgent.LlmConfig
Expand All @@ -65,7 +65,8 @@ public async Task<bool> Execute(RoleDialogModel message)
_logger.LogInformation(response.Content);
return true;
}
private async Task<string> GetSecondStagePlanPrompt(SecondaryBreakdownTask task, RoleDialogModel message)

private async Task<string> GetSecondStagePlanPrompt(string taskDescription, RoleDialogModel message)
{
var agentService = _services.GetRequiredService<IAgentService>();
var render = _services.GetRequiredService<ITemplateRender>();
Expand All @@ -75,17 +76,18 @@ private async Task<string> GetSecondStagePlanPrompt(SecondaryBreakdownTask task,
var responseFormat = JsonSerializer.Serialize(new SecondStagePlan
{
Tool = "tool name if task solution provided",
Parameters = new JsonDocument[] { JsonDocument.Parse("{}") },
Results = new string[] { "" }
Parameters = [ JsonDocument.Parse("{}") ],
Results = [ string.Empty ]
});

return render.Render(template, new Dictionary<string, object>
{
{ "task_description", task.TaskDescription },
{ "task_description", taskDescription },
{ "primary_plan", new[]{ message.Content } },
{ "response_format", responseFormat }
});
}

private async Task<RoleDialogModel> GetAiResponse(Agent plannerAgent)
{
var conv = _services.GetRequiredService<IConversationService>();
Expand Down
34 changes: 11 additions & 23 deletions src/Plugins/BotSharp.Plugin.Planner/Functions/SummaryPlanFn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,26 @@ public async Task<bool> Execute(RoleDialogModel message)
var currentAgent = await agentService.LoadAgent(message.CurrentAgentId);
state.SetState("max_tokens", "4096");

var task = state.GetState("requirement_detail");
var taskRequirement = state.GetState("requirement_detail");

// Get DDL
// Get table names
var steps = message.Content.JsonArrayContent<SecondStagePlan>();

// Get all the related tables
var allTables = new List<string>();
foreach (var step in steps)
{
allTables.AddRange(step.Tables);
}
message.Data = allTables.Distinct().ToList();

// Get table DDL and stores in content
// Get table DDL statements
var msgCopy = RoleDialogModel.From(message);
await fn.InvokeFunction("get_table_definition", msgCopy);
var ddlStatements = msgCopy.Content;
var relevantKnowledge = message.Content;
message.Data = null;

// Summarize and generate query
var summaryPlanPrompt = await GetPlanSummaryPrompt(task, message.Content, ddlStatements);
var summaryPlanPrompt = await GetPlanSummaryPrompt(taskRequirement, relevantKnowledge, ddlStatements);
_logger.LogInformation($"Summary plan prompt:\r\n{summaryPlanPrompt}");

var plannerAgent = new Agent
Expand All @@ -64,9 +63,8 @@ public async Task<bool> Execute(RoleDialogModel message)
return true;
}

private async Task<string> GetPlanSummaryPrompt(string task, string knowledge, string ddlStatement)
private async Task<string> GetPlanSummaryPrompt(string taskDescription, string relevantKnowledge, string ddlStatement)
{
// save to knowledge base
var agentService = _services.GetRequiredService<IAgentService>();
var render = _services.GetRequiredService<ITemplateRender>();

Expand All @@ -81,8 +79,8 @@ private async Task<string> GetPlanSummaryPrompt(string task, string knowledge, s
return render.Render(template, new Dictionary<string, object>
{
{ "table_structure", ddlStatement },
{ "task_description", task },
{ "relevant_knowledges", knowledge },
{ "task_description", taskDescription },
{ "relevant_knowledges", relevantKnowledge },
{ "response_format", responseFormat }
});
}
Expand All @@ -91,19 +89,9 @@ private async Task<RoleDialogModel> GetAiResponse(Agent plannerAgent)
var conv = _services.GetRequiredService<IConversationService>();
var wholeDialogs = conv.GetDialogHistory();

// Add "test" to wholeDialogs' last element
if (plannerAgent.Name == "planner_summary")
{
// Add "test" to wholeDialogs' last element in a new paragraph
wholeDialogs.Last().Content += "\n\nIf the table structure didn't mention auto incremental, the data field id needs to insert id manually and you need to use max(id) instead of LAST_INSERT_ID function.\nFor example, you should use SET @id = select max(id) from table;";
wholeDialogs.Last().Content += "\n\nTry if you can generate a single query to fulfill the needs";
}

if (plannerAgent.Name == "planning_1st")
{
// Add "test" to wholeDialogs' last element in a new paragraph
wholeDialogs.Last().Content += "\n\nYou must analyze the table description to infer the table relations.";
}
// Append text
wholeDialogs.Last().Content += "\n\nIf the table structure didn't mention auto incremental, the data field id needs to insert id manually and you need to use max(id) instead of LAST_INSERT_ID function.\nFor example, you should use SET @id = select max(id) from table;";
wholeDialogs.Last().Content += "\n\nTry if you can generate a single query to fulfill the needs";

var completion = CompletionProvider.GetChatCompletion(_services,
provider: plannerAgent.LlmConfig.Provider,
Expand Down