diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/EquivalenceEvaluatorContext.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/EquivalenceEvaluatorContext.cs index 3fcb7b3d36e..166fe2bbc85 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/EquivalenceEvaluatorContext.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/EquivalenceEvaluatorContext.cs @@ -6,6 +6,8 @@ // We disable this warning because it is a false positive arising from the analyzer's lack of support for C#'s primary // constructor syntax. +using System.Collections.Generic; + namespace Microsoft.Extensions.AI.Evaluation.Quality; /// @@ -29,4 +31,8 @@ public sealed class EquivalenceEvaluatorContext(string groundTruth) : Evaluation /// the response supplied via . /// public string GroundTruth { get; } = groundTruth; + + /// + public override IReadOnlyList GetContents() + => [new TextContent(GroundTruth)]; } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/GroundednessEvaluatorContext.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/GroundednessEvaluatorContext.cs index 32a9cf25a38..df14ae62d95 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/GroundednessEvaluatorContext.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/GroundednessEvaluatorContext.cs @@ -6,6 +6,8 @@ // We disable this warning because it is a false positive arising from the analyzer's lack of support for C#'s primary // constructor syntax. +using System.Collections.Generic; + namespace Microsoft.Extensions.AI.Evaluation.Quality; /// @@ -29,4 +31,8 @@ public sealed class GroundednessEvaluatorContext(string groundingContext) : Eval /// in the information present in the supplied . /// public string GroundingContext { get; } = groundingContext; + + /// + public override IReadOnlyList GetContents() + => [new TextContent(GroundingContext)]; } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs index cbacdc246dc..c6c38cf583a 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs @@ -171,9 +171,9 @@ await JsonOutputFixer.RepairJsonAsync( result.AddDiagnosticToAllMetrics( EvaluationDiagnostic.Error( $""" - Failed to repair the following response from the model and parse scores for '{RelevanceMetricName}', '{TruthMetricName}' and '{CompletenessMetricName}'.: - {evaluationResponseText} - """)); + Failed to repair the following response from the model and parse scores for '{RelevanceMetricName}', '{TruthMetricName}' and '{CompletenessMetricName}'.: + {evaluationResponseText} + """)); } else { @@ -186,10 +186,10 @@ await JsonOutputFixer.RepairJsonAsync( result.AddDiagnosticToAllMetrics( EvaluationDiagnostic.Error( $""" - Failed to repair the following response from the model and parse scores for '{RelevanceMetricName}', '{TruthMetricName}' and '{CompletenessMetricName}'.: - {evaluationResponseText} - {ex} - """)); + Failed to repair the following response from the model and parse scores for '{RelevanceMetricName}', '{TruthMetricName}' and '{CompletenessMetricName}'.: + {evaluationResponseText} + {ex} + """)); } } } @@ -211,28 +211,28 @@ void UpdateResult() if (!string.IsNullOrWhiteSpace(evaluationResponse.ModelId)) { - commonMetadata["rtc-evaluation-model-used"] = evaluationResponse.ModelId!; + commonMetadata["evaluation-model-used"] = evaluationResponse.ModelId!; } if (evaluationResponse.Usage is UsageDetails usage) { if (usage.InputTokenCount is not null) { - commonMetadata["rtc-evaluation-input-tokens-used"] = $"{usage.InputTokenCount}"; + commonMetadata["evaluation-input-tokens-used"] = $"{usage.InputTokenCount}"; } if (usage.OutputTokenCount is not null) { - commonMetadata["rtc-evaluation-output-tokens-used"] = $"{usage.OutputTokenCount}"; + commonMetadata["evaluation-output-tokens-used"] = $"{usage.OutputTokenCount}"; } if (usage.TotalTokenCount is not null) { - commonMetadata["rtc-evaluation-total-tokens-used"] = $"{usage.TotalTokenCount}"; + commonMetadata["evaluation-total-tokens-used"] = $"{usage.TotalTokenCount}"; } } - commonMetadata["rtc-evaluation-duration"] = duration; + commonMetadata["evaluation-duration"] = duration; NumericMetric relevance = result.Get(RelevanceMetricName); relevance.Value = rating.Relevance; diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ReportingConfiguration.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ReportingConfiguration.cs index 3ae9f73d021..c8a5c1b2411 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ReportingConfiguration.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ReportingConfiguration.cs @@ -263,7 +263,7 @@ await ResponseCacheProvider.GetCacheAsync( private static IEnumerable GetCachingKeysForChatClient(IChatClient chatClient) { - var metadata = chatClient.GetService(); + ChatClientMetadata? metadata = chatClient.GetService(); string? providerName = metadata?.ProviderName; if (!string.IsNullOrWhiteSpace(providerName)) diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/TypeScript/components/ChatDetailsSection.tsx b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/TypeScript/components/ChatDetailsSection.tsx index d25662ca708..0958e0b560d 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/TypeScript/components/ChatDetailsSection.tsx +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/TypeScript/components/ChatDetailsSection.tsx @@ -24,7 +24,7 @@ export const ChatDetailsSection = ({ chatDetails }: { chatDetails: ChatDetails;
setIsExpanded(!isExpanded)}> {isExpanded ? : } -

LLM Chat Diagnostic Details

+

Diagnostic Data

{hasCacheStatus && (
{cachedTurns != totalTurns ? diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/CodeVulnerabilityEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/CodeVulnerabilityEvaluator.cs index 10475ae9dad..c708b6c2cd3 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/CodeVulnerabilityEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/CodeVulnerabilityEvaluator.cs @@ -2,9 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; -using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI.Evaluation.Safety; @@ -31,16 +31,10 @@ namespace Microsoft.Extensions.AI.Evaluation.Safety; /// will be ignored. /// /// -/// -/// Specifies the Azure AI project that should be used and credentials that should be used when this -/// communicates with the Azure AI Content Safety service to perform -/// evaluations. -/// -public sealed class CodeVulnerabilityEvaluator(ContentSafetyServiceConfiguration contentSafetyServiceConfiguration) +public sealed class CodeVulnerabilityEvaluator() : ContentSafetyEvaluator( - contentSafetyServiceConfiguration, contentSafetyServiceAnnotationTask: "code vulnerability", - evaluatorName: nameof(CodeVulnerabilityEvaluator)) + metricNames: new Dictionary { ["code_vulnerability"] = CodeVulnerabilityMetricName }) { /// /// Gets the of the returned by @@ -48,9 +42,6 @@ public sealed class CodeVulnerabilityEvaluator(ContentSafetyServiceConfiguration /// public static string CodeVulnerabilityMetricName => "Code Vulnerability"; - /// - public override IReadOnlyCollection EvaluationMetricNames => [CodeVulnerabilityMetricName]; - /// public override async ValueTask EvaluateAsync( IEnumerable messages, @@ -59,30 +50,18 @@ public override async ValueTask EvaluateAsync( IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) { - const string CodeVulnerabilityContentSafetyServiceMetricName = "code_vulnerability"; + _ = Throw.IfNull(chatConfiguration); + _ = Throw.IfNull(modelResponse); EvaluationResult result = await EvaluateContentSafetyAsync( + chatConfiguration.ChatClient, messages, modelResponse, + additionalContext, contentSafetyServicePayloadFormat: ContentSafetyServicePayloadFormat.ContextCompletion.ToString(), - contentSafetyServiceMetricName: CodeVulnerabilityContentSafetyServiceMetricName, cancellationToken: cancellationToken).ConfigureAwait(false); - IEnumerable updatedMetrics = - result.Metrics.Values.Select( - metric => - { - if (metric.Name == CodeVulnerabilityContentSafetyServiceMetricName) - { - metric.Name = CodeVulnerabilityMetricName; - } - - return metric; - }); - - result = new EvaluationResult(updatedMetrics); - result.Interpret(metric => metric is BooleanMetric booleanMetric ? booleanMetric.InterpretScore() : null); return result; } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentHarmEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentHarmEvaluator.cs index ca8d187a6ed..efca908f01a 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentHarmEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentHarmEvaluator.cs @@ -2,44 +2,27 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; -using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI.Evaluation.Safety; -#pragma warning disable S1694 // An abstract class should have both abstract and concrete methods /// /// An base class that can be used to implement s that utilize the /// Azure AI Content Safety service to evaluate responses produced by an AI model for the presence of a variety of /// harmful content such as violence, hate speech, etc. /// -/// -/// Specifies the Azure AI project that should be used and credentials that should be used when this -/// communicates with the Azure AI Content Safety service to perform evaluations. -/// -/// -/// The name of the metric that should be used when this communicates with the -/// Azure AI Content Safety service to perform evaluations. +/// +/// A dictionary containing the mapping from the names of the metrics that are used when communicating with the Azure +/// AI Content Safety to the s of the s returned by +/// this . /// -/// -/// The name of the produced by this . -/// -/// The name of the derived . -public abstract class ContentHarmEvaluator( - ContentSafetyServiceConfiguration contentSafetyServiceConfiguration, - string contentSafetyServiceMetricName, - string metricName, - string evaluatorName) - : ContentSafetyEvaluator( - contentSafetyServiceConfiguration, - contentSafetyServiceAnnotationTask: "content harm", - evaluatorName) +#pragma warning disable S1694 // An abstract class should have both abstract and concrete methods +public abstract class ContentHarmEvaluator(IDictionary metricNames) + : ContentSafetyEvaluator(contentSafetyServiceAnnotationTask: "content harm", metricNames) #pragma warning restore S1694 { - /// - public override IReadOnlyCollection EvaluationMetricNames => [metricName]; - /// public sealed override async ValueTask EvaluateAsync( IEnumerable messages, @@ -48,28 +31,21 @@ public sealed override async ValueTask EvaluateAsync( IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) { + _ = Throw.IfNull(chatConfiguration); + _ = Throw.IfNull(modelResponse); + EvaluationResult result = await EvaluateContentSafetyAsync( + chatConfiguration.ChatClient, messages, modelResponse, + additionalContext, contentSafetyServicePayloadFormat: ContentSafetyServicePayloadFormat.Conversation.ToString(), - contentSafetyServiceMetricName: contentSafetyServiceMetricName, cancellationToken: cancellationToken).ConfigureAwait(false); - IEnumerable updatedMetrics = - result.Metrics.Values.Select( - metric => - { - if (metric.Name == contentSafetyServiceMetricName) - { - metric.Name = metricName; - } - - return metric; - }); + result.Interpret( + metric => metric is NumericMetric numericMetric ? numericMetric.InterpretContentHarmScore() : null); - result = new EvaluationResult(updatedMetrics); - result.Interpret(metric => metric is NumericMetric numericMetric ? numericMetric.InterpretHarmScore() : null); return result; } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyChatClient.cs new file mode 100644 index 00000000000..a29ed3b0a32 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyChatClient.cs @@ -0,0 +1,152 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable S3604 +// S3604: Member initializer values should not be redundant. +// We disable this warning because it is a false positive arising from the analyzer's lack of support for C#'s primary +// constructor syntax. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.AI.Evaluation.Safety; + +internal sealed class ContentSafetyChatClient : IChatClient +{ + private const string Moniker = "Azure AI Content Safety"; + + private readonly ContentSafetyService _service; + private readonly IChatClient? _originalChatClient; + private readonly ChatClientMetadata _metadata; + + public ContentSafetyChatClient( + ContentSafetyServiceConfiguration contentSafetyServiceConfiguration, + IChatClient? originalChatClient = null) + { + _service = new ContentSafetyService(contentSafetyServiceConfiguration); + _originalChatClient = originalChatClient; + + ChatClientMetadata? originalMetadata = _originalChatClient?.GetService(); + + string providerName = + $"{Moniker} (" + + $"Subscription: {contentSafetyServiceConfiguration.SubscriptionId}, " + + $"Resource Group: {contentSafetyServiceConfiguration.ResourceGroupName}, " + + $"Project: {contentSafetyServiceConfiguration.ProjectName})"; + + if (originalMetadata?.ProviderName is string originalProviderName && + !string.IsNullOrWhiteSpace(originalProviderName)) + { + providerName = $"{originalProviderName}; {providerName}"; + } + + string modelId = Moniker; + + if (originalMetadata?.DefaultModelId is string originalModelId && + !string.IsNullOrWhiteSpace(originalModelId)) + { + modelId = $"{originalModelId}; {modelId}"; + } + + _metadata = new ChatClientMetadata(providerName, originalMetadata?.ProviderUri, modelId); + } + + public async Task GetResponseAsync( + IEnumerable messages, + ChatOptions? options = null, + CancellationToken cancellationToken = default) + { + if (options is ContentSafetyChatOptions contentSafetyChatOptions) + { + Debug.Assert(messages.Any() && !messages.Skip(1).Any(), $"Expected exactly one message."); + string payload = messages.Single().Text; + + string annotationResult = + await _service.AnnotateAsync( + payload, + contentSafetyChatOptions.AnnotationTask, + contentSafetyChatOptions.EvaluatorName, + cancellationToken).ConfigureAwait(false); + + return new ChatResponse(new ChatMessage(ChatRole.Assistant, annotationResult)) + { + ModelId = Moniker + }; + } + else if (_originalChatClient is not null) + { + return await _originalChatClient.GetResponseAsync( + messages, + options, + cancellationToken).ConfigureAwait(false); + } + else + { + throw new NotSupportedException(); + } + } + + public async IAsyncEnumerable GetStreamingResponseAsync( + IEnumerable messages, + ChatOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (options is ContentSafetyChatOptions contentSafetyChatOptions) + { + Debug.Assert(messages.Any() && !messages.Skip(1).Any(), $"Expected exactly one message."); + string payload = messages.Single().Text; + + string annotationResult = + await _service.AnnotateAsync( + payload, + contentSafetyChatOptions.AnnotationTask, + contentSafetyChatOptions.EvaluatorName, + cancellationToken).ConfigureAwait(false); + + yield return new ChatResponseUpdate(ChatRole.Assistant, annotationResult) + { + ModelId = Moniker + }; + } + else if (_originalChatClient is not null) + { + await foreach (var update in + _originalChatClient.GetStreamingResponseAsync( + messages, + options, + cancellationToken).ConfigureAwait(false)) + { + yield return update; + } + } + else + { + throw new NotSupportedException(); + } + } + + public object? GetService(Type serviceType, object? serviceKey = null) + { + if (serviceKey is null) + { + if (serviceType == typeof(ChatClientMetadata)) + { + return _metadata; + } + else if (serviceType == typeof(ContentSafetyChatClient)) + { + return this; + } + } + + return _originalChatClient?.GetService(serviceType, serviceKey); + } + + public void Dispose() + => _originalChatClient?.Dispose(); +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyChatOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyChatOptions.cs new file mode 100644 index 00000000000..741bca9f790 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyChatOptions.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable S3604 +// S3604: Member initializer values should not be redundant. +// We disable this warning because it is a false positive arising from the analyzer's lack of support for C#'s primary +// constructor syntax. + +namespace Microsoft.Extensions.AI.Evaluation.Safety; + +internal sealed class ContentSafetyChatOptions(string annotationTask, string evaluatorName) : ChatOptions +{ + internal string AnnotationTask { get; } = annotationTask; + internal string EvaluatorName { get; } = evaluatorName; +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyEvaluator.cs index 252a79cf334..c85db5f4fa1 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyEvaluator.cs @@ -8,8 +8,12 @@ using System; using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI.Evaluation.Safety; @@ -18,82 +22,199 @@ namespace Microsoft.Extensions.AI.Evaluation.Safety; /// Azure AI Content Safety service to evaluate responses produced by an AI model for the presence of a variety of /// unsafe content such as protected material, vulnerable code, harmful content etc. ///
-/// -/// Specifies the Azure AI project that should be used and credentials that should be used when this -/// communicates with the Azure AI Content Safety service to perform evaluations. -/// /// -/// The name of the annotation task that should be used when this communicates -/// with the Azure AI Content Safety service to perform evaluations. +/// The name of the annotation task that should be used when communicating with the Azure AI Content Safety service to +/// perform evaluations. +/// +/// +/// A dictionary containing the mapping from the names of the metrics that are used when communicating with the Azure +/// AI Content Safety to the s of the s returned by +/// this . /// -/// The name of the derived . +#pragma warning disable S1694 // An abstract class should have both abstract and concrete methods public abstract class ContentSafetyEvaluator( - ContentSafetyServiceConfiguration contentSafetyServiceConfiguration, string contentSafetyServiceAnnotationTask, - string evaluatorName) : IEvaluator + IDictionary metricNames) : IEvaluator +#pragma warning restore S1694 { - private readonly ContentSafetyService _service = - new ContentSafetyService(contentSafetyServiceConfiguration, contentSafetyServiceAnnotationTask, evaluatorName); - /// - public abstract IReadOnlyCollection EvaluationMetricNames { get; } + public IReadOnlyCollection EvaluationMetricNames { get; } = [.. metricNames.Values]; /// - public abstract ValueTask EvaluateAsync( + public virtual ValueTask EvaluateAsync( IEnumerable messages, ChatResponse modelResponse, ChatConfiguration? chatConfiguration = null, IEnumerable? additionalContext = null, - CancellationToken cancellationToken = default); + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chatConfiguration); + + return EvaluateContentSafetyAsync( + chatConfiguration.ChatClient, + messages, + modelResponse, + additionalContext, + cancellationToken: cancellationToken); + } /// /// Evaluates the supplied using the Azure AI Content Safety Service and returns /// an containing one or more s. /// + /// + /// The that should be used to communicate with the Azure AI Content Safety Service when + /// performing evaluations. + /// /// /// The conversation history including the request that produced the supplied . /// /// The response that is to be evaluated. /// - /// Per conversation turn contextual information (beyond that which is available in ) - /// that the may need to accurately evaluate the supplied - /// . + /// Additional contextual information (beyond that which is available in ) that the + /// may need to accurately evaluate the supplied . /// /// /// An identifier that specifies the format of the payload that should be used when communicating with the Azure AI /// Content Safety service to perform evaluations. /// - /// - /// The name of the metric that should be used in the payload when communicating with the Azure AI Content Safety - /// service to perform evaluations. + /// + /// A flag that indicates whether the names of the metrics should be included in the payload + /// that is sent to the Azure AI Content Safety service when performing evaluations. /// /// /// A that can cancel the evaluation operation. /// /// An containing one or more s. - protected ValueTask EvaluateContentSafetyAsync( + protected async ValueTask EvaluateContentSafetyAsync( + IChatClient contentSafetyServiceChatClient, IEnumerable messages, ChatResponse modelResponse, - IEnumerable? additionalContext = null, + IEnumerable? additionalContext = null, string contentSafetyServicePayloadFormat = "HumanSystem", // ContentSafetyServicePayloadFormat.HumanSystem.ToString() - string? contentSafetyServiceMetricName = null, + bool includeMetricNamesInContentSafetyServicePayload = true, CancellationToken cancellationToken = default) { - ContentSafetyServicePayloadFormat payloadFormat = + _ = Throw.IfNull(contentSafetyServiceChatClient); + _ = Throw.IfNull(modelResponse); + + string payload; + string annotationResult; + IReadOnlyList? diagnostics; + EvaluationResult result; + Stopwatch stopwatch = Stopwatch.StartNew(); + + try + { + ContentSafetyServicePayloadFormat payloadFormat = #if NET - Enum.Parse(contentSafetyServicePayloadFormat); + Enum.Parse(contentSafetyServicePayloadFormat); #else - (ContentSafetyServicePayloadFormat)Enum.Parse( - typeof(ContentSafetyServicePayloadFormat), - contentSafetyServicePayloadFormat); + (ContentSafetyServicePayloadFormat)Enum.Parse( + typeof(ContentSafetyServicePayloadFormat), + contentSafetyServicePayloadFormat); #endif - return _service.EvaluateAsync( - messages, - modelResponse, - additionalContext, - payloadFormat, - metricNames: string.IsNullOrWhiteSpace(contentSafetyServiceMetricName) ? null : [contentSafetyServiceMetricName!], - cancellationToken); + IEnumerable conversation = [.. messages, .. modelResponse.Messages]; + + string evaluatorName = GetType().Name; + + IEnumerable? perTurnContext = null; + if (additionalContext is not null && additionalContext.Any()) + { + IReadOnlyList? relevantContext = FilterAdditionalContext(additionalContext); + +#pragma warning disable S1067 // Expressions should not be too complex + if (relevantContext is not null && relevantContext.Any() && + relevantContext.SelectMany(c => c.GetContents()) is IEnumerable content && content.Any() && + content.OfType() is IEnumerable textContent && textContent.Any() && + string.Join(Environment.NewLine, textContent.Select(c => c.Text)) is string contextString && + !string.IsNullOrWhiteSpace(contextString)) +#pragma warning restore S1067 + { + // Currently we only support supplying a context for the last conversation turn (which is the main one + // that is being evaluated). + perTurnContext = [contextString]; + } + } + + (payload, diagnostics) = + ContentSafetyServicePayloadUtilities.GetPayload( + payloadFormat, + conversation, + contentSafetyServiceAnnotationTask, + evaluatorName, + perTurnContext, + metricNames: includeMetricNamesInContentSafetyServicePayload ? metricNames.Keys : null, + cancellationToken); + + var payloadMessage = new ChatMessage(ChatRole.User, payload); + + ChatResponse annotationResponse = + await contentSafetyServiceChatClient.GetResponseAsync( + payloadMessage, + options: new ContentSafetyChatOptions(contentSafetyServiceAnnotationTask, evaluatorName), + cancellationToken: cancellationToken).ConfigureAwait(false); + + annotationResult = annotationResponse.Text; + result = ContentSafetyService.ParseAnnotationResult(annotationResult); + } + finally + { + stopwatch.Stop(); + } + + string duration = $"{stopwatch.Elapsed.TotalSeconds.ToString("F2", CultureInfo.InvariantCulture)} s"; + + UpdateMetrics(); + + return result; + + void UpdateMetrics() + { + foreach (EvaluationMetric metric in result.Metrics.Values) + { + string contentSafetyServiceMetricName = metric.Name; + if (metricNames.TryGetValue(contentSafetyServiceMetricName, out string? metricName)) + { + metric.Name = metricName; + } + + metric.AddOrUpdateMetadata(name: "evaluation-duration", value: duration); + + metric.Interpretation = + metric switch + { + BooleanMetric booleanMetric => booleanMetric.InterpretContentSafetyScore(), + NumericMetric numericMetric => numericMetric.InterpretContentSafetyScore(), + _ => metric.Interpretation + }; + + if (diagnostics is not null) + { + metric.AddDiagnostics(diagnostics); + } + +#pragma warning disable S125 // Sections of code should not be commented out + // The following commented code can be useful for debugging purposes. + // metric.LogJsonData(payload); + // metric.LogJsonData(annotationResult); +#pragma warning restore S125 + } + } } + + /// + /// Filters the s supplied by the caller via + /// down to just the s that are relevant to the evaluation being performed by this + /// . + /// + /// The s supplied by the caller. + /// + /// The s that are relevant to the evaluation being performed by this + /// . + /// + protected virtual IReadOnlyList? FilterAdditionalContext( + IEnumerable? additionalContext) + => null; } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyService.UrlCacheKey.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyService.UrlCacheKey.cs new file mode 100644 index 00000000000..54459c8b8d9 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyService.UrlCacheKey.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable S3604 +// S3604: Member initializer values should not be redundant. +// We disable this warning because it is a false positive arising from the analyzer's lack of support for C#'s primary +// constructor syntax. + +using System; + +namespace Microsoft.Extensions.AI.Evaluation.Safety; + +internal sealed partial class ContentSafetyService +{ + private sealed class UrlCacheKey(ContentSafetyServiceConfiguration configuration, string annotationTask) + { + internal ContentSafetyServiceConfiguration Configuration { get; } = configuration; + internal string AnnotationTask { get; } = annotationTask; + + public override bool Equals(object? other) + { + if (other is not UrlCacheKey otherKey) + { + return false; + } + else + { + return + otherKey.Configuration.SubscriptionId == Configuration.SubscriptionId && + otherKey.Configuration.ResourceGroupName == Configuration.ResourceGroupName && + otherKey.Configuration.ProjectName == Configuration.ProjectName && + otherKey.AnnotationTask == AnnotationTask; + } + } + + public override int GetHashCode() => + HashCode.Combine( + Configuration.SubscriptionId, + Configuration.ResourceGroupName, + Configuration.ProjectName, + AnnotationTask); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyService.UrlConfigurationComparer.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyService.UrlConfigurationComparer.cs deleted file mode 100644 index b3f96cbd80c..00000000000 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyService.UrlConfigurationComparer.cs +++ /dev/null @@ -1,37 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System; -using System.Collections.Generic; - -namespace Microsoft.Extensions.AI.Evaluation.Safety; - -internal sealed partial class ContentSafetyService -{ - private sealed class UrlConfigurationComparer : IEqualityComparer - { - internal static UrlConfigurationComparer Instance { get; } = new UrlConfigurationComparer(); - - public bool Equals(ContentSafetyServiceConfiguration? first, ContentSafetyServiceConfiguration? second) - { - if (first is null && second is null) - { - return true; - } - else if (first is null || second is null) - { - return false; - } - else - { - return - first.SubscriptionId == second.SubscriptionId && - first.ResourceGroupName == second.ResourceGroupName && - first.ProjectName == second.ProjectName; - } - } - - public int GetHashCode(ContentSafetyServiceConfiguration obj) - => HashCode.Combine(obj.SubscriptionId, obj.ResourceGroupName, obj.ProjectName); - } -} diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyService.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyService.cs index 63373507dfa..e258e1ea575 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyService.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyService.cs @@ -8,25 +8,19 @@ using System; using System.Collections.Concurrent; -using System.Collections.Generic; using System.Diagnostics; -using System.Globalization; using System.Linq; using System.Net; using System.Net.Http; using System.Net.Http.Headers; using System.Text.Json; -using System.Text.Json.Nodes; using System.Threading; using System.Threading.Tasks; using Azure.Core; namespace Microsoft.Extensions.AI.Evaluation.Safety; -internal sealed partial class ContentSafetyService( - ContentSafetyServiceConfiguration serviceConfiguration, - string annotationTask, - string evaluatorName) +internal sealed partial class ContentSafetyService(ContentSafetyServiceConfiguration serviceConfiguration) { private static HttpClient? _sharedHttpClient; private static HttpClient SharedHttpClient @@ -38,77 +32,14 @@ private static HttpClient SharedHttpClient } } - private static readonly ConcurrentDictionary _serviceUrlCache = - new ConcurrentDictionary(UrlConfigurationComparer.Instance); + private static readonly ConcurrentDictionary _serviceUrlCache = + new ConcurrentDictionary(); private readonly HttpClient _httpClient = serviceConfiguration.HttpClient ?? SharedHttpClient; private string? _serviceUrl; - public async ValueTask EvaluateAsync( - IEnumerable messages, - ChatResponse modelResponse, - IEnumerable? contexts = null, - ContentSafetyServicePayloadFormat payloadFormat = ContentSafetyServicePayloadFormat.HumanSystem, - IEnumerable? metricNames = null, - CancellationToken cancellationToken = default) - { - JsonObject payload; - IList? diagnostics; - string annotationResult; - string duration; - Stopwatch stopwatch = Stopwatch.StartNew(); - - try - { - string serviceUrl = await GetServiceUrlAsync(cancellationToken).ConfigureAwait(false); - - (payload, diagnostics) = - ContentSafetyServicePayloadUtilities.GetPayload( - payloadFormat, - messages, - modelResponse, - annotationTask, - evaluatorName, - contexts, - metricNames, - cancellationToken); - - string resultUrl = - await SubmitAnnotationRequestAsync(serviceUrl, payload, cancellationToken).ConfigureAwait(false); - - annotationResult = await FetchAnnotationResultAsync(resultUrl, cancellationToken).ConfigureAwait(false); - } - finally - { - stopwatch.Stop(); - duration = $"{stopwatch.Elapsed.TotalSeconds.ToString("F2", CultureInfo.InvariantCulture)} s"; - } - - EvaluationResult result = ParseAnnotationResult(annotationResult, duration); - - if (diagnostics is not null) - { - result.AddDiagnosticsToAllMetrics(diagnostics); - } - -#pragma warning disable S125 // Sections of code should not be commented out - // The following commented code can be useful for debugging purposes. - // result.AddDiagnosticsToAllMetrics( - // EvaluationDiagnostic.Informational( - // $""" - // Annotation Request Payload: - // {payload.ToJsonString(new JsonSerializerOptions { WriteIndented = true })} - // - // Annotation Result: - // {annotationResult} - // """)); -#pragma warning restore S125 - - return result; - } - - private static EvaluationResult ParseAnnotationResult(string annotationResponse, string evaluationDuration) + internal static EvaluationResult ParseAnnotationResult(string annotationResponse) { #pragma warning disable S125 // Sections of code should not be commented out // Example annotation response: @@ -189,28 +120,56 @@ private static EvaluationResult ParseAnnotationResult(string annotationResponse, } } - metric.AddOrUpdateMetadata("evaluation-duration", evaluationDuration); - result.Metrics[metric.Name] = metric; } return result; } - private async ValueTask GetServiceUrlAsync(CancellationToken cancellationToken) + internal async ValueTask AnnotateAsync( + string payload, + string annotationTask, + string evaluatorName, + CancellationToken cancellationToken = default) + { + string serviceUrl = + await GetServiceUrlAsync(annotationTask, evaluatorName, cancellationToken).ConfigureAwait(false); + + string resultUrl = + await SubmitAnnotationRequestAsync( + serviceUrl, + payload, + evaluatorName, + cancellationToken).ConfigureAwait(false); + + string annotationResult = + await FetchAnnotationResultAsync( + resultUrl, + evaluatorName, + cancellationToken).ConfigureAwait(false); + + return annotationResult; + } + + private async ValueTask GetServiceUrlAsync( + string annotationTask, + string evaluatorName, + CancellationToken cancellationToken) { if (_serviceUrl is not null) { return _serviceUrl; } - if (_serviceUrlCache.TryGetValue(serviceConfiguration, out string? serviceUrl)) + var key = new UrlCacheKey(serviceConfiguration, annotationTask); + if (_serviceUrlCache.TryGetValue(key, out string? serviceUrl)) { _serviceUrl = serviceUrl; return _serviceUrl; } - string discoveryUrl = await GetServiceDiscoveryUrlAsync(cancellationToken).ConfigureAwait(false); + string discoveryUrl = + await GetServiceDiscoveryUrlAsync(evaluatorName, cancellationToken).ConfigureAwait(false); serviceUrl = $"{discoveryUrl}/raisvc/v1.0" + @@ -220,24 +179,30 @@ private async ValueTask GetServiceUrlAsync(CancellationToken cancellatio await EnsureServiceAvailabilityAsync( serviceUrl, - annotationTask, + capability: annotationTask, + evaluatorName, cancellationToken).ConfigureAwait(false); - _ = _serviceUrlCache.TryAdd(serviceConfiguration, serviceUrl); + _ = _serviceUrlCache.TryAdd(key, serviceUrl); _serviceUrl = serviceUrl; return _serviceUrl; } - private async ValueTask GetServiceDiscoveryUrlAsync(CancellationToken cancellationToken) + private async ValueTask GetServiceDiscoveryUrlAsync( + string evaluatorName, + CancellationToken cancellationToken) { - string requestUrl = + string resourceManagerUrl = $"https://management.azure.com/subscriptions/{serviceConfiguration.SubscriptionId}" + $"/resourceGroups/{serviceConfiguration.ResourceGroupName}" + $"/providers/Microsoft.MachineLearningServices/workspaces/{serviceConfiguration.ProjectName}" + $"?api-version=2023-08-01-preview"; HttpResponseMessage response = - await GetResponseAsync(requestUrl, cancellationToken: cancellationToken).ConfigureAwait(false); + await GetResponseAsync( + resourceManagerUrl, + evaluatorName, + cancellationToken: cancellationToken).ConfigureAwait(false); if (!response.IsSuccessStatusCode) { @@ -276,12 +241,16 @@ private async ValueTask GetServiceDiscoveryUrlAsync(CancellationToken ca private async ValueTask EnsureServiceAvailabilityAsync( string serviceUrl, string capability, + string evaluatorName, CancellationToken cancellationToken) { string serviceAvailabilityUrl = $"{serviceUrl}/checkannotation"; HttpResponseMessage response = - await GetResponseAsync(serviceAvailabilityUrl, cancellationToken: cancellationToken).ConfigureAwait(false); + await GetResponseAsync( + serviceAvailabilityUrl, + evaluatorName, + cancellationToken: cancellationToken).ConfigureAwait(false); if (!response.IsSuccessStatusCode) { @@ -324,17 +293,18 @@ private async ValueTask EnsureServiceAvailabilityAsync( private async ValueTask SubmitAnnotationRequestAsync( string serviceUrl, - JsonObject payload, + string payload, + string evaluatorName, CancellationToken cancellationToken) { string annotationUrl = $"{serviceUrl}/submitannotation"; - string payloadString = payload.ToJsonString(); HttpResponseMessage response = await GetResponseAsync( annotationUrl, + evaluatorName, requestMethod: HttpMethod.Post, - payloadString, + payload, cancellationToken).ConfigureAwait(false); if (!response.IsSuccessStatusCode) @@ -372,6 +342,7 @@ await GetResponseAsync( private async ValueTask FetchAnnotationResultAsync( string resultUrl, + string evaluatorName, CancellationToken cancellationToken) { const int InitialDelayInMilliseconds = 500; @@ -385,7 +356,11 @@ private async ValueTask FetchAnnotationResultAsync( do { ++attempts; - response = await GetResponseAsync(resultUrl, cancellationToken: cancellationToken).ConfigureAwait(false); + response = + await GetResponseAsync( + resultUrl, + evaluatorName, + cancellationToken: cancellationToken).ConfigureAwait(false); if (response.StatusCode != HttpStatusCode.OK) { @@ -426,6 +401,7 @@ private async ValueTask FetchAnnotationResultAsync( private async ValueTask GetResponseAsync( string requestUrl, + string evaluatorName, HttpMethod? requestMethod = null, string? payload = null, CancellationToken cancellationToken = default) @@ -434,7 +410,7 @@ private async ValueTask GetResponseAsync( using var request = new HttpRequestMessage(requestMethod, requestUrl); request.Content = new StringContent(payload ?? string.Empty); - await AddHeadersAsync(request, cancellationToken).ConfigureAwait(false); + await AddHeadersAsync(request, evaluatorName, cancellationToken).ConfigureAwait(false); HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false); return response; @@ -442,6 +418,7 @@ private async ValueTask GetResponseAsync( private async ValueTask AddHeadersAsync( HttpRequestMessage httpRequestMessage, + string evaluatorName, CancellationToken cancellationToken = default) { string userAgent = diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyServiceConfiguration.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyServiceConfiguration.cs index f28b027feab..615485cad5c 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyServiceConfiguration.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyServiceConfiguration.cs @@ -12,8 +12,9 @@ namespace Microsoft.Extensions.AI.Evaluation.Safety; /// -/// Specifies the Azure AI project that should be used and credentials that should be used when a -/// communicates with the Azure AI Content Safety service to perform evaluations. +/// Specifies configuration parameters such as the Azure AI project that should be used, and the credentials that +/// should be used, when a communicates with the Azure AI Content Safety service +/// to perform evaluations. /// /// /// The Azure that should be used when authenticating requests. diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyServiceConfigurationExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyServiceConfigurationExtensions.cs new file mode 100644 index 00000000000..2021809c5e5 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyServiceConfigurationExtensions.cs @@ -0,0 +1,79 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI.Evaluation.Safety; + +/// +/// Extension methods for . +/// +public static class ContentSafetyServiceConfigurationExtensions +{ + /// + /// Returns a that can be used to communicate with the Azure AI Content Safety + /// service for performing content safety evaluations. + /// + /// + /// An object that specifies configuration parameters such as the Azure AI project that should be used, and the + /// credentials that should be used, when communicating with the Azure AI Content Safety service to perform + /// content safety evaluations. + /// + /// + /// The original , if any. If specified, the returned + /// will be based on , with the + /// in being replaced with + /// a new that can be used both to communicate with the AI model that + /// is configured to communicate with, as well as to communicate with + /// the Azure AI Content Safety service. + /// + /// + /// A that can be used to communicate with the Azure AI Content Safety service for + /// performing content safety evaluations. + /// + public static ChatConfiguration ToChatConfiguration( + this ContentSafetyServiceConfiguration contentSafetyServiceConfiguration, + ChatConfiguration? originalChatConfiguration = null) + { + _ = Throw.IfNull(contentSafetyServiceConfiguration); + +#pragma warning disable CA2000 // Dispose objects before they go out of scope. + // We can't dispose newChatClient here because it is returned to the caller. + + var newChatClient = + new ContentSafetyChatClient( + contentSafetyServiceConfiguration, + originalChatClient: originalChatConfiguration?.ChatClient); +#pragma warning restore CA2000 + + return new ChatConfiguration(newChatClient, originalChatConfiguration?.TokenCounter); + } + + /// + /// Returns an that can be used to communicate with the Azure AI Content Safety service + /// for performing content safety evaluations. + /// + /// + /// An object that specifies configuration parameters such as the Azure AI project that should be used, and the + /// credentials that should be used, when communicating with the Azure AI Content Safety service to perform + /// content safety evaluations. + /// + /// + /// The original , if any. If specified, the returned + /// will be a wrapper around that can be used both + /// to communicate with the AI model that is configured to communicate with, + /// as well as to communicate with the Azure AI Content Safety service. + /// + /// + /// A that can be used to communicate with the Azure AI Content Safety service for + /// performing content safety evaluations. + /// + public static IChatClient ToIChatClient( + this ContentSafetyServiceConfiguration contentSafetyServiceConfiguration, + IChatClient? originalChatClient = null) + { + _ = Throw.IfNull(contentSafetyServiceConfiguration); + + return new ContentSafetyChatClient(contentSafetyServiceConfiguration, originalChatClient); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyServicePayloadUtilities.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyServicePayloadUtilities.cs index 0c49b3fb902..bb12bc6afec 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyServicePayloadUtilities.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyServicePayloadUtilities.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics; using System.Linq; using System.Text.Json.Nodes; using System.Threading; @@ -21,51 +20,45 @@ internal static bool ContainsImage(this ChatMessage message) => message.Contents.Any(IsImage); internal static bool ContainsImage(this ChatResponse response) - => response.Messages.ContainImage(); + => response.Messages.ContainsImage(); - internal static bool ContainImage(this IEnumerable messages) - => messages.Any(ContainsImage); + internal static bool ContainsImage(this IEnumerable conversation) + => conversation.Any(ContainsImage); -#pragma warning disable S107 // Methods should not have too many parameters - internal static (JsonObject payload, IList? diagnostics) GetPayload( + internal static (string payload, IReadOnlyList? diagnostics) GetPayload( ContentSafetyServicePayloadFormat payloadFormat, - IEnumerable messages, - ChatResponse modelResponse, + IEnumerable conversation, string annotationTask, string evaluatorName, - IEnumerable? contexts = null, + IEnumerable? perTurnContext = null, IEnumerable? metricNames = null, CancellationToken cancellationToken = default) => -#pragma warning restore S107 payloadFormat switch { ContentSafetyServicePayloadFormat.HumanSystem => GetUserTextListPayloadWithEmbeddedXml( - messages, - modelResponse, + conversation, annotationTask, evaluatorName, - contexts, + perTurnContext, metricNames, cancellationToken: cancellationToken), ContentSafetyServicePayloadFormat.QuestionAnswer => GetUserTextListPayloadWithEmbeddedJson( - messages, - modelResponse, + conversation, annotationTask, evaluatorName, - contexts, + perTurnContext, metricNames, cancellationToken: cancellationToken), ContentSafetyServicePayloadFormat.QueryResponse => GetUserTextListPayloadWithEmbeddedJson( - messages, - modelResponse, + conversation, annotationTask, evaluatorName, - contexts, + perTurnContext, metricNames, questionPropertyName: "query", answerPropertyName: "response", @@ -73,11 +66,10 @@ internal static (JsonObject payload, IList? diagnostics) G ContentSafetyServicePayloadFormat.ContextCompletion => GetUserTextListPayloadWithEmbeddedJson( - messages, - modelResponse, + conversation, annotationTask, evaluatorName, - contexts, + perTurnContext, metricNames, questionPropertyName: "context", answerPropertyName: "completion", @@ -85,11 +77,10 @@ internal static (JsonObject payload, IList? diagnostics) G ContentSafetyServicePayloadFormat.Conversation => GetConversationPayload( - messages, - modelResponse, + conversation, annotationTask, evaluatorName, - contexts, + perTurnContext, metricNames, cancellationToken: cancellationToken), @@ -97,13 +88,12 @@ internal static (JsonObject payload, IList? diagnostics) G }; #pragma warning disable S107 // Methods should not have too many parameters - private static (JsonObject payload, IList? diagnostics) + private static (string payload, IReadOnlyList? diagnostics) GetUserTextListPayloadWithEmbeddedXml( - IEnumerable messages, - ChatResponse modelResponse, + IEnumerable conversation, string annotationTask, string evaluatorName, - IEnumerable? contexts = null, + IEnumerable? perTurnContext = null, IEnumerable? metricNames = null, string questionElementName = "Human", string answerElementName = "System", @@ -113,15 +103,14 @@ private static (JsonObject payload, IList? diagnostics) #pragma warning restore S107 { List> turns; - List? turnContexts; + List? normalizedPerTurnContext; List? diagnostics; - (turns, turnContexts, diagnostics, _) = - PreProcessMessages( - messages, - modelResponse, + (turns, normalizedPerTurnContext, diagnostics, _) = + PreProcessConversation( + conversation, evaluatorName, - contexts, + perTurnContext, returnLastTurnOnly: strategy is ContentSafetyServicePayloadStrategy.AnnotateLastTurn, cancellationToken: cancellationToken); @@ -143,9 +132,9 @@ private static (JsonObject payload, IList? diagnostics) item.Add(new XElement(answerElementName, answer.Text)); } - if (turnContexts is not null && turnContexts.Any()) + if (normalizedPerTurnContext is not null && normalizedPerTurnContext.Any()) { - item.Add(new XElement(contextElementName, turnContexts[index])); + item.Add(new XElement(contextElementName, normalizedPerTurnContext[index])); } return item; @@ -183,17 +172,16 @@ private static (JsonObject payload, IList? diagnostics) payload["MetricList"] = new JsonArray([.. metricNames]); } - return (payload, diagnostics); + return (payload.ToJsonString(), diagnostics); } #pragma warning disable S107 // Methods should not have too many parameters - private static (JsonObject payload, IList? diagnostics) + private static (string payload, IReadOnlyList? diagnostics) GetUserTextListPayloadWithEmbeddedJson( - IEnumerable messages, - ChatResponse modelResponse, + IEnumerable conversation, string annotationTask, string evaluatorName, - IEnumerable? contexts = null, + IEnumerable? perTurnContext = null, IEnumerable? metricNames = null, string questionPropertyName = "question", string answerPropertyName = "answer", @@ -209,15 +197,14 @@ private static (JsonObject payload, IList? diagnostics) } List> turns; - List? turnContexts; + List? normalizedPerTurnContext; List? diagnostics; - (turns, turnContexts, diagnostics, _) = - PreProcessMessages( - messages, - modelResponse, + (turns, normalizedPerTurnContext, diagnostics, _) = + PreProcessConversation( + conversation, evaluatorName, - contexts, + perTurnContext, returnLastTurnOnly: strategy is ContentSafetyServicePayloadStrategy.AnnotateLastTurn, cancellationToken: cancellationToken); @@ -239,9 +226,9 @@ private static (JsonObject payload, IList? diagnostics) item[answerPropertyName] = answer.Text; } - if (turnContexts is not null && turnContexts.Any()) + if (normalizedPerTurnContext is not null && normalizedPerTurnContext.Any()) { - item[contextPropertyName] = turnContexts[index]; + item[contextPropertyName] = normalizedPerTurnContext[index]; } return item; @@ -269,20 +256,17 @@ private static (JsonObject payload, IList? diagnostics) payload["MetricList"] = new JsonArray([.. metricNames]); } - return (payload, diagnostics); + return (payload.ToJsonString(), diagnostics); } -#pragma warning disable S107 // Methods should not have too many parameters - private static (JsonObject payload, IList? diagnostics) GetConversationPayload( - IEnumerable messages, - ChatResponse modelResponse, + private static (string payload, IReadOnlyList? diagnostics) GetConversationPayload( + IEnumerable conversation, string annotationTask, string evaluatorName, - IEnumerable? contexts = null, + IEnumerable? perTurnContext = null, IEnumerable? metricNames = null, ContentSafetyServicePayloadStrategy strategy = ContentSafetyServicePayloadStrategy.AnnotateConversation, CancellationToken cancellationToken = default) -#pragma warning restore S107 { if (strategy is ContentSafetyServicePayloadStrategy.AnnotateEachTurn) { @@ -291,16 +275,15 @@ private static (JsonObject payload, IList? diagnostics) Ge } List> turns; - List? turnContexts; + List? normalizedPerTurnContext; List? diagnostics; string contentType; - (turns, turnContexts, diagnostics, contentType) = - PreProcessMessages( - messages, - modelResponse, + (turns, normalizedPerTurnContext, diagnostics, contentType) = + PreProcessConversation( + conversation, evaluatorName, - contexts, + perTurnContext, returnLastTurnOnly: strategy is ContentSafetyServicePayloadStrategy.AnnotateLastTurn, areImagesSupported: true, cancellationToken); @@ -324,7 +307,9 @@ IEnumerable GetMessages(Dictionary turn, int tu { IEnumerable contents = GetContents(answer); - if (turnContexts is not null && turnContexts.Any() && turnContexts[turnIndex] is string context) + if (normalizedPerTurnContext is not null && + normalizedPerTurnContext.Any() && + normalizedPerTurnContext[turnIndex] is string context) { yield return new JsonObject { @@ -412,25 +397,25 @@ IEnumerable GetContents(ChatMessage message) // // On the other hand, if ContentSafetyServicePayloadStrategy.AnnotateConversation is used, the service will // produce a single annotation result for the entire conversation. - return (payload, diagnostics); + return (payload.ToJsonString(), diagnostics); } private static (List> turns, - List? turnContexts, + List? normalizedPerTurnContext, List? diagnostics, - string contentType) PreProcessMessages( - IEnumerable messages, - ChatResponse modelResponse, + string contentType) PreProcessConversation( + IEnumerable conversation, string evaluatorName, - IEnumerable? contexts = null, + IEnumerable? perTurnContext = null, bool returnLastTurnOnly = false, bool areImagesSupported = false, CancellationToken cancellationToken = default) { List> turns = []; Dictionary currentTurn = []; - List? turnContexts = contexts is null || !contexts.Any() ? null : [.. contexts]; + List? normalizedPerTurnContext = + perTurnContext is null || !perTurnContext.Any() ? null : [.. perTurnContext]; int currentTurnIndex = 0; int ignoredMessageCount = 0; @@ -448,7 +433,7 @@ void StartNewTurn() ++currentTurnIndex; } - foreach (ChatMessage message in messages) + foreach (ChatMessage message in conversation) { cancellationToken.ThrowIfCancellationRequested(); @@ -474,22 +459,6 @@ void StartNewTurn() } } - foreach (ChatMessage message in modelResponse.Messages) - { - cancellationToken.ThrowIfCancellationRequested(); - - if (message.Role == ChatRole.Assistant) - { - currentTurn["answer"] = message; - - StartNewTurn(); - } - else - { - ignoredMessageCount++; - } - } - if (returnLastTurnOnly) { turns.RemoveRange(index: 0, count: turns.Count - 1); @@ -541,9 +510,8 @@ void ValidateContents(ChatMessage message) diagnostics = [ EvaluationDiagnostic.Warning( $"The supplied conversation contained {ignoredMessageCount} messages with unsupported roles. " + - $"{evaluatorName} only considers messages with role '{ChatRole.User}' and '{ChatRole.Assistant}' in the supplied conversation history. " + - $"In the supplied model response, it only considers messages with role '{ChatRole.Assistant}'. " + - $"The unsupported messages were ignored.")]; + $"{evaluatorName} only considers messages with role '{ChatRole.User}' and '{ChatRole.Assistant}'. " + + $"The unsupported messages (which may include messages with role '{ChatRole.System}' and '{ChatRole.Tool}') were ignored.")]; } if (incompleteTurnCount > 0) @@ -578,43 +546,41 @@ void ValidateContents(ChatMessage message) } } - if (turnContexts is not null && turnContexts.Any()) + if (normalizedPerTurnContext is not null && normalizedPerTurnContext.Any()) { - if (turnContexts.Count > turns.Count) + if (normalizedPerTurnContext.Count > turns.Count) { - var ignoredContextCount = turnContexts.Count - turns.Count; + var ignoredContextCount = normalizedPerTurnContext.Count - turns.Count; diagnostics ??= []; diagnostics.Add( EvaluationDiagnostic.Warning( $"The supplied conversation contained {turns.Count} turns. " + - $"However, the supplied context object contained contexts for {turnContexts.Count} turns. " + - $"The initial {ignoredContextCount} contexts in the context object were ignored. " + - $"Only the last {turns.Count} contexts were used.")); + $"However, context for {normalizedPerTurnContext.Count} turns were supplied as part of the context collection. " + + $"The initial {ignoredContextCount} items from the context collection were ignored. " + + $"Only the last {turns.Count} items from the context collection were used.")); - turnContexts.RemoveRange(0, ignoredContextCount); + normalizedPerTurnContext.RemoveRange(0, ignoredContextCount); } - else if (turnContexts.Count < turns.Count) + else if (normalizedPerTurnContext.Count < turns.Count) { - int missingContextCount = turns.Count - turnContexts.Count; + int missingContextCount = turns.Count - normalizedPerTurnContext.Count; diagnostics ??= []; diagnostics.Add( EvaluationDiagnostic.Warning( $"The supplied conversation contained {turns.Count} turns. " + - $"However, the supplied context object only contained contexts for {turnContexts.Count} turns. " + - $"The initial {missingContextCount} turns in the conversations were evaluated without a context. " + - $"The supplied contexts were applied to the last {turnContexts.Count} turns.")); + $"However, context for only {normalizedPerTurnContext.Count} turns were supplied as part of the context collection. " + + $"The initial {missingContextCount} turns in the conversations were evaluated without any context. " + + $"The supplied items in the context collection were applied to the last {normalizedPerTurnContext.Count} turns.")); - turnContexts.InsertRange(0, Enumerable.Repeat(null, missingContextCount)); + normalizedPerTurnContext.InsertRange(0, Enumerable.Repeat(null, missingContextCount)); } - - Debug.Assert(turns.Count == turnContexts.Count, "The returned number of turns and contexts should match."); } string contentType = areImagesSupported && imagesCount > 0 ? "image" : "text"; - return (turns, turnContexts, diagnostics, contentType); + return (turns, normalizedPerTurnContext, diagnostics, contentType); } private static bool IsTextOrUsage(this AIContent content) diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/EvaluationMetricExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/EvaluationMetricExtensions.cs index cd17ceb7988..20246e3aaa2 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/EvaluationMetricExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/EvaluationMetricExtensions.cs @@ -1,11 +1,15 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Text.Json; +using System.Text.Json.Nodes; +using Microsoft.Shared.Diagnostics; + namespace Microsoft.Extensions.AI.Evaluation.Safety; internal static class EvaluationMetricExtensions { - internal static EvaluationMetricInterpretation InterpretHarmScore(this NumericMetric metric) + internal static EvaluationMetricInterpretation InterpretContentHarmScore(this NumericMetric metric) { EvaluationRating rating = metric.Value switch { @@ -29,7 +33,7 @@ internal static EvaluationMetricInterpretation InterpretHarmScore(this NumericMe : new EvaluationMetricInterpretation(rating); } - internal static EvaluationMetricInterpretation InterpretScore(this NumericMetric metric) + internal static EvaluationMetricInterpretation InterpretContentSafetyScore(this NumericMetric metric) { EvaluationRating rating = metric.Value switch { @@ -53,7 +57,9 @@ internal static EvaluationMetricInterpretation InterpretScore(this NumericMetric : new EvaluationMetricInterpretation(rating); } - internal static EvaluationMetricInterpretation InterpretScore(this BooleanMetric metric, bool passValue = false) + internal static EvaluationMetricInterpretation InterpretContentSafetyScore( + this BooleanMetric metric, + bool passValue = false) { EvaluationRating rating = metric.Value switch { @@ -69,4 +75,28 @@ internal static EvaluationMetricInterpretation InterpretScore(this BooleanMetric failed: true, reason: $"{metric.Name} is {passValue}."); } + + internal static void LogJsonData(this EvaluationMetric metric, string data) + { + JsonNode? jsonData = JsonNode.Parse(data); + + if (jsonData is null) + { + string message = + $""" + Failed to parse supplied {nameof(data)} below into a {nameof(JsonNode)}. + {data} + """; + + Throw.ArgumentException(paramName: nameof(data), message); + } + + metric.LogJsonData(jsonData); + } + + internal static void LogJsonData(this EvaluationMetric metric, JsonNode data) + { + string serializedData = data.ToJsonString(new JsonSerializerOptions { WriteIndented = true }); + metric.AddDiagnostic(EvaluationDiagnostic.Informational(serializedData)); + } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/GroundednessProEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/GroundednessProEvaluator.cs index 525bd8ede02..6af681d751f 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/GroundednessProEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/GroundednessProEvaluator.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI.Evaluation.Safety; @@ -31,16 +32,10 @@ namespace Microsoft.Extensions.AI.Evaluation.Safety; /// produce more accurate results than similar evaluations performed using a regular (non-finetuned) model. /// /// -/// -/// Specifies the Azure AI project that should be used and credentials that should be used when this -/// communicates with the Azure AI Content Safety service to perform -/// evaluations. -/// -public sealed class GroundednessProEvaluator(ContentSafetyServiceConfiguration contentSafetyServiceConfiguration) +public sealed class GroundednessProEvaluator() : ContentSafetyEvaluator( - contentSafetyServiceConfiguration, contentSafetyServiceAnnotationTask: "groundedness", - evaluatorName: nameof(GroundednessProEvaluator)) + metricNames: new Dictionary { ["generic_groundedness"] = GroundednessProMetricName }) { /// /// Gets the of the returned by @@ -48,9 +43,6 @@ public sealed class GroundednessProEvaluator(ContentSafetyServiceConfiguration c /// public static string GroundednessProMetricName => "Groundedness Pro"; - /// - public override IReadOnlyCollection EvaluationMetricNames => [GroundednessProMetricName]; - /// public override async ValueTask EvaluateAsync( IEnumerable messages, @@ -59,43 +51,34 @@ public override async ValueTask EvaluateAsync( IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) { - IEnumerable contexts; - if (additionalContext?.OfType().FirstOrDefault() - is GroundednessProEvaluatorContext context) - { - contexts = [context.GroundingContext]; - } - else - { - throw new InvalidOperationException( - $"A value of type '{nameof(GroundednessProEvaluatorContext)}' was not found in the '{nameof(additionalContext)}' collection."); - } - - const string GenericGroundednessContentSafetyServiceMetricName = "generic_groundedness"; + _ = Throw.IfNull(chatConfiguration); + _ = Throw.IfNull(modelResponse); EvaluationResult result = await EvaluateContentSafetyAsync( + chatConfiguration.ChatClient, messages, modelResponse, - contexts, + additionalContext, contentSafetyServicePayloadFormat: ContentSafetyServicePayloadFormat.QuestionAnswer.ToString(), - contentSafetyServiceMetricName: GenericGroundednessContentSafetyServiceMetricName, cancellationToken: cancellationToken).ConfigureAwait(false); - IEnumerable updatedMetrics = - result.Metrics.Values.Select( - metric => - { - if (metric.Name == GenericGroundednessContentSafetyServiceMetricName) - { - metric.Name = GroundednessProMetricName; - } - - return metric; - }); - - result = new EvaluationResult(updatedMetrics); - result.Interpret(metric => metric is NumericMetric numericMetric ? numericMetric.InterpretScore() : null); return result; } + + /// + protected override IReadOnlyList? FilterAdditionalContext( + IEnumerable? additionalContext) + { + if (additionalContext?.OfType().FirstOrDefault() + is GroundednessProEvaluatorContext context) + { + return [context]; + } + else + { + throw new InvalidOperationException( + $"A value of type '{nameof(GroundednessProEvaluatorContext)}' was not found in the '{nameof(additionalContext)}' collection."); + } + } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/GroundednessProEvaluatorContext.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/GroundednessProEvaluatorContext.cs index 3d293c27571..5d38b62496a 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/GroundednessProEvaluatorContext.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/GroundednessProEvaluatorContext.cs @@ -6,6 +6,8 @@ // We disable this warning because it is a false positive arising from the analyzer's lack of support for C#'s primary // constructor syntax. +using System.Collections.Generic; + namespace Microsoft.Extensions.AI.Evaluation.Safety; /// @@ -29,4 +31,8 @@ public sealed class GroundednessProEvaluatorContext(string groundingContext) : E /// in the information present in the supplied . /// public string GroundingContext { get; } = groundingContext; + + /// + public override IReadOnlyList GetContents() + => [new TextContent(GroundingContext)]; } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/HateAndUnfairnessEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/HateAndUnfairnessEvaluator.cs index 7932a54333a..718b742b29a 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/HateAndUnfairnessEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/HateAndUnfairnessEvaluator.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; + namespace Microsoft.Extensions.AI.Evaluation.Safety; /// @@ -18,17 +20,9 @@ namespace Microsoft.Extensions.AI.Evaluation.Safety; /// currently not supported. /// /// -/// -/// Specifies the Azure AI project that should be used and credentials that should be used when this -/// communicates with the Azure AI Content Safety service to perform -/// evaluations. -/// -public sealed class HateAndUnfairnessEvaluator(ContentSafetyServiceConfiguration contentSafetyServiceConfiguration) +public sealed class HateAndUnfairnessEvaluator() : ContentHarmEvaluator( - contentSafetyServiceConfiguration, - contentSafetyServiceMetricName: "hate_fairness", - metricName: HateAndUnfairnessMetricName, - evaluatorName: nameof(HateAndUnfairnessEvaluator)) + metricNames: new Dictionary { ["hate_fairness"] = HateAndUnfairnessMetricName }) { /// /// Gets the of the returned by diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/IndirectAttackEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/IndirectAttackEvaluator.cs index d2cb3c10840..f65a5ee82f6 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/IndirectAttackEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/IndirectAttackEvaluator.cs @@ -2,9 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; namespace Microsoft.Extensions.AI.Evaluation.Safety; @@ -45,58 +42,14 @@ namespace Microsoft.Extensions.AI.Evaluation.Safety; /// evaluated responses. Images and other multimodal content present in the evaluated responses will be ignored. /// /// -/// -/// Specifies the Azure AI project that should be used and credentials that should be used when this -/// communicates with the Azure AI Content Safety service to perform -/// evaluations. -/// -public sealed class IndirectAttackEvaluator(ContentSafetyServiceConfiguration contentSafetyServiceConfiguration) +public sealed class IndirectAttackEvaluator() : ContentSafetyEvaluator( - contentSafetyServiceConfiguration, contentSafetyServiceAnnotationTask: "xpia", - evaluatorName: nameof(IndirectAttackEvaluator)) + metricNames: new Dictionary { ["xpia"] = IndirectAttackMetricName }) { /// /// Gets the of the returned by /// . /// public static string IndirectAttackMetricName => "Indirect Attack"; - - /// - public override IReadOnlyCollection EvaluationMetricNames => [IndirectAttackMetricName]; - - /// - public override async ValueTask EvaluateAsync( - IEnumerable messages, - ChatResponse modelResponse, - ChatConfiguration? chatConfiguration = null, - IEnumerable? additionalContext = null, - CancellationToken cancellationToken = default) - { - const string IndirectAttackContentSafetyServiceMetricName = "xpia"; - - EvaluationResult result = - await EvaluateContentSafetyAsync( - messages, - modelResponse, - contentSafetyServicePayloadFormat: ContentSafetyServicePayloadFormat.HumanSystem.ToString(), - contentSafetyServiceMetricName: IndirectAttackContentSafetyServiceMetricName, - cancellationToken: cancellationToken).ConfigureAwait(false); - - IEnumerable updatedMetrics = - result.Metrics.Values.Select( - metric => - { - if (metric.Name == IndirectAttackContentSafetyServiceMetricName) - { - metric.Name = IndirectAttackMetricName; - } - - return metric; - }); - - result = new EvaluationResult(updatedMetrics); - result.Interpret(metric => metric is BooleanMetric booleanMetric ? booleanMetric.InterpretScore() : null); - return result; - } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ProtectedMaterialEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ProtectedMaterialEvaluator.cs index fdd76e7fdd9..d37bc17c94b 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ProtectedMaterialEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ProtectedMaterialEvaluator.cs @@ -2,9 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; -using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI.Evaluation.Safety; @@ -26,15 +26,17 @@ namespace Microsoft.Extensions.AI.Evaluation.Safety; /// indicating the absence of protected material. /// /// -/// -/// Specifies the Azure AI project that should be used and credentials that should be used when this -/// communicates with the Azure AI Content Safety service to perform evaluations. -/// -public sealed class ProtectedMaterialEvaluator(ContentSafetyServiceConfiguration contentSafetyServiceConfiguration) +public sealed class ProtectedMaterialEvaluator() : ContentSafetyEvaluator( - contentSafetyServiceConfiguration, contentSafetyServiceAnnotationTask: "protected material", - evaluatorName: nameof(ProtectedMaterialEvaluator)) + metricNames: + new Dictionary + { + ["protected_material"] = ProtectedMaterialMetricName, + ["artwork"] = ProtectedArtworkMetricName, + ["fictional_characters"] = ProtectedFictionalCharactersMetricName, + ["logos_and_brands"] = ProtectedLogosAndBrandsMetricName + }) { /// /// Gets the of the returned by @@ -60,15 +62,6 @@ public sealed class ProtectedMaterialEvaluator(ContentSafetyServiceConfiguration /// public static string ProtectedLogosAndBrandsMetricName => "Protected Logos And Brands"; - /// - public override IReadOnlyCollection EvaluationMetricNames => - [ - ProtectedMaterialMetricName, - ProtectedArtworkMetricName, - ProtectedFictionalCharactersMetricName, - ProtectedLogosAndBrandsMetricName - ]; - /// public override async ValueTask EvaluateAsync( IEnumerable messages, @@ -77,23 +70,32 @@ public override async ValueTask EvaluateAsync( IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) { + _ = Throw.IfNull(chatConfiguration); + _ = Throw.IfNull(modelResponse); + + IChatClient chatClient = chatConfiguration.ChatClient; + // First evaluate the text content in the conversation for protected material. EvaluationResult result = await EvaluateContentSafetyAsync( + chatClient, messages, modelResponse, contentSafetyServicePayloadFormat: ContentSafetyServicePayloadFormat.HumanSystem.ToString(), + includeMetricNamesInContentSafetyServicePayload: false, cancellationToken: cancellationToken).ConfigureAwait(false); // If images are present in the conversation, do a second evaluation for protected material in images. // The content safety service does not support evaluating both text and images in the same request currently. - if (messages.ContainImage() || modelResponse.ContainsImage()) + if (messages.ContainsImage() || modelResponse.ContainsImage()) { EvaluationResult imageResult = await EvaluateContentSafetyAsync( + chatClient, messages, modelResponse, contentSafetyServicePayloadFormat: ContentSafetyServicePayloadFormat.Conversation.ToString(), + includeMetricNamesInContentSafetyServicePayload: false, cancellationToken: cancellationToken).ConfigureAwait(false); foreach (EvaluationMetric imageMetric in imageResult.Metrics.Values) @@ -102,31 +104,6 @@ await EvaluateContentSafetyAsync( } } - IEnumerable updatedMetrics = - result.Metrics.Values.Select( - metric => - { - switch (metric.Name) - { - case "protected_material": - metric.Name = ProtectedMaterialMetricName; - return metric; - case "artwork": - metric.Name = ProtectedArtworkMetricName; - return metric; - case "fictional_characters": - metric.Name = ProtectedFictionalCharactersMetricName; - return metric; - case "logos_and_brands": - metric.Name = ProtectedLogosAndBrandsMetricName; - return metric; - default: - return metric; - } - }); - - result = new EvaluationResult(updatedMetrics); - result.Interpret(metric => metric is BooleanMetric booleanMetric ? booleanMetric.InterpretScore() : null); return result; } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/SelfHarmEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/SelfHarmEvaluator.cs index 60177b9a1d9..5946bbf0a7b 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/SelfHarmEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/SelfHarmEvaluator.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; + namespace Microsoft.Extensions.AI.Evaluation.Safety; /// @@ -18,17 +20,8 @@ namespace Microsoft.Extensions.AI.Evaluation.Safety; /// currently not supported. /// /// -/// -/// Specifies the Azure AI project that should be used and credentials that should be used when this -/// communicates with the Azure AI Content Safety service to perform -/// evaluations. -/// -public sealed class SelfHarmEvaluator(ContentSafetyServiceConfiguration contentSafetyServiceConfiguration) - : ContentHarmEvaluator( - contentSafetyServiceConfiguration, - contentSafetyServiceMetricName: "self_harm", - metricName: SelfHarmMetricName, - evaluatorName: nameof(SelfHarmEvaluator)) +public sealed class SelfHarmEvaluator() + : ContentHarmEvaluator(metricNames: new Dictionary { ["self_harm"] = SelfHarmMetricName }) { /// /// Gets the of the returned by diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/SexualEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/SexualEvaluator.cs index 7e74e012374..bd5445ddd86 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/SexualEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/SexualEvaluator.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; + namespace Microsoft.Extensions.AI.Evaluation.Safety; /// @@ -18,17 +20,8 @@ namespace Microsoft.Extensions.AI.Evaluation.Safety; /// currently not supported. /// /// -/// -/// Specifies the Azure AI project that should be used and credentials that should be used when this -/// communicates with the Azure AI Content Safety service to perform -/// evaluations. -/// -public sealed class SexualEvaluator(ContentSafetyServiceConfiguration contentSafetyServiceConfiguration) - : ContentHarmEvaluator( - contentSafetyServiceConfiguration, - contentSafetyServiceMetricName: "sexual", - metricName: SexualMetricName, - evaluatorName: nameof(SexualEvaluator)) +public sealed class SexualEvaluator() + : ContentHarmEvaluator(metricNames: new Dictionary { ["sexual"] = SexualMetricName }) { /// /// Gets the of the returned by diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/UngroundedAttributesEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/UngroundedAttributesEvaluator.cs index 73b3a2e8d93..79a5deb4888 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/UngroundedAttributesEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/UngroundedAttributesEvaluator.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI.Evaluation.Safety; @@ -34,16 +35,11 @@ namespace Microsoft.Extensions.AI.Evaluation.Safety; /// produce more accurate results than similar evaluations performed using a regular (non-finetuned) model. /// /// -/// -/// Specifies the Azure AI project that should be used and credentials that should be used when this -/// communicates with the Azure AI Content Safety service to perform -/// evaluations. -/// -public sealed class UngroundedAttributesEvaluator(ContentSafetyServiceConfiguration contentSafetyServiceConfiguration) +public sealed class UngroundedAttributesEvaluator() : ContentSafetyEvaluator( - contentSafetyServiceConfiguration, contentSafetyServiceAnnotationTask: "inference sensitive attributes", - evaluatorName: nameof(UngroundedAttributesEvaluator)) + metricNames: + new Dictionary { ["inference_sensitive_attributes"] = UngroundedAttributesMetricName }) { /// /// Gets the of the returned by @@ -51,9 +47,6 @@ public sealed class UngroundedAttributesEvaluator(ContentSafetyServiceConfigurat /// public static string UngroundedAttributesMetricName => "Ungrounded Attributes"; - /// - public override IReadOnlyCollection EvaluationMetricNames => [UngroundedAttributesMetricName]; - /// public override async ValueTask EvaluateAsync( IEnumerable messages, @@ -62,43 +55,34 @@ public override async ValueTask EvaluateAsync( IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) { - IEnumerable contexts; - if (additionalContext?.OfType().FirstOrDefault() - is UngroundedAttributesEvaluatorContext context) - { - contexts = [context.GroundingContext]; - } - else - { - throw new InvalidOperationException( - $"A value of type '{nameof(UngroundedAttributesEvaluatorContext)}' was not found in the '{nameof(additionalContext)}' collection."); - } - - const string UngroundedAttributesContentSafetyServiceMetricName = "inference_sensitive_attributes"; + _ = Throw.IfNull(chatConfiguration); + _ = Throw.IfNull(modelResponse); EvaluationResult result = await EvaluateContentSafetyAsync( + chatConfiguration.ChatClient, messages, modelResponse, - contexts, + additionalContext, contentSafetyServicePayloadFormat: ContentSafetyServicePayloadFormat.QueryResponse.ToString(), - contentSafetyServiceMetricName: UngroundedAttributesContentSafetyServiceMetricName, cancellationToken: cancellationToken).ConfigureAwait(false); - IEnumerable updatedMetrics = - result.Metrics.Values.Select( - metric => - { - if (metric.Name == UngroundedAttributesContentSafetyServiceMetricName) - { - metric.Name = UngroundedAttributesMetricName; - } - - return metric; - }); - - result = new EvaluationResult(updatedMetrics); - result.Interpret(metric => metric is BooleanMetric booleanMetric ? booleanMetric.InterpretScore() : null); return result; } + + /// + protected override IReadOnlyList? FilterAdditionalContext( + IEnumerable? additionalContext) + { + if (additionalContext?.OfType().FirstOrDefault() + is UngroundedAttributesEvaluatorContext context) + { + return [context]; + } + else + { + throw new InvalidOperationException( + $"A value of type '{nameof(UngroundedAttributesEvaluatorContext)}' was not found in the '{nameof(additionalContext)}' collection."); + } + } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/UngroundedAttributesEvaluatorContext.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/UngroundedAttributesEvaluatorContext.cs index f9ae1295676..59f553b3150 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/UngroundedAttributesEvaluatorContext.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/UngroundedAttributesEvaluatorContext.cs @@ -6,6 +6,8 @@ // We disable this warning because it is a false positive arising from the analyzer's lack of support for C#'s primary // constructor syntax. +using System.Collections.Generic; + namespace Microsoft.Extensions.AI.Evaluation.Safety; /// @@ -31,4 +33,8 @@ public sealed class UngroundedAttributesEvaluatorContext(string groundingContext /// whether the response contains information about the protected class or emotional state of a person. /// public string GroundingContext { get; } = groundingContext; + + /// + public override IReadOnlyList GetContents() + => [new TextContent(GroundingContext)]; } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ViolenceEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ViolenceEvaluator.cs index d80e6a52f1e..99928ff8184 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ViolenceEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ViolenceEvaluator.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; + namespace Microsoft.Extensions.AI.Evaluation.Safety; /// @@ -18,17 +20,8 @@ namespace Microsoft.Extensions.AI.Evaluation.Safety; /// currently not supported. /// /// -/// -/// Specifies the Azure AI project that should be used and credentials that should be used when this -/// communicates with the Azure AI Content Safety service to perform -/// evaluations. -/// -public sealed class ViolenceEvaluator(ContentSafetyServiceConfiguration contentSafetyServiceConfiguration) - : ContentHarmEvaluator( - contentSafetyServiceConfiguration, - contentSafetyServiceMetricName: "violence", - metricName: ViolenceMetricName, - evaluatorName: nameof(ViolenceEvaluator)) +public sealed class ViolenceEvaluator() + : ContentHarmEvaluator(metricNames: new Dictionary { ["violence"] = ViolenceMetricName }) { /// /// Gets the of the returned by diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationContext.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationContext.cs index ca5ccbab4cc..a9342922415 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationContext.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationContext.cs @@ -1,10 +1,37 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; + namespace Microsoft.Extensions.AI.Evaluation; /// /// A base class that represents additional contextual information (beyond that which is available in the conversation /// history) that an may need to accurately evaluate a supplied response. /// -public class EvaluationContext; +#pragma warning disable S1694 // An abstract class should have both abstract and concrete methods +public abstract class EvaluationContext +#pragma warning restore S1694 +{ + /// + /// Returns a list of objects that include all the information present in this + /// . + /// + /// + /// + /// This function allows us to decompose the information present in an into + /// objects for text, or objects for + /// images, and other similar objects for other modalities such as audio and video in the + /// future. + /// + /// + /// For simple s that only contain text, this function can return a single + /// object that includes the contained text. + /// + /// + /// + /// A list of objects that include all the information present in this + /// . + /// + public abstract IReadOnlyList GetContents(); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/SafetyEvaluatorTests.cs b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/SafetyEvaluatorTests.cs index ed8a04a2bdd..270c091ecb2 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/SafetyEvaluatorTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/SafetyEvaluatorTests.cs @@ -33,8 +33,8 @@ static SafetyEvaluatorTests() ResponseFormat = ChatResponseFormat.Text }; - ChatConfiguration chatConfiguration = Setup.CreateChatConfiguration(); - ChatClientMetadata? clientMetadata = chatConfiguration.ChatClient.GetService(); + ChatConfiguration llmChatConfiguration = Setup.CreateChatConfiguration(); + ChatClientMetadata? clientMetadata = llmChatConfiguration.ChatClient.GetService(); string version = $"Product Version: {Constants.Version}"; string date = $"Date: {DateTime.UtcNow:dddd, dd MMMM yyyy}"; @@ -53,14 +53,17 @@ static SafetyEvaluatorTests() resourceGroupName: Settings.Current.AzureResourceGroupName, projectName: Settings.Current.AzureAIProjectName); - IEvaluator hateAndUnfairnessEvaluator = new HateAndUnfairnessEvaluator(contentSafetyServiceConfiguration); - IEvaluator selfHarmEvaluator = new SelfHarmEvaluator(contentSafetyServiceConfiguration); - IEvaluator sexualEvaluator = new SexualEvaluator(contentSafetyServiceConfiguration); - IEvaluator violenceEvaluator = new ViolenceEvaluator(contentSafetyServiceConfiguration); - IEvaluator protectedMaterialEvaluator = new ProtectedMaterialEvaluator(contentSafetyServiceConfiguration); - IEvaluator groundednessProEvaluator = new GroundednessProEvaluator(contentSafetyServiceConfiguration); - IEvaluator ungroundedAttributesEvaluator = new UngroundedAttributesEvaluator(contentSafetyServiceConfiguration); - IEvaluator indirectAttackEvaluator = new IndirectAttackEvaluator(contentSafetyServiceConfiguration); + ChatConfiguration contentSafetyChatConfiguration = + contentSafetyServiceConfiguration.ToChatConfiguration(llmChatConfiguration); + + IEvaluator hateAndUnfairnessEvaluator = new HateAndUnfairnessEvaluator(); + IEvaluator selfHarmEvaluator = new SelfHarmEvaluator(); + IEvaluator sexualEvaluator = new SexualEvaluator(); + IEvaluator violenceEvaluator = new ViolenceEvaluator(); + IEvaluator protectedMaterialEvaluator = new ProtectedMaterialEvaluator(); + IEvaluator groundednessProEvaluator = new GroundednessProEvaluator(); + IEvaluator ungroundedAttributesEvaluator = new UngroundedAttributesEvaluator(); + IEvaluator indirectAttackEvaluator = new IndirectAttackEvaluator(); _contentSafetyReportingConfiguration = DiskBasedReportingConfiguration.Create( @@ -72,10 +75,13 @@ static SafetyEvaluatorTests() groundednessProEvaluator, ungroundedAttributesEvaluator, indirectAttackEvaluator], - chatConfiguration: chatConfiguration, + chatConfiguration: contentSafetyChatConfiguration, executionName: Constants.Version, tags: [version, date, projectName, testClass, provider, model, temperature, usesContext]); + ChatConfiguration contentSafetyChatConfigurationWithoutLLM = + contentSafetyServiceConfiguration.ToChatConfiguration(); + _imageContentSafetyReportingConfiguration = DiskBasedReportingConfiguration.Create( storageRootPath: Settings.Current.StorageRootPath, @@ -84,15 +90,17 @@ static SafetyEvaluatorTests() violenceEvaluator, protectedMaterialEvaluator, indirectAttackEvaluator], + chatConfiguration: contentSafetyChatConfigurationWithoutLLM, executionName: Constants.Version, tags: [version, date, projectName, testClass, provider, model, temperature]); - IEvaluator codeVulnerabilityEvaluator = new CodeVulnerabilityEvaluator(contentSafetyServiceConfiguration); + IEvaluator codeVulnerabilityEvaluator = new CodeVulnerabilityEvaluator(); _codeVulnerabilityReportingConfiguration = DiskBasedReportingConfiguration.Create( storageRootPath: Settings.Current.StorageRootPath, evaluators: [codeVulnerabilityEvaluator], + chatConfiguration: contentSafetyChatConfigurationWithoutLLM, executionName: Constants.Version, tags: [version, date, projectName, testClass, provider, model, temperature]); } @@ -130,13 +138,14 @@ Mars is approximately 34 million miles from Earth at its closest approach (oppos At its furthest point (conjunction), Mars is about 250 million miles from Earth. The distance varies due to the elliptical orbits of both planets. """; - IEnumerable contexts = + + IEnumerable additionalContext = [ new GroundednessProEvaluatorContext(groundingContext), new UngroundedAttributesEvaluatorContext(groundingContext) ]; - EvaluationResult result = await scenarioRun.EvaluateAsync(messages, response, contexts); + EvaluationResult result = await scenarioRun.EvaluateAsync(messages, response, additionalContext); Assert.False( result.ContainsDiagnostics(d => d.Severity is EvaluationDiagnosticSeverity.Error), @@ -175,6 +184,9 @@ Keep your responses concise staying under 100 words as much as possible. ChatResponse response2 = await chatClient.GetResponseAsync(messages, _chatOptions); + // At the moment, the GroundednessProEvaluator only supports evaluating the last turn of the conversation. We + // include context that is relevant to both turns as part of the string above. However, only the included + // context relevant to the last (second) turn matters for the evaluation. string groundingContext = """ Mercury's distance from Earth varies due to their elliptical orbits. @@ -186,15 +198,13 @@ At its closest (opposition), Jupiter is about 365 million miles away. At its furthest (conjunction), it can be approximately 601 million miles away. """; - // At the moment, the GroundednessProEvaluator only supports evaluating the last turn of the conversation. - // We include context for the first turn below, however, this is essentially redundant at the moment. - IEnumerable contexts = + IEnumerable additionalContext = [ new GroundednessProEvaluatorContext(groundingContext), new UngroundedAttributesEvaluatorContext(groundingContext) ]; - EvaluationResult result = await scenarioRun.EvaluateAsync(messages, response2, contexts); + EvaluationResult result = await scenarioRun.EvaluateAsync(messages, response2, additionalContext); Assert.False( result.ContainsDiagnostics(d => d.Severity is EvaluationDiagnosticSeverity.Error),