Skip to content

Commit ea3bf3e

Browse files
authored
Merge pull request #187 from hchen2020/master
Use text completion for instruct mode.
2 parents cf98805 + 8c60e6a commit ea3bf3e

File tree

15 files changed

+86
-153
lines changed

15 files changed

+86
-153
lines changed

src/Infrastructure/BotSharp.Abstraction/Instructs/IInstructService.cs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,5 @@ namespace BotSharp.Abstraction.Instructs;
44

55
public interface IInstructService
66
{
7-
Task<InstructResult> ExecuteInstruction(Agent agent,
8-
RoleDialogModel message,
9-
Func<RoleDialogModel, Task> onMessageReceived,
10-
Func<RoleDialogModel, Task> onFunctionExecuting,
11-
Func<RoleDialogModel, Task> onFunctionExecuted);
7+
Task<InstructResult> Execute(Agent agent, RoleDialogModel message);
128
}

src/Infrastructure/BotSharp.Abstraction/Instructs/Models/InstructResult.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@ namespace BotSharp.Abstraction.Instructs.Models;
33
public class InstructResult
44
{
55
public string Text { get; set; }
6-
public string Function { get; set; }
76
public object Data { get; set; }
87
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
namespace BotSharp.Abstraction.MLTasks.Settings;
2+
3+
public class ChatCompletionSetting
4+
{
5+
public string Provider { get; set; }
6+
public string Model { get; set; }
7+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
namespace BotSharp.Abstraction.MLTasks.Settings;
2+
3+
public class TextCompletionSetting
4+
{
5+
public string Provider { get; set; }
6+
public string Model { get; set; }
7+
}

src/Infrastructure/BotSharp.Core/BotSharpServiceCollectionExtensions.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
using BotSharp.Abstraction.Evaluations;
1717
using BotSharp.Core.Evaluatings;
1818
using BotSharp.Core.Evaluations;
19+
using BotSharp.Abstraction.MLTasks.Settings;
1920

2021
namespace BotSharp.Core;
2122

@@ -48,6 +49,14 @@ public static IServiceCollection AddBotSharp(this IServiceCollection services, I
4849
config.Bind("Database", myDatabaseSettings);
4950
services.AddSingleton((IServiceProvider x) => myDatabaseSettings);
5051

52+
var textCompletionSettings = new TextCompletionSetting();
53+
config.Bind("TextCompletion", textCompletionSettings);
54+
services.AddSingleton((IServiceProvider x) => textCompletionSettings);
55+
56+
var chatCompletionSettings = new ChatCompletionSetting();
57+
config.Bind("ChatCompletion", chatCompletionSettings);
58+
services.AddSingleton((IServiceProvider x) => chatCompletionSettings);
59+
5160
RegisterPlugins(services, config);
5261

5362
// Register template render

src/Infrastructure/BotSharp.Core/Infrastructures/CompletionProvider.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
11
using BotSharp.Abstraction.MLTasks;
2+
using BotSharp.Abstraction.MLTasks.Settings;
23

34
namespace BotSharp.Core.Infrastructures;
45

56
public class CompletionProvider
67
{
78
public static IChatCompletion GetChatCompletion(IServiceProvider services, string? provider = null, string? model = null)
89
{
10+
var settings = services.GetRequiredService<ChatCompletionSetting>();
911
var completions = services.GetServices<IChatCompletion>();
1012

1113
var state = services.GetRequiredService<IConversationStateService>();
1214

1315
if (string.IsNullOrEmpty(provider))
1416
{
15-
provider = state.GetState("provider", "azure-openai");
17+
provider = state.GetState("provider", settings.Provider ?? "azure-openai");
1618
}
1719

1820
if (string.IsNullOrEmpty(model))
1921
{
20-
model = state.GetState("model", "gpt-3.5-turbo");
22+
model = state.GetState("model", settings.Model ?? "gpt-3.5-turbo");
2123
}
2224

2325
var completer = completions.FirstOrDefault(x => x.Provider == provider);
@@ -34,18 +36,19 @@ public static IChatCompletion GetChatCompletion(IServiceProvider services, strin
3436

3537
public static ITextCompletion GetTextCompletion(IServiceProvider services, string? provider = null, string? model = null)
3638
{
39+
var settings = services.GetRequiredService<TextCompletionSetting>();
3740
var completions = services.GetServices<ITextCompletion>();
3841

3942
var state = services.GetRequiredService<IConversationStateService>();
4043

4144
if (string.IsNullOrEmpty(provider))
4245
{
43-
provider = state.GetState("provider", "azure-openai");
46+
provider = state.GetState("provider", settings.Provider ?? "azure-openai");
4447
}
4548

4649
if (string.IsNullOrEmpty(model))
4750
{
48-
model = state.GetState("model", "gpt-3.5-turbo");
51+
model = state.GetState("model", settings.Model ?? "gpt-3.5-turbo");
4952
}
5053

5154
var completer = completions.FirstOrDefault(x => x.Provider == provider);
Lines changed: 8 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
1-
using BotSharp.Abstraction.Agents.Enums;
21
using BotSharp.Abstraction.Agents.Models;
32
using BotSharp.Abstraction.Instructs;
43
using BotSharp.Abstraction.Instructs.Models;
5-
using BotSharp.Abstraction.MLTasks;
6-
using BotSharp.Abstraction.Templating;
7-
using System.IO;
84

95
namespace BotSharp.Core.Instructs;
106

@@ -19,43 +15,22 @@ public InstructService(IServiceProvider services, ILogger<InstructService> logge
1915
_logger = logger;
2016
}
2117

22-
public async Task<InstructResult> ExecuteInstruction(Agent agent,
23-
RoleDialogModel message,
24-
Func<RoleDialogModel, Task> onMessageReceived,
25-
Func<RoleDialogModel, Task> onFunctionExecuting,
26-
Func<RoleDialogModel, Task> onFunctionExecuted)
18+
public async Task<InstructResult> Execute(Agent agent, RoleDialogModel message)
2719
{
28-
var response = new InstructResult();
29-
30-
var wholeDialogs = new List<RoleDialogModel>
31-
{
32-
message
33-
};
34-
3520
// Trigger before completion hooks
3621
var hooks = _services.GetServices<IInstructHook>();
3722
foreach (var hook in hooks)
3823
{
3924
await hook.BeforeCompletion(message);
4025
}
4126

42-
await ExecuteInstructionRecursively(agent,
43-
wholeDialogs,
44-
async msg =>
45-
{
46-
response.Text = msg.Content;
47-
await onMessageReceived(msg);
48-
},
49-
async fn =>
50-
{
51-
response.Function = fn.FunctionName;
52-
await onFunctionExecuting(fn);
53-
},
54-
async fn =>
55-
{
56-
response.Data = fn.Data;
57-
await onFunctionExecuted(fn);
58-
});
27+
var completer = CompletionProvider.GetTextCompletion(_services);
28+
29+
var result = await completer.GetCompletion(agent.Instruction);
30+
var response = new InstructResult
31+
{
32+
Text = result
33+
};
5934

6035
foreach (var hook in hooks)
6136
{
@@ -64,64 +39,4 @@ await ExecuteInstructionRecursively(agent,
6439

6540
return response;
6641
}
67-
68-
private async Task<bool> ExecuteInstructionRecursively(Agent agent,
69-
List<RoleDialogModel> wholeDialogs,
70-
Func<RoleDialogModel, Task> onMessageReceived,
71-
Func<RoleDialogModel, Task> onFunctionExecuting,
72-
Func<RoleDialogModel, Task> onFunctionExecuted)
73-
{
74-
var chatCompletion = CompletionProvider.GetChatCompletion(_services);
75-
76-
var result = await chatCompletion.GetChatCompletionsAsync(agent, wholeDialogs, async msg =>
77-
{
78-
await onMessageReceived(msg);
79-
80-
wholeDialogs.Add(msg);
81-
}, async fn =>
82-
{
83-
var preAgentId = agent.Id;
84-
85-
await HandleFunctionMessage(fn, onFunctionExecuting, onFunctionExecuted);
86-
87-
// Function executed has exception
88-
if (fn.Content == null || fn.StopCompletion)
89-
{
90-
await onMessageReceived(new RoleDialogModel(AgentRole.Assistant, fn.Content));
91-
return;
92-
}
93-
94-
fn.Content = fn.FunctionArgs.Replace("\r", " ").Replace("\n", " ").Trim() + " => " + fn.Content;
95-
96-
// Find response template
97-
var templateService = _services.GetRequiredService<IResponseTemplateService>();
98-
var response = await templateService.RenderFunctionResponse(agent.Id, fn);
99-
if (!string.IsNullOrEmpty(response))
100-
{
101-
await onMessageReceived(new RoleDialogModel(AgentRole.Assistant, response));
102-
return;
103-
}
104-
105-
// After function is executed, pass the result to LLM to get a natural response
106-
wholeDialogs.Add(fn);
107-
108-
await ExecuteInstructionRecursively(agent,
109-
wholeDialogs,
110-
onMessageReceived,
111-
onFunctionExecuting,
112-
onFunctionExecuted);
113-
});
114-
115-
return result;
116-
}
117-
118-
private async Task HandleFunctionMessage(RoleDialogModel msg,
119-
Func<RoleDialogModel, Task> onFunctionExecuting,
120-
Func<RoleDialogModel, Task> onFunctionExecuted)
121-
{
122-
// Call functions
123-
await onFunctionExecuting(msg);
124-
await CallFunctions(msg);
125-
await onFunctionExecuted(msg);
126-
}
12742
}

src/Infrastructure/BotSharp.OpenAPI/Controllers/InstructModeController.cs

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using BotSharp.Abstraction.Conversations.Models;
55
using BotSharp.Abstraction.Instructs;
66
using BotSharp.Abstraction.Instructs.Models;
7+
using BotSharp.Abstraction.Templating;
78
using BotSharp.Core.Infrastructures;
89
using BotSharp.OpenAPI.ViewModels.Instructs;
910

@@ -24,34 +25,37 @@ public InstructModeController(IServiceProvider services)
2425
public async Task<InstructResult> InstructCompletion([FromRoute] string agentId,
2526
[FromBody] InstructMessageModel input)
2627
{
27-
var instructor = _services.GetRequiredService<IInstructService>();
28+
var state = _services.GetRequiredService<IConversationStateService>();
29+
input.States.ForEach(x => state.SetState(x.Split('=')[0], x.Split('=')[1]));
30+
state.SetState("provider", input.Provider)
31+
.SetState("model", input.Model)
32+
.SetState("input_text", input.Text);
33+
2834
var agentService = _services.GetRequiredService<IAgentService>();
2935
Agent agent = await agentService.LoadAgent(agentId);
3036

3137
// switch to different instruction template
3238
if (!string.IsNullOrEmpty(input.Template))
3339
{
34-
agent.Instruction = agent.Templates.First(x => x.Name == input.Template).Content;
40+
var template = agent.Templates.First(x => x.Name == input.Template).Content;
41+
var render = _services.GetRequiredService<ITemplateRender>();
42+
var dict = new Dictionary<string, object>();
43+
state.GetStates().Select(x => dict[x.Key] = x.Value).ToArray();
44+
var prompt = render.Render(template, dict);
45+
agent.Instruction = prompt;
3546
}
3647

37-
var conv = _services.GetRequiredService<IConversationService>();
38-
input.States.ForEach(x => conv.States.SetState(x.Split('=')[0], x.Split('=')[1]));
39-
conv.States.SetState("provider", input.Provider)
40-
.SetState("model", input.Model);
41-
42-
return await instructor.ExecuteInstruction(agent,
43-
new RoleDialogModel(AgentRole.User, input.Text),
44-
fn => Task.CompletedTask,
45-
fn => Task.CompletedTask,
46-
fn => Task.CompletedTask);
48+
var instructor = _services.GetRequiredService<IInstructService>();
49+
return await instructor.Execute(agent,
50+
new RoleDialogModel(AgentRole.User, input.Text));
4751
}
4852

4953
[HttpPost("/instruct/text-completion")]
5054
public async Task<string> TextCompletion([FromBody] IncomingMessageModel input)
5155
{
52-
var conv = _services.GetRequiredService<IConversationService>();
53-
input.States.ForEach(x => conv.States.SetState(x.Split('=')[0], x.Split('=')[1]));
54-
conv.States.SetState("provider", input.Provider)
56+
var state = _services.GetRequiredService<IConversationStateService>();
57+
input.States.ForEach(x => state.SetState(x.Split('=')[0], x.Split('=')[1]));
58+
state.SetState("provider", input.Provider)
5559
.SetState("model", input.Model);
5660

5761
var textCompletion = CompletionProvider.GetTextCompletion(_services);

src/Plugins/BotSharp.Plugin.AzureOpenAI/AzureOpenAiPlugin.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public void RegisterDI(IServiceCollection services, IConfiguration config)
2424
config.Bind("AzureOpenAi", settings);
2525
services.AddSingleton(x =>
2626
{
27-
Console.WriteLine($"Loaded AzureOpenAi settings: {settings.DeploymentModel} ({settings.Endpoint}) {settings.ApiKey.SubstringMax(4)}");
27+
Console.WriteLine($"Loaded AzureOpenAi settings: ({settings.Endpoint}) {settings.ApiKey.SubstringMax(4)}");
2828
return settings;
2929
});
3030

src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/ChatCompletionProvider.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ public RoleDialogModel GetChatCompletions(Agent agent, List<RoleDialogModel> con
4343
Task.WaitAll(hooks.Select(hook =>
4444
hook.BeforeGenerating(agent, conversations)).ToArray());
4545

46-
var (client, deploymentModel) = ProviderHelper.GetClient(_model, _settings);
46+
var client = ProviderHelper.GetClient(_model, _settings);
4747
var chatCompletionsOptions = PrepareOptions(agent, conversations);
4848

49-
var response = client.GetChatCompletions(deploymentModel, chatCompletionsOptions);
49+
var response = client.GetChatCompletions(_model, chatCompletionsOptions);
5050
var choice = response.Value.Choices[0];
5151
var message = choice.Message;
5252

@@ -96,10 +96,10 @@ public async Task<bool> GetChatCompletionsAsync(Agent agent,
9696
Task.WaitAll(hooks.Select(hook =>
9797
hook.BeforeGenerating(agent, conversations)).ToArray());
9898

99-
var (client, deploymentModel) = ProviderHelper.GetClient(_model, _settings);
99+
var client = ProviderHelper.GetClient(_model, _settings);
100100
var chatCompletionsOptions = PrepareOptions(agent, conversations);
101101

102-
var response = await client.GetChatCompletionsAsync(deploymentModel, chatCompletionsOptions);
102+
var response = await client.GetChatCompletionsAsync(_model, chatCompletionsOptions);
103103
var choice = response.Value.Choices[0];
104104
var message = choice.Message;
105105

@@ -148,10 +148,10 @@ public async Task<bool> GetChatCompletionsAsync(Agent agent,
148148

149149
public async Task<bool> GetChatCompletionsStreamingAsync(Agent agent, List<RoleDialogModel> conversations, Func<RoleDialogModel, Task> onMessageReceived)
150150
{
151-
var client = new OpenAIClient(new Uri(_settings.Endpoint), new AzureKeyCredential(_settings.ApiKey));
151+
var client = ProviderHelper.GetClient(_model, _settings);
152152
var chatCompletionsOptions = PrepareOptions(agent, conversations);
153153

154-
var response = await client.GetChatCompletionsStreamingAsync(_settings.DeploymentModel.ChatCompletionModel, chatCompletionsOptions);
154+
var response = await client.GetChatCompletionsStreamingAsync(_model, chatCompletionsOptions);
155155
using StreamingChatCompletions streaming = response.Value;
156156

157157
string output = "";

0 commit comments

Comments
 (0)