Skip to content

Commit 94b3adc

Browse files
committed
重构 LLM 服务,使用 IServiceProvider 替代 ILLMFactory,简化服务初始化
1 parent 71531b0 commit 94b3adc

File tree

3 files changed

+177
-7
lines changed

3 files changed

+177
-7
lines changed

TelegramSearchBot/Service/AI/LLM/GeneralLLMService.cs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
using Microsoft.EntityFrameworkCore; // For AnyAsync()
77
using Microsoft.Extensions.Logging;
88
using StackExchange.Redis;
9+
using Microsoft.Extensions.DependencyInjection;
10+
using TelegramSearchBot.Interface.AI.LLM;
911
using TelegramSearchBot.Attributes;
1012
using TelegramSearchBot.Interface;
11-
using TelegramSearchBot.Interface.AI.LLM;
1213
using TelegramSearchBot.Model;
1314
using TelegramSearchBot.Model.AI;
1415
using TelegramSearchBot.Model.Data;
@@ -22,7 +23,7 @@ public class GeneralLLMService : IService, IGeneralLLMService {
2223
private readonly OllamaService _ollamaService;
2324
private readonly GeminiService _geminiService;
2425
private readonly ILogger<GeneralLLMService> _logger;
25-
private readonly ILLMFactory _LLMFactory;
26+
private readonly IServiceProvider _serviceProvider;
2627

2728
public string ServiceName => "GeneralLLMService";
2829

@@ -40,17 +41,17 @@ public GeneralLLMService(
4041
OllamaService ollamaService,
4142
OpenAIService openAIService,
4243
GeminiService geminiService,
43-
ILLMFactory _LLMFactory
44+
IServiceProvider serviceProvider
4445
) {
4546
this.connectionMultiplexer = connectionMultiplexer;
4647
_dbContext = dbContext;
4748
_logger = logger;
49+
_serviceProvider = serviceProvider;
4850

4951
// Initialize services with default values
5052
_openAIService = openAIService;
5153
_ollamaService = ollamaService;
5254
_geminiService = geminiService;
53-
this._LLMFactory = _LLMFactory;
5455
}
5556
public async Task<List<LLMChannel>> GetChannelsAsync(string modelName) {
5657
// 2. 查询ChannelWithModel获取关联的LLMChannel
@@ -140,7 +141,8 @@ orderby s.Priority descending
140141
var redisKey = $"llm:channel:{channel.Id}:semaphore";
141142
var currentCount = await redisDb.StringGetAsync(redisKey);
142143
int count = currentCount.HasValue ? ( int ) currentCount : 0;
143-
var service = _LLMFactory.GetLLMService(channel.Provider);
144+
var llmFactory = _serviceProvider.GetRequiredService<ILLMFactory>();
145+
var service = llmFactory.GetLLMService(channel.Provider);
144146

145147
if (count < channel.Parallel) {
146148
// 获取锁并增加计数

TelegramSearchBot/Service/AI/LLM/LLMServiceFactory.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public LLMServiceFactory(
3030
{
3131
_logger = logger;
3232
_legacyService = legacyService;
33-
_extensionsAIService = (IGeneralLLMService)extensionsAIService;
33+
_extensionsAIService = extensionsAIService; // 直接赋值,因为 OpenAIExtensionsAIService 实现了 IGeneralLLMService
3434
}
3535

3636
/// <summary>

TelegramSearchBot/Service/AI/LLM/OpenAIExtensionsAIService.cs

Lines changed: 169 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace TelegramSearchBot.Service.AI.LLM
3030
/// 这是一个真正的实现,使用Microsoft.Extensions.AI抽象层
3131
/// </summary>
3232
[Injectable(ServiceLifetime.Transient)]
33-
public class OpenAIExtensionsAIService : IService, ILLMService
33+
public class OpenAIExtensionsAIService : IService, ILLMService, IGeneralLLMService
3434
{
3535
public string ServiceName => "OpenAIExtensionsAIService";
3636

@@ -301,5 +301,173 @@ public async Task<bool> IsHealthyAsync(LLMChannel channel)
301301
return false;
302302
}
303303
}
304+
305+
#region IGeneralLLMService Implementation
306+
307+
/// <summary>
308+
/// 获取指定模型的可用渠道
309+
/// </summary>
310+
public async Task<List<LLMChannel>> GetChannelsAsync(string modelName)
311+
{
312+
// 简化实现:返回所有OpenAI渠道
313+
var channels = await _dbContext.LLMChannels
314+
.Where(c => c.Provider == LLMProvider.OpenAI)
315+
.ToListAsync();
316+
return channels;
317+
}
318+
319+
/// <summary>
320+
/// 执行消息处理(简化版本)
321+
/// </summary>
322+
public async IAsyncEnumerable<string> ExecAsync(Model.Data.Message message, long ChatId,
323+
[System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
324+
{
325+
// 获取模型名称
326+
var modelName = await _dbContext.GroupSettings
327+
.Where(s => s.GroupId == ChatId)
328+
.Select(s => s.LLMModelName)
329+
.FirstOrDefaultAsync();
330+
331+
if (string.IsNullOrEmpty(modelName))
332+
{
333+
_logger.LogWarning("未找到模型配置");
334+
yield break;
335+
}
336+
337+
// 获取渠道
338+
var channels = await GetChannelsAsync(modelName);
339+
if (!channels.Any())
340+
{
341+
_logger.LogWarning($"未找到模型 {modelName} 的可用渠道");
342+
yield break;
343+
}
344+
345+
// 使用第一个可用渠道
346+
var channel = channels.First();
347+
await foreach (var response in ExecAsync(message, ChatId, modelName, channel, cancellationToken))
348+
{
349+
yield return response;
350+
}
351+
}
352+
353+
/// <summary>
354+
/// 执行消息处理(带服务和渠道参数的重载)
355+
/// </summary>
356+
public async IAsyncEnumerable<string> ExecAsync(Model.Data.Message message, long ChatId, string modelName, ILLMService service, LLMChannel channel,
357+
[System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellation = default)
358+
{
359+
// 直接调用内部的ExecAsync方法
360+
await foreach (var response in ExecAsync(message, ChatId, modelName, channel, cancellation))
361+
{
362+
yield return response;
363+
}
364+
}
365+
366+
/// <summary>
367+
/// 执行操作(委托给ExecAsync)
368+
/// </summary>
369+
public async IAsyncEnumerable<TResult> ExecOperationAsync<TResult>(
370+
Func<ILLMService, LLMChannel, CancellationToken, IAsyncEnumerable<TResult>> operation,
371+
string modelName,
372+
[System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
373+
{
374+
var channels = await GetChannelsAsync(modelName);
375+
if (!channels.Any())
376+
{
377+
_logger.LogWarning($"未找到模型 {modelName} 的可用渠道");
378+
yield break;
379+
}
380+
381+
var channel = channels.First();
382+
await foreach (var result in operation(this, channel, cancellationToken))
383+
{
384+
yield return result;
385+
}
386+
}
387+
388+
/// <summary>
389+
/// 分析图像(简化版本)
390+
/// </summary>
391+
public async Task<string> AnalyzeImageAsync(string PhotoPath, long ChatId, CancellationToken cancellationToken = default)
392+
{
393+
var modelName = await _dbContext.GroupSettings
394+
.Where(s => s.GroupId == ChatId)
395+
.Select(s => s.LLMModelName)
396+
.FirstOrDefaultAsync() ?? "gpt-4-vision-preview";
397+
398+
var channels = await GetChannelsAsync(modelName);
399+
if (!channels.Any())
400+
{
401+
throw new Exception($"未找到模型 {modelName} 的可用渠道");
402+
}
403+
404+
var channel = channels.First();
405+
return await AnalyzeImageAsync(PhotoPath, modelName, channel);
406+
}
407+
408+
/// <summary>
409+
/// 分析图像 - 带服务和渠道参数的重载
410+
/// </summary>
411+
public async IAsyncEnumerable<string> AnalyzeImageAsync(string PhotoPath, long ChatId, string modelName, ILLMService service, LLMChannel channel,
412+
[System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
413+
{
414+
var result = await AnalyzeImageAsync(PhotoPath, modelName, channel);
415+
yield return result;
416+
}
417+
418+
/// <summary>
419+
/// 生成消息嵌入向量
420+
/// </summary>
421+
public async Task<float[]> GenerateEmbeddingsAsync(Model.Data.Message message, long ChatId)
422+
{
423+
var text = message.Content ?? "";
424+
return await GenerateEmbeddingsAsync(text, CancellationToken.None);
425+
}
426+
427+
/// <summary>
428+
/// 生成文本嵌入向量(简化版本)
429+
/// </summary>
430+
public async Task<float[]> GenerateEmbeddingsAsync(string message, CancellationToken cancellationToken = default)
431+
{
432+
var modelName = "text-embedding-ada-002"; // 默认嵌入模型
433+
var channels = await GetChannelsAsync(modelName);
434+
if (!channels.Any())
435+
{
436+
throw new Exception($"未找到嵌入模型 {modelName} 的可用渠道");
437+
}
438+
439+
var channel = channels.First();
440+
return await GenerateEmbeddingsAsync(message, modelName, channel);
441+
}
442+
443+
/// <summary>
444+
/// 生成嵌入向量 - 带服务和渠道参数的重载
445+
/// </summary>
446+
public async IAsyncEnumerable<float[]> GenerateEmbeddingsAsync(string message, string modelName, ILLMService service, LLMChannel channel,
447+
[System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
448+
{
449+
var result = await GenerateEmbeddingsAsync(message, modelName, channel);
450+
yield return result;
451+
}
452+
453+
/// <summary>
454+
/// 获取AltPhoto可用容量
455+
/// </summary>
456+
public Task<int> GetAltPhotoAvailableCapacityAsync()
457+
{
458+
// 简化实现:返回固定值
459+
return Task.FromResult(100);
460+
}
461+
462+
/// <summary>
463+
/// 获取可用容量
464+
/// </summary>
465+
public Task<int> GetAvailableCapacityAsync(string modelName = "gpt-3.5-turbo")
466+
{
467+
// 简化实现:返回固定值
468+
return Task.FromResult(1000);
469+
}
470+
471+
#endregion
304472
}
305473
}

0 commit comments

Comments
 (0)