diff --git a/src/libraries/System.Net.Requests/src/System/Net/HttpWebResponse.cs b/src/libraries/System.Net.Requests/src/System/Net/HttpWebResponse.cs index 7b0e9b90681fc9..96bdffff721763 100644 --- a/src/libraries/System.Net.Requests/src/System/Net/HttpWebResponse.cs +++ b/src/libraries/System.Net.Requests/src/System/Net/HttpWebResponse.cs @@ -346,7 +346,7 @@ public override Stream GetResponseStream() return contentStream; } - return new TruncatedReadStream(contentStream, maxErrorResponseLength); + return new TruncatedReadStream(contentStream, (long)maxErrorResponseLength * 1024); } return Stream.Null; @@ -381,8 +381,9 @@ private void CheckDisposed() private static string GetHeaderValueAsString(IEnumerable values) => string.Join(", ", values); - internal sealed class TruncatedReadStream(Stream innerStream, int maxSize) : Stream + internal sealed class TruncatedReadStream(Stream innerStream, long maxSize) : Stream { + private long _maxRemainingLength = maxSize; public override bool CanRead => true; public override bool CanSeek => false; public override bool CanWrite => false; @@ -399,8 +400,8 @@ public override int Read(byte[] buffer, int offset, int count) public override int Read(Span buffer) { - int readBytes = innerStream.Read(buffer.Slice(0, Math.Min(buffer.Length, maxSize))); - maxSize -= readBytes; + int readBytes = innerStream.Read(buffer.Slice(0, (int)Math.Min(buffer.Length, _maxRemainingLength))); + _maxRemainingLength -= readBytes; return readBytes; } @@ -411,9 +412,9 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) { - int readBytes = await innerStream.ReadAsync(buffer.Slice(0, Math.Min(buffer.Length, maxSize)), cancellationToken) + int readBytes = await innerStream.ReadAsync(buffer.Slice(0, (int)Math.Min(buffer.Length, _maxRemainingLength)), cancellationToken) .ConfigureAwait(false); - maxSize -= readBytes; + _maxRemainingLength -= readBytes; return readBytes; } diff --git a/src/libraries/System.Net.Requests/tests/HttpWebRequestTest.cs b/src/libraries/System.Net.Requests/tests/HttpWebRequestTest.cs index 5e91ba7f732efd..2dcbc910bff4bc 100644 --- a/src/libraries/System.Net.Requests/tests/HttpWebRequestTest.cs +++ b/src/libraries/System.Net.Requests/tests/HttpWebRequestTest.cs @@ -2318,33 +2318,70 @@ await server.AcceptConnectionAsync( [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] public async Task SendHttpRequest_WhenDefaultMaximumErrorResponseLengthSet_Success() { - await RemoteExecutor.Invoke(async (async) => + await RemoteExecutor.Invoke(async isAsync => { TaskCompletionSource tcs = new TaskCompletionSource(); await LoopbackServer.CreateClientAndServerAsync( - async (uri) => + async uri => { HttpWebRequest request = WebRequest.CreateHttp(uri); - HttpWebRequest.DefaultMaximumErrorResponseLength = 5; - var exception = - await Assert.ThrowsAsync(() => bool.Parse(async) ? request.GetResponseAsync() : Task.Run(() => request.GetResponse())); + HttpWebRequest.DefaultMaximumErrorResponseLength = 1; // 1 KB + WebException exception = + await Assert.ThrowsAsync(() => bool.Parse(isAsync) ? request.GetResponseAsync() : Task.Run(() => request.GetResponse())); tcs.SetResult(); Assert.NotNull(exception.Response); - using (var responseStream = exception.Response.GetResponseStream()) + using (Stream responseStream = exception.Response.GetResponseStream()) { - var buffer = new byte[10]; - int readLen = responseStream.Read(buffer, 0, buffer.Length); - Assert.Equal(5, readLen); - Assert.Equal(new string('a', 5), Encoding.UTF8.GetString(buffer[0..readLen])); - Assert.Equal(0, responseStream.Read(buffer)); + byte[] buffer = new byte[10 * 1024]; + int totalReadLen = 0; + int readLen = 0; + while ((readLen = responseStream.Read(buffer, readLen, buffer.Length - readLen)) > 0) + { + totalReadLen += readLen; + } + + Assert.Equal(1024, totalReadLen); + Assert.Equal(new string('a', 1024), Encoding.UTF8.GetString(buffer[0..totalReadLen])); } }, - async (server) => + async server => + { + await server.AcceptConnectionAsync( + async connection => + { + await connection.SendResponseAsync(statusCode: HttpStatusCode.BadRequest, content: new string('a', 10 * 1024)); + await tcs.Task; + }); + }); + }, IsAsync.ToString()).DisposeAsync(); + } + + [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + public async Task SendHttpRequest_WhenDefaultMaximumErrorResponseLengthSetToIntMax_DoesNotThrow() + { + await RemoteExecutor.Invoke(async isAsync => + { + TaskCompletionSource tcs = new TaskCompletionSource(); + await LoopbackServer.CreateClientAndServerAsync( + async uri => + { + HttpWebRequest request = WebRequest.CreateHttp(uri); + HttpWebRequest.DefaultMaximumErrorResponseLength = int.MaxValue; // KB + WebException exception = + await Assert.ThrowsAsync(() => bool.Parse(isAsync) ? request.GetResponseAsync() : Task.Run(() => request.GetResponse())); + tcs.SetResult(); + Assert.NotNull(exception.Response); + using (Stream responseStream = exception.Response.GetResponseStream()) + { + Assert.Equal(1, await responseStream.ReadAsync(new byte[1])); + } + }, + async server => { await server.AcceptConnectionAsync( async connection => { - await connection.SendResponseAsync(statusCode: HttpStatusCode.BadRequest, content: new string('a', 10)); + await connection.SendResponseAsync(statusCode: HttpStatusCode.BadRequest, content: new string('a', 10 * 1024)); await tcs.Task; }); });