Skip to content

Commit

Permalink
ClientModel: Move buffering into the transport (#41772)
Browse files Browse the repository at this point in the history
* Move buffering into the transport

* update

* nit

* pr fb

* move network timeout value initialization to transport

* Update contract for Response.Content

* pr fb

* Add exception if stream position is not 0

* bug fix

* pr fb

* Add CreateAsync factory method to ClientResultException
  • Loading branch information
annelo-msft authored Feb 13, 2024
1 parent a4d19fa commit 88b55fd
Show file tree
Hide file tree
Showing 16 changed files with 469 additions and 363 deletions.
11 changes: 4 additions & 7 deletions sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public ClientResultException(System.ClientModel.Primitives.PipelineResponse resp
protected ClientResultException(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { }
public ClientResultException(string message, System.ClientModel.Primitives.PipelineResponse? response = null, System.Exception? innerException = null) { }
public int Status { get { throw null; } protected set { } }
public static System.Threading.Tasks.Task<System.ClientModel.ClientResultException> CreateAsync(System.ClientModel.Primitives.PipelineResponse response, System.Exception? innerException = null) { throw null; }
public override void GetObjectData(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { }
public System.ClientModel.Primitives.PipelineResponse? GetRawResponse() { throw null; }
}
Expand Down Expand Up @@ -205,14 +206,16 @@ protected PipelineRequestHeaders() { }
public abstract partial class PipelineResponse : System.IDisposable
{
protected PipelineResponse() { }
public virtual System.BinaryData Content { get { throw null; } }
public abstract System.BinaryData Content { get; }
public abstract System.IO.Stream? ContentStream { get; set; }
public System.ClientModel.Primitives.PipelineResponseHeaders Headers { get { throw null; } }
public virtual bool IsError { get { throw null; } }
public abstract string ReasonPhrase { get; }
public abstract int Status { get; }
public abstract void Dispose();
protected abstract System.ClientModel.Primitives.PipelineResponseHeaders GetHeadersCore();
public abstract System.BinaryData ReadContent(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken));
public abstract System.Threading.Tasks.ValueTask<System.BinaryData> ReadContentAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken));
protected virtual void SetIsErrorCore(bool isError) { }
}
public abstract partial class PipelineResponseHeaders : System.Collections.Generic.IEnumerable<System.Collections.Generic.KeyValuePair<string, string>>, System.Collections.IEnumerable
Expand Down Expand Up @@ -246,10 +249,4 @@ protected void AssertNotFrozen() { }
public virtual void Freeze() { }
public void SetHeader(string name, string value) { }
}
public partial class ResponseBufferingPolicy : System.ClientModel.Primitives.PipelinePolicy
{
public ResponseBufferingPolicy() { }
public sealed override void Process(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList<System.ClientModel.Primitives.PipelinePolicy> pipeline, int currentIndex) { }
public sealed override System.Threading.Tasks.ValueTask ProcessAsync(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList<System.ClientModel.Primitives.PipelinePolicy> pipeline, int currentIndex) { throw null; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public ClientResultException(System.ClientModel.Primitives.PipelineResponse resp
protected ClientResultException(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { }
public ClientResultException(string message, System.ClientModel.Primitives.PipelineResponse? response = null, System.Exception? innerException = null) { }
public int Status { get { throw null; } protected set { } }
public static System.Threading.Tasks.Task<System.ClientModel.ClientResultException> CreateAsync(System.ClientModel.Primitives.PipelineResponse response, System.Exception? innerException = null) { throw null; }
public override void GetObjectData(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { }
public System.ClientModel.Primitives.PipelineResponse? GetRawResponse() { throw null; }
}
Expand Down Expand Up @@ -204,14 +205,16 @@ protected PipelineRequestHeaders() { }
public abstract partial class PipelineResponse : System.IDisposable
{
protected PipelineResponse() { }
public virtual System.BinaryData Content { get { throw null; } }
public abstract System.BinaryData Content { get; }
public abstract System.IO.Stream? ContentStream { get; set; }
public System.ClientModel.Primitives.PipelineResponseHeaders Headers { get { throw null; } }
public virtual bool IsError { get { throw null; } }
public abstract string ReasonPhrase { get; }
public abstract int Status { get; }
public abstract void Dispose();
protected abstract System.ClientModel.Primitives.PipelineResponseHeaders GetHeadersCore();
public abstract System.BinaryData ReadContent(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken));
public abstract System.Threading.Tasks.ValueTask<System.BinaryData> ReadContentAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken));
protected virtual void SetIsErrorCore(bool isError) { }
}
public abstract partial class PipelineResponseHeaders : System.Collections.Generic.IEnumerable<System.Collections.Generic.KeyValuePair<string, string>>, System.Collections.IEnumerable
Expand Down Expand Up @@ -245,10 +248,4 @@ protected void AssertNotFrozen() { }
public virtual void Freeze() { }
public void SetHeader(string name, string value) { }
}
public partial class ResponseBufferingPolicy : System.ClientModel.Primitives.PipelinePolicy
{
public ResponseBufferingPolicy() { }
public sealed override void Process(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList<System.ClientModel.Primitives.PipelinePolicy> pipeline, int currentIndex) { }
public sealed override System.Threading.Tasks.ValueTask ProcessAsync(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList<System.ClientModel.Primitives.PipelinePolicy> pipeline, int currentIndex) { throw null; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Globalization;
using System.Runtime.Serialization;
using System.Text;
using System.Threading.Tasks;

namespace System.ClientModel;

Expand All @@ -17,6 +18,12 @@ public class ClientResultException : Exception, ISerializable
private readonly PipelineResponse? _response;
private int _status;

public static async Task<ClientResultException> CreateAsync(PipelineResponse response, Exception? innerException = default)
{
string message = await CreateMessageAsync(response).ConfigureAwait(false);
return new ClientResultException(message, response, innerException);
}

/// <summary>
/// Gets the HTTP status code of the response. Returns. <code>0</code> if response was not received.
/// </summary>
Expand Down Expand Up @@ -66,8 +73,21 @@ public override void GetObjectData(SerializationInfo info, StreamingContext cont
public PipelineResponse? GetRawResponse() => _response;

private static string CreateMessage(PipelineResponse response)
=> CreateMessageSyncOrAsync(response, async: false).EnsureCompleted();

private static async ValueTask<string> CreateMessageAsync(PipelineResponse response)
=> await CreateMessageSyncOrAsync(response, async: true).ConfigureAwait(false);

private static async ValueTask<string> CreateMessageSyncOrAsync(PipelineResponse response, bool async)
{
response.BufferContent();
if (async)
{
await response.ReadContentAsync().ConfigureAwait(false);
}
else
{
response.ReadContent();
}

StringBuilder messageBuilder = new();

Expand Down
27 changes: 24 additions & 3 deletions sdk/core/System.ClientModel/src/Internal/CancellationHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,32 @@ private static void ThrowOperationCanceledException(Exception? innerException, C

/// <summary>Throws a cancellation exception if cancellation has been requested via <paramref name="cancellationToken"/>.</summary>
/// <param name="cancellationToken">The token to check for a cancellation request.</param>
internal static void ThrowIfCancellationRequested(CancellationToken cancellationToken)
/// <param name="innerException">The inner exception to wrap. May be null.</param>
internal static void ThrowIfCancellationRequested(CancellationToken cancellationToken, Exception? innerException = default)
{
if (cancellationToken.IsCancellationRequested)
{
ThrowOperationCanceledException(innerException: null, cancellationToken);
ThrowOperationCanceledException(innerException, cancellationToken);
}
}

/// <summary>Throws a cancellation exception if cancellation has been requested via <paramref name="messageToken"/> or <paramref name="timeoutToken"/>.</summary>
/// <param name="messageToken">The user-provided token.</param>
/// <param name="timeoutToken">The linked token that is cancelled on timeout provided token.</param>
/// <param name="innerException">The inner exception to use.</param>
/// <param name="timeout">The timeout used for the operation.</param>
#pragma warning disable CA1068 // Cancellation token has to be the last parameter
internal static void ThrowIfCancellationRequestedOrTimeout(CancellationToken messageToken, CancellationToken timeoutToken, Exception? innerException, TimeSpan timeout)
#pragma warning restore CA1068
{
ThrowIfCancellationRequested(messageToken, innerException);

if (timeoutToken.IsCancellationRequested)
{
throw CreateOperationCanceledException(
innerException,
timeoutToken,
$"The operation was cancelled because it exceeded the configured timeout of {timeout:g}. ");
}
}
}
}
17 changes: 8 additions & 9 deletions sdk/core/System.ClientModel/src/Internal/ReadTimeoutStream.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.ClientModel.Primitives;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -60,26 +59,26 @@ public override void Flush()

public override int Read(byte[] buffer, int offset, int count)
{
var source = StartTimeout(default, out bool dispose);
CancellationTokenSource source = StartTimeout(default, out bool dispose);
try
{
return _stream.Read(buffer, offset, count);
}
// We dispose stream on timeout so catch and check if cancellation token was cancelled
catch (IOException ex)
{
ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout);
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout);
throw;
}
// We dispose stream on timeout so catch and check if cancellation token was cancelled
catch (ObjectDisposedException ex)
{
ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout);
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout);
throw;
}
catch (OperationCanceledException ex)
{
ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout);
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout);
throw;
}
finally
Expand All @@ -90,7 +89,7 @@ public override int Read(byte[] buffer, int offset, int count)

public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
var source = StartTimeout(cancellationToken, out bool dispose);
CancellationTokenSource source = StartTimeout(cancellationToken, out bool dispose);
try
{
#pragma warning disable CA1835 // ReadAsync(Memory<>) overload is not available in all targets
Expand All @@ -100,18 +99,18 @@ public override async Task<int> ReadAsync(byte[] buffer, int offset, int count,
// We dispose stream on timeout so catch and check if cancellation token was cancelled
catch (IOException ex)
{
ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout);
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout);
throw;
}
// We dispose stream on timeout so catch and check if cancellation token was cancelled
catch (ObjectDisposedException ex)
{
ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout);
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout);
throw;
}
catch (OperationCanceledException ex)
{
ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout);
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout);
throw;
}
finally
Expand Down
44 changes: 44 additions & 0 deletions sdk/core/System.ClientModel/src/Internal/StreamExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ namespace System.ClientModel.Internal;

internal static class StreamExtensions
{
// Same value as Stream.CopyTo uses by default
private const int DefaultCopyBufferSize = 81920;

public static async Task WriteAsync(this Stream stream, ReadOnlyMemory<byte> buffer, CancellationToken cancellation = default)
{
Argument.AssertNotNull(stream, nameof(stream));
Expand Down Expand Up @@ -86,4 +89,45 @@ public static async Task WriteAsync(this Stream stream, ReadOnlySequence<byte> b
ArrayPool<byte>.Shared.Return(array);
}
}

public static async Task CopyToAsync(this Stream source, Stream destination, CancellationToken cancellationToken)
{
byte[] buffer = ArrayPool<byte>.Shared.Rent(DefaultCopyBufferSize);

try
{
while (true)
{
#pragma warning disable CA1835 // ReadAsync(Memory<>) overload is not available in all targets
int bytesRead = await source.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false);
#pragma warning restore // ReadAsync(Memory<>) overload is not available in all targets
if (bytesRead == 0)
break;
await destination.WriteAsync(new ReadOnlyMemory<byte>(buffer, 0, bytesRead), cancellationToken).ConfigureAwait(false);
}
}
finally
{
ArrayPool<byte>.Shared.Return(buffer);
}
}

public static void CopyTo(this Stream source, Stream destination, CancellationToken cancellationToken)
{
byte[] buffer = ArrayPool<byte>.Shared.Rent(DefaultCopyBufferSize);

try
{
int read;
while ((read = source.Read(buffer, 0, buffer.Length)) != 0)
{
cancellationToken.ThrowIfCancellationRequested();
destination.Write(buffer, 0, read);
}
}
finally
{
ArrayPool<byte>.Shared.Return(buffer);
}
}
}
Loading

0 comments on commit 88b55fd

Please sign in to comment.