Skip to content

Commit

Permalink
Merge pull request #558 from SebastianStehle/streaming-fixes
Browse files Browse the repository at this point in the history
Streaming fixes.
  • Loading branch information
kayhantolga authored May 18, 2024
2 parents 85fb09d + 62297aa commit 05d8db0
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 24 deletions.
1 change: 1 addition & 0 deletions OpenAI.Playground/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
// Tools
//await ChatCompletionTestHelper.RunChatFunctionCallTest(sdk);
//await ChatCompletionTestHelper.RunChatFunctionCallTestAsStream(sdk);
//await ChatCompletionTestHelper.RunSimpleCompletionStreamWithUsageTest(sdk);
//await BatchTestHelper.RunBatchOperationsTest(sdk);

// Whisper
Expand Down
129 changes: 124 additions & 5 deletions OpenAI.Playground/TestHelpers/ChatCompletionTestHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,65 @@ public static async Task RunSimpleCompletionStreamTest(IOpenAIService sdk)
}
}

public static async Task RunSimpleCompletionStreamWithUsageTest(IOpenAIService sdk)
{
ConsoleExtensions.WriteLine("Chat Completion Stream Testing is starting:", ConsoleColor.Cyan);
try
{
ConsoleExtensions.WriteLine("Chat Completion Stream Test:", ConsoleColor.DarkCyan);
var completionResult = sdk.ChatCompletion.CreateCompletionAsStream(new ChatCompletionCreateRequest
{
Messages = new List<ChatMessage>
{
new(StaticValues.ChatMessageRoles.System, "You are a helpful assistant."),
new(StaticValues.ChatMessageRoles.User, "Who won the world series in 2020?"),
new(StaticValues.ChatMessageRoles.System, "The Los Angeles Dodgers won the World Series in 2020."),
new(StaticValues.ChatMessageRoles.User, "Tell me a story about The Los Angeles Dodgers")
},
StreamOptions = new StreamOptions
{
IncludeUsage = true,
},
MaxTokens = 150,
Model = Models.Gpt_3_5_Turbo
});

await foreach (var completion in completionResult)
{
if (completion.Successful)
{
if (completion.Usage != null)
{
Console.WriteLine();
Console.WriteLine();
Console.WriteLine($"Usage: {completion.Usage.TotalTokens}");
}
else
{
Console.Write(completion.Choices.First().Message.Content);
}
}
else
{
if (completion.Error == null)
{
throw new Exception("Unknown Error");
}

Console.WriteLine($"{completion.Error.Code}: {completion.Error.Message}");
}
}

Console.WriteLine("");
Console.WriteLine("Complete");
}
catch (Exception e)
{
Console.WriteLine(e);
throw;
}
}

public static async Task RunChatFunctionCallTest(IOpenAIService sdk)
{
ConsoleExtensions.WriteLine("Chat Tool Functions Call Testing is starting:", ConsoleColor.Cyan);
Expand Down Expand Up @@ -124,21 +183,24 @@ public static async Task RunChatFunctionCallTest(IOpenAIService sdk)
try
{
ConsoleExtensions.WriteLine("Chat Function Call Test:", ConsoleColor.DarkCyan);
var completionResult = await sdk.ChatCompletion.CreateCompletion(new ChatCompletionCreateRequest

var request = new ChatCompletionCreateRequest
{
Messages = new List<ChatMessage>
{
ChatMessage.FromSystem("Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."),
ChatMessage.FromUser("Give me a weather report for Chicago, USA, for the next 5 days.")
},
Tools = new List<ToolDefinition> { ToolDefinition.DefineFunction(fn1), ToolDefinition.DefineFunction(fn2) ,ToolDefinition.DefineFunction(fn3) ,ToolDefinition.DefineFunction(fn4) },
Tools = new List<ToolDefinition> { ToolDefinition.DefineFunction(fn1), ToolDefinition.DefineFunction(fn2), ToolDefinition.DefineFunction(fn3), ToolDefinition.DefineFunction(fn4) },
// optionally, to force a specific function:
//ToolChoice = ToolChoice.FunctionChoice("get_current_weather"),
// or auto tool choice:
//ToolChoice = ToolChoice.Auto,
MaxTokens = 50,
Model = Models.Gpt_3_5_Turbo
});
};

var completionResult = await sdk.ChatCompletion.CreateCompletion(request);

/* expected output along the lines of:
Expand All @@ -158,6 +220,8 @@ public static async Task RunChatFunctionCallTest(IOpenAIService sdk)
var tools = choice.Message.ToolCalls;
if (tools != null)
{
request.Messages.Add(choice.Message);

Console.WriteLine($"Tools: {tools.Count}");
foreach (var toolCall in tools)
{
Expand All @@ -171,6 +235,11 @@ public static async Task RunChatFunctionCallTest(IOpenAIService sdk)
{
Console.WriteLine($" {entry.Key}: {entry.Value}");
}

if (fn.Name == "get_n_day_weather_forecast")
{
request.Messages.Add(ChatMessage.FromTool("10 Degrees", toolCall.Id!));
}
}
}
}
Expand All @@ -184,6 +253,22 @@ public static async Task RunChatFunctionCallTest(IOpenAIService sdk)

Console.WriteLine($"{completionResult.Error.Code}: {completionResult.Error.Message}");
}

var completionResultAfterTool = await sdk.ChatCompletion.CreateCompletion(request);

if (completionResultAfterTool.Successful)
{
Console.WriteLine(completionResultAfterTool.Choices.First().Message.Content);
}
else
{
if (completionResultAfterTool.Error == null)
{
throw new Exception("Unknown Error");
}

Console.WriteLine($"{completionResultAfterTool.Error.Code}: {completionResultAfterTool.Error.Message}");
}
}
catch (Exception e)
{
Expand Down Expand Up @@ -228,7 +313,8 @@ public static async Task RunChatFunctionCallTestAsStream(IOpenAIService sdk)
try
{
ConsoleExtensions.WriteLine("Chat Function Call Test:", ConsoleColor.DarkCyan);
var completionResults = sdk.ChatCompletion.CreateCompletionAsStream(new ChatCompletionCreateRequest

var request = new ChatCompletionCreateRequest
{
Messages = new List<ChatMessage>
{
Expand All @@ -246,7 +332,9 @@ public static async Task RunChatFunctionCallTestAsStream(IOpenAIService sdk)
ToolChoice = ToolChoice.Auto,
//MaxTokens = 50,
Model = Models.Gpt_4_1106_preview
});
};

var completionResults = sdk.ChatCompletion.CreateCompletionAsStream(request);

/* when testing weather forecasts, expected output should be along the lines of:
Expand Down Expand Up @@ -274,6 +362,8 @@ public static async Task RunChatFunctionCallTestAsStream(IOpenAIService sdk)
var tools = choice.Message.ToolCalls;
if (tools != null)
{
request.Messages.Add(choice.Message);

Console.WriteLine($"Tools: {tools.Count}");
for (int i = 0; i < tools.Count; i++)
{
Expand Down Expand Up @@ -313,6 +403,16 @@ public static async Task RunChatFunctionCallTestAsStream(IOpenAIService sdk)
// ignore
}
}

if (fn.Name == "google_search")
{
request.Messages.Add(ChatMessage.FromTool("Tom", toolCall.Id!));
}

if (fn.Name == "getURL")
{
request.Messages.Add(ChatMessage.FromTool("News", toolCall.Id!));
}
}
}
}
Expand All @@ -327,6 +427,25 @@ public static async Task RunChatFunctionCallTestAsStream(IOpenAIService sdk)
Console.WriteLine($"{completionResult.Error.Code}: {completionResult.Error.Message}");
}
}

var completionResultsAfterTool = sdk.ChatCompletion.CreateCompletionAsStream(request);

await foreach (var completion in completionResultsAfterTool)
{
if (completion.Successful)
{
Console.Write(completion.Choices.First().Message.Content);
}
else
{
if (completion.Error == null)
{
throw new Exception("Unknown Error");
}

Console.WriteLine($"{completion.Error.Code}: {completion.Error.Message}");
}
}
}
catch (Exception e)
{
Expand Down
12 changes: 9 additions & 3 deletions OpenAI.SDK/Extensions/HttpclientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public static async Task<string> PostFileAndReadAsStringAsync(this HttpClient cl
return await HandleResponseContent<TResponse>(response, cancellationToken);
}

private static async Task<TResponse> HandleResponseContent<TResponse>(this HttpResponseMessage response, CancellationToken cancellationToken) where TResponse : BaseResponse, new()
public static async Task<TResponse> HandleResponseContent<TResponse>(this HttpResponseMessage response, CancellationToken cancellationToken) where TResponse : BaseResponse, new()
{
TResponse result;

Expand All @@ -144,7 +144,14 @@ public static async Task<string> PostFileAndReadAsStringAsync(this HttpClient cl
}

result.HttpStatusCode = response.StatusCode;
result.HeaderValues = new()
result.HeaderValues = response.ParseHeaders();

return result;
}

public static ResponseHeaderValues ParseHeaders(this HttpResponseMessage response)
{
return new ResponseHeaderValues()
{
Date = response.Headers.Date,
Connection = response.Headers.Connection?.ToString(),
Expand Down Expand Up @@ -181,7 +188,6 @@ public static async Task<string> PostFileAndReadAsStringAsync(this HttpClient cl
Version = response.Headers.GetHeaderValue("openai-version")
}
};
return result;
}

#if NETSTANDARD2_0
Expand Down
47 changes: 32 additions & 15 deletions OpenAI.SDK/Managers/OpenAIChatCompletions.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.CompilerServices;
using System.Text.Json;
using OpenAI.Extensions;
using OpenAI.Interfaces;
Expand Down Expand Up @@ -32,6 +31,17 @@ public async IAsyncEnumerable<ChatCompletionCreateResponse> CreateCompletionAsSt
chatCompletionCreateRequest.ProcessModelId(modelId, _defaultModelId);

using var response = _httpClient.PostAsStreamAsync(_endpointProvider.ChatCompletionCreate(), chatCompletionCreateRequest, cancellationToken);

if (!response.IsSuccessStatusCode)
{
yield return await response.HandleResponseContent<ChatCompletionCreateResponse>(cancellationToken);
yield break;
}

// Ensure that we parse headers only once to improve performance a little bit.
var httpStatusCode = response.StatusCode;
var headerValues = response.ParseHeaders();

await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken);
using var reader = new StreamReader(stream);

Expand All @@ -41,12 +51,12 @@ public async IAsyncEnumerable<ChatCompletionCreateResponse> CreateCompletionAsSt
cancellationToken.ThrowIfCancellationRequested();

var line = await reader.ReadLineAsync();

// Break the loop if we have reached the end of the stream
if (line == null)
{
break;
}

// Skip empty lines
if (string.IsNullOrEmpty(line))
{
Expand Down Expand Up @@ -87,6 +97,8 @@ public async IAsyncEnumerable<ChatCompletionCreateResponse> CreateCompletionAsSt

if (!ctx.IsFnAssemblyActive)
{
block.HttpStatusCode = httpStatusCode;
block.HeaderValues = headerValues;
yield return block;
}
}
Expand Down Expand Up @@ -151,13 +163,14 @@ public void Process(ChatCompletionCreateResponse block)
if (tcMetadata.index > -1)
{
//Handles just ToolCall type == "function"
using var argumentsList = ExtractArgsSoFar().GetEnumerator();
using var argumentsList = ExtractArgsSoFar()
.GetEnumerator();
var existItems = argumentsList.MoveNext();

if (existItems)
{
//toolcall item must exists as added in previous steps, otherwise First() will raise an InvalidOperationException
var tc = _deltaFnCallList!.Where(t => t.Index == tcMetadata.index).First();
var tc = _deltaFnCallList!.First(t => t.Index == tcMetadata.index);
tc.FunctionCall!.Arguments += argumentsList.Current;
argumentsList.MoveNext();
}
Expand All @@ -168,6 +181,9 @@ public void Process(ChatCompletionCreateResponse block)
if (IsFnAssemblyActive && isStreamingFnCallEnd)
{
firstChoice.Message ??= ChatMessage.FromAssistant(""); // just in case? not sure it's needed
// TODO When more than one function call is in a single index, OpenAI only returns the role delta at the beginning, which causes an issue.
// TODO The current solution addresses this problem, but we need to fix it by using the role of the index.
firstChoice.Message.Role ??= "assistant";
firstChoice.Message.ToolCalls = new List<ToolCall>(_deltaFnCallList);
_deltaFnCallList.Clear();
}
Expand All @@ -176,17 +192,18 @@ public void Process(ChatCompletionCreateResponse block)
bool IsStreamingFunctionCall()
{
return firstChoice.FinishReason == null && // actively streaming, is a tool call main item, and have a function call
firstChoice.Message?.ToolCalls?.Count > 0 &&
(firstChoice.Message?.ToolCalls.Any(t => t.FunctionCall != null
&& !string.IsNullOrEmpty(t.Id)
&& t.Type == StaticValues.CompletionStatics.ToolType.Function) ?? false);
firstChoice.Message?.ToolCalls?.Count > 0 && (firstChoice.Message?.ToolCalls.Any(t => t.FunctionCall != null && !string.IsNullOrEmpty(t.Id) && t.Type == StaticValues.CompletionStatics.ToolType.Function) ?? false);
}

(int index, string? id, string? type) GetToolCallMetadata()
{
var tc = block.Choices?.FirstOrDefault()?.Message?.ToolCalls?
.Where(t => t.FunctionCall != null)
.Select(t => t).FirstOrDefault();
var tc = block.Choices
?.FirstOrDefault()
?.Message
?.ToolCalls
?.Where(t => t.FunctionCall != null)
.Select(t => t)
.FirstOrDefault();

return tc switch
{
Expand All @@ -197,12 +214,12 @@ bool IsStreamingFunctionCall()

IEnumerable<string> ExtractArgsSoFar()
{
var toolCalls = block.Choices?.FirstOrDefault()?.Message?.ToolCalls;
var toolCalls = block.Choices?.FirstOrDefault()
?.Message?.ToolCalls;

if (toolCalls != null)
{
var functionCallList = toolCalls
.Where(t => t.FunctionCall != null)
var functionCallList = toolCalls.Where(t => t.FunctionCall != null)
.Select(t => t.FunctionCall);

foreach (var functionCall in functionCallList)
Expand Down
Loading

0 comments on commit 05d8db0

Please sign in to comment.