Skip to content

Commit c7719ca

Browse files
adamsitnikjtschuster
authored andcommitted
[NRBF] Comments and bug fixes from internal code review (dotnet#107735)
* copy comments and asserts from Levis internal code review * apply Levis suggestion: don't store Array.MaxLength as a const, as it may change in the future * add missing and fix some of the existing comments * first bug fix: SerializationRecord.TypeNameMatches should throw ArgumentNullException for null Type argument * second bug fix: SerializationRecord.TypeNameMatches should know the difference between SZArray and single-dimension, non-zero offset arrays (example: int[] and int[*]) * third bug fix: don't cast bytes to booleans * fourth bug fix: don't cast bytes to DateTimes * add one test case that I've forgot in previous PR
1 parent 0ffdb10 commit c7719ca

28 files changed

+392
-21
lines changed

src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs

+3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
namespace System.Formats.Nrbf;
55

6+
// See [MS-NRBF] Sec. 2.7 for more information.
7+
// https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/ca3ad2bc-777b-413a-a72a-9ba6ced76bc3
8+
69
[Flags]
710
internal enum AllowedRecordTypes : uint
811
{

src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayInfo.cs

+6-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@ namespace System.Formats.Nrbf;
1616
[DebuggerDisplay("{ArrayType}, rank={Rank}")]
1717
internal readonly struct ArrayInfo
1818
{
19-
internal const int MaxArrayLength = 2147483591; // Array.MaxLength
19+
#if NET8_0_OR_GREATER
20+
internal static int MaxArrayLength => Array.MaxLength; // dynamic lookup in case the value changes in a future runtime
21+
#else
22+
internal const int MaxArrayLength = 2147483591; // hardcode legacy Array.MaxLength for downlevel runtimes
23+
#endif
2024

2125
internal ArrayInfo(SerializationRecordId id, long totalElementsCount, BinaryArrayType arrayType = BinaryArrayType.Single, int rank = 1)
2226
{
@@ -47,7 +51,7 @@ internal static int ParseValidArrayLength(BinaryReader reader)
4751
{
4852
int length = reader.ReadInt32();
4953

50-
if (length is < 0 or > MaxArrayLength)
54+
if (length < 0 || length > MaxArrayLength)
5155
{
5256
ThrowHelper.ThrowInvalidValue(length);
5357
}

src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayOfClassesRecord.cs

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Collections.Generic;
55
using System.Reflection.Metadata;
66
using System.Formats.Nrbf.Utils;
7+
using System.Diagnostics;
78

89
namespace System.Formats.Nrbf;
910

@@ -54,6 +55,7 @@ public override TypeName TypeName
5455
}
5556

5657
int nullCount = ((NullsRecord)actual).NullCount;
58+
Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount.");
5759
do
5860
{
5961
result[resultIndex++] = null;
@@ -63,6 +65,8 @@ public override TypeName TypeName
6365
}
6466
}
6567

68+
Debug.Assert(resultIndex == result.Length, "We should have traversed the entirety of the newly created array.");
69+
6670
return result;
6771
}
6872

src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs

+6-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.IO;
66
using System.Reflection.Metadata;
77
using System.Formats.Nrbf.Utils;
8+
using System.Diagnostics;
89

910
namespace System.Formats.Nrbf;
1011

@@ -33,13 +34,15 @@ public override TypeName TypeName
3334
{
3435
object?[] values = new object?[Length];
3536

36-
for (int recordIndex = 0, valueIndex = 0; recordIndex < Records.Count; recordIndex++)
37+
int valueIndex = 0;
38+
for (int recordIndex = 0; recordIndex < Records.Count; recordIndex++)
3739
{
3840
SerializationRecord record = Records[recordIndex];
3941

4042
int nullCount = record is NullsRecord nullsRecord ? nullsRecord.NullCount : 0;
4143
if (nullCount == 0)
4244
{
45+
// "new object[] { <SELF> }" is special cased because it allows for storing reference to itself.
4346
values[valueIndex++] = record is MemberReferenceRecord referenceRecord && referenceRecord.Reference.Equals(Id)
4447
? values // a reference to self, and a way to get StackOverflow exception ;)
4548
: record.GetValue();
@@ -59,6 +62,8 @@ public override TypeName TypeName
5962
while (nullCount > 0);
6063
}
6164

65+
Debug.Assert(valueIndex == values.Length, "We should have traversed the entirety of the newly created array.");
66+
6267
return values;
6368
}
6469

src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs

+71-3
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,32 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
5353
return (List<T>)(object)DecodeDecimals(reader, count);
5454
}
5555

56+
// char[] has a unique representation in NRBF streams. Typical strings are transcoded
57+
// to UTF-8 and prefixed with the number of bytes in the UTF-8 representation. char[]
58+
// is also serialized as UTF-8, but it is instead prefixed with the number of chars
59+
// in the UTF-16 representation, not the number of bytes in the UTF-8 representation.
60+
// This number doesn't directly precede the UTF-8 contents in the NRBF stream; it's
61+
// instead contained within the ArrayInfo structure (passed to this method as the
62+
// 'count' argument).
63+
//
64+
// The practical consequence of this is that we don't actually know how many UTF-8
65+
// bytes we need to consume in order to ensure we've read 'count' chars. We know that
66+
// an n-length UTF-16 string turns into somewhere between [n .. 3n] UTF-8 bytes.
67+
// The best we can do is that when reading an n-element char[], we'll ensure that
68+
// there are at least n bytes remaining in the input stream. We'll still need to
69+
// account for that even with this check, we might hit EOF before fully populating
70+
// the char[]. But from a safety perspective, it does appropriately limit our
71+
// allocations to be proportional to the amount of data present in the input stream,
72+
// which is a sufficient defense against DoS.
73+
5674
long requiredBytes = count;
57-
if (typeof(T) != typeof(char)) // the input is UTF8
75+
if (typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan))
76+
{
77+
// We can't assume DateTime as represented by the runtime is 8 bytes.
78+
// The only assumption we can make is that it's 8 bytes on the wire.
79+
requiredBytes *= 8;
80+
}
81+
else if (typeof(T) != typeof(char))
5882
{
5983
requiredBytes *= Unsafe.SizeOf<T>();
6084
}
@@ -79,6 +103,10 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
79103
{
80104
return (T[])(object)reader.ParseChars(count);
81105
}
106+
else if (typeof(T) == typeof(TimeSpan) || typeof(T) == typeof(DateTime))
107+
{
108+
return DecodeTime(reader, count);
109+
}
82110

83111
// It's safe to pre-allocate, as we have ensured there is enough bytes in the stream.
84112
T[] result = new T[count];
@@ -130,8 +158,7 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
130158
}
131159
#endif
132160
}
133-
else if (typeof(T) == typeof(long) || typeof(T) == typeof(ulong) || typeof(T) == typeof(double)
134-
|| typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan))
161+
else if (typeof(T) == typeof(long) || typeof(T) == typeof(ulong) || typeof(T) == typeof(double))
135162
{
136163
Span<long> span = MemoryMarshal.Cast<T, long>(result);
137164
#if NET
@@ -145,6 +172,21 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
145172
}
146173
}
147174

175+
if (typeof(T) == typeof(bool))
176+
{
177+
// See DontCastBytesToBooleans test to see what could go wrong.
178+
bool[] booleans = (bool[])(object)result;
179+
resultAsBytes = MemoryMarshal.AsBytes<T>(result);
180+
for (int i = 0; i < booleans.Length; i++)
181+
{
182+
// We don't use the bool array to get the value, as an optimizing compiler or JIT could elide this.
183+
if (resultAsBytes[i] != 0) // it can be any byte different than 0
184+
{
185+
booleans[i] = true; // set it to 1 in explicit way
186+
}
187+
}
188+
}
189+
148190
return result;
149191
}
150192

@@ -158,8 +200,34 @@ private static List<decimal> DecodeDecimals(BinaryReader reader, int count)
158200
return values;
159201
}
160202

203+
private static T[] DecodeTime(BinaryReader reader, int count)
204+
{
205+
T[] values = new T[count];
206+
for (int i = 0; i < values.Length; i++)
207+
{
208+
if (typeof(T) == typeof(DateTime))
209+
{
210+
values[i] = (T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64());
211+
}
212+
else if (typeof(T) == typeof(TimeSpan))
213+
{
214+
values[i] = (T)(object)new TimeSpan(reader.ReadInt64());
215+
}
216+
else
217+
{
218+
throw new InvalidOperationException();
219+
}
220+
}
221+
222+
return values;
223+
}
224+
161225
private static List<T> DecodeFromNonSeekableStream(BinaryReader reader, int count)
162226
{
227+
// The count arg could originate from untrusted input, so we shouldn't
228+
// pass it as-is to the ctor's capacity arg. We'll instead rely on
229+
// List<T>.Add's O(1) amortization to keep the entire loop O(count).
230+
163231
List<T> values = new List<T>(Math.Min(count, 4));
164232
for (int i = 0; i < count; i++)
165233
{

src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs

+6-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.IO;
66
using System.Reflection.Metadata;
77
using System.Formats.Nrbf.Utils;
8+
using System.Diagnostics;
89

910
namespace System.Formats.Nrbf;
1011

@@ -47,7 +48,8 @@ internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetA
4748
{
4849
string?[] values = new string?[Length];
4950

50-
for (int recordIndex = 0, valueIndex = 0; recordIndex < Records.Count; recordIndex++)
51+
int valueIndex = 0;
52+
for (int recordIndex = 0; recordIndex < Records.Count; recordIndex++)
5153
{
5254
SerializationRecord record = Records[recordIndex];
5355

@@ -73,6 +75,7 @@ record = memberReference.GetReferencedRecord();
7375
}
7476

7577
int nullCount = ((NullsRecord)record).NullCount;
78+
Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount.");
7679
do
7780
{
7881
values[valueIndex++] = null;
@@ -81,6 +84,8 @@ record = memberReference.GetReferencedRecord();
8184
while (nullCount > 0);
8285
}
8386

87+
Debug.Assert(valueIndex == values.Length, "We should have traversed the entirety of the newly created array.");
88+
8489
return values;
8590
}
8691
}

src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs

+21-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.IO;
77
using System.Reflection.Metadata;
88
using System.Formats.Nrbf.Utils;
9+
using System.Diagnostics;
910

1011
namespace System.Formats.Nrbf;
1112

@@ -84,6 +85,10 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls)
8485
case SerializationRecordType.ArraySinglePrimitive:
8586
case SerializationRecordType.ArraySingleObject:
8687
case SerializationRecordType.ArraySingleString:
88+
89+
// Recursion depth is bounded by the depth of arrayType, which is
90+
// a trustworthy Type instance. Don't need to worry about stack overflow.
91+
8792
ArrayRecord nestedArrayRecord = (ArrayRecord)record;
8893
Array nestedArray = nestedArrayRecord.GetArray(actualElementType, allowNulls);
8994
array.SetValue(nestedArray, resultIndex++);
@@ -97,6 +102,7 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls)
97102
}
98103

99104
int nullCount = ((NullsRecord)item).NullCount;
105+
Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount.");
100106
do
101107
{
102108
array.SetValue(null, resultIndex++);
@@ -110,6 +116,8 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls)
110116
}
111117
}
112118

119+
Debug.Assert(resultIndex == array.Length, "We should have traversed the entirety of the newly created array.");
120+
113121
return array;
114122
}
115123

@@ -122,6 +130,7 @@ internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, Pay
122130
bool isRectangular = arrayType is BinaryArrayType.Rectangular;
123131

124132
// It is an arbitrary limit in the current CoreCLR type loader.
133+
// Don't change this value without reviewing the loop a few lines below.
125134
const int MaxSupportedArrayRank = 32;
126135

127136
if (rank < 1 || rank > MaxSupportedArrayRank
@@ -132,18 +141,26 @@ internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, Pay
132141
}
133142

134143
int[] lengths = new int[rank]; // adversary-controlled, but acceptable since upper limit of 32
135-
long totalElementCount = 1;
144+
long totalElementCount = 1; // to avoid integer overflow during the multiplication below
136145
for (int i = 0; i < lengths.Length; i++)
137146
{
138147
lengths[i] = ArrayInfo.ParseValidArrayLength(reader);
139148
totalElementCount *= lengths[i];
140149

150+
// n.b. This forbids "new T[Array.MaxLength, Array.MaxLength, Array.MaxLength, ..., 0]"
151+
// but allows "new T[0, Array.MaxLength, Array.MaxLength, Array.MaxLength, ...]". But
152+
// that's the same behavior that newarr and Array.CreateInstance exhibit, so at least
153+
// we're consistent.
154+
141155
if (totalElementCount > ArrayInfo.MaxArrayLength)
142156
{
143157
ThrowHelper.ThrowInvalidValue(lengths[i]); // max array size exceeded
144158
}
145159
}
146160

161+
// Per BinaryReaderExtensions.ReadArrayType, we do not support nonzero offsets, so
162+
// we don't need to read the NRBF stream 'LowerBounds' field here.
163+
147164
MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, 1, options, recordMap);
148165
ArrayInfo arrayInfo = new(objectId, totalElementCount, arrayType, rank);
149166

@@ -186,6 +203,9 @@ private static Type MapElementType(Type arrayType, out bool isClassRecord)
186203
Type elementType = arrayType;
187204
int arrayNestingDepth = 0;
188205

206+
// Loop iteration counts are bound by the nesting depth of arrayType,
207+
// which is a trustworthy input. No DoS concerns.
208+
189209
while (elementType.IsArray)
190210
{
191211
elementType = elementType.GetElementType()!;

src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassInfo.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ internal static ClassInfo Decode(BinaryReader reader)
5050

5151
// Use Dictionary instead of List so that searching for member IDs by name
5252
// is O(n) instead of O(m * n), where m = memberCount and n = memberNameLength,
53-
// in degenerate cases.
53+
// in degenerate cases. Since memberCount may be hostile, don't allow it to be
54+
// used as the initial capacity in the collection instance.
5455
Dictionary<string, int> memberNames = new(StringComparer.Ordinal);
5556
for (int i = 0; i < memberCount; i++)
5657
{

src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassTypeInfo.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
namespace System.Formats.Nrbf;
1010

1111
/// <summary>
12-
/// Identifies a class by it's name and library id.
12+
/// Identifies a class by its name and library id.
1313
/// </summary>
1414
/// <remarks>
1515
/// ClassTypeInfo structures are described in <see href="https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/844b24dd-9f82-426e-9b98-05334307a239">[MS-NRBF] 2.1.1.8</see>.

src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ internal bool ShouldBeRepresentedAsArrayOfClassRecords()
110110
{
111111
// This library tries to minimize the number of concepts the users need to learn to use it.
112112
// Since SZArrays are most common, it provides an SZArrayRecord<T> abstraction.
113-
// Every other array (jagged, multi-dimensional etc) is represented using SZArrayRecord.
113+
// Every other array (jagged, multi-dimensional etc) is represented using ArrayRecord.
114114
// The goal of this method is to determine whether given array can be represented as SZArrayRecord<ClassRecord>.
115115

116116
(BinaryType binaryType, object? additionalInfo) = Infos[0];

src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NextInfo.cs

+1-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,5 @@ internal NextInfo(AllowedRecordTypes allowed, SerializationRecord parent,
2727
internal PrimitiveType PrimitiveType { get; }
2828

2929
internal NextInfo With(AllowedRecordTypes allowed, PrimitiveType primitiveType)
30-
=> allowed == Allowed && primitiveType == PrimitiveType
31-
? this // previous record was of the same type
32-
: new(allowed, Parent, Stack, primitiveType);
30+
=> new(allowed, Parent, Stack, primitiveType);
3331
}

src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public static class NrbfDecoder
2222
// The header consists of:
2323
// - a byte that describes the record type (SerializationRecordType.SerializedStreamHeader)
2424
// - four 32 bit integers:
25-
// - root Id (every value is valid)
25+
// - root Id (every value except of 0 is valid)
2626
// - header Id (value is ignored)
2727
// - major version, it has to be equal 1.
2828
// - minor version, it has to be equal 0.

src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/PayloadOptions.cs

+7
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,17 @@ public PayloadOptions() { }
2525
/// </summary>
2626
/// <value><see langword="true" /> if truncated type names should be reassembled; otherwise, <see langword="false" />.</value>
2727
/// <remarks>
28+
/// <para>
2829
/// Example:
2930
/// TypeName: "Namespace.TypeName`1[[Namespace.GenericArgName"
3031
/// LibraryName: "AssemblyName]]"
3132
/// Is combined into "Namespace.TypeName`1[[Namespace.GenericArgName, AssemblyName]]"
33+
/// </para>
34+
/// <para>
35+
/// Setting this to <see langword="true" /> can render <see cref="NrbfDecoder"/> susceptible to Denial of Service
36+
/// attacks when parsing or handling malicious input.
37+
/// </para>
38+
/// <para>The default value is <see langword="false" />.</para>
3239
/// </remarks>
3340
public bool UndoTruncatedTypeNames { get; set; }
3441
}

0 commit comments

Comments
 (0)