diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Helper/Base64DecoderHelper.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Helper/Base64DecoderHelper.cs index d63377faf110ec..fcc002bae27e4c 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Helper/Base64DecoderHelper.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Helper/Base64DecoderHelper.cs @@ -269,6 +269,7 @@ internal static unsafe OperationStatus DecodeFrom(TBase64Deco static OperationStatus InvalidDataFallback(TBase64Decoder decoder, ReadOnlySpan source, Span bytes, ref int bytesConsumed, ref int bytesWritten, bool isFinalBlock) { + ReadOnlySpan originalSource = source; source = source.Slice(bytesConsumed); bytes = bytes.Slice(bytesWritten); @@ -297,6 +298,7 @@ static OperationStatus InvalidDataFallback(TBase64Decoder decoder, ReadOnlySpan< } // Skip over the starting whitespace and continue. + int whitespaceConsumed = localConsumed; bytesConsumed += localConsumed; source = source.Slice(localConsumed); @@ -306,6 +308,20 @@ static OperationStatus InvalidDataFallback(TBase64Decoder decoder, ReadOnlySpan< bytesWritten += localWritten; if (status is not OperationStatus.InvalidData) { + // If we got DestinationTooSmall and have remaining non-whitespace data, + // fall back to block-wise decoding which can handle small destinations better. + if (status == OperationStatus.DestinationTooSmall && !source.IsEmpty) + { + // Check if there's non-whitespace remaining + int nonWhitespaceIdx = decoder.IndexOfAnyExceptWhiteSpace(source.Slice(localConsumed)); + if (nonWhitespaceIdx >= 0) + { + // Reset and use blockwise decoder from the start + bytesConsumed = 0; + bytesWritten = 0; + return decoder.DecodeWithWhiteSpaceBlockwiseWrapper(decoder, originalSource, bytes, ref bytesConsumed, ref bytesWritten, isFinalBlock); + } + } break; } diff --git a/src/libraries/System.Runtime/tests/System.Runtime.Extensions.Tests/System/Convert.cs b/src/libraries/System.Runtime/tests/System.Runtime.Extensions.Tests/System/Convert.cs index c3ba20a5d76d22..9e91d77757241a 100644 --- a/src/libraries/System.Runtime/tests/System.Runtime.Extensions.Tests/System/Convert.cs +++ b/src/libraries/System.Runtime/tests/System.Runtime.Extensions.Tests/System/Convert.cs @@ -6,6 +6,8 @@ using System.Linq; using System.Text; using System.Collections.Generic; +using System.Buffers.Text; +using System.Buffers; using Test.Cryptography; @@ -297,6 +299,8 @@ public static void TryFromBase64String(string encoded, byte[] expected) bool success = Convert.TryFromBase64String(encoded, actual, out int bytesWritten); Assert.False(success); Assert.Equal(0, bytesWritten); + + Assert.Equal(OperationStatus.InvalidData, Base64.DecodeFromUtf8(Encoding.UTF8.GetBytes(encoded), actual, out _, out _)); } else { @@ -307,6 +311,10 @@ public static void TryFromBase64String(string encoded, byte[] expected) Assert.True(success); Assert.Equal(expected, actual); Assert.Equal(expected.Length, bytesWritten); + + Assert.Equal(OperationStatus.Done, Base64.DecodeFromUtf8(Encoding.UTF8.GetBytes(encoded), actual, out int bytesConsumed, out bytesWritten)); + Assert.Equal(encoded.Length, bytesConsumed); + Assert.Equal(expected.Length, bytesWritten); } // Buffer too short @@ -316,6 +324,9 @@ public static void TryFromBase64String(string encoded, byte[] expected) bool success = Convert.TryFromBase64String(encoded, actual, out int bytesWritten); Assert.False(success); Assert.Equal(0, bytesWritten); + + Assert.Equal(OperationStatus.DestinationTooSmall, Base64.DecodeFromUtf8(Encoding.UTF8.GetBytes(encoded), actual, out _, out bytesWritten)); + Assert.Equal(actual.Length, bytesWritten); } // Buffer larger than needed @@ -327,6 +338,10 @@ public static void TryFromBase64String(string encoded, byte[] expected) Assert.Equal(99, actual[expected.Length]); Assert.Equal(expected, actual.Take(expected.Length)); Assert.Equal(expected.Length, bytesWritten); + + Assert.Equal(OperationStatus.Done, Base64.DecodeFromUtf8(Encoding.UTF8.GetBytes(encoded), actual, out int bytesConsumed, out bytesWritten)); + Assert.Equal(encoded.Length, bytesConsumed); + Assert.Equal(expected.Length, bytesWritten); } } }