Skip to content

Commit

Permalink
Merge pull request #459 from hchen2020/master
Browse files Browse the repository at this point in the history
Fix llm selection bug.
  • Loading branch information
Oceania2018 authored May 16, 2024
2 parents 19ceb94 + 2c757ac commit 7b87046
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ public interface ILlmProviderService
{
LlmModelSetting GetSetting(string provider, string model);
List<string> GetProviders();
LlmModelSetting GetProviderModel(string provider, string id, bool multiModal = false);
LlmModelSetting GetProviderModel(string provider, string id, bool? multiModal = null);
List<LlmModelSetting> GetProviderModels(string provider);
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public static IChatCompletion GetChatCompletion(IServiceProvider services,
string? provider = null,
string? model = null,
string? modelId = null,
bool multiModal = false,
bool? multiModal = null,
AgentLlmConfig? agentConfig = null)
{
var completions = services.GetServices<IChatCompletion>();
Expand All @@ -59,7 +59,7 @@ private static (string, string) GetProviderAndModel(IServiceProvider services,
string? provider = null,
string? model = null,
string? modelId = null,
bool multiModal = false,
bool? multiModal = null,
AgentLlmConfig? agentConfig = null)
{
var agentSetting = services.GetRequiredService<AgentSettings>();
Expand All @@ -82,7 +82,7 @@ private static (string, string) GetProviderAndModel(IServiceProvider services,
{
var modelIdentity = state.ContainsState("model_id") ? state.GetState("model_id") : modelId;
var llmProviderService = services.GetRequiredService<ILlmProviderService>();
model = llmProviderService.GetProviderModel(provider, modelIdentity, multiModal)?.Name;
model = llmProviderService.GetProviderModel(provider, modelIdentity, multiModal: multiModal)?.Name;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,15 @@ public List<LlmModelSetting> GetProviderModels(string provider)
?.Models ?? new List<LlmModelSetting>();
}

public LlmModelSetting GetProviderModel(string provider, string id, bool multiModal = false)
public LlmModelSetting GetProviderModel(string provider, string id, bool? multiModal = null)
{
var models = GetProviderModels(provider)
.Where(x => x.Id == id && x.MultiModal == multiModal)
.ToList();
.Where(x => x.Id == id);

if (multiModal.HasValue)
{
models = models.Where(x => x.MultiModal == multiModal);
}

var random = new Random();
var index = random.Next(0, models.Count());
Expand Down

0 comments on commit 7b87046

Please sign in to comment.