diff --git a/src/Shared/runtime/Http2/Hpack/HPackEncoder.cs b/src/Shared/runtime/Http2/Hpack/HPackEncoder.cs index d2fbc52232a9..4c3ac2952704 100644 --- a/src/Shared/runtime/Http2/Hpack/HPackEncoder.cs +++ b/src/Shared/runtime/Http2/Hpack/HPackEncoder.cs @@ -4,6 +4,7 @@ #nullable enable using System.Collections.Generic; using System.Diagnostics; +using System.Text; namespace System.Net.Http.HPack { @@ -96,7 +97,7 @@ public static bool EncodeLiteralHeaderFieldWithoutIndexing(int index, string val if (IntegerEncoder.Encode(index, 4, destination, out int indexLength)) { Debug.Assert(indexLength >= 1); - if (EncodeStringLiteral(value, destination.Slice(indexLength), out int nameLength)) + if (EncodeStringLiteral(value, valueEncoding: null, destination.Slice(indexLength), out int nameLength)) { bytesWritten = indexLength + nameLength; return true; @@ -128,7 +129,7 @@ public static bool EncodeLiteralHeaderFieldNeverIndexing(int index, string value if (IntegerEncoder.Encode(index, 4, destination, out int indexLength)) { Debug.Assert(indexLength >= 1); - if (EncodeStringLiteral(value, destination.Slice(indexLength), out int nameLength)) + if (EncodeStringLiteral(value, valueEncoding: null, destination.Slice(indexLength), out int nameLength)) { bytesWritten = indexLength + nameLength; return true; @@ -160,7 +161,7 @@ public static bool EncodeLiteralHeaderFieldIndexing(int index, string value, Spa if (IntegerEncoder.Encode(index, 6, destination, out int indexLength)) { Debug.Assert(indexLength >= 1); - if (EncodeStringLiteral(value, destination.Slice(indexLength), out int nameLength)) + if (EncodeStringLiteral(value, valueEncoding: null, destination.Slice(indexLength), out int nameLength)) { bytesWritten = indexLength + nameLength; return true; @@ -276,7 +277,7 @@ private static bool EncodeLiteralHeaderNewNameCore(byte mask, string name, strin { destination[0] = mask; if (EncodeLiteralHeaderName(name, destination.Slice(1), out int nameLength) && - EncodeStringLiteral(value, destination.Slice(1 + nameLength), out int valueLength)) + EncodeStringLiteral(value, valueEncoding: null, destination.Slice(1 + nameLength), out int valueLength)) { bytesWritten = 1 + nameLength + valueLength; return true; @@ -289,6 +290,11 @@ private static bool EncodeLiteralHeaderNewNameCore(byte mask, string name, strin /// Encodes a "Literal Header Field without Indexing - New Name". public static bool EncodeLiteralHeaderFieldWithoutIndexingNewName(string name, ReadOnlySpan values, string separator, Span destination, out int bytesWritten) + { + return EncodeLiteralHeaderFieldWithoutIndexingNewName(name, values, separator, valueEncoding: null, destination, out bytesWritten); + } + + public static bool EncodeLiteralHeaderFieldWithoutIndexingNewName(string name, ReadOnlySpan values, string separator, Encoding? valueEncoding, Span destination, out int bytesWritten) { // From https://tools.ietf.org/html/rfc7541#section-6.2.2 // ------------------------------------------------------ @@ -309,7 +315,7 @@ public static bool EncodeLiteralHeaderFieldWithoutIndexingNewName(string name, R { destination[0] = 0; if (EncodeLiteralHeaderName(name, destination.Slice(1), out int nameLength) && - EncodeStringLiterals(values, separator, destination.Slice(1 + nameLength), out int valueLength)) + EncodeStringLiterals(values, separator, valueEncoding, destination.Slice(1 + nameLength), out int valueLength)) { bytesWritten = 1 + nameLength + valueLength; return true; @@ -395,27 +401,20 @@ private static bool EncodeLiteralHeaderName(string value, Span destination return false; } - private static bool EncodeStringLiteralValue(string value, Span destination, out int bytesWritten) + private static void EncodeValueStringPart(string value, Span destination) { - if (value.Length <= destination.Length) + Debug.Assert(destination.Length >= value.Length); + + for (int i = 0; i < value.Length; i++) { - for (int i = 0; i < value.Length; i++) + char c = value[i]; + if ((c & 0xFF80) != 0) { - char c = value[i]; - if ((c & 0xFF80) != 0) - { - throw new HttpRequestException(SR.net_http_request_invalid_char_encoding); - } - - destination[i] = (byte)c; + throw new HttpRequestException(SR.net_http_request_invalid_char_encoding); } - bytesWritten = value.Length; - return true; + destination[i] = (byte)c; } - - bytesWritten = 0; - return false; } public static bool EncodeStringLiteral(ReadOnlySpan value, Span destination, out int bytesWritten) @@ -453,6 +452,11 @@ public static bool EncodeStringLiteral(ReadOnlySpan value, Span dest } public static bool EncodeStringLiteral(string value, Span destination, out int bytesWritten) + { + return EncodeStringLiteral(value, valueEncoding: null, destination, out bytesWritten); + } + + public static bool EncodeStringLiteral(string value, Encoding? valueEncoding, Span destination, out int bytesWritten) { // From https://tools.ietf.org/html/rfc7541#section-5.2 // ------------------------------------------------------ @@ -466,13 +470,28 @@ public static bool EncodeStringLiteral(string value, Span destination, out if (destination.Length != 0) { destination[0] = 0; // TODO: Use Huffman encoding - if (IntegerEncoder.Encode(value.Length, 7, destination, out int integerLength)) + + int encodedStringLength = valueEncoding is null || ReferenceEquals(valueEncoding, Encoding.Latin1) + ? value.Length + : valueEncoding.GetByteCount(value); + + if (IntegerEncoder.Encode(encodedStringLength, 7, destination, out int integerLength)) { Debug.Assert(integerLength >= 1); - - if (EncodeStringLiteralValue(value, destination.Slice(integerLength), out int valueLength)) + destination = destination.Slice(integerLength); + if (encodedStringLength <= destination.Length) { - bytesWritten = integerLength + valueLength; + if (valueEncoding is null) + { + EncodeValueStringPart(value, destination); + } + else + { + int written = valueEncoding.GetBytes(value, destination); + Debug.Assert(written == encodedStringLength); + } + + bytesWritten = integerLength + encodedStringLength; return true; } } @@ -502,56 +521,87 @@ public static bool EncodeDynamicTableSizeUpdate(int value, Span destinatio } public static bool EncodeStringLiterals(ReadOnlySpan values, string? separator, Span destination, out int bytesWritten) + { + return EncodeStringLiterals(values, separator, valueEncoding: null, destination, out bytesWritten); + } + + public static bool EncodeStringLiterals(ReadOnlySpan values, string? separator, Encoding? valueEncoding, Span destination, out int bytesWritten) { bytesWritten = 0; if (values.Length == 0) { - return EncodeStringLiteral("", destination, out bytesWritten); + return EncodeStringLiteral("", valueEncoding: null, destination, out bytesWritten); } else if (values.Length == 1) { - return EncodeStringLiteral(values[0], destination, out bytesWritten); + return EncodeStringLiteral(values[0], valueEncoding, destination, out bytesWritten); } if (destination.Length != 0) { - int valueLength = 0; + Debug.Assert(separator != null); + int valueLength; // Calculate length of all parts and separators. - foreach (string part in values) + if (valueEncoding is null || ReferenceEquals(valueEncoding, Encoding.Latin1)) { - valueLength = checked((int)(valueLength + part.Length)); + valueLength = checked((int)(values.Length - 1) * separator.Length); + foreach (string part in values) + { + valueLength = checked((int)(valueLength + part.Length)); + } + } + else + { + valueLength = checked((int)(values.Length - 1) * valueEncoding.GetByteCount(separator)); + foreach (string part in values) + { + valueLength = checked((int)(valueLength + valueEncoding.GetByteCount(part))); + } } - - Debug.Assert(separator != null); - valueLength = checked((int)(valueLength + (values.Length - 1) * separator.Length)); destination[0] = 0; if (IntegerEncoder.Encode(valueLength, 7, destination, out int integerLength)) { Debug.Assert(integerLength >= 1); - - int encodedLength = 0; - for (int j = 0; j < values.Length; j++) + destination = destination.Slice(integerLength); + if (destination.Length >= valueLength) { - if (j != 0 && !EncodeStringLiteralValue(separator, destination.Slice(integerLength), out encodedLength)) + if (valueEncoding is null) { - return false; + string value = values[0]; + EncodeValueStringPart(value, destination); + destination = destination.Slice(value.Length); + + for (int i = 1; i < values.Length; i++) + { + EncodeValueStringPart(separator, destination); + destination = destination.Slice(separator.Length); + + value = values[i]; + EncodeValueStringPart(value, destination); + destination = destination.Slice(value.Length); + } } + else + { + int written = valueEncoding.GetBytes(values[0], destination); + destination = destination.Slice(written); - integerLength += encodedLength; + for (int i = 1; i < values.Length; i++) + { + written = valueEncoding.GetBytes(separator, destination); + destination = destination.Slice(written); - if (!EncodeStringLiteralValue(values[j], destination.Slice(integerLength), out encodedLength)) - { - return false; + written = valueEncoding.GetBytes(values[i], destination); + destination = destination.Slice(written); + } } - integerLength += encodedLength; + bytesWritten = integerLength + valueLength; + return true; } - - bytesWritten = integerLength; - return true; } } diff --git a/src/Shared/runtime/Http3/QPack/QPackEncoder.cs b/src/Shared/runtime/Http3/QPack/QPackEncoder.cs index be43dc3bc716..68e04ed2d4ce 100644 --- a/src/Shared/runtime/Http3/QPack/QPackEncoder.cs +++ b/src/Shared/runtime/Http3/QPack/QPackEncoder.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Net.Http.HPack; +using System.Text; namespace System.Net.Http.QPack { @@ -59,6 +60,11 @@ public static byte[] EncodeStaticIndexedHeaderFieldToArray(int index) // - T is constant 1 here, indicating a static table reference. // - H is constant 0 here, as we do not yet perform Huffman coding. public static bool EncodeLiteralHeaderFieldWithStaticNameReference(int index, string value, Span destination, out int bytesWritten) + { + return EncodeLiteralHeaderFieldWithStaticNameReference(index, value, valueEncoding: null, destination, out bytesWritten); + } + + public static bool EncodeLiteralHeaderFieldWithStaticNameReference(int index, string value, Encoding? valueEncoding, Span destination, out int bytesWritten) { // Requires at least two bytes (one for name reference header, one for value length) if (destination.Length >= 2) @@ -68,7 +74,7 @@ public static bool EncodeLiteralHeaderFieldWithStaticNameReference(int index, st { destination = destination.Slice(headerBytesWritten); - if (EncodeValueString(value, destination, out int valueBytesWritten)) + if (EncodeValueString(value, valueEncoding, destination, out int valueBytesWritten)) { bytesWritten = headerBytesWritten + valueBytesWritten; return true; @@ -81,7 +87,7 @@ public static bool EncodeLiteralHeaderFieldWithStaticNameReference(int index, st } /// - /// Encodes just the name part of a Literal Header Field With Static Name Reference. Must call after to encode the header's value. + /// Encodes just the name part of a Literal Header Field With Static Name Reference. Must call after to encode the header's value. /// public static byte[] EncodeLiteralHeaderFieldWithStaticNameReferenceToArray(int index) { @@ -119,7 +125,12 @@ public static byte[] EncodeLiteralHeaderFieldWithStaticNameReferenceToArray(int // - H is constant 0 here, as we do not yet perform Huffman coding. public static bool EncodeLiteralHeaderFieldWithoutNameReference(string name, string value, Span destination, out int bytesWritten) { - if (EncodeNameString(name, destination, out int nameLength) && EncodeValueString(value, destination.Slice(nameLength), out int valueLength)) + return EncodeLiteralHeaderFieldWithoutNameReference(name, value, valueEncoding: null, destination, out bytesWritten); + } + + public static bool EncodeLiteralHeaderFieldWithoutNameReference(string name, string value, Encoding? valueEncoding, Span destination, out int bytesWritten) + { + if (EncodeNameString(name, destination, out int nameLength) && EncodeValueString(value, valueEncoding, destination.Slice(nameLength), out int valueLength)) { bytesWritten = nameLength + valueLength; return true; @@ -136,7 +147,12 @@ public static bool EncodeLiteralHeaderFieldWithoutNameReference(string name, str /// public static bool EncodeLiteralHeaderFieldWithoutNameReference(string name, ReadOnlySpan values, string valueSeparator, Span destination, out int bytesWritten) { - if (EncodeNameString(name, destination, out int nameLength) && EncodeValueString(values, valueSeparator, destination.Slice(nameLength), out int valueLength)) + return EncodeLiteralHeaderFieldWithoutNameReference(name, values, valueSeparator, valueEncoding: null, destination, out bytesWritten); + } + + public static bool EncodeLiteralHeaderFieldWithoutNameReference(string name, ReadOnlySpan values, string valueSeparator, Encoding? valueEncoding, Span destination, out int bytesWritten) + { + if (EncodeNameString(name, destination, out int nameLength) && EncodeValueString(values, valueSeparator, valueEncoding, destination.Slice(nameLength), out int valueLength)) { bytesWritten = nameLength + valueLength; return true; @@ -147,7 +163,7 @@ public static bool EncodeLiteralHeaderFieldWithoutNameReference(string name, Rea } /// - /// Encodes just the value part of a Literawl Header Field Without Static Name Reference. Must call after to encode the header's value. + /// Encodes just the value part of a Literawl Header Field Without Static Name Reference. Must call after to encode the header's value. /// public static byte[] EncodeLiteralHeaderFieldWithoutNameReferenceToArray(string name) { @@ -169,19 +185,32 @@ public static byte[] EncodeLiteralHeaderFieldWithoutNameReferenceToArray(string return temp.Slice(0, bytesWritten).ToArray(); } - private static bool EncodeValueString(string s, Span buffer, out int length) + private static bool EncodeValueString(string s, Encoding? valueEncoding, Span buffer, out int length) { if (buffer.Length != 0) { buffer[0] = 0; - if (IntegerEncoder.Encode(s.Length, 7, buffer, out int nameLength)) + + int encodedStringLength = valueEncoding is null || ReferenceEquals(valueEncoding, Encoding.Latin1) + ? s.Length + : valueEncoding.GetByteCount(s); + + if (IntegerEncoder.Encode(encodedStringLength, 7, buffer, out int nameLength)) { buffer = buffer.Slice(nameLength); - if (buffer.Length >= s.Length) + if (buffer.Length >= encodedStringLength) { - EncodeValueStringPart(s, buffer); + if (valueEncoding is null) + { + EncodeValueStringPart(s, buffer); + } + else + { + int written = valueEncoding.GetBytes(s, buffer); + Debug.Assert(written == encodedStringLength); + } - length = nameLength + s.Length; + length = nameLength + encodedStringLength; return true; } } @@ -195,25 +224,42 @@ private static bool EncodeValueString(string s, Span buffer, out int lengt /// Encodes a value by concatenating a collection of strings, separated by a separator string. /// public static bool EncodeValueString(ReadOnlySpan values, string? separator, Span buffer, out int length) + { + return EncodeValueString(values, separator, valueEncoding: null, buffer, out length); + } + + public static bool EncodeValueString(ReadOnlySpan values, string? separator, Encoding? valueEncoding, Span buffer, out int length) { if (values.Length == 1) { - return EncodeValueString(values[0], buffer, out length); + return EncodeValueString(values[0], valueEncoding, buffer, out length); } if (values.Length == 0) { // TODO: this will be called with a string array from HttpHeaderCollection. Can we ever get a 0-length array from that? Assert if not. - return EncodeValueString(string.Empty, buffer, out length); + return EncodeValueString(string.Empty, valueEncoding: null, buffer, out length); } if (buffer.Length > 0) { Debug.Assert(separator != null); - int valueLength = separator.Length * (values.Length - 1); - for (int i = 0; i < values.Length; ++i) + int valueLength; + if (valueEncoding is null || ReferenceEquals(valueEncoding, Encoding.Latin1)) + { + valueLength = separator.Length * (values.Length - 1); + foreach (string part in values) + { + valueLength += part.Length; + } + } + else { - valueLength += values[i].Length; + valueLength = valueEncoding.GetByteCount(separator) * (values.Length - 1); + foreach (string part in values) + { + valueLength += valueEncoding.GetByteCount(part); + } } buffer[0] = 0; @@ -222,18 +268,35 @@ public static bool EncodeValueString(ReadOnlySpan values, string? separa buffer = buffer.Slice(nameLength); if (buffer.Length >= valueLength) { - string value = values[0]; - EncodeValueStringPart(value, buffer); - buffer = buffer.Slice(value.Length); - - for (int i = 1; i < values.Length; ++i) + if (valueEncoding is null) { - EncodeValueStringPart(separator, buffer); - buffer = buffer.Slice(separator.Length); - - value = values[i]; + string value = values[0]; EncodeValueStringPart(value, buffer); buffer = buffer.Slice(value.Length); + + for (int i = 1; i < values.Length; i++) + { + EncodeValueStringPart(separator, buffer); + buffer = buffer.Slice(separator.Length); + + value = values[i]; + EncodeValueStringPart(value, buffer); + buffer = buffer.Slice(value.Length); + } + } + else + { + int written = valueEncoding.GetBytes(values[0], buffer); + buffer = buffer.Slice(written); + + for (int i = 1; i < values.Length; i++) + { + written = valueEncoding.GetBytes(separator, buffer); + buffer = buffer.Slice(written); + + written = valueEncoding.GetBytes(values[i], buffer); + buffer = buffer.Slice(written); + } } length = nameLength + valueLength;