diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/Microsoft.Extensions.DataIngestion.csproj b/src/Libraries/Microsoft.Extensions.DataIngestion/Microsoft.Extensions.DataIngestion.csproj index a3e3b4e7a9a..bab4b509a36 100644 --- a/src/Libraries/Microsoft.Extensions.DataIngestion/Microsoft.Extensions.DataIngestion.csproj +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/Microsoft.Extensions.DataIngestion.csproj @@ -18,6 +18,7 @@ + diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/ClassificationEnricher.cs b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/ClassificationEnricher.cs new file mode 100644 index 00000000000..e1cb1ca7438 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/ClassificationEnricher.cs @@ -0,0 +1,132 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.DataIngestion; + +/// +/// Enriches document chunks with a classification label based on their content. +/// +/// This class uses a chat-based language model to analyze the content of document chunks and assign a +/// single, most relevant classification label. The classification is performed using a predefined set of classes, with +/// an optional fallback class for cases where no suitable classification can be determined. +public sealed class ClassificationEnricher : IngestionChunkProcessor +{ + private readonly IChatClient _chatClient; + private readonly ChatOptions? _chatOptions; + private readonly FrozenSet _predefinedClasses; + private readonly ChatMessage _systemPrompt; + + /// + /// Initializes a new instance of the class. + /// + /// The chat client used for classification. + /// The set of predefined classification classes. + /// Options for the chat client. + /// The fallback class to use when no suitable classification is found. When not provided, it defaults to "Unknown". + public ClassificationEnricher(IChatClient chatClient, ReadOnlySpan predefinedClasses, + ChatOptions? chatOptions = null, string? fallbackClass = null) + { + _chatClient = Throw.IfNull(chatClient); + _chatOptions = chatOptions; + if (string.IsNullOrWhiteSpace(fallbackClass)) + { + fallbackClass = "Unknown"; + } + + _predefinedClasses = CreatePredefinedSet(predefinedClasses, fallbackClass!); + _systemPrompt = CreateSystemPrompt(predefinedClasses, fallbackClass!); + } + + /// + /// Gets the metadata key used to store the classification. + /// + public static string MetadataKey => "classification"; + + /// + public override async IAsyncEnumerable> ProcessAsync(IAsyncEnumerable> chunks, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chunks); + + await foreach (IngestionChunk chunk in chunks.WithCancellation(cancellationToken)) + { + var response = await _chatClient.GetResponseAsync( + [ + _systemPrompt, + new(ChatRole.User, chunk.Content) + ], _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false); + + chunk.Metadata[MetadataKey] = _predefinedClasses.Contains(response.Text) + ? response.Text + : throw new InvalidOperationException($"Classification returned an unexpected class: '{response.Text}'."); + + yield return chunk; + } + } + + private static FrozenSet CreatePredefinedSet(ReadOnlySpan predefinedClasses, string fallbackClass) + { + if (predefinedClasses.Length == 0) + { + Throw.ArgumentException(nameof(predefinedClasses), "Predefined classes must be provided."); + } + + HashSet predefinedClassesSet = new(StringComparer.Ordinal) { fallbackClass }; + foreach (string predefinedClass in predefinedClasses) + { +#if NET + if (predefinedClass.Contains(',', StringComparison.Ordinal)) +#else + if (predefinedClass.IndexOf(',') >= 0) +#endif + { + Throw.ArgumentException(nameof(predefinedClasses), $"Predefined class '{predefinedClass}' must not contain ',' character."); + } + + if (!predefinedClassesSet.Add(predefinedClass)) + { + if (predefinedClass.Equals(fallbackClass, StringComparison.Ordinal)) + { + Throw.ArgumentException(nameof(predefinedClasses), $"Fallback class '{fallbackClass}' must not be one of the predefined classes."); + } + + Throw.ArgumentException(nameof(predefinedClasses), $"Duplicate class found: '{predefinedClass}'."); + } + } + + return predefinedClassesSet.ToFrozenSet(); + } + + private static ChatMessage CreateSystemPrompt(ReadOnlySpan predefinedClasses, string fallbackClass) + { + StringBuilder sb = new("You are a classification expert. Analyze the given text and assign a single, most relevant class. Use only the following predefined classes: "); + +#if NET9_0_OR_GREATER + sb.AppendJoin(", ", predefinedClasses!); +#else +#pragma warning disable IDE0058 // Expression value is never used + for (int i = 0; i < predefinedClasses.Length; i++) + { + sb.Append(predefinedClasses[i]); + if (i < predefinedClasses.Length - 1) + { + sb.Append(", "); + } + } +#endif + sb.Append(" and return ").Append(fallbackClass).Append(" when unable to classify."); +#pragma warning restore IDE0058 // Expression value is never used + + return new(ChatRole.System, sb.ToString()); + } +} diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/ImageAlternativeTextEnricher.cs b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/ImageAlternativeTextEnricher.cs new file mode 100644 index 00000000000..5f68552cc3f --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/ImageAlternativeTextEnricher.cs @@ -0,0 +1,74 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.DataIngestion; + +/// +/// Enriches elements with alternative text using an AI service, +/// so the generated embeddings can include the image content information. +/// +public sealed class ImageAlternativeTextEnricher : IngestionDocumentProcessor +{ + private readonly IChatClient _chatClient; + private readonly ChatOptions? _chatOptions; + private readonly ChatMessage _systemPrompt; + + /// + /// Initializes a new instance of the class. + /// + /// The chat client used to get responses for generating alternative text. + /// Options for the chat client. + public ImageAlternativeTextEnricher(IChatClient chatClient, ChatOptions? chatOptions = null) + { + _chatClient = Throw.IfNull(chatClient); + _chatOptions = chatOptions; + _systemPrompt = new(ChatRole.System, "Write a detailed alternative text for this image with less than 50 words."); + } + + /// + public override async Task ProcessAsync(IngestionDocument document, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(document); + + foreach (var element in document.EnumerateContent()) + { + if (element is IngestionDocumentImage image) + { + await ProcessAsync(image, cancellationToken).ConfigureAwait(false); + } + else if (element is IngestionDocumentTable table) + { + foreach (var cell in table.Cells) + { + if (cell is IngestionDocumentImage cellImage) + { + await ProcessAsync(cellImage, cancellationToken).ConfigureAwait(false); + } + } + } + } + + return document; + } + + private async Task ProcessAsync(IngestionDocumentImage image, CancellationToken cancellationToken) + { + if (image.Content.HasValue && !string.IsNullOrEmpty(image.MediaType) + && string.IsNullOrEmpty(image.AlternativeText)) + { + var response = await _chatClient.GetResponseAsync( + [ + _systemPrompt, + new(ChatRole.User, [new DataContent(image.Content.Value, image.MediaType!)]) + ], _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false); + + image.AlternativeText = response.Text; + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/KeywordEnricher.cs b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/KeywordEnricher.cs new file mode 100644 index 00000000000..56a305e2a87 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/KeywordEnricher.cs @@ -0,0 +1,160 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.DataIngestion; + +/// +/// Enriches chunks with keyword extraction using an AI chat model. +/// +/// +/// It adds "keywords" metadata to each chunk. It's an array of strings representing the extracted keywords. +/// +public sealed class KeywordEnricher : IngestionChunkProcessor +{ + private const int DefaultMaxKeywords = 5; +#if NET + private static readonly System.Buffers.SearchValues _illegalCharacters = System.Buffers.SearchValues.Create([';', ',']); +#else + private static readonly char[] _illegalCharacters = [';', ',']; +#endif + private readonly IChatClient _chatClient; + private readonly ChatOptions? _chatOptions; + private readonly FrozenSet? _predefinedKeywords; + private readonly ChatMessage _systemPrompt; + + /// + /// Initializes a new instance of the class. + /// + /// The chat client used for keyword extraction. + /// The set of predefined keywords for extraction. + /// Options for the chat client. + /// The maximum number of keywords to extract. When not provided, it defaults to 5. + /// The confidence threshold for keyword inclusion. When not provided, it defaults to 0.7. + /// + /// If no predefined keywords are provided, the model will extract keywords based on the content alone. + /// Such results may vary more significantly between different AI models. + /// + public KeywordEnricher(IChatClient chatClient, ReadOnlySpan predefinedKeywords, + ChatOptions? chatOptions = null, int? maxKeywords = null, double? confidenceThreshold = null) + { + _chatClient = Throw.IfNull(chatClient); + _chatOptions = chatOptions; + _predefinedKeywords = CreatePredfinedKeywords(predefinedKeywords); + + double threshold = confidenceThreshold.HasValue + ? Throw.IfOutOfRange(confidenceThreshold.Value, 0.0, 1.0, nameof(confidenceThreshold)) + : 0.7; + int keywordsCount = maxKeywords.HasValue + ? Throw.IfLessThanOrEqual(maxKeywords.Value, 0, nameof(maxKeywords)) + : DefaultMaxKeywords; + _systemPrompt = CreateSystemPrompt(keywordsCount, predefinedKeywords, threshold); + } + + /// + /// Gets the metadata key used to store the keywords. + /// + public static string MetadataKey => "keywords"; + + /// + public override async IAsyncEnumerable> ProcessAsync(IAsyncEnumerable> chunks, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chunks); + + await foreach (IngestionChunk chunk in chunks.WithCancellation(cancellationToken)) + { + // Structured response is not used here because it's not part of Microsoft.Extensions.AI.Abstractions. + var response = await _chatClient.GetResponseAsync( + [ + _systemPrompt, + new(ChatRole.User, chunk.Content) + ], _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false); + +#pragma warning disable EA0009 // Use 'System.MemoryExtensions.Split' for improved performance + string[] keywords = response.Text.Split(';'); + if (_predefinedKeywords is not null) + { + foreach (var keyword in keywords) + { + if (!_predefinedKeywords.Contains(keyword)) + { + throw new InvalidOperationException($"The extracted keyword '{keyword}' is not in the predefined keywords list."); + } + } + } + + chunk.Metadata[MetadataKey] = keywords; + + yield return chunk; + } + } + + private static FrozenSet? CreatePredfinedKeywords(ReadOnlySpan predefinedKeywords) + { + if (predefinedKeywords.Length == 0) + { + return null; + } + + HashSet result = new(StringComparer.Ordinal); + foreach (string keyword in predefinedKeywords) + { +#if NET + if (keyword.AsSpan().ContainsAny(_illegalCharacters)) +#else + if (keyword.IndexOfAny(_illegalCharacters) >= 0) +#endif + { + Throw.ArgumentException(nameof(predefinedKeywords), $"Predefined keyword '{keyword}' contains an invalid character (';' or ',')."); + } + + if (!result.Add(keyword)) + { + Throw.ArgumentException(nameof(predefinedKeywords), $"Duplicate keyword found: '{keyword}'"); + } + } + + return result.ToFrozenSet(StringComparer.Ordinal); + } + + private static ChatMessage CreateSystemPrompt(int maxKeywords, ReadOnlySpan predefinedKeywords, double confidenceThreshold) + { + StringBuilder sb = new($"You are a keyword extraction expert. Analyze the given text and extract up to {maxKeywords} most relevant keywords. "); + + if (predefinedKeywords.Length > 0) + { +#pragma warning disable IDE0058 // Expression value is never used + sb.Append("Focus on extracting keywords from the following predefined list: "); +#if NET9_0_OR_GREATER + sb.AppendJoin(", ", predefinedKeywords!); +#else + for (int i = 0; i < predefinedKeywords.Length; i++) + { + sb.Append(predefinedKeywords[i]); + if (i < predefinedKeywords.Length - 1) + { + sb.Append(", "); + } + } +#endif + + sb.Append(". "); + } + + sb.Append("Exclude keywords with confidence score below ").Append(confidenceThreshold).Append('.'); + sb.Append(" Return just the keywords separated with ';'."); +#pragma warning restore IDE0058 // Expression value is never used + + return new(ChatRole.System, sb.ToString()); + } +} diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/SentimentEnricher.cs b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/SentimentEnricher.cs new file mode 100644 index 00000000000..4873842d1c4 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/SentimentEnricher.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.DataIngestion; + +/// +/// Enriches chunks with sentiment analysis using an AI chat model. +/// +/// +/// It adds "sentiment" metadata to each chunk. It can be Positive, Negative, Neutral or Unknown when confidence score is below the threshold. +/// +public sealed class SentimentEnricher : IngestionChunkProcessor +{ + private readonly IChatClient _chatClient; + private readonly ChatOptions? _chatOptions; + private readonly FrozenSet _validSentiments = +#if NET9_0_OR_GREATER + FrozenSet.Create(StringComparer.Ordinal, "Positive", "Negative", "Neutral", "Unknown"); +#else + new string[] { "Positive", "Negative", "Neutral", "Unknown" }.ToFrozenSet(StringComparer.Ordinal); +#endif + private readonly ChatMessage _systemPrompt; + + /// + /// Initializes a new instance of the class. + /// + /// The chat client used for sentiment analysis. + /// Options for the chat client. + /// The confidence threshold for sentiment determination. When not provided, it defaults to 0.7. + public SentimentEnricher(IChatClient chatClient, ChatOptions? chatOptions = null, double? confidenceThreshold = null) + { + _chatClient = Throw.IfNull(chatClient); + _chatOptions = chatOptions; + + double threshold = confidenceThreshold.HasValue ? Throw.IfOutOfRange(confidenceThreshold.Value, 0.0, 1.0, nameof(confidenceThreshold)) : 0.7; + + string prompt = $""" + You are a sentiment analysis expert. Analyze the sentiment of the given text and return Positive/Negative/Neutral or + Unknown when confidence score is below {threshold}. Return just the value of the sentiment. + """; + _systemPrompt = new(ChatRole.System, prompt); + } + + /// + /// Gets the metadata key used to store the sentiment. + /// + public static string MetadataKey => "sentiment"; + + /// + public override async IAsyncEnumerable> ProcessAsync(IAsyncEnumerable> chunks, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chunks); + + await foreach (var chunk in chunks.WithCancellation(cancellationToken)) + { + var response = await _chatClient.GetResponseAsync( + [ + _systemPrompt, + new(ChatRole.User, chunk.Content) + ], _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false); + + if (!_validSentiments.Contains(response.Text)) + { + throw new InvalidOperationException($"Invalid sentiment response: '{response.Text}'."); + } + + chunk.Metadata[MetadataKey] = response.Text; + + yield return chunk; + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/SummaryEnricher.cs b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/SummaryEnricher.cs new file mode 100644 index 00000000000..f91b9809b05 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/SummaryEnricher.cs @@ -0,0 +1,65 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.DataIngestion; + +/// +/// Enriches chunks with summary text using an AI chat model. +/// +/// +/// It adds "summary" text metadata to each chunk. +/// +public sealed class SummaryEnricher : IngestionChunkProcessor +{ + private readonly IChatClient _chatClient; + private readonly ChatOptions? _chatOptions; + private readonly ChatMessage _systemPrompt; + + /// + /// Initializes a new instance of the class. + /// + /// The chat client used for summary generation. + /// Options for the chat client. + /// The maximum number of words for the summary. When not provided, it defaults to 100. + public SummaryEnricher(IChatClient chatClient, ChatOptions? chatOptions = null, int? maxWordCount = null) + { + _chatClient = Throw.IfNull(chatClient); + _chatOptions = chatOptions; + + int wordCount = maxWordCount.HasValue ? Throw.IfLessThanOrEqual(maxWordCount.Value, 0, nameof(maxWordCount)) : 100; + _systemPrompt = new(ChatRole.System, $"Write a summary text for this text with no more than {wordCount} words. Return just the summary."); + } + + /// + /// Gets the metadata key used to store the summary. + /// + public static string MetadataKey => "summary"; + + /// + public override async IAsyncEnumerable> ProcessAsync(IAsyncEnumerable> chunks, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chunks); + + await foreach (var chunk in chunks.WithCancellation(cancellationToken)) + { + var response = await _chatClient.GetResponseAsync( + [ + _systemPrompt, + new(ChatRole.User, chunk.Content) + ], _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false); + + chunk.Metadata[MetadataKey] = response.Text; + + yield return chunk; + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Microsoft.Extensions.DataIngestion.Tests.csproj b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Microsoft.Extensions.DataIngestion.Tests.csproj index d00b7b652e6..b5ff0659d57 100644 --- a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Microsoft.Extensions.DataIngestion.Tests.csproj +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Microsoft.Extensions.DataIngestion.Tests.csproj @@ -3,8 +3,8 @@ $(NoWarn);S3967 - - $(NoWarn);RT0002 + + $(NoWarn);CA1063 x64 @@ -20,8 +20,9 @@ + - + diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/AlternativeTextEnricherTests.cs b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/AlternativeTextEnricherTests.cs new file mode 100644 index 00000000000..cc59db3f389 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/AlternativeTextEnricherTests.cs @@ -0,0 +1,110 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Xunit; + +namespace Microsoft.Extensions.DataIngestion.Processors.Tests; + +public class AlternativeTextEnricherTests +{ + [Fact] + public void ThrowsOnNullChatClient() + { + Assert.Throws("chatClient", () => new ImageAlternativeTextEnricher(null!)); + } + + [Fact] + public async Task ThrowsOnNullDocument() + { + using TestChatClient chatClient = new(); + + ImageAlternativeTextEnricher sut = new(chatClient); + + await Assert.ThrowsAsync("document", async () => await sut.ProcessAsync(null!)); + } + + [Fact] + public async Task CanGenerateImageAltText() + { + const string PreExistingAltText = "Pre-existing alt text"; + ReadOnlyMemory imageContent = new byte[256]; + + int counter = 0; + string[] descriptions = { "First alt text", "Second alt text" }; + using TestChatClient chatClient = new() + { + GetResponseAsyncCallback = (messages, options, cancellationToken) => + { + var materializedMessages = messages.ToArray(); + + Assert.Equal(2, materializedMessages.Length); + Assert.Equal(ChatRole.System, materializedMessages[0].Role); + Assert.Equal(ChatRole.User, materializedMessages[1].Role); + var content = Assert.Single(materializedMessages[1].Contents); + DataContent dataContent = Assert.IsType(content); + Assert.Equal("image/png", dataContent.MediaType); + Assert.Equal(imageContent.ToArray(), dataContent.Data.ToArray()); + + return Task.FromResult(new ChatResponse(new[] + { + new ChatMessage(ChatRole.Assistant, descriptions[counter++]) + })); + } + }; + ImageAlternativeTextEnricher sut = new(chatClient); + + IngestionDocumentImage documentImage = new($"![](nonExisting.png)") + { + AlternativeText = null, + Content = imageContent, + MediaType = "image/png" + }; + + IngestionDocumentImage tableCell = new($"![](another.png)") + { + AlternativeText = null, + Content = imageContent, + MediaType = "image/png" + }; + + IngestionDocumentImage imageWithAltText = new($"![](noChangesNeeded.png)") + { + AlternativeText = PreExistingAltText, + Content = imageContent, + MediaType = "image/png" + }; + + IngestionDocumentImage imageWithNoContent = new($"![](noImage.png)") + { + AlternativeText = null, + Content = default, + MediaType = "image/png" + }; + + IngestionDocument document = new("withImage") + { + Sections = + { + new IngestionDocumentSection + { + Elements = + { + documentImage, + new IngestionDocumentTable("nvm", new[,] { { tableCell } }) + } + } + } + }; + + await sut.ProcessAsync(document); + + Assert.Equal(descriptions[0], documentImage.AlternativeText); + Assert.Equal(descriptions[1], tableCell.AlternativeText); + Assert.Same(PreExistingAltText, imageWithAltText.AlternativeText); + Assert.Null(imageWithNoContent.AlternativeText); + } +} diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/ClassificationEnricherTests.cs b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/ClassificationEnricherTests.cs new file mode 100644 index 00000000000..3f890969262 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/ClassificationEnricherTests.cs @@ -0,0 +1,134 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Xunit; + +namespace Microsoft.Extensions.DataIngestion.Processors.Tests; + +public class ClassificationEnricherTests +{ + private static readonly IngestionDocument _document = new("test"); + + [Fact] + public void ThrowsOnNullChatClient() + { + Assert.Throws("chatClient", () => new ClassificationEnricher(null!, predefinedClasses: ["some"])); + } + + [Fact] + public void ThrowsOnEmptyPredefinedClasses() + { + Assert.Throws("predefinedClasses", () => new ClassificationEnricher(new TestChatClient(), predefinedClasses: [])); + } + + [Fact] + public void ThrowsOnDuplicatePredefinedClasses() + { + Assert.Throws("predefinedClasses", () => new ClassificationEnricher(new TestChatClient(), predefinedClasses: ["same", "same"])); + } + + [Fact] + public void ThrowsOnPredefinedClassesContainingFallback() + { + Assert.Throws("predefinedClasses", () => new ClassificationEnricher(new TestChatClient(), predefinedClasses: ["same", "Unknown"])); + } + + [Fact] + public void ThrowsOnFallbackInPredefinedClasses() + { + Assert.Throws("predefinedClasses", () => new ClassificationEnricher(new TestChatClient(), predefinedClasses: ["some"], fallbackClass: "some")); + } + + [Fact] + public void ThrowsOnPredefinedClassesContainingComma() + { + Assert.Throws("predefinedClasses", () => new ClassificationEnricher(new TestChatClient(), predefinedClasses: ["n,t"])); + } + + [Fact] + public async Task ThrowsOnNullChunks() + { + using TestChatClient chatClient = new(); + ClassificationEnricher sut = new(chatClient, predefinedClasses: ["some"]); + + await Assert.ThrowsAsync("chunks", async () => + { + await foreach (var _ in sut.ProcessAsync(null!)) + { + // No-op + } + }); + } + + [Fact] + public async Task CanClassify() + { + int counter = 0; + string[] classes = ["AI", "Animals", "UFO"]; + using TestChatClient chatClient = new() + { + GetResponseAsyncCallback = (messages, options, cancellationToken) => + { + var materializedMessages = messages.ToArray(); + + Assert.Equal(2, materializedMessages.Length); + Assert.Equal(ChatRole.System, materializedMessages[0].Role); + Assert.Equal(ChatRole.User, materializedMessages[1].Role); + + return Task.FromResult(new ChatResponse(new[] + { + new ChatMessage(ChatRole.Assistant, classes[counter++]) + })); + } + }; + ClassificationEnricher sut = new(chatClient, ["AI", "Animals", "Sports"], fallbackClass: "UFO"); + + IReadOnlyList> got = await sut.ProcessAsync(CreateChunks().ToAsyncEnumerable()).ToListAsync(); + + Assert.Equal(3, got.Count); + Assert.Equal("AI", got[0].Metadata[ClassificationEnricher.MetadataKey]); + Assert.Equal("Animals", got[1].Metadata[ClassificationEnricher.MetadataKey]); + Assert.Equal("UFO", got[2].Metadata[ClassificationEnricher.MetadataKey]); + } + + [Fact] + public async Task ThrowsOnInvalidResponse() + { + using TestChatClient chatClient = new() + { + GetResponseAsyncCallback = (messages, options, cancellationToken) => + { + return Task.FromResult(new ChatResponse(new[] + { + new ChatMessage(ChatRole.Assistant, "Unexpected result!") + })); + } + }; + + ClassificationEnricher sut = new(chatClient, ["AI", "Animals", "Sports"]); + var input = CreateChunks().ToAsyncEnumerable(); + + await Assert.ThrowsAsync(async () => + { + await foreach (var _ in sut.ProcessAsync(input)) + { + // No-op + } + }); + } + + private static List> CreateChunks() => + [ + new(".NET developers need to integrate and interact with a growing variety of artificial intelligence (AI) services in their apps. " + + "The Microsoft.Extensions.AI libraries provide a unified approach for representing generative AI components, and enable seamless" + + " integration and interoperability with various AI services.", _document), + new ("Rabbits are small mammals in the family Leporidae of the order Lagomorpha (along with the hare and the pika)." + + "They are herbivorous animals and are known for their long ears, large hind legs, and short fluffy tails.", _document), + new("This text does not belong to any category.", _document), + ]; +} diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/KeywordEnricherTests.cs b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/KeywordEnricherTests.cs new file mode 100644 index 00000000000..0f11cd7d46b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/KeywordEnricherTests.cs @@ -0,0 +1,132 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Xunit; + +namespace Microsoft.Extensions.DataIngestion.Processors.Tests; + +public class KeywordEnricherTests +{ + private static readonly IngestionDocument _document = new("test"); + + [Fact] + public void ThrowsOnNullChatClient() + { + Assert.Throws("chatClient", () => new KeywordEnricher(null!, predefinedKeywords: null, confidenceThreshold: 0.5)); + } + + [Theory] + [InlineData(-0.1)] + [InlineData(1.1)] + public void ThrowsOnInvalidThreshold(double threshold) + { + Assert.Throws("confidenceThreshold", () => new KeywordEnricher(new TestChatClient(), predefinedKeywords: null, confidenceThreshold: threshold)); + } + + [Theory] + [InlineData(0)] + [InlineData(-1)] + public void ThrowsOnInvalidMaxKeywords(int keywordCount) + { + Assert.Throws("maxKeywords", () => new KeywordEnricher(new TestChatClient(), predefinedKeywords: null, maxKeywords: keywordCount)); + } + + [Fact] + public void ThrowsOnDuplicateKeywords() + { + Assert.Throws("predefinedKeywords", () => new KeywordEnricher(new TestChatClient(), predefinedKeywords: ["same", "same"], confidenceThreshold: 0.5)); + } + + [Theory] + [InlineData(',')] + [InlineData(';')] + public void ThrowsOnIllegalCharacters(char illegal) + { + Assert.Throws("predefinedKeywords", () => new KeywordEnricher(new TestChatClient(), predefinedKeywords: [$"n{illegal}t"])); + } + + [Fact] + public async Task ThrowsOnNullChunks() + { + using TestChatClient chatClient = new(); + KeywordEnricher sut = new(chatClient, predefinedKeywords: null, confidenceThreshold: 0.5); + + await Assert.ThrowsAsync("chunks", async () => + { + await foreach (var _ in sut.ProcessAsync(null!)) + { + // No-op + } + }); + } + + [Theory] + [InlineData] + [InlineData("AI", "MEAI", "Animals", "Rabbits")] + public async Task CanExtractKeywords(params string[] predefined) + { + int counter = 0; + string[] keywords = { "AI;MEAI", "Animals;Rabbits" }; + using TestChatClient chatClient = new() + { + GetResponseAsyncCallback = (messages, options, cancellationToken) => + { + var materializedMessages = messages.ToArray(); + + Assert.Equal(2, materializedMessages.Length); + Assert.Equal(ChatRole.System, materializedMessages[0].Role); + Assert.Equal(ChatRole.User, materializedMessages[1].Role); + + return Task.FromResult(new ChatResponse(new[] + { + new ChatMessage(ChatRole.Assistant, keywords[counter++]) + })); + } + }; + + KeywordEnricher sut = new(chatClient, predefinedKeywords: predefined, confidenceThreshold: 0.5); + var chunks = CreateChunks().ToAsyncEnumerable(); + + IReadOnlyList> got = await sut.ProcessAsync(chunks).ToListAsync(); + + Assert.Equal(["AI", "MEAI"], (string[])got[0].Metadata[KeywordEnricher.MetadataKey]); + Assert.Equal(["Animals", "Rabbits"], (string[])got[1].Metadata[KeywordEnricher.MetadataKey]); + } + + [Fact] + public async Task ThrowsOnInvalidResponse() + { + using TestChatClient chatClient = new() + { + GetResponseAsyncCallback = (messages, options, cancellationToken) => + { + return Task.FromResult(new ChatResponse(new[] + { + new ChatMessage(ChatRole.Assistant, "Unexpected result!") + })); + } + }; + + KeywordEnricher sut = new(chatClient, ["some"]); + var input = CreateChunks().ToAsyncEnumerable(); + + await Assert.ThrowsAsync(async () => + { + await foreach (var _ in sut.ProcessAsync(input)) + { + // No-op + } + }); + } + + private static List> CreateChunks() => + [ + new("The Microsoft.Extensions.AI libraries provide a unified approach for representing generative AI components", _document), + new("Rabbits are great pets. They are friendly and make excellent companions.", _document) + ]; +} diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/SentimentEnricherTests.cs b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/SentimentEnricherTests.cs new file mode 100644 index 00000000000..166b3c05959 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/SentimentEnricherTests.cs @@ -0,0 +1,113 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Xunit; + +namespace Microsoft.Extensions.DataIngestion.Processors.Tests; + +public class SentimentEnricherTests +{ + private static readonly IngestionDocument _document = new("test"); + + [Fact] + public void ThrowsOnNullChatClient() + { + Assert.Throws("chatClient", () => new SentimentEnricher(null!)); + } + + [Theory] + [InlineData(-0.1)] + [InlineData(1.1)] + public void ThrowsOnInvalidThreshold(double threshold) + { + Assert.Throws("confidenceThreshold", () => new SentimentEnricher(new TestChatClient(), confidenceThreshold: threshold)); + } + + [Fact] + public async Task ThrowsOnNullChunks() + { + using TestChatClient chatClient = new(); + SentimentEnricher sut = new(chatClient); + + await Assert.ThrowsAsync("chunks", async () => + { + await foreach (var _ in sut.ProcessAsync(null!)) + { + // No-op + } + }); + } + + [Fact] + public async Task CanProvideSentiment() + { + int counter = 0; + string[] sentiments = { "Positive", "Negative", "Neutral", "Unknown" }; + using TestChatClient chatClient = new() + { + GetResponseAsyncCallback = (messages, options, cancellationToken) => + { + var materializedMessages = messages.ToArray(); + + Assert.Equal(2, materializedMessages.Length); + Assert.Equal(ChatRole.System, materializedMessages[0].Role); + Assert.Equal(ChatRole.User, materializedMessages[1].Role); + + return Task.FromResult(new ChatResponse(new[] + { + new ChatMessage(ChatRole.Assistant, sentiments[counter++]) + })); + } + }; + SentimentEnricher sut = new(chatClient); + var input = CreateChunks().ToAsyncEnumerable(); + + var chunks = await sut.ProcessAsync(input).ToListAsync(); + + Assert.Equal(4, chunks.Count); + + Assert.Equal("Positive", chunks[0].Metadata[SentimentEnricher.MetadataKey]); + Assert.Equal("Negative", chunks[1].Metadata[SentimentEnricher.MetadataKey]); + Assert.Equal("Neutral", chunks[2].Metadata[SentimentEnricher.MetadataKey]); + Assert.Equal("Unknown", chunks[3].Metadata[SentimentEnricher.MetadataKey]); + } + + [Fact] + public async Task ThrowsOnInvalidResponse() + { + using TestChatClient chatClient = new() + { + GetResponseAsyncCallback = (messages, options, cancellationToken) => + { + return Task.FromResult(new ChatResponse(new[] + { + new ChatMessage(ChatRole.Assistant, "Unexpected result!") + })); + } + }; + + SentimentEnricher sut = new(chatClient); + var input = CreateChunks().ToAsyncEnumerable(); + + await Assert.ThrowsAsync(async () => + { + await foreach (var _ in sut.ProcessAsync(input)) + { + // No-op + } + }); + } + + private static List> CreateChunks() => + [ + new("I love programming! It's so much fun and rewarding.", _document), + new("I hate bugs. They are so frustrating and time-consuming.", _document), + new("The weather is okay, not too bad but not great either.", _document), + new("I hate you. I am sorry, I actually don't. I am not sure myself what my feelings are.", _document) + ]; +} diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/SummaryEnricherTests.cs b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/SummaryEnricherTests.cs new file mode 100644 index 00000000000..6fda37004d3 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/SummaryEnricherTests.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Xunit; + +namespace Microsoft.Extensions.DataIngestion.Processors.Tests; + +public class SummaryEnricherTests +{ + private static readonly IngestionDocument _document = new("test"); + + [Fact] + public void ThrowsOnNullChatClient() + { + Assert.Throws("chatClient", () => new SummaryEnricher(null!)); + } + + [Theory] + [InlineData(0)] + [InlineData(-1)] + public void ThrowsOnInvalidMaxKeywords(int wordCount) + { + Assert.Throws("maxWordCount", () => new SummaryEnricher(new TestChatClient(), maxWordCount: wordCount)); + } + + [Fact] + public async Task ThrowsOnNullChunks() + { + using TestChatClient chatClient = new(); + SummaryEnricher sut = new(chatClient); + + await Assert.ThrowsAsync("chunks", async () => + { + await foreach (var _ in sut.ProcessAsync(null!)) + { + // No-op + } + }); + } + + [Fact] + public async Task CanProvideSummary() + { + int counter = 0; + string[] summaries = { "First summary.", "Second summary." }; + using TestChatClient chatClient = new() + { + GetResponseAsyncCallback = (messages, options, cancellationToken) => + { + var materializedMessages = messages.ToArray(); + + Assert.Equal(2, materializedMessages.Length); + Assert.Equal(ChatRole.System, materializedMessages[0].Role); + Assert.Equal(ChatRole.User, materializedMessages[1].Role); + + return Task.FromResult(new ChatResponse(new[] + { + new ChatMessage(ChatRole.Assistant, summaries[counter++]) + })); + } + }; + SummaryEnricher sut = new(chatClient); + var input = CreateChunks().ToAsyncEnumerable(); + + var chunks = await sut.ProcessAsync(input).ToListAsync(); + + Assert.Equal(2, chunks.Count); + Assert.Equal(summaries[0], (string)chunks[0].Metadata[SummaryEnricher.MetadataKey]!); + Assert.Equal(summaries[1], (string)chunks[1].Metadata[SummaryEnricher.MetadataKey]!); + } + + private static List> CreateChunks() => + [ + new("I love programming! It's so much fun and rewarding.", _document), + new("I hate bugs. They are so frustrating and time-consuming.", _document) + ]; +} diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Utils/IAsyncEnumerableExtensions.cs b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Utils/IAsyncEnumerableExtensions.cs index 60120dded5d..bb30b585233 100644 --- a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Utils/IAsyncEnumerableExtensions.cs +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Utils/IAsyncEnumerableExtensions.cs @@ -49,4 +49,15 @@ internal static async ValueTask SingleAsync(this IAsyncEnumerable sourc ? result : throw new InvalidOperationException(); } + + internal static async ValueTask> ToListAsync(this IAsyncEnumerable source) + { + List list = []; + await foreach (var item in source) + { + list.Add(item); + } + + return list; + } }