Skip to content

Commit

Permalink
Add internal Encoding.TryGetBytes (#84609)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub authored Apr 11, 2023
1 parent aa27a07 commit 19ff978
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,31 @@ public override unsafe int GetBytes(ReadOnlySpan<char> chars, Span<byte> bytes)
}
}

// TODO https://github.com/dotnet/runtime/issues/84425: Make this public.
/// <summary>Encodes into a span of bytes a set of characters from the specified read-only span if the destination is large enough.</summary>
/// <param name="chars">The span containing the set of characters to encode.</param>
/// <param name="bytes">The byte span to hold the encoded bytes.</param>
/// <param name="bytesWritten">Upon successful completion of the operation, the number of bytes encoded into <paramref name="bytes"/>.</param>
/// <returns><see langword="true"/> if all of the characters were encoded into the destination; <see langword="false"/> if the destination was too small to contain all the encoded bytes.</returns>
internal override unsafe bool TryGetBytes(ReadOnlySpan<char> chars, Span<byte> bytes, out int bytesWritten)
{
fixed (char* charsPtr = &MemoryMarshal.GetReference(chars))
fixed (byte* bytesPtr = &MemoryMarshal.GetReference(bytes))
{
int written = GetBytesCommon(charsPtr, chars.Length, bytesPtr, bytes.Length, throwForDestinationOverflow: false);
if (written >= 0)
{
bytesWritten = written;
return true;
}

bytesWritten = 0;
return false;
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private unsafe int GetBytesCommon(char* pChars, int charCount, byte* pBytes, int byteCount)
private unsafe int GetBytesCommon(char* pChars, int charCount, byte* pBytes, int byteCount, bool throwForDestinationOverflow = true)
{
// Common helper method for all non-EncoderNLS entry points to GetBytes.
// A modification of this method should be copied in to each of the supported encodings: ASCII, UTF8, UTF16, UTF32.
Expand All @@ -347,7 +370,7 @@ private unsafe int GetBytesCommon(char* pChars, int charCount, byte* pBytes, int
{
// Simple narrowing conversion couldn't operate on entire buffer - invoke fallback.

return GetBytesWithFallback(pChars, charCount, pBytes, byteCount, charsConsumed, bytesWritten);
return GetBytesWithFallback(pChars, charCount, pBytes, byteCount, charsConsumed, bytesWritten, throwForDestinationOverflow);
}
}

Expand All @@ -360,7 +383,7 @@ private protected sealed override unsafe int GetBytesFast(char* pChars, int char
return bytesWritten;
}

private protected sealed override unsafe int GetBytesWithFallback(ReadOnlySpan<char> chars, int originalCharsLength, Span<byte> bytes, int originalBytesLength, EncoderNLS? encoder)
private protected sealed override unsafe int GetBytesWithFallback(ReadOnlySpan<char> chars, int originalCharsLength, Span<byte> bytes, int originalBytesLength, EncoderNLS? encoder, bool throwForDestinationOverflow = true)
{
// We special-case EncoderReplacementFallback if it's telling us to write a single ASCII char,
// since we believe this to be relatively common and we can handle it more efficiently than
Expand Down Expand Up @@ -406,7 +429,7 @@ private protected sealed override unsafe int GetBytesWithFallback(ReadOnlySpan<c
}
else
{
return base.GetBytesWithFallback(chars, originalCharsLength, bytes, originalBytesLength, encoder);
return base.GetBytesWithFallback(chars, originalCharsLength, bytes, originalBytesLength, encoder, throwForDestinationOverflow);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ private protected virtual unsafe int GetBytesFast(char* pChars, int charsLength,
/// If the destination buffer is not large enough to hold the entirety of the transcoded data.
/// </exception>
[MethodImpl(MethodImplOptions.NoInlining)]
private protected unsafe int GetBytesWithFallback(char* pOriginalChars, int originalCharCount, byte* pOriginalBytes, int originalByteCount, int charsConsumedSoFar, int bytesWrittenSoFar)
private protected unsafe int GetBytesWithFallback(char* pOriginalChars, int originalCharCount, byte* pOriginalBytes, int originalByteCount, int charsConsumedSoFar, int bytesWrittenSoFar, bool throwForDestinationOverflow = true)
{
// This is a stub method that's marked "no-inlining" so that it we don't stack-spill spans
// into our immediate caller. Doing so increases the method prolog in what's supposed to
Expand All @@ -499,7 +499,8 @@ private protected unsafe int GetBytesWithFallback(char* pOriginalChars, int orig
originalCharsLength: originalCharCount,
bytes: new Span<byte>(pOriginalBytes, originalByteCount).Slice(bytesWrittenSoFar),
originalBytesLength: originalByteCount,
encoder: null);
encoder: null,
throwForDestinationOverflow);
}

/// <summary>
Expand All @@ -508,7 +509,7 @@ private protected unsafe int GetBytesWithFallback(char* pOriginalChars, int orig
/// and <paramref name="bytesWrittenSoFar"/> signal where in the provided buffers the fallback loop
/// should begin operating. The behavior of this method is to drain any leftover data in the
/// <see cref="EncoderNLS"/> instance, then to invoke the <see cref="GetBytesFast"/> virtual method
/// after data has been drained, then to call <see cref="GetBytesWithFallback(ReadOnlySpan{char}, int, Span{byte}, int, EncoderNLS)"/>.
/// after data has been drained, then to call <see cref="GetBytesWithFallback(ReadOnlySpan{char}, int, Span{byte}, int, EncoderNLS, bool)"/>.
/// </summary>
/// <returns>
/// The total number of bytes written to <paramref name="pOriginalBytes"/>, including <paramref name="bytesWrittenSoFar"/>.
Expand Down Expand Up @@ -582,7 +583,7 @@ private unsafe int GetBytesWithFallback(char* pOriginalChars, int originalCharCo
/// implementation, deferring to the base implementation if needed. This method calls <see cref="ThrowBytesOverflow"/>
/// if necessary.
/// </remarks>
private protected virtual unsafe int GetBytesWithFallback(ReadOnlySpan<char> chars, int originalCharsLength, Span<byte> bytes, int originalBytesLength, EncoderNLS? encoder)
private protected virtual unsafe int GetBytesWithFallback(ReadOnlySpan<char> chars, int originalCharsLength, Span<byte> bytes, int originalBytesLength, EncoderNLS? encoder, bool throwForDestinationOverflow = true)
{
Debug.Assert(!chars.IsEmpty, "Caller shouldn't invoke this method with an empty input buffer.");
Debug.Assert(originalCharsLength >= 0, "Caller provided invalid parameter.");
Expand Down Expand Up @@ -678,8 +679,15 @@ private protected virtual unsafe int GetBytesWithFallback(ReadOnlySpan<char> cha
// The line below will also throw if the encoder couldn't make any progress at all
// because the output buffer wasn't large enough to contain the result of even
// a single scalar conversion or fallback.

ThrowBytesOverflow(encoder, nothingEncoded: bytes.Length == originalBytesLength);
if (throwForDestinationOverflow)
{
ThrowBytesOverflow(encoder, nothingEncoded: bytes.Length == originalBytesLength);
}
else
{
Debug.Assert(encoder is null);
return -1;
}
}

// If an EncoderNLS instance is active, update its "total consumed character count" value.
Expand Down
19 changes: 19 additions & 0 deletions src/libraries/System.Private.CoreLib/src/System/Text/Encoding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,25 @@ public virtual unsafe int GetBytes(ReadOnlySpan<char> chars, Span<byte> bytes)
}
}

// TODO https://github.com/dotnet/runtime/issues/84425: Make this public.
/// <summary>Encodes into a span of bytes a set of characters from the specified read-only span if the destination is large enough.</summary>
/// <param name="chars">The span containing the set of characters to encode.</param>
/// <param name="bytes">The byte span to hold the encoded bytes.</param>
/// <param name="bytesWritten">Upon successful completion of the operation, the number of bytes encoded into <paramref name="bytes"/>.</param>
/// <returns><see langword="true"/> if all of the characters were encoded into the destination; <see langword="false"/> if the destination was too small to contain all the encoded bytes.</returns>
internal virtual bool TryGetBytes(ReadOnlySpan<char> chars, Span<byte> bytes, out int bytesWritten)
{
int required = GetByteCount(chars);
if (required <= bytes.Length)
{
bytesWritten = GetBytes(chars, bytes);
return true;
}

bytesWritten = 0;
return false;
}

// Returns the number of characters produced by decoding the given byte
// array.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,29 @@ public override unsafe int GetBytes(ReadOnlySpan<char> chars, Span<byte> bytes)
}
}

// TODO https://github.com/dotnet/runtime/issues/84425: Make this public.
/// <summary>Encodes into a span of bytes a set of characters from the specified read-only span if the destination is large enough.</summary>
/// <param name="chars">The span containing the set of characters to encode.</param>
/// <param name="bytes">The byte span to hold the encoded bytes.</param>
/// <param name="bytesWritten">Upon successful completion of the operation, the number of bytes encoded into <paramref name="bytes"/>.</param>
/// <returns><see langword="true"/> if all of the characters were encoded into the destination; <see langword="false"/> if the destination was too small to contain all the encoded bytes.</returns>
internal override unsafe bool TryGetBytes(ReadOnlySpan<char> chars, Span<byte> bytes, out int bytesWritten)
{
fixed (char* charsPtr = &MemoryMarshal.GetReference(chars))
fixed (byte* bytesPtr = &MemoryMarshal.GetReference(bytes))
{
int written = GetBytesCommon(charsPtr, chars.Length, bytesPtr, bytes.Length, throwForDestinationOverflow: false);
if (written >= 0)
{
bytesWritten = written;
return true;
}

bytesWritten = 0;
return false;
}
}

public override unsafe int GetBytes(string s, int charIndex, int charCount, byte[] bytes, int byteIndex)
{
if (s is null || bytes is null)
Expand Down Expand Up @@ -269,7 +292,7 @@ public override unsafe int GetBytes(string s, int charIndex, int charCount, byte


[MethodImpl(MethodImplOptions.AggressiveInlining)]
private unsafe int GetBytesCommon(char* pChars, int charCount, byte* pBytes, int byteCount)
private unsafe int GetBytesCommon(char* pChars, int charCount, byte* pBytes, int byteCount, bool throwForDestinationOverflow = true)
{
// Common helper method for all non-EncoderNLS entry points to GetBytes.
// A modification of this method should be copied in to each of the supported encodings: ASCII, UTF8, UTF16, UTF32.
Expand All @@ -293,7 +316,7 @@ private unsafe int GetBytesCommon(char* pChars, int charCount, byte* pBytes, int
{
// Simple narrowing conversion couldn't operate on entire buffer - invoke fallback.

return GetBytesWithFallback(pChars, charCount, pBytes, byteCount, charsConsumed, bytesWritten);
return GetBytesWithFallback(pChars, charCount, pBytes, byteCount, charsConsumed, bytesWritten, throwForDestinationOverflow);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ private unsafe string GetStringForSmallInput(byte[] bytes)

return new string(new ReadOnlySpan<char>(ref *pDestination, charsWritten)); // this overload of ROS ctor doesn't validate length
}

// TODO https://github.com/dotnet/runtime/issues/84425: Make this public.
// TODO: Make this [Intrinsic] and handle JIT-time UTF8 encoding of literal `chars`.
internal override unsafe bool TryGetBytes(ReadOnlySpan<char> chars, Span<byte> bytes, out int bytesWritten)
{
return base.TryGetBytes(chars, bytes, out bytesWritten);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,31 @@ public override unsafe int GetBytes(ReadOnlySpan<char> chars, Span<byte> bytes)
}
}

// TODO https://github.com/dotnet/runtime/issues/84425: Make this public.
/// <summary>Encodes into a span of bytes a set of characters from the specified read-only span if the destination is large enough.</summary>
/// <param name="chars">The span containing the set of characters to encode.</param>
/// <param name="bytes">The byte span to hold the encoded bytes.</param>
/// <param name="bytesWritten">Upon successful completion of the operation, the number of bytes encoded into <paramref name="bytes"/>.</param>
/// <returns><see langword="true"/> if all of the characters were encoded into the destination; <see langword="false"/> if the destination was too small to contain all the encoded bytes.</returns>
internal override unsafe bool TryGetBytes(ReadOnlySpan<char> chars, Span<byte> bytes, out int bytesWritten)
{
fixed (char* charsPtr = &MemoryMarshal.GetReference(chars))
fixed (byte* bytesPtr = &MemoryMarshal.GetReference(bytes))
{
int written = GetBytesCommon(charsPtr, chars.Length, bytesPtr, bytes.Length, throwForDestinationOverflow: false);
if (written >= 0)
{
bytesWritten = written;
return true;
}

bytesWritten = 0;
return false;
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private unsafe int GetBytesCommon(char* pChars, int charCount, byte* pBytes, int byteCount)
private unsafe int GetBytesCommon(char* pChars, int charCount, byte* pBytes, int byteCount, bool throwForDestinationOverflow = true)
{
// Common helper method for all non-EncoderNLS entry points to GetBytes.
// A modification of this method should be copied in to each of the supported encodings: ASCII, UTF8, UTF16, UTF32.
Expand All @@ -394,7 +417,7 @@ private unsafe int GetBytesCommon(char* pChars, int charCount, byte* pBytes, int
{
// Simple narrowing conversion couldn't operate on entire buffer - invoke fallback.

return GetBytesWithFallback(pChars, charCount, pBytes, byteCount, charsConsumed, bytesWritten);
return GetBytesWithFallback(pChars, charCount, pBytes, byteCount, charsConsumed, bytesWritten, throwForDestinationOverflow);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ public bool AppendFormatted<T>(T value, int alignment, string? format)
/// <param name="value">The span to write.</param>
public bool AppendFormatted(scoped ReadOnlySpan<char> value)
{
if (FromUtf16(value, _destination.Slice(_pos), out _, out int bytesWritten) == OperationStatus.Done)
if (Encoding.UTF8.TryGetBytes(value, _destination.Slice(_pos), out int bytesWritten))
{
_pos += bytesWritten;
return true;
Expand Down

0 comments on commit 19ff978

Please sign in to comment.