Skip to content

Commit

Permalink
Added CancellationToken parameters to the most complete convenience o…
Browse files Browse the repository at this point in the history
…verloads. (#53)

* added CancellationToken parameters to the ChatClient

* cleaned up RO extension

* added CTs to AudioClient

* added CTs to EmbeddingClient

* added CTs to FileClient

* added CTs to ModeriationClient

* added CTs to ImageClient

* added CTs to VectorStoreClient

* added CTs to AssistantClient

* fixed test sources as the callsite become ambiguous

* fixed formatting in one file

* PR feedback

* removed back the special CT overload
  • Loading branch information
KrzysztofCwalina authored Jun 14, 2024
1 parent b1fa082 commit 19a65a0
Show file tree
Hide file tree
Showing 11 changed files with 401 additions and 237 deletions.
252 changes: 157 additions & 95 deletions src/Custom/Assistants/AssistantClient.cs

Large diffs are not rendered by default.

31 changes: 19 additions & 12 deletions src/Custom/Audio/AudioClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.ClientModel;
using System.ClientModel.Primitives;
using System.IO;
using System.Threading;
using System.Threading.Tasks;

namespace OpenAI.Audio;
Expand Down Expand Up @@ -85,16 +86,17 @@ protected internal AudioClient(ClientPipeline pipeline, string model, Uri endpoi
/// <param name="text"> The text for the voice to speak. </param>
/// <param name="voice"> The voice to use. </param>
/// <param name="options"> Additional options to tailor the text-to-speech request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <returns> The generated audio in the specified output format. </returns>
public virtual async Task<ClientResult<BinaryData>> GenerateSpeechFromTextAsync(string text, GeneratedSpeechVoice voice, SpeechGenerationOptions options = null)
public virtual async Task<ClientResult<BinaryData>> GenerateSpeechFromTextAsync(string text, GeneratedSpeechVoice voice, SpeechGenerationOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(text, nameof(text));

options ??= new();
CreateSpeechGenerationOptions(text, voice, ref options);

using BinaryContent content = options.ToBinaryContent();
ClientResult result = await GenerateSpeechFromTextAsync(content, null).ConfigureAwait(false);
ClientResult result = await GenerateSpeechFromTextAsync(content, cancellationToken.ToRequestOptions()).ConfigureAwait(false);
return ClientResult.FromValue(result.GetRawResponse().Content, result.GetRawResponse());
}

Expand All @@ -108,16 +110,17 @@ public virtual async Task<ClientResult<BinaryData>> GenerateSpeechFromTextAsync(
/// <param name="text"> The text for the voice to speak. </param>
/// <param name="voice"> The voice to use. </param>
/// <param name="options"> Additional options to tailor the text-to-speech request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <returns> The generated audio in the specified output format. </returns>
public virtual ClientResult<BinaryData> GenerateSpeechFromText(string text, GeneratedSpeechVoice voice, SpeechGenerationOptions options = null)
public virtual ClientResult<BinaryData> GenerateSpeechFromText(string text, GeneratedSpeechVoice voice, SpeechGenerationOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(text, nameof(text));

options ??= new();
CreateSpeechGenerationOptions(text, voice, ref options);

using BinaryContent content = options.ToBinaryContent();
ClientResult result = GenerateSpeechFromText(content, (RequestOptions)null);
ClientResult result = GenerateSpeechFromText(content, cancellationToken.ToRequestOptions()); ;
return ClientResult.FromValue(result.GetRawResponse().Content, result.GetRawResponse());
}

Expand All @@ -135,10 +138,11 @@ public virtual ClientResult<BinaryData> GenerateSpeechFromText(string text, Gene
/// not match.
/// </param>
/// <param name="options"> Additional options to tailor the audio transcription request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <exception cref="ArgumentNullException"> <paramref name="audio"/> or <paramref name="audioFilename"/> is null. </exception>
/// <exception cref="ArgumentException"> <paramref name="audioFilename"/> is an empty string, and was expected to be non-empty. </exception>
/// <returns> The audio transcription. </returns>
public virtual async Task<ClientResult<AudioTranscription>> TranscribeAudioAsync(Stream audio, string audioFilename, AudioTranscriptionOptions options = null)
public virtual async Task<ClientResult<AudioTranscription>> TranscribeAudioAsync(Stream audio, string audioFilename, AudioTranscriptionOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(audio, nameof(audio));
Argument.AssertNotNullOrEmpty(audioFilename, nameof(audioFilename));
Expand All @@ -147,7 +151,7 @@ public virtual async Task<ClientResult<AudioTranscription>> TranscribeAudioAsync
CreateAudioTranscriptionOptions(audio, audioFilename, ref options);

using MultipartFormDataBinaryContent content = options.ToMultipartContent(audio, audioFilename);
ClientResult result = await TranscribeAudioAsync(content, content.ContentType).ConfigureAwait(false);
ClientResult result = await TranscribeAudioAsync(content, content.ContentType, cancellationToken.ToRequestOptions()).ConfigureAwait(false);
return ClientResult.FromValue(AudioTranscription.FromResponse(result.GetRawResponse()), result.GetRawResponse());
}

Expand All @@ -161,10 +165,11 @@ public virtual async Task<ClientResult<AudioTranscription>> TranscribeAudioAsync
/// not match.
/// </param>
/// <param name="options"> Additional options to tailor the audio transcription request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <exception cref="ArgumentNullException"> <paramref name="audio"/> or <paramref name="audioFilename"/> is null. </exception>
/// <exception cref="ArgumentException"> <paramref name="audioFilename"/> is an empty string, and was expected to be non-empty. </exception>
/// <returns> The audio transcription. </returns>
public virtual ClientResult<AudioTranscription> TranscribeAudio(Stream audio, string audioFilename, AudioTranscriptionOptions options = null)
public virtual ClientResult<AudioTranscription> TranscribeAudio(Stream audio, string audioFilename, AudioTranscriptionOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(audio, nameof(audio));
Argument.AssertNotNullOrEmpty(audioFilename, nameof(audioFilename));
Expand All @@ -173,7 +178,7 @@ public virtual ClientResult<AudioTranscription> TranscribeAudio(Stream audio, st
CreateAudioTranscriptionOptions(audio, audioFilename, ref options);

using MultipartFormDataBinaryContent content = options.ToMultipartContent(audio, audioFilename);
ClientResult result = TranscribeAudio(content, content.ContentType);
ClientResult result = TranscribeAudio(content, content.ContentType, cancellationToken.ToRequestOptions());
return ClientResult.FromValue(AudioTranscription.FromResponse(result.GetRawResponse()), result.GetRawResponse());
}

Expand Down Expand Up @@ -229,10 +234,11 @@ public virtual ClientResult<AudioTranscription> TranscribeAudio(string audioFile
/// not match.
/// </param>
/// <param name="options"> Additional options to tailor the audio translation request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <exception cref="ArgumentNullException"> <paramref name="audio"/> or <paramref name="audioFilename"/> is null. </exception>
/// <exception cref="ArgumentException"> <paramref name="audioFilename"/> is an empty string, and was expected to be non-empty. </exception>
/// <returns> The audio translation. </returns>
public virtual async Task<ClientResult<AudioTranslation>> TranslateAudioAsync(Stream audio, string audioFilename, AudioTranslationOptions options = null)
public virtual async Task<ClientResult<AudioTranslation>> TranslateAudioAsync(Stream audio, string audioFilename, AudioTranslationOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(audio, nameof(audio));
Argument.AssertNotNullOrEmpty(audioFilename, nameof(audioFilename));
Expand All @@ -241,7 +247,7 @@ public virtual async Task<ClientResult<AudioTranslation>> TranslateAudioAsync(St
CreateAudioTranslationOptions(audio, audioFilename, ref options);

using MultipartFormDataBinaryContent content = options.ToMultipartContent(audio, audioFilename);
ClientResult result = await TranslateAudioAsync(content, content.ContentType).ConfigureAwait(false);
ClientResult result = await TranslateAudioAsync(content, content.ContentType, cancellationToken.ToRequestOptions()).ConfigureAwait(false);
return ClientResult.FromValue(AudioTranslation.FromResponse(result.GetRawResponse()), result.GetRawResponse());
}

Expand All @@ -253,10 +259,11 @@ public virtual async Task<ClientResult<AudioTranslation>> TranslateAudioAsync(St
/// not match.
/// </param>
/// <param name="options"> Additional options to tailor the audio translation request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <exception cref="ArgumentNullException"> <paramref name="audio"/> or <paramref name="audioFilename"/> is null. </exception>
/// <exception cref="ArgumentException"> <paramref name="audioFilename"/> is an empty string, and was expected to be non-empty. </exception>
/// <returns> The audio translation. </returns>
public virtual ClientResult<AudioTranslation> TranslateAudio(Stream audio, string audioFilename, AudioTranslationOptions options = null)
public virtual ClientResult<AudioTranslation> TranslateAudio(Stream audio, string audioFilename, AudioTranslationOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(audio, nameof(audio));
Argument.AssertNotNullOrEmpty(audioFilename, nameof(audioFilename));
Expand All @@ -265,7 +272,7 @@ public virtual ClientResult<AudioTranslation> TranslateAudio(Stream audio, strin
CreateAudioTranslationOptions(audio, audioFilename, ref options);

using MultipartFormDataBinaryContent content = options.ToMultipartContent(audio, audioFilename);
ClientResult result = TranslateAudio(content, content.ContentType);
ClientResult result = TranslateAudio(content, content.ContentType, cancellationToken.ToRequestOptions());
return ClientResult.FromValue(AudioTranslation.FromResponse(result.GetRawResponse()), result.GetRawResponse());
}

Expand Down
25 changes: 15 additions & 10 deletions src/Custom/Chat/ChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace OpenAI.Chat;
Expand Down Expand Up @@ -68,16 +69,18 @@ protected internal ChatClient(ClientPipeline pipeline, string model, Uri endpoin
/// </summary>
/// <param name="messages"> The messages to provide as input and history for chat completion. </param>
/// <param name="options"> Additional options for the chat completion request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <returns> A result for a single chat completion. </returns>
public virtual async Task<ClientResult<ChatCompletion>> CompleteChatAsync(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null)
public virtual async Task<ClientResult<ChatCompletion>> CompleteChatAsync(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNullOrEmpty(messages, nameof(messages));

options ??= new();
CreateChatCompletionOptions(messages, ref options);

using BinaryContent content = options.ToBinaryContent();
ClientResult result = await CompleteChatAsync(content, null).ConfigureAwait(false);

ClientResult result = await CompleteChatAsync(content, cancellationToken.ToRequestOptions()).ConfigureAwait(false);
return ClientResult.FromValue(ChatCompletion.FromResponse(result.GetRawResponse()), result.GetRawResponse());
}

Expand All @@ -94,16 +97,17 @@ public virtual async Task<ClientResult<ChatCompletion>> CompleteChatAsync(params
/// </summary>
/// <param name="messages"> The messages to provide as input and history for chat completion. </param>
/// <param name="options"> Additional options for the chat completion request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <returns> A result for a single chat completion. </returns>
public virtual ClientResult<ChatCompletion> CompleteChat(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null)
public virtual ClientResult<ChatCompletion> CompleteChat(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNullOrEmpty(messages, nameof(messages));

options ??= new();
CreateChatCompletionOptions(messages, ref options);

using BinaryContent content = options.ToBinaryContent();
ClientResult result = CompleteChat(content, null);
ClientResult result = CompleteChat(content, cancellationToken.ToRequestOptions());
return ClientResult.FromValue(ChatCompletion.FromResponse(result.GetRawResponse()), result.GetRawResponse());

}
Expand All @@ -126,18 +130,19 @@ public virtual ClientResult<ChatCompletion> CompleteChat(params ChatMessage[] me
/// </remarks>
/// <param name="messages"> The messages to provide as input for chat completion. </param>
/// <param name="options"> Additional options for the chat completion request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <returns> A streaming result with incremental chat completion updates. </returns>
public virtual AsyncResultCollection<StreamingChatCompletionUpdate> CompleteChatStreamingAsync(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null)
public virtual AsyncResultCollection<StreamingChatCompletionUpdate> CompleteChatStreamingAsync(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(messages, nameof(messages));

options ??= new();
CreateChatCompletionOptions(messages, ref options, stream: true);

using BinaryContent content = options.ToBinaryContent();
RequestOptions requestOptions = new() { BufferResponse = false };

async Task<ClientResult> getResultAsync() =>
await CompleteChatAsync(content, requestOptions).ConfigureAwait(false);
await CompleteChatAsync(content, cancellationToken.ToRequestOptions(streaming: true)).ConfigureAwait(false);
return new AsyncStreamingChatCompletionUpdateCollection(getResultAsync);
}

Expand All @@ -164,17 +169,17 @@ public virtual AsyncResultCollection<StreamingChatCompletionUpdate> CompleteChat
/// </remarks>
/// <param name="messages"> The messages to provide as input for chat completion. </param>
/// <param name="options"> Additional options for the chat completion request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <returns> A streaming result with incremental chat completion updates. </returns>
public virtual ResultCollection<StreamingChatCompletionUpdate> CompleteChatStreaming(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null)
public virtual ResultCollection<StreamingChatCompletionUpdate> CompleteChatStreaming(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(messages, nameof(messages));

options ??= new();
CreateChatCompletionOptions(messages, ref options, stream: true);

using BinaryContent content = options.ToBinaryContent();
RequestOptions requestOptions = new() { BufferResponse = false };
ClientResult getResult() => CompleteChat(content, requestOptions);
ClientResult getResult() => CompleteChat(content, cancellationToken.ToRequestOptions(streaming: true));
return new StreamingChatCompletionUpdateCollection(getResult);
}

Expand Down
22 changes: 22 additions & 0 deletions src/Custom/Common/CancellationTokenExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System.ClientModel.Primitives;
using System.Threading;

internal static class CancellationTokenExtensions
{
public static RequestOptions ToRequestOptions(this CancellationToken cancellationToken, bool streaming = false)
{
if (cancellationToken == default)
{
if (!streaming) return null;
return StreamRequestOptions;
}

return new RequestOptions() {
CancellationToken = cancellationToken,
BufferResponse = !streaming,
};
}

private static RequestOptions StreamRequestOptions => _streamRequestOptions ??= new() { BufferResponse = false };
private static RequestOptions _streamRequestOptions;
}
Loading

0 comments on commit 19a65a0

Please sign in to comment.