From 57122f7833cd73ceee1ce24289fe41a5d6497c63 Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Thu, 16 May 2024 16:32:58 -0500 Subject: [PATCH] Fix llm selection bug. --- .../MLTasks/ILlmProviderService.cs | 2 +- .../Infrastructures/CompletionProvider.cs | 6 +++--- .../Infrastructures/LlmProviderService.cs | 10 +++++++--- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/Infrastructure/BotSharp.Abstraction/MLTasks/ILlmProviderService.cs b/src/Infrastructure/BotSharp.Abstraction/MLTasks/ILlmProviderService.cs index 120304d2c..20762fe0f 100644 --- a/src/Infrastructure/BotSharp.Abstraction/MLTasks/ILlmProviderService.cs +++ b/src/Infrastructure/BotSharp.Abstraction/MLTasks/ILlmProviderService.cs @@ -6,6 +6,6 @@ public interface ILlmProviderService { LlmModelSetting GetSetting(string provider, string model); List GetProviders(); - LlmModelSetting GetProviderModel(string provider, string id, bool multiModal = false); + LlmModelSetting GetProviderModel(string provider, string id, bool? multiModal = null); List GetProviderModels(string provider); } diff --git a/src/Infrastructure/BotSharp.Core/Infrastructures/CompletionProvider.cs b/src/Infrastructure/BotSharp.Core/Infrastructures/CompletionProvider.cs index ace1e6642..4655b1de7 100644 --- a/src/Infrastructure/BotSharp.Core/Infrastructures/CompletionProvider.cs +++ b/src/Infrastructure/BotSharp.Core/Infrastructures/CompletionProvider.cs @@ -36,7 +36,7 @@ public static IChatCompletion GetChatCompletion(IServiceProvider services, string? provider = null, string? model = null, string? modelId = null, - bool multiModal = false, + bool? multiModal = null, AgentLlmConfig? agentConfig = null) { var completions = services.GetServices(); @@ -59,7 +59,7 @@ private static (string, string) GetProviderAndModel(IServiceProvider services, string? provider = null, string? model = null, string? modelId = null, - bool multiModal = false, + bool? multiModal = null, AgentLlmConfig? agentConfig = null) { var agentSetting = services.GetRequiredService(); @@ -82,7 +82,7 @@ private static (string, string) GetProviderAndModel(IServiceProvider services, { var modelIdentity = state.ContainsState("model_id") ? state.GetState("model_id") : modelId; var llmProviderService = services.GetRequiredService(); - model = llmProviderService.GetProviderModel(provider, modelIdentity, multiModal)?.Name; + model = llmProviderService.GetProviderModel(provider, modelIdentity, multiModal: multiModal)?.Name; } } diff --git a/src/Infrastructure/BotSharp.Core/Infrastructures/LlmProviderService.cs b/src/Infrastructure/BotSharp.Core/Infrastructures/LlmProviderService.cs index 7d92a6877..8320bdb73 100644 --- a/src/Infrastructure/BotSharp.Core/Infrastructures/LlmProviderService.cs +++ b/src/Infrastructure/BotSharp.Core/Infrastructures/LlmProviderService.cs @@ -44,11 +44,15 @@ public List GetProviderModels(string provider) ?.Models ?? new List(); } - public LlmModelSetting GetProviderModel(string provider, string id, bool multiModal = false) + public LlmModelSetting GetProviderModel(string provider, string id, bool? multiModal = null) { var models = GetProviderModels(provider) - .Where(x => x.Id == id && x.MultiModal == multiModal) - .ToList(); + .Where(x => x.Id == id); + + if (multiModal.HasValue) + { + models = models.Where(x => x.MultiModal == multiModal); + } var random = new Random(); var index = random.Next(0, models.Count());