diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContent.cs index 6895d4c1e42..6562b7bcc42 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContent.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContent.cs @@ -11,6 +11,7 @@ namespace Microsoft.Extensions.AI; [JsonDerivedType(typeof(FunctionCallContent), typeDiscriminator: "functionCall")] [JsonDerivedType(typeof(FunctionResultContent), typeDiscriminator: "functionResult")] [JsonDerivedType(typeof(TextContent), typeDiscriminator: "text")] +[JsonDerivedType(typeof(UriContent), typeDiscriminator: "uri")] [JsonDerivedType(typeof(UsageContent), typeDiscriminator: "usage")] public class AIContent { diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataContent.cs index 041d33a9704..dc0c5db9289 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataContent.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataContent.cs @@ -8,17 +8,17 @@ using Microsoft.Shared.Diagnostics; #pragma warning disable S3996 // URI properties should not be strings +#pragma warning disable CA1054 // URI-like parameters should not be strings #pragma warning disable CA1056 // URI-like properties should not be strings namespace Microsoft.Extensions.AI; /// -/// Represents data content, such as an image or audio. +/// Represents binary content with an associated media type (also known as MIME type). /// /// /// -/// The represented content may either be the actual bytes stored in this instance, or it may -/// be a URI that references the location of the content. +/// The content represents in-memory data. For references to data at a remote URI, use instead. /// /// /// always returns a valid URI string, even if the instance was constructed from @@ -32,20 +32,27 @@ public class DataContent : AIContent // Ideally DataContent would be based in terms of Uri. However, Uri has a length limitation that makes it prohibitive // for the kinds of data URIs necessary to support here. As such, this type is based in strings. + /// Parsed data URI information. + private readonly DataUriParser.DataUri? _dataUri; + /// The string-based representation of the URI, including any data in the instance. private string? _uri; /// The data, lazily initialized if the data is provided in a data URI. private ReadOnlyMemory? _data; - /// Parsed data URI information. - private DataUriParser.DataUri? _dataUri; - /// /// Initializes a new instance of the class. /// - /// The URI of the content. This can be a data URI. - /// The media type (also known as MIME type) represented by the content. + /// The data URI containing the content. + /// + /// The media type (also known as MIME type) represented by the content. If not provided, + /// it must be provided as part of the . + /// + /// is . + /// is not a data URI. + /// did not contain a media type and was not supplied. + /// is an invalid media type. public DataContent(Uri uri, string? mediaType = null) : this(Throw.IfNull(uri).ToString(), mediaType) { @@ -54,42 +61,48 @@ public DataContent(Uri uri, string? mediaType = null) /// /// Initializes a new instance of the class. /// - /// The URI of the content. This can be a data URI. + /// The data URI containing the content. /// The media type (also known as MIME type) represented by the content. + /// is . + /// is not a data URI. + /// did not contain a media type and was not supplied. + /// is an invalid media type. [JsonConstructor] public DataContent([StringSyntax(StringSyntaxAttribute.Uri)] string uri, string? mediaType = null) { _uri = Throw.IfNullOrWhitespace(uri); - ValidateMediaType(ref mediaType); - MediaType = mediaType; - - if (uri.StartsWith(DataUriParser.Scheme, StringComparison.OrdinalIgnoreCase)) + if (!uri.StartsWith(DataUriParser.Scheme, StringComparison.OrdinalIgnoreCase)) { - _dataUri = DataUriParser.Parse(uri.AsMemory()); + Throw.ArgumentException(nameof(uri), "The provided URI is not a data URI."); + } - // If the data URI contains a media type that's different from a non-null media type - // explicitly provided, prefer the one explicitly provided as an override. - if (MediaType is not null) - { - if (MediaType != _dataUri.MediaType) - { - // Extract the bytes from the data URI and null out the uri. - // Then we'll lazily recreate it later if needed based on the updated media type. - _data = _dataUri.ToByteArray(); - _dataUri = null; - _uri = null; - } - } - else + _dataUri = DataUriParser.Parse(uri.AsMemory()); + + if (mediaType is null) + { + mediaType = _dataUri.MediaType; + if (mediaType is null) { - MediaType = _dataUri.MediaType; + Throw.ArgumentNullException(nameof(mediaType), $"{nameof(uri)} did not contain a media type, and {nameof(mediaType)} was not provided."); } } - else if (!System.Uri.TryCreate(uri, UriKind.Absolute, out _)) + else { - throw new UriFormatException("The URI is not well-formed."); + if (mediaType != _dataUri.MediaType) + { + // If the data URI contains a media type that's different from a non-null media type + // explicitly provided, prefer the one explicitly provided as an override. + + // Extract the bytes from the data URI and null out the uri. + // Then we'll lazily recreate it later if needed based on the updated media type. + _data = _dataUri.ToByteArray(); + _dataUri = null; + _uri = null; + } } + + MediaType = DataUriParser.ThrowIfInvalidMediaType(mediaType); } /// @@ -97,32 +110,29 @@ public DataContent([StringSyntax(StringSyntaxAttribute.Uri)] string uri, string? /// /// The byte contents. /// The media type (also known as MIME type) represented by the content. - public DataContent(ReadOnlyMemory data, string? mediaType = null) + /// is null. + /// is empty or composed entirely of whitespace. + public DataContent(ReadOnlyMemory data, string mediaType) { - ValidateMediaType(ref mediaType); - MediaType = mediaType; + MediaType = DataUriParser.ThrowIfInvalidMediaType(mediaType); _data = data; } /// - /// Determines whether the has the specified prefix. + /// Determines whether the 's top-level type matches the specified . /// - /// The media type prefix. - /// if the has the specified prefix, otherwise . - public bool MediaTypeStartsWith(string prefix) - => MediaType?.StartsWith(prefix, StringComparison.OrdinalIgnoreCase) is true; - - /// Sets to null if it's empty or composed entirely of whitespace. - private static void ValidateMediaType(ref string? mediaType) - { - if (!DataUriParser.IsValidMediaType(mediaType.AsSpan(), ref mediaType)) - { - Throw.ArgumentException(nameof(mediaType), "Invalid media type."); - } - } + /// The type to compare against . + /// if the type portion of matches the specified value; otherwise, false. + /// + /// A media type is primarily composed of two parts, a "type" and a "subtype", separated by a slash ("/"). + /// The type portion is also referred to as the "top-level type"; for example, + /// "image/png" has a top-level type of "image". compares + /// the specified against the type portion of . + /// + public bool HasTopLevelMediaType(string topLevelType) => DataUriParser.HasTopLevelMediaType(MediaType, topLevelType); - /// Gets the URI for this . + /// Gets the data URI for this . /// /// The returned URI is always a valid URI string, even if the instance was constructed from a /// or from a . In the case of a , this property returns a data URI containing @@ -137,8 +147,8 @@ public string Uri { if (_dataUri is null) { - Debug.Assert(Data is not null, "Expected Data to be initialized."); - _uri = string.Concat("data:", MediaType, ";base64,", Convert.ToBase64String(Data.GetValueOrDefault() + Debug.Assert(_data is not null, "Expected _data to be initialized."); + _uri = string.Concat("data:", MediaType, ";base64,", Convert.ToBase64String(_data.GetValueOrDefault() #if NET .Span)); #else @@ -167,10 +177,9 @@ public string Uri /// If the media type was explicitly specified, this property returns that value. /// If the media type was not explicitly specified, but a data URI was supplied and that data URI contained a non-default /// media type, that media type is returned. - /// Otherwise, this property returns null. /// - [JsonPropertyOrder(1)] - public string? MediaType { get; private set; } + [JsonIgnore] + public string MediaType { get; } /// Gets the data represented by this instance. /// @@ -181,16 +190,18 @@ public string Uri /// no attempt is made to retrieve the data from that URI. /// [JsonIgnore] - public ReadOnlyMemory? Data + public ReadOnlyMemory Data { get { - if (_dataUri is not null) + if (_data is null) { - _data ??= _dataUri.ToByteArray(); + Debug.Assert(_dataUri is not null, "Expected dataUri to be initialized."); + _data = _dataUri!.ToByteArray(); } - return _data; + Debug.Assert(_data is not null, "Expected data to be initialized."); + return _data.GetValueOrDefault(); } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataUriParser.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataUriParser.cs index 5cb33d1a55c..cff25e9c30b 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataUriParser.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataUriParser.cs @@ -5,10 +5,14 @@ #if NET8_0_OR_GREATER using System.Buffers.Text; #endif -using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Net; using System.Net.Http.Headers; +using System.Runtime.CompilerServices; using System.Text; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable CA1307 // Specify StringComparison for clarity namespace Microsoft.Extensions.AI; @@ -55,8 +59,9 @@ public static DataUri Parse(ReadOnlyMemory dataUri) } // Validate the media type, if present. + ReadOnlySpan span = metadata.Span.Trim(); string? mediaType = null; - if (!IsValidMediaType(metadata.Span.Trim(), ref mediaType)) + if (!span.IsEmpty && !IsValidMediaType(span, ref mediaType)) { throw new UriFormatException("Invalid data URI format: the media type is not a valid."); } @@ -64,20 +69,25 @@ public static DataUri Parse(ReadOnlyMemory dataUri) return new DataUri(data, isBase64, mediaType); } - /// Validates that a media type is valid, and if successful, ensures we have it as a string. - public static bool IsValidMediaType(ReadOnlySpan mediaTypeSpan, ref string? mediaType) + public static string ThrowIfInvalidMediaType( + string mediaType, [CallerArgumentExpression(nameof(mediaType))] string parameterName = "") { - Debug.Assert( - mediaType is null || mediaTypeSpan.Equals(mediaType.AsSpan(), StringComparison.Ordinal), - "mediaType string should either be null or the same as the span"); + _ = Throw.IfNullOrWhitespace(mediaType, parameterName); - // If the media type is empty or all whitespace, normalize it to null. - if (mediaTypeSpan.IsWhiteSpace()) + if (!IsValidMediaType(mediaType)) { - mediaType = null; - return true; + Throw.ArgumentException(parameterName, $"An invalid media type was specified: '{mediaType}'"); } + return mediaType; + } + + public static bool IsValidMediaType(string mediaType) => + IsValidMediaType(mediaType.AsSpan(), ref mediaType); + + /// Validates that a media type is valid, and if successful, ensures we have it as a string. + public static bool IsValidMediaType(ReadOnlySpan mediaTypeSpan, [NotNull] ref string? mediaType) + { // For common media types, we can avoid both allocating a string for the span and avoid parsing overheads. string? knownType = mediaTypeSpan switch { @@ -108,7 +118,7 @@ public static bool IsValidMediaType(ReadOnlySpan mediaTypeSpan, ref string }; if (knownType is not null) { - mediaType ??= knownType; + mediaType = knownType; return true; } @@ -117,6 +127,16 @@ public static bool IsValidMediaType(ReadOnlySpan mediaTypeSpan, ref string return MediaTypeHeaderValue.TryParse(mediaType, out _); } + public static bool HasTopLevelMediaType(string mediaType, string topLevelMediaType) + { + int slashIndex = mediaType.IndexOf('/'); + + ReadOnlySpan span = slashIndex < 0 ? mediaType.AsSpan() : mediaType.AsSpan(0, slashIndex); + span = span.Trim(); + + return span.Equals(topLevelMediaType.AsSpan(), StringComparison.OrdinalIgnoreCase); + } + /// Test whether the value is a base64 string without whitespace. private static bool IsValidBase64Data(ReadOnlySpan value) { diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/UriContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/UriContent.cs new file mode 100644 index 00000000000..7beaa40efdf --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/UriContent.cs @@ -0,0 +1,92 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics; +using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Represents a URL, typically to hosted content such as an image, audio, or video. +/// +/// +/// This class is intended for use with HTTP or HTTPS URIs that reference hosted content. +/// For data URIs, use instead. +/// +[DebuggerDisplay("{DebuggerDisplay,nq}")] +public class UriContent : AIContent +{ + /// The URI represented. + private Uri _uri; + + /// The MIME type of the data at the referenced URI. + private string _mediaType; + + /// Initializes a new instance of the class. + /// The URI to the represented content. + /// The media type (also known as MIME type) represented by the content. + /// is . + /// is . + /// is an invalid media type. + /// is an invalid URL. + /// + /// A media type must be specified, so that consumers know what to do with the content. + /// If an exact media type is not known, but the category (e.g. image) is known, a wildcard + /// may be used (e.g. "image/*"). + /// + public UriContent(string uri, string mediaType) + : this(new Uri(Throw.IfNull(uri)), mediaType) + { + } + + /// Initializes a new instance of the class. + /// The URI to the represented content. + /// The media type (also known as MIME type) represented by the content. + /// is . + /// is . + /// is an invalid media type. + /// + /// A media type must be specified, so that consumers know what to do with the content. + /// If an exact media type is not known, but the category (e.g. image) is known, a wildcard + /// may be used (e.g. "image/*"). + /// + [JsonConstructor] + public UriContent(Uri uri, string mediaType) + { + _uri = Throw.IfNull(uri); + _mediaType = DataUriParser.ThrowIfInvalidMediaType(mediaType); + } + + /// Gets or sets the for this content. + public Uri Uri + { + get => _uri; + set => _uri = Throw.IfNull(value); + } + + /// Gets or sets the media type (also known as MIME type) for this content. + public string MediaType + { + get => _mediaType; + set => _mediaType = DataUriParser.ThrowIfInvalidMediaType(value); + } + + /// + /// Determines whether the 's top-level type matches the specified . + /// + /// The type to compare against . + /// if the type portion of matches the specified value; otherwise, false. + /// + /// A media type is primarily composed of two parts, a "type" and a "subtype", separated by a slash ("/"). + /// The type portion is also referred to as the "top-level type"; for example, + /// "image/png" has a top-level type of "image". compares + /// the specified against the type portion of . + /// + public bool HasTopLevelMediaType(string topLevelType) => DataUriParser.HasTopLevelMediaType(MediaType, topLevelType); + + /// Gets a string representing this instance to display in the debugger. + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string DebuggerDisplay => $"Uri = {_uri}"; +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs index 35d8260e406..d8ed6967d71 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs @@ -17,8 +17,6 @@ namespace Microsoft.Extensions.AI; public static class EmbeddingGeneratorExtensions { /// Asks the for an object of type . - /// The type from which embeddings will be generated. - /// The numeric type of the embedding data. /// The type of the object to be retrieved. /// The generator. /// An optional key that can be used to help identify the target service. @@ -28,9 +26,8 @@ public static class EmbeddingGeneratorExtensions /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the /// , including itself or any services it might be wrapping. /// - public static TService? GetService( - this IEmbeddingGenerator generator, object? serviceKey = null) - where TEmbedding : Embedding + public static TService? GetService( + this IEmbeddingGenerator generator, object? serviceKey = null) { _ = Throw.IfNull(generator); @@ -41,8 +38,6 @@ public static class EmbeddingGeneratorExtensions /// Asks the for an object of the specified type /// and throws an exception if one isn't available. /// - /// The type from which embeddings will be generated. - /// The numeric type of the embedding data. /// The generator. /// The type of object being requested. /// An optional key that can be used to help identify the target service. @@ -54,9 +49,8 @@ public static class EmbeddingGeneratorExtensions /// The purpose of this method is to allow for the retrieval of services that are required to be provided by the /// , including itself or any services it might be wrapping. /// - public static object GetRequiredService( - this IEmbeddingGenerator generator, Type serviceType, object? serviceKey = null) - where TEmbedding : Embedding + public static object GetRequiredService( + this IEmbeddingGenerator generator, Type serviceType, object? serviceKey = null) { _ = Throw.IfNull(generator); _ = Throw.IfNull(serviceType); @@ -70,8 +64,6 @@ public static object GetRequiredService( /// Asks the for an object of type /// and throws an exception if one isn't available. /// - /// The type from which embeddings will be generated. - /// The numeric type of the embedding data. /// The type of the object to be retrieved. /// The generator. /// An optional key that can be used to help identify the target service. @@ -82,9 +74,8 @@ public static object GetRequiredService( /// The purpose of this method is to allow for the retrieval of strongly typed services that are required to be provided by the /// , including itself or any services it might be wrapping. /// - public static TService GetRequiredService( - this IEmbeddingGenerator generator, object? serviceKey = null) - where TEmbedding : Embedding + public static TService GetRequiredService( + this IEmbeddingGenerator generator, object? serviceKey = null) { _ = Throw.IfNull(generator); @@ -96,42 +87,6 @@ public static TService GetRequiredService( return service; } - // The following overloads exist purely to work around the lack of partial generic type inference. - // Given an IEmbeddingGenerator generator, to call GetService with TService, you still need - // to re-specify both TInput and TEmbedding, e.g. generator.GetService, TService>. - // The case of string/Embedding is by far the most common case today, so this overload exists as an - // accelerator to allow it to be written simply as generator.GetService. - - /// Asks the for an object of type . - /// The type of the object to be retrieved. - /// The generator. - /// An optional key that can be used to help identify the target service. - /// The found object, otherwise . - /// is . - /// - /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the - /// , including itself or any services it might be wrapping. - /// - public static TService? GetService(this IEmbeddingGenerator> generator, object? serviceKey = null) => - GetService, TService>(generator, serviceKey); - - /// - /// Asks the for an object of type - /// and throws an exception if one isn't available. - /// - /// The type of the object to be retrieved. - /// The generator. - /// An optional key that can be used to help identify the target service. - /// The found object. - /// is . - /// No service of the requested type for the specified key is available. - /// - /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the - /// , including itself or any services it might be wrapping. - /// - public static TService GetRequiredService(this IEmbeddingGenerator> generator, object? serviceKey = null) => - GetRequiredService, TService>(generator, serviceKey); - /// Generates an embedding vector from the specified . /// The type from which embeddings will be generated. /// The numeric type of the embedding data. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs index 59fcc9e2393..4f8174b6874 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs @@ -2,42 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; namespace Microsoft.Extensions.AI; /// Represents a generator of embeddings. -/// The type from which embeddings will be generated. -/// The type of embeddings to generate. /// -/// -/// Unless otherwise specified, all members of are thread-safe for concurrent use. -/// It is expected that all implementations of support being used by multiple requests concurrently. -/// Instances must not be disposed of while the instance is still in use. -/// -/// -/// However, implementations of may mutate the arguments supplied to -/// , such as by configuring the options instance. Thus, consumers of the interface either should -/// avoid using shared instances of these arguments for concurrent invocations or should otherwise ensure by construction that -/// no instances are used which might employ such mutation. -/// +/// This base interface is used to allow for embedding generators to be stored in a non-generic manner. +/// To use the generator to create embeddings, instances typed as this base interface first need to be +/// cast to the generic interface . /// -public interface IEmbeddingGenerator : IDisposable - where TEmbedding : Embedding +public interface IEmbeddingGenerator : IDisposable { - /// Generates embeddings for each of the supplied . - /// The sequence of values for which to generate embeddings. - /// The embedding generation options with which to configure the request. - /// The to monitor for cancellation requests. The default is . - /// The generated embeddings. - /// is . - Task> GenerateAsync( - IEnumerable values, - EmbeddingGenerationOptions? options = null, - CancellationToken cancellationToken = default); - /// Asks the for an object of the specified type . /// The type of object being requested. /// An optional key that can be used to help identify the target service. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator{TInput,TEmbedding}.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator{TInput,TEmbedding}.cs new file mode 100644 index 00000000000..ff3910ae737 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator{TInput,TEmbedding}.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.AI; + +/// Represents a generator of embeddings. +/// The type from which embeddings will be generated. +/// The type of embeddings to generate. +/// +/// +/// Unless otherwise specified, all members of are thread-safe for concurrent use. +/// It is expected that all implementations of support being used by multiple requests concurrently. +/// Instances must not be disposed of while the instance is still in use. +/// +/// +/// However, implementations of may mutate the arguments supplied to +/// , such as by configuring the options instance. Thus, consumers of the interface either should +/// avoid using shared instances of these arguments for concurrent invocations or should otherwise ensure by construction that +/// no instances are used which might employ such mutation. +/// +/// +public interface IEmbeddingGenerator : IEmbeddingGenerator + where TEmbedding : Embedding +{ + /// Generates embeddings for each of the supplied . + /// The sequence of values for which to generate embeddings. + /// The embedding generation options with which to configure the request. + /// The to monitor for cancellation requests. The default is . + /// The generated embeddings. + /// is . + Task> GenerateAsync( + IEnumerable values, + EmbeddingGenerationOptions? options = null, + CancellationToken cancellationToken = default); +} diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs index db03a62f2a9..ed2cc991e8c 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs @@ -490,42 +490,34 @@ private static List GetContentParts(IList con parts.Add(new ChatMessageTextContentItem(textContent.Text)); break; - case DataContent dataContent when dataContent.MediaTypeStartsWith("image/"): - if (dataContent.Data.HasValue) - { - parts.Add(new ChatMessageImageContentItem(BinaryData.FromBytes(dataContent.Data.Value), dataContent.MediaType)); - } - else if (dataContent.Uri is string uri) - { - parts.Add(new ChatMessageImageContentItem(new Uri(uri))); - } + case UriContent uriContent when uriContent.HasTopLevelMediaType("image"): + parts.Add(new ChatMessageImageContentItem(uriContent.Uri)); + break; + case DataContent dataContent when dataContent.HasTopLevelMediaType("image"): + parts.Add(new ChatMessageImageContentItem(BinaryData.FromBytes(dataContent.Data), dataContent.MediaType)); break; - case DataContent dataContent when dataContent.MediaTypeStartsWith("audio/"): - if (dataContent.Data.HasValue) - { - AudioContentFormat format; - if (dataContent.MediaTypeStartsWith("audio/mpeg")) - { - format = AudioContentFormat.Mp3; - } - else if (dataContent.MediaTypeStartsWith("audio/wav")) - { - format = AudioContentFormat.Wav; - } - else - { - break; - } + case UriContent uriContent when uriContent.HasTopLevelMediaType("audio"): + parts.Add(new ChatMessageAudioContentItem(uriContent.Uri)); + break; - parts.Add(new ChatMessageAudioContentItem(BinaryData.FromBytes(dataContent.Data.Value), format)); + case DataContent dataContent when dataContent.HasTopLevelMediaType("audio"): + AudioContentFormat format; + if (dataContent.MediaType.Equals("audio/mpeg", StringComparison.OrdinalIgnoreCase)) + { + format = AudioContentFormat.Mp3; + } + else if (dataContent.MediaType.Equals("audio/wav", StringComparison.OrdinalIgnoreCase)) + { + format = AudioContentFormat.Wav; } - else if (dataContent.Uri is string uri) + else { - parts.Add(new ChatMessageAudioContentItem(new Uri(uri))); + break; } + parts.Add(new ChatMessageAudioContentItem(BinaryData.FromBytes(dataContent.Data), format)); break; } } diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs index c0f4b2f4636..5cadc200869 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs @@ -73,7 +73,7 @@ public AzureAIInferenceEmbeddingGenerator( } /// - object? IEmbeddingGenerator>.GetService(Type serviceType, object? serviceKey) + object? IEmbeddingGenerator.GetService(Type serviceType, object? serviceKey) { _ = Throw.IfNull(serviceType); diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs index ed1448c8b69..0af538b9802 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs @@ -392,10 +392,10 @@ private IEnumerable ToOllamaChatRequestMessages(ChatMe OllamaChatRequestMessage? currentTextMessage = null; foreach (var item in content.Contents) { - if (item is DataContent dataContent && dataContent.MediaTypeStartsWith("image/") && dataContent.Data.HasValue) + if (item is DataContent dataContent && dataContent.HasTopLevelMediaType("image")) { IList images = currentTextMessage?.Images ?? []; - images.Add(Convert.ToBase64String(dataContent.Data.Value + images.Add(Convert.ToBase64String(dataContent.Data #if NET .Span)); #else diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs index 6056753dd26..0b63491ddc2 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs @@ -61,7 +61,7 @@ public OllamaEmbeddingGenerator(Uri endpoint, string? modelId = null, HttpClient } /// - object? IEmbeddingGenerator>.GetService(Type serviceType, object? serviceKey) + object? IEmbeddingGenerator.GetService(Type serviceType, object? serviceKey) { _ = Throw.IfNull(serviceType); diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs index 1e5afb6d529..9aaad72ec3b 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs @@ -299,7 +299,7 @@ strictObj is bool strictValue ? messageContents.Add(MessageContent.FromText(tc.Text)); break; - case DataContent dc when dc.MediaTypeStartsWith("image/"): + case DataContent dc when dc.HasTopLevelMediaType("image"): messageContents.Add(MessageContent.FromImageUri(new(dc.Uri))); break; diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs index 7cf0be18fb0..8ae8a32b898 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs @@ -125,7 +125,7 @@ void IDisposable.Dispose() } /// - object? IEmbeddingGenerator>.GetService(Type serviceType, object? serviceKey) + object? IEmbeddingGenerator.GetService(Type serviceType, object? serviceKey) { _ = Throw.IfNull(serviceType); diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs index 59727d38f00..fdee45ea96d 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs @@ -566,15 +566,14 @@ private static ChatRole FromOpenAIChatRole(ChatMessageRole role) => } else if (contentPart.Kind == ChatMessageContentPartKind.Image) { - DataContent? imageContent; - aiContent = imageContent = - contentPart.ImageUri is not null ? new DataContent(contentPart.ImageUri, contentPart.ImageBytesMediaType) : + aiContent = + contentPart.ImageUri is not null ? new UriContent(contentPart.ImageUri, "image/*") : contentPart.ImageBytes is not null ? new DataContent(contentPart.ImageBytes.ToMemory(), contentPart.ImageBytesMediaType) : null; - if (imageContent is not null && contentPart.ImageDetailLevel?.ToString() is string detail) + if (aiContent is not null && contentPart.ImageDetailLevel?.ToString() is string detail) { - (imageContent.AdditionalProperties ??= [])[nameof(contentPart.ImageDetailLevel)] = detail; + (aiContent.AdditionalProperties ??= [])[nameof(contentPart.ImageDetailLevel)] = detail; } } diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatMessage.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatMessage.cs index c051c208f1e..8d9195b0953 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatMessage.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatMessage.cs @@ -205,11 +205,11 @@ private static List FromOpenAIChatContent(IList ToOpenAIChatContent(IList parts.Add(ChatMessageContentPart.CreateTextPart(textContent.Text)); break; - case DataContent dataContent when dataContent.MediaTypeStartsWith("image/"): - if (dataContent.Data.HasValue) - { - parts.Add(ChatMessageContentPart.CreateImagePart(BinaryData.FromBytes(dataContent.Data.Value), dataContent.MediaType)); - } - else if (dataContent.Uri is string uri) - { - parts.Add(ChatMessageContentPart.CreateImagePart(new Uri(uri))); - } + case UriContent uriContent when uriContent.HasTopLevelMediaType("image"): + parts.Add(ChatMessageContentPart.CreateImagePart(uriContent.Uri)); + break; + case DataContent dataContent when dataContent.HasTopLevelMediaType("image"): + parts.Add(ChatMessageContentPart.CreateImagePart(BinaryData.FromBytes(dataContent.Data), dataContent.MediaType)); break; - case DataContent dataContent when dataContent.MediaTypeStartsWith("audio/") && dataContent.Data.HasValue: - var audioData = BinaryData.FromBytes(dataContent.Data.Value); - if (dataContent.MediaTypeStartsWith("audio/mpeg")) + case DataContent dataContent when dataContent.HasTopLevelMediaType("audio"): + var audioData = BinaryData.FromBytes(dataContent.Data); + if (dataContent.MediaType.Equals("audio/mpeg", StringComparison.OrdinalIgnoreCase)) { parts.Add(ChatMessageContentPart.CreateInputAudioPart(audioData, ChatInputAudioFormat.Mp3)); } - else if (dataContent.MediaTypeStartsWith("audio/wav")) + else if (dataContent.MediaType.Equals("audio/wav", StringComparison.OrdinalIgnoreCase)) { parts.Add(ChatMessageContentPart.CreateInputAudioPart(audioData, ChatInputAudioFormat.Wav)); } diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs index b84e8ac6e60..ebc7e3d26af 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs @@ -53,6 +53,8 @@ public static EmbeddingGeneratorBuilder AddEmbeddingGenerato var builder = new EmbeddingGeneratorBuilder(innerGeneratorFactory); serviceCollection.Add(new ServiceDescriptor(typeof(IEmbeddingGenerator), builder.Build, lifetime)); + serviceCollection.Add(new ServiceDescriptor(typeof(IEmbeddingGenerator), + static services => services.GetRequiredService>(), lifetime)); return builder; } @@ -103,6 +105,8 @@ public static EmbeddingGeneratorBuilder AddKeyedEmbeddingGen var builder = new EmbeddingGeneratorBuilder(innerGeneratorFactory); serviceCollection.Add(new ServiceDescriptor(typeof(IEmbeddingGenerator), serviceKey, factory: (services, serviceKey) => builder.Build(services), lifetime)); + serviceCollection.Add(new ServiceDescriptor(typeof(IEmbeddingGenerator), serviceKey, + static (services, serviceKey) => services.GetRequiredKeyedService>(serviceKey), lifetime)); return builder; } } diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs index 24770df1052..90553ca5411 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs @@ -52,7 +52,7 @@ public override async Task> GenerateAsync(IEnume { if (_logger.IsEnabled(LogLevel.Trace)) { - LogInvokedSensitive(AsJson(values), AsJson(options), AsJson(this.GetService())); + LogInvokedSensitive(AsJson(values), AsJson(options), AsJson(this.GetService())); } else { diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs index 3fd92a103aa..26ead720a1c 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs @@ -50,7 +50,7 @@ public OpenTelemetryEmbeddingGenerator(IEmbeddingGenerator i { Debug.Assert(innerGenerator is not null, "Should have been validated by the base ctor."); - if (innerGenerator!.GetService() is EmbeddingGeneratorMetadata metadata) + if (innerGenerator!.GetService() is EmbeddingGeneratorMetadata metadata) { _system = metadata.ProviderName; _modelId = metadata.ModelId; diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs index 7174d2a70c8..c449f064255 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs @@ -141,8 +141,8 @@ public void Text_ConcatsAllTextContent() { ChatMessage message = new(ChatRole.User, [ - new DataContent("http://localhost/audio"), - new DataContent("http://localhost/image"), + new DataContent("data:text/image;base64,aGVsbG8="), + new DataContent("data:text/plain;base64,aGVsbG8="), new FunctionCallContent("callId1", "fc1"), new TextContent("text-1"), new TextContent("text-2"), @@ -240,7 +240,7 @@ public void ItCanBeSerializeAndDeserialized() { AdditionalProperties = new() { ["metadata-key-1"] = "metadata-value-1" } }, - new DataContent(new Uri("https://fake-random-test-host:123"), "mime-type/2") + new DataContent(new Uri("data:text/plain;base64,aGVsbG8="), "mime-type/2") { AdditionalProperties = new() { ["metadata-key-2"] = "metadata-value-2" } }, @@ -286,7 +286,7 @@ public void ItCanBeSerializeAndDeserialized() var dataContent = deserializedMessage.Contents[1] as DataContent; Assert.NotNull(dataContent); - Assert.Equal("https://fake-random-test-host:123/", dataContent.Uri); + Assert.Equal("data:mime-type/2;base64,aGVsbG8=", dataContent.Uri); Assert.Equal("mime-type/2", dataContent.MediaType); Assert.NotNull(dataContent.AdditionalProperties); Assert.Single(dataContent.AdditionalProperties); @@ -294,7 +294,7 @@ public void ItCanBeSerializeAndDeserialized() dataContent = deserializedMessage.Contents[2] as DataContent; Assert.NotNull(dataContent); - Assert.True(dataContent.Data!.Value.Span.SequenceEqual(new BinaryData(new[] { 1, 2, 3 }, TestJsonSerializerContext.Default.Options))); + Assert.True(dataContent.Data.Span.SequenceEqual(new BinaryData(new[] { 1, 2, 3 }, TestJsonSerializerContext.Default.Options))); Assert.Equal("mime-type/3", dataContent.MediaType); Assert.NotNull(dataContent.AdditionalProperties); Assert.Single(dataContent.AdditionalProperties); diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs index 454c3c3cad3..00e074ab276 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs @@ -124,7 +124,7 @@ void AddGap() { for (int i = 0; i < gapLength; i++) { - updates.Add(new() { Contents = [new DataContent("https://uri", mediaType: "image/png")] }); + updates.Add(new() { Contents = [new DataContent("data:image/png;base64,aGVsbG8=")] }); } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs index 7e5ff6b1e84..cc406929aa1 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs @@ -84,8 +84,8 @@ public void Text_Get_UsesAllTextContent() Role = ChatRole.User, Contents = [ - new DataContent("http://localhost/audio"), - new DataContent("http://localhost/image"), + new DataContent("data:image/audio;base64,aGVsbG8="), + new DataContent("data:image/image;base64,aGVsbG8="), new FunctionCallContent("callId1", "fc1"), new TextContent("text-1"), new TextContent("text-2"), @@ -114,9 +114,9 @@ public void JsonSerialization_Roundtrips() Contents = [ new TextContent("text-1"), - new DataContent("http://localhost/image"), + new DataContent("data:image/png;base64,aGVsbG8="), new FunctionCallContent("callId1", "fc1"), - new DataContent("data"u8.ToArray()), + new DataContent("data"u8.ToArray(), "text/plain"), new TextContent("text-2"), ], RawRepresentation = new object(), @@ -137,13 +137,13 @@ public void JsonSerialization_Roundtrips() Assert.Equal("text-1", ((TextContent)result.Contents[0]).Text); Assert.IsType(result.Contents[1]); - Assert.Equal("http://localhost/image", ((DataContent)result.Contents[1]).Uri); + Assert.Equal("data:image/png;base64,aGVsbG8=", ((DataContent)result.Contents[1]).Uri); Assert.IsType(result.Contents[2]); Assert.Equal("fc1", ((FunctionCallContent)result.Contents[2]).Name); Assert.IsType(result.Contents[3]); - Assert.Equal("data"u8.ToArray(), ((DataContent)result.Contents[3]).Data?.ToArray()); + Assert.Equal("data"u8.ToArray(), ((DataContent)result.Contents[3]).Data.ToArray()); Assert.IsType(result.Contents[4]); Assert.Equal("text-2", ((TextContent)result.Contents[4]).Text); diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests.cs index dfa28373d48..83f09c66889 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests.cs @@ -13,10 +13,16 @@ public sealed class DataContentTests // Invalid URI [InlineData("", typeof(ArgumentException))] - [InlineData("invalid", typeof(UriFormatException))] + [InlineData("invalid", typeof(ArgumentException))] + [InlineData("data", typeof(ArgumentException))] + + // Not a data URI + [InlineData("http://localhost/blah.png", typeof(ArgumentException))] + [InlineData("https://localhost/blah.png", typeof(ArgumentException))] + [InlineData("ftp://localhost/blah.png", typeof(ArgumentException))] + [InlineData("a://localhost/blah.png", typeof(ArgumentException))] // Format errors - [InlineData("data", typeof(UriFormatException))] // data missing colon [InlineData("data:", typeof(UriFormatException))] // data missing comma [InlineData("data:something,", typeof(UriFormatException))] // mime type without subtype [InlineData("data:something;else,data", typeof(UriFormatException))] // mime type without subtype @@ -48,7 +54,7 @@ public void Ctor_InvalidUri_Throws(string path, Type exception) [InlineData("type/subtype;key=value;another=")] public void Ctor_InvalidMediaType_Throws(string type) { - Assert.Throws("mediaType", () => new DataContent("http://localhost/test", type)); + Assert.Throws("mediaType", () => new DataContent("data:image/png;base64,aGVsbG8=", type)); } [Theory] @@ -58,7 +64,7 @@ public void Ctor_InvalidMediaType_Throws(string type) [InlineData("type/subtype;key=value;another=value;yet_another=value")] public void Ctor_ValidMediaType_Roundtrips(string mediaType) { - var content = new DataContent("http://localhost/test", mediaType); + var content = new DataContent("data:image/png;base64,aGVsbG8=", mediaType); Assert.Equal(mediaType, content.MediaType); content = new DataContent("data:,", mediaType); @@ -82,43 +88,25 @@ public void Ctor_NoMediaType_Roundtrips() { DataContent content; - foreach (string url in new[] { "http://localhost/test", "about:something", "file://c:\\path" }) - { - content = new DataContent(url); - Assert.Equal(url, content.Uri); - Assert.Null(content.MediaType); - Assert.Null(content.Data); - } - - content = new DataContent("data:,something"); - Assert.Equal("data:,something", content.Uri); - Assert.Null(content.MediaType); - Assert.Equal("something"u8.ToArray(), content.Data!.Value.ToArray()); - - content = new DataContent("data:,Hello+%3C%3E"); - Assert.Equal("data:,Hello+%3C%3E", content.Uri); - Assert.Null(content.MediaType); - Assert.Equal("Hello <>"u8.ToArray(), content.Data!.Value.ToArray()); + content = new DataContent("data:image/png;base64,aGVsbG8="); + Assert.Equal("data:image/png;base64,aGVsbG8=", content.Uri); + Assert.Equal("image/png", content.MediaType); + + content = new DataContent(new Uri("data:image/png;base64,aGVsbG8=")); + Assert.Equal("data:image/png;base64,aGVsbG8=", content.Uri); + Assert.Equal("image/png", content.MediaType); } [Fact] public void Serialize_MatchesExpectedJson() { Assert.Equal( - """{"uri":"data:,"}""", - JsonSerializer.Serialize(new DataContent("data:,"), TestJsonSerializerContext.Default.Options)); - - Assert.Equal( - """{"uri":"http://localhost/"}""", - JsonSerializer.Serialize(new DataContent(new Uri("http://localhost/")), TestJsonSerializerContext.Default.Options)); - - Assert.Equal( - """{"uri":"data:application/octet-stream;base64,AQIDBA==","mediaType":"application/octet-stream"}""", + """{"uri":"data:application/octet-stream;base64,AQIDBA=="}""", JsonSerializer.Serialize(new DataContent( uri: "data:application/octet-stream;base64,AQIDBA=="), TestJsonSerializerContext.Default.Options)); Assert.Equal( - """{"uri":"data:application/octet-stream;base64,AQIDBA==","mediaType":"application/octet-stream"}""", + """{"uri":"data:application/octet-stream;base64,AQIDBA=="}""", JsonSerializer.Serialize(new DataContent( new ReadOnlyMemory([0x01, 0x02, 0x03, 0x04]), "application/octet-stream"), TestJsonSerializerContext.Default.Options)); @@ -136,53 +124,43 @@ public void Deserialize_MissingUriString_Throws(string json) public void Deserialize_MatchesExpectedData() { // Data + MimeType only - var content = JsonSerializer.Deserialize("""{"mediaType":"application/octet-stream","uri":"data:;base64,AQIDBA=="}""", TestJsonSerializerContext.Default.Options)!; + var content = JsonSerializer.Deserialize("""{"uri":"data:application/octet-stream;base64,AQIDBA=="}""", TestJsonSerializerContext.Default.Options)!; Assert.Equal("data:application/octet-stream;base64,AQIDBA==", content.Uri); - Assert.NotNull(content.Data); - Assert.Equal([0x01, 0x02, 0x03, 0x04], content.Data.Value.ToArray()); + Assert.Equal([0x01, 0x02, 0x03, 0x04], content.Data.ToArray()); Assert.Equal("application/octet-stream", content.MediaType); // Uri referenced content-only - content = JsonSerializer.Deserialize("""{"mediaType":"application/octet-stream","uri":"http://localhost/"}""", TestJsonSerializerContext.Default.Options)!; + content = JsonSerializer.Deserialize("""{"uri":"data:application/octet-stream;base64,AQIDBA=="}""", TestJsonSerializerContext.Default.Options)!; - Assert.Null(content.Data); - Assert.Equal("http://localhost/", content.Uri); + Assert.Equal("data:application/octet-stream;base64,AQIDBA==", content.Uri); Assert.Equal("application/octet-stream", content.MediaType); // Using extra metadata content = JsonSerializer.Deserialize(""" { - "uri": "data:;base64,AQIDBA==", + "uri": "data:audio/wav;base64,AQIDBA==", "modelId": "gpt-4", "additionalProperties": { "key": "value" - }, - "mediaType": "text/plain" + } } """, TestJsonSerializerContext.Default.Options)!; - Assert.Equal("data:text/plain;base64,AQIDBA==", content.Uri); - Assert.NotNull(content.Data); - Assert.Equal([0x01, 0x02, 0x03, 0x04], content.Data.Value.ToArray()); - Assert.Equal("text/plain", content.MediaType); + Assert.Equal("data:audio/wav;base64,AQIDBA==", content.Uri); + Assert.Equal([0x01, 0x02, 0x03, 0x04], content.Data.ToArray()); + Assert.Equal("audio/wav", content.MediaType); Assert.Equal("value", content.AdditionalProperties!["key"]!.ToString()); } [Theory] [InlineData( - """{"uri": "data:;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType": "text/plain"}""", - """{"uri":"data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType":"text/plain"}""")] - [InlineData( - """{"uri": "data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType": "text/plain"}""", - """{"uri":"data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType":"text/plain"}""")] + """{"uri": "data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8="}""", + """{"uri":"data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8="}""")] [InlineData( // Does not support non-readable content """{"uri": "data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=", "unexpected": true}""", - """{"uri":"data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType":"text/plain"}""")] - [InlineData( // Uri comes before mimetype - """{"mediaType": "text/plain", "uri": "http://localhost/" }""", - """{"uri":"http://localhost/","mediaType":"text/plain"}""")] + """{"uri":"data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8="}""")] public void Serialize_Deserialize_Roundtrips(string serialized, string expectedToString) { var content = JsonSerializer.Deserialize(serialized, TestJsonSerializerContext.Default.Options)!; @@ -222,30 +200,28 @@ public void MediaType_Roundtrips(string mediaType) } [Theory] - [InlineData("image/gif", "image/")] + [InlineData("image/gif", "image")] [InlineData("IMAGE/JPEG", "image")] - [InlineData("image/vnd.microsoft.icon", "ima")] - [InlineData("image/svg+xml", "IMAGE/")] + [InlineData("image/vnd.microsoft.icon", "imAge")] + [InlineData("image/svg+xml", "IMAGE")] [InlineData("image/nonexistentimagemimetype", "IMAGE")] - [InlineData("audio/mpeg", "aUdIo/")] - [InlineData("application/json", "")] - [InlineData("application/pdf", "application/pdf")] - public void HasMediaTypePrefix_ReturnsTrue(string? mediaType, string prefix) + [InlineData("audio/mpeg", "aUdIo")] + public void HasMediaTypePrefix_ReturnsTrue(string mediaType, string prefix) { - var content = new DataContent("http://localhost/image.png", mediaType); - Assert.True(content.MediaTypeStartsWith(prefix)); + var content = new DataContent("data:application/octet-stream;base64,AQIDBA==", mediaType); + Assert.True(content.HasTopLevelMediaType(prefix)); } [Theory] - [InlineData("audio/mpeg", "image/")] + [InlineData("audio/mpeg", "audio/")] + [InlineData("audio/mpeg", "image")] + [InlineData("audio/mpeg", "audio/mpeg")] [InlineData("text/css", "text/csv")] + [InlineData("text/css", "/csv")] [InlineData("application/json", "application/json!")] - [InlineData("", "")] // The media type will get normalized to null - [InlineData(null, "image/")] - [InlineData(null, "")] - public void HasMediaTypePrefix_ReturnsFalse(string? mediaType, string prefix) + public void HasMediaTypePrefix_ReturnsFalse(string mediaType, string prefix) { - var content = new DataContent("http://localhost/image.png", mediaType); - Assert.False(content.MediaTypeStartsWith(prefix)); + var content = new DataContent("data:application/octet-stream;base64,AQIDBA==", mediaType); + Assert.False(content.HasTopLevelMediaType(prefix)); } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/UriContentTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/UriContentTests.cs new file mode 100644 index 00000000000..8b4e8c6665d --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/UriContentTests.cs @@ -0,0 +1,130 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public sealed class UriContentTests +{ + [Fact] + public void Ctor_InvalidUriMediaType_Throws() + { + Assert.Throws("uri", () => new UriContent((string)null!, "image/png")); + Assert.Throws("uri", () => new UriContent((Uri)null!, "image/png")); + Assert.Throws(() => new UriContent("notauri", "image/png")); + + Assert.Throws("mediaType", () => new UriContent("data:image/png;base64,aGVsbG8=", null!)); + Assert.Throws("mediaType", () => new UriContent("data:image/png;base64,aGVsbG8=", "")); + Assert.Throws("mediaType", () => new UriContent("data:image/png;base64,aGVsbG8=", "image")); + + Assert.Throws("mediaType", () => new UriContent(new Uri("data:image/png;base64,aGVsbG8="), null!)); + Assert.Throws("mediaType", () => new UriContent(new Uri("data:image/png;base64,aGVsbG8="), "")); + Assert.Throws("mediaType", () => new UriContent(new Uri("data:image/png;base64,aGVsbG8="), "audio")); + + UriContent c = new("http://localhost/something", "image/png"); + Assert.Throws("value", () => c.Uri = null!); + } + + [Theory] + [InlineData("type")] + [InlineData("type//subtype")] + [InlineData("type/subtype/")] + [InlineData("type/subtype;key=")] + [InlineData("type/subtype;=value")] + [InlineData("type/subtype;key=value;another=")] + public void Ctor_InvalidMediaType_Throws(string type) + { + Assert.Throws("mediaType", () => new UriContent("http://localhost/something", type)); + + UriContent c = new("http://localhost/something", "image/png"); + Assert.Throws("value", () => c.MediaType = type); + Assert.Throws("value", () => c.MediaType = null!); + } + + [Theory] + [InlineData("type/subtype")] + [InlineData("type/subtype;key=value")] + [InlineData("type/subtype;key=value;another=value")] + [InlineData("type/subtype;key=value;another=value;yet_another=value")] + public void Ctor_ValidMediaType_Roundtrips(string mediaType) + { + var content = new UriContent("http://localhost/something", mediaType); + Assert.Equal(mediaType, content.MediaType); + + content.MediaType = "image/png"; + Assert.Equal("image/png", content.MediaType); + + content.MediaType = mediaType; + Assert.Equal(mediaType, content.MediaType); + } + + [Fact] + public void Serialize_MatchesExpectedJson() + { + Assert.Equal( + """{"uri":"http://localhost/something","mediaType":"image/png"}""", + JsonSerializer.Serialize( + new UriContent("http://localhost/something", "image/png"), + TestJsonSerializerContext.Default.Options)); + } + + [Theory] + [InlineData("application/json")] + [InlineData("application/octet-stream")] + [InlineData("application/pdf")] + [InlineData("application/xml")] + [InlineData("audio/mpeg")] + [InlineData("audio/ogg")] + [InlineData("audio/wav")] + [InlineData("image/apng")] + [InlineData("image/avif")] + [InlineData("image/bmp")] + [InlineData("image/gif")] + [InlineData("image/jpeg")] + [InlineData("image/png")] + [InlineData("image/svg+xml")] + [InlineData("image/tiff")] + [InlineData("image/webp")] + [InlineData("text/css")] + [InlineData("text/csv")] + [InlineData("text/html")] + [InlineData("text/javascript")] + [InlineData("text/plain")] + [InlineData("text/plain;charset=UTF-8")] + [InlineData("text/xml")] + [InlineData("custom/mediatypethatdoesntexists")] + public void MediaType_Roundtrips(string mediaType) + { + UriContent c = new("http://localhost", mediaType); + Assert.Equal(mediaType, c.MediaType); + } + + [Theory] + [InlineData("image/gif", "image")] + [InlineData("IMAGE/JPEG", "image")] + [InlineData("image/vnd.microsoft.icon", "imAge")] + [InlineData("image/svg+xml", "IMAGE")] + [InlineData("image/nonexistentimagemimetype", "IMAGE")] + [InlineData("audio/mpeg", "aUdIo")] + public void HasMediaTypePrefix_ReturnsTrue(string mediaType, string prefix) + { + var content = new UriContent("http://localhost", mediaType); + Assert.True(content.HasTopLevelMediaType(prefix)); + } + + [Theory] + [InlineData("audio/mpeg", "audio/")] + [InlineData("audio/mpeg", "image")] + [InlineData("audio/mpeg", "audio/mpeg")] + [InlineData("text/css", "text/csv")] + [InlineData("text/css", "/csv")] + [InlineData("application/json", "application/json!")] + public void HasMediaTypePrefix_ReturnsFalse(string mediaType, string prefix) + { + var content = new UriContent("http://localhost", mediaType); + Assert.False(content.HasTopLevelMediaType(prefix)); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs index fe4af33cf23..dfe970b23ca 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs @@ -14,15 +14,12 @@ public class EmbeddingGeneratorExtensionsTests public void GetService_InvalidArgs_Throws() { Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetService(null!)); - Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetService, object>(null!)); } [Fact] public void GetRequiredService_InvalidArgs_Throws() { Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetRequiredService(null!)); - Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetRequiredService>(null!, typeof(string))); - Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetRequiredService, object>(null!)); using var generator = new TestEmbeddingGenerator(); Assert.Throws("serviceType", () => generator.GetRequiredService(null!)); @@ -51,41 +48,31 @@ public void GetService_ValidService_Returned() Assert.Equal("null key", generator.GetService(typeof(string))); Assert.Equal("null key", generator.GetService()); - Assert.Equal("null key", generator.GetService, string>()); Assert.Equal("non-null key", generator.GetService(typeof(string), "key")); Assert.Equal("non-null key", generator.GetService("key")); - Assert.Equal("non-null key", generator.GetService, string>("key")); Assert.Null(generator.GetService(typeof(object))); Assert.Null(generator.GetService()); - Assert.Null(generator.GetService, object>()); Assert.Null(generator.GetService(typeof(object), "key")); Assert.Null(generator.GetService("key")); - Assert.Null(generator.GetService, object>("key")); Assert.Null(generator.GetService()); - Assert.Null(generator.GetService, int?>()); Assert.Equal("null key", generator.GetRequiredService(typeof(string))); Assert.Equal("null key", generator.GetRequiredService()); - Assert.Equal("null key", generator.GetRequiredService, string>()); Assert.Equal("non-null key", generator.GetRequiredService(typeof(string), "key")); Assert.Equal("non-null key", generator.GetRequiredService("key")); - Assert.Equal("non-null key", generator.GetRequiredService, string>("key")); Assert.Throws(() => generator.GetRequiredService(typeof(object))); Assert.Throws(() => generator.GetRequiredService()); - Assert.Throws(() => generator.GetRequiredService, object>()); Assert.Throws(() => generator.GetRequiredService(typeof(object), "key")); Assert.Throws(() => generator.GetRequiredService("key")); - Assert.Throws(() => generator.GetRequiredService, object>("key")); Assert.Throws(() => generator.GetRequiredService()); - Assert.Throws(() => generator.GetRequiredService, int?>()); } [Fact] diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs index b8a68c913ed..d0167c8778b 100644 --- a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs @@ -615,7 +615,7 @@ public async Task MultipleContent_NonStreaming() Assert.NotNull(await client.GetResponseAsync([new(ChatRole.User, [ new TextContent("Describe this picture."), - new DataContent("http://dot.net/someimage.png", mediaType: "image/png"), + new UriContent("http://dot.net/someimage.png", mediaType: "image/*"), ])])); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs index 0b8aca0785e..0a499ab644d 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs @@ -149,7 +149,7 @@ public async Task FailureUsage_NullJson() [Fact] public async Task FailureUsage_NoJsonInResponse() { - var expectedResponse = new ChatResponse(new ChatMessage(ChatRole.Assistant, [new DataContent("https://example.com")])); + var expectedResponse = new ChatResponse(new ChatMessage(ChatRole.Assistant, [new UriContent("https://example.com", "image/*")])); using var client = new TestChatClient { GetResponseAsyncCallback = (messages, options, cancellationToken) => Task.FromResult(expectedResponse), diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DependencyInjectionPatterns.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DependencyInjectionPatterns.cs index c2f288165cb..3ff20afaad1 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DependencyInjectionPatterns.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DependencyInjectionPatterns.cs @@ -167,7 +167,8 @@ public void AddEmbeddingGenerator_RegistersExpectedLifetime(ServiceLifetime? lif ? sc.AddEmbeddingGenerator(services => new TestEmbeddingGenerator(), lifetime.Value) : sc.AddEmbeddingGenerator(services => new TestEmbeddingGenerator()); - ServiceDescriptor sd = Assert.Single(sc); + Assert.Equal(2, sc.Count); + ServiceDescriptor sd = sc[0]; Assert.Equal(typeof(IEmbeddingGenerator>), sd.ServiceType); Assert.False(sd.IsKeyedService); Assert.Null(sd.ImplementationInstance); @@ -176,6 +177,28 @@ public void AddEmbeddingGenerator_RegistersExpectedLifetime(ServiceLifetime? lif Assert.Equal(expectedLifetime, sd.Lifetime); } + [Theory] + [InlineData(null)] + [InlineData(ServiceLifetime.Singleton)] + [InlineData(ServiceLifetime.Scoped)] + [InlineData(ServiceLifetime.Transient)] + public void AddEmbeddingGenerator_RegistersNonGeneric(ServiceLifetime? lifetime) + { + ServiceCollection sc = new(); + ServiceLifetime expectedLifetime = lifetime ?? ServiceLifetime.Singleton; + var builder = lifetime.HasValue + ? sc.AddEmbeddingGenerator(services => new TestEmbeddingGenerator(), lifetime.Value) + : sc.AddEmbeddingGenerator(services => new TestEmbeddingGenerator()); + IServiceProvider sp = sc.BuildServiceProvider(); + + IEmbeddingGenerator>? g = sp.GetService>>(); + IEmbeddingGenerator? ng = sp.GetService(); + + Assert.NotNull(g); + Assert.NotNull(ng); + Assert.Equal(lifetime != ServiceLifetime.Transient, ReferenceEquals(g, ng)); + } + [Theory] [InlineData(null)] [InlineData(ServiceLifetime.Singleton)] @@ -189,7 +212,8 @@ public void AddKeyedEmbeddingGenerator_RegistersExpectedLifetime(ServiceLifetime ? sc.AddKeyedEmbeddingGenerator("key", services => new TestEmbeddingGenerator(), lifetime.Value) : sc.AddKeyedEmbeddingGenerator("key", services => new TestEmbeddingGenerator()); - ServiceDescriptor sd = Assert.Single(sc); + Assert.Equal(2, sc.Count); + ServiceDescriptor sd = sc[0]; Assert.Equal(typeof(IEmbeddingGenerator>), sd.ServiceType); Assert.True(sd.IsKeyedService); Assert.Equal("key", sd.ServiceKey); @@ -199,6 +223,28 @@ public void AddKeyedEmbeddingGenerator_RegistersExpectedLifetime(ServiceLifetime Assert.Equal(expectedLifetime, sd.Lifetime); } + [Theory] + [InlineData(null)] + [InlineData(ServiceLifetime.Singleton)] + [InlineData(ServiceLifetime.Scoped)] + [InlineData(ServiceLifetime.Transient)] + public void AddKeyedEmbeddingGenerator_RegistersNonGeneric(ServiceLifetime? lifetime) + { + ServiceCollection sc = new(); + ServiceLifetime expectedLifetime = lifetime ?? ServiceLifetime.Singleton; + var builder = lifetime.HasValue + ? sc.AddKeyedEmbeddingGenerator("key", services => new TestEmbeddingGenerator(), lifetime.Value) + : sc.AddKeyedEmbeddingGenerator("key", services => new TestEmbeddingGenerator()); + IServiceProvider sp = sc.BuildServiceProvider(); + + IEmbeddingGenerator>? g = sp.GetKeyedService>>("key"); + IEmbeddingGenerator? ng = sp.GetKeyedService("key"); + + Assert.NotNull(g); + Assert.NotNull(ng); + Assert.Equal(lifetime != ServiceLifetime.Transient, ReferenceEquals(g, ng)); + } + public class SingletonMiddleware(IChatClient inner, IServiceProvider services) : DelegatingChatClient(inner) { public new IChatClient InnerClient => base.InnerClient;