-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reference Include for in-memory provider
Part of #16963
- Loading branch information
1 parent
fb61241
commit dfd323b
Showing
8 changed files
with
764 additions
and
18 deletions.
There are no files selected for viewing
161 changes: 161 additions & 0 deletions
161
src/EFCore.InMemory/Query/Internal/CustomShaperCompilingExpressionVisitor.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
// Copyright (c) .NET Foundation. All rights reserved. | ||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Diagnostics; | ||
using System.Linq; | ||
using System.Linq.Expressions; | ||
using System.Reflection; | ||
using Microsoft.EntityFrameworkCore.Infrastructure; | ||
using Microsoft.EntityFrameworkCore.Metadata; | ||
using Microsoft.EntityFrameworkCore.Query; | ||
|
||
namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal | ||
{ | ||
public partial class InMemoryShapedQueryCompilingExpressionVisitor | ||
{ | ||
private class CustomShaperCompilingExpressionVisitor : ExpressionVisitor | ||
{ | ||
private readonly bool _tracking; | ||
|
||
public CustomShaperCompilingExpressionVisitor(bool tracking) | ||
{ | ||
_tracking = tracking; | ||
} | ||
|
||
private static readonly MethodInfo _includeReferenceMethodInfo | ||
= typeof(CustomShaperCompilingExpressionVisitor).GetTypeInfo() | ||
.GetDeclaredMethod(nameof(IncludeReference)); | ||
|
||
private static void IncludeReference<TEntity, TIncludingEntity, TIncludedEntity>( | ||
QueryContext queryContext, | ||
TEntity entity, | ||
TIncludedEntity relatedEntity, | ||
INavigation navigation, | ||
INavigation inverseNavigation, | ||
Action<TIncludingEntity, TIncludedEntity> fixup, | ||
bool trackingQuery) | ||
where TIncludingEntity : TEntity | ||
{ | ||
if (entity is TIncludingEntity includingEntity) | ||
{ | ||
if (trackingQuery) | ||
{ | ||
// For non-null relatedEntity StateManager will set the flag | ||
if (relatedEntity == null) | ||
{ | ||
queryContext.StateManager.TryGetEntry(includingEntity).SetIsLoaded(navigation); | ||
} | ||
} | ||
else | ||
{ | ||
SetIsLoadedNoTracking(includingEntity, navigation); | ||
if (relatedEntity != null) | ||
{ | ||
fixup(includingEntity, relatedEntity); | ||
if (inverseNavigation != null | ||
&& !inverseNavigation.IsCollection()) | ||
{ | ||
SetIsLoadedNoTracking(relatedEntity, inverseNavigation); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
private static void SetIsLoadedNoTracking(object entity, INavigation navigation) | ||
=> ((ILazyLoader)(navigation | ||
.DeclaringEntityType | ||
.GetServiceProperties() | ||
.FirstOrDefault(p => p.ClrType == typeof(ILazyLoader))) | ||
?.GetGetter().GetClrValue(entity)) | ||
?.SetLoaded(entity, navigation.Name); | ||
|
||
protected override Expression VisitExtension(Expression extensionExpression) | ||
{ | ||
if (extensionExpression is IncludeExpression includeExpression) | ||
{ | ||
Debug.Assert( | ||
!includeExpression.Navigation.IsCollection(), | ||
"Only reference include should be present in tree"); | ||
|
||
var entityClrType = includeExpression.EntityExpression.Type; | ||
var includingClrType = includeExpression.Navigation.DeclaringEntityType.ClrType; | ||
var inverseNavigation = includeExpression.Navigation.FindInverse(); | ||
var relatedEntityClrType = includeExpression.Navigation.GetTargetType().ClrType; | ||
if (includingClrType != entityClrType | ||
&& includingClrType.IsAssignableFrom(entityClrType)) | ||
{ | ||
includingClrType = entityClrType; | ||
} | ||
|
||
return Expression.Call( | ||
_includeReferenceMethodInfo.MakeGenericMethod(entityClrType, includingClrType, relatedEntityClrType), | ||
QueryCompilationContext.QueryContextParameter, | ||
// We don't need to visit entityExpression since it is supposed to be a parameterExpression only | ||
includeExpression.EntityExpression, | ||
includeExpression.NavigationExpression, | ||
Expression.Constant(includeExpression.Navigation), | ||
Expression.Constant(inverseNavigation, typeof(INavigation)), | ||
Expression.Constant( | ||
GenerateFixup( | ||
includingClrType, relatedEntityClrType, includeExpression.Navigation, inverseNavigation).Compile()), | ||
Expression.Constant(_tracking)); | ||
} | ||
|
||
return base.VisitExtension(extensionExpression); | ||
} | ||
|
||
private static LambdaExpression GenerateFixup( | ||
Type entityType, | ||
Type relatedEntityType, | ||
INavigation navigation, | ||
INavigation inverseNavigation) | ||
{ | ||
var entityParameter = Expression.Parameter(entityType); | ||
var relatedEntityParameter = Expression.Parameter(relatedEntityType); | ||
var expressions = new List<Expression> | ||
{ | ||
navigation.IsCollection() | ||
? AddToCollectionNavigation(entityParameter, relatedEntityParameter, navigation) | ||
: AssignReferenceNavigation(entityParameter, relatedEntityParameter, navigation) | ||
}; | ||
|
||
if (inverseNavigation != null) | ||
{ | ||
expressions.Add( | ||
inverseNavigation.IsCollection() | ||
? AddToCollectionNavigation(relatedEntityParameter, entityParameter, inverseNavigation) | ||
: AssignReferenceNavigation(relatedEntityParameter, entityParameter, inverseNavigation)); | ||
|
||
} | ||
|
||
return Expression.Lambda(Expression.Block(typeof(void), expressions), entityParameter, relatedEntityParameter); | ||
} | ||
|
||
private static Expression AssignReferenceNavigation( | ||
ParameterExpression entity, | ||
ParameterExpression relatedEntity, | ||
INavigation navigation) | ||
{ | ||
return entity.MakeMemberAccess(navigation.GetMemberInfo(forMaterialization: true, forSet: true)).Assign(relatedEntity); | ||
} | ||
|
||
private static Expression AddToCollectionNavigation( | ||
ParameterExpression entity, | ||
ParameterExpression relatedEntity, | ||
INavigation navigation) | ||
=> Expression.Call( | ||
Expression.Constant(navigation.GetCollectionAccessor()), | ||
_collectionAccessorAddMethodInfo, | ||
entity, | ||
relatedEntity, | ||
Expression.Constant(true)); | ||
|
||
private static readonly MethodInfo _collectionAccessorAddMethodInfo | ||
= typeof(IClrCollectionAccessor).GetTypeInfo() | ||
.GetDeclaredMethod(nameof(IClrCollectionAccessor.Add)); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
101 changes: 101 additions & 0 deletions
101
src/EFCore.InMemory/Query/Internal/ShaperExpressionProcessingExpressionVisitor.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
// Copyright (c) .NET Foundation. All rights reserved. | ||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. | ||
|
||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Linq.Expressions; | ||
using Microsoft.EntityFrameworkCore.Query; | ||
|
||
namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal | ||
{ | ||
public class ShaperExpressionProcessingExpressionVisitor : ExpressionVisitor | ||
{ | ||
private readonly InMemoryQueryExpression _queryExpression; | ||
|
||
private readonly IDictionary<Expression, ParameterExpression> _mapping = new Dictionary<Expression, ParameterExpression>(); | ||
private readonly List<ParameterExpression> _variables = new List<ParameterExpression>(); | ||
private readonly List<Expression> _expressions = new List<Expression>(); | ||
|
||
public ShaperExpressionProcessingExpressionVisitor( | ||
InMemoryQueryExpression queryExpression) | ||
{ | ||
_queryExpression = queryExpression; | ||
} | ||
|
||
public virtual Expression Inject(Expression expression) | ||
{ | ||
var result = Visit(expression); | ||
|
||
if (_expressions.All(e => e.NodeType == ExpressionType.Assign)) | ||
{ | ||
result = new ReplacingExpressionVisitor(_expressions.Cast<BinaryExpression>() | ||
.ToDictionary(e => e.Left, e => e.Right)).Visit(result); | ||
} | ||
else | ||
{ | ||
_expressions.Add(result); | ||
result = Expression.Block(_variables, _expressions); | ||
} | ||
|
||
return ConvertToLambda(result, Expression.Parameter(result.Type, "result")); | ||
} | ||
|
||
private LambdaExpression ConvertToLambda(Expression result, ParameterExpression resultParameter) | ||
=> Expression.Lambda( | ||
result, | ||
QueryCompilationContext.QueryContextParameter, | ||
_queryExpression.ValueBufferParameter); | ||
|
||
protected override Expression VisitExtension(Expression extensionExpression) | ||
{ | ||
switch (extensionExpression) | ||
{ | ||
case EntityShaperExpression entityShaperExpression: | ||
{ | ||
var key = GenerateKey((ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression); | ||
if (!_mapping.TryGetValue(key, out var variable)) | ||
{ | ||
variable = Expression.Parameter(entityShaperExpression.EntityType.ClrType); | ||
_variables.Add(variable); | ||
_expressions.Add(Expression.Assign(variable, entityShaperExpression)); | ||
_mapping[key] = variable; | ||
} | ||
|
||
return variable; | ||
} | ||
|
||
case ProjectionBindingExpression projectionBindingExpression: | ||
{ | ||
var key = GenerateKey(projectionBindingExpression); | ||
if (!_mapping.TryGetValue(key, out var variable)) | ||
{ | ||
variable = Expression.Parameter(projectionBindingExpression.Type); | ||
_variables.Add(variable); | ||
_expressions.Add(Expression.Assign(variable, projectionBindingExpression)); | ||
_mapping[key] = variable; | ||
} | ||
|
||
return variable; | ||
} | ||
|
||
case IncludeExpression includeExpression: | ||
{ | ||
var entity = Visit(includeExpression.EntityExpression); | ||
_expressions.Add( | ||
includeExpression.Update( | ||
entity, | ||
Visit(includeExpression.NavigationExpression))); | ||
|
||
return entity; | ||
} | ||
} | ||
|
||
return base.VisitExtension(extensionExpression); | ||
} | ||
|
||
private Expression GenerateKey(ProjectionBindingExpression projectionBindingExpression) | ||
=> projectionBindingExpression.ProjectionMember != null | ||
? _queryExpression.GetMappedProjection(projectionBindingExpression.ProjectionMember) | ||
: projectionBindingExpression; | ||
} | ||
} |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.