Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize Convert.ToBase64CharArray and TryToBase64Chars #73320

Merged
merged 3 commits into from
Aug 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 127 additions & 53 deletions src/libraries/System.Private.CoreLib/src/System/Convert.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers;
using System.Buffers.Text;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Buffers;
using System.Buffers.Text;
using System.Runtime.Intrinsics;
using System.Text;

namespace System
Expand Down Expand Up @@ -104,6 +105,7 @@ public static partial class Convert
internal const string Base64Table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=";

private const int Base64LineBreakPosition = 76;
private const int Base64VectorizationLengthThreshold = 16;

#if DEBUG
static Convert()
Expand Down Expand Up @@ -2324,7 +2326,7 @@ public static string ToBase64String(byte[] inArray, int offset, int length, Base

public static string ToBase64String(ReadOnlySpan<byte> bytes, Base64FormattingOptions options = Base64FormattingOptions.None)
{
if (options < Base64FormattingOptions.None || options > Base64FormattingOptions.InsertLineBreaks)
if ((uint)options > (uint)Base64FormattingOptions.InsertLineBreaks)
{
throw new ArgumentException(SR.Format(SR.Arg_EnumIllegalVal, (int)options), nameof(options));
}
Expand All @@ -2337,36 +2339,22 @@ public static string ToBase64String(ReadOnlySpan<byte> bytes, Base64FormattingOp
bool insertLineBreaks = (options == Base64FormattingOptions.InsertLineBreaks);
int outputLength = ToBase64_CalculateAndValidateOutputLength(bytes.Length, insertLineBreaks);

if (!insertLineBreaks && bytes.Length >= 64)
{
// For large inputs it's faster to allocate a temp buffer and call UTF8 version
// which is then extended to UTF8 via Latin1.GetString (base64 is always ASCI)
[MethodImpl(MethodImplOptions.NoInlining)]
static string ToBase64StringLargeInputs(ReadOnlySpan<byte> data, int outputLen)
{
byte[]? rentedBytes = null;
Span<byte> utf8buffer = outputLen <= 256 ? stackalloc byte[256] : (rentedBytes = ArrayPool<byte>.Shared.Rent(outputLen));
OperationStatus status = Base64.EncodeToUtf8(data, utf8buffer, out int _, out int bytesWritten);
Debug.Assert(status == OperationStatus.Done && bytesWritten == outputLen);
string result = Encoding.Latin1.GetString(utf8buffer.Slice(0, outputLen));
if (rentedBytes != null)
{
ArrayPool<byte>.Shared.Return(rentedBytes);
}
return result;
}
return ToBase64StringLargeInputs(bytes, outputLength);
}

string result = string.FastAllocateString(outputLength);

unsafe
if (!insertLineBreaks && bytes.Length >= Base64VectorizationLengthThreshold)
{
fixed (byte* bytesPtr = &MemoryMarshal.GetReference(bytes))
fixed (char* charsPtr = result)
ToBase64CharsLargeNoLineBreaks(bytes, new Span<char>(ref result.GetRawStringData(), result.Length), result.Length);
}
else
{
unsafe
{
int charsWritten = ConvertToBase64Array(charsPtr, bytesPtr, 0, bytes.Length, insertLineBreaks);
Debug.Assert(result.Length == charsWritten, $"Expected {result.Length} == {charsWritten}");
fixed (byte* bytesPtr = &MemoryMarshal.GetReference(bytes))
fixed (char* charsPtr = result)
{
int charsWritten = ConvertToBase64Array(charsPtr, bytesPtr, 0, bytes.Length, insertLineBreaks);
Debug.Assert(result.Length == charsWritten, $"Expected {result.Length} == {charsWritten}");
}
}
}

Expand All @@ -2389,50 +2377,47 @@ public static unsafe int ToBase64CharArray(byte[] inArray, int offsetIn, int len
throw new ArgumentOutOfRangeException(nameof(offsetIn), SR.ArgumentOutOfRange_GenericPositive);
if (offsetOut < 0)
throw new ArgumentOutOfRangeException(nameof(offsetOut), SR.ArgumentOutOfRange_GenericPositive);

if (options < Base64FormattingOptions.None || options > Base64FormattingOptions.InsertLineBreaks)
{
throw new ArgumentException(SR.Format(SR.Arg_EnumIllegalVal, (int)options), nameof(options));
}

int retVal;
int inArrayLength = inArray.Length;

int inArrayLength;
int outArrayLength;
int numElementsToCopy;

inArrayLength = inArray.Length;

if (offsetIn > (int)(inArrayLength - length))
if (offsetIn > (inArrayLength - length))
throw new ArgumentOutOfRangeException(nameof(offsetIn), SR.ArgumentOutOfRange_OffsetLength);

if (inArrayLength == 0)
return 0;

bool insertLineBreaks = (options == Base64FormattingOptions.InsertLineBreaks);
// This is the maximally required length that must be available in the char array
outArrayLength = outArray.Length;
int outArrayLength = outArray.Length;

// Length of the char buffer required
numElementsToCopy = ToBase64_CalculateAndValidateOutputLength(length, insertLineBreaks);
bool insertLineBreaks = options == Base64FormattingOptions.InsertLineBreaks;
int charLengthRequired = ToBase64_CalculateAndValidateOutputLength(length, insertLineBreaks);

if (offsetOut > (int)(outArrayLength - numElementsToCopy))
if (offsetOut > outArrayLength - charLengthRequired)
throw new ArgumentOutOfRangeException(nameof(offsetOut), SR.ArgumentOutOfRange_OffsetOut);

fixed (char* outChars = &outArray[offsetOut])
if (!insertLineBreaks && length >= Base64VectorizationLengthThreshold)
{
ToBase64CharsLargeNoLineBreaks(new ReadOnlySpan<byte>(inArray, offsetIn, length), outArray.AsSpan(offsetOut), charLengthRequired);
}
else
{
fixed (char* outChars = &outArray[offsetOut])
fixed (byte* inData = &inArray[0])
{
retVal = ConvertToBase64Array(outChars, inData, offsetIn, length, insertLineBreaks);
int converted = ConvertToBase64Array(outChars, inData, offsetIn, length, insertLineBreaks);
Debug.Assert(converted == charLengthRequired);
}
}

return retVal;
return charLengthRequired;
}

public static unsafe bool TryToBase64Chars(ReadOnlySpan<byte> bytes, Span<char> chars, out int charsWritten, Base64FormattingOptions options = Base64FormattingOptions.None)
{
if (options < Base64FormattingOptions.None || options > Base64FormattingOptions.InsertLineBreaks)
if ((uint)options > (uint)Base64FormattingOptions.InsertLineBreaks)
{
throw new ArgumentException(SR.Format(SR.Arg_EnumIllegalVal, (int)options), nameof(options));
}
Expand All @@ -2443,7 +2428,7 @@ public static unsafe bool TryToBase64Chars(ReadOnlySpan<byte> bytes, Span<char>
return true;
}

bool insertLineBreaks = (options == Base64FormattingOptions.InsertLineBreaks);
bool insertLineBreaks = options == Base64FormattingOptions.InsertLineBreaks;

int charLengthRequired = ToBase64_CalculateAndValidateOutputLength(bytes.Length, insertLineBreaks);
if (charLengthRequired > chars.Length)
Expand All @@ -2452,12 +2437,101 @@ public static unsafe bool TryToBase64Chars(ReadOnlySpan<byte> bytes, Span<char>
return false;
}

fixed (char* outChars = &MemoryMarshal.GetReference(chars))
fixed (byte* inData = &MemoryMarshal.GetReference(bytes))
if (!insertLineBreaks && bytes.Length >= Base64VectorizationLengthThreshold)
{
charsWritten = ConvertToBase64Array(outChars, inData, 0, bytes.Length, insertLineBreaks);
return true;
ToBase64CharsLargeNoLineBreaks(bytes, chars, charLengthRequired);
}
else
{
fixed (char* outChars = &MemoryMarshal.GetReference(chars))
fixed (byte* inData = &MemoryMarshal.GetReference(bytes))
{
int converted = ConvertToBase64Array(outChars, inData, 0, bytes.Length, insertLineBreaks);
Debug.Assert(converted == charLengthRequired);
}
}

charsWritten = charLengthRequired;
return true;
}

/// <summary>Base64 encodes the bytes from <paramref name="bytes"/> into <paramref name="chars"/>.</summary>
/// <param name="bytes">The bytes to encode.</param>
/// <param name="chars">The destination buffer large enough to handle the encoded chars.</param>
/// <param name="charLengthRequired">The pre-calculated, exact number of chars that will be written.</param>
private static unsafe void ToBase64CharsLargeNoLineBreaks(ReadOnlySpan<byte> bytes, Span<char> chars, int charLengthRequired)
{
// For large enough inputs, it's beneficial to use the vectorized UTF8-based Base64 encoding
// and then widen the resulting bytes into chars.
Debug.Assert(bytes.Length >= Base64VectorizationLengthThreshold);
Debug.Assert(chars.Length >= charLengthRequired);
Debug.Assert(charLengthRequired % 4 == 0);

// Base64-encode the bytes directly into the destination char buffer (reinterpreted as a byte buffer).
OperationStatus status = Base64.EncodeToUtf8(bytes, MemoryMarshal.AsBytes(chars), out _, out int bytesWritten);
Debug.Assert(status == OperationStatus.Done && charLengthRequired == bytesWritten);

// Now widen the ASCII bytes in-place to chars (if the vectorized ASCIIUtility.WidenAsciiToUtf16 is ever updated
// to support in-place updates, it should be used here instead). Since the base64 bytes are all valid ASCII, the byte
// data is guaranteed to be 1/2 as long as the char data, and we can widen in-place.
ref ushort dest = ref Unsafe.As<char, ushort>(ref MemoryMarshal.GetReference(chars));
ref byte src = ref Unsafe.As<ushort, byte>(ref dest);
ref byte srcBeginning = ref src;

// We process the bytes/chars from right to left to avoid overwriting the remaining unprocessed data.
// The refs start out pointing just past the end of the data, and each iteration of a loop bumps
// the refs back the apropriate amount and performs the copy/widening.
dest = ref Unsafe.Add(ref dest, charLengthRequired);
src = ref Unsafe.Add(ref src, charLengthRequired);

// Handle 32 bytes at a time.
if (Vector256.IsHardwareAccelerated)
{
ref byte srcBeginningPlus31 = ref Unsafe.Add(ref srcBeginning, 31);
while (Unsafe.IsAddressGreaterThan(ref src, ref srcBeginningPlus31))
{
src = ref Unsafe.Subtract(ref src, 32);
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
dest = ref Unsafe.Subtract(ref dest, 32);

(Vector256<ushort> utf16Lower, Vector256<ushort> utf16Upper) = Vector256.Widen(Vector256.LoadUnsafe(ref src));

utf16Lower.StoreUnsafe(ref dest);
utf16Upper.StoreUnsafe(ref dest, 16);
}
}

// Handle 16 bytes at a time.
if (Vector128.IsHardwareAccelerated)
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
{
ref byte srcBeginningPlus15 = ref Unsafe.Add(ref srcBeginning, 15);
while (Unsafe.IsAddressGreaterThan(ref src, ref srcBeginningPlus15))
{
src = ref Unsafe.Subtract(ref src, 16);
dest = ref Unsafe.Subtract(ref dest, 16);

(Vector128<ushort> utf16Lower, Vector128<ushort> utf16Upper) = Vector128.Widen(Vector128.LoadUnsafe(ref src));

utf16Lower.StoreUnsafe(ref dest);
utf16Upper.StoreUnsafe(ref dest, 8);
}
}

// Handle 4 bytes at a time.
ref byte srcBeginningPlus3 = ref Unsafe.Add(ref srcBeginning, 3);
while (Unsafe.IsAddressGreaterThan(ref src, ref srcBeginningPlus3))
{
dest = ref Unsafe.Subtract(ref dest, 4);
src = ref Unsafe.Subtract(ref src, 4);
ASCIIUtility.WidenFourAsciiBytesToUtf16AndWriteToBuffer(ref Unsafe.As<ushort, char>(ref dest), Unsafe.ReadUnaligned<uint>(ref src));
}

// The length produced by Base64 encoding is always a multiple of 4, so we don't need to handle
// 1 byte at a time as is common in other vectorized operations, as nothing will remain after
// the 4-byte loop.

Debug.Assert(Unsafe.AreSame(ref srcBeginning, ref src));
Debug.Assert(Unsafe.AreSame(ref srcBeginning, ref Unsafe.As<ushort, byte>(ref dest)),
"The two references should have ended up exactly at the beginning");
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
}

private static unsafe int ConvertToBase64Array(char* outChars, byte* inData, int offset, int length, bool insertLineBreaks)
Expand Down
33 changes: 32 additions & 1 deletion src/libraries/System.Runtime.Extensions/tests/System/Convert.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public static void ToBase64CharArrayTest()
int length = Convert.ToBase64CharArray(barray, 0, barray.Length, carray, 0, Base64FormattingOptions.InsertLineBreaks);
int length2 = Convert.ToBase64CharArray(barray, 0, barray.Length, carray, 0, Base64FormattingOptions.None);
Assert.Equal(352, length);
Assert.Equal(352, length);
Assert.Equal(344, length2);
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
}

[Fact]
Expand All @@ -46,6 +46,37 @@ public static void ToBase64StringTest()
Assert.Equal(subset, Convert.FromBase64String(s3));
}

[Fact]
public static void Base64_AllMethodsRoundtripConsistently()
{
var r = new Random(42);
for (int length = 0; length < 128; length++)
{
var original = new byte[length];
r.NextBytes(original);

string encodedString = Convert.ToBase64String(original);

char[] encodedArray = new char[encodedString.Length];
int charsWritten = Convert.ToBase64CharArray(original, 0, original.Length, encodedArray, 0);
Assert.Equal(encodedArray.Length, charsWritten);
AssertExtensions.SequenceEqual<char>(encodedString, encodedArray);

char[] encodedSpan = new char[encodedString.Length];
Assert.True(Convert.TryToBase64Chars(original, encodedSpan, out charsWritten));
Assert.Equal(encodedSpan.Length, charsWritten);
AssertExtensions.SequenceEqual<char>(encodedString, encodedSpan);

AssertExtensions.SequenceEqual(original, Convert.FromBase64String(encodedString));
AssertExtensions.SequenceEqual(original, Convert.FromBase64CharArray(encodedArray, 0, encodedArray.Length));

byte[] actualBytes = new byte[original.Length];
Assert.True(Convert.TryFromBase64Chars(encodedSpan, actualBytes, out int bytesWritten));
Assert.Equal(original.Length, bytesWritten);
AssertExtensions.SequenceEqual(original, actualBytes);
}
}

[Fact]
public void ToBooleanTests()
{
Expand Down