Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for the new function calling capability in Chat Completions API #300

Merged
merged 7 commits into from
Jun 18, 2023
Merged
6 changes: 4 additions & 2 deletions OpenAI.Playground/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@
// | /|\ | /\ ___\o \o | o/ o/__ /\ | /|\ |
// | / \ / \ | \ /) | ( \ /o\ / ) | (\ / | / \ / \ |
// |-----------------------------------------------------------------------|

await ChatCompletionTestHelper.RunSimpleChatCompletionTest(sdk);
await ChatCompletionTestHelper.RunSimpleCompletionStreamTest(sdk);
await ChatCompletionTestHelper.RunChatFunctionCallTest(sdk);

// Whisper
await AudioTestHelper.RunSimpleAudioCreateTranscriptionTest(sdk);
await AudioTestHelper.RunSimpleAudioCreateTranslationTest(sdk);
//await AudioTestHelper.RunSimpleAudioCreateTranscriptionTest(sdk);
//await AudioTestHelper.RunSimpleAudioCreateTranslationTest(sdk);

//await ModelTestHelper.FetchModelsTest(sdk);
//await EditTestHelper.RunSimpleEditCreateTest(sdk);
Expand Down
82 changes: 82 additions & 0 deletions OpenAI.Playground/TestHelpers/ChatCompletionTestHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,86 @@ public static async Task RunSimpleCompletionStreamTest(IOpenAIService sdk)
throw;
}
}

public static async Task RunChatFunctionCallTest(IOpenAIService sdk)
{
ConsoleExtensions.WriteLine("Chat Function Call Testing is starting:", ConsoleColor.Cyan);

// example taken from:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_call_functions_with_chat_models.ipynb

var fn1 = new FunctionDefinitionBuilder("get_current_weather", "Get the current weather")
.AddParameter("location", "string", "The city and state, e.g. San Francisco, CA")
.AddParameter("format", "string", "The temperature unit to use. Infer this from the users location.",
@enum: new List<string> { "celsius", "fahrenheit" })
.Build();

var fn2 = new FunctionDefinitionBuilder("get_n_day_weather_forecast", "Get an N-day weather forecast")
.AddParameter("location", "string", "The city and state, e.g. San Francisco, CA")
.AddParameter("format", "string", "The temperature unit to use. Infer this from the users location.",
@enum: new List<string> { "celsius", "fahrenheit" })
.AddParameter("num_days", "integer", "The number of days to forecast")
.Build();

try
{
ConsoleExtensions.WriteLine("Chat Function Call Test:", ConsoleColor.DarkCyan);
var completionResults = sdk.ChatCompletion.CreateCompletionAsStream(new ChatCompletionCreateRequest
{
Messages = new List<ChatMessage>
{
ChatMessage.FromSystem("Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."),
ChatMessage.FromUser("Give me a weather report for Chicago, USA, for the next 5 days."),
},
Functions = new List<FunctionDefinition> { fn1, fn2 },
// optionally, to force a specific function:
// FunctionCall = new Dictionary<string, string> { { "name", "get_current_weather" } },
MaxTokens = 50,
Model = Models.ChatGpt3_5Turbo_0613
});

/* expected output along the lines of:

Message:
Function call: get_n_day_weather_forecast
location: Chicago, USA
format: celsius
num_days: 5
*/

await foreach (var completionResult in completionResults)
{
if (completionResult.Successful)
{
var choice = completionResult.Choices.First();
Console.WriteLine($"Message: {choice.Message.Content}");

var fn = choice.Message.FunctionCall;
if (fn != null)
{
Console.WriteLine($"Function call: {fn.Name}");
foreach (var entry in fn.ParseArguments())
{
Console.WriteLine($" {entry.Key}: {entry.Value}");
}
}
}
else
{
if (completionResult.Error == null)
{
throw new Exception("Unknown Error");
}

Console.WriteLine($"{completionResult.Error.Code}: {completionResult.Error.Message}");
}
}
}
catch (Exception e)
{
Console.WriteLine(e);
throw;
}
}

}
79 changes: 75 additions & 4 deletions OpenAI.SDK/Managers/OpenAIChatCompletions.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
using System.Runtime.CompilerServices;
using System.Text.Json;
using OpenAI.Extensions;
using OpenAI.Extensions;
using OpenAI.Interfaces;
using OpenAI.ObjectModels.RequestModels;
using OpenAI.ObjectModels.ResponseModels;
using System.Runtime.CompilerServices;
using System.Text.Json;

namespace OpenAI.Managers;

Expand All @@ -20,6 +20,9 @@ public async Task<ChatCompletionCreateResponse> CreateCompletion(ChatCompletionC
public async IAsyncEnumerable<ChatCompletionCreateResponse> CreateCompletionAsStream(ChatCompletionCreateRequest chatCompletionCreateRequest, string? modelId = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
// Helper data in case we need to reassemble a multi-packet response
ReassemblyContext ctx = new();

// Mark the request as streaming
chatCompletionCreateRequest.Stream = true;

Expand All @@ -29,6 +32,7 @@ public async IAsyncEnumerable<ChatCompletionCreateResponse> CreateCompletionAsSt
using var response = _httpClient.PostAsStreamAsync(_endpointProvider.ChatCompletionCreate(), chatCompletionCreateRequest, cancellationToken);
await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken);
using var reader = new StreamReader(stream);

// Continuously read the stream until the end of it
while (!reader.EndOfStream)
{
Expand Down Expand Up @@ -66,8 +70,75 @@ public async IAsyncEnumerable<ChatCompletionCreateResponse> CreateCompletionAsSt

if (null != block)
{
yield return block;
ctx.Process(block);

if (!ctx.IsFnAssemblyActive)
{
yield return block;
}
}
}
}

/// <summary>
/// This helper class attempts to reassemble a function call response
/// that was split up across several streamed chunks.
/// Note that this only works for the first message in each response,
/// and ignores the others; if OpenAI ever changes their response format
/// this will need to be adjusted.
/// </summary>
private class ReassemblyContext
{
private FunctionCall? FnCall = null;

public bool IsFnAssemblyActive => FnCall != null;



/// <summary>
/// Detects if a response block is a part of a multi-chunk
/// streamed function call response. As long as that's true,
/// it keeps accumulating block contents, and once function call
/// streaming is done, it produces the assembled results in the final block.
/// </summary>
/// <param name="block"></param>
public void Process(ChatCompletionCreateResponse block)
{
var firstChoice = block.Choices?.FirstOrDefault();
if (firstChoice == null) { return; } // not a valid state? nothing to do

var isStreamingFnCall = IsStreamingFunctionCall();

// If we're not yet assembling, and we just got a streaming block that has a function_call segment,
// this is the beginning of a function call assembly.
// We're going to steal the partial message and squirrel it away for the time being.
if (!IsFnAssemblyActive && isStreamingFnCall)
{
FnCall = firstChoice.Message.FunctionCall;
firstChoice.Message.FunctionCall = null;
}

// As long as we're assembling, keep on appending those args
if (IsFnAssemblyActive)
{
FnCall.Arguments += ExtractArgsSoFar();
}

// If we were assembling and it just finished, fill this block with the info we've assembled, and we're done.
if (IsFnAssemblyActive && !isStreamingFnCall)
{
firstChoice.Message ??= ChatMessage.FromAssistant(""); // just in case? not sure it's needed
firstChoice.Message.FunctionCall = FnCall;
FnCall = null;
}

// Returns true if we're actively streaming, and also have a partial function call in the response
bool IsStreamingFunctionCall() =>
firstChoice.FinishReason == null && // actively streaming, and
firstChoice.Message?.FunctionCall != null; // have a function call

string ExtractArgsSoFar() =>
block.Choices?.FirstOrDefault()?.Message?.FunctionCall?.Arguments ?? "";
}
}
}
48 changes: 42 additions & 6 deletions OpenAI.SDK/ObjectModels/Models.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ public enum Subject
/// </summary>
public static string Gpt_4 => "gpt-4";

/// <summary>
/// Same capabilities as the base gpt-4 mode but with 4x the context length. Will be updated with our latest model
/// iteration.
/// 32,768 tokens Up to Sep 2021
/// </summary>
public static string Gpt_4_32k => "gpt-4-32k";

/// <summary>
/// Snapshot of gpt-4 from March 14th 2023. Unlike gpt-4, this model will not receive updates, and will only be
/// supported for a three month period ending on June 14th 2023.
Expand All @@ -110,18 +117,26 @@ public enum Subject
public static string Gpt_4_0314 => "gpt-4-0314";

/// <summary>
/// Same capabilities as the base gpt-4 mode but with 4x the context length. Will be updated with our latest model
/// iteration.
/// Snapshot of gpt-4-32 from March 14th 2023. Unlike gpt-4-32k, this model will not receive updates, and will only be
/// supported for a three month period ending on June 14th 2023.
/// 32,768 tokens Up to Sep 2021
/// </summary>
public static string Gpt_4_32k => "gpt-4-32k";
public static string Gpt_4_32k_0314 => "gpt-4-32k-0314";

/// <summary>
/// Snapshot of gpt-4-32 from March 14th 2023. Unlike gpt-4-32k, this model will not receive updates, and will only be
/// supported for a three month period ending on June 14th 2023.
/// Snapshot of gpt-4 from June 13th 2023 with function calling data. Unlike gpt-4, this model will not receive updates,
/// and will be deprecated 3 months after a new version is released.
/// 8,192 tokens Up to Sep 2021
/// </summary>
public static string Gpt_4_0613 => "gpt-4-0613";

/// <summary>
/// Snapshot of gpt-4-32 from June 13th 2023. Unlike gpt-4-32k, this model will not receive updates,
/// and will be deprecated 3 months after a new version is released.
/// 32,768 tokens Up to Sep 2021
/// </summary>
public static string Gpt_4_32k_0314 => "gpt-4-32k-0314";
public static string Gpt_32k_0613 => "gpt-4-32k-0613";


public static string Ada => "ada";
public static string Babbage => "babbage";
Expand Down Expand Up @@ -175,13 +190,34 @@ public enum Subject
/// </summary>
public static string ChatGpt3_5Turbo => "gpt-3.5-turbo";

/// <summary>
/// Same capabilities as the standard gpt-3.5-turbo model but with 4 times the context.
/// 16,384 tokens Up to Sep 2021
/// </summary>
public static string ChatGpt3_5Turbo_16k => "gpt-3.5-turbo-16k";

/// <summary>
/// Snapshot of gpt-3.5-turbo from March 1st 2023. Unlike gpt-3.5-turbo, this model will not receive updates, and will
/// only be supported for a three month period ending on June 1st 2023.
/// 4,096 tokens Up to Sep 2021
/// </summary>
public static string ChatGpt3_5Turbo0301 => "gpt-3.5-turbo-0301";

/// <summary>
/// Snapshot of gpt-3.5-turbo from June 13th 2023 with function calling data. Unlike gpt-3.5-turbo,
/// this model will not receive updates, and will be deprecated 3 months after a new version is released.
/// 4,096 tokens Up to Sep 2021
/// </summary>
public static string ChatGpt3_5Turbo_0613 => "gpt-3.5-turbo-0613";

/// <summary>
/// Snapshot of gpt-3.5-turbo from June 13th 2023 with function calling data. Unlike gpt-3.5-turbo,
/// this model will not receive updates, and will be deprecated 3 months after a new version is released.
/// 4,096 tokens Up to Sep 2021
/// </summary>
public static string ChatGpt3_5Turbo_16k_0613 => "gpt-3.5-turbo-16k-0613";


public static string WhisperV1 => "whisper-1";

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ public class ChatCompletionCreateRequest : IModelValidate, IOpenAiModels.ITemper
[JsonPropertyName("messages")]
public IList<ChatMessage> Messages { get; set; }

/// <summary>
/// A list of functions the model may generate JSON inputs for.
/// </summary>
[JsonPropertyName("functions")]
public IList<FunctionDefinition>? Functions { get; set; }

/// <summary>
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are
Expand Down Expand Up @@ -134,4 +140,18 @@ public IEnumerable<ValidationResult> Validate()
/// </summary>
[JsonPropertyName("user")]
public string User { get; set; }


/// <summary>
/// String or object. Controls how the model responds to function calls.
/// "none" means the model does not call a function, and responds to the end-user.
/// "auto" means the model can pick between an end-user or calling a function.
/// "none" is the default when no functions are present. "auto" is the default if functions are present.
/// Specifying a particular function via {"name": "my_function"} forces the model to call that function.
/// (Note: in C# specify that as:
/// FunctionCall = new Dictionary&lt;string, string&gt; { { "name", "my_function" } }
/// ).
/// </summary>
[JsonPropertyName("function_call")]
public object? FunctionCall { get; set; }
}
Loading