diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs index c78b495393b..dd4c0e28c71 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs @@ -155,21 +155,22 @@ void IDisposable.Dispose() foreach (var content in input.Contents) { - if (content is FunctionCallContent callRequest) + switch (content) { - message.ToolCalls.Add( - ChatToolCall.CreateFunctionToolCall( - callRequest.CallId, - callRequest.Name, - new(JsonSerializer.SerializeToUtf8Bytes( - callRequest.Arguments, - options.GetTypeInfo(typeof(IDictionary)))))); - } - } + case ErrorContent errorContent when errorContent.ErrorCode is nameof(message.Refusal): + message.Refusal = errorContent.Message; + break; - if (input.AdditionalProperties?.TryGetValue(nameof(message.Refusal), out string? refusal) is true) - { - message.Refusal = refusal; + case FunctionCallContent callRequest: + message.ToolCalls.Add( + ChatToolCall.CreateFunctionToolCall( + callRequest.CallId, + callRequest.Name, + new(JsonSerializer.SerializeToUtf8Bytes( + callRequest.Arguments, + options.GetTypeInfo(typeof(IDictionary)))))); + break; + } } yield return message; @@ -370,7 +371,7 @@ private static async IAsyncEnumerable FromOpenAIStreamingCha // add it to this function calling item. if (refusal is not null) { - (responseUpdate.AdditionalProperties ??= [])[nameof(ChatMessageContentPart.Refusal)] = refusal.ToString(); + responseUpdate.Contents.Add(new ErrorContent(refusal.ToString()) { ErrorCode = "Refusal" }); } // Propagate additional relevant metadata. @@ -450,6 +451,12 @@ private static ChatResponse FromOpenAIChatCompletion(ChatCompletion openAIComple } } + // And add error content for any refusals, which represent errors in generating output that conforms to a provided schema. + if (openAICompletion.Refusal is string refusal) + { + returnMessage.Contents.Add(new ErrorContent(refusal) { ErrorCode = nameof(openAICompletion.Refusal) }); + } + // Wrap the content in a ChatResponse to return. var response = new ChatResponse(returnMessage) { @@ -470,11 +477,6 @@ private static ChatResponse FromOpenAIChatCompletion(ChatCompletion openAIComple (response.AdditionalProperties ??= [])[nameof(openAICompletion.ContentTokenLogProbabilities)] = contentTokenLogProbs; } - if (openAICompletion.Refusal is string refusal) - { - (response.AdditionalProperties ??= [])[nameof(openAICompletion.Refusal)] = refusal; - } - if (openAICompletion.RefusalTokenLogProbabilities is { Count: > 0 } refusalTokenLogProbs) { (response.AdditionalProperties ??= [])[nameof(openAICompletion.RefusalTokenLogProbabilities)] = refusalTokenLogProbs; diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIResponseChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIResponseChatClient.cs index 92224379c10..0c6af51bb81 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIResponseChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIResponseChatClient.cs @@ -263,6 +263,7 @@ public async IAsyncEnumerable GetStreamingResponseAsync( MessageId = lastMessageId, ModelId = modelId, ResponseId = responseId, + Role = lastRole, ConversationId = responseId, Contents = [ @@ -274,6 +275,19 @@ public async IAsyncEnumerable GetStreamingResponseAsync( ], }; break; + + case StreamingResponseRefusalDoneUpdate refusalDone: + yield return new ChatResponseUpdate + { + CreatedAt = createdAt, + MessageId = lastMessageId, + ModelId = modelId, + ResponseId = responseId, + Role = lastRole, + ConversationId = responseId, + Contents = [new ErrorContent(refusalDone.Refusal) { ErrorCode = nameof(ResponseContentPart.Refusal) }], + }; + break; } } } @@ -539,9 +553,15 @@ private static List ToAIContents(IEnumerable con foreach (ResponseContentPart part in contents) { - if (part.Kind == ResponseContentPartKind.OutputText) + switch (part.Kind) { - results.Add(new TextContent(part.Text)); + case ResponseContentPartKind.OutputText: + results.Add(new TextContent(part.Text)); + break; + + case ResponseContentPartKind.Refusal: + results.Add(new ErrorContent(part.Refusal) { ErrorCode = nameof(ResponseContentPartKind.Refusal) }); + break; } } @@ -572,6 +592,10 @@ private static List ToOpenAIResponsesContent(IList