Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to #16078 - Query/Null semantics: when checking if expression is null, just check it's constituents rather than entire expression #18560

Merged
merged 1 commit into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,11 @@ private SqlBinaryExpression VisitSqlBinaryExpression(SqlBinaryExpression sqlBina
newRight = rightUnary.Operand;
}

// TODO: optimize this by looking at subcomponents, e.g. f(a, b) == null <=> a == null || b == null
var leftIsNull = _sqlExpressionFactory.IsNull(newLeft);
var rightIsNull = _sqlExpressionFactory.IsNull(newRight);

// doing a full null semantics rewrite - removing all nulls from truth table
// this will NOT be correct once we introduce simplified null semantics
_isNullable = false;

if (sqlBinaryExpression.OperatorType == ExpressionType.Equal)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.Storage;

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
Expand Down Expand Up @@ -47,42 +48,62 @@ protected override Expression VisitExtension(Expression extensionExpression)

protected virtual Expression VisitSqlUnaryExpression(SqlUnaryExpression sqlUnaryExpression)
{
if (sqlUnaryExpression.OperatorType == ExpressionType.Not)
switch (sqlUnaryExpression.OperatorType)
{
return VisitNot(sqlUnaryExpression);
}
case ExpressionType.Not:
return VisitNot(sqlUnaryExpression);

// NULL IS NULL -> true
// non_nullable_constant IS NULL -> false
if (sqlUnaryExpression.OperatorType == ExpressionType.Equal
&& sqlUnaryExpression.Operand is SqlConstantExpression innerConstantNull1)
{
return SqlExpressionFactory.Constant(innerConstantNull1.Value == null, sqlUnaryExpression.TypeMapping);
}
case ExpressionType.Equal:
switch (sqlUnaryExpression.Operand)
{
case SqlConstantExpression constantOperand:
return SqlExpressionFactory.Constant(constantOperand.Value == null, sqlUnaryExpression.TypeMapping);

// NULL IS NOT NULL -> false
// non_nullable_constant IS NOT NULL -> true
if (sqlUnaryExpression.OperatorType == ExpressionType.NotEqual
&& sqlUnaryExpression.Operand is SqlConstantExpression innerConstantNull2)
{
return SqlExpressionFactory.Constant(innerConstantNull2.Value != null, sqlUnaryExpression.TypeMapping);
}
case ColumnExpression columnOperand
when !columnOperand.IsNullable:
return SqlExpressionFactory.Constant(false, sqlUnaryExpression.TypeMapping);

if (sqlUnaryExpression.Operand is SqlUnaryExpression innerUnary)
{
// (!a) IS NULL <==> a IS NULL
if (sqlUnaryExpression.OperatorType == ExpressionType.Equal
&& innerUnary.OperatorType == ExpressionType.Not)
{
return Visit(SqlExpressionFactory.IsNull(innerUnary.Operand));
}
case SqlUnaryExpression sqlUnaryOperand
when sqlUnaryOperand.OperatorType == ExpressionType.Convert
|| sqlUnaryOperand.OperatorType == ExpressionType.Not
|| sqlUnaryOperand.OperatorType == ExpressionType.Negate:
return (SqlExpression)Visit(SqlExpressionFactory.IsNull(sqlUnaryOperand.Operand));

// (!a) IS NOT NULL <==> a IS NOT NULL
if (sqlUnaryExpression.OperatorType == ExpressionType.NotEqual
&& innerUnary.OperatorType == ExpressionType.Not)
{
return Visit(SqlExpressionFactory.IsNotNull(innerUnary.Operand));
}
case SqlBinaryExpression sqlBinaryOperand:
var newLeft = (SqlExpression)Visit(SqlExpressionFactory.IsNull(sqlBinaryOperand.Left));
var newRight = (SqlExpression)Visit(SqlExpressionFactory.IsNull(sqlBinaryOperand.Right));

return sqlBinaryOperand.OperatorType == ExpressionType.Coalesce
? SimplifyLogicalSqlBinaryExpression(ExpressionType.AndAlso, newLeft, newRight, sqlBinaryOperand.TypeMapping)
: SimplifyLogicalSqlBinaryExpression(ExpressionType.OrElse, newLeft, newRight, sqlBinaryOperand.TypeMapping);
}
break;

case ExpressionType.NotEqual:
switch (sqlUnaryExpression.Operand)
{
case SqlConstantExpression constantOperand:
return SqlExpressionFactory.Constant(constantOperand.Value != null, sqlUnaryExpression.TypeMapping);

case ColumnExpression columnOperand
when !columnOperand.IsNullable:
return SqlExpressionFactory.Constant(true, sqlUnaryExpression.TypeMapping);

case SqlUnaryExpression sqlUnaryOperand
when sqlUnaryOperand.OperatorType == ExpressionType.Convert
|| sqlUnaryOperand.OperatorType == ExpressionType.Not
|| sqlUnaryOperand.OperatorType == ExpressionType.Negate:
return (SqlExpression)Visit(SqlExpressionFactory.IsNotNull(sqlUnaryOperand.Operand));

case SqlBinaryExpression sqlBinaryOperand:
var newLeft = (SqlExpression)Visit(SqlExpressionFactory.IsNotNull(sqlBinaryOperand.Left));
var newRight = (SqlExpression)Visit(SqlExpressionFactory.IsNotNull(sqlBinaryOperand.Right));

return sqlBinaryOperand.OperatorType == ExpressionType.Coalesce
? SimplifyLogicalSqlBinaryExpression(ExpressionType.OrElse, newLeft, newRight, sqlBinaryOperand.TypeMapping)
: SimplifyLogicalSqlBinaryExpression(ExpressionType.AndAlso, newLeft, newRight, sqlBinaryOperand.TypeMapping);
}
break;
}

var newOperand = (SqlExpression)Visit(sqlUnaryExpression.Operand);
Expand Down Expand Up @@ -135,9 +156,13 @@ private Expression VisitNot(SqlUnaryExpression sqlUnaryExpression)
var newLeft = (SqlExpression)Visit(SqlExpressionFactory.Not(innerBinary.Left));
var newRight = (SqlExpression)Visit(SqlExpressionFactory.Not(innerBinary.Right));

return innerBinary.OperatorType == ExpressionType.AndAlso
? SqlExpressionFactory.OrElse(newLeft, newRight)
: SqlExpressionFactory.AndAlso(newLeft, newRight);
return SimplifyLogicalSqlBinaryExpression(
innerBinary.OperatorType == ExpressionType.AndAlso
maumar marked this conversation as resolved.
Show resolved Hide resolved
? ExpressionType.OrElse
: ExpressionType.AndAlso,
newLeft,
newRight,
innerBinary.TypeMapping);
}

// those optimizations are only valid in 2-value logic
Expand Down Expand Up @@ -168,36 +193,11 @@ private Expression VisitSqlBinaryExpression(SqlBinaryExpression sqlBinaryExpress
if (sqlBinaryExpression.OperatorType == ExpressionType.AndAlso
|| sqlBinaryExpression.OperatorType == ExpressionType.OrElse)
{
// true && a -> a
// true || a -> true
// false && a -> false
// false || a -> a
if (newLeft is SqlConstantExpression newLeftConstant)
{
return sqlBinaryExpression.OperatorType == ExpressionType.AndAlso
? (bool)newLeftConstant.Value
? newRight
: newLeftConstant
: (bool)newLeftConstant.Value
? newLeftConstant
: newRight;
}
else if (newRight is SqlConstantExpression newRightConstant)
{
// a && true -> a
// a || true -> true
// a && false -> false
// a || false -> a
return sqlBinaryExpression.OperatorType == ExpressionType.AndAlso
? (bool)newRightConstant.Value
? newLeft
: newRightConstant
: (bool)newRightConstant.Value
? newRightConstant
: newLeft;
}

return sqlBinaryExpression.Update(newLeft, newRight);
return SimplifyLogicalSqlBinaryExpression(
sqlBinaryExpression.OperatorType,
newLeft,
newRight,
sqlBinaryExpression.TypeMapping);
}

// those optimizations are only valid in 2-value logic
Expand Down Expand Up @@ -227,5 +227,43 @@ private Expression VisitSqlBinaryExpression(SqlBinaryExpression sqlBinaryExpress

return sqlBinaryExpression.Update(newLeft, newRight);
}

private SqlExpression SimplifyLogicalSqlBinaryExpression(
ExpressionType operatorType,
SqlExpression newLeft,
SqlExpression newRight,
RelationalTypeMapping typeMapping)
{
// true && a -> a
// true || a -> true
// false && a -> false
// false || a -> a
if (newLeft is SqlConstantExpression newLeftConstant)
{
return operatorType == ExpressionType.AndAlso
? (bool)newLeftConstant.Value
? newRight
: newLeftConstant
: (bool)newLeftConstant.Value
? newLeftConstant
: newRight;
}
else if (newRight is SqlConstantExpression newRightConstant)
{
// a && true -> a
// a || true -> true
// a && false -> false
// a || false -> a
return operatorType == ExpressionType.AndAlso
? (bool)newRightConstant.Value
? newLeft
: newRightConstant
: (bool)newRightConstant.Value
? newRightConstant
: newLeft;
}

return SqlExpressionFactory.MakeBinary(operatorType, newLeft, newRight, typeMapping);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,33 @@ protected override Expression VisitExtension(Expression extensionExpression)
{
var newSelectExpression = (SelectExpression)base.VisitExtension(extensionExpression);

return newSelectExpression.Predicate is SqlConstantExpression newSelectPredicateConstant
// if predicate is optimized to true, we can simply remove it
var newPredicate = newSelectExpression.Predicate is SqlConstantExpression newSelectPredicateConstant
&& !(selectExpression.Predicate is SqlConstantExpression)
? (bool)newSelectPredicateConstant.Value
? null
: SqlExpressionFactory.Equal(
roji marked this conversation as resolved.
Show resolved Hide resolved
newSelectPredicateConstant,
SqlExpressionFactory.Constant(true, newSelectPredicateConstant.TypeMapping))
: newSelectExpression.Predicate;

var newHaving = newSelectExpression.Having is SqlConstantExpression newSelectHavingConstant
&& !(selectExpression.Having is SqlConstantExpression)
? (bool)newSelectHavingConstant.Value
? null
: SqlExpressionFactory.Equal(
newSelectHavingConstant,
SqlExpressionFactory.Constant(true, newSelectHavingConstant.TypeMapping))
: newSelectExpression.Having;

return newPredicate != newSelectExpression.Predicate
|| newHaving != newSelectExpression.Having
? newSelectExpression.Update(
newSelectExpression.Projection.ToList(),
newSelectExpression.Tables.ToList(),
SqlExpressionFactory.Equal(
newSelectPredicateConstant,
SqlExpressionFactory.Constant(true, newSelectPredicateConstant.TypeMapping)),
newPredicate,
newSelectExpression.GroupBy.ToList(),
newSelectExpression.Having,
newHaving,
newSelectExpression.Orderings.ToList(),
newSelectExpression.Limit,
newSelectExpression.Offset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class ColumnExpression : SqlExpression
internal ColumnExpression(IProperty property, TableExpressionBase table, bool nullable)
: this(
property.GetColumnName(), table, property.ClrType, property.GetRelationalTypeMapping(),
nullable || property.IsNullable || property.DeclaringEntityType.BaseType != null)
nullable || property.IsColumnNullable())
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,25 @@ public virtual void Null_semantics_with_null_check_complex()
}
}

[ConditionalFact]
public virtual void IsNull_on_complex_expression()
{
using (var ctx = CreateContext())
{
var query1 = ctx.Entities1.Where(e => -e.NullableIntA != null).ToList();
Assert.Equal(18, query1.Count);

var query2 = ctx.Entities1.Where(e => (e.NullableIntA + e.NullableIntB) == null).ToList();
Assert.Equal(15, query2.Count);

var query3 = ctx.Entities1.Where(e => (e.NullableIntA ?? e.NullableIntB) == null).ToList();
Assert.Equal(3, query3.Count);

var query4 = ctx.Entities1.Where(e => (e.NullableIntA ?? e.NullableIntB) != null).ToList();
Assert.Equal(24, query4.Count);
}
}

protected static TResult Maybe<TResult>(object caller, Func<TResult> expression)
where TResult : class
{
Expand Down
101 changes: 101 additions & 0 deletions test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8006,6 +8006,107 @@ public virtual Task Group_by_with_aggregate_max_on_entity_type(bool isAsync)
})));
}

[ConditionalTheory(Skip = "issue #18492")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Group_by_on_StartsWith_with_null_parameter_as_argument(bool isAsync)
{
var prm = (string)null;

return AssertQueryScalar(
isAsync,
ss => ss.Set<Gear>().GroupBy(g => g.FullName.StartsWith(prm)).Select(g => g.Key));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Group_by_with_having_StartsWith_with_null_parameter_as_argument(bool isAsync)
{
var prm = (string)null;

return AssertQuery(
isAsync,
ss => ss.Set<Gear>().GroupBy(g => g.FullName).Where(g => g.Key.StartsWith(prm)).Select(g => g.Key),
ss => ss.Set<Gear>().GroupBy(g => g.FullName).Where(g => false).Select(g => g.Key));
}

[ConditionalTheory(Skip = "issue #18492")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_StartsWith_with_null_parameter_as_argument(bool isAsync)
{
var prm = (string)null;

return AssertQueryScalar(
isAsync,
ss => ss.Set<Gear>().Select(g => g.FullName.StartsWith(prm)),
ss => ss.Set<Gear>().Select(g => false));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_null_parameter_is_not_null(bool isAsync)
{
var prm = (string)null;

return AssertQueryScalar(
isAsync,
ss => ss.Set<Gear>().Select(g => prm != null),
ss => ss.Set<Gear>().Select(g => false));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_null_parameter_is_not_null(bool isAsync)
{
var prm = (string)null;

return AssertQuery(
isAsync,
ss => ss.Set<Gear>().Where(g => prm != null),
ss => ss.Set<Gear>().Where(g => false));
}

[ConditionalTheory(Skip = "issue #18492")]
[MemberData(nameof(IsAsyncData))]
public virtual Task OrderBy_StartsWith_with_null_parameter_as_argument(bool isAsync)
{
var prm = (string)null;

return AssertQuery(
isAsync,
ss => ss.Set<Gear>().OrderBy(g => g.FullName.StartsWith(prm)).ThenBy(g => g.Nickname),
ss => ss.Set<Gear>().OrderBy(g => false).ThenBy(g => g.Nickname),
assertOrder: true);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Where_with_enum_flags_parameter(bool isAsync)
{
MilitaryRank? rank = MilitaryRank.Private;

await AssertQuery(
isAsync,
ss => ss.Set<Gear>().Where(g => (g.Rank & rank) == rank));

rank = null;

await AssertQuery(
isAsync,
ss => ss.Set<Gear>().Where(g => (g.Rank & rank) == rank));

rank = MilitaryRank.Corporal;

await AssertQuery(
isAsync,
ss => ss.Set<Gear>().Where(g => (g.Rank | rank) != rank));

rank = null;

await AssertQuery(
isAsync,
ss => ss.Set<Gear>().Where(g => (g.Rank | rank) != rank));
}

protected async Task AssertTranslationFailed(Func<Task> testCode)
{
Assert.Contains(
Expand Down
Loading