From 9b5fa3fa2ad29a7c7cdd4c56b0a1b0b63fd007c6 Mon Sep 17 00:00:00 2001 From: Fabian Grewing Date: Mon, 22 May 2023 14:09:23 +0200 Subject: [PATCH] Call Complete on SubchannelCallTracker only after HttpContent has been disposed --- .../Balancer/Internal/BalancerHttpHandler.cs | 17 +++-- .../Balancer/Internal/HttpContentWrapper.cs | 71 +++++++++++++++++++ 2 files changed, 82 insertions(+), 6 deletions(-) create mode 100644 src/Grpc.Net.Client/Balancer/Internal/HttpContentWrapper.cs diff --git a/src/Grpc.Net.Client/Balancer/Internal/BalancerHttpHandler.cs b/src/Grpc.Net.Client/Balancer/Internal/BalancerHttpHandler.cs index b83554334..22f7debf7 100644 --- a/src/Grpc.Net.Client/Balancer/Internal/BalancerHttpHandler.cs +++ b/src/Grpc.Net.Client/Balancer/Internal/BalancerHttpHandler.cs @@ -149,13 +149,18 @@ protected override async Task SendAsync( { var responseMessage = await responseMessageTask.ConfigureAwait(false); - // TODO(JamesNK): This doesn't take into account long running streams. - // If there is response content then we need to wait until it is read to the end - // or the request is disposed. - result.SubchannelCallTracker?.Complete(new CompletionContext + if (result.SubchannelCallTracker is not null) { - Address = address - }); + if (responseMessage.Content is not null) + { + responseMessage.Content = new HttpContentWrapper(responseMessage.Content, + () => result.SubchannelCallTracker.Complete(new CompletionContext { Address = address })); + } + else + { + result.SubchannelCallTracker.Complete(new CompletionContext { Address = address }); + } + } return responseMessage; } diff --git a/src/Grpc.Net.Client/Balancer/Internal/HttpContentWrapper.cs b/src/Grpc.Net.Client/Balancer/Internal/HttpContentWrapper.cs new file mode 100644 index 000000000..17b8f2789 --- /dev/null +++ b/src/Grpc.Net.Client/Balancer/Internal/HttpContentWrapper.cs @@ -0,0 +1,71 @@ +using System.Net; + +namespace Grpc.Net.Client.Balancer.Internal; + +internal sealed class HttpContentWrapper : HttpContent +{ + private readonly HttpContent _inner; + private readonly Action _disposeAction; + private bool _disposed; + + public HttpContentWrapper(HttpContent inner, Action disposeAction) + { + _inner = inner; + _disposeAction = disposeAction; + + foreach (var kvp in inner.Headers) + { + Headers.TryAddWithoutValidation(kvp.Key, kvp.Value.ToArray()); + } + } + +#if NET5_0_OR_GREATER + + protected override void SerializeToStream(Stream stream, TransportContext? context, CancellationToken cancellationToken) + { + using var content = _inner.ReadAsStream(cancellationToken); + content.CopyTo(stream); + } + + protected override async Task SerializeToStreamAsync(Stream stream, TransportContext? context, CancellationToken cancellationToken) + { + var content = await _inner.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + await using (content.ConfigureAwait(false)) + { + await content.CopyToAsync(stream, cancellationToken).ConfigureAwait(false); + } + } + +#endif + + protected override async Task SerializeToStreamAsync(Stream stream, TransportContext? context) + { + var content = await _inner.ReadAsStreamAsync().ConfigureAwait(false); +#if NET5_0_OR_GREATER + await using (content.ConfigureAwait(false)) +#else + using (content) +#endif + { + await content.CopyToAsync(stream).ConfigureAwait(false); + } + } + + protected override bool TryComputeLength(out long length) + { + length = 0; + return false; + } + + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + + if (disposing && !_disposed) + { + _disposeAction(); + _inner.Dispose(); + _disposed = true; + } + } +}