diff --git a/src/Microsoft.EntityFrameworkCore.Specification.Tests/GraphUpdatesTestBase.cs b/src/Microsoft.EntityFrameworkCore.Specification.Tests/GraphUpdatesTestBase.cs index f4d171290a5..38f245ce98b 100644 --- a/src/Microsoft.EntityFrameworkCore.Specification.Tests/GraphUpdatesTestBase.cs +++ b/src/Microsoft.EntityFrameworkCore.Specification.Tests/GraphUpdatesTestBase.cs @@ -119,9 +119,9 @@ public virtual void Save_optional_many_to_one_dependents(ChangeMechanism changeM switch (changeMechanism) { case ChangeMechanism.Principal: - existing.Children.Add(new2a); - existing.Children.Add(new2b); - root.OptionalChildren.Add(new1); + Add(existing.Children, new2a); + Add(existing.Children, new2b); + Add(root.OptionalChildren, new1); break; case ChangeMechanism.Dependent: new2a.Parent = existing; @@ -209,9 +209,9 @@ public virtual void Save_required_many_to_one_dependents(ChangeMechanism changeM switch (changeMechanism) { case ChangeMechanism.Principal: - existing.Children.Add(new2a); - existing.Children.Add(new2b); - root.RequiredChildren.Add(new1); + Add(existing.Children, new2a); + Add(existing.Children, new2b); + Add(root.RequiredChildren, new1); break; case ChangeMechanism.Dependent: new2a.Parent = existing; @@ -276,8 +276,8 @@ public virtual void Save_removed_optional_many_to_one_dependents(ChangeMechanism removed1.Parent = null; break; case ChangeMechanism.Principal: - childCollection.Remove(removed2); - root.OptionalChildren.Remove(removed1); + Remove(childCollection, removed2); + Remove(root.OptionalChildren, removed1); break; case ChangeMechanism.FK: removed2.ParentId = null; @@ -305,11 +305,11 @@ public virtual void Save_removed_optional_many_to_one_dependents(ChangeMechanism AssertKeys(root, loadedRoot); AssertNavigations(loadedRoot); - Assert.Equal(2, loadedRoot.RequiredChildren.Count); - Assert.Equal(2, loadedRoot.RequiredChildren.First().Children.Count); + Assert.Equal(2, loadedRoot.RequiredChildren.Count()); + Assert.Equal(2, loadedRoot.RequiredChildren.First().Children.Count()); - Assert.Equal(1, loadedRoot.OptionalChildren.Count); - Assert.Equal(1, loadedRoot.OptionalChildren.First().Children.Count); + Assert.Equal(1, loadedRoot.OptionalChildren.Count()); + Assert.Equal(1, loadedRoot.OptionalChildren.First().Children.Count()); } } @@ -342,8 +342,8 @@ public virtual void Save_removed_required_many_to_one_dependents(ChangeMechanism removed1.Parent = null; break; case ChangeMechanism.Principal: - childCollection.Remove(removed2); - root.RequiredChildren.Remove(removed1); + Remove(childCollection, removed2); + Remove(root.RequiredChildren, removed1); break; case ChangeMechanism.FK: context.Entry(removed2).GetInfrastructure()[context.Entry(removed2).Property(e => e.ParentId).Metadata] = null; @@ -362,7 +362,7 @@ public virtual void Save_removed_required_many_to_one_dependents(ChangeMechanism AssertNavigations(root); - Assert.Equal(1, root.RequiredChildren.Count); + Assert.Equal(1, root.RequiredChildren.Count()); Assert.DoesNotContain(removed1Id, root.RequiredChildren.Select(e => e.Id)); Assert.Empty(context.Required1s.Where(e => e.Id == removed1Id)); @@ -1007,9 +1007,9 @@ public virtual void Save_optional_many_to_one_dependents_with_alternate_key(Chan switch (changeMechanism) { case ChangeMechanism.Principal: - existing.Children.Add(new2a); - existing.Children.Add(new2b); - root.OptionalChildrenAk.Add(new1); + Add(existing.Children, new2a); + Add(existing.Children, new2b); + Add(root.OptionalChildrenAk, new1); break; case ChangeMechanism.Dependent: new2a.Parent = existing; @@ -1097,9 +1097,9 @@ public virtual void Save_required_many_to_one_dependents_with_alternate_key(Chan switch (changeMechanism) { case ChangeMechanism.Principal: - existing.Children.Add(new2a); - existing.Children.Add(new2b); - root.RequiredChildrenAk.Add(new1); + Add(existing.Children, new2a); + Add(existing.Children, new2b); + Add(root.RequiredChildrenAk, new1); break; case ChangeMechanism.Dependent: new2a.Parent = existing; @@ -1164,8 +1164,8 @@ public virtual void Save_removed_optional_many_to_one_dependents_with_alternate_ removed1.Parent = null; break; case ChangeMechanism.Principal: - childCollection.Remove(removed2); - root.OptionalChildrenAk.Remove(removed1); + Remove(childCollection, removed2); + Remove(root.OptionalChildrenAk, removed1); break; case ChangeMechanism.FK: removed2.ParentId = null; @@ -1193,11 +1193,11 @@ public virtual void Save_removed_optional_many_to_one_dependents_with_alternate_ AssertKeys(root, loadedRoot); AssertNavigations(loadedRoot); - Assert.Equal(2, loadedRoot.RequiredChildrenAk.Count); - Assert.Equal(2, loadedRoot.RequiredChildrenAk.First().Children.Count); + Assert.Equal(2, loadedRoot.RequiredChildrenAk.Count()); + Assert.Equal(2, loadedRoot.RequiredChildrenAk.First().Children.Count()); - Assert.Equal(1, loadedRoot.OptionalChildrenAk.Count); - Assert.Equal(1, loadedRoot.OptionalChildrenAk.First().Children.Count); + Assert.Equal(1, loadedRoot.OptionalChildrenAk.Count()); + Assert.Equal(1, loadedRoot.OptionalChildrenAk.First().Children.Count()); } } @@ -1225,8 +1225,8 @@ public virtual void Save_removed_required_many_to_one_dependents_with_alternate_ removed1.Parent = null; break; case ChangeMechanism.Principal: - childCollection.Remove(removed2); - root.RequiredChildrenAk.Remove(removed1); + Remove(childCollection, removed2); + Remove(root.RequiredChildrenAk, removed1); break; default: throw new ArgumentOutOfRangeException(nameof(changeMechanism)); @@ -1251,11 +1251,11 @@ public virtual void Save_removed_required_many_to_one_dependents_with_alternate_ Assert.False(context.RequiredAk1s.Any(e => e.Id == removed1.Id)); Assert.False(context.RequiredAk2s.Any(e => e.Id == removed2.Id)); - Assert.Equal(1, loadedRoot.RequiredChildrenAk.Count); - Assert.Equal(1, loadedRoot.RequiredChildrenAk.First().Children.Count); + Assert.Equal(1, loadedRoot.RequiredChildrenAk.Count()); + Assert.Equal(1, loadedRoot.RequiredChildrenAk.First().Children.Count()); - Assert.Equal(2, loadedRoot.OptionalChildrenAk.Count); - Assert.Equal(2, loadedRoot.OptionalChildrenAk.First().Children.Count); + Assert.Equal(2, loadedRoot.OptionalChildrenAk.Count()); + Assert.Equal(2, loadedRoot.OptionalChildrenAk.First().Children.Count()); } } @@ -1888,7 +1888,7 @@ public virtual void Required_many_to_one_dependents_are_cascade_deleted() { var root = LoadFullGraph(context); - Assert.Equal(2, root.RequiredChildren.Count); + Assert.Equal(2, root.RequiredChildren.Count()); var removed = root.RequiredChildren.First(); @@ -1905,7 +1905,7 @@ public virtual void Required_many_to_one_dependents_are_cascade_deleted() Assert.Equal(EntityState.Detached, context.Entry(removed).State); Assert.True(cascadeRemoved.All(e => context.Entry(e).State == EntityState.Detached)); - Assert.Equal(1, root.RequiredChildren.Count); + Assert.Equal(1, root.RequiredChildren.Count()); Assert.DoesNotContain(removedId, root.RequiredChildren.Select(e => e.Id)); Assert.Empty(context.Required1s.Where(e => e.Id == removedId)); @@ -1916,7 +1916,7 @@ public virtual void Required_many_to_one_dependents_are_cascade_deleted() { var root = LoadFullGraph(context); - Assert.Equal(1, root.RequiredChildren.Count); + Assert.Equal(1, root.RequiredChildren.Count()); Assert.DoesNotContain(removedId, root.RequiredChildren.Select(e => e.Id)); Assert.Empty(context.Required1s.Where(e => e.Id == removedId)); @@ -1934,7 +1934,7 @@ public virtual void Optional_many_to_one_dependents_are_orphaned() { var root = LoadFullGraph(context); - Assert.Equal(2, root.OptionalChildren.Count); + Assert.Equal(2, root.OptionalChildren.Count()); var removed = root.OptionalChildren.First(); @@ -1951,7 +1951,7 @@ public virtual void Optional_many_to_one_dependents_are_orphaned() Assert.Equal(EntityState.Detached, context.Entry(removed).State); Assert.True(orphaned.All(e => context.Entry(e).State == EntityState.Unchanged)); - Assert.Equal(1, root.OptionalChildren.Count); + Assert.Equal(1, root.OptionalChildren.Count()); Assert.DoesNotContain(removedId, root.OptionalChildren.Select(e => e.Id)); Assert.Empty(context.Optional1s.Where(e => e.Id == removedId)); @@ -1962,7 +1962,7 @@ public virtual void Optional_many_to_one_dependents_are_orphaned() { var root = LoadFullGraph(context); - Assert.Equal(1, root.OptionalChildren.Count); + Assert.Equal(1, root.OptionalChildren.Count()); Assert.DoesNotContain(removedId, root.OptionalChildren.Select(e => e.Id)); Assert.Empty(context.Optional1s.Where(e => e.Id == removedId)); @@ -2100,7 +2100,7 @@ public virtual void Optional_many_to_one_dependents_with_alternate_key_are_orpha { var root = LoadFullGraph(context); - Assert.Equal(2, root.OptionalChildrenAk.Count); + Assert.Equal(2, root.OptionalChildrenAk.Count()); var removed = root.OptionalChildrenAk.First(); @@ -2117,7 +2117,7 @@ public virtual void Optional_many_to_one_dependents_with_alternate_key_are_orpha Assert.Equal(EntityState.Detached, context.Entry(removed).State); Assert.True(orphaned.All(e => context.Entry(e).State == EntityState.Unchanged)); - Assert.Equal(1, root.OptionalChildrenAk.Count); + Assert.Equal(1, root.OptionalChildrenAk.Count()); Assert.DoesNotContain(removedId, root.OptionalChildrenAk.Select(e => e.Id)); Assert.Empty(context.OptionalAk1s.Where(e => e.Id == removedId)); @@ -2128,7 +2128,7 @@ public virtual void Optional_many_to_one_dependents_with_alternate_key_are_orpha { var root = LoadFullGraph(context); - Assert.Equal(1, root.OptionalChildrenAk.Count); + Assert.Equal(1, root.OptionalChildrenAk.Count()); Assert.DoesNotContain(removedId, root.OptionalChildrenAk.Select(e => e.Id)); Assert.Empty(context.OptionalAk1s.Where(e => e.Id == removedId)); @@ -2146,7 +2146,7 @@ public virtual void Required_many_to_one_dependents_with_alternate_key_are_casca { var root = LoadFullGraph(context); - Assert.Equal(2, root.RequiredChildrenAk.Count); + Assert.Equal(2, root.RequiredChildrenAk.Count()); var removed = root.RequiredChildrenAk.First(); @@ -2163,7 +2163,7 @@ public virtual void Required_many_to_one_dependents_with_alternate_key_are_casca Assert.Equal(EntityState.Detached, context.Entry(removed).State); Assert.True(cascadeRemoved.All(e => context.Entry(e).State == EntityState.Detached)); - Assert.Equal(1, root.RequiredChildrenAk.Count); + Assert.Equal(1, root.RequiredChildrenAk.Count()); Assert.DoesNotContain(removedId, root.RequiredChildrenAk.Select(e => e.Id)); Assert.Empty(context.RequiredAk1s.Where(e => e.Id == removedId)); @@ -2174,7 +2174,7 @@ public virtual void Required_many_to_one_dependents_with_alternate_key_are_casca { var root = LoadFullGraph(context); - Assert.Equal(1, root.RequiredChildrenAk.Count); + Assert.Equal(1, root.RequiredChildrenAk.Count()); Assert.DoesNotContain(removedId, root.RequiredChildrenAk.Select(e => e.Id)); Assert.Empty(context.RequiredAk1s.Where(e => e.Id == removedId)); @@ -2332,7 +2332,7 @@ public virtual void Required_many_to_one_dependents_are_cascade_deleted_in_store Assert.Equal(EntityState.Detached, context.Entry(removed).State); - Assert.Equal(1, root.RequiredChildren.Count); + Assert.Equal(1, root.RequiredChildren.Count()); Assert.DoesNotContain(removedId, root.RequiredChildren.Select(e => e.Id)); Assert.Empty(context.Required1s.Where(e => e.Id == removedId)); @@ -2343,7 +2343,7 @@ public virtual void Required_many_to_one_dependents_are_cascade_deleted_in_store { var root = LoadFullGraph(context); - Assert.Equal(1, root.RequiredChildren.Count); + Assert.Equal(1, root.RequiredChildren.Count()); Assert.DoesNotContain(removedId, root.RequiredChildren.Select(e => e.Id)); Assert.Empty(context.Required1s.Where(e => e.Id == removedId)); @@ -2467,7 +2467,7 @@ public virtual void Required_many_to_one_dependents_with_alternate_key_are_casca Assert.Equal(EntityState.Detached, context.Entry(removed).State); - Assert.Equal(1, root.RequiredChildrenAk.Count); + Assert.Equal(1, root.RequiredChildrenAk.Count()); Assert.DoesNotContain(removedId, root.RequiredChildrenAk.Select(e => e.Id)); Assert.Empty(context.RequiredAk1s.Where(e => e.Id == removedId)); @@ -2478,7 +2478,7 @@ public virtual void Required_many_to_one_dependents_with_alternate_key_are_casca { var root = LoadFullGraph(context); - Assert.Equal(1, root.RequiredChildrenAk.Count); + Assert.Equal(1, root.RequiredChildrenAk.Count()); Assert.DoesNotContain(removedId, root.RequiredChildrenAk.Select(e => e.Id)); Assert.Empty(context.RequiredAk1s.Where(e => e.Id == removedId)); @@ -2602,7 +2602,7 @@ public virtual void Optional_many_to_one_dependents_are_orphaned_in_store() Assert.Equal(EntityState.Detached, context.Entry(removed).State); - Assert.Equal(1, root.OptionalChildren.Count); + Assert.Equal(1, root.OptionalChildren.Count()); Assert.DoesNotContain(removedId, root.OptionalChildren.Select(e => e.Id)); Assert.Empty(context.Optional1s.Where(e => e.Id == removedId)); @@ -2616,7 +2616,7 @@ public virtual void Optional_many_to_one_dependents_are_orphaned_in_store() { var root = LoadFullGraph(context); - Assert.Equal(1, root.OptionalChildren.Count); + Assert.Equal(1, root.OptionalChildren.Count()); Assert.DoesNotContain(removedId, root.OptionalChildren.Select(e => e.Id)); Assert.Empty(context.Optional1s.Where(e => e.Id == removedId)); @@ -2700,7 +2700,7 @@ public virtual void Optional_many_to_one_dependents_with_alternate_key_are_orpha Assert.Equal(EntityState.Detached, context.Entry(removed).State); - Assert.Equal(1, root.OptionalChildrenAk.Count); + Assert.Equal(1, root.OptionalChildrenAk.Count()); Assert.DoesNotContain(removedId, root.OptionalChildrenAk.Select(e => e.Id)); Assert.Empty(context.OptionalAk1s.Where(e => e.Id == removedId)); @@ -2714,7 +2714,7 @@ public virtual void Optional_many_to_one_dependents_with_alternate_key_are_orpha { var root = LoadFullGraph(context); - Assert.Equal(1, root.OptionalChildrenAk.Count); + Assert.Equal(1, root.OptionalChildrenAk.Count()); Assert.DoesNotContain(removedId, root.OptionalChildrenAk.Select(e => e.Id)); Assert.Empty(context.OptionalAk1s.Where(e => e.Id == removedId)); @@ -2779,7 +2779,7 @@ public virtual void Required_many_to_one_dependents_are_cascade_deleted_starting { root = LoadFullGraph(context); - Assert.Equal(2, root.RequiredChildren.Count); + Assert.Equal(2, root.RequiredChildren.Count()); } using (var context = CreateContext()) @@ -2807,7 +2807,7 @@ public virtual void Required_many_to_one_dependents_are_cascade_deleted_starting { root = LoadFullGraph(context); - Assert.Equal(1, root.RequiredChildren.Count); + Assert.Equal(1, root.RequiredChildren.Count()); Assert.DoesNotContain(removedId, root.RequiredChildren.Select(e => e.Id)); Assert.Empty(context.Required1s.Where(e => e.Id == removedId)); @@ -2826,7 +2826,7 @@ public virtual void Optional_many_to_one_dependents_are_orphaned_starting_detach { root = LoadFullGraph(context); - Assert.Equal(2, root.OptionalChildren.Count); + Assert.Equal(2, root.OptionalChildren.Count()); } using (var context = CreateContext()) @@ -2854,7 +2854,7 @@ public virtual void Optional_many_to_one_dependents_are_orphaned_starting_detach { root = LoadFullGraph(context); - Assert.Equal(1, root.OptionalChildren.Count); + Assert.Equal(1, root.OptionalChildren.Count()); Assert.DoesNotContain(removedId, root.OptionalChildren.Select(e => e.Id)); Assert.Empty(context.Optional1s.Where(e => e.Id == removedId)); @@ -3002,7 +3002,7 @@ public virtual void Optional_many_to_one_dependents_with_alternate_key_are_orpha { root = LoadFullGraph(context); - Assert.Equal(2, root.OptionalChildrenAk.Count); + Assert.Equal(2, root.OptionalChildrenAk.Count()); } using (var context = CreateContext()) @@ -3030,7 +3030,7 @@ public virtual void Optional_many_to_one_dependents_with_alternate_key_are_orpha { root = LoadFullGraph(context); - Assert.Equal(1, root.OptionalChildrenAk.Count); + Assert.Equal(1, root.OptionalChildrenAk.Count()); Assert.DoesNotContain(removedId, root.OptionalChildrenAk.Select(e => e.Id)); Assert.Empty(context.OptionalAk1s.Where(e => e.Id == removedId)); @@ -3049,7 +3049,7 @@ public virtual void Required_many_to_one_dependents_with_alternate_key_are_casca { root = LoadFullGraph(context); - Assert.Equal(2, root.RequiredChildrenAk.Count); + Assert.Equal(2, root.RequiredChildrenAk.Count()); } using (var context = CreateContext()) @@ -3077,7 +3077,7 @@ public virtual void Required_many_to_one_dependents_with_alternate_key_are_casca { root = LoadFullGraph(context); - Assert.Equal(1, root.RequiredChildrenAk.Count); + Assert.Equal(1, root.RequiredChildrenAk.Count()); Assert.DoesNotContain(removedId, root.RequiredChildrenAk.Select(e => e.Id)); Assert.Empty(context.RequiredAk1s.Where(e => e.Id == removedId)); @@ -3221,7 +3221,7 @@ public virtual void Required_many_to_one_dependents_are_cascade_detached_when_Ad { var root = LoadFullGraph(context); - Assert.Equal(2, root.RequiredChildren.Count); + Assert.Equal(2, root.RequiredChildren.Count()); var removed = root.RequiredChildren.First(); @@ -3232,7 +3232,7 @@ public virtual void Required_many_to_one_dependents_are_cascade_detached_when_Ad Assert.Equal(2, orphanedIds.Count); var added = new Required2(); - removed.Children.Add(added); + Add(removed.Children, added); context.ChangeTracker.DetectChanges(); Assert.Equal(EntityState.Unchanged, context.Entry(removed).State); @@ -3256,7 +3256,7 @@ public virtual void Required_many_to_one_dependents_are_cascade_detached_when_Ad { var root = LoadFullGraph(context); - Assert.Equal(1, root.RequiredChildren.Count); + Assert.Equal(1, root.RequiredChildren.Count()); Assert.DoesNotContain(removedId, root.RequiredChildren.Select(e => e.Id)); Assert.Empty(context.Required1s.Where(e => e.Id == removedId)); @@ -3360,7 +3360,7 @@ public virtual void Required_many_to_one_dependents_with_alternate_key_are_casca { var root = LoadFullGraph(context); - Assert.Equal(2, root.RequiredChildrenAk.Count); + Assert.Equal(2, root.RequiredChildrenAk.Count()); var removed = root.RequiredChildrenAk.First(); @@ -3371,7 +3371,7 @@ public virtual void Required_many_to_one_dependents_with_alternate_key_are_casca Assert.Equal(2, orphanedIds.Count); var added = new RequiredAk2(); - removed.Children.Add(added); + Add(removed.Children, added); context.ChangeTracker.DetectChanges(); Assert.Equal(EntityState.Unchanged, context.Entry(removed).State); @@ -3395,7 +3395,7 @@ public virtual void Required_many_to_one_dependents_with_alternate_key_are_casca { var root = LoadFullGraph(context); - Assert.Equal(1, root.RequiredChildrenAk.Count); + Assert.Equal(1, root.RequiredChildrenAk.Count()); Assert.DoesNotContain(removedId, root.RequiredChildrenAk.Select(e => e.Id)); Assert.Empty(context.RequiredAk1s.Where(e => e.Id == removedId)); @@ -3489,6 +3489,10 @@ public virtual void Required_non_PK_one_to_one_with_alternate_key_are_cascade_de } } + private void Add(IEnumerable collection, T item) => ((ICollection)collection).Add(item); + + private void Remove(IEnumerable collection, T item) => ((ICollection)collection).Remove(item); + public enum ChangeMechanism { Dependent, @@ -3545,8 +3549,8 @@ private static void AssertKeys(Root expected, Root actual) actual.RequiredChildren.OrderBy(e => e.Id).Select(e => e.Id)); Assert.Equal( - expected.RequiredChildren.OrderBy(e => e.Id).Select(e => e.Children.Count), - actual.RequiredChildren.OrderBy(e => e.Id).Select(e => e.Children.Count)); + expected.RequiredChildren.OrderBy(e => e.Id).Select(e => e.Children.Count()), + actual.RequiredChildren.OrderBy(e => e.Id).Select(e => e.Children.Count())); Assert.Equal( expected.RequiredChildren.OrderBy(e => e.Id).SelectMany(e => e.Children).OrderBy(e => e.Id).Select(e => e.Id), @@ -3557,8 +3561,8 @@ private static void AssertKeys(Root expected, Root actual) actual.OptionalChildren.OrderBy(e => e.Id).Select(e => e.Id)); Assert.Equal( - expected.OptionalChildren.OrderBy(e => e.Id).Select(e => e.Children.Count), - actual.OptionalChildren.OrderBy(e => e.Id).Select(e => e.Children.Count)); + expected.OptionalChildren.OrderBy(e => e.Id).Select(e => e.Children.Count()), + actual.OptionalChildren.OrderBy(e => e.Id).Select(e => e.Children.Count())); Assert.Equal( expected.OptionalChildren.OrderBy(e => e.Id).SelectMany(e => e.Children).OrderBy(e => e.Id).Select(e => e.Id), @@ -3579,8 +3583,8 @@ private static void AssertKeys(Root expected, Root actual) actual.RequiredChildrenAk.OrderBy(e => e.Id).Select(e => e.AlternateId)); Assert.Equal( - expected.RequiredChildrenAk.OrderBy(e => e.Id).Select(e => e.Children.Count), - actual.RequiredChildrenAk.OrderBy(e => e.Id).Select(e => e.Children.Count)); + expected.RequiredChildrenAk.OrderBy(e => e.Id).Select(e => e.Children.Count()), + actual.RequiredChildrenAk.OrderBy(e => e.Id).Select(e => e.Children.Count())); Assert.Equal( expected.RequiredChildrenAk.OrderBy(e => e.Id).SelectMany(e => e.Children).OrderBy(e => e.Id).Select(e => e.AlternateId), @@ -3591,8 +3595,8 @@ private static void AssertKeys(Root expected, Root actual) actual.OptionalChildrenAk.OrderBy(e => e.Id).Select(e => e.AlternateId)); Assert.Equal( - expected.OptionalChildrenAk.OrderBy(e => e.Id).Select(e => e.Children.Count), - actual.OptionalChildrenAk.OrderBy(e => e.Id).Select(e => e.Children.Count)); + expected.OptionalChildrenAk.OrderBy(e => e.Id).Select(e => e.Children.Count()), + actual.OptionalChildrenAk.OrderBy(e => e.Id).Select(e => e.Children.Count())); Assert.Equal( expected.OptionalChildrenAk.OrderBy(e => e.Id).SelectMany(e => e.Children).OrderBy(e => e.Id).Select(e => e.AlternateId), @@ -3742,14 +3746,14 @@ protected class Root public int Id { get; set; } public Guid AlternateId { get; set; } - public ICollection RequiredChildren { get; set; } = new List(); - public ICollection OptionalChildren { get; set; } = new List(); + public IEnumerable RequiredChildren { get; set; } = new List(); + public IEnumerable OptionalChildren { get; set; } = new List(); public RequiredSingle1 RequiredSingle { get; set; } public RequiredNonPkSingle1 RequiredNonPkSingle { get; set; } public OptionalSingle1 OptionalSingle { get; set; } - public ICollection RequiredChildrenAk { get; set; } = new List(); - public ICollection OptionalChildrenAk { get; set; } = new List(); + public IEnumerable RequiredChildrenAk { get; set; } = new List(); + public IEnumerable OptionalChildrenAk { get; set; } = new List(); public RequiredSingleAk1 RequiredSingleAk { get; set; } public RequiredNonPkSingleAk1 RequiredNonPkSingleAk { get; set; } public OptionalSingleAk1 OptionalSingleAk { get; set; } @@ -3762,7 +3766,7 @@ protected class Required1 public int ParentId { get; set; } public Root Parent { get; set; } - public ICollection Children { get; set; } = new List(); + public IEnumerable Children { get; set; } = new List(); } protected class Required2 @@ -3780,7 +3784,7 @@ protected class Optional1 public int? ParentId { get; set; } public Root Parent { get; set; } - public ICollection Children { get; set; } = new List(); + public IEnumerable Children { get; set; } = new List(); } protected class Optional2 @@ -3850,7 +3854,7 @@ protected class RequiredAk1 public Guid ParentId { get; set; } public Root Parent { get; set; } - public ICollection Children { get; set; } = new List(); + public IEnumerable Children { get; set; } = new List(); } protected class RequiredAk2 @@ -3870,7 +3874,7 @@ protected class OptionalAk1 public Guid? ParentId { get; set; } public Root Parent { get; set; } - public ICollection Children { get; set; } = new List(); + public IEnumerable Children { get; set; } = new List(); } protected class OptionalAk2 diff --git a/src/Microsoft.EntityFrameworkCore/EntityFrameworkQueryableExtensions.cs b/src/Microsoft.EntityFrameworkCore/EntityFrameworkQueryableExtensions.cs index 9b506fd69f2..8fca85a75bf 100644 --- a/src/Microsoft.EntityFrameworkCore/EntityFrameworkQueryableExtensions.cs +++ b/src/Microsoft.EntityFrameworkCore/EntityFrameworkQueryableExtensions.cs @@ -2189,7 +2189,7 @@ internal static readonly MethodInfo IncludeMethodInfo /// type of entity being queried (). If you wish to include additional types based on the navigation /// properties of the type being included, then chain a call to /// + /// cref="ThenInclude{TEntity, TPreviousProperty, TProperty}(IIncludableQueryable{TEntity, IEnumerable{TPreviousProperty}}, Expression{Func{TPreviousProperty, TProperty}})" /> /// after this call. /// /// @@ -2289,7 +2289,7 @@ internal static readonly MethodInfo ThenIncludeAfterReferenceMethodInfo /// A new query with the related data included. /// public static IIncludableQueryable ThenInclude( - [NotNull] this IIncludableQueryable> source, + [NotNull] this IIncludableQueryable> source, [NotNull] Expression> navigationPropertyPath) where TEntity : class => new IncludableQueryable( diff --git a/src/Microsoft.EntityFrameworkCore/Metadata/Internal/ClrCollectionAccessorFactory.cs b/src/Microsoft.EntityFrameworkCore/Metadata/Internal/ClrCollectionAccessorFactory.cs index b6feb17d0ed..78da19f4d08 100644 --- a/src/Microsoft.EntityFrameworkCore/Metadata/Internal/ClrCollectionAccessorFactory.cs +++ b/src/Microsoft.EntityFrameworkCore/Metadata/Internal/ClrCollectionAccessorFactory.cs @@ -38,7 +38,7 @@ public virtual IClrCollectionAccessor Create([NotNull] INavigation navigation) } var property = navigation.GetPropertyInfo(); - var elementType = property.PropertyType.TryGetElementType(typeof(ICollection<>)); + var elementType = property.PropertyType.TryGetElementType(typeof(IEnumerable<>)); // TODO: Only ICollections supported; add support for enumerables with add/remove methods // Issue #752 @@ -46,18 +46,18 @@ public virtual IClrCollectionAccessor Create([NotNull] INavigation navigation) { throw new InvalidOperationException( CoreStrings.NavigationBadType( - navigation.Name, navigation.DeclaringEntityType.Name, property.PropertyType.FullName, navigation.GetTargetType().Name)); + navigation.Name, navigation.DeclaringEntityType.Name, property.PropertyType.Name, navigation.GetTargetType().Name)); } if (property.PropertyType.IsArray) { throw new InvalidOperationException( - CoreStrings.NavigationArray(navigation.Name, navigation.DeclaringEntityType.Name, property.PropertyType.FullName)); + CoreStrings.NavigationArray(navigation.Name, navigation.DeclaringEntityType.DisplayName(), property.PropertyType.Name)); } if (property.GetMethod == null) { - throw new InvalidOperationException(CoreStrings.NavigationNoGetter(navigation.Name, navigation.DeclaringEntityType.Name)); + throw new InvalidOperationException(CoreStrings.NavigationNoGetter(navigation.Name, navigation.DeclaringEntityType.DisplayName())); } var boundMethod = _genericCreate.MakeGenericMethod( @@ -69,7 +69,7 @@ public virtual IClrCollectionAccessor Create([NotNull] INavigation navigation) [UsedImplicitly] private static IClrCollectionAccessor CreateGeneric(PropertyInfo property) where TEntity : class - where TCollection : class, ICollection + where TCollection : class, IEnumerable { var getterDelegate = (Func)property.GetMethod.CreateDelegate(typeof(Func)); diff --git a/src/Microsoft.EntityFrameworkCore/Metadata/Internal/ClrICollectionAccessor.cs b/src/Microsoft.EntityFrameworkCore/Metadata/Internal/ClrICollectionAccessor.cs index fc5118ca802..f8bb92b72fa 100644 --- a/src/Microsoft.EntityFrameworkCore/Metadata/Internal/ClrICollectionAccessor.cs +++ b/src/Microsoft.EntityFrameworkCore/Metadata/Internal/ClrICollectionAccessor.cs @@ -14,7 +14,7 @@ namespace Microsoft.EntityFrameworkCore.Metadata.Internal /// public class ClrICollectionAccessor : IClrCollectionAccessor where TEntity : class - where TCollection : class, ICollection + where TCollection : class, IEnumerable { private readonly string _propertyName; private readonly Func _getCollection; @@ -89,10 +89,10 @@ public virtual object Create(IEnumerable values) if (_createCollection == null) { throw new InvalidOperationException(CoreStrings.NavigationCannotCreateType( - _propertyName, typeof(TEntity).FullName, typeof(TCollection).FullName)); + _propertyName, typeof(TEntity).Name, typeof(TCollection).Name)); } - var collection = _createCollection(); + var collection = (ICollection)_createCollection(); foreach (TElement value in values) { collection.Add(value); @@ -107,25 +107,42 @@ public virtual object Create(IEnumerable values) /// public virtual object GetOrCreate(object instance) => GetOrCreateCollection(instance); - private TCollection GetOrCreateCollection(object instance) + private ICollection GetOrCreateCollection(object instance) { - var collection = _getCollection((TEntity)instance); + var collection = GetCollection(instance); if (collection == null) { if (_setCollection == null) { - throw new InvalidOperationException(CoreStrings.NavigationNoSetter(_propertyName, typeof(TEntity).FullName)); + throw new InvalidOperationException(CoreStrings.NavigationNoSetter(_propertyName, typeof(TEntity).Name)); } if (_createAndSetCollection == null) { throw new InvalidOperationException(CoreStrings.NavigationCannotCreateType( - _propertyName, typeof(TEntity).FullName, typeof(TCollection).FullName)); + _propertyName, typeof(TEntity).Name, typeof(TCollection).Name)); } - collection = _createAndSetCollection((TEntity)instance, _setCollection); + collection = (ICollection)_createAndSetCollection((TEntity)instance, _setCollection); } + + return collection; + } + + private ICollection GetCollection(object instance) + { + var enumerable = _getCollection((TEntity)instance); + var collection = enumerable as ICollection; + + if (enumerable != null + && collection == null) + { + throw new InvalidOperationException( + CoreStrings.NavigationBadType( + _propertyName, typeof(TEntity).Name, enumerable.GetType().Name, typeof(TElement).Name)); + } + return collection; } @@ -135,7 +152,7 @@ private TCollection GetOrCreateCollection(object instance) /// public virtual bool Contains(object instance, object value) { - var collection = _getCollection((TEntity)instance); + var collection = GetCollection((TEntity)instance); return (collection != null) && collection.Contains((TElement)value); } @@ -145,6 +162,6 @@ public virtual bool Contains(object instance, object value) /// directly from your code. This API may change or be removed in future releases. /// public virtual void Remove(object instance, object value) - => _getCollection((TEntity)instance)?.Remove((TElement)value); + => GetCollection((TEntity)instance)?.Remove((TElement)value); } } diff --git a/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/QueryBugsTest.cs b/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/QueryBugsTest.cs index 4e068586d58..b89fb319108 100644 --- a/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/QueryBugsTest.cs +++ b/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/QueryBugsTest.cs @@ -597,7 +597,7 @@ public void Customer_collections_materialize_properly_3758() var query4 = ctx.Customers.Select(c => c.Orders4); - Assert.Equal(CoreStrings.NavigationCannotCreateType("Orders4", typeof(Customer3758).FullName, typeof(MyInvalidCollection3758).FullName), + Assert.Equal(CoreStrings.NavigationCannotCreateType("Orders4", typeof(Customer3758).Name, typeof(MyInvalidCollection3758).Name), Assert.Throws(() => query4.ToList()).Message); } } diff --git a/test/Microsoft.EntityFrameworkCore.Tests/Metadata/Internal/ClrCollectionAccessorFactoryTest.cs b/test/Microsoft.EntityFrameworkCore.Tests/Metadata/Internal/ClrCollectionAccessorFactoryTest.cs index 0ebef57f689..3f7b9ee5bdd 100644 --- a/test/Microsoft.EntityFrameworkCore.Tests/Metadata/Internal/ClrCollectionAccessorFactoryTest.cs +++ b/test/Microsoft.EntityFrameworkCore.Tests/Metadata/Internal/ClrCollectionAccessorFactoryTest.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using System.Reflection; @@ -14,7 +15,6 @@ // ReSharper disable ConvertToAutoPropertyWhenPossible // ReSharper disable ConvertToAutoPropertyWithPrivateSetter // ReSharper disable UnusedMember.Local - namespace Microsoft.EntityFrameworkCore.Tests.Metadata.Internal { public class ClrCollectionAccessorFactoryTest @@ -30,6 +30,12 @@ public void Navigation_is_returned_if_it_implements_IClrCollectionAccessor() Assert.Same(accessorMock.Object, source.Create(navigationMock.Object)); } + [Fact] + public void Delegate_accessor_is_returned_for_IEnumerable_navigation() + { + AccessorTest("AsIEnumerable", e => e.AsIEnumerable); + } + [Fact] public void Delegate_accessor_is_returned_for_ICollection_navigation() { @@ -135,19 +141,52 @@ public void Creating_accessor_for_navigation_without_getter_throws() var navigation = CreateNavigation("WithNoGetter"); Assert.Equal( - CoreStrings.NavigationNoGetter("WithNoGetter", typeof(MyEntity).FullName), + CoreStrings.NavigationNoGetter("WithNoGetter", typeof(MyEntity).Name), Assert.Throws(() => new ClrCollectionAccessorFactory().Create(navigation)).Message); } [Fact] - public void Creating_accessor_for_enumerable_navigation_throws() + public void Add_for_enumerable_backed_by_non_collection_throws() { - var navigation = CreateNavigation("AsIEnumerable"); + Enumerable_backed_by_non_collection_throws((a, e, v) => a.Add(e, v)); + } + + [Fact] + public void AddRange_for_enumerable_backed_by_non_collection_throws() + { + Enumerable_backed_by_non_collection_throws((a, e, v) => a.AddRange(e, new[] { v })); + } + + [Fact] + public void Contains_for_enumerable_backed_by_non_collection_throws() + { + Enumerable_backed_by_non_collection_throws((a, e, v) => a.Contains(e, v)); + } + + [Fact] + public void Remove_for_enumerable_backed_by_non_collection_throws() + { + Enumerable_backed_by_non_collection_throws((a, e, v) => a.Remove(e, v)); + } + + [Fact] + public void GetOrCreate_for_enumerable_backed_by_non_collection_throws() + { + Enumerable_backed_by_non_collection_throws((a, e, v) => a.GetOrCreate(e)); + } + + private void Enumerable_backed_by_non_collection_throws(Action test) + { + var accessor = new ClrCollectionAccessorFactory().Create(CreateNavigation("AsIEnumerableNotCollection")); + + var entity = new MyEntity(); + var value = new MyOtherEntity(); + entity.InitializeCollections(); Assert.Equal( CoreStrings.NavigationBadType( - "AsIEnumerable", typeof(MyEntity).FullName, typeof(IEnumerable).FullName, typeof(MyOtherEntity).FullName), - Assert.Throws(() => new ClrCollectionAccessorFactory().Create(navigation)).Message); + "AsIEnumerableNotCollection", typeof(MyEntity).Name, typeof(MyEnumerable).Name, typeof(MyOtherEntity).Name), + Assert.Throws(() => test(accessor, entity, value)).Message); } [Fact] @@ -156,7 +195,7 @@ public void Creating_accessor_for_array_navigation_throws() var navigation = CreateNavigation("AsArray"); Assert.Equal( - CoreStrings.NavigationArray("AsArray", typeof(MyEntity).FullName, typeof(MyOtherEntity[]).FullName), + CoreStrings.NavigationArray("AsArray", typeof(MyEntity).Name, typeof(MyOtherEntity[]).Name), Assert.Throws(() => new ClrCollectionAccessorFactory().Create(navigation)).Message); } @@ -166,7 +205,7 @@ public void Initialization_for_navigation_without_setter_throws() var accessor = new ClrCollectionAccessorFactory().Create(CreateNavigation("WithNoSetter")); Assert.Equal( - CoreStrings.NavigationNoSetter("WithNoSetter", typeof(MyEntity).FullName), + CoreStrings.NavigationNoSetter("WithNoSetter", typeof(MyEntity).Name), Assert.Throws(() => accessor.Add(new MyEntity(), new MyOtherEntity())).Message); } @@ -176,7 +215,7 @@ public void Initialization_for_navigation_with_private_constructor_throws() var accessor = new ClrCollectionAccessorFactory().Create(CreateNavigation("AsMyPrivateCollection")); Assert.Equal( - CoreStrings.NavigationCannotCreateType("AsMyPrivateCollection", typeof(MyEntity).FullName, typeof(MyPrivateCollection).FullName), + CoreStrings.NavigationCannotCreateType("AsMyPrivateCollection", typeof(MyEntity).Name, typeof(MyPrivateCollection).Name), Assert.Throws(() => accessor.Add(new MyEntity(), new MyOtherEntity())).Message); } @@ -186,7 +225,7 @@ public void Initialization_for_navigation_with_internal_constructor_throws() var accessor = new ClrCollectionAccessorFactory().Create(CreateNavigation("AsMyInternalCollection")); Assert.Equal( - CoreStrings.NavigationCannotCreateType("AsMyInternalCollection", typeof(MyEntity).FullName, typeof(MyInternalCollection).FullName), + CoreStrings.NavigationCannotCreateType("AsMyInternalCollection", typeof(MyEntity).Name, typeof(MyInternalCollection).Name), Assert.Throws(() => accessor.Add(new MyEntity(), new MyOtherEntity())).Message); } @@ -196,7 +235,7 @@ public void Initialization_for_navigation_without_parameterless_constructor_thro var accessor = new ClrCollectionAccessorFactory().Create(CreateNavigation("AsMyUnavailableCollection")); Assert.Equal( - CoreStrings.NavigationCannotCreateType("AsMyUnavailableCollection", typeof(MyEntity).FullName, typeof(MyUnavailableCollection).FullName), + CoreStrings.NavigationCannotCreateType("AsMyUnavailableCollection", typeof(MyEntity).Name, typeof(MyUnavailableCollection).Name), Assert.Throws(() => accessor.Add(new MyEntity(), new MyOtherEntity())).Message); } @@ -223,6 +262,7 @@ private class MyEntity // ReSharper disable once NotAccessedField.Local private ICollection _withNoGetter; private IEnumerable _enumerable; + private IEnumerable _enumerableNotCollection; private MyOtherEntity[] _array; private MyPrivateCollection _privateCollection; private MyInternalCollection _internalCollection; @@ -237,6 +277,7 @@ public void InitializeCollections() _withNoSetter = new HashSet(); _withNoGetter = new HashSet(); _enumerable = new HashSet(); + _enumerableNotCollection = new MyEnumerable(); _array = new MyOtherEntity[0]; _privateCollection = MyPrivateCollection.Create(); _internalCollection = new MyInternalCollection(); @@ -280,6 +321,12 @@ internal IEnumerable AsIEnumerable set { _enumerable = value; } } + internal IEnumerable AsIEnumerableNotCollection + { + get { return _enumerableNotCollection; } + set { _enumerableNotCollection = value; } + } + internal MyOtherEntity[] AsArray { get { return _array; } @@ -339,5 +386,15 @@ public MyUnavailableCollection(bool _) { } } + + private class MyEnumerable : IEnumerable + { + public IEnumerator GetEnumerator() + { + throw new NotImplementedException(); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } } }