Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better handling of different collection types in fix-up #16500

Merged
merged 1 commit into from
Jul 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 68 additions & 8 deletions src/EFCore/Metadata/Internal/ClrCollectionAccessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
Expand All @@ -18,6 +19,7 @@ namespace Microsoft.EntityFrameworkCore.Metadata.Internal
public class ClrICollectionAccessor<TEntity, TCollection, TElement> : IClrCollectionAccessor
where TEntity : class
where TCollection : class, IEnumerable<TElement>
where TElement : class
{
private readonly string _propertyName;
private readonly Func<TEntity, TCollection> _getCollection;
Expand Down Expand Up @@ -64,7 +66,7 @@ public virtual bool Add(object instance, object value)
var collection = GetOrCreateCollection(instance);
var element = (TElement)value;

if (!collection.Contains(element))
if (!Contains(collection, value))
{
collection.Add(element);

Expand All @@ -86,7 +88,7 @@ public virtual void AddRange(object instance, IEnumerable<object> values)

foreach (TElement value in values)
{
if (!collection.Contains(value))
if (!Contains(collection, value))
{
collection.Add(value);
}
Expand Down Expand Up @@ -186,11 +188,7 @@ private ICollection<TElement> GetCollection(object instance)
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual bool Contains(object instance, object value)
{
var collection = GetCollection((TEntity)instance);

return (collection?.Contains((TElement)value) == true);
}
=> Contains(GetCollection((TEntity)instance), value);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand All @@ -199,6 +197,68 @@ public virtual bool Contains(object instance, object value)
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual bool Remove(object instance, object value)
=> GetCollection((TEntity)instance)?.Remove((TElement)value) ?? false;
{
var collection = GetCollection((TEntity)instance);

switch (collection)
{
case List<TElement> list:
for (var i = 0; i < list.Count; i++)
{
if (ReferenceEquals(list[i], value))
{
list.RemoveAt(i);
return true;
}
}
return false;
case Collection<TElement> concreteCollection:
for (var i = 0; i < concreteCollection.Count; i++)
{
if (ReferenceEquals(concreteCollection[i], value))
{
concreteCollection.RemoveAt(i);
return true;
}
}
return false;
case SortedSet<TElement> sortedSet:
return sortedSet.TryGetValue((TElement)value, out var found)
&& ReferenceEquals(found, value)
&& sortedSet.Remove(found);
default:
return collection?.Remove((TElement)value) ?? false;
}
}

private static bool Contains(ICollection<TElement> collection, object value)
{
switch (collection)
{
case List<TElement> list:
foreach (var element in list)
{
if (ReferenceEquals(element, value))
{
return true;
}
}
return false;
case Collection<TElement> concreteCollection:
for (var i = 0; i < concreteCollection.Count; i++)
{
if (ReferenceEquals(concreteCollection[i], value))
{
return true;
}
}
return false;
case SortedSet<TElement> sortedSet:
return sortedSet.TryGetValue((TElement)value, out var found)
&& ReferenceEquals(found, value);
default:
return collection?.Contains((TElement)value) == true;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ public virtual IClrCollectionAccessor Create(
private static IClrCollectionAccessor CreateGeneric<TEntity, TCollection, TElement>(INavigation navigation, MemberInfo memberInfo)
where TEntity : class
where TCollection : class, IEnumerable<TElement>
where TElement : class
{
var entityParameter = Expression.Parameter(typeof(TEntity), "entity");
var valueParameter = Expression.Parameter(typeof(TCollection), "collection");
Expand Down
117 changes: 117 additions & 0 deletions test/EFCore.Tests/ChangeTracking/Internal/FixupTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2633,6 +2633,123 @@ public void Navigation_fixup_is_non_destructive_to_existing_graphs()
}
}

[ConditionalFact]
public void Comparable_entities_that_comply_are_tracked_correctly()
{
using (var context = new ComparableEntitiesContext("ComparableEntities"))
{
var level2a = new Level2
{
Name = "Foo"
};

var level2b = new Level2
{
Name = "Bar"
};

var level1 = new Level1
{
Children =
{
level2a, level2b,
},
};

context.Add(level1);
context.SaveChanges();

Assert.Equal(3, context.ChangeTracker.Entries().Count());
Assert.Equal(EntityState.Unchanged, context.Entry(level1).State);
Assert.Equal(EntityState.Unchanged, context.Entry(level2a).State);
Assert.Equal(EntityState.Unchanged, context.Entry(level2b).State);

Assert.Equal(2, level1.Children.Count);
Assert.Contains(level2a, level1.Children);
Assert.Contains(level2b, level1.Children);

level1.Children.Clear();

Assert.Equal(3, context.ChangeTracker.Entries().Count());
Assert.Equal(EntityState.Unchanged, context.Entry(level1).State);
Assert.Equal(EntityState.Deleted, context.Entry(level2a).State);
Assert.Equal(EntityState.Deleted, context.Entry(level2b).State);

Assert.Equal(0, level1.Children.Count);

var level2c = new Level2
{
Name = "Foo"
};

var level2d = new Level2
{
Name = "Quz"
};

level1.Children.Add(level2c);
level1.Children.Add(level2d);

Assert.Equal(2, level1.Children.Count);
Assert.Contains(level2c, level1.Children);
Assert.Contains(level2d, level1.Children);

Assert.Equal(5, context.ChangeTracker.Entries().Count());
Assert.Equal(EntityState.Unchanged, context.Entry(level1).State);
Assert.Equal(EntityState.Deleted, context.Entry(level2a).State);
Assert.Equal(EntityState.Deleted, context.Entry(level2b).State);
Assert.Equal(EntityState.Added, context.Entry(level2c).State);
Assert.Equal(EntityState.Added, context.Entry(level2d).State);

context.SaveChanges();

Assert.Equal(2, level1.Children.Count);
Assert.Contains(level2c, level1.Children);
Assert.Contains(level2d, level1.Children);

Assert.Equal(3, context.ChangeTracker.Entries().Count());
Assert.Equal(EntityState.Unchanged, context.Entry(level1).State);
Assert.Equal(EntityState.Unchanged, context.Entry(level2c).State);
Assert.Equal(EntityState.Unchanged, context.Entry(level2d).State);
}
}

private class Level1
{
public int Id { get; set; }

[Required]
public ICollection<Level2> Children { get; set; } = new SortedSet<Level2>();
}

private class Level2 : IComparable<Level2>
{
public int Id { get; set; }
public string Name { get; set; }

public Level1 Level1 { get; set; }
public int Level1Id { get; set; }

public int CompareTo(Level2 other)
=> StringComparer.InvariantCultureIgnoreCase.Compare(Name, other.Name);
}

private class ComparableEntitiesContext : DbContext
{
private readonly string _databaseName;

public ComparableEntitiesContext(string databaseName)
{
_databaseName = databaseName;
}

protected internal override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
=> optionsBuilder.UseInMemoryDatabase(_databaseName);

public DbSet<Level1> Level1s { get; set; }
public DbSet<Level2> Level2s { get; set; }
}

protected virtual void AssertAllFixedUp(DbContext context)
{
foreach (var entry in context.ChangeTracker.Entries<Product>())
Expand Down
Loading