Skip to content

Commit

Permalink
Query: Add support for owned reference navigation
Browse files Browse the repository at this point in the history
Part of #15285
  • Loading branch information
smitpatel committed Jul 17, 2019
1 parent cd44254 commit b9857de
Show file tree
Hide file tree
Showing 16 changed files with 513 additions and 751 deletions.
30 changes: 30 additions & 0 deletions src/EFCore.Relational/Query/Pipeline/EntityProjectionExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query.Pipeline;
using Microsoft.EntityFrameworkCore.Relational.Query.Pipeline.SqlExpressions;

namespace Microsoft.EntityFrameworkCore.Relational.Query.Pipeline
Expand All @@ -15,6 +16,9 @@ public class EntityProjectionExpression : Expression
{
private readonly IDictionary<IProperty, ColumnExpression> _propertyExpressionsCache
= new Dictionary<IProperty, ColumnExpression>();
private readonly IDictionary<INavigation, EntityShaperExpression> _navigationExpressionsCache
= new Dictionary<INavigation, EntityShaperExpression>();

private readonly TableExpressionBase _innerTable;
private readonly bool _nullable;

Expand Down Expand Up @@ -107,5 +111,31 @@ public ColumnExpression BindProperty(IProperty property)

return expression;
}

public void AddNavigationBinding(INavigation navigation, EntityShaperExpression entityShaper)
{
if (!EntityType.GetTypesInHierarchy().Contains(navigation.DeclaringEntityType))
{
throw new InvalidOperationException(
$"Called EntityProjectionExpression.AddNavigationBinding() with incorrect INavigation. " +
$"EntityType:{EntityType.DisplayName()}, Property:{navigation.Name}");
}

_navigationExpressionsCache[navigation] = entityShaper;
}

public EntityShaperExpression BindNavigation(INavigation navigation)
{
if (!EntityType.GetTypesInHierarchy().Contains(navigation.DeclaringEntityType))
{
throw new InvalidOperationException(
$"Called EntityProjectionExpression.BindNavigation() with incorrect INavigation. " +
$"EntityType:{EntityType.DisplayName()}, Property:{navigation.Name}");
}

return _navigationExpressionsCache.TryGetValue(navigation, out var expression)
? expression
: null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,21 +160,26 @@ protected override Expression VisitExtension(Expression extensionExpression)
{
if (extensionExpression is EntityShaperExpression entityShaperExpression)
{
var projectionBindingExpression = (ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression;
VerifySelectExpression(projectionBindingExpression);

if (_clientEval)
EntityProjectionExpression entityProjectionExpression;
if (entityShaperExpression.ValueBufferExpression is ProjectionBindingExpression projectionBindingExpression)
{
var entityProjection = (EntityProjectionExpression)_selectExpression.GetMappedProjection(
VerifySelectExpression(projectionBindingExpression);
entityProjectionExpression = (EntityProjectionExpression)_selectExpression.GetMappedProjection(
projectionBindingExpression.ProjectionMember);
}
else
{
entityProjectionExpression = (EntityProjectionExpression)entityShaperExpression.ValueBufferExpression;
}

if (_clientEval)
{
return entityShaperExpression.Update(
new ProjectionBindingExpression(_selectExpression, _selectExpression.AddToProjection(entityProjection)));
new ProjectionBindingExpression(_selectExpression, _selectExpression.AddToProjection(entityProjectionExpression)));
}
else
{
_projectionMapping[_projectionMembers.Peek()]
= _selectExpression.GetMappedProjection(projectionBindingExpression.ProjectionMember);
_projectionMapping[_projectionMembers.Peek()] = entityProjectionExpression;

return entityShaperExpression.Update(
new ProjectionBindingExpression(_selectExpression, _projectionMembers.Peek(), typeof(ValueBuffer)));
Expand All @@ -183,12 +188,9 @@ protected override Expression VisitExtension(Expression extensionExpression)

if (extensionExpression is IncludeExpression includeExpression)
{
// TODO: handle owned navigations
return includeExpression.Navigation.ForeignKey.IsOwnership
? Visit(includeExpression.EntityExpression)
: _clientEval
? base.VisitExtension(includeExpression)
: null;
return _clientEval
? base.VisitExtension(includeExpression)
: null;
}

throw new InvalidOperationException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
using Microsoft.EntityFrameworkCore.Relational.Query.Pipeline.SqlExpressions;
using System.Diagnostics;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Query.Expressions.Internal;

namespace Microsoft.EntityFrameworkCore.Relational.Query.Pipeline
{
public class RelationalQueryableMethodTranslatingExpressionVisitor : QueryableMethodTranslatingExpressionVisitor
{
private readonly RelationalSqlTranslatingExpressionVisitor _sqlTranslator;
private readonly WeakEntityExpandingExpressionVisitor _weakEntityExpandingExpressionVisitor;
private readonly RelationalProjectionBindingExpressionVisitor _projectionBindingExpressionVisitor;
private readonly IModel _model;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
Expand All @@ -31,6 +34,7 @@ public RelationalQueryableMethodTranslatingExpressionVisitor(
: base(subquery: false)
{
_sqlTranslator = relationalSqlTranslatingExpressionVisitorFactory.Create(model, this);
_weakEntityExpandingExpressionVisitor = new WeakEntityExpandingExpressionVisitor(_sqlTranslator, sqlExpressionFactory);
_projectionBindingExpressionVisitor = new RelationalProjectionBindingExpressionVisitor(this, _sqlTranslator);
_model = model;
_sqlExpressionFactory = sqlExpressionFactory;
Expand All @@ -39,11 +43,13 @@ public RelationalQueryableMethodTranslatingExpressionVisitor(
private RelationalQueryableMethodTranslatingExpressionVisitor(
IModel model,
RelationalSqlTranslatingExpressionVisitor sqlTranslator,
WeakEntityExpandingExpressionVisitor weakEntityExpandingExpressionVisitor,
ISqlExpressionFactory sqlExpressionFactory)
: base(subquery: true)
{
_model = model;
_sqlTranslator = sqlTranslator;
_weakEntityExpandingExpressionVisitor = weakEntityExpandingExpressionVisitor;
_projectionBindingExpressionVisitor = new RelationalProjectionBindingExpressionVisitor(this, sqlTranslator);
_sqlExpressionFactory = sqlExpressionFactory;
}
Expand All @@ -66,6 +72,7 @@ public override ShapedQueryExpression TranslateSubquery(Expression expression)
return (ShapedQueryExpression)new RelationalQueryableMethodTranslatingExpressionVisitor(
_model,
_sqlTranslator,
_weakEntityExpandingExpressionVisitor,
_sqlExpressionFactory).Visit(expression);

}
Expand All @@ -86,7 +93,7 @@ private ShapedQueryExpression CreateShapedQueryExpression(Type elementType, stri
return CreateShapedQueryExpression(entityType, queryExpression);
}

private ShapedQueryExpression CreateShapedQueryExpression(IEntityType entityType, SelectExpression selectExpression)
private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType entityType, SelectExpression selectExpression)
{
return new ShapedQueryExpression(
selectExpression,
Expand Down Expand Up @@ -275,7 +282,7 @@ protected override ShapedQueryExpression TranslateGroupBy(
var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.PrepareForAggregate();

var remappedKeySelector = RemapLambdaBody(source.ShaperExpression, keySelector);
var remappedKeySelector = RemapLambdaBody(source, keySelector);

var translatedKey = TranslateGroupingKey(remappedKeySelector)
?? (remappedKeySelector is ConstantExpression ? remappedKeySelector : null);
Expand Down Expand Up @@ -476,8 +483,8 @@ private SqlBinaryExpression CreateJoinPredicate(
ShapedQueryExpression inner,
LambdaExpression innerKeySelector)
{
var outerKey = RemapLambdaBody(outer.ShaperExpression, outerKeySelector);
var innerKey = RemapLambdaBody(inner.ShaperExpression, innerKeySelector);
var outerKey = RemapLambdaBody(outer, outerKeySelector);
var innerKey = RemapLambdaBody(inner, innerKeySelector);

if (outerKey is NewExpression outerNew)
{
Expand Down Expand Up @@ -517,12 +524,9 @@ private SqlBinaryExpression CreateJoinPredicate(
var left = TranslateExpression(outerKey);
var right = TranslateExpression(innerKey);

if (left != null && right != null)
{
return _sqlExpressionFactory.Equal(left, right);
}

return null;
return left != null && right != null
? _sqlExpressionFactory.Equal(left, right)
: null;
}

protected override ShapedQueryExpression TranslateLastOrDefault(
Expand Down Expand Up @@ -687,7 +691,7 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s
selectExpression.PushdownIntoSubquery();
}

var newSelectorBody = ReplacingExpressionVisitor.Replace(selector.Parameters.Single(), source.ShaperExpression, selector.Body);
var newSelectorBody = RemapLambdaBody(source, selector);
source.ShaperExpression = _projectionBindingExpressionVisitor.Translate(selectExpression, newSelectorBody);

return source;
Expand Down Expand Up @@ -877,14 +881,156 @@ private SqlExpression TranslateExpression(Expression expression)
private SqlExpression TranslateLambdaExpression(
ShapedQueryExpression shapedQueryExpression, LambdaExpression lambdaExpression)
{
var lambdaBody = RemapLambdaBody(shapedQueryExpression.ShaperExpression, lambdaExpression);
var lambdaBody = RemapLambdaBody(shapedQueryExpression, lambdaExpression);

return TranslateExpression(lambdaBody);
}

private Expression RemapLambdaBody(Expression shaperBody, LambdaExpression lambdaExpression)
private Expression RemapLambdaBody(ShapedQueryExpression shapedQueryExpression, LambdaExpression lambdaExpression)
{
var lambdaBody = ReplacingExpressionVisitor.Replace(
lambdaExpression.Parameters.Single(), shapedQueryExpression.ShaperExpression, lambdaExpression.Body);

var selectExpression = (SelectExpression)shapedQueryExpression.QueryExpression;
lambdaBody = _weakEntityExpandingExpressionVisitor.Expand(selectExpression, lambdaBody);

return lambdaBody;
}

public class WeakEntityExpandingExpressionVisitor : ExpressionVisitor
{
return ReplacingExpressionVisitor.Replace(lambdaExpression.Parameters.Single(), shaperBody, lambdaExpression.Body);
private SelectExpression _selectExpression;
private readonly RelationalSqlTranslatingExpressionVisitor _sqlTranslator;
private readonly ISqlExpressionFactory _sqlExpressionFactory;

public WeakEntityExpandingExpressionVisitor(
RelationalSqlTranslatingExpressionVisitor sqlTranslator, ISqlExpressionFactory sqlExpressionFactory)
{
_sqlTranslator = sqlTranslator;
_sqlExpressionFactory = sqlExpressionFactory;
}

public Expression Expand(SelectExpression selectExpression, Expression lambdaBody)
{
_selectExpression = selectExpression;

return Visit(lambdaBody);
}

protected override Expression VisitMember(MemberExpression memberExpression)
{
var innerExpression = Visit(memberExpression.Expression);

if (innerExpression is EntityShaperExpression
|| (innerExpression is UnaryExpression innerUnaryExpression
&& innerUnaryExpression.NodeType == ExpressionType.Convert
&& innerUnaryExpression.Operand is EntityShaperExpression))
{
var collectionNavigation = Expand(innerExpression, MemberIdentity.Create(memberExpression.Member));
if (collectionNavigation != null)
{
return collectionNavigation;
}
}

return memberExpression.Update(innerExpression);
}

protected override Expression VisitExtension(Expression extensionExpression)
=> extensionExpression is EntityShaperExpression
? extensionExpression
: base.VisitExtension(extensionExpression);

private Expression Expand(Expression source, MemberIdentity member)
{
Type convertedType = null;
if (source is UnaryExpression unaryExpression
&& unaryExpression.NodeType == ExpressionType.Convert)
{
source = unaryExpression.Operand;
if (unaryExpression.Type != typeof(object))
{
convertedType = unaryExpression.Type;
}
}

if (source is EntityShaperExpression entityShaperExpression)
{
var entityType = entityShaperExpression.EntityType;
if (convertedType != null)
{
entityType = entityType.RootType().GetDerivedTypesInclusive()
.FirstOrDefault(et => et.ClrType == convertedType);

if (entityType == null)
{
return null;
}
}

var navigation = member.MemberInfo != null
? entityType.FindNavigation(member.MemberInfo)
: entityType.FindNavigation(member.Name);

if (navigation != null)
{
if (navigation.IsCollection())
{
return CreateShapedQueryExpression(
navigation.GetTargetType(),
_sqlExpressionFactory.Select(navigation.GetTargetType()));
}

var entityProjectionExpression = (EntityProjectionExpression)
(entityShaperExpression.ValueBufferExpression is ProjectionBindingExpression projectionBindingExpression
? _selectExpression.GetMappedProjection(projectionBindingExpression.ProjectionMember)
: entityShaperExpression.ValueBufferExpression);

var innerShaper = entityProjectionExpression.BindNavigation(navigation);
if (innerShaper == null)
{
var targetEntityType = navigation.GetTargetType();
var innerSelectExpression = _sqlExpressionFactory.Select(targetEntityType);
var innerShapedQuery = CreateShapedQueryExpression(targetEntityType, innerSelectExpression);

var makeNullable = navigation.ForeignKey.PrincipalKey.Properties
.Concat(navigation.ForeignKey.Properties)
.Select(p => p.ClrType)
.Any(t => t.IsNullableType());

var outerKey = CreateKeyAccessExpression(
entityShaperExpression, navigation.ForeignKey.PrincipalKey.Properties, makeNullable);
var innerKey = CreateKeyAccessExpression(
innerShapedQuery.ShaperExpression, navigation.ForeignKey.Properties, makeNullable);

var joinPredicate = _sqlTranslator.Translate(Expression.Equal(outerKey, innerKey));
_selectExpression.AddLeftJoin(innerSelectExpression, joinPredicate, null);
var leftJoinTable = ((LeftJoinExpression)_selectExpression.Tables.Last()).Table;
innerShaper = new EntityShaperExpression(targetEntityType,
new EntityProjectionExpression(targetEntityType, leftJoinTable, true),
true);
entityProjectionExpression.AddNavigationBinding(navigation, innerShaper);
}

return innerShaper;
}
}

return null;
}

public static Expression CreateKeyAccessExpression(
Expression target, IReadOnlyList<IProperty> properties, bool makeNullable = false)
=> properties.Count == 1
? target.CreateEFPropertyExpression(properties[0], makeNullable)
: Expression.New(
AnonymousObject.AnonymousObjectCtor,
Expression.NewArrayInit(
typeof(object),
properties
.Select(p => Expression.Convert(target.CreateEFPropertyExpression(p, makeNullable), typeof(object)))
.Cast<Expression>()
.ToArray()));
}

private ShapedQueryExpression AggregateResultShaper(
Expand Down
Loading

0 comments on commit b9857de

Please sign in to comment.