diff --git a/src/libraries/System.Private.CoreLib/src/System/Convert.cs b/src/libraries/System.Private.CoreLib/src/System/Convert.cs index f217a5279eb17..ba85c62448b56 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Convert.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Convert.cs @@ -1,13 +1,14 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; +using System.Buffers.Text; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; -using System.Buffers; -using System.Buffers.Text; +using System.Runtime.Intrinsics; using System.Text; namespace System @@ -104,6 +105,7 @@ public static partial class Convert internal const string Base64Table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="; private const int Base64LineBreakPosition = 76; + private const int Base64VectorizationLengthThreshold = 16; #if DEBUG static Convert() @@ -2324,7 +2326,7 @@ public static string ToBase64String(byte[] inArray, int offset, int length, Base public static string ToBase64String(ReadOnlySpan bytes, Base64FormattingOptions options = Base64FormattingOptions.None) { - if (options < Base64FormattingOptions.None || options > Base64FormattingOptions.InsertLineBreaks) + if ((uint)options > (uint)Base64FormattingOptions.InsertLineBreaks) { throw new ArgumentException(SR.Format(SR.Arg_EnumIllegalVal, (int)options), nameof(options)); } @@ -2337,36 +2339,22 @@ public static string ToBase64String(ReadOnlySpan bytes, Base64FormattingOp bool insertLineBreaks = (options == Base64FormattingOptions.InsertLineBreaks); int outputLength = ToBase64_CalculateAndValidateOutputLength(bytes.Length, insertLineBreaks); - if (!insertLineBreaks && bytes.Length >= 64) - { - // For large inputs it's faster to allocate a temp buffer and call UTF8 version - // which is then extended to UTF8 via Latin1.GetString (base64 is always ASCI) - [MethodImpl(MethodImplOptions.NoInlining)] - static string ToBase64StringLargeInputs(ReadOnlySpan data, int outputLen) - { - byte[]? rentedBytes = null; - Span utf8buffer = outputLen <= 256 ? stackalloc byte[256] : (rentedBytes = ArrayPool.Shared.Rent(outputLen)); - OperationStatus status = Base64.EncodeToUtf8(data, utf8buffer, out int _, out int bytesWritten); - Debug.Assert(status == OperationStatus.Done && bytesWritten == outputLen); - string result = Encoding.Latin1.GetString(utf8buffer.Slice(0, outputLen)); - if (rentedBytes != null) - { - ArrayPool.Shared.Return(rentedBytes); - } - return result; - } - return ToBase64StringLargeInputs(bytes, outputLength); - } - string result = string.FastAllocateString(outputLength); - unsafe + if (!insertLineBreaks && bytes.Length >= Base64VectorizationLengthThreshold) { - fixed (byte* bytesPtr = &MemoryMarshal.GetReference(bytes)) - fixed (char* charsPtr = result) + ToBase64CharsLargeNoLineBreaks(bytes, new Span(ref result.GetRawStringData(), result.Length), result.Length); + } + else + { + unsafe { - int charsWritten = ConvertToBase64Array(charsPtr, bytesPtr, 0, bytes.Length, insertLineBreaks); - Debug.Assert(result.Length == charsWritten, $"Expected {result.Length} == {charsWritten}"); + fixed (byte* bytesPtr = &MemoryMarshal.GetReference(bytes)) + fixed (char* charsPtr = result) + { + int charsWritten = ConvertToBase64Array(charsPtr, bytesPtr, 0, bytes.Length, insertLineBreaks); + Debug.Assert(result.Length == charsWritten, $"Expected {result.Length} == {charsWritten}"); + } } } @@ -2389,50 +2377,47 @@ public static unsafe int ToBase64CharArray(byte[] inArray, int offsetIn, int len throw new ArgumentOutOfRangeException(nameof(offsetIn), SR.ArgumentOutOfRange_GenericPositive); if (offsetOut < 0) throw new ArgumentOutOfRangeException(nameof(offsetOut), SR.ArgumentOutOfRange_GenericPositive); - if (options < Base64FormattingOptions.None || options > Base64FormattingOptions.InsertLineBreaks) - { throw new ArgumentException(SR.Format(SR.Arg_EnumIllegalVal, (int)options), nameof(options)); - } - int retVal; + int inArrayLength = inArray.Length; - int inArrayLength; - int outArrayLength; - int numElementsToCopy; - - inArrayLength = inArray.Length; - - if (offsetIn > (int)(inArrayLength - length)) + if (offsetIn > (inArrayLength - length)) throw new ArgumentOutOfRangeException(nameof(offsetIn), SR.ArgumentOutOfRange_OffsetLength); if (inArrayLength == 0) return 0; - bool insertLineBreaks = (options == Base64FormattingOptions.InsertLineBreaks); // This is the maximally required length that must be available in the char array - outArrayLength = outArray.Length; + int outArrayLength = outArray.Length; // Length of the char buffer required - numElementsToCopy = ToBase64_CalculateAndValidateOutputLength(length, insertLineBreaks); + bool insertLineBreaks = options == Base64FormattingOptions.InsertLineBreaks; + int charLengthRequired = ToBase64_CalculateAndValidateOutputLength(length, insertLineBreaks); - if (offsetOut > (int)(outArrayLength - numElementsToCopy)) + if (offsetOut > outArrayLength - charLengthRequired) throw new ArgumentOutOfRangeException(nameof(offsetOut), SR.ArgumentOutOfRange_OffsetOut); - fixed (char* outChars = &outArray[offsetOut]) + if (!insertLineBreaks && length >= Base64VectorizationLengthThreshold) { + ToBase64CharsLargeNoLineBreaks(new ReadOnlySpan(inArray, offsetIn, length), outArray.AsSpan(offsetOut), charLengthRequired); + } + else + { + fixed (char* outChars = &outArray[offsetOut]) fixed (byte* inData = &inArray[0]) { - retVal = ConvertToBase64Array(outChars, inData, offsetIn, length, insertLineBreaks); + int converted = ConvertToBase64Array(outChars, inData, offsetIn, length, insertLineBreaks); + Debug.Assert(converted == charLengthRequired); } } - return retVal; + return charLengthRequired; } public static unsafe bool TryToBase64Chars(ReadOnlySpan bytes, Span chars, out int charsWritten, Base64FormattingOptions options = Base64FormattingOptions.None) { - if (options < Base64FormattingOptions.None || options > Base64FormattingOptions.InsertLineBreaks) + if ((uint)options > (uint)Base64FormattingOptions.InsertLineBreaks) { throw new ArgumentException(SR.Format(SR.Arg_EnumIllegalVal, (int)options), nameof(options)); } @@ -2443,7 +2428,7 @@ public static unsafe bool TryToBase64Chars(ReadOnlySpan bytes, Span return true; } - bool insertLineBreaks = (options == Base64FormattingOptions.InsertLineBreaks); + bool insertLineBreaks = options == Base64FormattingOptions.InsertLineBreaks; int charLengthRequired = ToBase64_CalculateAndValidateOutputLength(bytes.Length, insertLineBreaks); if (charLengthRequired > chars.Length) @@ -2452,12 +2437,101 @@ public static unsafe bool TryToBase64Chars(ReadOnlySpan bytes, Span return false; } - fixed (char* outChars = &MemoryMarshal.GetReference(chars)) - fixed (byte* inData = &MemoryMarshal.GetReference(bytes)) + if (!insertLineBreaks && bytes.Length >= Base64VectorizationLengthThreshold) { - charsWritten = ConvertToBase64Array(outChars, inData, 0, bytes.Length, insertLineBreaks); - return true; + ToBase64CharsLargeNoLineBreaks(bytes, chars, charLengthRequired); + } + else + { + fixed (char* outChars = &MemoryMarshal.GetReference(chars)) + fixed (byte* inData = &MemoryMarshal.GetReference(bytes)) + { + int converted = ConvertToBase64Array(outChars, inData, 0, bytes.Length, insertLineBreaks); + Debug.Assert(converted == charLengthRequired); + } } + + charsWritten = charLengthRequired; + return true; + } + + /// Base64 encodes the bytes from into . + /// The bytes to encode. + /// The destination buffer large enough to handle the encoded chars. + /// The pre-calculated, exact number of chars that will be written. + private static unsafe void ToBase64CharsLargeNoLineBreaks(ReadOnlySpan bytes, Span chars, int charLengthRequired) + { + // For large enough inputs, it's beneficial to use the vectorized UTF8-based Base64 encoding + // and then widen the resulting bytes into chars. + Debug.Assert(bytes.Length >= Base64VectorizationLengthThreshold); + Debug.Assert(chars.Length >= charLengthRequired); + Debug.Assert(charLengthRequired % 4 == 0); + + // Base64-encode the bytes directly into the destination char buffer (reinterpreted as a byte buffer). + OperationStatus status = Base64.EncodeToUtf8(bytes, MemoryMarshal.AsBytes(chars), out _, out int bytesWritten); + Debug.Assert(status == OperationStatus.Done && charLengthRequired == bytesWritten); + + // Now widen the ASCII bytes in-place to chars (if the vectorized ASCIIUtility.WidenAsciiToUtf16 is ever updated + // to support in-place updates, it should be used here instead). Since the base64 bytes are all valid ASCII, the byte + // data is guaranteed to be 1/2 as long as the char data, and we can widen in-place. + ref ushort dest = ref Unsafe.As(ref MemoryMarshal.GetReference(chars)); + ref byte src = ref Unsafe.As(ref dest); + ref byte srcBeginning = ref src; + + // We process the bytes/chars from right to left to avoid overwriting the remaining unprocessed data. + // The refs start out pointing just past the end of the data, and each iteration of a loop bumps + // the refs back the apropriate amount and performs the copy/widening. + dest = ref Unsafe.Add(ref dest, charLengthRequired); + src = ref Unsafe.Add(ref src, charLengthRequired); + + // Handle 32 bytes at a time. + if (Vector256.IsHardwareAccelerated) + { + ref byte srcBeginningPlus31 = ref Unsafe.Add(ref srcBeginning, 31); + while (Unsafe.IsAddressGreaterThan(ref src, ref srcBeginningPlus31)) + { + src = ref Unsafe.Subtract(ref src, 32); + dest = ref Unsafe.Subtract(ref dest, 32); + + (Vector256 utf16Lower, Vector256 utf16Upper) = Vector256.Widen(Vector256.LoadUnsafe(ref src)); + + utf16Lower.StoreUnsafe(ref dest); + utf16Upper.StoreUnsafe(ref dest, 16); + } + } + + // Handle 16 bytes at a time. + if (Vector128.IsHardwareAccelerated) + { + ref byte srcBeginningPlus15 = ref Unsafe.Add(ref srcBeginning, 15); + while (Unsafe.IsAddressGreaterThan(ref src, ref srcBeginningPlus15)) + { + src = ref Unsafe.Subtract(ref src, 16); + dest = ref Unsafe.Subtract(ref dest, 16); + + (Vector128 utf16Lower, Vector128 utf16Upper) = Vector128.Widen(Vector128.LoadUnsafe(ref src)); + + utf16Lower.StoreUnsafe(ref dest); + utf16Upper.StoreUnsafe(ref dest, 8); + } + } + + // Handle 4 bytes at a time. + ref byte srcBeginningPlus3 = ref Unsafe.Add(ref srcBeginning, 3); + while (Unsafe.IsAddressGreaterThan(ref src, ref srcBeginningPlus3)) + { + dest = ref Unsafe.Subtract(ref dest, 4); + src = ref Unsafe.Subtract(ref src, 4); + ASCIIUtility.WidenFourAsciiBytesToUtf16AndWriteToBuffer(ref Unsafe.As(ref dest), Unsafe.ReadUnaligned(ref src)); + } + + // The length produced by Base64 encoding is always a multiple of 4, so we don't need to handle + // 1 byte at a time as is common in other vectorized operations, as nothing will remain after + // the 4-byte loop. + + Debug.Assert(Unsafe.AreSame(ref srcBeginning, ref src)); + Debug.Assert(Unsafe.AreSame(ref srcBeginning, ref Unsafe.As(ref dest)), + "The two references should have ended up exactly at the beginning"); } private static unsafe int ConvertToBase64Array(char* outChars, byte* inData, int offset, int length, bool insertLineBreaks) diff --git a/src/libraries/System.Runtime.Extensions/tests/System/Convert.cs b/src/libraries/System.Runtime.Extensions/tests/System/Convert.cs index 25f7b0f145aa7..9083367976089 100644 --- a/src/libraries/System.Runtime.Extensions/tests/System/Convert.cs +++ b/src/libraries/System.Runtime.Extensions/tests/System/Convert.cs @@ -29,7 +29,7 @@ public static void ToBase64CharArrayTest() int length = Convert.ToBase64CharArray(barray, 0, barray.Length, carray, 0, Base64FormattingOptions.InsertLineBreaks); int length2 = Convert.ToBase64CharArray(barray, 0, barray.Length, carray, 0, Base64FormattingOptions.None); Assert.Equal(352, length); - Assert.Equal(352, length); + Assert.Equal(344, length2); } [Fact] @@ -46,6 +46,37 @@ public static void ToBase64StringTest() Assert.Equal(subset, Convert.FromBase64String(s3)); } + [Fact] + public static void Base64_AllMethodsRoundtripConsistently() + { + var r = new Random(42); + for (int length = 0; length < 128; length++) + { + var original = new byte[length]; + r.NextBytes(original); + + string encodedString = Convert.ToBase64String(original); + + char[] encodedArray = new char[encodedString.Length]; + int charsWritten = Convert.ToBase64CharArray(original, 0, original.Length, encodedArray, 0); + Assert.Equal(encodedArray.Length, charsWritten); + AssertExtensions.SequenceEqual(encodedString, encodedArray); + + char[] encodedSpan = new char[encodedString.Length]; + Assert.True(Convert.TryToBase64Chars(original, encodedSpan, out charsWritten)); + Assert.Equal(encodedSpan.Length, charsWritten); + AssertExtensions.SequenceEqual(encodedString, encodedSpan); + + AssertExtensions.SequenceEqual(original, Convert.FromBase64String(encodedString)); + AssertExtensions.SequenceEqual(original, Convert.FromBase64CharArray(encodedArray, 0, encodedArray.Length)); + + byte[] actualBytes = new byte[original.Length]; + Assert.True(Convert.TryFromBase64Chars(encodedSpan, actualBytes, out int bytesWritten)); + Assert.Equal(original.Length, bytesWritten); + AssertExtensions.SequenceEqual(original, actualBytes); + } + } + [Fact] public void ToBooleanTests() {