diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs index d0e1276462f..b0a46ba19c8 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs @@ -68,6 +68,14 @@ public async Task>> GenerateAsync(IEnumerab _embeddingClient.GenerateEmbeddingsAsync(values, openAIOptions, cancellationToken); var embeddings = (await t.ConfigureAwait(false)).Value; + UsageDetails? usage = embeddings.Usage is not null ? + new() + { + InputTokenCount = embeddings.Usage.InputTokenCount, + TotalTokenCount = embeddings.Usage.TotalTokenCount + } : + null; + return new(embeddings.Select(e => new Embedding(e.ToFloats()) { @@ -75,11 +83,7 @@ public async Task>> GenerateAsync(IEnumerab ModelId = embeddings.Model, })) { - Usage = new() - { - InputTokenCount = embeddings.Usage.InputTokenCount, - TotalTokenCount = embeddings.Usage.TotalTokenCount - }, + Usage = usage, }; } diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorTests.cs index 0db88d499e1..d046fc05521 100644 --- a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorTests.cs @@ -216,6 +216,91 @@ public async Task EmbeddingGenerationOptions_DoNotOverwrite_NotNullPropertiesInR } } + [Fact] + public async Task EmbeddingGenerationOptions_MissingUsage_Ignored() + { + const string Input = """ + { + "input":["hello, world!"], + "model":"text-embedding-3-small", + "encoding_format":"base64" + } + """; + + const string Output = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": "AAAAAA==" + } + ], + "model": "text-embedding-3-small" + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IEmbeddingGenerator> generator = CreateEmbeddingGenerator(httpClient, "text-embedding-3-small"); + + var response = await generator.GenerateAsync(["hello, world!"]); + Assert.NotNull(response); + Assert.Single(response); + Assert.Null(response.Usage); + + foreach (Embedding e in response) + { + Assert.Equal("text-embedding-3-small", e.ModelId); + Assert.NotNull(e.CreatedAt); + Assert.Equal(1, e.Vector.Length); + } + } + + [Fact] + public async Task EmbeddingGenerationOptions_NullUsage_Ignored() + { + const string Input = """ + { + "input":["hello, world!"], + "model":"text-embedding-3-small", + "encoding_format":"base64" + } + """; + + const string Output = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": "AAAAAA==" + } + ], + "model": "text-embedding-3-small", + "usage": null + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IEmbeddingGenerator> generator = CreateEmbeddingGenerator(httpClient, "text-embedding-3-small"); + + var response = await generator.GenerateAsync(["hello, world!"]); + Assert.NotNull(response); + Assert.Single(response); + Assert.Null(response.Usage); + + foreach (Embedding e in response) + { + Assert.Equal("text-embedding-3-small", e.ModelId); + Assert.NotNull(e.CreatedAt); + Assert.Equal(1, e.Vector.Length); + } + } + [Fact] public async Task RequestHeaders_UserAgent_ContainsMEAI() {