Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Translate GetType == type on hierarchy #28011

Merged
1 commit merged into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ private static readonly MethodInfo StringEqualsWithStringComparison
private static readonly MethodInfo StringEqualsWithStringComparisonStatic
= typeof(string).GetRuntimeMethod(nameof(string.Equals), new[] { typeof(string), typeof(string), typeof(StringComparison) });

private static readonly MethodInfo GetTypeMethodInfo = typeof(object).GetTypeInfo().GetDeclaredMethod(nameof(object.GetType))!;

private readonly QueryCompilationContext _queryCompilationContext;
private readonly IModel _model;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
Expand Down Expand Up @@ -145,6 +147,22 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
ifFalse));
}

if (binaryExpression.NodeType == ExpressionType.Equal || binaryExpression.NodeType == ExpressionType.NotEqual
&& binaryExpression.Left.Type == typeof(Type))
{
if (IsGetTypeMethodCall(binaryExpression.Left, out var entityReference1)
&& IsTypeConstant(binaryExpression.Right, out var type1))
{
return ProcessGetType(entityReference1!, type1!, binaryExpression.NodeType == ExpressionType.Equal);
}

if (IsGetTypeMethodCall(binaryExpression.Right, out var entityReference2)
&& IsTypeConstant(binaryExpression.Left, out var type2))
{
return ProcessGetType(entityReference2!, type2!, binaryExpression.NodeType == ExpressionType.Equal);
}
}

var left = TryRemoveImplicitConvert(binaryExpression.Left);
var right = TryRemoveImplicitConvert(binaryExpression.Right);

Expand Down Expand Up @@ -199,6 +217,77 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
sqlRight,
null);

Expression ProcessGetType(EntityReferenceExpression entityReferenceExpression, Type comparisonType, bool match)
{
var entityType = entityReferenceExpression.EntityType;

if (entityType.BaseType == null
&& !entityType.GetDirectlyDerivedTypes().Any())
{
// No hierarchy
return _sqlExpressionFactory.Constant((entityType.ClrType == comparisonType) == match);
}

if (entityType.GetAllBaseTypes().Any(e => e.ClrType == comparisonType))
{
// EntitySet will never contain a type of base type
return _sqlExpressionFactory.Constant(!match);
}

var derivedType = entityType.GetDerivedTypesInclusive().SingleOrDefault(et => et.ClrType == comparisonType);
// If no derived type matches then fail the translation
if (derivedType != null)
{
// If the derived type is abstract type then predicate will always be false
if (derivedType.IsAbstract())
{
return _sqlExpressionFactory.Constant(!match);
}

// Or add predicate for matching that particular type discriminator value
// All hierarchies have discriminator property
if (TryBindMember(entityReferenceExpression, MemberIdentity.Create(entityType.GetDiscriminatorPropertyName()))
is SqlExpression discriminatorColumn)
{
return match
? _sqlExpressionFactory.Equal(
discriminatorColumn,
_sqlExpressionFactory.Constant(derivedType.GetDiscriminatorValue()))
: _sqlExpressionFactory.NotEqual(
discriminatorColumn,
_sqlExpressionFactory.Constant(derivedType.GetDiscriminatorValue()));
}
}

return QueryCompilationContext.NotTranslatedExpression;
}

bool IsGetTypeMethodCall(Expression expression, out EntityReferenceExpression entityReferenceExpression)
{
entityReferenceExpression = null;
if (expression is not MethodCallExpression methodCallExpression
|| methodCallExpression.Method != GetTypeMethodInfo)
{
return false;
}

entityReferenceExpression = Visit(methodCallExpression.Object) as EntityReferenceExpression;
return entityReferenceExpression != null;
}

static bool IsTypeConstant(Expression expression, out Type type)
{
type = null;
if (expression is not UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression
|| unaryExpression.Operand is not ConstantExpression constantExpression)
{
return false;
}

type = constantExpression.Value as Type;
return type != null;
}

static bool TryUnwrapConvertToObject(Expression expression, out Expression operand)
{
if (expression is UnaryExpression convertExpression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ public class InMemoryExpressionTranslatingExpressionVisitor : ExpressionVisitor
private static readonly MethodInfo InMemoryLikeMethodInfo =
typeof(InMemoryExpressionTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(InMemoryLike))!;

private static readonly MethodInfo GetTypeMethodInfo = typeof(object).GetTypeInfo().GetDeclaredMethod(nameof(object.GetType))!;

// Regex special chars defined here:
// https://msdn.microsoft.com/en-us/library/4edbef7e(v=vs.110).aspx
private static readonly char[] RegexSpecialChars
Expand Down Expand Up @@ -205,6 +207,22 @@ static Expression RemoveConvert(Expression e)
}
}

if (binaryExpression.NodeType == ExpressionType.Equal || binaryExpression.NodeType == ExpressionType.NotEqual
&& binaryExpression.Left.Type == typeof(Type))
{
if (IsGetTypeMethodCall(binaryExpression.Left, out var entityReference1)
&& IsTypeConstant(binaryExpression.Right, out var type1))
{
return ProcessGetType(entityReference1!, type1!, binaryExpression.NodeType == ExpressionType.Equal);
}

if (IsGetTypeMethodCall(binaryExpression.Right, out var entityReference2)
&& IsTypeConstant(binaryExpression.Left, out var type2))
{
return ProcessGetType(entityReference2!, type2!, binaryExpression.NodeType == ExpressionType.Equal);
}
}

var newLeft = Visit(binaryExpression.Left);
var newRight = Visit(binaryExpression.Right);

Expand Down Expand Up @@ -317,6 +335,75 @@ static Expression RemoveConvert(Expression e)
binaryExpression.IsLiftedToNull,
binaryExpression.Method,
binaryExpression.Conversion);

Expression ProcessGetType(EntityReferenceExpression entityReferenceExpression, Type comparisonType, bool match)
{
var entityType = entityReferenceExpression.EntityType;

if (entityType.BaseType == null
&& !entityType.GetDirectlyDerivedTypes().Any())
{
// No hierarchy
return Expression.Constant((entityType.ClrType == comparisonType) == match);
}
if (entityType.GetAllBaseTypes().Any(e => e.ClrType == comparisonType))
{
// EntitySet will never contain a type of base type
return Expression.Constant(!match);
}

var derivedType = entityType.GetDerivedTypesInclusive().SingleOrDefault(et => et.ClrType == comparisonType);
// If no derived type matches then fail the translation
if (derivedType != null)
{
// If the derived type is abstract type then predicate will always be false
if (derivedType.IsAbstract())
{
return Expression.Constant(!match);
}

// Or add predicate for matching that particular type discriminator value
// All hierarchies have discriminator property
var discriminatorProperty = entityType.FindDiscriminatorProperty()!;
var boundProperty = BindProperty(entityReferenceExpression, discriminatorProperty, discriminatorProperty.ClrType);
// KeyValueComparer is not null at runtime
var valueComparer = discriminatorProperty.GetKeyValueComparer();

var result = valueComparer.ExtractEqualsBody(
boundProperty!,
Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType));

return match ? result : Expression.Not(result);
}

return QueryCompilationContext.NotTranslatedExpression;
}

bool IsGetTypeMethodCall(Expression expression, out EntityReferenceExpression? entityReferenceExpression)
{
entityReferenceExpression = null;
if (expression is not MethodCallExpression methodCallExpression
|| methodCallExpression.Method != GetTypeMethodInfo)
{
return false;
}

entityReferenceExpression = Visit(methodCallExpression.Object) as EntityReferenceExpression;
return entityReferenceExpression != null;
}

static bool IsTypeConstant(Expression expression, out Type? type)
{
type = null;
if (expression is not UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression
|| unaryExpression.Operand is not ConstantExpression constantExpression)
{
return false;
}

type = constantExpression.Value as Type;
return type != null;
}
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ private static readonly MethodInfo StringEqualsWithStringComparisonStatic
private static readonly MethodInfo ObjectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) })!;

private static readonly MethodInfo GetTypeMethodInfo = typeof(object).GetTypeInfo().GetDeclaredMethod(nameof(object.GetType))!;

private readonly QueryCompilationContext _queryCompilationContext;
private readonly IModel _model;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
Expand Down Expand Up @@ -290,6 +292,22 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
return Visit(ConvertObjectArrayEqualityComparison(binaryExpression.Left, binaryExpression.Right));
}

if (binaryExpression.NodeType == ExpressionType.Equal || binaryExpression.NodeType == ExpressionType.NotEqual
&& binaryExpression.Left.Type == typeof(Type))
{
if (IsGetTypeMethodCall(binaryExpression.Left, out var entityReference1)
&& IsTypeConstant(binaryExpression.Right, out var type1))
{
return ProcessGetType(entityReference1!, type1!, binaryExpression.NodeType == ExpressionType.Equal);
}

if (IsGetTypeMethodCall(binaryExpression.Right, out var entityReference2)
&& IsTypeConstant(binaryExpression.Left, out var type2))
{
return ProcessGetType(entityReference2!, type2!, binaryExpression.NodeType == ExpressionType.Equal);
}
}

var left = TryRemoveImplicitConvert(binaryExpression.Left);
var right = TryRemoveImplicitConvert(binaryExpression.Right);

Expand Down Expand Up @@ -389,6 +407,148 @@ static Expression RemoveConvert(Expression e)
null)
?? QueryCompilationContext.NotTranslatedExpression;

Expression ProcessGetType(EntityReferenceExpression entityReferenceExpression, Type comparisonType, bool match)
{
var entityType = entityReferenceExpression.EntityType;

if (entityType.BaseType == null
&& !entityType.GetDirectlyDerivedTypes().Any())
{
// No hierarchy
return _sqlExpressionFactory.Constant((entityType.ClrType == comparisonType) == match);
}

if (entityType.GetAllBaseTypes().Any(e => e.ClrType == comparisonType))
{
// EntitySet will never contain a type of base type
return _sqlExpressionFactory.Constant(!match);
}

var derivedType = entityType.GetDerivedTypesInclusive().SingleOrDefault(et => et.ClrType == comparisonType);
// If no derived type matches then fail the translation
if (derivedType != null)
{
// If the derived type is abstract type then predicate will always be false
if (derivedType.IsAbstract())
{
return _sqlExpressionFactory.Constant(!match);
}

// Or add predicate for matching that particular type discriminator value
var discriminatorProperty = entityType.FindDiscriminatorProperty();
if (discriminatorProperty == null)
{
// TPT or TPC
var discriminatorValue = derivedType.ShortName();
if (entityReferenceExpression.SubqueryEntity != null)
{
var entityShaper = (EntityShaperExpression)entityReferenceExpression.SubqueryEntity.ShaperExpression;
var entityProjection = (EntityProjectionExpression)Visit(entityShaper.ValueBufferExpression);
var subSelectExpression = (SelectExpression)entityReferenceExpression.SubqueryEntity.QueryExpression;

var predicate = GeneratePredicateTpt(entityProjection);

subSelectExpression.ApplyPredicate(predicate);
subSelectExpression.ReplaceProjection(new List<Expression>());
subSelectExpression.ApplyProjection();
if (subSelectExpression.Limit == null
&& subSelectExpression.Offset == null)
{
subSelectExpression.ClearOrdering();
}

return _sqlExpressionFactory.Exists(subSelectExpression, false);
}

if (entityReferenceExpression.ParameterEntity != null)
{
var entityProjection = (EntityProjectionExpression)Visit(
entityReferenceExpression.ParameterEntity.ValueBufferExpression);

return GeneratePredicateTpt(entityProjection);
}

SqlExpression GeneratePredicateTpt(EntityProjectionExpression entityProjectionExpression)
{
if (entityProjectionExpression.DiscriminatorExpression is CaseExpression caseExpression)
{
// TPT case
// Most root type doesn't have matching case
// All derived types needs to be excluded
var derivedTypeValues = derivedType.GetDerivedTypes().Where(e => !e.IsAbstract()).Select(e => e.ShortName()).ToList();
var predicates = new List<SqlExpression>();
foreach (var caseWhenClause in caseExpression.WhenClauses)
{
var value = (string)((SqlConstantExpression)caseWhenClause.Result).Value!;
if (value == discriminatorValue)
{
predicates.Add(caseWhenClause.Test);
}
else if (derivedTypeValues.Contains(value))
{
predicates.Add(_sqlExpressionFactory.Not(caseWhenClause.Test));
}
}

var result = predicates.Aggregate((a, b) => _sqlExpressionFactory.AndAlso(a, b));

return match ? result : _sqlExpressionFactory.Not(result);
}

return match
? _sqlExpressionFactory.Equal(
entityProjectionExpression.DiscriminatorExpression!,
_sqlExpressionFactory.Constant(discriminatorValue))
: _sqlExpressionFactory.NotEqual(
entityProjectionExpression.DiscriminatorExpression!,
_sqlExpressionFactory.Constant(discriminatorValue));
}
}
else
{
var discriminatorColumn = BindProperty(entityReferenceExpression, discriminatorProperty);
if (discriminatorColumn != null)
{
return match
? _sqlExpressionFactory.Equal(
discriminatorColumn,
_sqlExpressionFactory.Constant(derivedType.GetDiscriminatorValue()))
: _sqlExpressionFactory.NotEqual(
discriminatorColumn,
_sqlExpressionFactory.Constant(derivedType.GetDiscriminatorValue()));
}
}
}

return QueryCompilationContext.NotTranslatedExpression;
}

bool IsGetTypeMethodCall(Expression expression, out EntityReferenceExpression? entityReferenceExpression)
{
entityReferenceExpression = null;
if (expression is not MethodCallExpression methodCallExpression
|| methodCallExpression.Method != GetTypeMethodInfo)
{
return false;
}

entityReferenceExpression = Visit(methodCallExpression.Object) as EntityReferenceExpression;
return entityReferenceExpression != null;
}

static bool IsTypeConstant(Expression expression, out Type? type)
{
type = null;
if (expression is not UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression
|| unaryExpression.Operand is not ConstantExpression constantExpression)
{
return false;
}

type = constantExpression.Value as Type;
return type != null;
}

static bool TryUnwrapConvertToObject(Expression expression, out Expression? operand)
{
if (expression is UnaryExpression convertExpression
Expand Down
Loading