Skip to content

Commit

Permalink
Query: Change TPT to use string discriminator column (#21784)
Browse files Browse the repository at this point in the history
Part of #21509
  • Loading branch information
smitpatel authored Jul 25, 2020
1 parent bc9dc1b commit f203e1f
Show file tree
Hide file tree
Showing 12 changed files with 1,283 additions and 1,748 deletions.
54 changes: 22 additions & 32 deletions src/EFCore.Relational/Query/EntityProjectionExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Runtime.InteropServices;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Metadata;
Expand Down Expand Up @@ -44,28 +46,28 @@ public EntityProjectionExpression([NotNull] IEntityType entityType, [NotNull] Ta
/// </summary>
/// <param name="entityType"> The entity type to shape. </param>
/// <param name="propertyExpressionMap"> A dictionary of column expressions corresponding to properties of the entity type. </param>
/// <param name="entityTypeIdentifyingExpressionMap"> A dictionary of <see cref="SqlExpression"/> to identify each entity type in hierarchy. </param>
/// <param name="discriminatorExpression"> A <see cref="SqlExpression"/> to generate discriminator for each concrete entity type in hierarchy. </param>
public EntityProjectionExpression(
[NotNull] IEntityType entityType,
[NotNull] IDictionary<IProperty, ColumnExpression> propertyExpressionMap,
[CanBeNull] IReadOnlyDictionary<IEntityType, SqlExpression> entityTypeIdentifyingExpressionMap = null)
[CanBeNull] SqlExpression discriminatorExpression = null)
{
Check.NotNull(entityType, nameof(entityType));
Check.NotNull(propertyExpressionMap, nameof(propertyExpressionMap));

EntityType = entityType;
_propertyExpressionMap = propertyExpressionMap;
EntityTypeIdentifyingExpressionMap = entityTypeIdentifyingExpressionMap;
DiscriminatorExpression = discriminatorExpression;
}

/// <summary>
/// The entity type being projected out.
/// </summary>
public virtual IEntityType EntityType { get; }
/// <summary>
/// Dictionary of entity type identifying expressions.
/// A <see cref="SqlExpression"/> to generate discriminator for entity type.
/// </summary>
public virtual IReadOnlyDictionary<IEntityType, SqlExpression> EntityTypeIdentifyingExpressionMap { get; }
public virtual SqlExpression DiscriminatorExpression { get; }
/// <inheritdoc />
public sealed override ExpressionType NodeType => ExpressionType.Extension;
/// <inheritdoc />
Expand All @@ -86,21 +88,11 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
propertyExpressionMap[expression.Key] = newExpression;
}

Dictionary<IEntityType, SqlExpression> entityTypeIdentifyingExpressionMap = null;
if (EntityTypeIdentifyingExpressionMap != null)
{
entityTypeIdentifyingExpressionMap = new Dictionary<IEntityType, SqlExpression>();
foreach (var expression in EntityTypeIdentifyingExpressionMap)
{
var newExpression = (SqlExpression)visitor.Visit(expression.Value);
changed |= newExpression != expression.Value;

entityTypeIdentifyingExpressionMap[expression.Key] = newExpression;
}
}
var discriminatorExpression = (SqlExpression)visitor.Visit(DiscriminatorExpression);
changed |= discriminatorExpression != DiscriminatorExpression;

return changed
? new EntityProjectionExpression(EntityType, propertyExpressionMap, entityTypeIdentifyingExpressionMap)
? new EntityProjectionExpression(EntityType, propertyExpressionMap, discriminatorExpression)
: this;
}

Expand All @@ -116,8 +108,8 @@ public virtual EntityProjectionExpression MakeNullable()
propertyExpressionMap[expression.Key] = expression.Value.MakeNullable();
}

// We don't need to process EntityTypeIdentifyingExpressionMap because they are already nullable
return new EntityProjectionExpression(EntityType, propertyExpressionMap, EntityTypeIdentifyingExpressionMap);
// We don't need to process DiscriminatorExpression because they are already nullable
return new EntityProjectionExpression(EntityType, propertyExpressionMap, DiscriminatorExpression);
}

/// <summary>
Expand All @@ -140,21 +132,19 @@ public virtual EntityProjectionExpression UpdateEntityType([NotNull] IEntityType
}
}

Dictionary<IEntityType, SqlExpression> entityTypeIdentifyingExpressionMap = null;
if (EntityTypeIdentifyingExpressionMap != null)
var discriminatorExpression = DiscriminatorExpression;
if (DiscriminatorExpression is CaseExpression caseExpression)
{
entityTypeIdentifyingExpressionMap = new Dictionary<IEntityType, SqlExpression>();
foreach (var kvp in EntityTypeIdentifyingExpressionMap)
{
var entityType = kvp.Key;
if (entityType.IsStrictlyDerivedFrom(derivedType))
{
entityTypeIdentifyingExpressionMap[entityType] = kvp.Value;
}
}
var entityTypesToSelect = derivedType.GetDerivedTypesInclusive().Where(et => !et.IsAbstract())
.Select(et => et.ShortName()).ToHashSet();
var whenClauses = caseExpression.WhenClauses
.Where(wc => entityTypesToSelect.Contains((string)((SqlConstantExpression)wc.Result).Value))
.ToList();

discriminatorExpression = caseExpression.Update(operand: null, whenClauses, elseResult: null);
}

return new EntityProjectionExpression(derivedType, propertyExpressionMap, entityTypeIdentifyingExpressionMap);
return new EntityProjectionExpression(derivedType, propertyExpressionMap, discriminatorExpression);
}

/// <summary>
Expand Down
33 changes: 20 additions & 13 deletions src/EFCore.Relational/Query/RelationalEntityShaperExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
Expand Down Expand Up @@ -72,22 +73,28 @@ protected override LambdaExpression GenerateMaterializationCondition(IEntityType
{
// TPT
var valueBufferParameter = Parameter(typeof(ValueBuffer));
var body = entityType.IsAbstract()
? Block(Throw(Call(_createUnableToIdentifyConcreteTypeException)), Constant(null, typeof(IEntityType)))
: (Expression)Constant(entityType, typeof(IEntityType));

var concreteEntityTypes = entityType.GetDerivedTypes().Where(dt => !dt.IsAbstract()).ToArray();
for (var i = 0; i < concreteEntityTypes.Length; i++)
var discriminatorValueVariable = Variable(typeof(string), "discriminator");
var expressions = new List<Expression>
{
Assign(
discriminatorValueVariable,
valueBufferParameter.CreateValueBufferReadValueExpression(typeof(string), 0, null))
};

var derivedConcreteEntityTypes = entityType.GetDerivedTypes().Where(dt => !dt.IsAbstract()).ToArray();
var switchCases = new SwitchCase[derivedConcreteEntityTypes.Length];
for (var i = 0; i < derivedConcreteEntityTypes.Length; i++)
{
body = Condition(
Equal(
valueBufferParameter.CreateValueBufferReadValueExpression(typeof(bool?), i, property: null),
Constant(true, typeof(bool?))),
Constant(concreteEntityTypes[i], typeof(IEntityType)),
body);
var discriminatorValue = Constant(derivedConcreteEntityTypes[i].ShortName(), typeof(string));
switchCases[i] = SwitchCase(Constant(derivedConcreteEntityTypes[i], typeof(IEntityType)), discriminatorValue);
}

baseCondition = Lambda(body, valueBufferParameter);
var defaultBlock = entityType.IsAbstract()
? Block(Throw(Call(_createUnableToIdentifyConcreteTypeException)), Constant(null, typeof(IEntityType)))
: (Expression)Constant(entityType, typeof(IEntityType));

expressions.Add(Switch(discriminatorValueVariable, defaultBlock, switchCases));
baseCondition = Lambda(Block(new[] { discriminatorValueVariable }, expressions), valueBufferParameter);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,8 @@ protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression s
if (discriminatorProperty == null)
{
var selectExpression = (SelectExpression)source.QueryExpression;
var concreteEntityTypes = derivedType.GetConcreteDerivedTypesInclusive().ToList();
var discriminatorValues = derivedType.GetConcreteDerivedTypesInclusive().Where(et => !et.IsAbstract())
.Select(et => et.ShortName()).ToList();
var projectionBindingExpression = (ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression;

var projectionMember = projectionBindingExpression.ProjectionMember;
Expand All @@ -793,17 +794,38 @@ protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression s

var entityProjectionExpression = (EntityProjectionExpression)selectExpression.GetMappedProjection(projectionMember);

var predicate = entityProjectionExpression.EntityTypeIdentifyingExpressionMap
.Where(kvp => concreteEntityTypes.Contains(kvp.Key))
.Select(kvp => kvp.Value)
.Aggregate((l, r) => _sqlExpressionFactory.OrElse(l, r));
var predicate = GeneratePredicateTPT(entityProjectionExpression);

selectExpression.ApplyPredicate(predicate);
selectExpression.ReplaceProjectionMapping(
new Dictionary<ProjectionMember, Expression>
{
{ projectionMember, entityProjectionExpression.UpdateEntityType(derivedType) }
});

SqlExpression GeneratePredicateTPT(EntityProjectionExpression entityProjectionExpression)
{
if (entityProjectionExpression.DiscriminatorExpression is CaseExpression caseExpression)
{
var matchingCaseWhenClauses = caseExpression.WhenClauses
.Where(wc => discriminatorValues.Contains((string)((SqlConstantExpression)wc.Result).Value))
.ToList();

return matchingCaseWhenClauses.Count == 1
? matchingCaseWhenClauses[0].Test
: matchingCaseWhenClauses.Select(e => e.Test)
.Aggregate((l, r) => _sqlExpressionFactory.OrElse(l, r));
}

return discriminatorValues.Count == 1
? _sqlExpressionFactory.Equal(
entityProjectionExpression.DiscriminatorExpression,
_sqlExpressionFactory.Constant(discriminatorValues[0]))
: (SqlExpression)_sqlExpressionFactory.In(
entityProjectionExpression.DiscriminatorExpression,
_sqlExpressionFactory.Constant(discriminatorValues),
negated: false);
}
}
else if (!derivedType.GetRootType().GetIsDiscriminatorMappingComplete()
|| !derivedType.GetAllBaseTypesInclusiveAscending()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -775,16 +775,14 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp
if (discriminatorProperty == null)
{
// TPT
var discriminatorValues = concreteEntityTypes.Select(et => et.ShortName()).ToList();
if (entityReferenceExpression.SubqueryEntity != null)
{
var entityShaper = (EntityShaperExpression)entityReferenceExpression.SubqueryEntity.ShaperExpression;
var entityProjection = (EntityProjectionExpression)Visit(entityShaper.ValueBufferExpression);
var subSelectExpression = (SelectExpression)entityReferenceExpression.SubqueryEntity.QueryExpression;

var predicate = entityProjection.EntityTypeIdentifyingExpressionMap
.Where(kvp => concreteEntityTypes.Contains(kvp.Key))
.Select(kvp => kvp.Value)
.Aggregate((l, r) => _sqlExpressionFactory.OrElse(l, r));
var predicate = GeneratePredicateTPT(entityProjection);

subSelectExpression.ApplyPredicate(predicate);
subSelectExpression.ReplaceProjectionMapping(new Dictionary<ProjectionMember, Expression>());
Expand All @@ -802,16 +800,36 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp
var entityProjection = (EntityProjectionExpression)Visit(
entityReferenceExpression.ParameterEntity.ValueBufferExpression);

return entityProjection.EntityTypeIdentifyingExpressionMap
.Where(kvp => concreteEntityTypes.Contains(kvp.Key))
.Select(kvp => kvp.Value)
.Aggregate((l, r) => _sqlExpressionFactory.OrElse(l, r));
return GeneratePredicateTPT(entityProjection);
}

SqlExpression GeneratePredicateTPT(EntityProjectionExpression entityProjectionExpression)
{
if (entityProjectionExpression.DiscriminatorExpression is CaseExpression caseExpression)
{
var matchingCaseWhenClauses = caseExpression.WhenClauses
.Where(wc => discriminatorValues.Contains((string)((SqlConstantExpression)wc.Result).Value))
.ToList();

return matchingCaseWhenClauses.Count == 1
? matchingCaseWhenClauses[0].Test
: matchingCaseWhenClauses.Select(e => e.Test)
.Aggregate((l, r) => _sqlExpressionFactory.OrElse(l, r));
}

return discriminatorValues.Count == 1
? _sqlExpressionFactory.Equal(
entityProjectionExpression.DiscriminatorExpression,
_sqlExpressionFactory.Constant(discriminatorValues[0]))
: (SqlExpression)_sqlExpressionFactory.In(
entityProjectionExpression.DiscriminatorExpression,
_sqlExpressionFactory.Constant(discriminatorValues),
negated: false);
}
}
else
{
var discriminatorColumn = BindProperty(entityReferenceExpression, discriminatorProperty);

if (discriminatorColumn != null)
{
return concreteEntityTypes.Count == 1
Expand Down Expand Up @@ -1169,9 +1187,9 @@ private bool TryRewriteEntityEquality(ExpressionType nodeType, Expression left,

result = Visit(primaryKeyProperties1.Select(p =>
{
var comparison = Expression.Call(_objectEqualsMethodInfo,
Expression.Convert(CreatePropertyAccessExpression(nonNullEntityReference, p), typeof(object)),
Expression.Convert(Expression.Constant(null, p.ClrType.MakeNullable()), typeof(object)));
var comparison = Expression.Call(_objectEqualsMethodInfo,
Expression.Convert(CreatePropertyAccessExpression(nonNullEntityReference, p), typeof(object)),
Expression.Convert(Expression.Constant(null, p.ClrType.MakeNullable()), typeof(object)));
return nodeType == ExpressionType.Equal
? (Expression)comparison
Expand Down
Loading

0 comments on commit f203e1f

Please sign in to comment.