Skip to content

Commit

Permalink
[NRBF] Fixes and fuzzing improvements (dotnet#110194)
Browse files Browse the repository at this point in the history
* Simplify array handling to fix issues with jagged and abstract array types

Jagged arrays in the payload can contain cycles. In that scenario, no value is correct for `ArrayRecord.FlattenedLength`, and `ArrayRecord.GetArray` does not have enough context to know how to handle the cycles. To address these issues, jagged array handling is simplified so that calling code can handle the cycles in the most appropriate way for the application.

Single-dimension arrays can be represented in the payload using abstract types such as `IComparable[]` where the concrete element type is not known. When the concrete element type is known, `ArrayRecord.GetArray` could return either `SZArrayRecord<ClassRecord>` or `SZArrayRecord<object>`; without a concrete type, we need to return something that represents the element abstractly.

1. `ArrayRecord.FlattenedLength` is removed from the API
2. `ArrayRecord.GetArray` now returns `ArrayRecord[]` for jagged arrays instead of trying to populate them
3. `ArrayRecord.GetArray` now returns `SZArrayRecord<SerializationRecord>` for single-dimension arrays instead of either `SZArrayRecord<ClassRecord>` or `SZArrayRecord<object>`

* extend the Fuzzer to consume all possible data exposed by the NrbfDecoder
  • Loading branch information
adamsitnik authored and mikelle-rogers committed Dec 4, 2024
1 parent 54dd763 commit 82e1216
Show file tree
Hide file tree
Showing 35 changed files with 1,855 additions and 878 deletions.
269 changes: 215 additions & 54 deletions src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/NrbfDecoderFuzzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

using System.Buffers;
using System.Formats.Nrbf;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.Serialization;
using System.Text;

Expand Down Expand Up @@ -38,71 +40,27 @@ private static void Test(Span<byte> testSpan, Stream stream)
{
if (NrbfDecoder.StartsWithPayloadHeader(testSpan))
{
HashSet<SerializationRecordId> visited = new();
Queue<SerializationRecord> queue = new();
try
{
SerializationRecord record = NrbfDecoder.Decode(stream, out IReadOnlyDictionary<SerializationRecordId, SerializationRecord> recordMap);
switch (record.RecordType)

Assert.Equal(true, recordMap.ContainsKey(record.Id)); // make sure the loop below includes it
foreach (SerializationRecord fromMap in recordMap.Values)
{
case SerializationRecordType.ArraySingleObject:
SZArrayRecord<object?> arrayObj = (SZArrayRecord<object?>)record;
object?[] objArray = arrayObj.GetArray();
Assert.Equal(arrayObj.Length, objArray.Length);
Assert.Equal(1, arrayObj.Rank);
break;
case SerializationRecordType.ArraySingleString:
SZArrayRecord<string?> arrayString = (SZArrayRecord<string?>)record;
string?[] array = arrayString.GetArray();
Assert.Equal(arrayString.Length, array.Length);
Assert.Equal(1, arrayString.Rank);
Assert.Equal(true, arrayString.TypeNameMatches(typeof(string[])));
break;
case SerializationRecordType.ArraySinglePrimitive:
case SerializationRecordType.BinaryArray:
ArrayRecord arrayBinary = (ArrayRecord)record;
Assert.NotNull(arrayBinary.TypeName);
break;
case SerializationRecordType.BinaryObjectString:
_ = ((PrimitiveTypeRecord<string>)record).Value;
break;
case SerializationRecordType.ClassWithId:
case SerializationRecordType.ClassWithMembersAndTypes:
case SerializationRecordType.SystemClassWithMembersAndTypes:
ClassRecord classRecord = (ClassRecord)record;
Assert.NotNull(classRecord.TypeName);

foreach (string name in classRecord.MemberNames)
{
Assert.Equal(true, classRecord.HasMember(name));
}
break;
case SerializationRecordType.MemberPrimitiveTyped:
PrimitiveTypeRecord primitiveType = (PrimitiveTypeRecord)record;
Assert.NotNull(primitiveType.Value);
break;
case SerializationRecordType.MemberReference:
Assert.NotNull(record.TypeName);
break;
case SerializationRecordType.BinaryLibrary:
Assert.Equal(false, record.Id.Equals(default));
break;
case SerializationRecordType.ObjectNull:
case SerializationRecordType.ObjectNullMultiple:
case SerializationRecordType.ObjectNullMultiple256:
Assert.Equal(default, record.Id);
break;
case SerializationRecordType.MessageEnd:
case SerializationRecordType.SerializedStreamHeader:
// case SerializationRecordType.ClassWithMembers: will cause NotSupportedException
// case SerializationRecordType.SystemClassWithMembers: will cause NotSupportedException
default:
throw new Exception("Unexpected RecordType");
visited.Add(fromMap.Id);
queue.Enqueue(fromMap);
}
}
catch (SerializationException) { /* Reading from the stream encountered invalid NRBF data.*/ }
catch (NotSupportedException) { /* Reading from the stream encountered unsupported records */ }
catch (DecoderFallbackException) { /* Reading from the stream encountered an invalid UTF8 sequence. */ }
catch (EndOfStreamException) { /* The end of the stream was reached before reading SerializationRecordType.MessageEnd record. */ }
catch (IOException) { /* An I/O error occurred. */ }

// Lets consume it outside of the try/catch block to not swallow any exceptions by accident.
Consume(visited, queue);
}
else
{
Expand All @@ -117,6 +75,209 @@ private static void Test(Span<byte> testSpan, Stream stream)
}
}

private static void Consume(HashSet<SerializationRecordId> visited, Queue<SerializationRecord> queue)
{
while (queue.Count > 0)
{
SerializationRecord serializationRecord = queue.Dequeue();

if (serializationRecord is PrimitiveTypeRecord primitiveTypeRecord)
{
ConsumePrimitiveValue(primitiveTypeRecord.Value);
}
else if (serializationRecord is ClassRecord classRecord)
{
foreach (string memberName in classRecord.MemberNames)
{
ConsumePrimitiveValue(memberName);

Assert.Equal(true, classRecord.HasMember(memberName));

object? rawValue;

try
{
rawValue = classRecord.GetRawValue(memberName);
}
catch (SerializationException ex) when (ex.Message == "Invalid member reference.")
{
// It was a reference to a non-existing record, just continue.
continue;
}

if (rawValue is not null)
{
if (rawValue is SerializationRecord nestedRecord)
{
TryEnqueue(nestedRecord);
}
else
{
ConsumePrimitiveValue(rawValue);
}
}
}
}
else if (serializationRecord is ArrayRecord arrayRecord)
{
Type? type;

try
{
// THIS IS VERY BAD IDEA FOR ANY KIND OF PRODUCT CODE!!
// IT'S USED ONLY FOR THE PURPOSE OF TESTING, DO NOT COPY IT.
type = Type.GetType(arrayRecord.TypeName.AssemblyQualifiedName, throwOnError: false);
if (type is null)
{
continue;
}
}
catch (Exception) // throwOnError passed to GetType does not prevent from all kinds of exceptions
{
// It was some type made up by the Fuzzer.
// Since it's currently impossible to get the array without providing the type,
// we just bail here (in the future we may add an enumerator to ArrayRecord).
continue;
}

Array? array;
try
{
array = arrayRecord.GetArray(type);
}
catch (SerializationException ex) when (ex.Message == "Invalid member reference.")
{
// It contained a reference to a non-existing record, just continue.
continue;
}

ReadOnlySpan<int> lengths = arrayRecord.Lengths;
long totalElementsCount = 1;
for (int i = 0; i < arrayRecord.Rank; i++)
{
Assert.Equal(lengths[i], array.GetLength(i));
totalElementsCount *= lengths[i];
}

// This array contains indices that are used to get values of multi-dimensional array.
// At the beginning, all values are set to 0, so we start from the first element.
int[] indices = new int[arrayRecord.Rank];

long flatIndex = 0;
for (; flatIndex < totalElementsCount; flatIndex++)
{
object? rawValue = array.GetValue(indices);
if (rawValue is not null)
{
if (rawValue is SerializationRecord record)
{
TryEnqueue(record);
}
else
{
ConsumePrimitiveValue(rawValue);
}
}

// The loop below is responsible for incrementing the multi-dimensional indices.
// It finds the dimension and then performs an increment.
int dimension = indices.Length - 1;
while (dimension >= 0)
{
indices[dimension]++;
if (indices[dimension] < lengths[dimension])
{
break;
}
indices[dimension] = 0;
dimension--;
}
}

// We track the flat index to ensure that we have enumerated over all elements.
Assert.Equal(totalElementsCount, flatIndex);
}
else
{
// The map may currently contain it (it may change in the future)
Assert.Equal(SerializationRecordType.BinaryLibrary, serializationRecord.RecordType);
}
}

void TryEnqueue(SerializationRecord record)
{
if (visited.Add(record.Id)) // avoid unbounded recursion
{
queue.Enqueue(record);
}
}
}

[MethodImpl(MethodImplOptions.NoInlining)]
private static void ConsumePrimitiveValue(object value)
{
if (value is string text)
Assert.Equal(text, text.ToString()); // we want to touch all elements to see if memory is not corrupted
else if (value is bool boolean)
Assert.Equal(true, Unsafe.BitCast<bool, byte>(boolean) is 1 or 0); // other values are illegal!!
else if (value is sbyte @sbyte)
TestNumber(@sbyte);
else if (value is byte @byte)
TestNumber(@byte);
else if (value is char character)
TestNumber(character);
else if (value is short @short)
TestNumber(@short);
else if (value is ushort @ushort)
TestNumber(@ushort);
else if (value is int integer)
TestNumber(integer);
else if (value is uint @uint)
TestNumber(@uint);
else if (value is long @long)
TestNumber(@long);
else if (value is ulong @ulong)
TestNumber(@ulong);
else if (value is float @float)
{
if (!float.IsNaN(@float) && !float.IsInfinity(@float))
{
TestNumber(@float);
}
}
else if (value is double @double)
{
if (!double.IsNaN(@double) && !double.IsInfinity(@double))
{
TestNumber(@double);
}
}
else if (value is decimal @decimal)
TestNumber(@decimal);
else if (value is nint @nint)
TestNumber(@nint);
else if (value is nuint @nuint)
TestNumber(@nuint);
else if (value is DateTime datetime)
Assert.Equal(true, datetime >= DateTime.MinValue && datetime <= DateTime.MaxValue);
else if (value is TimeSpan timeSpan)
Assert.Equal(true, timeSpan >= TimeSpan.MinValue && timeSpan <= TimeSpan.MaxValue);
else
throw new InvalidOperationException();

static void TestNumber<T>(T value) where T : IComparable<T>, IMinMaxValue<T>
{
if (value.CompareTo(T.MinValue) < 0)
{
throw new Exception($"Expected {value} to be more or equal {T.MinValue}, {value.CompareTo(T.MinValue)}.");
}
if (value.CompareTo(T.MaxValue) > 0)
{
throw new Exception($"Expected {value} to be less or equal {T.MaxValue}, {value.CompareTo(T.MaxValue)}.");
}
}
}

private sealed class NonSeekableStream : MemoryStream
{
public NonSeekableStream(byte[] buffer) : base(buffer) { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ namespace System.Formats.Nrbf
public abstract partial class ArrayRecord : System.Formats.Nrbf.SerializationRecord
{
internal ArrayRecord() { }
public virtual long FlattenedLength { get { throw null; } }
public override System.Formats.Nrbf.SerializationRecordId Id { get { throw null; } }
public abstract System.ReadOnlySpan<int> Lengths { get; }
public int Rank { get { throw null; } }
Expand Down
2 changes: 1 addition & 1 deletion src/libraries/System.Formats.Nrbf/src/PACKAGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ There are more than a dozen different serialization [record types](https://learn
- `PrimitiveTypeRecord<T>` derives from the non-generic [PrimitiveTypeRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.primitivetyperecord), which also exposes a [Value](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.primitivetyperecord.value) property. But on the base class, the value is returned as `object` (which introduces boxing for value types).
- [ClassRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.classrecord): describes all `class` and `struct` besides the aforementioned primitive types.
- [ArrayRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.arrayrecord): describes all array records, including jagged and multi-dimensional arrays.
- [`SZArrayRecord<T>`](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.szarrayrecord-1): describes single-dimensional, zero-indexed array records, where `T` can be either a primitive type or a `ClassRecord`.
- [`SZArrayRecord<T>`](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.szarrayrecord-1): describes single-dimensional, zero-indexed array records, where `T` can be either a primitive type or a `SerializationRecord`.

```csharp
SerializationRecord rootObject = NrbfDecoder.Decode(payload); // payload is a Stream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ internal enum AllowedRecordTypes : uint
ArraySingleString = 1 << SerializationRecordType.ArraySingleString,

Nulls = ObjectNull | ObjectNullMultiple256 | ObjectNullMultiple,
Arrays = ArraySingleObject | ArraySinglePrimitive | ArraySingleString | BinaryArray,

/// <summary>
/// Any .NET object (a primitive, a reference type, a reference or single null).
/// </summary>
AnyObject = MemberPrimitiveTyped
| ArraySingleObject | ArraySinglePrimitive | ArraySingleString | BinaryArray
| Arrays
| ClassWithId | ClassWithMembersAndTypes | SystemClassWithMembersAndTypes
| BinaryObjectString
| MemberReference
Expand Down
Loading

0 comments on commit 82e1216

Please sign in to comment.