diff --git a/OpenAI.SDK/Betalgo.Ranul.OpenAI.csproj b/OpenAI.SDK/Betalgo.Ranul.OpenAI.csproj index 57d627f8..d4c37800 100644 --- a/OpenAI.SDK/Betalgo.Ranul.OpenAI.csproj +++ b/OpenAI.SDK/Betalgo.Ranul.OpenAI.csproj @@ -74,4 +74,8 @@ + + + + \ No newline at end of file diff --git a/OpenAI.SDK/Managers/OpenAIChatClient.cs b/OpenAI.SDK/Managers/OpenAIChatClient.cs new file mode 100644 index 00000000..57ba3117 --- /dev/null +++ b/OpenAI.SDK/Managers/OpenAIChatClient.cs @@ -0,0 +1,390 @@ +using System.Runtime.CompilerServices; +using System.Text.Json; +using Betalgo.Ranul.OpenAI.ObjectModels; +using Betalgo.Ranul.OpenAI.ObjectModels.RequestModels; +using Betalgo.Ranul.OpenAI.ObjectModels.ResponseModels; +using Betalgo.Ranul.OpenAI.ObjectModels.SharedModels; +using Microsoft.Extensions.AI; +using ChatMessage = Microsoft.Extensions.AI.ChatMessage; + +namespace Betalgo.Ranul.OpenAI.Managers; + +public partial class OpenAIService : IChatClient +{ + private ChatClientMetadata? _chatMetadata; + + /// + ChatClientMetadata IChatClient.Metadata => _chatMetadata ??= new(nameof(OpenAIService), _httpClient.BaseAddress, _defaultModelId); + + /// + TService? IChatClient.GetService(object? key) where TService : class + { + return this as TService; + } + + /// + void IDisposable.Dispose() + { + } + + /// + async Task IChatClient.CompleteAsync(IList chatMessages, ChatOptions? options, CancellationToken cancellationToken) + { + var request = CreateRequest(chatMessages, options); + + var response = await ChatCompletion.CreateCompletion(request, options?.ModelId, cancellationToken); + ThrowIfNotSuccessful(response); + + string? finishReason = null; + List responseMessages = []; + foreach (var choice in response.Choices) + { + finishReason ??= choice.FinishReason; + + ChatMessage m = new() + { + Role = new(choice.Message.Role), + AuthorName = choice.Message.Name, + RawRepresentation = choice + }; + + PopulateContents(choice.Message, m.Contents); + + if (response.ServiceTier is string serviceTier) + { + (m.AdditionalProperties ??= [])[nameof(response.ServiceTier)] = serviceTier; + } + + if (response.SystemFingerPrint is string fingerprint) + { + (m.AdditionalProperties ??= [])[nameof(response.SystemFingerPrint)] = fingerprint; + } + + responseMessages.Add(m); + } + + return new(responseMessages) + { + CreatedAt = response.CreatedAt, + CompletionId = response.Id, + FinishReason = finishReason is not null ? new(finishReason) : null, + ModelId = response.Model, + RawRepresentation = response, + Usage = response.Usage is { } usage ? GetUsageDetails(usage) : null + }; + } + + /// + async IAsyncEnumerable IChatClient.CompleteStreamingAsync(IList chatMessages, ChatOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) + { + var request = CreateRequest(chatMessages, options); + + await foreach (var response in ChatCompletion.CreateCompletionAsStream(request, options?.ModelId, cancellationToken: cancellationToken)) + { + ThrowIfNotSuccessful(response); + + foreach (var choice in response.Choices) + { + StreamingChatCompletionUpdate update = new() + { + AuthorName = choice.Delta.Name, + CompletionId = response.Id, + CreatedAt = response.CreatedAt, + FinishReason = choice.FinishReason is not null ? new(choice.FinishReason) : null, + ModelId = response.Model, + RawRepresentation = response, + Role = choice.Delta.Role is not null ? new(choice.Delta.Role) : null + }; + + if (choice.Index is not null) + { + update.ChoiceIndex = choice.Index.Value; + } + + if (response.ServiceTier is string serviceTier) + { + (update.AdditionalProperties ??= [])[nameof(response.ServiceTier)] = serviceTier; + } + + if (response.SystemFingerPrint is string fingerprint) + { + (update.AdditionalProperties ??= [])[nameof(response.SystemFingerPrint)] = fingerprint; + } + + PopulateContents(choice.Delta, update.Contents); + + yield return update; + + if (response.Usage is { } usage) + { + yield return new() + { + AuthorName = choice.Delta.Name, + CompletionId = response.Id, + Contents = [new UsageContent(GetUsageDetails(usage))], + CreatedAt = response.CreatedAt, + FinishReason = choice.FinishReason is not null ? new(choice.FinishReason) : null, + ModelId = response.Model, + Role = choice.Delta.Role is not null ? new(choice.Delta.Role) : null + }; + } + } + } + } + + private static void ThrowIfNotSuccessful(ChatCompletionCreateResponse response) + { + if (!response.Successful) + { + throw new InvalidOperationException(response.Error is { } error ? $"{response.Error.Code}: {response.Error.Message}" : "Betalgo.Ranul Unknown error"); + } + } + + private ChatCompletionCreateRequest CreateRequest(IList chatMessages, ChatOptions? options) + { + ChatCompletionCreateRequest request = new() + { + Model = options?.ModelId ?? _defaultModelId + }; + + if (options is not null) + { + // Strongly-typed properties from options + request.MaxCompletionTokens = options.MaxOutputTokens; + request.Temperature = options.Temperature; + request.TopP = options.TopP; + request.FrequencyPenalty = options.FrequencyPenalty; + request.PresencePenalty = options.PresencePenalty; + request.StopAsList = options.StopSequences; + + // Non-strongly-typed properties from additional properties + request.LogitBias = options.AdditionalProperties?.TryGetValue(nameof(request.LogitBias), out var logitBias) is true ? logitBias : null; + request.LogProbs = options.AdditionalProperties?.TryGetValue(nameof(request.LogProbs), out bool logProbs) is true ? logProbs : null; + request.N = options.AdditionalProperties?.TryGetValue(nameof(request.N), out int n) is true ? n : null; + request.ParallelToolCalls = options.AdditionalProperties?.TryGetValue(nameof(request.ParallelToolCalls), out bool parallelToolCalls) is true ? parallelToolCalls : null; + request.Seed = options.AdditionalProperties?.TryGetValue(nameof(request.Seed), out int seed) is true ? seed : null; + request.ServiceTier = options.AdditionalProperties?.TryGetValue(nameof(request.ServiceTier), out string? serviceTier) is true ? serviceTier : null!; + request.User = options.AdditionalProperties?.TryGetValue(nameof(request.User), out string? user) is true ? user : null!; + request.TopLogprobs = options.AdditionalProperties?.TryGetValue(nameof(request.TopLogprobs), out int topLogprobs) is true ? topLogprobs : null; + + // Response format + switch (options.ResponseFormat) + { + case ChatResponseFormatText: + request.ResponseFormat = new() { Type = StaticValues.CompletionStatics.ResponseFormat.Text }; + break; + + case ChatResponseFormatJson { Schema: not null } json: + request.ResponseFormat = new() + { + Type = StaticValues.CompletionStatics.ResponseFormat.JsonSchema, + JsonSchema = new() + { + Name = json.SchemaName ?? "JsonSchema", + Schema = JsonSerializer.Deserialize(json.Schema), + Description = json.SchemaDescription + } + }; + break; + + case ChatResponseFormatJson: + request.ResponseFormat = new() { Type = StaticValues.CompletionStatics.ResponseFormat.Json }; + break; + } + + // Tools + request.Tools = options.Tools + ?.OfType() + .Select(f => + { + return ToolDefinition.DefineFunction(new() + { + Name = f.Metadata.Name, + Description = f.Metadata.Description, + Parameters = CreateParameters(f) + }); + }) + .ToList() is { Count: > 0 } tools + ? tools + : null; + if (request.Tools is not null) + { + request.ToolChoice = options.ToolMode is RequiredChatToolMode r ? new() + { + Type = StaticValues.CompletionStatics.ToolChoiceType.Required, + Function = r.RequiredFunctionName is null ? null : new ToolChoice.FunctionTool() { Name = r.RequiredFunctionName } + } : + options.ToolMode is AutoChatToolMode ? new() { Type = StaticValues.CompletionStatics.ToolChoiceType.Auto } : new ToolChoice() { Type = StaticValues.CompletionStatics.ToolChoiceType.None }; + } + } + + // Messages + request.Messages = []; + foreach (var message in chatMessages) + { + foreach (var content in message.Contents) + { + switch (content) + { + case TextContent tc: + request.Messages.Add(new() + { + Content = tc.Text, + Name = message.AuthorName, + Role = message.Role.ToString() + }); + break; + + case ImageContent ic: + request.Messages.Add(new() + { + Contents = + [ + new() + { + Type = "image_url", + ImageUrl = new() + { + Url = ic.Uri, + Detail = ic.AdditionalProperties?.TryGetValue(nameof(MessageImageUrl.Detail), out string? detail) is true ? detail : null + } + } + ], + Name = message.AuthorName, + Role = message.Role.ToString() + }); + break; + + case FunctionResultContent frc: + request.Messages.Add(new() + { + ToolCallId = frc.CallId, + Content = frc.Result?.ToString(), + Name = message.AuthorName, + Role = message.Role.ToString() + }); + break; + } + } + + var functionCallContents = message.Contents.OfType().ToArray(); + if (functionCallContents.Length > 0) + { + request.Messages.Add(new() + { + Name = message.AuthorName, + Role = message.Role.ToString(), + ToolCalls = functionCallContents.Select(fcc => new ToolCall() + { + Type = "function", + Id = fcc.CallId, + FunctionCall = new() + { + Name = fcc.Name, + Arguments = JsonSerializer.Serialize(fcc.Arguments) + } + }) + .ToList() + }); + } + } + + return request; + } + + private static PropertyDefinition CreateParameters(AIFunction f) + { + List required = []; + Dictionary properties = []; + + var parameters = f.Metadata.Parameters; + + foreach (var parameter in parameters) + { + properties.Add(parameter.Name, parameter.Schema is JsonElement e ? e.Deserialize()! : PropertyDefinition.DefineObject(null, null, null, null, null)); + + if (parameter.IsRequired) + { + required.Add(parameter.Name); + } + } + + return PropertyDefinition.DefineObject(properties, required, null, null, null); + } + + private static void PopulateContents(ObjectModels.RequestModels.ChatMessage source, IList destination) + { + if (source.Content is not null) + { + destination.Add(new TextContent(source.Content)); + } + + if (source.Contents is { } contents) + { + foreach (var content in contents) + { + if (content.Text is string text) + { + destination.Add(new TextContent(text)); + } + + if (content.ImageUrl is { } url) + { + destination.Add(new ImageContent(url.Url)); + } + } + } + + if (source.ToolCalls is { } toolCalls) + { + foreach (var tc in toolCalls) + { + destination.Add(new FunctionCallContent(tc.Id ?? string.Empty, tc.FunctionCall?.Name ?? string.Empty, tc.FunctionCall?.Arguments is string a ? JsonSerializer.Deserialize>(a) : null)); + } + } + } + + private static UsageDetails GetUsageDetails(UsageResponse usage) + { + var details = new UsageDetails() + { + InputTokenCount = usage.PromptTokens, + OutputTokenCount = usage.CompletionTokens, + TotalTokenCount = usage.TotalTokens + }; + + if (usage.PromptTokensDetails is { } promptDetails) + { + Dictionary d = new(StringComparer.OrdinalIgnoreCase); + (details.AdditionalProperties ??= [])[nameof(usage.PromptTokensDetails)] = d; + + if (promptDetails.CachedTokens is int cachedTokens) + { + d[nameof(promptDetails.CachedTokens)] = cachedTokens; + } + + if (promptDetails.AudioTokens is int audioTokens) + { + d[nameof(promptDetails.AudioTokens)] = audioTokens; + } + } + + if (usage.CompletionTokensDetails is { } completionDetails) + { + Dictionary d = new(StringComparer.OrdinalIgnoreCase); + (details.AdditionalProperties ??= [])[nameof(usage.CompletionTokensDetails)] = d; + + if (completionDetails.ReasoningTokens is int reasoningTokens) + { + d[nameof(completionDetails.ReasoningTokens)] = reasoningTokens; + } + + if (completionDetails.AudioTokens is int audioTokens) + { + d[nameof(promptDetails.AudioTokens)] = audioTokens; + } + } + + return details; + } +} \ No newline at end of file diff --git a/OpenAI.SDK/Managers/OpenAIEmbeddingGenerator.cs b/OpenAI.SDK/Managers/OpenAIEmbeddingGenerator.cs new file mode 100644 index 00000000..42b65899 --- /dev/null +++ b/OpenAI.SDK/Managers/OpenAIEmbeddingGenerator.cs @@ -0,0 +1,41 @@ +using Microsoft.Extensions.AI; + +namespace Betalgo.Ranul.OpenAI.Managers; + +public partial class OpenAIService : IEmbeddingGenerator> +{ + private EmbeddingGeneratorMetadata? _embeddingMetadata; + + EmbeddingGeneratorMetadata IEmbeddingGenerator>.Metadata => + _embeddingMetadata ??= new(nameof(OpenAIService), _httpClient.BaseAddress, _defaultModelId); + + TService? IEmbeddingGenerator>.GetService(object? key) where TService : class => + this as TService; + + async Task>> IEmbeddingGenerator>.GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken) + { + var response = await this.Embeddings.CreateEmbedding(new() + { + Model = options?.ModelId ?? _defaultModelId, + Dimensions = options?.Dimensions, + InputAsList = values.ToList(), + }, cancellationToken); + + if (!response.Successful) + { + throw new InvalidOperationException(response.Error is { } error ? + $"{response.Error.Code}: {response.Error.Message}" : + "Unknown error"); + } + + return new(response.Data.Select(e => new Embedding(e.Embedding.Select(d => (float)d).ToArray()) { ModelId = response.Model })) + { + Usage = response.Usage is { } usage ? new() + { + InputTokenCount = usage.PromptTokens, + OutputTokenCount = usage.CompletionTokens, + TotalTokenCount = usage.TotalTokens, + } : null, + }; + } +} \ No newline at end of file