Skip to content

Commit

Permalink
Collection include for the in-memory provider
Browse files Browse the repository at this point in the history
Part of #16963
  • Loading branch information
ajcvickers committed Aug 12, 2019
1 parent 669ea67 commit 87aaf63
Show file tree
Hide file tree
Showing 16 changed files with 117 additions and 719 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ public static IServiceCollection AddEntityFrameworkInMemoryDatabase([NotNull] th
.TryAddSingleton<IInMemorySingletonOptions, InMemorySingletonOptions>()
.TryAddSingleton<IInMemoryStoreCache, InMemoryStoreCache>()
.TryAddSingleton<IInMemoryTableFactory, InMemoryTableFactory>()
.TryAddScoped<IInMemoryDatabase, InMemoryDatabase>()
.TryAddScoped<IInMemoryMaterializerFactory, InMemoryMaterializerFactory>());
.TryAddScoped<IInMemoryDatabase, InMemoryDatabase>());

builder.TryAddCoreServices();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
Expand All @@ -28,6 +27,10 @@ private static readonly MethodInfo _includeReferenceMethodInfo
= typeof(CustomShaperCompilingExpressionVisitor).GetTypeInfo()
.GetDeclaredMethod(nameof(IncludeReference));

private static readonly MethodInfo _includeCollectionMethodInfo
= typeof(CustomShaperCompilingExpressionVisitor).GetTypeInfo()
.GetDeclaredMethod(nameof(IncludeCollection));

private static void IncludeReference<TEntity, TIncludingEntity, TIncludedEntity>(
QueryContext queryContext,
TEntity entity,
Expand All @@ -36,7 +39,9 @@ private static void IncludeReference<TEntity, TIncludingEntity, TIncludedEntity>
INavigation inverseNavigation,
Action<TIncludingEntity, TIncludedEntity> fixup,
bool trackingQuery)
where TIncludingEntity : TEntity
where TIncludingEntity : class, TEntity
where TEntity : class
where TIncludedEntity : class
{
if (entity is TIncludingEntity includingEntity)
{
Expand Down Expand Up @@ -64,6 +69,48 @@ private static void IncludeReference<TEntity, TIncludingEntity, TIncludedEntity>
}
}

private static void IncludeCollection<TEntity, TIncludingEntity, TIncludedEntity>(
QueryContext queryContext,
TEntity entity,
IEnumerable<TIncludedEntity> relatedEntities,
INavigation navigation,
INavigation inverseNavigation,
Action<TIncludingEntity, TIncludedEntity> fixup,
bool trackingQuery)
where TIncludingEntity : class, TEntity
where TEntity : class
where TIncludedEntity : class
{
if (entity is TIncludingEntity includingEntity)
{
var collectionAccessor = navigation.GetCollectionAccessor();
collectionAccessor.GetOrCreate(includingEntity, forMaterialization: true);

if (trackingQuery)
{
foreach (var relatedEntity in relatedEntities)
{
collectionAccessor.Add(includingEntity, relatedEntity, forMaterialization: true);
}

queryContext.StateManager.TryGetEntry(includingEntity).SetIsLoaded(navigation);
}
else
{
foreach (var relatedEntity in relatedEntities)
{
fixup(includingEntity, relatedEntity);
if (inverseNavigation != null)
{
SetIsLoadedNoTracking(relatedEntity, inverseNavigation);
}
}

SetIsLoadedNoTracking(includingEntity, navigation);
}
}
}

private static void SetIsLoadedNoTracking(object entity, INavigation navigation)
=> ((ILazyLoader)(navigation
.DeclaringEntityType
Expand All @@ -76,10 +123,6 @@ protected override Expression VisitExtension(Expression extensionExpression)
{
if (extensionExpression is IncludeExpression includeExpression)
{
Debug.Assert(
!includeExpression.Navigation.IsCollection(),
"Only reference include should be present in tree");

var entityClrType = includeExpression.EntityExpression.Type;
var includingClrType = includeExpression.Navigation.DeclaringEntityType.ClrType;
var inverseNavigation = includeExpression.Navigation.FindInverse();
Expand All @@ -90,6 +133,22 @@ protected override Expression VisitExtension(Expression extensionExpression)
includingClrType = entityClrType;
}

if (includeExpression.Navigation.IsCollection())
{
return Expression.Call(
_includeCollectionMethodInfo.MakeGenericMethod(entityClrType, includingClrType, relatedEntityClrType),
QueryCompilationContext.QueryContextParameter,
// We don't need to visit entityExpression since it is supposed to be a parameterExpression only
includeExpression.EntityExpression,
includeExpression.NavigationExpression,
Expression.Constant(includeExpression.Navigation),
Expression.Constant(inverseNavigation, typeof(INavigation)),
Expression.Constant(
GenerateFixup(
includingClrType, relatedEntityClrType, includeExpression.Navigation, inverseNavigation).Compile()),
Expression.Constant(_tracking));
}

return Expression.Call(
_includeReferenceMethodInfo.MakeGenericMethod(entityClrType, includingClrType, relatedEntityClrType),
QueryCompilationContext.QueryContextParameter,
Expand Down
28 changes: 0 additions & 28 deletions src/EFCore.InMemory/Query/Internal/IInMemoryMaterializerFactory.cs

This file was deleted.

106 changes: 0 additions & 106 deletions src/EFCore.InMemory/Query/Internal/InMemoryMaterializerFactory.cs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,14 @@ public override Expression Visit(Expression expression)
return expression;

case MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression:
//return _selectExpression.AddCollectionProjection(
// _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(
// materializeCollectionNavigationExpression.Subquery),
// materializeCollectionNavigationExpression.Navigation, null);
throw new NotImplementedException();

var translated = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(
materializeCollectionNavigationExpression.Subquery);

return new ProjectionBindingExpression(
_queryExpression,
_queryExpression.AddToProjection(translated),
typeof(IEnumerable<>).MakeGenericType(materializeCollectionNavigationExpression.Navigation.GetTargetType().ClrType));

case MethodCallExpression methodCallExpression:
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query;
Expand Down
20 changes: 16 additions & 4 deletions src/EFCore/Internal/ConcurrencyDetector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
using System;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.Extensions.DependencyInjection;

namespace Microsoft.EntityFrameworkCore.Internal
Expand All @@ -29,6 +27,8 @@ public class ConcurrencyDetector : IConcurrencyDetector
{
private readonly IDisposable _disposer;
private int _inCriticalSection;
private static readonly AsyncLocal<bool> _threadHasLock = new AsyncLocal<bool>();
private int _refCount;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand All @@ -48,17 +48,29 @@ public virtual IDisposable EnterCriticalSection()
{
if (Interlocked.CompareExchange(ref _inCriticalSection, 1, 0) == 1)
{
throw new InvalidOperationException(CoreStrings.ConcurrentMethodInvocation);
if (!_threadHasLock.Value)
{
throw new InvalidOperationException(CoreStrings.ConcurrentMethodInvocation);
}
}
else
{
_threadHasLock.Value = true;
}

_refCount++;
return _disposer;
}

private void ExitCriticalSection()
{
Debug.Assert(_inCriticalSection == 1, "Expected to be in a critical section");

_inCriticalSection = 0;
if (--_refCount == 0)
{
_threadHasLock.Value = false;
_inCriticalSection = 0;
}
}

private readonly struct Disposer : IDisposable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ public class InMemoryComplianceTest : ComplianceTestBase
typeof(FunkyDataQueryTestBase<>),
typeof(OptimisticConcurrencyTestBase<>),
typeof(StoreGeneratedTestBase<>),
typeof(LoadTestBase<>), // issue #16963
typeof(MusicStoreTestBase<>), // issue #16963
typeof(ConferencePlannerTestBase<>), // issue #16963
typeof(GraphUpdatesTestBase<>), // issue #16963
Expand Down
Loading

0 comments on commit 87aaf63

Please sign in to comment.