Skip to content

Commit

Permalink
Query: Move Property/Navigation binding to EntityProjection directly
Browse files Browse the repository at this point in the history
Support for Owned entities
Part of #15285

Also added SQL Assertion on Sqlite for LongCount
  • Loading branch information
smitpatel committed Jul 3, 2019
1 parent 6d1503d commit 3d44d35
Show file tree
Hide file tree
Showing 11 changed files with 169 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,25 +83,22 @@ protected override Expression VisitMember(MemberExpression memberExpression)

private bool TryBindProperty(Expression source, MemberIdentity member, out SqlExpression expression)
{
if (source is EntityShaperExpression entityShaperExpression)
if (source is EntityProjectionExpression entityProjectionExpression)
{
var projectionBindingExpression = (ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression;
var selectExpression = ((SelectExpression)projectionBindingExpression.QueryExpression);

var entityType = entityShaperExpression.EntityType;
var entityType = entityProjectionExpression.EntityType;
var property = member.MemberInfo != null
? entityType.FindProperty(member.MemberInfo)
: entityType.FindProperty(member.Name);
if (property != null)
{
expression = selectExpression.BindProperty(property, projectionBindingExpression);
expression = entityProjectionExpression.BindProperty(property);
return true;
}

var navigation = member.MemberInfo != null
? entityType.FindNavigation(member.MemberInfo)
: entityType.FindNavigation(member.Name);
expression = selectExpression.BindNavigation(navigation, projectionBindingExpression);
expression = entityProjectionExpression.BindNavigation(navigation);
return true;
}
else if (source is ObjectAccessExpression objectAccessExpression)
Expand Down Expand Up @@ -131,7 +128,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
{
if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName))
{
if (!TryBindProperty(source, MemberIdentity.Create(propertyName), out var result))
if (!TryBindProperty(Visit(source), MemberIdentity.Create(propertyName), out var result))
{
throw new InvalidOperationException();
}
Expand Down Expand Up @@ -260,6 +257,11 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
{
var operand = Visit(unaryExpression.Operand);

if (operand is EntityProjectionExpression)
{
return unaryExpression.Update(operand);
}

if (TranslationFailed(unaryExpression.Operand, operand))
{
return null;
Expand Down Expand Up @@ -319,24 +321,25 @@ protected override Expression VisitParameter(ParameterExpression parameterExpres

protected override Expression VisitExtension(Expression extensionExpression)
{
if (extensionExpression is EntityShaperExpression)
switch (extensionExpression)
{
return extensionExpression;
}
case EntityProjectionExpression _:
case SqlExpression _:
return extensionExpression;

if (extensionExpression is ProjectionBindingExpression projectionBindingExpression)
{
var selectExpression = (SelectExpression)projectionBindingExpression.QueryExpression;
case EntityShaperExpression entityShaperExpression:
return Visit(entityShaperExpression.ValueBufferExpression);

return selectExpression.GetMappedProjection(projectionBindingExpression.ProjectionMember);
}
case ProjectionBindingExpression projectionBindingExpression:
var selectExpression = (SelectExpression)projectionBindingExpression.QueryExpression;
return selectExpression.GetMappedProjection(projectionBindingExpression.ProjectionMember);

if (extensionExpression is NullConditionalExpression nullConditionalExpression)
{
return Visit(nullConditionalExpression.AccessOperation);
}
case NullConditionalExpression nullConditionalExpression:
return Visit(nullConditionalExpression.AccessOperation);

return base.VisitExtension(extensionExpression);
default:
return null;
}
}

[DebuggerStepThrough]
Expand Down
21 changes: 11 additions & 10 deletions src/EFCore.Cosmos/Query/Pipeline/EntityProjectionExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,38 @@ private readonly IDictionary<IProperty, KeyAccessExpression> _propertyExpression
= new Dictionary<IProperty, KeyAccessExpression>();
private readonly IDictionary<INavigation, ObjectAccessExpression> _navigationExpressionsCache
= new Dictionary<INavigation, ObjectAccessExpression>();
private readonly IEntityType _entityType;

public EntityProjectionExpression(IEntityType entityType, RootReferenceExpression accessExpression, string alias)
{
_entityType = entityType;
EntityType = entityType;
AccessExpression = accessExpression;
Alias = alias;
}

public override ExpressionType NodeType => ExpressionType.Extension;
public override Type Type => _entityType.ClrType;
public override Type Type => EntityType.ClrType;

public string Alias { get; }

public RootReferenceExpression AccessExpression { get; }

public IEntityType EntityType { get; }

protected override Expression VisitChildren(ExpressionVisitor visitor)
{
var accessExpression = (RootReferenceExpression)visitor.Visit(AccessExpression);

return accessExpression != AccessExpression
? new EntityProjectionExpression(_entityType, accessExpression, Alias)
? new EntityProjectionExpression(EntityType, accessExpression, Alias)
: this;
}

public KeyAccessExpression GetProperty(IProperty property)
public KeyAccessExpression BindProperty(IProperty property)
{
if (!_entityType.GetTypesInHierarchy().Contains(property.DeclaringEntityType))
if (!EntityType.GetTypesInHierarchy().Contains(property.DeclaringEntityType))
{
throw new InvalidOperationException(
$"Called EntityProjectionExpression.GetProperty() with incorrect IProperty. EntityType:{_entityType.DisplayName()}, Property:{property.Name}");
$"Called EntityProjectionExpression.GetProperty() with incorrect IProperty. EntityType:{EntityType.DisplayName()}, Property:{property.Name}");
}

if (!_propertyExpressionsCache.TryGetValue(property, out var expression))
Expand All @@ -58,12 +59,12 @@ public KeyAccessExpression GetProperty(IProperty property)
return expression;
}

public ObjectAccessExpression GetNavigation(INavigation navigation)
public ObjectAccessExpression BindNavigation(INavigation navigation)
{
if (!_entityType.GetTypesInHierarchy().Contains(navigation.DeclaringEntityType))
if (!EntityType.GetTypesInHierarchy().Contains(navigation.DeclaringEntityType))
{
throw new InvalidOperationException(
$"Called EntityProjectionExpression.GetNavigation() with incorrect INavigation. EntityType:{_entityType.DisplayName()}, Navigation:{navigation.Name}");
$"Called EntityProjectionExpression.GetNavigation() with incorrect INavigation. EntityType:{EntityType.DisplayName()}, Navigation:{navigation.Name}");
}

if (!_navigationExpressionsCache.TryGetValue(navigation, out var expression))
Expand Down
8 changes: 0 additions & 8 deletions src/EFCore.Cosmos/Query/Pipeline/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,6 @@ public void ReverseOrderings()
}
}

public SqlExpression BindProperty(IProperty property, ProjectionBindingExpression projectionBindingExpression)
=> ((EntityProjectionExpression)_projectionMapping[projectionBindingExpression.ProjectionMember])
.GetProperty(property);

public SqlExpression BindNavigation(INavigation navigation, ProjectionBindingExpression projectionBindingExpression)
=> ((EntityProjectionExpression)_projectionMapping[projectionBindingExpression.ProjectionMember])
.GetNavigation(navigation);

public override Type Type => typeof(JObject);
public override ExpressionType NodeType => ExpressionType.Extension;

Expand Down
4 changes: 2 additions & 2 deletions src/EFCore.Cosmos/Query/Pipeline/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ private void AddDiscriminator(SelectExpression selectExpression, IEntityType ent
if (concreteEntityType.GetDiscriminatorProperty() != null)
{
var discriminatorColumn = ((EntityProjectionExpression)selectExpression.GetMappedProjection(new ProjectionMember()))
.GetProperty(concreteEntityType.GetDiscriminatorProperty());
.BindProperty(concreteEntityType.GetDiscriminatorProperty());

selectExpression.ApplyPredicate(
Equal(discriminatorColumn, Constant(concreteEntityType.GetDiscriminatorValue())));
Expand All @@ -377,7 +377,7 @@ private void AddDiscriminator(SelectExpression selectExpression, IEntityType ent
else
{
var discriminatorColumn = ((EntityProjectionExpression)selectExpression.GetMappedProjection(new ProjectionMember()))
.GetProperty(concreteEntityTypes[0].GetDiscriminatorProperty());
.BindProperty(concreteEntityTypes[0].GetDiscriminatorProperty());

selectExpression.ApplyPredicate(
In(discriminatorColumn, Constant(concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList()), negated: false));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public EntityProjectionExpression UpdateEntityType(IEntityType derivedType)
public override ExpressionType NodeType => ExpressionType.Extension;
public override Type Type => EntityType.ClrType;

public ColumnExpression GetProperty(IProperty property)
public ColumnExpression BindProperty(IProperty property)
{
if (!EntityType.GetTypesInHierarchy().Contains(property.DeclaringEntityType))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,8 +622,9 @@ protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression s
var selectExpression = (SelectExpression)source.QueryExpression;
var concreteEntityTypes = derivedType.GetConcreteDerivedTypesInclusive().ToList();
var projectionBindingExpression = (ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression;
var discriminatorColumn = selectExpression
.BindProperty(projectionBindingExpression, entityType.GetDiscriminatorProperty());
var entityProjectionExpression = (EntityProjectionExpression)selectExpression.GetMappedProjection(
projectionBindingExpression.ProjectionMember);
var discriminatorColumn = entityProjectionExpression.BindProperty(entityType.GetDiscriminatorProperty());

var predicate = concreteEntityTypes.Count == 1
? _sqlExpressionFactory.Equal(discriminatorColumn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,16 @@ protected override Expression VisitExtension(Expression node)

protected override Expression VisitMember(MemberExpression memberExpression)
{
if (memberExpression.Expression is EntityShaperExpression
|| (memberExpression.Expression is UnaryExpression innerUnaryExpression
var innerExpression = Visit(memberExpression.Expression);

if (innerExpression is EntityProjectionExpression
|| (innerExpression is UnaryExpression innerUnaryExpression
&& innerUnaryExpression.NodeType == ExpressionType.Convert
&& innerUnaryExpression.Operand is EntityShaperExpression))
&& innerUnaryExpression.Operand is EntityProjectionExpression))
{
return BindProperty(memberExpression.Expression, memberExpression.Member.GetSimpleMemberName());
return BindProperty(innerExpression, memberExpression.Member.GetSimpleMemberName());
}

var innerExpression = Visit(memberExpression.Expression);

return TranslationFailed(memberExpression.Expression, innerExpression)
? null
: _memberTranslatorProvider.Translate((SqlExpression)innerExpression, memberExpression.Member, memberExpression.Type);
Expand All @@ -195,9 +195,9 @@ private SqlExpression BindProperty(Expression source, string propertyName)
}
}

if (source is EntityShaperExpression entityShaper)
if (source is EntityProjectionExpression entityProjection)
{
var entityType = entityShaper.EntityType;
var entityType = entityProjection.EntityType;
if (convertedType != null)
{
entityType = entityType.RootType().GetDerivedTypesInclusive()
Expand All @@ -209,25 +209,23 @@ private SqlExpression BindProperty(Expression source, string propertyName)
}
}

return BindProperty(entityShaper, entityType.FindProperty(propertyName));
return BindProperty(entityProjection, entityType.FindProperty(propertyName));
}

throw new InvalidOperationException();
}

private SqlExpression BindProperty(EntityShaperExpression entityShaperExpression, IProperty property)
private SqlExpression BindProperty(EntityProjectionExpression entityProjectionExpression, IProperty property)
{
var projectionBindingExpression = (ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression;
return ((SelectExpression)projectionBindingExpression.QueryExpression)
.BindProperty(projectionBindingExpression, property);
return entityProjectionExpression.BindProperty(property);
}

protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExpression)
{
if (typeBinaryExpression.NodeType == ExpressionType.TypeIs
&& typeBinaryExpression.Expression is EntityShaperExpression entityShaperExpression)
&& Visit(typeBinaryExpression.Expression) is EntityProjectionExpression entityProjectionExpression)
{
var entityType = entityShaperExpression.EntityType;
var entityType = entityProjectionExpression.EntityType;
if (entityType.GetAllBaseTypesInclusive().Any(et => et.ClrType == typeBinaryExpression.TypeOperand))
{
return _sqlExpressionFactory.Constant(true);
Expand All @@ -237,7 +235,7 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp
if (derivedType != null)
{
var concreteEntityTypes = derivedType.GetConcreteDerivedTypesInclusive().ToList();
var discriminatorColumn = BindProperty(entityShaperExpression, entityType.GetDiscriminatorProperty());
var discriminatorColumn = BindProperty(entityProjectionExpression, entityType.GetDiscriminatorProperty());

return concreteEntityTypes.Count == 1
? _sqlExpressionFactory.Equal(discriminatorColumn,
Expand Down Expand Up @@ -277,7 +275,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
// EF.Property case
if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName))
{
return BindProperty(source, propertyName);
return BindProperty(Visit(source), propertyName);
}

// GroupBy Aggregate case
Expand Down Expand Up @@ -459,15 +457,30 @@ protected override Expression VisitParameter(ParameterExpression parameterExpres
=> new SqlParameterExpression(parameterExpression, null);

protected override Expression VisitExtension(Expression extensionExpression)
=> extensionExpression switch
{
switch (extensionExpression)
{
EntityShaperExpression e => e,
SqlExpression e => e,
NullConditionalExpression e => Visit(e.AccessOperation),
CorrelationPredicateExpression e => Visit(e.EqualExpression),
ProjectionBindingExpression e => ((SelectExpression)e.QueryExpression).GetMappedProjection(e.ProjectionMember),
_ => null
};
case EntityProjectionExpression _:
case SqlExpression _:
return extensionExpression;

case NullConditionalExpression nullConditionalExpression:
return Visit(nullConditionalExpression.AccessOperation);

case EntityShaperExpression entityShaperExpression:
return Visit(entityShaperExpression.ValueBufferExpression);

case CorrelationPredicateExpression correlationPredicateExpression:
return Visit(correlationPredicateExpression.EqualExpression);

case ProjectionBindingExpression projectionBindingExpression:
var selectExpression = (SelectExpression)projectionBindingExpression.QueryExpression;
return selectExpression.GetMappedProjection(projectionBindingExpression.ProjectionMember);

default:
return null;
}
}

protected override Expression VisitConditional(ConditionalExpression conditionalExpression)
{
Expand All @@ -494,16 +507,19 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
{
var operand = Visit(unaryExpression.Operand);

if (operand is EntityProjectionExpression)
{
return unaryExpression.Update(operand);
}

if (TranslationFailed(unaryExpression.Operand, operand))
{
return null;
}

var sqlOperand = (SqlExpression)operand;

switch (unaryExpression.NodeType)
{

case ExpressionType.Not:
return _sqlExpressionFactory.Not(sqlOperand);

Expand Down
4 changes: 2 additions & 2 deletions src/EFCore.Relational/Query/Pipeline/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ private void AddDiscriminator(SelectExpression selectExpression, IEntityType ent
}

var discriminatorColumn = ((EntityProjectionExpression)selectExpression.GetMappedProjection(new ProjectionMember()))
.GetProperty(concreteEntityType.GetDiscriminatorProperty());
.BindProperty(concreteEntityType.GetDiscriminatorProperty());

selectExpression.ApplyPredicate(
Equal(discriminatorColumn, Constant(concreteEntityType.GetDiscriminatorValue())));
Expand All @@ -558,7 +558,7 @@ private void AddDiscriminator(SelectExpression selectExpression, IEntityType ent
else
{
var discriminatorColumn = ((EntityProjectionExpression)selectExpression.GetMappedProjection(new ProjectionMember()))
.GetProperty(concreteEntityTypes[0].GetDiscriminatorProperty());
.BindProperty(concreteEntityTypes[0].GetDiscriminatorProperty());

selectExpression.ApplyPredicate(
In(discriminatorColumn, Constant(concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList()), negated: false));
Expand Down
Loading

0 comments on commit 3d44d35

Please sign in to comment.