Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow navigation properties to be typed as IEnumerable so long as backed by ICollection #5772

Merged
merged 1 commit into from
Jun 17, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -2189,7 +2189,7 @@ internal static readonly MethodInfo IncludeMethodInfo
/// type of entity being queried (<typeparamref name="TEntity" />). If you wish to include additional types based on the navigation
/// properties of the type being included, then chain a call to
/// <see
/// cref="ThenInclude{TEntity, TPreviousProperty, TProperty}(IIncludableQueryable{TEntity, ICollection{TPreviousProperty}}, Expression{Func{TPreviousProperty, TProperty}})" />
/// cref="ThenInclude{TEntity, TPreviousProperty, TProperty}(IIncludableQueryable{TEntity, IEnumerable{TPreviousProperty}}, Expression{Func{TPreviousProperty, TProperty}})" />
/// after this call.
/// </summary>
/// <example>
Expand Down Expand Up @@ -2289,7 +2289,7 @@ internal static readonly MethodInfo ThenIncludeAfterReferenceMethodInfo
/// A new query with the related data included.
/// </returns>
public static IIncludableQueryable<TEntity, TProperty> ThenInclude<TEntity, TPreviousProperty, TProperty>(
[NotNull] this IIncludableQueryable<TEntity, ICollection<TPreviousProperty>> source,
[NotNull] this IIncludableQueryable<TEntity, IEnumerable<TPreviousProperty>> source,
[NotNull] Expression<Func<TPreviousProperty, TProperty>> navigationPropertyPath)
where TEntity : class
=> new IncludableQueryable<TEntity, TProperty>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,26 @@ public virtual IClrCollectionAccessor Create([NotNull] INavigation navigation)
}

var property = navigation.GetPropertyInfo();
var elementType = property.PropertyType.TryGetElementType(typeof(ICollection<>));
var elementType = property.PropertyType.TryGetElementType(typeof(IEnumerable<>));

// TODO: Only ICollections supported; add support for enumerables with add/remove methods
// Issue #752
if (elementType == null)
{
throw new InvalidOperationException(
CoreStrings.NavigationBadType(
navigation.Name, navigation.DeclaringEntityType.Name, property.PropertyType.FullName, navigation.GetTargetType().Name));
navigation.Name, navigation.DeclaringEntityType.Name, property.PropertyType.Name, navigation.GetTargetType().Name));
}

if (property.PropertyType.IsArray)
{
throw new InvalidOperationException(
CoreStrings.NavigationArray(navigation.Name, navigation.DeclaringEntityType.Name, property.PropertyType.FullName));
CoreStrings.NavigationArray(navigation.Name, navigation.DeclaringEntityType.DisplayName(), property.PropertyType.Name));
}

if (property.GetMethod == null)
{
throw new InvalidOperationException(CoreStrings.NavigationNoGetter(navigation.Name, navigation.DeclaringEntityType.Name));
throw new InvalidOperationException(CoreStrings.NavigationNoGetter(navigation.Name, navigation.DeclaringEntityType.DisplayName()));
}

var boundMethod = _genericCreate.MakeGenericMethod(
Expand All @@ -69,7 +69,7 @@ public virtual IClrCollectionAccessor Create([NotNull] INavigation navigation)
[UsedImplicitly]
private static IClrCollectionAccessor CreateGeneric<TEntity, TCollection, TElement>(PropertyInfo property)
where TEntity : class
where TCollection : class, ICollection<TElement>
where TCollection : class, IEnumerable<TElement>
{
var getterDelegate = (Func<TEntity, TCollection>)property.GetMethod.CreateDelegate(typeof(Func<TEntity, TCollection>));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace Microsoft.EntityFrameworkCore.Metadata.Internal
/// </summary>
public class ClrICollectionAccessor<TEntity, TCollection, TElement> : IClrCollectionAccessor
where TEntity : class
where TCollection : class, ICollection<TElement>
where TCollection : class, IEnumerable<TElement>
{
private readonly string _propertyName;
private readonly Func<TEntity, TCollection> _getCollection;
Expand Down Expand Up @@ -89,10 +89,10 @@ public virtual object Create(IEnumerable<object> values)
if (_createCollection == null)
{
throw new InvalidOperationException(CoreStrings.NavigationCannotCreateType(
_propertyName, typeof(TEntity).FullName, typeof(TCollection).FullName));
_propertyName, typeof(TEntity).Name, typeof(TCollection).Name));
}

var collection = _createCollection();
var collection = (ICollection<TElement>)_createCollection();
foreach (TElement value in values)
{
collection.Add(value);
Expand All @@ -107,25 +107,42 @@ public virtual object Create(IEnumerable<object> values)
/// </summary>
public virtual object GetOrCreate(object instance) => GetOrCreateCollection(instance);

private TCollection GetOrCreateCollection(object instance)
private ICollection<TElement> GetOrCreateCollection(object instance)
{
var collection = _getCollection((TEntity)instance);
var collection = GetCollection(instance);

if (collection == null)
{
if (_setCollection == null)
{
throw new InvalidOperationException(CoreStrings.NavigationNoSetter(_propertyName, typeof(TEntity).FullName));
throw new InvalidOperationException(CoreStrings.NavigationNoSetter(_propertyName, typeof(TEntity).Name));
}

if (_createAndSetCollection == null)
{
throw new InvalidOperationException(CoreStrings.NavigationCannotCreateType(
_propertyName, typeof(TEntity).FullName, typeof(TCollection).FullName));
_propertyName, typeof(TEntity).Name, typeof(TCollection).Name));
}

collection = _createAndSetCollection((TEntity)instance, _setCollection);
collection = (ICollection<TElement>)_createAndSetCollection((TEntity)instance, _setCollection);
}

return collection;
}

private ICollection<TElement> GetCollection(object instance)
{
var enumerable = _getCollection((TEntity)instance);
var collection = enumerable as ICollection<TElement>;

if (enumerable != null
&& collection == null)
{
throw new InvalidOperationException(
CoreStrings.NavigationBadType(
_propertyName, typeof(TEntity).Name, enumerable.GetType().Name, typeof(TElement).Name));
}

return collection;
}

Expand All @@ -135,7 +152,7 @@ private TCollection GetOrCreateCollection(object instance)
/// </summary>
public virtual bool Contains(object instance, object value)
{
var collection = _getCollection((TEntity)instance);
var collection = GetCollection((TEntity)instance);

return (collection != null) && collection.Contains((TElement)value);
}
Expand All @@ -145,6 +162,6 @@ public virtual bool Contains(object instance, object value)
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
public virtual void Remove(object instance, object value)
=> _getCollection((TEntity)instance)?.Remove((TElement)value);
=> GetCollection((TEntity)instance)?.Remove((TElement)value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ public void Customer_collections_materialize_properly_3758()

var query4 = ctx.Customers.Select(c => c.Orders4);

Assert.Equal(CoreStrings.NavigationCannotCreateType("Orders4", typeof(Customer3758).FullName, typeof(MyInvalidCollection3758<Order3758>).FullName),
Assert.Equal(CoreStrings.NavigationCannotCreateType("Orders4", typeof(Customer3758).Name, typeof(MyInvalidCollection3758<Order3758>).Name),
Assert.Throws<InvalidOperationException>(() => query4.ToList()).Message);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
Expand All @@ -14,7 +15,6 @@
// ReSharper disable ConvertToAutoPropertyWhenPossible
// ReSharper disable ConvertToAutoPropertyWithPrivateSetter
// ReSharper disable UnusedMember.Local

namespace Microsoft.EntityFrameworkCore.Tests.Metadata.Internal
{
public class ClrCollectionAccessorFactoryTest
Expand All @@ -30,6 +30,12 @@ public void Navigation_is_returned_if_it_implements_IClrCollectionAccessor()
Assert.Same(accessorMock.Object, source.Create(navigationMock.Object));
}

[Fact]
public void Delegate_accessor_is_returned_for_IEnumerable_navigation()
{
AccessorTest("AsIEnumerable", e => e.AsIEnumerable);
}

[Fact]
public void Delegate_accessor_is_returned_for_ICollection_navigation()
{
Expand Down Expand Up @@ -135,19 +141,52 @@ public void Creating_accessor_for_navigation_without_getter_throws()
var navigation = CreateNavigation("WithNoGetter");

Assert.Equal(
CoreStrings.NavigationNoGetter("WithNoGetter", typeof(MyEntity).FullName),
CoreStrings.NavigationNoGetter("WithNoGetter", typeof(MyEntity).Name),
Assert.Throws<InvalidOperationException>(() => new ClrCollectionAccessorFactory().Create(navigation)).Message);
}

[Fact]
public void Creating_accessor_for_enumerable_navigation_throws()
public void Add_for_enumerable_backed_by_non_collection_throws()
{
var navigation = CreateNavigation("AsIEnumerable");
Enumerable_backed_by_non_collection_throws((a, e, v) => a.Add(e, v));
}

[Fact]
public void AddRange_for_enumerable_backed_by_non_collection_throws()
{
Enumerable_backed_by_non_collection_throws((a, e, v) => a.AddRange(e, new[] { v }));
}

[Fact]
public void Contains_for_enumerable_backed_by_non_collection_throws()
{
Enumerable_backed_by_non_collection_throws((a, e, v) => a.Contains(e, v));
}

[Fact]
public void Remove_for_enumerable_backed_by_non_collection_throws()
{
Enumerable_backed_by_non_collection_throws((a, e, v) => a.Remove(e, v));
}

[Fact]
public void GetOrCreate_for_enumerable_backed_by_non_collection_throws()
{
Enumerable_backed_by_non_collection_throws((a, e, v) => a.GetOrCreate(e));
}

private void Enumerable_backed_by_non_collection_throws(Action<IClrCollectionAccessor, MyEntity, MyOtherEntity> test)
{
var accessor = new ClrCollectionAccessorFactory().Create(CreateNavigation("AsIEnumerableNotCollection"));

var entity = new MyEntity();
var value = new MyOtherEntity();
entity.InitializeCollections();

Assert.Equal(
CoreStrings.NavigationBadType(
"AsIEnumerable", typeof(MyEntity).FullName, typeof(IEnumerable<MyOtherEntity>).FullName, typeof(MyOtherEntity).FullName),
Assert.Throws<InvalidOperationException>(() => new ClrCollectionAccessorFactory().Create(navigation)).Message);
"AsIEnumerableNotCollection", typeof(MyEntity).Name, typeof(MyEnumerable).Name, typeof(MyOtherEntity).Name),
Assert.Throws<InvalidOperationException>(() => test(accessor, entity, value)).Message);
}

[Fact]
Expand All @@ -156,7 +195,7 @@ public void Creating_accessor_for_array_navigation_throws()
var navigation = CreateNavigation("AsArray");

Assert.Equal(
CoreStrings.NavigationArray("AsArray", typeof(MyEntity).FullName, typeof(MyOtherEntity[]).FullName),
CoreStrings.NavigationArray("AsArray", typeof(MyEntity).Name, typeof(MyOtherEntity[]).Name),
Assert.Throws<InvalidOperationException>(() => new ClrCollectionAccessorFactory().Create(navigation)).Message);
}

Expand All @@ -166,7 +205,7 @@ public void Initialization_for_navigation_without_setter_throws()
var accessor = new ClrCollectionAccessorFactory().Create(CreateNavigation("WithNoSetter"));

Assert.Equal(
CoreStrings.NavigationNoSetter("WithNoSetter", typeof(MyEntity).FullName),
CoreStrings.NavigationNoSetter("WithNoSetter", typeof(MyEntity).Name),
Assert.Throws<InvalidOperationException>(() => accessor.Add(new MyEntity(), new MyOtherEntity())).Message);
}

Expand All @@ -176,7 +215,7 @@ public void Initialization_for_navigation_with_private_constructor_throws()
var accessor = new ClrCollectionAccessorFactory().Create(CreateNavigation("AsMyPrivateCollection"));

Assert.Equal(
CoreStrings.NavigationCannotCreateType("AsMyPrivateCollection", typeof(MyEntity).FullName, typeof(MyPrivateCollection).FullName),
CoreStrings.NavigationCannotCreateType("AsMyPrivateCollection", typeof(MyEntity).Name, typeof(MyPrivateCollection).Name),
Assert.Throws<InvalidOperationException>(() => accessor.Add(new MyEntity(), new MyOtherEntity())).Message);
}

Expand All @@ -186,7 +225,7 @@ public void Initialization_for_navigation_with_internal_constructor_throws()
var accessor = new ClrCollectionAccessorFactory().Create(CreateNavigation("AsMyInternalCollection"));

Assert.Equal(
CoreStrings.NavigationCannotCreateType("AsMyInternalCollection", typeof(MyEntity).FullName, typeof(MyInternalCollection).FullName),
CoreStrings.NavigationCannotCreateType("AsMyInternalCollection", typeof(MyEntity).Name, typeof(MyInternalCollection).Name),
Assert.Throws<InvalidOperationException>(() => accessor.Add(new MyEntity(), new MyOtherEntity())).Message);
}

Expand All @@ -196,7 +235,7 @@ public void Initialization_for_navigation_without_parameterless_constructor_thro
var accessor = new ClrCollectionAccessorFactory().Create(CreateNavigation("AsMyUnavailableCollection"));

Assert.Equal(
CoreStrings.NavigationCannotCreateType("AsMyUnavailableCollection", typeof(MyEntity).FullName, typeof(MyUnavailableCollection).FullName),
CoreStrings.NavigationCannotCreateType("AsMyUnavailableCollection", typeof(MyEntity).Name, typeof(MyUnavailableCollection).Name),
Assert.Throws<InvalidOperationException>(() => accessor.Add(new MyEntity(), new MyOtherEntity())).Message);
}

Expand All @@ -223,6 +262,7 @@ private class MyEntity
// ReSharper disable once NotAccessedField.Local
private ICollection<MyOtherEntity> _withNoGetter;
private IEnumerable<MyOtherEntity> _enumerable;
private IEnumerable<MyOtherEntity> _enumerableNotCollection;
private MyOtherEntity[] _array;
private MyPrivateCollection _privateCollection;
private MyInternalCollection _internalCollection;
Expand All @@ -237,6 +277,7 @@ public void InitializeCollections()
_withNoSetter = new HashSet<MyOtherEntity>();
_withNoGetter = new HashSet<MyOtherEntity>();
_enumerable = new HashSet<MyOtherEntity>();
_enumerableNotCollection = new MyEnumerable();
_array = new MyOtherEntity[0];
_privateCollection = MyPrivateCollection.Create();
_internalCollection = new MyInternalCollection();
Expand Down Expand Up @@ -280,6 +321,12 @@ internal IEnumerable<MyOtherEntity> AsIEnumerable
set { _enumerable = value; }
}

internal IEnumerable<MyOtherEntity> AsIEnumerableNotCollection
{
get { return _enumerableNotCollection; }
set { _enumerableNotCollection = value; }
}

internal MyOtherEntity[] AsArray
{
get { return _array; }
Expand Down Expand Up @@ -339,5 +386,15 @@ public MyUnavailableCollection(bool _)
{
}
}

private class MyEnumerable : IEnumerable<MyOtherEntity>
{
public IEnumerator<MyOtherEntity> GetEnumerator()
{
throw new NotImplementedException();
}

IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}
}
}