diff --git a/src/libraries/System.Private.DataContractSerialization/src/System/Xml/XmlStreamNodeWriter.cs b/src/libraries/System.Private.DataContractSerialization/src/System/Xml/XmlStreamNodeWriter.cs index a0efa808691c0..ea629d725ec48 100644 --- a/src/libraries/System.Private.DataContractSerialization/src/System/Xml/XmlStreamNodeWriter.cs +++ b/src/libraries/System.Private.DataContractSerialization/src/System/Xml/XmlStreamNodeWriter.cs @@ -1,8 +1,10 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers.Binary; using System.IO; using System.Text; +using System.Runtime.InteropServices; using System.Runtime.Serialization; using System.Threading.Tasks; using System.Diagnostics; @@ -330,34 +332,26 @@ protected unsafe void UnsafeWriteUnicodeChars(char* chars, int charCount) } } - protected unsafe int UnsafeGetUnicodeChars(char* chars, int charCount, byte[] buffer, int offset) + protected static unsafe int UnsafeGetUnicodeChars(char* chars, int charCount, byte[] buffer, int offset) { - char* charsMax = chars + charCount; - while (chars < charsMax) + if (BitConverter.IsLittleEndian) { - char value = *chars++; - buffer[offset++] = (byte)value; - value >>= 8; - buffer[offset++] = (byte)value; + new ReadOnlySpan(chars, charCount) + .CopyTo(MemoryMarshal.Cast(buffer.AsSpan(offset))); } + else + { + BinaryPrimitives.ReverseEndianness(new ReadOnlySpan(chars, charCount), + MemoryMarshal.Cast(buffer.AsSpan(offset))); + } + return charCount * 2; } protected unsafe int UnsafeGetUTF8Length(char* chars, int charCount) { - char* charsMax = chars + charCount; - while (chars < charsMax) - { - if (*chars >= 0x80) - break; - - chars++; - } - - if (chars == charsMax) - return charCount; - - return (int)(chars - (charsMax - charCount)) + (_encoding ?? DataContractSerializer.ValidatingUTF8).GetByteCount(chars, (int)(charsMax - chars)); + // Length will always be at least ( 128 / maxBytesPerChar) = 42 + return (_encoding ?? DataContractSerializer.ValidatingUTF8).GetByteCount(chars, charCount); } protected unsafe int UnsafeGetUTF8Chars(char* chars, int charCount, byte[] buffer, int offset) @@ -366,39 +360,32 @@ protected unsafe int UnsafeGetUTF8Chars(char* chars, int charCount, byte[] buffe { fixed (byte* _bytes = &buffer[offset]) { - byte* bytes = _bytes; - byte* bytesMax = &bytes[buffer.Length - offset]; - char* charsMax = &chars[charCount]; - - while (true) + // Fast path for small strings, use Encoding.GetBytes for larger strings since it is faster when vectorization is possible + if ((uint)charCount < 32) { + byte* bytes = _bytes; + char* charsMax = &chars[charCount]; + while (chars < charsMax) { char t = *chars; if (t >= 0x80) - break; + goto NonAscii; *bytes = (byte)t; bytes++; chars++; } + return charCount; - if (chars >= charsMax) - break; - - char* charsStart = chars; - while (chars < charsMax && *chars >= 0x80) - { - chars++; - } - - bytes += (_encoding ?? DataContractSerializer.ValidatingUTF8).GetBytes(charsStart, (int)(chars - charsStart), bytes, (int)(bytesMax - bytes)); - - if (chars >= charsMax) - break; + NonAscii: + byte* bytesMax = _bytes + buffer.Length - offset; + return (int)(bytes - _bytes) + (_encoding ?? DataContractSerializer.ValidatingUTF8).GetBytes(chars, (int)(charsMax - chars), bytes, (int)(bytesMax - bytes)); + } + else + { + return (_encoding ?? DataContractSerializer.ValidatingUTF8).GetBytes(chars, charCount, _bytes, buffer.Length - offset); } - - return (int)(bytes - _bytes); } } return 0; diff --git a/src/libraries/System.Runtime.Serialization.Xml/tests/XmlDictionaryWriterTest.cs b/src/libraries/System.Runtime.Serialization.Xml/tests/XmlDictionaryWriterTest.cs index 3ad4d32400e37..b3b5a8495cf5b 100644 --- a/src/libraries/System.Runtime.Serialization.Xml/tests/XmlDictionaryWriterTest.cs +++ b/src/libraries/System.Runtime.Serialization.Xml/tests/XmlDictionaryWriterTest.cs @@ -494,6 +494,71 @@ void AssertBytesWritten(Action action, XmlBinaryNodeType no } } + [Fact] + public static void XmlBaseWriter_WriteString() + { + const byte Chars8Text = 152; + const byte Chars16Text = 154; + MemoryStream ms = new MemoryStream(); + XmlDictionaryWriter writer = (XmlDictionaryWriter)XmlDictionaryWriter.CreateBinaryWriter(ms); + writer.WriteStartElement("root"); + + int[] lengths = new[] { 7, 8, 9, 15, 16, 17, 31, 32, 36, 258 }; + byte[] buffer = new byte[lengths.Max() + 1]; + + foreach (var length in lengths) + { + string allAscii = string.Create(length, null, (Span chars, object _) => + { + for (int i = 0; i < chars.Length; ++i) + chars[i] = (char)(i % 128); + }); + string multiByteLast = string.Create(length, null, (Span chars, object _) => + { + for (int i = 0; i < chars.Length; ++i) + chars[i] = (char)(i % 128); + chars[^1] = '\u00E4'; // 'ä' - Latin Small Letter a with Diaeresis. Latin-1 Supplement. + }); + + int numBytes = Encoding.UTF8.GetBytes(allAscii, buffer); + Assert.True(numBytes == length, "Test setup wrong - allAscii"); + ValidateWriteText(ms, writer, allAscii, expected: buffer.AsSpan(0, numBytes)); + + numBytes = Encoding.UTF8.GetBytes(multiByteLast, buffer); + Assert.True(numBytes == length + 1, "Test setup wrong - multiByte"); + ValidateWriteText(ms, writer, multiByteLast, expected: buffer.AsSpan(0, numBytes)); + } + + static void ValidateWriteText(MemoryStream ms, XmlDictionaryWriter writer, string text, ReadOnlySpan expected) + { + writer.Flush(); + ms.Seek(0, SeekOrigin.Begin); + ms.SetLength(0); + writer.WriteString(text); + writer.Flush(); + + ms.TryGetBuffer(out ArraySegment arraySegment); + ReadOnlySpan buffer = arraySegment; + + if (expected.Length <= byte.MaxValue) + { + Assert.Equal(Chars8Text, buffer[0]); + Assert.Equal(expected.Length, buffer[1]); + buffer = buffer.Slice(2); + } + else if (expected.Length <= ushort.MaxValue) + { + Assert.Equal(Chars16Text, buffer[0]); + Assert.Equal(expected.Length, (int)(buffer[1]) | ((int)buffer[2] << 8)); + buffer = buffer.Slice(3); + } + else + Assert.Fail("test use to long length"); + + AssertExtensions.SequenceEqual(expected, buffer); + } + } + private static bool ReadTest(MemoryStream ms, Encoding encoding, ReaderWriterFactory.ReaderWriterType rwType, byte[] byteArray) { ms.Position = 0;