Skip to content

Commit

Permalink
Fix to #16078 - Query/Null semantics: when checking if expression is …
Browse files Browse the repository at this point in the history
…null, just check it's constituents rather than entire expression

Problem was that during null semantics rewrite we create IS NULL calls on the operands of the comparison. If the operands themselves are complicated, we were simply comparing the entire complex expression to null. In some cases, we only need to look at constituents, e.g. a + b == null <=> a == null || b == null.

Also added other minor optimizations around null semantics:

- non_nullable_column IS NULL resolves to false,
- try to simplify expression after applying de Morgan transformations

Also fixed a bug exposed by these changes, where column nullability would be incorrect for scenarios with owned types.
  • Loading branch information
maumar committed Oct 24, 2019
1 parent 4e796ce commit 3787fe4
Show file tree
Hide file tree
Showing 15 changed files with 442 additions and 97 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
public class IsNullOptimizingExpressionVisitor : ExpressionVisitor
{
private readonly ISqlExpressionFactory _sqlExpressionFactory;

public IsNullOptimizingExpressionVisitor(ISqlExpressionFactory sqlExpressionFactory)
{
_sqlExpressionFactory = sqlExpressionFactory;
}

protected override Expression VisitExtension(Expression extensionExpression)
{
if (extensionExpression is SqlUnaryExpression sqlUnaryExpression)
{
if (sqlUnaryExpression.OperatorType == ExpressionType.Equal)
{
switch (sqlUnaryExpression.Operand)
{
case SqlUnaryExpression sqlUnaryOperand
when sqlUnaryOperand.OperatorType == ExpressionType.Convert
|| sqlUnaryOperand.OperatorType == ExpressionType.Not
|| sqlUnaryOperand.OperatorType == ExpressionType.Negate:
return (SqlExpression)Visit(_sqlExpressionFactory.IsNull(sqlUnaryOperand.Operand));

case SqlBinaryExpression sqlBinaryOperand:
var newLeft = (SqlExpression)Visit(_sqlExpressionFactory.IsNull(sqlBinaryOperand.Left));
var newRight = (SqlExpression)Visit(_sqlExpressionFactory.IsNull(sqlBinaryOperand.Right));

return sqlBinaryOperand.OperatorType == ExpressionType.Coalesce
? _sqlExpressionFactory.AndAlso(newLeft, newRight)
: _sqlExpressionFactory.OrElse(newLeft, newRight);

case ColumnExpression columnOperand
when !columnOperand.IsNullable:
return _sqlExpressionFactory.Constant(false, sqlUnaryExpression.TypeMapping);
}
}

if (sqlUnaryExpression.OperatorType == ExpressionType.Equal)
{
switch (sqlUnaryExpression.Operand)
{
case SqlUnaryExpression sqlUnaryOperand
when sqlUnaryOperand.OperatorType == ExpressionType.Convert
|| sqlUnaryOperand.OperatorType == ExpressionType.Not
|| sqlUnaryOperand.OperatorType == ExpressionType.Negate:
return (SqlExpression)Visit(_sqlExpressionFactory.IsNotNull(sqlUnaryOperand.Operand));

case SqlBinaryExpression sqlBinaryOperand:
var newLeft = (SqlExpression)Visit(_sqlExpressionFactory.IsNotNull(sqlBinaryOperand.Left));
var newRight = (SqlExpression)Visit(_sqlExpressionFactory.IsNotNull(sqlBinaryOperand.Right));

return sqlBinaryOperand.OperatorType == ExpressionType.Coalesce
? _sqlExpressionFactory.OrElse(newLeft, newRight)
: _sqlExpressionFactory.AndAlso(newLeft, newRight);

case ColumnExpression columnOperand
when !columnOperand.IsNullable:
return _sqlExpressionFactory.Constant(true, sqlUnaryExpression.TypeMapping);
}
}
}

return base.VisitExtension(extensionExpression);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,11 @@ private SqlBinaryExpression VisitSqlBinaryExpression(SqlBinaryExpression sqlBina
newRight = rightUnary.Operand;
}

// TODO: optimize this by looking at subcomponents, e.g. f(a, b) == null <=> a == null || b == null
var leftIsNull = _sqlExpressionFactory.IsNull(newLeft);
var rightIsNull = _sqlExpressionFactory.IsNull(newRight);

// doing a full null semantics rewrite - removing all nulls from truth table
// this will NOT be correct once we introduce simplified null semantics
_isNullable = false;

if (sqlBinaryExpression.OperatorType == ExpressionType.Equal)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.Storage;

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
Expand Down Expand Up @@ -60,6 +61,16 @@ protected virtual Expression VisitSqlUnaryExpression(SqlUnaryExpression sqlUnary
return SqlExpressionFactory.Constant(innerConstantNull1.Value == null, sqlUnaryExpression.TypeMapping);
}

// non_nullable_column IS NULL -> false
// non_nullable_column IS NOT NULL -> true
if ((sqlUnaryExpression.OperatorType == ExpressionType.Equal
|| sqlUnaryExpression.OperatorType == ExpressionType.NotEqual)
&& sqlUnaryExpression.Operand is ColumnExpression innerColumn
&& !innerColumn.IsNullable)
{
return SqlExpressionFactory.Constant(sqlUnaryExpression.OperatorType == ExpressionType.NotEqual, sqlUnaryExpression.TypeMapping);
}

// NULL IS NOT NULL -> false
// non_nullable_constant IS NOT NULL -> true
if (sqlUnaryExpression.OperatorType == ExpressionType.NotEqual
Expand Down Expand Up @@ -135,9 +146,13 @@ private Expression VisitNot(SqlUnaryExpression sqlUnaryExpression)
var newLeft = (SqlExpression)Visit(SqlExpressionFactory.Not(innerBinary.Left));
var newRight = (SqlExpression)Visit(SqlExpressionFactory.Not(innerBinary.Right));

return innerBinary.OperatorType == ExpressionType.AndAlso
? SqlExpressionFactory.OrElse(newLeft, newRight)
: SqlExpressionFactory.AndAlso(newLeft, newRight);
return SimplifyLogicalSqlBinaryExpression(
innerBinary.OperatorType == ExpressionType.AndAlso
? ExpressionType.OrElse
: ExpressionType.AndAlso,
newLeft,
newRight,
innerBinary.TypeMapping);
}

// those optimizations are only valid in 2-value logic
Expand Down Expand Up @@ -168,36 +183,11 @@ private Expression VisitSqlBinaryExpression(SqlBinaryExpression sqlBinaryExpress
if (sqlBinaryExpression.OperatorType == ExpressionType.AndAlso
|| sqlBinaryExpression.OperatorType == ExpressionType.OrElse)
{
// true && a -> a
// true || a -> true
// false && a -> false
// false || a -> a
if (newLeft is SqlConstantExpression newLeftConstant)
{
return sqlBinaryExpression.OperatorType == ExpressionType.AndAlso
? (bool)newLeftConstant.Value
? newRight
: newLeftConstant
: (bool)newLeftConstant.Value
? newLeftConstant
: newRight;
}
else if (newRight is SqlConstantExpression newRightConstant)
{
// a && true -> a
// a || true -> true
// a && false -> false
// a || false -> a
return sqlBinaryExpression.OperatorType == ExpressionType.AndAlso
? (bool)newRightConstant.Value
? newLeft
: newRightConstant
: (bool)newRightConstant.Value
? newRightConstant
: newLeft;
}

return sqlBinaryExpression.Update(newLeft, newRight);
return SimplifyLogicalSqlBinaryExpression(
sqlBinaryExpression.OperatorType,
newLeft,
newRight,
sqlBinaryExpression.TypeMapping);
}

// those optimizations are only valid in 2-value logic
Expand Down Expand Up @@ -227,5 +217,43 @@ private Expression VisitSqlBinaryExpression(SqlBinaryExpression sqlBinaryExpress

return sqlBinaryExpression.Update(newLeft, newRight);
}

private SqlExpression SimplifyLogicalSqlBinaryExpression(
ExpressionType operatorType,
SqlExpression newLeft,
SqlExpression newRight,
RelationalTypeMapping typeMapping)
{
// true && a -> a
// true || a -> true
// false && a -> false
// false || a -> a
if (newLeft is SqlConstantExpression newLeftConstant)
{
return operatorType == ExpressionType.AndAlso
? (bool)newLeftConstant.Value
? newRight
: newLeftConstant
: (bool)newLeftConstant.Value
? newLeftConstant
: newRight;
}
else if (newRight is SqlConstantExpression newRightConstant)
{
// a && true -> a
// a || true -> true
// a && false -> false
// a || false -> a
return operatorType == ExpressionType.AndAlso
? (bool)newRightConstant.Value
? newLeft
: newRightConstant
: (bool)newRightConstant.Value
? newRightConstant
: newLeft;
}

return SqlExpressionFactory.MakeBinary(operatorType, newLeft, newRight, typeMapping);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public override Expression Process(Expression query)
query = new NullSemanticsRewritingExpressionVisitor(SqlExpressionFactory).Visit(query);
}

query = new IsNullOptimizingExpressionVisitor(SqlExpressionFactory).Visit(query);
query = OptimizeSqlExpression(query);
query = new NullComparisonTransformingExpressionVisitor().Visit(query);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,33 @@ protected override Expression VisitExtension(Expression extensionExpression)
{
var newSelectExpression = (SelectExpression)base.VisitExtension(extensionExpression);

return newSelectExpression.Predicate is SqlConstantExpression newSelectPredicateConstant
// if predicate is optimized to true, we can simply remove it
var newPredicate = newSelectExpression.Predicate is SqlConstantExpression newSelectPredicateConstant
&& !(selectExpression.Predicate is SqlConstantExpression)
? (bool)newSelectPredicateConstant.Value
? null
: SqlExpressionFactory.Equal(
newSelectPredicateConstant,
SqlExpressionFactory.Constant(true, newSelectPredicateConstant.TypeMapping))
: newSelectExpression.Predicate;

var newHaving = newSelectExpression.Having is SqlConstantExpression newSelectHavingConstant
&& !(selectExpression.Having is SqlConstantExpression)
? (bool)newSelectHavingConstant.Value
? null
: SqlExpressionFactory.Equal(
newSelectHavingConstant,
SqlExpressionFactory.Constant(true, newSelectHavingConstant.TypeMapping))
: newSelectExpression.Having;

return newPredicate != newSelectExpression.Predicate
|| newHaving != newSelectExpression.Having
? newSelectExpression.Update(
newSelectExpression.Projection.ToList(),
newSelectExpression.Tables.ToList(),
SqlExpressionFactory.Equal(
newSelectPredicateConstant,
SqlExpressionFactory.Constant(true, newSelectPredicateConstant.TypeMapping)),
newPredicate,
newSelectExpression.GroupBy.ToList(),
newSelectExpression.Having,
newHaving,
newSelectExpression.Orderings.ToList(),
newSelectExpression.Limit,
newSelectExpression.Offset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class ColumnExpression : SqlExpression
internal ColumnExpression(IProperty property, TableExpressionBase table, bool nullable)
: this(
property.GetColumnName(), table, property.ClrType, property.GetRelationalTypeMapping(),
nullable || property.IsNullable || property.DeclaringEntityType.BaseType != null)
nullable || property.IsColumnNullable())
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,25 @@ public virtual void Null_semantics_with_null_check_complex()
}
}

[ConditionalFact]
public virtual void IsNull_on_complex_expression()
{
using (var ctx = CreateContext())
{
var query1 = ctx.Entities1.Where(e => -e.NullableIntA != null).ToList();
Assert.Equal(18, query1.Count);

var query2 = ctx.Entities1.Where(e => (e.NullableIntA + e.NullableIntB) == null).ToList();
Assert.Equal(15, query2.Count);

var query3 = ctx.Entities1.Where(e => (e.NullableIntA ?? e.NullableIntB) == null).ToList();
Assert.Equal(3, query3.Count);

var query4 = ctx.Entities1.Where(e => (e.NullableIntA ?? e.NullableIntB) != null).ToList();
Assert.Equal(24, query4.Count);
}
}

protected static TResult Maybe<TResult>(object caller, Func<TResult> expression)
where TResult : class
{
Expand Down
Loading

0 comments on commit 3787fe4

Please sign in to comment.