diff --git a/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs b/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs index 8e12cf7c3712f..d7a6e01a72352 100644 --- a/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs +++ b/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs @@ -11,6 +11,7 @@ public abstract partial class ArrayRecord : System.Formats.Nrbf.SerializationRec internal ArrayRecord() { } public override System.Formats.Nrbf.SerializationRecordId Id { get { throw null; } } public abstract System.ReadOnlySpan Lengths { get; } + public virtual long FlattenedLength { get; } public int Rank { get { throw null; } } [System.Diagnostics.CodeAnalysis.RequiresDynamicCode("The code for an array of the specified type might not be available.")] public System.Array GetArray(System.Type expectedArrayType, bool allowNulls = true) { throw null; } diff --git a/src/libraries/System.Formats.Nrbf/src/Resources/Strings.resx b/src/libraries/System.Formats.Nrbf/src/Resources/Strings.resx index 6b9ffc6da372b..c6085fff72398 100644 --- a/src/libraries/System.Formats.Nrbf/src/Resources/Strings.resx +++ b/src/libraries/System.Formats.Nrbf/src/Resources/Strings.resx @@ -126,9 +126,6 @@ Unexpected Null Record count. - - The serialized array length ({0}) was larger than the configured limit {1}. - {0} Record Type is not supported by design. @@ -136,16 +133,16 @@ Invalid member reference. - Invalid type name: `{0}`. + Invalid type name. Expected the array to be of type {0}, but its element type was {1}. - Invalid type or assembly name: `{0},{1}`. + Invalid type or assembly name. - Duplicate member name: `{0}`. + Duplicate member name. Stream does not support seeking. @@ -160,7 +157,7 @@ Only arrays with zero offsets are supported. - Invalid assembly name: `{0}`. + Invalid assembly name. Invalid format. diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayInfo.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayInfo.cs index 40468c3f8bb3d..da03a459f35aa 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayInfo.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayInfo.cs @@ -25,14 +25,14 @@ internal readonly struct ArrayInfo internal ArrayInfo(SerializationRecordId id, long totalElementsCount, BinaryArrayType arrayType = BinaryArrayType.Single, int rank = 1) { Id = id; - TotalElementsCount = totalElementsCount; + FlattenedLength = totalElementsCount; ArrayType = arrayType; Rank = rank; } internal SerializationRecordId Id { get; } - internal long TotalElementsCount { get; } + internal long FlattenedLength { get; } internal BinaryArrayType ArrayType { get; } @@ -40,8 +40,8 @@ internal ArrayInfo(SerializationRecordId id, long totalElementsCount, BinaryArra internal int GetSZArrayLength() { - Debug.Assert(TotalElementsCount <= MaxArrayLength); - return (int)TotalElementsCount; + Debug.Assert(FlattenedLength <= MaxArrayLength); + return (int)FlattenedLength; } internal static ArrayInfo Decode(BinaryReader reader) diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs index ddfd91a29fb1a..237b7b72a2719 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs @@ -18,7 +18,7 @@ public abstract class ArrayRecord : SerializationRecord private protected ArrayRecord(ArrayInfo arrayInfo) { ArrayInfo = arrayInfo; - ValuesToRead = arrayInfo.TotalElementsCount; + ValuesToRead = arrayInfo.FlattenedLength; } /// @@ -27,6 +27,12 @@ private protected ArrayRecord(ArrayInfo arrayInfo) /// A buffer of integers that represent the number of elements in every dimension. public abstract ReadOnlySpan Lengths { get; } + /// + /// When overridden in a derived class, gets the total number of all elements in every dimension. + /// + /// A number that represent the total number of all elements in every dimension. + public virtual long FlattenedLength => ArrayInfo.FlattenedLength; + /// /// Gets the rank of the array. /// @@ -44,7 +50,12 @@ private protected ArrayRecord(ArrayInfo arrayInfo) internal long ValuesToRead { get; private protected set; } - private protected ArrayInfo ArrayInfo { get; } + internal ArrayInfo ArrayInfo { get; } + + internal bool IsJagged + => ArrayInfo.ArrayType == BinaryArrayType.Jagged + // It is possible to have binary array records have an element type of array without being marked as jagged. + || TypeName.GetElementType().IsArray; /// /// Allocates an array and fills it with the data provided in the serialized records (in case of primitive types like or ) or the serialized records themselves. diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs index cc3210a3e32b7..41b1f73f03550 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs @@ -28,12 +28,15 @@ internal sealed class BinaryArrayRecord : ArrayRecord ]; private TypeName? _typeName; + private long _totalElementsCount; private BinaryArrayRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo) : base(arrayInfo) { MemberTypeInfo = memberTypeInfo; Values = []; + // We need to parse all elements of the jagged array to obtain total elements count. + _totalElementsCount = -1; } public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray; @@ -41,6 +44,22 @@ private BinaryArrayRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo) /// public override ReadOnlySpan Lengths => new int[1] { Length }; + /// + public override long FlattenedLength + { + get + { + if (_totalElementsCount < 0) + { + _totalElementsCount = IsJagged + ? GetJaggedArrayFlattenedLength(this) + : ArrayInfo.FlattenedLength; + } + + return _totalElementsCount; + } + } + public override TypeName TypeName => _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo); @@ -174,6 +193,65 @@ internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, Pay : new BinaryArrayRecord(arrayInfo, memberTypeInfo); } + private static long GetJaggedArrayFlattenedLength(BinaryArrayRecord jaggedArrayRecord) + { + long result = 0; + Queue? jaggedArrayRecords = null; + + do + { + if (jaggedArrayRecords is not null) + { + jaggedArrayRecord = jaggedArrayRecords.Dequeue(); + } + + Debug.Assert(jaggedArrayRecord.IsJagged); + + // In theory somebody could create a payload that would represent + // a very nested array with total elements count > long.MaxValue. + // That is why this method is using checked arithmetic. + result = checked(result + jaggedArrayRecord.Length); // count the arrays themselves + + foreach (object value in jaggedArrayRecord.Values) + { + if (value is not SerializationRecord record) + { + continue; + } + + if (record.RecordType == SerializationRecordType.MemberReference) + { + record = ((MemberReferenceRecord)record).GetReferencedRecord(); + } + + switch (record.RecordType) + { + case SerializationRecordType.ArraySinglePrimitive: + case SerializationRecordType.ArraySingleObject: + case SerializationRecordType.ArraySingleString: + case SerializationRecordType.BinaryArray: + ArrayRecord nestedArrayRecord = (ArrayRecord)record; + if (nestedArrayRecord.IsJagged) + { + (jaggedArrayRecords ??= new()).Enqueue((BinaryArrayRecord)nestedArrayRecord); + } + else + { + // Don't call nestedArrayRecord.FlattenedLength to avoid any potential recursion, + // just call nestedArrayRecord.ArrayInfo.FlattenedLength that returns pre-computed value. + result = checked(result + nestedArrayRecord.ArrayInfo.FlattenedLength); + } + break; + default: + break; + } + } + } + while (jaggedArrayRecords is not null && jaggedArrayRecords.Count > 0); + + return result; + } + private protected override void AddValue(object value) => Values.Add(value); internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryLibraryRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryLibraryRecord.cs index 7318052610e1b..b723d8083e4a9 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryLibraryRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryLibraryRecord.cs @@ -50,7 +50,7 @@ internal static BinaryLibraryRecord Decode(BinaryReader reader, PayloadOptions o } else if (!options.UndoTruncatedTypeNames) { - ThrowHelper.ThrowInvalidAssemblyName(rawName); + ThrowHelper.ThrowInvalidAssemblyName(); } return new BinaryLibraryRecord(id, rawName); diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassInfo.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassInfo.cs index 01645dcb51213..a1cb7b47fb5ae 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassInfo.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassInfo.cs @@ -71,7 +71,7 @@ internal static ClassInfo Decode(BinaryReader reader) continue; } #endif - throw new SerializationException(SR.Format(SR.Serialization_DuplicateMemberName, memberName)); + ThrowHelper.ThrowDuplicateMemberName(); } return new ClassInfo(id, typeName, memberNames); diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs index b5b15e71aecba..f64dde36163d6 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs @@ -193,11 +193,11 @@ internal static RectangularArrayRecord Create(BinaryReader reader, ArrayInfo arr // to encountering an EOF if we realize later that we actually need to read more bytes in // order to fully populate the char[,,,...] array. Any such allocation is still linearly // proportional to the length of the incoming payload, so it's not a DoS vector. - // The multiplication below is guaranteed not to overflow because TotalElementsCount is bounded - // to <= uint.MaxValue (see BinaryArrayRecord.Decode) and sizeOfSingleValue is at most 8. - Debug.Assert(arrayInfo.TotalElementsCount >= 0 && arrayInfo.TotalElementsCount <= long.MaxValue / sizeOfSingleValue); + // The multiplication below is guaranteed not to overflow because FlattenedLength is bounded + // to <= Array.MaxLength (see BinaryArrayRecord.Decode) and sizeOfSingleValue is at most 8. + Debug.Assert(arrayInfo.FlattenedLength >= 0 && arrayInfo.FlattenedLength <= long.MaxValue / sizeOfSingleValue); - long size = arrayInfo.TotalElementsCount * sizeOfSingleValue; + long size = arrayInfo.FlattenedLength * sizeOfSingleValue; bool? isDataAvailable = reader.IsDataAvailable(size); if (isDataAvailable.HasValue) { diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/ThrowHelper.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/ThrowHelper.cs index 55febf77533f9..ac8c861e5d199 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/ThrowHelper.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/ThrowHelper.cs @@ -6,28 +6,30 @@ namespace System.Formats.Nrbf.Utils; +// The exception messages do not contain member/type/assembly names on purpose, +// as it's most likely corrupted/tampered/malicious data. internal static class ThrowHelper { - internal static void ThrowInvalidValue(object value) + internal static void ThrowDuplicateMemberName() + => throw new SerializationException(SR.Serialization_DuplicateMemberName); + + internal static void ThrowInvalidValue(int value) => throw new SerializationException(SR.Format(SR.Serialization_InvalidValue, value)); internal static void ThrowInvalidReference() => throw new SerializationException(SR.Serialization_InvalidReference); - internal static void ThrowInvalidTypeName(string name) - => throw new SerializationException(SR.Format(SR.Serialization_InvalidTypeName, name)); + internal static void ThrowInvalidTypeName() + => throw new SerializationException(SR.Serialization_InvalidTypeName); internal static void ThrowUnexpectedNullRecordCount() => throw new SerializationException(SR.Serialization_UnexpectedNullRecordCount); - internal static void ThrowMaxArrayLength(long limit, long actual) - => throw new SerializationException(SR.Format(SR.Serialization_MaxArrayLength, actual, limit)); - internal static void ThrowArrayContainedNulls() => throw new SerializationException(SR.Serialization_ArrayContainedNulls); - internal static void ThrowInvalidAssemblyName(string rawName) - => throw new SerializationException(SR.Format(SR.Serialization_InvalidAssemblyName, rawName)); + internal static void ThrowInvalidAssemblyName() + => throw new SerializationException(SR.Serialization_InvalidAssemblyName); internal static void ThrowInvalidFormat() => throw new SerializationException(SR.Serialization_InvalidFormat); diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/TypeNameHelpers.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/TypeNameHelpers.cs index f9a5dc0ef4385..a2fba1b52ecbc 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/TypeNameHelpers.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/TypeNameHelpers.cs @@ -135,7 +135,7 @@ internal static TypeName ParseNonSystemClassRecordTypeName(this string rawName, if (typeName is null) { - throw new SerializationException(SR.Format(SR.Serialization_InvalidTypeOrAssemblyName, rawName, libraryRecord.RawLibraryName)); + throw new SerializationException(SR.Serialization_InvalidTypeOrAssemblyName); } if (typeName.AssemblyName is null) @@ -183,7 +183,7 @@ private static TypeName With(this TypeName typeName, AssemblyNameInfo assemblyNa else { // BinaryFormatter can not serialize pointers or references. - ThrowHelper.ThrowInvalidTypeName(typeName.FullName); + ThrowHelper.ThrowInvalidTypeName(); } } diff --git a/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs b/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs index 17155b8ba636d..49d523088a89f 100644 --- a/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs @@ -141,6 +141,7 @@ private void Test(int size, bool canSeek) SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(stream); Assert.Equal(size, arrayRecord.Length); + Assert.Equal(size, arrayRecord.FlattenedLength); T?[] output = arrayRecord.GetArray(); Assert.Equal(input, output); Assert.Same(output, arrayRecord.GetArray()); diff --git a/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs b/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs index a72c3227c1eec..8bb844ff76a58 100644 --- a/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs @@ -1,4 +1,5 @@ using System.Formats.Nrbf.Utils; +using System.IO; using System.Linq; using Xunit; @@ -6,29 +7,91 @@ namespace System.Formats.Nrbf.Tests; public class JaggedArraysTests : ReadTests { - [Fact] - public void CanReadJaggedArraysOfPrimitiveTypes_2D() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void CanReadJaggedArraysOfPrimitiveTypes_2D(bool useReferences) { int[][] input = new int[7][]; + int[] same = [1, 2, 3]; for (int i = 0; i < input.Length; i++) { - input[i] = [i, i, i]; + input[i] = useReferences + ? same // reuse the same object (represented as a single record that is referenced multiple times) + : [i, i, i]; // create new array } var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); Assert.Equal(input, arrayRecord.GetArray(input.GetType())); + Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength); + } + + [Theory] + [InlineData(1)] // SerializationRecordType.ObjectNull + [InlineData(200)] // SerializationRecordType.ObjectNullMultiple256 + [InlineData(10_000)] // SerializationRecordType.ObjectNullMultiple + public void FlattenedLengthIncludesNullArrays(int nullCount) + { + int[][] input = new int[nullCount][]; + + var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); + + Verify(input, arrayRecord); + Assert.Equal(input, arrayRecord.GetArray(input.GetType())); + Assert.Equal(nullCount, arrayRecord.FlattenedLength); + } + + [Fact] + public void ItIsPossibleToHaveBinaryArrayRecordsHaveAnElementTypeOfArrayWithoutBeingMarkedAsJagged() + { + int[][][] input = new int[3][][]; + long totalElementsCount = 0; + for (int i = 0; i < input.Length; i++) + { + input[i] = new int[4][]; + totalElementsCount++; // count the arrays themselves + + for (int j = 0; j < input[i].Length; j++) + { + input[i][j] = [i, j, 0, 1, 2]; + totalElementsCount += input[i][j].Length; + totalElementsCount++; // count the arrays themselves + } + } + + byte[] serialized = Serialize(input).ToArray(); + const int ArrayTypeByteIndex = + sizeof(byte) + sizeof(int) * 4 + // stream header + sizeof(byte) + // SerializationRecordType.BinaryArray + sizeof(int); // SerializationRecordId + + Assert.Equal((byte)BinaryArrayType.Jagged, serialized[ArrayTypeByteIndex]); + + // change the reported array type + serialized[ArrayTypeByteIndex] = (byte)BinaryArrayType.Single; + + var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(new MemoryStream(serialized)); + + Verify(input, arrayRecord); + Assert.Equal(input, arrayRecord.GetArray(input.GetType())); + Assert.Equal(3 + 3 * 4 + 3 * 4 * 5, totalElementsCount); + Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength); } [Fact] public void CanReadJaggedArraysOfPrimitiveTypes_3D() { int[][][] input = new int[7][][]; + long totalElementsCount = 0; for (int i = 0; i < input.Length; i++) { + totalElementsCount++; // count the arrays themselves input[i] = new int[1][]; + totalElementsCount++; // count the arrays themselves input[i][0] = [i, i, i]; + totalElementsCount += input[i][0].Length; } var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); @@ -36,6 +99,8 @@ public void CanReadJaggedArraysOfPrimitiveTypes_3D() Verify(input, arrayRecord); Assert.Equal(input, arrayRecord.GetArray(input.GetType())); Assert.Equal(1, arrayRecord.Rank); + Assert.Equal(7 + 7 * 1 + 7 * 1 * 3, totalElementsCount); + Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength); } [Fact] @@ -60,6 +125,7 @@ public void CanReadJaggedArrayOfRectangularArrays() Verify(input, arrayRecord); Assert.Equal(input, arrayRecord.GetArray(input.GetType())); Assert.Equal(1, arrayRecord.Rank); + Assert.Equal(input.Length + input.Length * 3 * 3, arrayRecord.FlattenedLength); } [Fact] @@ -75,6 +141,7 @@ public void CanReadJaggedArraysOfStrings() Verify(input, arrayRecord); Assert.Equal(input, arrayRecord.GetArray(input.GetType())); + Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength); } [Fact] @@ -90,6 +157,7 @@ public void CanReadJaggedArraysOfObjects() Verify(input, arrayRecord); Assert.Equal(input, arrayRecord.GetArray(input.GetType())); + Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength); } [Serializable] @@ -102,14 +170,18 @@ public class ComplexType public void CanReadJaggedArraysOfComplexTypes() { ComplexType[][] input = new ComplexType[3][]; + long totalElementsCount = 0; for (int i = 0; i < input.Length; i++) { input[i] = Enumerable.Range(0, i + 1).Select(j => new ComplexType { SomeField = j }).ToArray(); + totalElementsCount += input[i].Length; + totalElementsCount++; // count the arrays themselves } var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); + Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength); var output = (ClassRecord?[][])arrayRecord.GetArray(input.GetType()); for (int i = 0; i < input.Length; i++) { diff --git a/src/libraries/System.Formats.Nrbf/tests/RectangularArraysTests.cs b/src/libraries/System.Formats.Nrbf/tests/RectangularArraysTests.cs index 25e7bb5a4d533..3191d57ba807c 100644 --- a/src/libraries/System.Formats.Nrbf/tests/RectangularArraysTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/RectangularArraysTests.cs @@ -223,10 +223,13 @@ public void CanReadRectangularArraysOfComplexTypes_3D() internal static void Verify(Array input, ArrayRecord arrayRecord) { Assert.Equal(input.Rank, arrayRecord.Lengths.Length); + long totalElementsCount = 1; for (int i = 0; i < input.Rank; i++) { Assert.Equal(input.GetLength(i), arrayRecord.Lengths[i]); + totalElementsCount *= input.GetLength(i); } + Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength); Assert.Equal(input.GetType().FullName, arrayRecord.TypeName.FullName); Assert.Equal(input.GetType().GetAssemblyNameIncludingTypeForwards(), arrayRecord.TypeName.AssemblyName!.FullName); }