Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NRBF] Fixes and fuzzing improvements #110194

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 215 additions & 54 deletions src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/NrbfDecoderFuzzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

using System.Buffers;
using System.Formats.Nrbf;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.Serialization;
using System.Text;

Expand Down Expand Up @@ -38,71 +40,27 @@ private static void Test(Span<byte> testSpan, Stream stream)
{
if (NrbfDecoder.StartsWithPayloadHeader(testSpan))
{
HashSet<SerializationRecordId> visited = new();
Queue<SerializationRecord> queue = new();
try
{
SerializationRecord record = NrbfDecoder.Decode(stream, out IReadOnlyDictionary<SerializationRecordId, SerializationRecord> recordMap);
switch (record.RecordType)

Assert.Equal(true, recordMap.ContainsKey(record.Id)); // make sure the loop below includes it
foreach (SerializationRecord fromMap in recordMap.Values)
{
case SerializationRecordType.ArraySingleObject:
SZArrayRecord<object?> arrayObj = (SZArrayRecord<object?>)record;
object?[] objArray = arrayObj.GetArray();
Assert.Equal(arrayObj.Length, objArray.Length);
Assert.Equal(1, arrayObj.Rank);
break;
case SerializationRecordType.ArraySingleString:
SZArrayRecord<string?> arrayString = (SZArrayRecord<string?>)record;
string?[] array = arrayString.GetArray();
Assert.Equal(arrayString.Length, array.Length);
Assert.Equal(1, arrayString.Rank);
Assert.Equal(true, arrayString.TypeNameMatches(typeof(string[])));
break;
case SerializationRecordType.ArraySinglePrimitive:
case SerializationRecordType.BinaryArray:
ArrayRecord arrayBinary = (ArrayRecord)record;
Assert.NotNull(arrayBinary.TypeName);
break;
case SerializationRecordType.BinaryObjectString:
_ = ((PrimitiveTypeRecord<string>)record).Value;
break;
case SerializationRecordType.ClassWithId:
case SerializationRecordType.ClassWithMembersAndTypes:
case SerializationRecordType.SystemClassWithMembersAndTypes:
ClassRecord classRecord = (ClassRecord)record;
Assert.NotNull(classRecord.TypeName);

foreach (string name in classRecord.MemberNames)
{
Assert.Equal(true, classRecord.HasMember(name));
}
break;
case SerializationRecordType.MemberPrimitiveTyped:
PrimitiveTypeRecord primitiveType = (PrimitiveTypeRecord)record;
Assert.NotNull(primitiveType.Value);
break;
case SerializationRecordType.MemberReference:
Assert.NotNull(record.TypeName);
break;
case SerializationRecordType.BinaryLibrary:
Assert.Equal(false, record.Id.Equals(default));
break;
case SerializationRecordType.ObjectNull:
case SerializationRecordType.ObjectNullMultiple:
case SerializationRecordType.ObjectNullMultiple256:
Assert.Equal(default, record.Id);
break;
case SerializationRecordType.MessageEnd:
case SerializationRecordType.SerializedStreamHeader:
// case SerializationRecordType.ClassWithMembers: will cause NotSupportedException
// case SerializationRecordType.SystemClassWithMembers: will cause NotSupportedException
default:
throw new Exception("Unexpected RecordType");
visited.Add(fromMap.Id);
queue.Enqueue(fromMap);
}
}
catch (SerializationException) { /* Reading from the stream encountered invalid NRBF data.*/ }
catch (NotSupportedException) { /* Reading from the stream encountered unsupported records */ }
catch (DecoderFallbackException) { /* Reading from the stream encountered an invalid UTF8 sequence. */ }
catch (EndOfStreamException) { /* The end of the stream was reached before reading SerializationRecordType.MessageEnd record. */ }
catch (IOException) { /* An I/O error occurred. */ }

// Lets consume it outside of the try/catch block to not swallow any exceptions by accident.
Consume(visited, queue);
}
else
{
Expand All @@ -117,6 +75,209 @@ private static void Test(Span<byte> testSpan, Stream stream)
}
}

private static void Consume(HashSet<SerializationRecordId> visited, Queue<SerializationRecord> queue)
{
while (queue.Count > 0)
{
SerializationRecord serializationRecord = queue.Dequeue();

if (serializationRecord is PrimitiveTypeRecord primitiveTypeRecord)
{
ConsumePrimitiveValue(primitiveTypeRecord.Value);
}
else if (serializationRecord is ClassRecord classRecord)
{
foreach (string memberName in classRecord.MemberNames)
{
ConsumePrimitiveValue(memberName);

Assert.Equal(true, classRecord.HasMember(memberName));

object? rawValue;

try
{
rawValue = classRecord.GetRawValue(memberName);
}
catch (SerializationException ex) when (ex.Message == "Invalid member reference.")
{
// It was a reference to a non-existing record, just continue.
continue;
}

if (rawValue is not null)
{
if (rawValue is SerializationRecord nestedRecord)
{
TryEnqueue(nestedRecord);
}
else
{
ConsumePrimitiveValue(rawValue);
}
}
}
}
else if (serializationRecord is ArrayRecord arrayRecord)
{
Type? type;

try
{
// THIS IS VERY BAD IDEA FOR ANY KIND OF PRODUCT CODE!!
// IT'S USED ONLY FOR THE PURPOSE OF TESTING, DO NOT COPY IT.
type = Type.GetType(arrayRecord.TypeName.AssemblyQualifiedName, throwOnError: false);
if (type is null)
{
continue;
}
}
catch (Exception) // throwOnError passed to GetType does not prevent from all kinds of exceptions
{
// It was some type made up by the Fuzzer.
// Since it's currently impossible to get the array without providing the type,
// we just bail here (in the future we may add an enumerator to ArrayRecord).
continue;
}

Array? array;
try
{
array = arrayRecord.GetArray(type);
}
catch (SerializationException ex) when (ex.Message == "Invalid member reference.")
{
// It contained a reference to a non-existing record, just continue.
continue;
}

ReadOnlySpan<int> lengths = arrayRecord.Lengths;
long totalElementsCount = 1;
for (int i = 0; i < arrayRecord.Rank; i++)
{
Assert.Equal(lengths[i], array.GetLength(i));
totalElementsCount *= lengths[i];
}

// This array contains indices that are used to get values of multi-dimensional array.
// At the beginning, all values are set to 0, so we start from the first element.
int[] indices = new int[arrayRecord.Rank];

long flatIndex = 0;
for (; flatIndex < totalElementsCount; flatIndex++)
{
object? rawValue = array.GetValue(indices);
if (rawValue is not null)
{
if (rawValue is SerializationRecord record)
{
TryEnqueue(record);
}
else
{
ConsumePrimitiveValue(rawValue);
}
}

// The loop below is responsible for incrementing the multi-dimensional indices.
// It finds the dimension and then performs an increment.
int dimension = indices.Length - 1;
while (dimension >= 0)
{
indices[dimension]++;
if (indices[dimension] < lengths[dimension])
{
break;
}
indices[dimension] = 0;
dimension--;
}
}

// We track the flat index to ensure that we have enumerated over all elements.
Assert.Equal(totalElementsCount, flatIndex);
}
else
{
// The map may currently contain it (it may change in the future)
Assert.Equal(SerializationRecordType.BinaryLibrary, serializationRecord.RecordType);
}
}

void TryEnqueue(SerializationRecord record)
{
if (visited.Add(record.Id)) // avoid unbounded recursion
{
queue.Enqueue(record);
}
}
}

[MethodImpl(MethodImplOptions.NoInlining)]
private static void ConsumePrimitiveValue(object value)
{
if (value is string text)
Assert.Equal(text, text.ToString()); // we want to touch all elements to see if memory is not corrupted
else if (value is bool boolean)
Assert.Equal(true, Unsafe.BitCast<bool, byte>(boolean) is 1 or 0); // other values are illegal!!
else if (value is sbyte @sbyte)
TestNumber(@sbyte);
else if (value is byte @byte)
TestNumber(@byte);
else if (value is char character)
TestNumber(character);
else if (value is short @short)
TestNumber(@short);
else if (value is ushort @ushort)
TestNumber(@ushort);
else if (value is int integer)
TestNumber(integer);
else if (value is uint @uint)
TestNumber(@uint);
else if (value is long @long)
TestNumber(@long);
else if (value is ulong @ulong)
TestNumber(@ulong);
else if (value is float @float)
{
if (!float.IsNaN(@float) && !float.IsInfinity(@float))
{
TestNumber(@float);
}
}
else if (value is double @double)
{
if (!double.IsNaN(@double) && !double.IsInfinity(@double))
{
TestNumber(@double);
}
}
else if (value is decimal @decimal)
TestNumber(@decimal);
else if (value is nint @nint)
TestNumber(@nint);
else if (value is nuint @nuint)
TestNumber(@nuint);
else if (value is DateTime datetime)
Assert.Equal(true, datetime >= DateTime.MinValue && datetime <= DateTime.MaxValue);
else if (value is TimeSpan timeSpan)
Assert.Equal(true, timeSpan >= TimeSpan.MinValue && timeSpan <= TimeSpan.MaxValue);
else
throw new InvalidOperationException();

static void TestNumber<T>(T value) where T : IComparable<T>, IMinMaxValue<T>
{
if (value.CompareTo(T.MinValue) < 0)
{
throw new Exception($"Expected {value} to be more or equal {T.MinValue}, {value.CompareTo(T.MinValue)}.");
}
if (value.CompareTo(T.MaxValue) > 0)
{
throw new Exception($"Expected {value} to be less or equal {T.MaxValue}, {value.CompareTo(T.MaxValue)}.");
}
}
}

private sealed class NonSeekableStream : MemoryStream
{
public NonSeekableStream(byte[] buffer) : base(buffer) { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ namespace System.Formats.Nrbf
public abstract partial class ArrayRecord : System.Formats.Nrbf.SerializationRecord
{
internal ArrayRecord() { }
public virtual long FlattenedLength { get { throw null; } }
public override System.Formats.Nrbf.SerializationRecordId Id { get { throw null; } }
public abstract System.ReadOnlySpan<int> Lengths { get; }
public int Rank { get { throw null; } }
Expand Down
2 changes: 1 addition & 1 deletion src/libraries/System.Formats.Nrbf/src/PACKAGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ There are more than a dozen different serialization [record types](https://learn
- `PrimitiveTypeRecord<T>` derives from the non-generic [PrimitiveTypeRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.primitivetyperecord), which also exposes a [Value](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.primitivetyperecord.value) property. But on the base class, the value is returned as `object` (which introduces boxing for value types).
- [ClassRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.classrecord): describes all `class` and `struct` besides the aforementioned primitive types.
- [ArrayRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.arrayrecord): describes all array records, including jagged and multi-dimensional arrays.
- [`SZArrayRecord<T>`](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.szarrayrecord-1): describes single-dimensional, zero-indexed array records, where `T` can be either a primitive type or a `ClassRecord`.
- [`SZArrayRecord<T>`](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.szarrayrecord-1): describes single-dimensional, zero-indexed array records, where `T` can be either a primitive type or a `SerializationRecord`.

```csharp
SerializationRecord rootObject = NrbfDecoder.Decode(payload); // payload is a Stream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ internal enum AllowedRecordTypes : uint
ArraySingleString = 1 << SerializationRecordType.ArraySingleString,

Nulls = ObjectNull | ObjectNullMultiple256 | ObjectNullMultiple,
Arrays = ArraySingleObject | ArraySinglePrimitive | ArraySingleString | BinaryArray,

/// <summary>
/// Any .NET object (a primitive, a reference type, a reference or single null).
/// </summary>
AnyObject = MemberPrimitiveTyped
| ArraySingleObject | ArraySinglePrimitive | ArraySingleString | BinaryArray
| Arrays
| ClassWithId | ClassWithMembersAndTypes | SystemClassWithMembersAndTypes
| BinaryObjectString
| MemberReference
Expand Down
Loading
Loading