Skip to content

Commit

Permalink
Make EntityEntryGraphIterator publicly usable (#28459)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajcvickers committed Jul 22, 2022
1 parent b0bc172 commit 5150fec
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 41 deletions.
23 changes: 6 additions & 17 deletions src/EFCore/ChangeTracking/EntityEntryGraphNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,26 +89,15 @@ InternalEntityEntry IInfrastructure<InternalEntityEntry>.Instance
=> _entry;

/// <summary>
/// Creates a new node for the entity that is being traversed next in the graph.
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
/// <param name="currentNode">The node that the entity is being traversed from.</param>
/// <param name="internalEntityEntry">
/// The internal entry tracking information about the entity being traversed to.
/// </param>
/// <param name="reachedVia">The navigation property that is being traversed to reach the new node.</param>
/// <returns>The newly created node.</returns>
[EntityFrameworkInternal]
public virtual EntityEntryGraphNode CreateNode(
EntityEntryGraphNode currentNode,
InternalEntityEntry internalEntityEntry,
INavigationBase reachedVia)
{
Check.NotNull(currentNode, nameof(currentNode));
Check.NotNull(internalEntityEntry, nameof(internalEntityEntry));
Check.NotNull(reachedVia, nameof(reachedVia));

return new EntityEntryGraphNode(
internalEntityEntry,
currentNode.Entry.GetInfrastructure(),
reachedVia);
}
=> new(internalEntityEntry, currentNode.Entry.GetInfrastructure(), reachedVia);
}
36 changes: 22 additions & 14 deletions src/EFCore/ChangeTracking/EntityEntryGraphNode`.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,33 +32,41 @@ public EntityEntryGraphNode(
NodeState = state;
}

/// <summary>
/// Creates a new node in the entity graph.
/// </summary>
/// <param name="entry">The entry for the entity represented by this node.</param>
/// <param name="state">A state object that will be available when processing each node.</param>
/// <param name="sourceEntry">The entry from which this node was reached, or <see langword="null" /> if this is the root node.</param>
/// <param name="inboundNavigation">The navigation from the source node to this node, or <see langword="null" /> if this is the root node.</param>
public EntityEntryGraphNode(
EntityEntry entry,
TState state,
EntityEntry? sourceEntry,
INavigationBase? inboundNavigation)
: this(entry.GetInfrastructure(), state, sourceEntry?.GetInfrastructure(), inboundNavigation)
{
}

/// <summary>
/// Gets or sets state that will be available to all nodes that are visited after this node.
/// </summary>
public virtual TState NodeState { get; set; }

/// <summary>
/// Creates a new node for the entity that is being traversed next in the graph.
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
/// <param name="currentNode">The node that the entity is being traversed from.</param>
/// <param name="internalEntityEntry">
/// The internal entry tracking information about the entity being traversed to.
/// </param>
/// <param name="reachedVia">The navigation property that is being traversed to reach the new node.</param>
/// <returns>The newly created node.</returns>
[EntityFrameworkInternal]
public override EntityEntryGraphNode CreateNode(
EntityEntryGraphNode currentNode,
InternalEntityEntry internalEntityEntry,
INavigationBase reachedVia)
{
Check.NotNull(currentNode, nameof(currentNode));
Check.NotNull(internalEntityEntry, nameof(internalEntityEntry));
Check.NotNull(reachedVia, nameof(reachedVia));

return new EntityEntryGraphNode<TState>(
=> new EntityEntryGraphNode<TState>(
internalEntityEntry,
((EntityEntryGraphNode<TState>)currentNode).NodeState,
currentNode.Entry.GetInfrastructure(),
reachedVia);
}
}
92 changes: 82 additions & 10 deletions test/EFCore.Tests/ChangeTracking/TrackGraphTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,88 @@ protected override IList<string> TrackGraph(DbContext context, object root, Acti

return traversal;
}

[ConditionalTheory] // Issue #26461
[InlineData(false)]
[InlineData(true)]
public async Task Can_iterate_over_graph_using_public_surface(bool async)
{
using var context = new EarlyLearningCenter(GetType().Name);
var category = new Category
{
Id = 1,
Products = new List<Product>
{
new()
{
Id = 1,
CategoryId = 1,
Details = new ProductDetails { Id = 1 }
},
new()
{
Id = 2,
CategoryId = 1,
Details = new ProductDetails { Id = 2 }
},
new()
{
Id = 3,
CategoryId = 1,
Details = new ProductDetails { Id = 3 }
}
}
};

var rootEntry = context.Attach(category);

var graphIterator = context.GetService<IEntityEntryGraphIterator>();

var visited = new HashSet<object>();
var traversal = new List<string>();

bool Callback(EntityEntryGraphNode<HashSet<object>> node)
{
if (node.NodeState.Contains(node.Entry.Entity))
{
return false;
}

node.NodeState.Add(node.Entry.Entity);

traversal.Add(NodeString(node));

return true;
}

if (async)
{
await graphIterator.TraverseGraphAsync(
new EntityEntryGraphNode<HashSet<object>>(rootEntry, visited, null, null),
(node, _) => Task.FromResult(Callback(node)));
}
else
{
graphIterator.TraverseGraph(
new EntityEntryGraphNode<HashSet<object>>(rootEntry, visited, null, null),
Callback);
}

Assert.Equal(
new List<string>
{
"<None> -----> Category:1",
"Category:1 ---Products--> Product:1",
"Product:1 ---Details--> ProductDetails:1",
"Category:1 ---Products--> Product:2",
"Product:2 ---Details--> ProductDetails:2",
"Category:1 ---Products--> Product:3",
"Product:3 ---Details--> ProductDetails:3"
},
traversal);

Assert.Equal(7, visited.Count);
}
}

public class TrackGraphTestWithState : TrackGraphTestBase
Expand Down Expand Up @@ -1152,16 +1234,6 @@ public void TrackGraph_overload_can_visit_an_already_attached_graph()
Assert.Equal(7, visited.Count);
}

private static void AssertValuesSaved(int id, int someInt, string someString)
{
using var context = new TheShadows();
var entry = context.Entry(context.Set<Dark>().Single(e => EF.Property<int>(e, "Id") == id));

Assert.Equal(id, entry.Property<int>("Id").CurrentValue);
Assert.Equal(someInt, entry.Property<int>("SomeInt").CurrentValue);
Assert.Equal(someString, entry.Property<string>("SomeString").CurrentValue);
}

private class TheShadows : DbContext
{
protected internal override void OnModelCreating(ModelBuilder modelBuilder)
Expand Down

0 comments on commit 5150fec

Please sign in to comment.