Skip to content

Commit

Permalink
Fix to #16078 - Query/Null semantics: when checking if expression is …
Browse files Browse the repository at this point in the history
…null, just check it's constituents rather than entire expression

Problem was that during null semantics rewrite we create IS NULL calls on the operands of the comparison. If the operands themselves are complicated, we were simply comparing the entire complex expression to null. In some cases, we only need to look at constituents, e.g. a + b == null <=> a == null || b == null.

Also added other minor optimizations around null semantics:

- non_nullable_column IS NULL resolves to false,
- try to simplify expression after applying de Morgan transformations

Also fixed a bug exposed by these changes, where column nullability would be incorrect for scenarios with owned types.
  • Loading branch information
maumar committed Oct 24, 2019
1 parent b52d5fc commit baf9d53
Show file tree
Hide file tree
Showing 12 changed files with 374 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,13 @@ private SqlBinaryExpression VisitSqlBinaryExpression(SqlBinaryExpression sqlBina
newRight = rightUnary.Operand;
}

// TODO: optimize this by looking at subcomponents, e.g. f(a, b) == null <=> a == null || b == null
var leftIsNull = _sqlExpressionFactory.IsNull(newLeft);
var rightIsNull = _sqlExpressionFactory.IsNull(newRight);
var isNullOptimizer = new IsNullOptimizingExpressionVisitor(_sqlExpressionFactory);

var leftIsNull = (SqlExpression)isNullOptimizer.Visit(_sqlExpressionFactory.IsNull(newLeft));
var rightIsNull = (SqlExpression)isNullOptimizer.Visit(_sqlExpressionFactory.IsNull(newRight));

// doing a full null semantics rewrite - removing all nulls from truth table
// this will NOT be correct once we introduce simplified null semantics
_isNullable = false;

if (sqlBinaryExpression.OperatorType == ExpressionType.Equal)
Expand Down Expand Up @@ -335,6 +337,50 @@ private SqlBinaryExpression VisitSqlBinaryExpression(SqlBinaryExpression sqlBina
return sqlBinaryExpression.Update(newLeft, newRight);
}

private class IsNullOptimizingExpressionVisitor : ExpressionVisitor
{
private readonly ISqlExpressionFactory _sqlExpressionFactory;

public IsNullOptimizingExpressionVisitor(ISqlExpressionFactory sqlExpressionFactory)
{
_sqlExpressionFactory = sqlExpressionFactory;
}

protected override Expression VisitExtension(Expression extensionExpression)
{
if (extensionExpression is SqlUnaryExpression sqlUnaryExpression
&& sqlUnaryExpression.OperatorType == ExpressionType.Equal)
{
if (sqlUnaryExpression.Operand is SqlUnaryExpression sqlUnaryOperand
&& (sqlUnaryOperand.OperatorType == ExpressionType.Convert
|| sqlUnaryOperand.OperatorType == ExpressionType.Not
|| sqlUnaryOperand.OperatorType == ExpressionType.Negate))
{
return (SqlExpression)Visit(_sqlExpressionFactory.IsNull(sqlUnaryOperand.Operand));
}

if (sqlUnaryExpression.Operand is SqlBinaryExpression sqlBinaryOperand)
{
var newLeft = (SqlExpression)Visit(_sqlExpressionFactory.IsNull(sqlBinaryOperand.Left));
var newRight = (SqlExpression)Visit(_sqlExpressionFactory.IsNull(sqlBinaryOperand.Right));

// (a ?? b) == null <=> a == null && b == null
return sqlBinaryOperand.OperatorType == ExpressionType.Coalesce
? _sqlExpressionFactory.AndAlso(newLeft, newRight)
: _sqlExpressionFactory.OrElse(newLeft, newRight);
}

if (sqlUnaryExpression.Operand is ColumnExpression columnOperand
&& !columnOperand.IsNullable)
{
return _sqlExpressionFactory.Constant(false, sqlUnaryExpression.TypeMapping);
}
}

return base.VisitExtension(extensionExpression);
}
}

private List<ColumnExpression> FindNonNullableColumns(SqlExpression sqlExpression)
{
var result = new List<ColumnExpression>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.Storage;

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
Expand Down Expand Up @@ -60,6 +61,15 @@ protected virtual Expression VisitSqlUnaryExpression(SqlUnaryExpression sqlUnary
return SqlExpressionFactory.Constant(innerConstantNull1.Value == null, sqlUnaryExpression.TypeMapping);
}

// non_nullable_column IS NULL -> false
// non_nullable_column IS NOT NULL -> true
if ((sqlUnaryExpression.OperatorType == ExpressionType.Equal || sqlUnaryExpression.OperatorType == ExpressionType.NotEqual)
&& sqlUnaryExpression.Operand is ColumnExpression innerColumn
&& !innerColumn.IsNullable)
{
return SqlExpressionFactory.Constant(sqlUnaryExpression.OperatorType == ExpressionType.NotEqual, sqlUnaryExpression.TypeMapping);
}

// NULL IS NOT NULL -> false
// non_nullable_constant IS NOT NULL -> true
if (sqlUnaryExpression.OperatorType == ExpressionType.NotEqual
Expand Down Expand Up @@ -135,9 +145,13 @@ private Expression VisitNot(SqlUnaryExpression sqlUnaryExpression)
var newLeft = (SqlExpression)Visit(SqlExpressionFactory.Not(innerBinary.Left));
var newRight = (SqlExpression)Visit(SqlExpressionFactory.Not(innerBinary.Right));

return innerBinary.OperatorType == ExpressionType.AndAlso
? SqlExpressionFactory.OrElse(newLeft, newRight)
: SqlExpressionFactory.AndAlso(newLeft, newRight);
return CreateSqlBinaryEqualityExpression(
innerBinary.OperatorType == ExpressionType.AndAlso
? ExpressionType.OrElse
: ExpressionType.AndAlso,
newLeft,
newRight,
innerBinary.TypeMapping);
}

// those optimizations are only valid in 2-value logic
Expand Down Expand Up @@ -168,36 +182,11 @@ private Expression VisitSqlBinaryExpression(SqlBinaryExpression sqlBinaryExpress
if (sqlBinaryExpression.OperatorType == ExpressionType.AndAlso
|| sqlBinaryExpression.OperatorType == ExpressionType.OrElse)
{
// true && a -> a
// true || a -> true
// false && a -> false
// false || a -> a
if (newLeft is SqlConstantExpression newLeftConstant)
{
return sqlBinaryExpression.OperatorType == ExpressionType.AndAlso
? (bool)newLeftConstant.Value
? newRight
: newLeftConstant
: (bool)newLeftConstant.Value
? newLeftConstant
: newRight;
}
else if (newRight is SqlConstantExpression newRightConstant)
{
// a && true -> a
// a || true -> true
// a && false -> false
// a || false -> a
return sqlBinaryExpression.OperatorType == ExpressionType.AndAlso
? (bool)newRightConstant.Value
? newLeft
: newRightConstant
: (bool)newRightConstant.Value
? newRightConstant
: newLeft;
}

return sqlBinaryExpression.Update(newLeft, newRight);
return CreateSqlBinaryEqualityExpression(
sqlBinaryExpression.OperatorType,
newLeft,
newRight,
sqlBinaryExpression.TypeMapping);
}

// those optimizations are only valid in 2-value logic
Expand Down Expand Up @@ -227,5 +216,43 @@ private Expression VisitSqlBinaryExpression(SqlBinaryExpression sqlBinaryExpress

return sqlBinaryExpression.Update(newLeft, newRight);
}

private SqlExpression CreateSqlBinaryEqualityExpression(
ExpressionType operatorType,
SqlExpression newLeft,
SqlExpression newRight,
RelationalTypeMapping typeMapping)
{
// true && a -> a
// true || a -> true
// false && a -> false
// false || a -> a
if (newLeft is SqlConstantExpression newLeftConstant)
{
return operatorType == ExpressionType.AndAlso
? (bool)newLeftConstant.Value
? newRight
: newLeftConstant
: (bool)newLeftConstant.Value
? newLeftConstant
: newRight;
}
else if (newRight is SqlConstantExpression newRightConstant)
{
// a && true -> a
// a || true -> true
// a && false -> false
// a || false -> a
return operatorType == ExpressionType.AndAlso
? (bool)newRightConstant.Value
? newLeft
: newRightConstant
: (bool)newRightConstant.Value
? newRightConstant
: newLeft;
}

return SqlExpressionFactory.MakeBinary(operatorType, newLeft, newRight, typeMapping);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,36 @@ protected override Expression VisitExtension(Expression extensionExpression)
{
var newSelectExpression = (SelectExpression)base.VisitExtension(extensionExpression);

return newSelectExpression.Predicate is SqlConstantExpression newSelectPredicateConstant
// if predicate is optimized to true, we can simply remove it
var newPredicate = newSelectExpression.Predicate is SqlConstantExpression newSelectPredicateConstant
&& !(selectExpression.Predicate is SqlConstantExpression)
? newSelectExpression.Update(
? (bool)newSelectPredicateConstant.Value
? null
: SqlExpressionFactory.Equal(
newSelectPredicateConstant,
SqlExpressionFactory.Constant(true, newSelectPredicateConstant.TypeMapping))
: newSelectExpression.Predicate;

var newHaving = newSelectExpression.Having is SqlConstantExpression newSelectHavingConstant
&& !(selectExpression.Having is SqlConstantExpression)
? (bool)newSelectHavingConstant.Value
? null
: SqlExpressionFactory.Equal(
newSelectHavingConstant,
SqlExpressionFactory.Constant(true, newSelectHavingConstant.TypeMapping))
: newSelectExpression.Having;

return newSelectExpression.Update(
newSelectExpression.Projection.ToList(),
newSelectExpression.Tables.ToList(),
SqlExpressionFactory.Equal(
newSelectPredicateConstant,
SqlExpressionFactory.Constant(true, newSelectPredicateConstant.TypeMapping)),
newPredicate,
newSelectExpression.GroupBy.ToList(),
newSelectExpression.Having,
newHaving,
newSelectExpression.Orderings.ToList(),
newSelectExpression.Limit,
newSelectExpression.Offset,
newSelectExpression.IsDistinct,
newSelectExpression.Alias)
: newSelectExpression;
newSelectExpression.Alias);
}

return base.VisitExtension(extensionExpression);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class ColumnExpression : SqlExpression
internal ColumnExpression(IProperty property, TableExpressionBase table, bool nullable)
: this(
property.GetColumnName(), table, property.ClrType, property.GetRelationalTypeMapping(),
nullable || property.IsNullable || property.DeclaringEntityType.BaseType != null)
nullable || property.IsColumnNullable())
{
}

Expand Down
101 changes: 101 additions & 0 deletions test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8006,6 +8006,107 @@ public virtual Task Group_by_with_aggregate_max_on_entity_type(bool isAsync)
})));
}

[ConditionalTheory(Skip = "issue #18492")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Group_by_on_StartsWith_with_null_parameter_as_argument(bool isAsync)
{
var prm = (string)null;

return AssertQueryScalar(
isAsync,
ss => ss.Set<Gear>().GroupBy(g => g.FullName.StartsWith(prm)).Select(g => g.Key));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Group_by_with_having_StartsWith_with_null_parameter_as_argument(bool isAsync)
{
var prm = (string)null;

return AssertQuery(
isAsync,
ss => ss.Set<Gear>().GroupBy(g => g.FullName).Where(g => g.Key.StartsWith(prm)).Select(g => g.Key),
ss => ss.Set<Gear>().GroupBy(g => g.FullName).Where(g => false).Select(g => g.Key));
}

[ConditionalTheory(Skip = "issue #18492")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_StartsWith_with_null_parameter_as_argument(bool isAsync)
{
var prm = (string)null;

return AssertQueryScalar(
isAsync,
ss => ss.Set<Gear>().Select(g => g.FullName.StartsWith(prm)),
ss => ss.Set<Gear>().Select(g => false));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_null_parameter_is_not_null(bool isAsync)
{
var prm = (string)null;

return AssertQueryScalar(
isAsync,
ss => ss.Set<Gear>().Select(g => prm != null),
ss => ss.Set<Gear>().Select(g => false));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_null_parameter_is_not_null(bool isAsync)
{
var prm = (string)null;

return AssertQuery(
isAsync,
ss => ss.Set<Gear>().Where(g => prm != null),
ss => ss.Set<Gear>().Where(g => false));
}

[ConditionalTheory(Skip = "issue #18492")]
[MemberData(nameof(IsAsyncData))]
public virtual Task OrderBy_StartsWith_with_null_parameter_as_argument(bool isAsync)
{
var prm = (string)null;

return AssertQuery(
isAsync,
ss => ss.Set<Gear>().OrderBy(g => g.FullName.StartsWith(prm)).ThenBy(g => g.Nickname),
ss => ss.Set<Gear>().OrderBy(g => false).ThenBy(g => g.Nickname),
assertOrder: true);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Where_with_enum_flags_parameter(bool isAsync)
{
MilitaryRank? rank = MilitaryRank.Private;

await AssertQuery(
isAsync,
ss => ss.Set<Gear>().Where(g => (g.Rank & rank) == rank));

rank = null;

await AssertQuery(
isAsync,
ss => ss.Set<Gear>().Where(g => (g.Rank & rank) == rank));

rank = MilitaryRank.Corporal;

await AssertQuery(
isAsync,
ss => ss.Set<Gear>().Where(g => (g.Rank | rank) != rank));

rank = null;

await AssertQuery(
isAsync,
ss => ss.Set<Gear>().Where(g => (g.Rank | rank) != rank));
}

protected async Task AssertTranslationFailed(Func<Task> testCode)
{
Assert.Contains(
Expand Down
Loading

0 comments on commit baf9d53

Please sign in to comment.