Skip to content

Commit

Permalink
finally complete, but could be better
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenHodgson committed Nov 7, 2023
1 parent 208cc9f commit e8c70ac
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 25 deletions.
10 changes: 6 additions & 4 deletions OpenAI-DotNet-Tests/TestFixture_09_FineTuning.cs
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,10 @@ public async Task Test_01_CreateFineTuneJob()
Assert.IsNotNull(OpenAIClient.FineTuningEndpoint);
var fileData = await CreateTestTrainingDataAsync();
var request = new CreateFineTuneJobRequest(Model.GPT3_5_Turbo, fileData);
var fineTuneResponse = await OpenAIClient.FineTuningEndpoint.CreateJobAsync(request);
var job = await OpenAIClient.FineTuningEndpoint.CreateJobAsync(request);

Assert.IsNotNull(fineTuneResponse);
var result = await OpenAIClient.FilesEndpoint.DeleteFileAsync(fileData);
Assert.IsTrue(result);
Assert.IsNotNull(job);
Console.WriteLine($"Started {job.Id} | Status: {job.Status}");
}

[Test]
Expand Down Expand Up @@ -177,6 +176,9 @@ public async Task Test_05_CancelFineTuneJob()
Assert.IsNotNull(result);
Assert.IsTrue(result);
Console.WriteLine($"{job.Id} -> cancelled");
result = await OpenAIClient.FilesEndpoint.DeleteFileAsync(job.TrainingFile);
Assert.IsTrue(result);
Console.WriteLine($"{job.TrainingFile} -> deleted");
}
}
}
Expand Down
125 changes: 125 additions & 0 deletions OpenAI-DotNet/Extensions/CustomEnumConverter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text.Encodings.Web;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace OpenAI.Extensions
{
internal sealed class CustomEnumConverter<T> : JsonConverter<T> where T : Enum
{
private readonly JsonNamingPolicy namingPolicy;
private readonly Dictionary<string, T> readCache = new();
private readonly Dictionary<T, JsonEncodedText> writeCache = new();

// This converter will only support up to 64 enum values (including flags) on serialization and deserialization
private const int NameCacheLimit = 64;

private const string ValueSeparator = ", ";

public CustomEnumConverter(JsonNamingPolicy namingPolicy, JsonSerializerOptions options, object[] knownValues)
{
this.namingPolicy = namingPolicy;

var continueProcessing = true;
for (var i = 0; i < knownValues?.Length; i++)
{
if (!TryProcessValue((T)knownValues[i]))
{
continueProcessing = false;
break;
}
}

if (continueProcessing)
{
var values = Enum.GetValues(typeof(T));

for (var i = 0; i < values.Length; i++)
{
var value = (T)values.GetValue(i)!;

if (!TryProcessValue(value))
{
break;
}
}
}

bool TryProcessValue(T value)
{
if (readCache.Count == NameCacheLimit)
{
Debug.Assert(writeCache.Count == NameCacheLimit);
return false;
}

FormatAndAddToCaches(value, options.Encoder);
return true;
}
}

public override T Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
string json;

if (reader.TokenType != JsonTokenType.String ||
(json = reader.GetString()) == null ||
!readCache.TryGetValue(json, out var value))
{
throw new JsonException();
}

return value;
}

public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options)
{
if (!writeCache.TryGetValue(value, out var formatted))
{
if (writeCache.Count == NameCacheLimit)
{
Debug.Assert(readCache.Count == NameCacheLimit);
throw new ArgumentOutOfRangeException(nameof(writeCache));
}

formatted = FormatAndAddToCaches(value, options.Encoder);
}

writer.WriteStringValue(formatted);
}

private JsonEncodedText FormatAndAddToCaches(T value, JavaScriptEncoder encoder)
{
var (valueFormattedToStr, valueEncoded) = FormatEnumValue(value.ToString(), namingPolicy, encoder);
readCache[valueFormattedToStr] = value;
writeCache[value] = valueEncoded;
return valueEncoded;
}

private static ValueTuple<string, JsonEncodedText> FormatEnumValue(string value, JsonNamingPolicy namingPolicy, JavaScriptEncoder encoder)
{
string converted;

if (!value.Contains(ValueSeparator))
{
converted = namingPolicy.ConvertName(value);
}
else
{
// todo: optimize implementation here by leveraging https://github.com/dotnet/runtime/issues/934.
var enumValues = value.Split(ValueSeparator);

for (var i = 0; i < enumValues.Length; i++)
{
enumValues[i] = namingPolicy.ConvertName(enumValues[i]);
}

converted = string.Join(ValueSeparator, enumValues);
}

return (converted, JsonEncodedText.Encode(converted, encoder));
}
}
}
29 changes: 29 additions & 0 deletions OpenAI-DotNet/Extensions/CustomEnumConverterFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using System;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace OpenAI.Extensions
{
internal sealed class CustomEnumConverterFactory : JsonConverterFactory
{
public override bool CanConvert(Type typeToConvert) => typeToConvert.IsEnum;

public override JsonConverter CreateConverter(Type typeToConvert, JsonSerializerOptions options)
{
object[] knownValues = null;

if (typeToConvert == typeof(BindingFlags))
{
knownValues = new object[] { BindingFlags.CreateInstance | BindingFlags.DeclaredOnly };
}

return (JsonConverter)Activator.CreateInstance(
typeof(CustomEnumConverter<>).MakeGenericType(typeToConvert),
BindingFlags.Instance | BindingFlags.Public,
binder: null,
args: new object[] { new SnakeCaseNamingPolicy(), options, knownValues },
culture: null)!;
}
}
}
8 changes: 1 addition & 7 deletions OpenAI-DotNet/Extensions/SnakeCaseNamingPolicy.cs
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
using System.Text.Json;
using System.Text.Json.Serialization;

namespace OpenAI.Extensions
{
public class SnakeCaseNamingPolicy : JsonNamingPolicy
internal sealed class SnakeCaseNamingPolicy : JsonNamingPolicy
{
public override string ConvertName(string name)
=> StringExtensions.ToSnakeCase(name);
}

public class JsonStringSnakeEnumConverter : JsonStringEnumConverter
{
public JsonStringSnakeEnumConverter() : base(new SnakeCaseNamingPolicy()) { }
}
}
14 changes: 1 addition & 13 deletions OpenAI-DotNet/FineTuning/FineTuneJob.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,7 @@ public sealed class FineTuneJob

[JsonInclude]
[JsonPropertyName("status")]
public string JobStatus { get; private set; }

[JsonIgnore]
public JobStatus Status => JobStatus switch
{
"validating_files" => FineTuning.JobStatus.ValidatingFiles,
"queued" => FineTuning.JobStatus.Queued,
"running" => FineTuning.JobStatus.Running,
"succeeded" => FineTuning.JobStatus.Succeeded,
"cancelled" => FineTuning.JobStatus.Cancelled,
"failed" => FineTuning.JobStatus.Failed,
_ => throw new ArgumentOutOfRangeException($"{nameof(JobStatus)}: {JobStatus}")
};
public JobStatus Status { get; private set; }

[JsonInclude]
[JsonPropertyName("validation_file")]
Expand Down
2 changes: 1 addition & 1 deletion OpenAI-DotNet/OpenAIClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ private HttpClient SetupClient(HttpClient client = null)
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
Converters =
{
new JsonStringEnumConverter(new SnakeCaseNamingPolicy())
new CustomEnumConverterFactory()
}
};

Expand Down

0 comments on commit e8c70ac

Please sign in to comment.