Skip to content

Commit

Permalink
Query: Translate GetType == type on hierarchy (#28011)
Browse files Browse the repository at this point in the history
Resolves #13424
  • Loading branch information
smitpatel authored May 18, 2022
1 parent 6de40fe commit d69ba3c
Show file tree
Hide file tree
Showing 17 changed files with 1,085 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ private static readonly MethodInfo StringEqualsWithStringComparison
private static readonly MethodInfo StringEqualsWithStringComparisonStatic
= typeof(string).GetRuntimeMethod(nameof(string.Equals), new[] { typeof(string), typeof(string), typeof(StringComparison) });

private static readonly MethodInfo GetTypeMethodInfo = typeof(object).GetTypeInfo().GetDeclaredMethod(nameof(object.GetType))!;

private readonly QueryCompilationContext _queryCompilationContext;
private readonly IModel _model;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
Expand Down Expand Up @@ -145,6 +147,22 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
ifFalse));
}

if (binaryExpression.NodeType == ExpressionType.Equal || binaryExpression.NodeType == ExpressionType.NotEqual
&& binaryExpression.Left.Type == typeof(Type))
{
if (IsGetTypeMethodCall(binaryExpression.Left, out var entityReference1)
&& IsTypeConstant(binaryExpression.Right, out var type1))
{
return ProcessGetType(entityReference1!, type1!, binaryExpression.NodeType == ExpressionType.Equal);
}

if (IsGetTypeMethodCall(binaryExpression.Right, out var entityReference2)
&& IsTypeConstant(binaryExpression.Left, out var type2))
{
return ProcessGetType(entityReference2!, type2!, binaryExpression.NodeType == ExpressionType.Equal);
}
}

var left = TryRemoveImplicitConvert(binaryExpression.Left);
var right = TryRemoveImplicitConvert(binaryExpression.Right);

Expand Down Expand Up @@ -199,6 +217,77 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
sqlRight,
null);

Expression ProcessGetType(EntityReferenceExpression entityReferenceExpression, Type comparisonType, bool match)
{
var entityType = entityReferenceExpression.EntityType;

if (entityType.BaseType == null
&& !entityType.GetDirectlyDerivedTypes().Any())
{
// No hierarchy
return _sqlExpressionFactory.Constant((entityType.ClrType == comparisonType) == match);
}

if (entityType.GetAllBaseTypes().Any(e => e.ClrType == comparisonType))
{
// EntitySet will never contain a type of base type
return _sqlExpressionFactory.Constant(!match);
}

var derivedType = entityType.GetDerivedTypesInclusive().SingleOrDefault(et => et.ClrType == comparisonType);
// If no derived type matches then fail the translation
if (derivedType != null)
{
// If the derived type is abstract type then predicate will always be false
if (derivedType.IsAbstract())
{
return _sqlExpressionFactory.Constant(!match);
}

// Or add predicate for matching that particular type discriminator value
// All hierarchies have discriminator property
if (TryBindMember(entityReferenceExpression, MemberIdentity.Create(entityType.GetDiscriminatorPropertyName()))
is SqlExpression discriminatorColumn)
{
return match
? _sqlExpressionFactory.Equal(
discriminatorColumn,
_sqlExpressionFactory.Constant(derivedType.GetDiscriminatorValue()))
: _sqlExpressionFactory.NotEqual(
discriminatorColumn,
_sqlExpressionFactory.Constant(derivedType.GetDiscriminatorValue()));
}
}

return QueryCompilationContext.NotTranslatedExpression;
}

bool IsGetTypeMethodCall(Expression expression, out EntityReferenceExpression entityReferenceExpression)
{
entityReferenceExpression = null;
if (expression is not MethodCallExpression methodCallExpression
|| methodCallExpression.Method != GetTypeMethodInfo)
{
return false;
}

entityReferenceExpression = Visit(methodCallExpression.Object) as EntityReferenceExpression;
return entityReferenceExpression != null;
}

static bool IsTypeConstant(Expression expression, out Type type)
{
type = null;
if (expression is not UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression
|| unaryExpression.Operand is not ConstantExpression constantExpression)
{
return false;
}

type = constantExpression.Value as Type;
return type != null;
}

static bool TryUnwrapConvertToObject(Expression expression, out Expression operand)
{
if (expression is UnaryExpression convertExpression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ public class InMemoryExpressionTranslatingExpressionVisitor : ExpressionVisitor
private static readonly MethodInfo InMemoryLikeMethodInfo =
typeof(InMemoryExpressionTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(InMemoryLike))!;

private static readonly MethodInfo GetTypeMethodInfo = typeof(object).GetTypeInfo().GetDeclaredMethod(nameof(object.GetType))!;

// Regex special chars defined here:
// https://msdn.microsoft.com/en-us/library/4edbef7e(v=vs.110).aspx
private static readonly char[] RegexSpecialChars
Expand Down Expand Up @@ -205,6 +207,22 @@ static Expression RemoveConvert(Expression e)
}
}

if (binaryExpression.NodeType == ExpressionType.Equal || binaryExpression.NodeType == ExpressionType.NotEqual
&& binaryExpression.Left.Type == typeof(Type))
{
if (IsGetTypeMethodCall(binaryExpression.Left, out var entityReference1)
&& IsTypeConstant(binaryExpression.Right, out var type1))
{
return ProcessGetType(entityReference1!, type1!, binaryExpression.NodeType == ExpressionType.Equal);
}

if (IsGetTypeMethodCall(binaryExpression.Right, out var entityReference2)
&& IsTypeConstant(binaryExpression.Left, out var type2))
{
return ProcessGetType(entityReference2!, type2!, binaryExpression.NodeType == ExpressionType.Equal);
}
}

var newLeft = Visit(binaryExpression.Left);
var newRight = Visit(binaryExpression.Right);

Expand Down Expand Up @@ -317,6 +335,75 @@ static Expression RemoveConvert(Expression e)
binaryExpression.IsLiftedToNull,
binaryExpression.Method,
binaryExpression.Conversion);

Expression ProcessGetType(EntityReferenceExpression entityReferenceExpression, Type comparisonType, bool match)
{
var entityType = entityReferenceExpression.EntityType;

if (entityType.BaseType == null
&& !entityType.GetDirectlyDerivedTypes().Any())
{
// No hierarchy
return Expression.Constant((entityType.ClrType == comparisonType) == match);
}
if (entityType.GetAllBaseTypes().Any(e => e.ClrType == comparisonType))
{
// EntitySet will never contain a type of base type
return Expression.Constant(!match);
}

var derivedType = entityType.GetDerivedTypesInclusive().SingleOrDefault(et => et.ClrType == comparisonType);
// If no derived type matches then fail the translation
if (derivedType != null)
{
// If the derived type is abstract type then predicate will always be false
if (derivedType.IsAbstract())
{
return Expression.Constant(!match);
}

// Or add predicate for matching that particular type discriminator value
// All hierarchies have discriminator property
var discriminatorProperty = entityType.FindDiscriminatorProperty()!;
var boundProperty = BindProperty(entityReferenceExpression, discriminatorProperty, discriminatorProperty.ClrType);
// KeyValueComparer is not null at runtime
var valueComparer = discriminatorProperty.GetKeyValueComparer();

var result = valueComparer.ExtractEqualsBody(
boundProperty!,
Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType));

return match ? result : Expression.Not(result);
}

return QueryCompilationContext.NotTranslatedExpression;
}

bool IsGetTypeMethodCall(Expression expression, out EntityReferenceExpression? entityReferenceExpression)
{
entityReferenceExpression = null;
if (expression is not MethodCallExpression methodCallExpression
|| methodCallExpression.Method != GetTypeMethodInfo)
{
return false;
}

entityReferenceExpression = Visit(methodCallExpression.Object) as EntityReferenceExpression;
return entityReferenceExpression != null;
}

static bool IsTypeConstant(Expression expression, out Type? type)
{
type = null;
if (expression is not UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression
|| unaryExpression.Operand is not ConstantExpression constantExpression)
{
return false;
}

type = constantExpression.Value as Type;
return type != null;
}
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ private static readonly MethodInfo StringEqualsWithStringComparisonStatic
private static readonly MethodInfo ObjectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) })!;

private static readonly MethodInfo GetTypeMethodInfo = typeof(object).GetTypeInfo().GetDeclaredMethod(nameof(object.GetType))!;

private readonly QueryCompilationContext _queryCompilationContext;
private readonly IModel _model;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
Expand Down Expand Up @@ -290,6 +292,22 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
return Visit(ConvertObjectArrayEqualityComparison(binaryExpression.Left, binaryExpression.Right));
}

if (binaryExpression.NodeType == ExpressionType.Equal || binaryExpression.NodeType == ExpressionType.NotEqual
&& binaryExpression.Left.Type == typeof(Type))
{
if (IsGetTypeMethodCall(binaryExpression.Left, out var entityReference1)
&& IsTypeConstant(binaryExpression.Right, out var type1))
{
return ProcessGetType(entityReference1!, type1!, binaryExpression.NodeType == ExpressionType.Equal);
}

if (IsGetTypeMethodCall(binaryExpression.Right, out var entityReference2)
&& IsTypeConstant(binaryExpression.Left, out var type2))
{
return ProcessGetType(entityReference2!, type2!, binaryExpression.NodeType == ExpressionType.Equal);
}
}

var left = TryRemoveImplicitConvert(binaryExpression.Left);
var right = TryRemoveImplicitConvert(binaryExpression.Right);

Expand Down Expand Up @@ -389,6 +407,148 @@ static Expression RemoveConvert(Expression e)
null)
?? QueryCompilationContext.NotTranslatedExpression;

Expression ProcessGetType(EntityReferenceExpression entityReferenceExpression, Type comparisonType, bool match)
{
var entityType = entityReferenceExpression.EntityType;

if (entityType.BaseType == null
&& !entityType.GetDirectlyDerivedTypes().Any())
{
// No hierarchy
return _sqlExpressionFactory.Constant((entityType.ClrType == comparisonType) == match);
}

if (entityType.GetAllBaseTypes().Any(e => e.ClrType == comparisonType))
{
// EntitySet will never contain a type of base type
return _sqlExpressionFactory.Constant(!match);
}

var derivedType = entityType.GetDerivedTypesInclusive().SingleOrDefault(et => et.ClrType == comparisonType);
// If no derived type matches then fail the translation
if (derivedType != null)
{
// If the derived type is abstract type then predicate will always be false
if (derivedType.IsAbstract())
{
return _sqlExpressionFactory.Constant(!match);
}

// Or add predicate for matching that particular type discriminator value
var discriminatorProperty = entityType.FindDiscriminatorProperty();
if (discriminatorProperty == null)
{
// TPT or TPC
var discriminatorValue = derivedType.ShortName();
if (entityReferenceExpression.SubqueryEntity != null)
{
var entityShaper = (EntityShaperExpression)entityReferenceExpression.SubqueryEntity.ShaperExpression;
var entityProjection = (EntityProjectionExpression)Visit(entityShaper.ValueBufferExpression);
var subSelectExpression = (SelectExpression)entityReferenceExpression.SubqueryEntity.QueryExpression;

var predicate = GeneratePredicateTpt(entityProjection);

subSelectExpression.ApplyPredicate(predicate);
subSelectExpression.ReplaceProjection(new List<Expression>());
subSelectExpression.ApplyProjection();
if (subSelectExpression.Limit == null
&& subSelectExpression.Offset == null)
{
subSelectExpression.ClearOrdering();
}

return _sqlExpressionFactory.Exists(subSelectExpression, false);
}

if (entityReferenceExpression.ParameterEntity != null)
{
var entityProjection = (EntityProjectionExpression)Visit(
entityReferenceExpression.ParameterEntity.ValueBufferExpression);

return GeneratePredicateTpt(entityProjection);
}

SqlExpression GeneratePredicateTpt(EntityProjectionExpression entityProjectionExpression)
{
if (entityProjectionExpression.DiscriminatorExpression is CaseExpression caseExpression)
{
// TPT case
// Most root type doesn't have matching case
// All derived types needs to be excluded
var derivedTypeValues = derivedType.GetDerivedTypes().Where(e => !e.IsAbstract()).Select(e => e.ShortName()).ToList();
var predicates = new List<SqlExpression>();
foreach (var caseWhenClause in caseExpression.WhenClauses)
{
var value = (string)((SqlConstantExpression)caseWhenClause.Result).Value!;
if (value == discriminatorValue)
{
predicates.Add(caseWhenClause.Test);
}
else if (derivedTypeValues.Contains(value))
{
predicates.Add(_sqlExpressionFactory.Not(caseWhenClause.Test));
}
}

var result = predicates.Aggregate((a, b) => _sqlExpressionFactory.AndAlso(a, b));

return match ? result : _sqlExpressionFactory.Not(result);
}

return match
? _sqlExpressionFactory.Equal(
entityProjectionExpression.DiscriminatorExpression!,
_sqlExpressionFactory.Constant(discriminatorValue))
: _sqlExpressionFactory.NotEqual(
entityProjectionExpression.DiscriminatorExpression!,
_sqlExpressionFactory.Constant(discriminatorValue));
}
}
else
{
var discriminatorColumn = BindProperty(entityReferenceExpression, discriminatorProperty);
if (discriminatorColumn != null)
{
return match
? _sqlExpressionFactory.Equal(
discriminatorColumn,
_sqlExpressionFactory.Constant(derivedType.GetDiscriminatorValue()))
: _sqlExpressionFactory.NotEqual(
discriminatorColumn,
_sqlExpressionFactory.Constant(derivedType.GetDiscriminatorValue()));
}
}
}

return QueryCompilationContext.NotTranslatedExpression;
}

bool IsGetTypeMethodCall(Expression expression, out EntityReferenceExpression? entityReferenceExpression)
{
entityReferenceExpression = null;
if (expression is not MethodCallExpression methodCallExpression
|| methodCallExpression.Method != GetTypeMethodInfo)
{
return false;
}

entityReferenceExpression = Visit(methodCallExpression.Object) as EntityReferenceExpression;
return entityReferenceExpression != null;
}

static bool IsTypeConstant(Expression expression, out Type? type)
{
type = null;
if (expression is not UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression
|| unaryExpression.Operand is not ConstantExpression constantExpression)
{
return false;
}

type = constantExpression.Value as Type;
return type != null;
}

static bool TryUnwrapConvertToObject(Expression expression, out Expression? operand)
{
if (expression is UnaryExpression convertExpression
Expand Down
Loading

0 comments on commit d69ba3c

Please sign in to comment.