Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Neo VM Optimization] Detect reference counter miscalculation. #3325

Closed
wants to merge 16 commits into from
79 changes: 79 additions & 0 deletions src/Neo.VM/CompoundTypeReferenceCountChecker.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (C) 2015-2024 The Neo Project.
//
// CompoundTypeReferenceCountChecker.cs file belongs to the neo project and is free
// software distributed under the MIT software license, see the
// accompanying file LICENSE in the main directory of the
// repository or http://www.opensource.org/licenses/mit-license.php
// for more details.
//
// Redistribution and use in source and binary forms with or without
// modifications are permitted.

using Neo.VM.Types;
using System;
using System.Collections.Generic;

namespace Neo.VM;

internal class CompoundTypeReferenceCountChecker(int maxItems = 2048)
{
public void CheckCompoundType(CompoundType rootItem)
{
if (rootItem is null) throw new ArgumentNullException(nameof(rootItem));

var visited = new HashSet<StackItem>(ReferenceEqualityComparer.Instance);
var itemCount = TraverseCompoundType(rootItem, visited, 0);

if (itemCount > maxItems)
{
throw new InvalidOperationException($"Exceeded maximum of {maxItems} items.");
}
}

private int TraverseCompoundType(CompoundType rootItem, HashSet<StackItem> visited, int itemCount)
{
var stack = new Stack<CompoundType>();
stack.Push(rootItem);
visited.Add(rootItem);
itemCount++;

while (stack.Count > 0)
{
var currentCompound = stack.Pop();

foreach (var subItem in currentCompound.SubItems)
{
// if a compound type item has reference counter assigned
// Then its subitem is referred.
if (subItem is CompoundType compoundType)
{
// If a compound type has no reference counter
// Then this compound type is problematic
if (compoundType.ReferenceCounter == null)
{
throw new InvalidOperationException("Invalid stackitem being pushed.");
}

// Check if this subItem has been visited already
if (!visited.Add(compoundType))
{
continue;
}

// Add the subItem to the stack and increment the itemCount
stack.Push(compoundType);
itemCount++;

// Check if the itemCount exceeds the maximum allowed items
if (itemCount > maxItems)
{
throw new InvalidOperationException($"Exceeded maximum of {maxItems} items.");
}

}
}
}

return itemCount;
}
}
10 changes: 10 additions & 0 deletions src/Neo.VM/EvaluationStack.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ public void Push(StackItem item)
{
innerList.Add(item);
referenceCounter.AddStackReference(item);
if (item is CompoundType compoundType)
CheckCompoundType(compoundType);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand Down Expand Up @@ -164,5 +166,13 @@ public override string ToString()
{
return $"[{string.Join(", ", innerList.Select(p => $"{p.Type}({p})"))}]";
}

private static void CheckCompoundType(CompoundType rootItem, int maxItems = 2048)
{
if (rootItem is null)
throw new ArgumentNullException();
var checker = new CompoundTypeReferenceCountChecker(maxItems: 2048);
checker.CheckCompoundType(rootItem);
}
}
}
7 changes: 4 additions & 3 deletions src/Neo.VM/ExecutionEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ protected internal void ExecuteNext()
PreExecuteInstruction(instruction);
#if VMPERF
Console.WriteLine("op:["
+ this.CurrentContext.InstructionPointer.ToString("X04")
+ CurrentContext.InstructionPointer.ToString("X04")
+ "]"
+ this.CurrentContext.CurrentInstruction?.OpCode
+ CurrentContext.CurrentInstruction?.OpCode
+ " "
+ this.CurrentContext.EvaluationStack);
+ CurrentContext.EvaluationStack);
#endif
try
{
Expand All @@ -157,6 +157,7 @@ protected internal void ExecuteNext()
}
catch (Exception e)
{
Console.WriteLine(e);
OnFault(e);
}
}
Expand Down
1 change: 1 addition & 0 deletions src/Neo.VM/ReferenceCounter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ internal int CheckZeroReferred()
subitem.ObjectReferences!.Remove(compound);
}
}

item.Cleanup();
}
var nodeToRemove = node;
Expand Down
2 changes: 1 addition & 1 deletion src/Neo.VM/Types/CompoundType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public abstract class CompoundType : StackItem
/// <summary>
/// The reference counter used to count the items in the VM object.
/// </summary>
protected readonly ReferenceCounter? ReferenceCounter;
protected internal readonly ReferenceCounter? ReferenceCounter;

/// <summary>
/// Create a new <see cref="CompoundType"/> with the specified reference counter.
Expand Down
10 changes: 5 additions & 5 deletions src/Neo.VM/Types/StackItem.Vertex.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ internal class ObjectReferenceEntry
public ObjectReferenceEntry(StackItem item) => Item = item;
}

internal int StackReferences = 0;
internal Dictionary<CompoundType, ObjectReferenceEntry>? ObjectReferences;
internal int DFN = -1;
internal int LowLink = 0;
internal bool OnStack = false;
internal int StackReferences { get; set; } = 0;
internal Dictionary<CompoundType, ObjectReferenceEntry>? ObjectReferences { get; set; }
internal int DFN { get; set; } = -1;
internal int LowLink { get; set; } = 0;
internal bool OnStack { get; set; } = false;

internal IEnumerable<StackItem> Successors => ObjectReferences?.Values.Where(p => p.References > 0).Select(p => p.Item) ?? System.Array.Empty<StackItem>();

Expand Down
18 changes: 18 additions & 0 deletions tests/Neo.VM.Tests/UT_EvaluationStack.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
using System;
using System.Collections;
using System.Linq;
using Array = Neo.VM.Types.Array;

namespace Neo.Test
{
Expand Down Expand Up @@ -220,5 +221,22 @@ public void TestPrintInvalidUTF8()
stack.Insert(0, "4CC95219999D421243C8161E3FC0F4290C067845".FromHexString());
Assert.AreEqual("[ByteString(\"Base64: TMlSGZmdQhJDyBYeP8D0KQwGeEU=\")]", stack.ToString());
}

[TestMethod]
shargon marked this conversation as resolved.
Show resolved Hide resolved
public void TestInvalidReferenceStackItem()
{
var stack = new EvaluationStack(new ReferenceCounter());
var arr = new Array();
var arr2 = new Array();

for (var i = 0; i < 10; i++)
{
arr2.Add(i);
}

arr.Add(arr2);
Assert.ThrowsException<InvalidOperationException>(() => stack.Push(arr));
}

}
}