diff --git a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs index de94eebb9b..96fb88747d 100644 --- a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs +++ b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs @@ -110,7 +110,7 @@ bool TryAddStandard(object obj, int depth) return false; } - AddToDepth(objectsByDepth, depth, obj); + TryAddToHashSet(objectsByDepth, depth, obj); return true; } @@ -133,7 +133,7 @@ public ObjectGraph DiscoverNestedObjectGraph(object rootObject, CancellationToke if (visitedObjects.Add(rootObject)) { - AddToDepth(objectsByDepth, 0, rootObject); + TryAddToHashSet(objectsByDepth, 0, rootObject); DiscoverNestedObjects(rootObject, objectsByDepth, visitedObjects, currentDepth: 1, cancellationToken); } @@ -148,7 +148,7 @@ public ObjectGraph DiscoverNestedObjectGraph(object rootObject, CancellationToke /// The test context to discover objects from. /// Cancellation token for the operation. /// The tracked objects dictionary (same as testContext.TrackedObjects). - public ConcurrentDictionary> DiscoverAndTrackObjects(TestContext testContext, CancellationToken cancellationToken = default) + public Dictionary> DiscoverAndTrackObjects(TestContext testContext, CancellationToken cancellationToken = default) { var visitedObjects = testContext.TrackedObjects; @@ -188,7 +188,7 @@ bool TryAddStandard(object value, int depth) return false; } - AddToDepth(objectsByDepth, depth, value); + TryAddToHashSet(objectsByDepth, depth, value); return true; } @@ -212,7 +212,7 @@ void Recurse(object value, int depth) /// private void DiscoverNestedObjectsForTracking( object obj, - ConcurrentDictionary> visitedObjects, + Dictionary> visitedObjects, int currentDepth, CancellationToken cancellationToken) { @@ -262,24 +262,12 @@ private static bool ShouldSkipType(Type type) } /// - /// Adds an object to the specified depth level. - /// - private static void AddToDepth(Dictionary> objectsByDepth, int depth, object obj) - { - var hashSet = objectsByDepth.GetOrAdd(depth, _ => new HashSet(ReferenceComparer)); - hashSet.Add(obj); - } - - /// - /// Thread-safe add to HashSet at specified depth. Returns true if added (not duplicate). + /// Add to HashSet at specified depth. Returns true if added (not duplicate). /// - private static bool TryAddToHashSet(ConcurrentDictionary> dict, int depth, object obj) + private static bool TryAddToHashSet(Dictionary> dict, int depth, object obj) { var hashSet = dict.GetOrAdd(depth, _ => new HashSet(ReferenceComparer)); - lock (hashSet) - { - return hashSet.Add(obj); - } + return hashSet.Add(obj); } #region Consolidated Traversal Methods (DRY) diff --git a/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs b/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs index 01eb24af48..b4b67880b5 100644 --- a/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs +++ b/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs @@ -64,7 +64,7 @@ internal interface IObjectGraphDiscoverer /// This method modifies testContext.TrackedObjects directly. For pure query operations, /// use instead. /// - ConcurrentDictionary> DiscoverAndTrackObjects(TestContext testContext, CancellationToken cancellationToken = default); + Dictionary> DiscoverAndTrackObjects(TestContext testContext, CancellationToken cancellationToken = default); } /// diff --git a/TUnit.Core/TestContext.cs b/TUnit.Core/TestContext.cs index d28177f42e..e38cd07699 100644 --- a/TUnit.Core/TestContext.cs +++ b/TUnit.Core/TestContext.cs @@ -222,20 +222,11 @@ internal void InvalidateEventReceiverCaches() CachedClassInstance = null; } - internal ConcurrentDictionary ObjectBag => _testBuilderContext.StateBag; - internal AbstractExecutableTest InternalExecutableTest { get; set; } = null!; - private ConcurrentDictionary>? _trackedObjects; - - /// - /// Thread-safe lazy initialization of TrackedObjects using LazyInitializer - /// to prevent race conditions when multiple threads access this property simultaneously. - /// - internal ConcurrentDictionary> TrackedObjects => - LazyInitializer.EnsureInitialized(ref _trackedObjects)!; + internal Dictionary> TrackedObjects { get; } = new(); /// /// Sets the output captured during test building phase. diff --git a/TUnit.Core/Tracking/ObjectTracker.cs b/TUnit.Core/Tracking/ObjectTracker.cs index ed85d31eba..88290266c4 100644 --- a/TUnit.Core/Tracking/ObjectTracker.cs +++ b/TUnit.Core/Tracking/ObjectTracker.cs @@ -53,9 +53,9 @@ private static Counter GetOrCreateCounter(object obj) => /// Thread-safe: locks each HashSet while copying. /// Pre-calculates capacity to avoid HashSet resizing during population. /// - private static ISet FlattenTrackedObjects(ConcurrentDictionary> trackedObjects) + private static ISet FlattenTrackedObjects(Dictionary> trackedObjects) { - if (trackedObjects.IsEmpty) + if (trackedObjects.Count == 0) { return ImmutableHashSet.Empty; } @@ -64,12 +64,9 @@ private static ISet FlattenTrackedObjects(ConcurrentDictionary /// The test context to get trackable objects from. /// Optional cancellation token for long-running discovery. - public ConcurrentDictionary> GetTrackableObjects(TestContext testContext, CancellationToken cancellationToken = default) + public Dictionary> GetTrackableObjects(TestContext testContext, CancellationToken cancellationToken = default) { // OCP-compliant: Use the interface method directly instead of type-checking return _discoverer.DiscoverAndTrackObjects(testContext, cancellationToken); diff --git a/TUnit.Engine/Services/ObjectLifecycleService.cs b/TUnit.Engine/Services/ObjectLifecycleService.cs index 34269e62bc..1fd081dbe3 100644 --- a/TUnit.Engine/Services/ObjectLifecycleService.cs +++ b/TUnit.Engine/Services/ObjectLifecycleService.cs @@ -235,17 +235,9 @@ private async Task InitializeTrackedObjectsAsync(TestContext testContext, Cancel continue; } - // Copy to array under lock to prevent concurrent modification - object[] objectsCopy; - lock (objectsAtLevel) - { - objectsCopy = new object[objectsAtLevel.Count]; - objectsAtLevel.CopyTo(objectsCopy); - } - // Initialize all objects at this level in parallel - var tasks = new List(objectsCopy.Length); - foreach (var obj in objectsCopy) + var tasks = new List(objectsAtLevel.Count); + foreach (var obj in objectsAtLevel) { tasks.Add(InitializeObjectWithNestedAsync(obj, cancellationToken)); }