Skip to content

Commit

Permalink
Query: Translate GetType == type on hierarchy
Browse files Browse the repository at this point in the history
Resolves #13424
  • Loading branch information
smitpatel committed May 12, 2022
1 parent cf92261 commit c7e1af5
Show file tree
Hide file tree
Showing 14 changed files with 960 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ private static readonly MethodInfo StringEqualsWithStringComparison
private static readonly MethodInfo StringEqualsWithStringComparisonStatic
= typeof(string).GetRuntimeMethod(nameof(string.Equals), new[] { typeof(string), typeof(string), typeof(StringComparison) });

private static readonly MethodInfo GetTypeMethodInfo = typeof(object).GetTypeInfo().GetDeclaredMethod(nameof(object.GetType))!;

private readonly QueryCompilationContext _queryCompilationContext;
private readonly IModel _model;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
Expand Down Expand Up @@ -145,6 +147,86 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
ifFalse));
}

if (binaryExpression.NodeType == ExpressionType.Equal || binaryExpression.NodeType == ExpressionType.NotEqual
&& binaryExpression.Left.Type == typeof(Type))
{
if (IsGetTypeMethodCall(binaryExpression.Left, out var entityReference1)
&& IsTypeConstant(binaryExpression.Right, out var type1))
{
return ProcessGetType(entityReference1!, type1!, binaryExpression.NodeType == ExpressionType.Equal);
}

if (IsGetTypeMethodCall(binaryExpression.Right, out var entityReference2)
&& IsTypeConstant(binaryExpression.Left, out var type2))
{
return ProcessGetType(entityReference2!, type2!, binaryExpression.NodeType == ExpressionType.Equal);
}

Expression ProcessGetType(EntityReferenceExpression entityReferenceExpression, Type comparisonType, bool match)
{
var entityType = entityReferenceExpression.EntityType;
if (entityType.GetAllBaseTypes().Any(e => e.ClrType == comparisonType))
{
// EntitySet will never contain a type of base type
return _sqlExpressionFactory.Constant(!match);
}

var derivedType = entityType.GetDerivedTypesInclusive().SingleOrDefault(et => et.ClrType == comparisonType);
// If no derived type matches then fail the translation
if (derivedType != null)
{
// If the derived type is abstract type then predicate will always be false
if (derivedType.IsAbstract())
{
return _sqlExpressionFactory.Constant(!match);
}

// Or add predicate for matching that particular type discriminator value
// All hierarchies have discriminator property
var discriminatorProperty = entityType.FindDiscriminatorProperty()!;
if (TryBindMember(entityReferenceExpression, MemberIdentity.Create(entityType.GetDiscriminatorPropertyName()))
is SqlExpression discriminatorColumn)
{
return match
? _sqlExpressionFactory.Equal(
discriminatorColumn,
_sqlExpressionFactory.Constant(derivedType.GetDiscriminatorValue()))
: _sqlExpressionFactory.NotEqual(
discriminatorColumn,
_sqlExpressionFactory.Constant(derivedType.GetDiscriminatorValue()));
}
}

return QueryCompilationContext.NotTranslatedExpression;
}

bool IsGetTypeMethodCall(Expression expression, out EntityReferenceExpression entityReferenceExpression)
{
entityReferenceExpression = null;
if (expression is not MethodCallExpression methodCallExpression
|| methodCallExpression.Method != GetTypeMethodInfo)
{
return false;
}

entityReferenceExpression = Visit(methodCallExpression.Object) as EntityReferenceExpression;
return entityReferenceExpression != null;
}

static bool IsTypeConstant(Expression expression, out Type type)
{
type = null;
if (expression is not UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression
|| unaryExpression.Operand is not ConstantExpression constantExpression)
{
return false;
}

type = constantExpression.Value as Type;
return type != null;
}
}

var left = TryRemoveImplicitConvert(binaryExpression.Left);
var right = TryRemoveImplicitConvert(binaryExpression.Right);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ public class InMemoryExpressionTranslatingExpressionVisitor : ExpressionVisitor
private static readonly MethodInfo InMemoryLikeMethodInfo =
typeof(InMemoryExpressionTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(InMemoryLike))!;

private static readonly MethodInfo GetTypeMethodInfo = typeof(object).GetTypeInfo().GetDeclaredMethod(nameof(object.GetType))!;

// Regex special chars defined here:
// https://msdn.microsoft.com/en-us/library/4edbef7e(v=vs.110).aspx
private static readonly char[] RegexSpecialChars
Expand Down Expand Up @@ -205,6 +207,84 @@ static Expression RemoveConvert(Expression e)
}
}

if (binaryExpression.NodeType == ExpressionType.Equal || binaryExpression.NodeType == ExpressionType.NotEqual
&& binaryExpression.Left.Type == typeof(Type))
{
if (IsGetTypeMethodCall(binaryExpression.Left, out var entityReference1)
&& IsTypeConstant(binaryExpression.Right, out var type1))
{
return ProcessGetType(entityReference1!, type1!, binaryExpression.NodeType == ExpressionType.Equal);
}

if (IsGetTypeMethodCall(binaryExpression.Right, out var entityReference2)
&& IsTypeConstant(binaryExpression.Left, out var type2))
{
return ProcessGetType(entityReference2!, type2!, binaryExpression.NodeType == ExpressionType.Equal);
}

Expression ProcessGetType(EntityReferenceExpression entityReferenceExpression, Type comparisonType, bool match)
{
var entityType = entityReferenceExpression.EntityType;
if (entityType.GetAllBaseTypes().Any(e => e.ClrType == comparisonType))
{
// EntitySet will never contain a type of base type
return Expression.Constant(!match);
}

var derivedType = entityType.GetDerivedTypesInclusive().SingleOrDefault(et => et.ClrType == comparisonType);
// If no derived type matches then fail the translation
if (derivedType != null)
{
// If the derived type is abstract type then predicate will always be false
if (derivedType.IsAbstract())
{
return Expression.Constant(!match);
}

// Or add predicate for matching that particular type discriminator value
// All hierarchies have discriminator property
var discriminatorProperty = entityType.FindDiscriminatorProperty()!;
var boundProperty = BindProperty(entityReferenceExpression, discriminatorProperty, discriminatorProperty.ClrType);
// KeyValueComparer is not null at runtime
var valueComparer = discriminatorProperty.GetKeyValueComparer();

var result = valueComparer.ExtractEqualsBody(
boundProperty!,
Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType));

return match ? result : Expression.Not(result);
}

return QueryCompilationContext.NotTranslatedExpression;
}

bool IsGetTypeMethodCall(Expression expression, out EntityReferenceExpression? entityReferenceExpression)
{
entityReferenceExpression = null;
if (expression is not MethodCallExpression methodCallExpression
|| methodCallExpression.Method != GetTypeMethodInfo)
{
return false;
}

entityReferenceExpression = Visit(methodCallExpression.Object) as EntityReferenceExpression;
return entityReferenceExpression != null;
}

static bool IsTypeConstant(Expression expression, out Type? type)
{
type = null;
if (expression is not UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression
|| unaryExpression.Operand is not ConstantExpression constantExpression)
{
return false;
}

type = constantExpression.Value as Type;
return type != null;
}
}

var newLeft = Visit(binaryExpression.Left);
var newRight = Visit(binaryExpression.Right);

Expand Down
172 changes: 162 additions & 10 deletions src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ private static readonly MethodInfo StringEqualsWithStringComparisonStatic
private static readonly MethodInfo ObjectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) })!;

private static readonly MethodInfo GetTypeMethodInfo = typeof(object).GetTypeInfo().GetDeclaredMethod(nameof(object.GetType))!;

private readonly QueryCompilationContext _queryCompilationContext;
private readonly IModel _model;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
Expand Down Expand Up @@ -290,6 +292,156 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
return Visit(ConvertObjectArrayEqualityComparison(binaryExpression.Left, binaryExpression.Right));
}

if (binaryExpression.NodeType == ExpressionType.Equal || binaryExpression.NodeType == ExpressionType.NotEqual
&& binaryExpression.Left.Type == typeof(Type))
{
if (IsGetTypeMethodCall(binaryExpression.Left, out var entityReference1)
&& IsTypeConstant(binaryExpression.Right, out var type1))
{
return ProcessGetType(entityReference1!, type1!, binaryExpression.NodeType == ExpressionType.Equal);
}

if (IsGetTypeMethodCall(binaryExpression.Right, out var entityReference2)
&& IsTypeConstant(binaryExpression.Left, out var type2))
{
return ProcessGetType(entityReference2!, type2!, binaryExpression.NodeType == ExpressionType.Equal);
}

Expression ProcessGetType(EntityReferenceExpression entityReferenceExpression, Type comparisonType, bool match)
{
var entityType = entityReferenceExpression.EntityType;
if (entityType.GetAllBaseTypes().Any(e => e.ClrType == comparisonType))
{
// EntitySet will never contain a type of base type
return _sqlExpressionFactory.Constant(!match);
}

var derivedType = entityType.GetDerivedTypesInclusive().SingleOrDefault(et => et.ClrType == comparisonType);
// If no derived type matches then fail the translation
if (derivedType != null)
{
// If the derived type is abstract type then predicate will always be false
if (derivedType.IsAbstract())
{
return _sqlExpressionFactory.Constant(!match);
}

// Or add predicate for matching that particular type discriminator value
var discriminatorProperty = entityType.FindDiscriminatorProperty();
if (discriminatorProperty == null)
{
// TPT or TPC
var discriminatorValue = derivedType.ShortName();
if (entityReferenceExpression.SubqueryEntity != null)
{
var entityShaper = (EntityShaperExpression)entityReferenceExpression.SubqueryEntity.ShaperExpression;
var entityProjection = (EntityProjectionExpression)Visit(entityShaper.ValueBufferExpression);
var subSelectExpression = (SelectExpression)entityReferenceExpression.SubqueryEntity.QueryExpression;

var predicate = GeneratePredicateTpt(entityProjection);

subSelectExpression.ApplyPredicate(predicate);
subSelectExpression.ReplaceProjection(new List<Expression>());
subSelectExpression.ApplyProjection();
if (subSelectExpression.Limit == null
&& subSelectExpression.Offset == null)
{
subSelectExpression.ClearOrdering();
}

return _sqlExpressionFactory.Exists(subSelectExpression, false);
}

if (entityReferenceExpression.ParameterEntity != null)
{
var entityProjection = (EntityProjectionExpression)Visit(
entityReferenceExpression.ParameterEntity.ValueBufferExpression);

return GeneratePredicateTpt(entityProjection);
}

SqlExpression GeneratePredicateTpt(EntityProjectionExpression entityProjectionExpression)
{
if (entityProjectionExpression.DiscriminatorExpression is CaseExpression caseExpression)
{
// TPT case
// Most root type doesn't have matching case
// All derived types needs to be excluded
var derivedTypeValues = derivedType.GetDerivedTypes().Where(e => !e.IsAbstract()).Select(e => e.ShortName()).ToList();
var predicates = new List<SqlExpression>();
foreach (var caseWhenClause in caseExpression.WhenClauses)
{
var value = (string)((SqlConstantExpression)caseWhenClause.Result).Value!;
if (value == discriminatorValue)
{
predicates.Add(caseWhenClause.Test);
}
else if (derivedTypeValues.Contains(value))
{
predicates.Add(_sqlExpressionFactory.Not(caseWhenClause.Test));
}
}

var result = predicates.Aggregate((a, b) => _sqlExpressionFactory.AndAlso(a, b));

return match ? result : _sqlExpressionFactory.Not(result);
}

return match
? _sqlExpressionFactory.Equal(
entityProjectionExpression.DiscriminatorExpression!,
_sqlExpressionFactory.Constant(discriminatorValue))
: _sqlExpressionFactory.NotEqual(
entityProjectionExpression.DiscriminatorExpression!,
_sqlExpressionFactory.Constant(discriminatorValue));
}
}
else
{
var discriminatorColumn = BindProperty(entityReferenceExpression, discriminatorProperty);
if (discriminatorColumn != null)
{
return match
? _sqlExpressionFactory.Equal(
discriminatorColumn,
_sqlExpressionFactory.Constant(derivedType.GetDiscriminatorValue()))
: _sqlExpressionFactory.NotEqual(
discriminatorColumn,
_sqlExpressionFactory.Constant(derivedType.GetDiscriminatorValue()));
}
}
}

return QueryCompilationContext.NotTranslatedExpression;
}

bool IsGetTypeMethodCall(Expression expression, out EntityReferenceExpression? entityReferenceExpression)
{
entityReferenceExpression = null;
if (expression is not MethodCallExpression methodCallExpression
|| methodCallExpression.Method != GetTypeMethodInfo)
{
return false;
}

entityReferenceExpression = Visit(methodCallExpression.Object) as EntityReferenceExpression;
return entityReferenceExpression != null;
}

static bool IsTypeConstant(Expression expression, out Type? type)
{
type = null;
if (expression is not UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression
|| unaryExpression.Operand is not ConstantExpression constantExpression)
{
return false;
}

type = constantExpression.Value as Type;
return type != null;
}
}

var left = TryRemoveImplicitConvert(binaryExpression.Left);
var right = TryRemoveImplicitConvert(binaryExpression.Right);

Expand Down Expand Up @@ -1192,7 +1344,7 @@ private bool ProcessOrderByThenBy(
{
enumerableExpression.ApplyOrdering(orderingExpression);
}

return true;
}

Expand Down Expand Up @@ -1702,19 +1854,19 @@ public Expression Convert(Type type)
}

private sealed class SqlTypeMappingVerifyingExpressionVisitor : ExpressionVisitor
{
protected override Expression VisitExtension(Expression extensionExpression)
{
if (extensionExpression is SqlExpression sqlExpression
&& extensionExpression is not SqlFragmentExpression)
protected override Expression VisitExtension(Expression extensionExpression)
{
if (sqlExpression.TypeMapping == null)
if (extensionExpression is SqlExpression sqlExpression
&& extensionExpression is not SqlFragmentExpression)
{
throw new InvalidOperationException(RelationalStrings.NullTypeMappingInSqlTree(sqlExpression.Print()));
if (sqlExpression.TypeMapping == null)
{
throw new InvalidOperationException(RelationalStrings.NullTypeMappingInSqlTree(sqlExpression.Print()));
}
}
}

return base.VisitExtension(extensionExpression);
return base.VisitExtension(extensionExpression);
}
}
}
}
Loading

0 comments on commit c7e1af5

Please sign in to comment.