Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,33 @@ public sealed record class AIJsonSchemaCreateOptions
public Func<ParameterInfo, bool>? IncludeParameter { get; init; }

/// <summary>
/// Gets a value indicating whether to include the type keyword in inferred schemas for .NET enums.
/// Gets a <see cref="AIJsonSchemaTransformOptions"/> governing transformations on the JSON schema after it has been generated.
/// </summary>
public AIJsonSchemaTransformOptions? TransformOptions { get; init; }

/// <summary>
/// Gets a value indicating whether to include the type keyword in created schemas for .NET enums.
/// </summary>
[Obsolete("This property has been deprecated.")]
[System.ComponentModel.EditorBrowsable(System.ComponentModel.EditorBrowsableState.Never)]
public bool IncludeTypeInEnumSchemas { get; init; } = true;

/// <summary>
/// Gets a value indicating whether to generate schemas with the additionalProperties set to false for .NET objects.
/// </summary>
public bool DisallowAdditionalProperties { get; init; } = true;
[Obsolete("This property has been deprecated. Use the equivalent property in TransformOptions instead.")]
[System.ComponentModel.EditorBrowsable(System.ComponentModel.EditorBrowsableState.Never)]
public bool DisallowAdditionalProperties { get; init; }

/// <summary>
/// Gets a value indicating whether to include the $schema keyword in inferred schemas.
/// Gets a value indicating whether to include the $schema keyword in created schemas.
/// </summary>
public bool IncludeSchemaKeyword { get; init; }

/// <summary>
/// Gets a value indicating whether to mark all properties as required in the schema.
/// </summary>
[Obsolete("This property has been deprecated. Use the equivalent property in TransformOptions instead.")]
[System.ComponentModel.EditorBrowsable(System.ComponentModel.EditorBrowsableState.Never)]
public bool RequireAllProperties { get; init; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ public sealed record class AIJsonSchemaTransformOptions
/// </summary>
public bool UseNullableKeyword { get; init; }

/// <summary>
/// Gets a value indicating whether to move the default keyword to the description field in the schema.
/// </summary>
public bool MoveDefaultKeywordToDescription { get; init; }

/// <summary>
/// Gets the default options instance.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,11 @@ private static JsonSerializerOptions CreateDefaultOptions()
[JsonSerializable(typeof(AIFunctionArguments))]
[EditorBrowsable(EditorBrowsableState.Never)] // Never use JsonContext directly, use DefaultOptions instead.
private sealed partial class JsonContext : JsonSerializerContext;

[JsonSourceGenerationOptions(JsonSerializerDefaults.Web,
UseStringEnumConverter = true,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = false)]
[JsonSerializable(typeof(JsonNode))]
private sealed partial class JsonContextNoIndentation : JsonSerializerContext;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System.ComponentModel;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text.Json;
Expand Down Expand Up @@ -40,7 +39,7 @@ public static partial class AIJsonUtilities
private const string DefaultPropertyName = "default";
private const string RefPropertyName = "$ref";

/// <summary>The uri used when populating the $schema keyword in inferred schemas.</summary>
/// <summary>The uri used when populating the $schema keyword in created schemas.</summary>
private const string SchemaKeywordUri = "https://json-schema.org/draft/2020-12/schema";

// List of keywords used by JsonSchemaExporter but explicitly disallowed by some AI vendors.
Expand All @@ -54,7 +53,7 @@ public static partial class AIJsonUtilities
/// <param name="title">The title keyword used by the method schema.</param>
/// <param name="description">The description keyword used by the method schema.</param>
/// <param name="serializerOptions">The options used to extract the schema from the specified type.</param>
/// <param name="inferenceOptions">The options controlling schema inference.</param>
/// <param name="inferenceOptions">The options controlling schema creation.</param>
/// <returns>A JSON schema document encoded as a <see cref="JsonElement"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="method"/> is <see langword="null"/>.</exception>
public static JsonElement CreateFunctionJsonSchema(
Expand Down Expand Up @@ -106,13 +105,13 @@ public static JsonElement CreateFunctionJsonSchema(
inferenceOptions);

parameterSchemas.Add(parameter.Name, parameterSchema);
if (!parameter.IsOptional || inferenceOptions.RequireAllProperties)
if (!parameter.IsOptional)
{
(requiredProperties ??= []).Add((JsonNode)parameter.Name);
}
}

JsonObject schema = new();
JsonNode schema = new JsonObject();
if (inferenceOptions.IncludeSchemaKeyword)
{
schema[SchemaPropertyName] = SchemaKeywordUri;
Expand All @@ -136,7 +135,13 @@ public static JsonElement CreateFunctionJsonSchema(
schema[RequiredPropertyName] = requiredProperties;
}

return JsonSerializer.SerializeToElement(schema, JsonContext.Default.JsonNode);
// Finally, apply any schema transformations if specified.
if (inferenceOptions.TransformOptions is { } options)
{
schema = TransformSchema(schema, options);
}

return JsonSerializer.SerializeToElement(schema, JsonContextNoIndentation.Default.JsonNode);
}

/// <summary>Creates a JSON schema for the specified type.</summary>
Expand All @@ -145,7 +150,7 @@ public static JsonElement CreateFunctionJsonSchema(
/// <param name="hasDefaultValue"><see langword="true"/> if the parameter is optional; otherwise, <see langword="false"/>.</param>
/// <param name="defaultValue">The default value of the optional parameter, if applicable.</param>
/// <param name="serializerOptions">The options used to extract the schema from the specified type.</param>
/// <param name="inferenceOptions">The options controlling schema inference.</param>
/// <param name="inferenceOptions">The options controlling schema creation.</param>
/// <returns>A <see cref="JsonElement"/> representing the schema.</returns>
public static JsonElement CreateJsonSchema(
Type? type,
Expand All @@ -158,7 +163,14 @@ public static JsonElement CreateJsonSchema(
serializerOptions ??= DefaultOptions;
inferenceOptions ??= AIJsonSchemaCreateOptions.Default;
JsonNode schema = CreateJsonSchemaCore(type, parameterName: null, description, hasDefaultValue, defaultValue, serializerOptions, inferenceOptions);
return JsonSerializer.SerializeToElement(schema, JsonContext.Default.JsonNode);

// Finally, apply any schema transformations if specified.
if (inferenceOptions.TransformOptions is { } options)
{
schema = TransformSchema(schema, options);
}

return JsonSerializer.SerializeToElement(schema, JsonContextNoIndentation.Default.JsonNode);
}

/// <summary>Gets the default JSON schema to be used by types or functions.</summary>
Expand Down Expand Up @@ -203,25 +215,11 @@ private static JsonNode CreateJsonSchemaCore(

if (hasDefaultValue)
{
if (inferenceOptions.RequireAllProperties)
{
// Default values are only used in the context of optional parameters.
// Do not include a default keyword (since certain AI vendors don't support it)
// and instead embed its JSON in the description as a hint to the LLM.
string defaultValueJson = defaultValue is not null
? JsonSerializer.Serialize(defaultValue, serializerOptions.GetTypeInfo(defaultValue.GetType()))
: "null";

description = CreateDescriptionWithDefaultValue(description, defaultValueJson);
}
else
{
JsonNode? defaultValueNode = defaultValue is not null
? JsonSerializer.SerializeToNode(defaultValue, serializerOptions.GetTypeInfo(defaultValue.GetType()))
: null;
JsonNode? defaultValueNode = defaultValue is not null
? JsonSerializer.SerializeToNode(defaultValue, serializerOptions.GetTypeInfo(defaultValue.GetType()))
: null;

(schemaObj ??= [])[DefaultPropertyName] = defaultValueNode;
}
(schemaObj ??= [])[DefaultPropertyName] = defaultValueNode;
}

if (description is not null)
Expand Down Expand Up @@ -271,41 +269,11 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext schemaExporterContext, Js
}

// Include the type keyword in enum types
if (inferenceOptions.IncludeTypeInEnumSchemas && ctx.TypeInfo.Type.IsEnum && objSchema.ContainsKey(EnumPropertyName) && !objSchema.ContainsKey(TypePropertyName))
if (ctx.TypeInfo.Type.IsEnum && objSchema.ContainsKey(EnumPropertyName) && !objSchema.ContainsKey(TypePropertyName))
{
objSchema.InsertAtStart(TypePropertyName, "string");
}

// Disallow additional properties in object schemas
if (inferenceOptions.DisallowAdditionalProperties &&
objSchema.ContainsKey(PropertiesPropertyName) &&
!objSchema.ContainsKey(AdditionalPropertiesPropertyName))
{
objSchema.Add(AdditionalPropertiesPropertyName, (JsonNode)false);
}

// Mark all properties as required
if (inferenceOptions.RequireAllProperties &&
objSchema.TryGetPropertyValue(PropertiesPropertyName, out JsonNode? properties) &&
properties is JsonObject propertiesObj)
{
_ = objSchema.TryGetPropertyValue(RequiredPropertyName, out JsonNode? required);
if (required is not JsonArray { } requiredArray || requiredArray.Count != propertiesObj.Count)
{
requiredArray = [.. propertiesObj.Select(prop => (JsonNode)prop.Key)];
objSchema[RequiredPropertyName] = requiredArray;
}
}

// Strip default keywords and embed in description where required
if (inferenceOptions.RequireAllProperties &&
objSchema.TryGetPropertyValue(DefaultPropertyName, out JsonNode? defaultValue))
{
_ = objSchema.Remove(DefaultPropertyName);
string defaultValueJson = defaultValue?.ToJsonString() ?? "null";
localDescription = CreateDescriptionWithDefaultValue(localDescription, defaultValueJson);
}

// Filter potentially disallowed keywords.
foreach (string keyword in _schemaKeywordsDisallowedByAIVendors)
{
Expand All @@ -328,20 +296,8 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext schemaExporterContext, Js

if (ctx.Path.IsEmpty && hasDefaultValue)
{
// Add root-level default value metadata
if (inferenceOptions.RequireAllProperties)
{
// Default values are only used in the context of optional parameters.
// Do not include a default keyword (since certain AI vendors don't support it)
// and instead embed its JSON in the description as a hint to the LLM.
string defaultValueJson = JsonSerializer.Serialize(defaultValue, ctx.TypeInfo);
localDescription = CreateDescriptionWithDefaultValue(localDescription, defaultValueJson);
}
else
{
JsonNode? defaultValueNode = JsonSerializer.SerializeToNode(defaultValue, ctx.TypeInfo);
ConvertSchemaToObject(ref schema)[DefaultPropertyName] = defaultValueNode;
}
JsonNode? defaultValueNode = JsonSerializer.SerializeToNode(defaultValue, ctx.TypeInfo);
ConvertSchemaToObject(ref schema)[DefaultPropertyName] = defaultValueNode;
}

if (localDescription is not null)
Expand Down Expand Up @@ -423,7 +379,7 @@ private static void InsertAtStart(this JsonObject jsonObject, string key, JsonNo
jsonObject.Insert(0, key, value);
#else
jsonObject.Remove(key);
var copiedEntries = jsonObject.ToArray();
var copiedEntries = System.Linq.Enumerable.ToArray(jsonObject);
jsonObject.Clear();

jsonObject.Add(key, value);
Expand All @@ -434,13 +390,6 @@ private static void InsertAtStart(this JsonObject jsonObject, string key, JsonNo
#endif
}

private static string CreateDescriptionWithDefaultValue(string? existingDescription, string defaultValueJson)
{
return existingDescription is null
? $"Default value: {defaultValueJson}"
: $"{existingDescription} (Default value: {defaultValueJson})";
}

private static JsonElement ParseJsonElement(ReadOnlySpan<byte> utf8Json)
{
Utf8JsonReader reader = new(utf8Json);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ public static JsonElement TransformSchema(JsonElement schema, AIJsonSchemaTransf
}

JsonNode? nodeSchema = JsonSerializer.SerializeToNode(schema, JsonContext.Default.JsonElement);
JsonNode transformedSchema = TransformSchema(nodeSchema, transformOptions);
return JsonSerializer.SerializeToElement(transformedSchema, JsonContextNoIndentation.Default.JsonNode);
}

private static JsonNode TransformSchema(JsonNode? schema, AIJsonSchemaTransformOptions transformOptions)
{
List<string>? path = transformOptions.TransformSchemaNode is not null ? [] : null;
JsonNode transformedSchema = TransformSchemaCore(nodeSchema, transformOptions, path);
return JsonSerializer.Deserialize(transformedSchema, JsonContext.Default.JsonElement);
return TransformSchemaCore(schema, transformOptions, path);
}

private static JsonNode TransformSchemaCore(JsonNode? schema, AIJsonSchemaTransformOptions transformOptions, List<string>? path)
Expand Down Expand Up @@ -169,6 +174,18 @@ private static JsonNode TransformSchemaCore(JsonNode? schema, AIJsonSchemaTransf
}
}

if (transformOptions.MoveDefaultKeywordToDescription &&
schemaObj.TryGetPropertyValue(DefaultPropertyName, out JsonNode? defaultSchema))
{
string? description = schemaObj.TryGetPropertyValue(DescriptionPropertyName, out JsonNode? descriptionSchema) ? descriptionSchema?.GetValue<string>() : null;
string defaultValueJson = JsonSerializer.Serialize(defaultSchema, JsonContextNoIndentation.Default.JsonNode!);
description = description is null
? $"Default value: {defaultValueJson}"
: $"{description} (Default value: {defaultValueJson})";
schemaObj[DescriptionPropertyName] = description;
_ = schemaObj.Remove(DefaultPropertyName);
}

break;

default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ internal sealed class AzureAIInferenceChatClient : IChatClient
{
RequireAllProperties = true,
DisallowAdditionalProperties = true,
ConvertBooleanSchemas = true
ConvertBooleanSchemas = true,
MoveDefaultKeywordToDescription = true,
});

/// <summary>Metadata about the client.</summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ internal sealed partial class OpenAIChatClient : IChatClient
{
RequireAllProperties = true,
DisallowAdditionalProperties = true,
ConvertBooleanSchemas = true
ConvertBooleanSchemas = true,
MoveDefaultKeywordToDescription = true,
});

/// <summary>Gets the default OpenAI endpoint.</summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ public static partial class ChatClientStructuredOutputExtensions
private static readonly AIJsonSchemaCreateOptions _inferenceOptions = new()
{
IncludeSchemaKeyword = true,
DisallowAdditionalProperties = true,
IncludeTypeInEnumSchemas = true,
RequireAllProperties = true,
TransformOptions = new AIJsonSchemaTransformOptions
{
DisallowAdditionalProperties = true,
RequireAllProperties = true,
MoveDefaultKeywordToDescription = true,
},
};

/// <summary>Sends chat messages, requesting a response matching the type <typeparamref name="T"/>.</summary>
Expand Down
Loading
Loading