diff --git a/src/EFCore.Relational/Query/PipeLine/NullSemanticsRewritingVisitor.cs b/src/EFCore.Relational/Query/PipeLine/NullSemanticsRewritingVisitor.cs new file mode 100644 index 00000000000..eeac58f64eb --- /dev/null +++ b/src/EFCore.Relational/Query/PipeLine/NullSemanticsRewritingVisitor.cs @@ -0,0 +1,519 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore.Relational.Query.Pipeline; +using Microsoft.EntityFrameworkCore.Relational.Query.Pipeline.SqlExpressions; + +namespace Microsoft.EntityFrameworkCore.Query.Pipeline +{ + public class NullSemanticsRewritingVisitor : ExpressionVisitor + { + private readonly ISqlExpressionFactory _sqlExpressionFactory; + + private bool _isNullable = false; + + public NullSemanticsRewritingVisitor(ISqlExpressionFactory sqlExpressionFactory) + { + _sqlExpressionFactory = sqlExpressionFactory; + } + + protected override Expression VisitExtension(Expression extensionExpression) + { + switch (extensionExpression) + { + case SqlConstantExpression sqlConstantExpression: + return VisitSqlConstantExpression(sqlConstantExpression); + + case ColumnExpression columnExpression: + return VisitColumnExpression(columnExpression); + + case SqlParameterExpression sqlParameterExpression: + return VisitSqlParameterExpression(sqlParameterExpression); + + case SqlUnaryExpression sqlUnaryExpression: + return VisitSqlUnaryExpression(sqlUnaryExpression); + + case LikeExpression likeExpression: + return VisitLikeExpression(likeExpression); + + case SqlFunctionExpression sqlFunctionExpression: + return VisitSqlFunctionExpression(sqlFunctionExpression); + + case SqlBinaryExpression sqlBinaryExpression: + return VisitSqlBinaryExpression(sqlBinaryExpression); + + case CaseExpression caseExpression: + return VisitCaseExpression(caseExpression); + + case InnerJoinExpression innerJoinExpression: + return VisitInnerJoinExpression(innerJoinExpression); + + case LeftJoinExpression leftJoinExpression: + return VisitLeftJoinExpression(leftJoinExpression); + + default: + return base.VisitExtension(extensionExpression); + } + } + + private SqlConstantExpression VisitSqlConstantExpression(SqlConstantExpression sqlConstantExpression) + { + _isNullable = sqlConstantExpression.Value == null; + + return sqlConstantExpression; + } + + private ColumnExpression VisitColumnExpression(ColumnExpression columnExpression) + { + _isNullable = columnExpression.Nullable; + + return columnExpression; + } + + private SqlParameterExpression VisitSqlParameterExpression(SqlParameterExpression sqlParameterExpression) + { + // at this point we assume every parameter is nullable, we will filter out the non-nullable ones once we know the actual values + _isNullable = true; + + return sqlParameterExpression; + } + + private SqlUnaryExpression VisitSqlUnaryExpression(SqlUnaryExpression sqlUnaryExpression) + { + var newOperand = (SqlExpression)Visit(sqlUnaryExpression.Operand); + + // IsNull/IsNotNull + if (sqlUnaryExpression.OperatorType == ExpressionType.Equal + || sqlUnaryExpression.OperatorType == ExpressionType.NotEqual) + { + _isNullable = false; + } + + return sqlUnaryExpression.Update(newOperand); + } + + private LikeExpression VisitLikeExpression(LikeExpression likeExpression) + { + var newMatch = (SqlExpression)Visit(likeExpression.Match); + var isNullable = _isNullable; + var newPattern = (SqlExpression)Visit(likeExpression.Pattern); + isNullable |= _isNullable; + var newEscapeChar = (SqlExpression)Visit(likeExpression.EscapeChar); + _isNullable |= isNullable; + + return likeExpression.Update(newMatch, newPattern, newEscapeChar); + } + + private InnerJoinExpression VisitInnerJoinExpression(InnerJoinExpression innerJoinExpression) + { + var newTable = (TableExpressionBase)Visit(innerJoinExpression.Table); + var newJoinPredicate = VisitJoinPredicate((SqlBinaryExpression)innerJoinExpression.JoinPredicate); + + return innerJoinExpression.Update(newTable, newJoinPredicate); + } + + private LeftJoinExpression VisitLeftJoinExpression(LeftJoinExpression leftJoinExpression) + { + var newTable = (TableExpressionBase)Visit(leftJoinExpression.Table); + var newJoinPredicate = VisitJoinPredicate((SqlBinaryExpression)leftJoinExpression.JoinPredicate); + + return leftJoinExpression.Update(newTable, newJoinPredicate); + } + + private SqlExpression VisitJoinPredicate(SqlBinaryExpression predicate) + { + if (predicate.OperatorType == ExpressionType.Equal) + { + var newLeft = (SqlExpression)Visit(predicate.Left); + var newRight = (SqlExpression)Visit(predicate.Right); + + return predicate.Update(newLeft, newRight); + } + + if (predicate.OperatorType == ExpressionType.AndAlso) + { + return VisitSqlBinaryExpression(predicate); + } + + throw new InvalidOperationException("Unexpected join predicate shape: " + predicate); + } + + private CaseExpression VisitCaseExpression(CaseExpression caseExpression) + { + // if there is no 'else' there is a possibility of null, when none of the conditions are met + // otherwise the result is nullable if any of the WhenClause results OR ElseResult is nullable + var isNullable = caseExpression.ElseResult == null; + + var newOperand = (SqlExpression)Visit(caseExpression.Operand); + var newWhenClauses = new List(); + foreach (var whenClause in caseExpression.WhenClauses) + { + var newTest = (SqlExpression)Visit(whenClause.Test); + var newResult = (SqlExpression)Visit(whenClause.Result); + isNullable |= _isNullable; + newWhenClauses.Add(new CaseWhenClause(newTest, newResult)); + } + + var newElseResult = (SqlExpression)Visit(caseExpression.ElseResult); + _isNullable |= isNullable; + + return caseExpression.Update(newOperand, newWhenClauses, newElseResult); + } + + private SqlFunctionExpression VisitSqlFunctionExpression(SqlFunctionExpression sqlFunctionExpression) + { + var newInstance = (SqlExpression)Visit(sqlFunctionExpression.Instance); + var isNullable = _isNullable; + var newArguments = new SqlExpression[sqlFunctionExpression.Arguments.Count]; + for (var i = 0; i < newArguments.Length; i++) + { + newArguments[i] = (SqlExpression)Visit(sqlFunctionExpression.Arguments[i]); + isNullable |= _isNullable; + } + + _isNullable = isNullable; + + return sqlFunctionExpression.Update(newInstance, newArguments); + } + + private SqlBinaryExpression VisitSqlBinaryExpression(SqlBinaryExpression sqlBinaryExpression) + { + var newLeft = (SqlExpression)Visit(sqlBinaryExpression.Left); + var leftNullable = _isNullable; + + var newRight = (SqlExpression)Visit(sqlBinaryExpression.Right); + var rightNullable = _isNullable; + + if (sqlBinaryExpression.OperatorType == ExpressionType.Coalesce) + { + _isNullable = leftNullable && rightNullable; + + return sqlBinaryExpression.Update(newLeft, newRight); + } + + if (sqlBinaryExpression.OperatorType == ExpressionType.Equal + || sqlBinaryExpression.OperatorType == ExpressionType.NotEqual) + { + var leftUnary = newLeft as SqlUnaryExpression; + var rightUnary = newRight as SqlUnaryExpression; + + var leftNegated = leftUnary?.OperatorType == ExpressionType.Not; + var rightNegated = rightUnary?.OperatorType == ExpressionType.Not; + + if (leftNegated) + { + newLeft = leftUnary.Operand; + } + + if (rightNegated) + { + newRight = rightUnary.Operand; + } + + // TODO: optimize this by looking at subcomponents, e.g. f(a, b) == null <=> a == null || b == null + var leftIsNull = _sqlExpressionFactory.IsNull(newLeft); + var rightIsNull = _sqlExpressionFactory.IsNull(newRight); + + // doing a full null semantics rewrite - removing all nulls from truth table + _isNullable = false; + + if (sqlBinaryExpression.OperatorType == ExpressionType.Equal) + { + if (!leftNullable && !rightNullable) + { + // a == b <=> !a == !b -> a == b + // !a == b <=> a == !b -> a != b + return leftNegated == rightNegated + ? _sqlExpressionFactory.Equal(newLeft, newRight) + : _sqlExpressionFactory.NotEqual(newLeft, newRight); + } + + if (leftNullable && rightNullable) + { + // ?a == ?b <=> !(?a) == !(?b) -> [(a == b) && (a != null && b != null)] || (a == null && b == null)) + // !(?a) == ?b <=> ?a == !(?b) -> [(a != b) && (a != null && b != null)] || (a == null && b == null) + return leftNegated == rightNegated + ? ExpandNullableEqualNullable(newLeft, newRight, leftIsNull, rightIsNull) + : ExpandNegatedNullableEqualNullable(newLeft, newRight, leftIsNull, rightIsNull); + } + + if (leftNullable && !rightNullable) + { + // ?a == b <=> !(?a) == !b -> (a == b) && (a != null) + // !(?a) == b <=> ?a == !b -> (a != b) && (a != null) + return leftNegated == rightNegated + ? ExpandNullableEqualNonNullable(newLeft, newRight, leftIsNull) + : ExpandNegatedNullableEqualNonNullable(newLeft, newRight, leftIsNull); + } + + if (rightNullable && !leftNullable) + { + // a == ?b <=> !a == !(?b) -> (a == b) && (b != null) + // !a == ?b <=> a == !(?b) -> (a != b) && (b != null) + return leftNegated == rightNegated + ? ExpandNullableEqualNonNullable(newLeft, newRight, rightIsNull) + : ExpandNegatedNullableEqualNonNullable(newLeft, newRight, rightIsNull); + } + } + + if (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual) + { + if (!leftNullable && !rightNullable) + { + // a != b <=> !a != !b -> a != b + // !a != b <=> a != !b -> a == b + return leftNegated == rightNegated + ? _sqlExpressionFactory.NotEqual(newLeft, newRight) + : _sqlExpressionFactory.Equal(newLeft, newRight); + } + + if (leftNullable && rightNullable) + { + // ?a != ?b <=> !(?a) != !(?b) -> [(a != b) || (a == null || b == null)] && (a != null || b != null) + // !(?a) != ?b <=> ?a != !(?b) -> [(a == b) || (a == null || b == null)] && (a != null || b != null) + return leftNegated == rightNegated + ? ExpandNullableNotEqualNullable(newLeft, newRight, leftIsNull, rightIsNull) + : ExpandNegatedNullableNotEqualNullable(newLeft, newRight, leftIsNull, rightIsNull); + } + + if (leftNullable) + { + // ?a != b <=> !(?a) != !b -> (a != b) || (a == null) + // !(?a) != b <=> ?a != !b -> (a == b) || (a == null) + return leftNegated == rightNegated + ? ExpandNullableNotEqualNonNullable(newLeft, newRight, leftIsNull) + : ExpandNegatedNullableNotEqualNonNullable(newLeft, newRight, leftIsNull); + } + + if (rightNullable) + { + // a != ?b <=> !a != !(?b) -> (a != b) || (b == null) + // !a != ?b <=> a != !(?b) -> (a == b) || (b == null) + return leftNegated == rightNegated + ? ExpandNullableNotEqualNonNullable(newLeft, newRight, rightIsNull) + : ExpandNegatedNullableNotEqualNonNullable(newLeft, newRight, rightIsNull); + } + } + } + + _isNullable = leftNullable || rightNullable; + + return sqlBinaryExpression.Update(newLeft, newRight); + } + + // ?a == ?b -> [(a == b) && (a != null && b != null)] || (a == null && b == null)) + // + // a | b | F1 = a == b | F2 = (a != null && b != null) | F3 = F1 && F2 | + // | | | | | + // 0 | 0 | 1 | 1 | 1 | + // 0 | 1 | 0 | 1 | 0 | + // 0 | N | N | 0 | 0 | + // 1 | 0 | 0 | 1 | 0 | + // 1 | 1 | 1 | 1 | 1 | + // 1 | N | N | 0 | 0 | + // N | 0 | N | 0 | 0 | + // N | 1 | N | 0 | 0 | + // N | N | N | 0 | 0 | + // + // a | b | F4 = (a == null && b == null) | Final = F3 OR F4 | + // | | | | + // 0 | 0 | 0 | 1 OR 0 = 1 | + // 0 | 1 | 0 | 0 OR 0 = 0 | + // 0 | N | 0 | 0 OR 0 = 0 | + // 1 | 0 | 0 | 0 OR 0 = 0 | + // 1 | 1 | 0 | 1 OR 0 = 1 | + // 1 | N | 0 | 0 OR 0 = 0 | + // N | 0 | 0 | 0 OR 0 = 0 | + // N | 1 | 0 | 0 OR 0 = 0 | + // N | N | 1 | 0 OR 1 = 1 | + private SqlBinaryExpression ExpandNullableEqualNullable( + SqlExpression left, SqlExpression right, SqlExpression leftIsNull, SqlExpression rightIsNull) + => _sqlExpressionFactory.OrElse( + _sqlExpressionFactory.AndAlso( + _sqlExpressionFactory.Equal(left, right), + _sqlExpressionFactory.AndAlso( + _sqlExpressionFactory.Not(leftIsNull), + _sqlExpressionFactory.Not(rightIsNull))), + _sqlExpressionFactory.AndAlso( + leftIsNull, + rightIsNull)); + + // !(?a) == ?b -> [(a != b) && (a != null && b != null)] || (a == null && b == null) + // + // a | b | F1 = a != b | F2 = (a != null && b != null) | F3 = F1 && F2 | + // | | | | | + // 0 | 0 | 0 | 1 | 0 | + // 0 | 1 | 1 | 1 | 1 | + // 0 | N | N | 0 | 0 | + // 1 | 0 | 1 | 1 | 1 | + // 1 | 1 | 0 | 1 | 0 | + // 1 | N | N | 0 | 0 | + // N | 0 | N | 0 | 0 | + // N | 1 | N | 0 | 0 | + // N | N | N | 0 | 0 | + // + // a | b | F4 = (a == null && b == null) | Final = F3 OR F4 | + // | | | | + // 0 | 0 | 0 | 0 OR 0 = 0 | + // 0 | 1 | 0 | 1 OR 0 = 1 | + // 0 | N | 0 | 0 OR 0 = 0 | + // 1 | 0 | 0 | 1 OR 0 = 1 | + // 1 | 1 | 0 | 0 OR 0 = 0 | + // 1 | N | 0 | 0 OR 0 = 0 | + // N | 0 | 0 | 0 OR 0 = 0 | + // N | 1 | 0 | 0 OR 0 = 0 | + // N | N | 1 | 0 OR 1 = 1 | + private SqlBinaryExpression ExpandNegatedNullableEqualNullable( + SqlExpression left, SqlExpression right, SqlExpression leftIsNull, SqlExpression rightIsNull) + => _sqlExpressionFactory.OrElse( + _sqlExpressionFactory.AndAlso( + _sqlExpressionFactory.NotEqual(left, right), + _sqlExpressionFactory.AndAlso( + _sqlExpressionFactory.Not(leftIsNull), + _sqlExpressionFactory.Not(rightIsNull))), + _sqlExpressionFactory.AndAlso( + leftIsNull, + rightIsNull)); + + // ?a == b -> (a == b) && (a != null) + // + // a | b | F1 = a == b | F2 = (a != null) | Final = F1 && F2 | + // | | | | | + // 0 | 0 | 1 | 1 | 1 | + // 0 | 1 | 0 | 1 | 0 | + // 1 | 0 | 0 | 1 | 0 | + // 1 | 1 | 1 | 1 | 1 | + // N | 0 | N | 0 | 0 | + // N | 1 | N | 0 | 0 | + private SqlBinaryExpression ExpandNullableEqualNonNullable( + SqlExpression left, SqlExpression right, SqlExpression leftIsNull) + => _sqlExpressionFactory.AndAlso( + _sqlExpressionFactory.Equal(left, right), + _sqlExpressionFactory.Not(leftIsNull)); + + // !(?a) == b -> (a != b) && (a != null) + // + // a | b | F1 = a != b | F2 = (a != null) | Final = F1 && F2 | + // | | | | | + // 0 | 0 | 0 | 1 | 0 | + // 0 | 1 | 1 | 1 | 1 | + // 1 | 0 | 1 | 1 | 1 | + // 1 | 1 | 0 | 1 | 0 | + // N | 0 | N | 0 | 0 | + // N | 1 | N | 0 | 0 | + private SqlBinaryExpression ExpandNegatedNullableEqualNonNullable( + SqlExpression left, SqlExpression right, SqlExpression leftIsNull) + => _sqlExpressionFactory.AndAlso( + _sqlExpressionFactory.NotEqual(left, right), + _sqlExpressionFactory.Not(leftIsNull)); + + // ?a != ?b -> [(a != b) || (a == null || b == null)] && (a != null || b != null) + // + // a | b | F1 = a != b | F2 = (a == null || b == null) | F3 = F1 || F2 | + // | | | | | + // 0 | 0 | 0 | 0 | 0 | + // 0 | 1 | 1 | 0 | 1 | + // 0 | N | N | 1 | 1 | + // 1 | 0 | 1 | 0 | 1 | + // 1 | 1 | 0 | 0 | 0 | + // 1 | N | N | 1 | 1 | + // N | 0 | N | 1 | 1 | + // N | 1 | N | 1 | 1 | + // N | N | N | 1 | 1 | + // + // a | b | F4 = (a != null || b != null) | Final = F3 && F4 | + // | | | | + // 0 | 0 | 1 | 0 && 1 = 0 | + // 0 | 1 | 1 | 1 && 1 = 1 | + // 0 | N | 1 | 1 && 1 = 1 | + // 1 | 0 | 1 | 1 && 1 = 1 | + // 1 | 1 | 1 | 0 && 1 = 0 | + // 1 | N | 1 | 1 && 1 = 1 | + // N | 0 | 1 | 1 && 1 = 1 | + // N | 1 | 1 | 1 && 1 = 1 | + // N | N | 0 | 1 && 0 = 0 | + private SqlBinaryExpression ExpandNullableNotEqualNullable( + SqlExpression left, SqlExpression right, SqlExpression leftIsNull, SqlExpression rightIsNull) + => _sqlExpressionFactory.AndAlso( + _sqlExpressionFactory.OrElse( + _sqlExpressionFactory.NotEqual(left, right), + _sqlExpressionFactory.OrElse( + leftIsNull, + rightIsNull)), + _sqlExpressionFactory.OrElse( + _sqlExpressionFactory.Not(leftIsNull), + _sqlExpressionFactory.Not(rightIsNull))); + + // !(?a) != ?b -> [(a == b) || (a == null || b == null)] && (a != null || b != null) + // + // a | b | F1 = a == b | F2 = (a == null || b == null) | F3 = F1 || F2 | + // | | | | | + // 0 | 0 | 1 | 0 | 1 | + // 0 | 1 | 0 | 0 | 0 | + // 0 | N | N | 1 | 1 | + // 1 | 0 | 0 | 0 | 0 | + // 1 | 1 | 1 | 0 | 1 | + // 1 | N | N | 1 | 1 | + // N | 0 | N | 1 | 1 | + // N | 1 | N | 1 | 1 | + // N | N | N | 1 | 1 | + // + // a | b | F4 = (a != null || b != null) | Final = F3 && F4 | + // | | | | + // 0 | 0 | 1 | 1 && 1 = 1 | + // 0 | 1 | 1 | 0 && 1 = 0 | + // 0 | N | 1 | 1 && 1 = 1 | + // 1 | 0 | 1 | 0 && 1 = 0 | + // 1 | 1 | 1 | 1 && 1 = 1 | + // 1 | N | 1 | 1 && 1 = 1 | + // N | 0 | 1 | 1 && 1 = 1 | + // N | 1 | 1 | 1 && 1 = 1 | + // N | N | 0 | 1 && 0 = 0 | + private SqlBinaryExpression ExpandNegatedNullableNotEqualNullable( + SqlExpression left, SqlExpression right, SqlExpression leftIsNull, SqlExpression rightIsNull) + => _sqlExpressionFactory.AndAlso( + _sqlExpressionFactory.OrElse( + _sqlExpressionFactory.Equal(left, right), + _sqlExpressionFactory.OrElse( + leftIsNull, + rightIsNull)), + _sqlExpressionFactory.OrElse( + _sqlExpressionFactory.Not(leftIsNull), + _sqlExpressionFactory.Not(rightIsNull))); + + // ?a != b -> (a != b) || (a == null) + // + // a | b | F1 = a != b | F2 = (a == null) | Final = F1 OR F2 | + // | | | | | + // 0 | 0 | 0 | 0 | 0 | + // 0 | 1 | 1 | 0 | 1 | + // 1 | 0 | 1 | 0 | 1 | + // 1 | 1 | 0 | 0 | 0 | + // N | 0 | N | 1 | 1 | + // N | 1 | N | 1 | 1 | + private SqlBinaryExpression ExpandNullableNotEqualNonNullable( + SqlExpression left, SqlExpression right, SqlExpression leftIsNull) + => _sqlExpressionFactory.OrElse( + _sqlExpressionFactory.NotEqual(left, right), + leftIsNull); + + // !(?a) != b -> (a == b) || (a == null) + // + // a | b | F1 = a == b | F2 = (a == null) | F3 = F1 OR F2 | + // | | | | | + // 0 | 0 | 1 | 0 | 1 | + // 0 | 1 | 0 | 0 | 0 | + // 1 | 0 | 0 | 0 | 0 | + // 1 | 1 | 1 | 0 | 1 | + // N | 0 | N | 1 | 1 | + // N | 1 | N | 1 | 1 | + private SqlBinaryExpression ExpandNegatedNullableNotEqualNonNullable( + SqlExpression left, SqlExpression right, SqlExpression leftIsNull) + => _sqlExpressionFactory.OrElse( + _sqlExpressionFactory.Equal(left, right), + leftIsNull); + } +} diff --git a/src/EFCore.Relational/Query/PipeLine/SqlExpressionOptimizingVisitor.cs b/src/EFCore.Relational/Query/PipeLine/SqlExpressionOptimizingVisitor.cs new file mode 100644 index 00000000000..fe00bd0afab --- /dev/null +++ b/src/EFCore.Relational/Query/PipeLine/SqlExpressionOptimizingVisitor.cs @@ -0,0 +1,196 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore.Relational.Query.Pipeline; +using Microsoft.EntityFrameworkCore.Relational.Query.Pipeline.SqlExpressions; + +namespace Microsoft.EntityFrameworkCore.Query.Pipeline +{ + public class SqlExpressionOptimizingVisitor : ExpressionVisitor + { + private readonly ISqlExpressionFactory _sqlExpressionFactory; + + private readonly Dictionary _expressionTypesNegationMap + = new Dictionary + { + { ExpressionType.AndAlso, ExpressionType.OrElse }, + { ExpressionType.OrElse, ExpressionType.AndAlso }, + { ExpressionType.Equal, ExpressionType.NotEqual }, + { ExpressionType.NotEqual, ExpressionType.Equal }, + { ExpressionType.GreaterThan, ExpressionType.LessThanOrEqual }, + { ExpressionType.GreaterThanOrEqual, ExpressionType.LessThan }, + { ExpressionType.LessThan, ExpressionType.GreaterThanOrEqual }, + { ExpressionType.LessThanOrEqual, ExpressionType.GreaterThan }, + }; + + public SqlExpressionOptimizingVisitor(ISqlExpressionFactory sqlExpressionFactory) + { + _sqlExpressionFactory = sqlExpressionFactory; + } + + protected override Expression VisitExtension(Expression extensionExpression) + { + if (extensionExpression is SqlUnaryExpression sqlUnaryExpression) + { + return VisitSqlUnaryExpression(sqlUnaryExpression); + } + + if (extensionExpression is SqlBinaryExpression sqlBinaryExpression) + { + return VisitSqlBinaryExpression(sqlBinaryExpression); + } + + return base.VisitExtension(extensionExpression); + } + + private Expression VisitSqlUnaryExpression(SqlUnaryExpression sqlUnaryExpression) + { + // !(true) -> false + // !(false) -> true + if (sqlUnaryExpression.OperatorType == ExpressionType.Not + && sqlUnaryExpression.Operand is SqlConstantExpression innerConstantBool + && innerConstantBool.Value is bool value) + { + return value + ? _sqlExpressionFactory.Constant(false, sqlUnaryExpression.TypeMapping) + : _sqlExpressionFactory.Constant(true, sqlUnaryExpression.TypeMapping); + } + + // NULL IS NULL -> true + // non_nullablee_constant IS NULL -> false + if (sqlUnaryExpression.OperatorType == ExpressionType.Equal + && sqlUnaryExpression.Operand is SqlConstantExpression innerConstantNull1) + { + return _sqlExpressionFactory.Constant(innerConstantNull1.Value == null, sqlUnaryExpression.TypeMapping); + } + + // NULL IS NOT NULL -> false + // non_nullablee_constant IS NOT NULL -> true + if (sqlUnaryExpression.OperatorType == ExpressionType.NotEqual + && sqlUnaryExpression.Operand is SqlConstantExpression innerConstantNull2) + { + return _sqlExpressionFactory.Constant(innerConstantNull2.Value != null, sqlUnaryExpression.TypeMapping); + } + + if (sqlUnaryExpression.Operand is SqlUnaryExpression innerUnary) + { + if (sqlUnaryExpression.OperatorType == ExpressionType.Not) + { + // !(!a) -> a + if (innerUnary.OperatorType == ExpressionType.Not) + { + return Visit(innerUnary.Operand); + } + + if (innerUnary.OperatorType == ExpressionType.Equal) + { + //!(a IS NULL) -> a IS NOT NULL + return Visit(_sqlExpressionFactory.IsNotNull(innerUnary.Operand)); + } + + //!(a IS NOT NULL) -> a IS NULL + if (innerUnary.OperatorType == ExpressionType.NotEqual) + { + return Visit(_sqlExpressionFactory.IsNull(innerUnary.Operand)); + } + } + + // (!a) IS NULL <==> a IS NULL + if (sqlUnaryExpression.OperatorType == ExpressionType.Equal + && innerUnary.OperatorType == ExpressionType.Not) + { + return Visit(_sqlExpressionFactory.IsNull(innerUnary.Operand)); + } + + // (!a) IS NOT NULL <==> a IS NOT NULL + if (sqlUnaryExpression.OperatorType == ExpressionType.NotEqual + && innerUnary.OperatorType == ExpressionType.Not) + { + return Visit(_sqlExpressionFactory.IsNotNull(innerUnary.Operand)); + } + } + + if (sqlUnaryExpression.Operand is SqlBinaryExpression innerBinary) + { + // De Morgan's + if (innerBinary.OperatorType == ExpressionType.AndAlso + || innerBinary.OperatorType == ExpressionType.OrElse) + { + var newLeft = (SqlExpression)Visit(_sqlExpressionFactory.Not(innerBinary.Left)); + var newRight = (SqlExpression)Visit(_sqlExpressionFactory.Not(innerBinary.Right)); + + return innerBinary.OperatorType == ExpressionType.AndAlso + ? _sqlExpressionFactory.OrElse(newLeft, newRight) + : _sqlExpressionFactory.AndAlso(newLeft, newRight); + } + + // note that those optimizations are only valid in 2-value logic + // they are safe to do here because null semantics removes possibility of nulls in the tree + // however if we decide to do "partial" null semantics (that doesn't distinguish between NULL and FALSE, e.g. for predicates) + // we need to be extra careful here + if (_expressionTypesNegationMap.ContainsKey(innerBinary.OperatorType)) + { + return Visit( + _sqlExpressionFactory.MakeBinary( + _expressionTypesNegationMap[innerBinary.OperatorType], + innerBinary.Left, + innerBinary.Right, + innerBinary.TypeMapping)); + } + } + + var newOperand = (SqlExpression)Visit(sqlUnaryExpression.Operand); + + return sqlUnaryExpression.Update(newOperand); + } + + private Expression VisitSqlBinaryExpression(SqlBinaryExpression sqlBinaryExpression) + { + var newLeft = (SqlExpression)Visit(sqlBinaryExpression.Left); + var newRight = (SqlExpression)Visit(sqlBinaryExpression.Right); + + if (sqlBinaryExpression.OperatorType == ExpressionType.AndAlso + || sqlBinaryExpression.OperatorType == ExpressionType.OrElse) + { + + var newLeftConstant = newLeft as SqlConstantExpression; + var newRightConstant = newRight as SqlConstantExpression; + + // true && a -> a + // true || a -> true + // false && a -> false + // false || a -> a + if (newLeftConstant != null) + { + return sqlBinaryExpression.OperatorType == ExpressionType.AndAlso + ? (bool)newLeftConstant.Value + ? newRight + : newLeftConstant + : (bool)newLeftConstant.Value + ? newLeftConstant + : newRight; + } + else if (newRightConstant != null) + { + // a && true -> a + // a || true -> true + // a && false -> false + // a || false -> a + return sqlBinaryExpression.OperatorType == ExpressionType.AndAlso + ? (bool)newRightConstant.Value + ? newLeft + : newRightConstant + : (bool)newRightConstant.Value + ? newRightConstant + : newLeft; + } + + return sqlBinaryExpression.Update(newLeft, newRight); + } + + return sqlBinaryExpression.Update(newLeft, newRight); + } + } +} diff --git a/src/EFCore.Relational/Query/Pipeline/RelationalShapedQueryOptimizingExpressionVisitors.cs b/src/EFCore.Relational/Query/Pipeline/RelationalShapedQueryOptimizingExpressionVisitors.cs index d967cfd30d4..2dbbb73ac75 100644 --- a/src/EFCore.Relational/Query/Pipeline/RelationalShapedQueryOptimizingExpressionVisitors.cs +++ b/src/EFCore.Relational/Query/Pipeline/RelationalShapedQueryOptimizingExpressionVisitors.cs @@ -2,25 +2,38 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Query.Pipeline; namespace Microsoft.EntityFrameworkCore.Relational.Query.Pipeline { public class RelationalShapedQueryOptimizer : ShapedQueryOptimizer { - private QueryCompilationContext2 _queryCompilationContext; + private readonly QueryCompilationContext2 _queryCompilationContext; - public RelationalShapedQueryOptimizer(QueryCompilationContext2 queryCompilationContext) + public RelationalShapedQueryOptimizer( + QueryCompilationContext2 queryCompilationContext, + ISqlExpressionFactory sqlExpressionFactory) { _queryCompilationContext = queryCompilationContext; + SqlExpressionFactory = sqlExpressionFactory; } + protected ISqlExpressionFactory SqlExpressionFactory { get; private set; } + public override Expression Visit(Expression query) { query = base.Visit(query); query = new ShaperExpressionDedupingExpressionVisitor().Process(query); query = new SelectExpressionProjectionApplyingExpressionVisitor().Visit(query); query = new SelectExpressionTableAliasUniquifyingExpressionVisitor().Visit(query); + + if (!RelationalOptionsExtension.Extract(_queryCompilationContext.ContextOptions).UseRelationalNulls) + { + query = new NullSemanticsRewritingVisitor(SqlExpressionFactory).Visit(query); + } + + query = new SqlExpressionOptimizingVisitor(SqlExpressionFactory).Visit(query); query = new NullComparisonTransformingExpressionVisitor().Visit(query); return query; diff --git a/src/EFCore.Relational/Query/Pipeline/RelationalShapedQueryOptimizingExpressionVisitorsFactory.cs b/src/EFCore.Relational/Query/Pipeline/RelationalShapedQueryOptimizingExpressionVisitorsFactory.cs index c4a47132c55..9b8d4766475 100644 --- a/src/EFCore.Relational/Query/Pipeline/RelationalShapedQueryOptimizingExpressionVisitorsFactory.cs +++ b/src/EFCore.Relational/Query/Pipeline/RelationalShapedQueryOptimizingExpressionVisitorsFactory.cs @@ -7,9 +7,16 @@ namespace Microsoft.EntityFrameworkCore.Relational.Query.Pipeline { public class RelationalShapedQueryOptimizerFactory : ShapedQueryOptimizerFactory { + protected ISqlExpressionFactory SqlExpressionFactory { get; private set; } + + public RelationalShapedQueryOptimizerFactory(ISqlExpressionFactory sqlExpressionFactory) + { + SqlExpressionFactory = sqlExpressionFactory; + } + public override ShapedQueryOptimizer Create(QueryCompilationContext2 queryCompilationContext) { - return new RelationalShapedQueryOptimizer(queryCompilationContext); + return new RelationalShapedQueryOptimizer(queryCompilationContext, SqlExpressionFactory); } } } diff --git a/src/EFCore.Relational/Query/Pipeline/SqlExpressions/CaseExpression.cs b/src/EFCore.Relational/Query/Pipeline/SqlExpressions/CaseExpression.cs index 7de293657d8..7d26989c63b 100644 --- a/src/EFCore.Relational/Query/Pipeline/SqlExpressions/CaseExpression.cs +++ b/src/EFCore.Relational/Query/Pipeline/SqlExpressions/CaseExpression.cs @@ -82,9 +82,13 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) } public virtual CaseExpression Update( - SqlExpression operand, IReadOnlyList whenClauses, SqlExpression elseResult) + SqlExpression operand, + IReadOnlyList whenClauses, + SqlExpression elseResult) { - return new CaseExpression(operand, whenClauses, elseResult); + return operand != Operand || !whenClauses.SequenceEqual(WhenClauses) || elseResult != ElseResult + ? new CaseExpression(operand, whenClauses, elseResult) + : this; } #endregion diff --git a/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SqlBinaryExpression.cs b/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SqlBinaryExpression.cs index cde35b5a0ab..642b5e13587 100644 --- a/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SqlBinaryExpression.cs +++ b/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SqlBinaryExpression.cs @@ -136,7 +136,7 @@ public override void Print(ExpressionPrinter expressionPrinter) expressionPrinter.StringBuilder.Append(")"); } - expressionPrinter.StringBuilder.Append(" " + expressionPrinter.GenerateBinaryOperator(OperatorType) + " "); + expressionPrinter.StringBuilder.Append(expressionPrinter.GenerateBinaryOperator(OperatorType)); requiresBrackets = RequiresBrackets(Right); diff --git a/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SqlFunctionExpression.cs b/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SqlFunctionExpression.cs index a467408fdc8..3bdce32f63b 100644 --- a/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SqlFunctionExpression.cs +++ b/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SqlFunctionExpression.cs @@ -136,7 +136,7 @@ public SqlFunctionExpression ApplyTypeMapping(RelationalTypeMapping typeMapping) public SqlFunctionExpression Update(SqlExpression instance, IReadOnlyList arguments) { - return instance != Instance || arguments != Arguments + return instance != Instance || !arguments.SequenceEqual(Arguments) ? new SqlFunctionExpression(instance, Schema, FunctionName, IsNiladic, arguments, Type, TypeMapping) : this; } diff --git a/src/EFCore.SqlServer/Query/Pipeline/SearchConditionConvertingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Pipeline/SearchConditionConvertingExpressionVisitor.cs index 76cceb85b81..9419c3b171f 100644 --- a/src/EFCore.SqlServer/Query/Pipeline/SearchConditionConvertingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Pipeline/SearchConditionConvertingExpressionVisitor.cs @@ -24,9 +24,9 @@ private Expression ApplyConversion(SqlExpression sqlExpression, bool condition) : ConvertToValue(sqlExpression, condition); private Expression ConvertToSearchCondition(SqlExpression sqlExpression, bool condition) - => condition - ? sqlExpression - : BuildCompareToExpression(sqlExpression); + => condition + ? sqlExpression + : BuildCompareToExpression(sqlExpression); private Expression ConvertToValue(SqlExpression sqlExpression, bool condition) { @@ -246,67 +246,15 @@ protected override Expression VisitSqlFunction(SqlFunctionExpression sqlFunction { var parentSearchCondition = _isSearchCondition; _isSearchCondition = false; - var changed = false; var instance = (SqlExpression)Visit(sqlFunctionExpression.Instance); - changed |= instance != sqlFunctionExpression.Instance; var arguments = new SqlExpression[sqlFunctionExpression.Arguments.Count]; for (var i = 0; i < arguments.Length; i++) { arguments[i] = (SqlExpression)Visit(sqlFunctionExpression.Arguments[i]); - changed |= arguments[i] != sqlFunctionExpression.Arguments[i]; } _isSearchCondition = parentSearchCondition; - SqlExpression newFunction; - if (changed) - { - if (sqlFunctionExpression.Instance != null) - { - if (sqlFunctionExpression.IsNiladic) - { - newFunction = _sqlExpressionFactory.Function( - instance, - sqlFunctionExpression.FunctionName, - sqlFunctionExpression.IsNiladic, - sqlFunctionExpression.Type, - sqlFunctionExpression.TypeMapping); - } - else - { - newFunction = _sqlExpressionFactory.Function( - instance, - sqlFunctionExpression.FunctionName, - arguments, - sqlFunctionExpression.Type, - sqlFunctionExpression.TypeMapping); - } - } - else - { - if (sqlFunctionExpression.IsNiladic) - { - newFunction = _sqlExpressionFactory.Function( - sqlFunctionExpression.Schema, - sqlFunctionExpression.FunctionName, - sqlFunctionExpression.IsNiladic, - sqlFunctionExpression.Type, - sqlFunctionExpression.TypeMapping); - } - else - { - newFunction = _sqlExpressionFactory.Function( - sqlFunctionExpression.Schema, - sqlFunctionExpression.FunctionName, - arguments, - sqlFunctionExpression.Type, - sqlFunctionExpression.TypeMapping); - } - } - } - else - { - newFunction = sqlFunctionExpression; - } + var newFunction = sqlFunctionExpression.Update(instance, arguments); var condition = string.Equals(sqlFunctionExpression.FunctionName, "FREETEXT") || string.Equals(sqlFunctionExpression.FunctionName, "CONTAINS"); diff --git a/src/EFCore.SqlServer/Query/Pipeline/SqlServerShapedQueryOptimizingExpressionVisitors.cs b/src/EFCore.SqlServer/Query/Pipeline/SqlServerShapedQueryOptimizingExpressionVisitors.cs index 4495ddf42a7..020992d9724 100644 --- a/src/EFCore.SqlServer/Query/Pipeline/SqlServerShapedQueryOptimizingExpressionVisitors.cs +++ b/src/EFCore.SqlServer/Query/Pipeline/SqlServerShapedQueryOptimizingExpressionVisitors.cs @@ -1,7 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System.Collections.Generic; using System.Linq.Expressions; using Microsoft.EntityFrameworkCore.Query.Pipeline; using Microsoft.EntityFrameworkCore.Relational.Query.Pipeline; @@ -10,20 +9,17 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Pipeline { public class SqlServerShapedQueryOptimizer : RelationalShapedQueryOptimizer { - private readonly ISqlExpressionFactory _sqlExpressionFactory; - public SqlServerShapedQueryOptimizer( QueryCompilationContext2 queryCompilationContext, ISqlExpressionFactory sqlExpressionFactory) - : base(queryCompilationContext) + : base(queryCompilationContext, sqlExpressionFactory) { - _sqlExpressionFactory = sqlExpressionFactory; } public override Expression Visit(Expression query) { query = base.Visit(query); - query = new SearchConditionConvertingExpressionVisitor(_sqlExpressionFactory).Visit(query); + query = new SearchConditionConvertingExpressionVisitor(SqlExpressionFactory).Visit(query); return query; } diff --git a/src/EFCore.SqlServer/Query/Pipeline/SqlServerShapedQueryOptimizingExpressionVisitorsFactory.cs b/src/EFCore.SqlServer/Query/Pipeline/SqlServerShapedQueryOptimizingExpressionVisitorsFactory.cs index 867b048ac29..72fbea5e4ed 100644 --- a/src/EFCore.SqlServer/Query/Pipeline/SqlServerShapedQueryOptimizingExpressionVisitorsFactory.cs +++ b/src/EFCore.SqlServer/Query/Pipeline/SqlServerShapedQueryOptimizingExpressionVisitorsFactory.cs @@ -3,22 +3,19 @@ using Microsoft.EntityFrameworkCore.Query.Pipeline; using Microsoft.EntityFrameworkCore.Relational.Query.Pipeline; -using Microsoft.EntityFrameworkCore.Storage; namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Pipeline { public class SqlServerShapedQueryOptimizerFactory : RelationalShapedQueryOptimizerFactory { - private readonly ISqlExpressionFactory _sqlExpressionFactory; - public SqlServerShapedQueryOptimizerFactory(ISqlExpressionFactory sqlExpressionFactory) + : base(sqlExpressionFactory) { - _sqlExpressionFactory = sqlExpressionFactory; } public override ShapedQueryOptimizer Create(QueryCompilationContext2 queryCompilationContext) { - return new SqlServerShapedQueryOptimizer(queryCompilationContext, _sqlExpressionFactory); + return new SqlServerShapedQueryOptimizer(queryCompilationContext, SqlExpressionFactory); } } } diff --git a/src/EFCore.SqlServer/Query/Pipeline/SqlServerStringMethodTranslator.cs b/src/EFCore.SqlServer/Query/Pipeline/SqlServerStringMethodTranslator.cs index c12de9dc366..f5e7c0a066c 100644 --- a/src/EFCore.SqlServer/Query/Pipeline/SqlServerStringMethodTranslator.cs +++ b/src/EFCore.SqlServer/Query/Pipeline/SqlServerStringMethodTranslator.cs @@ -266,10 +266,7 @@ private SqlExpression TranslateStartsEndsWith(SqlExpression instance, SqlExpress instance, _sqlExpressionFactory.Constant(null, stringTypeMapping)); } - if (constantString.Length == 0) - { - return _sqlExpressionFactory.Constant(true); - } + return constantString.Any(c => IsLikeWildChar(c)) ? _sqlExpressionFactory.Like( instance, @@ -289,42 +286,34 @@ private SqlExpression TranslateStartsEndsWith(SqlExpression instance, SqlExpress // because of wildchars). if (startsWith) { - return _sqlExpressionFactory.OrElse( - _sqlExpressionFactory.AndAlso( - _sqlExpressionFactory.Like( + return _sqlExpressionFactory.AndAlso( + _sqlExpressionFactory.Like( + instance, + _sqlExpressionFactory.Add( instance, - _sqlExpressionFactory.Add( - instance, - _sqlExpressionFactory.Constant("%"))), - _sqlExpressionFactory.Equal( - _sqlExpressionFactory.Function( - "LEFT", - new[] { - instance, - _sqlExpressionFactory.Function("LEN", new[] { pattern }, typeof(int)) - }, - typeof(string), - stringTypeMapping), - pattern)), + _sqlExpressionFactory.Constant("%"))), _sqlExpressionFactory.Equal( - pattern, - _sqlExpressionFactory.Constant(string.Empty))); + _sqlExpressionFactory.Function( + "LEFT", + new[] { + instance, + _sqlExpressionFactory.Function("LEN", new[] { pattern }, typeof(int)) + }, + typeof(string), + stringTypeMapping), + pattern)); } - return _sqlExpressionFactory.OrElse( - _sqlExpressionFactory.Equal( - _sqlExpressionFactory.Function( - "RIGHT", - new[] { - instance, - _sqlExpressionFactory.Function("LEN", new[] { pattern }, typeof(int)) - }, - typeof(string), - stringTypeMapping), - pattern), - _sqlExpressionFactory.Equal( - pattern, - _sqlExpressionFactory.Constant(string.Empty))); + return _sqlExpressionFactory.Equal( + _sqlExpressionFactory.Function( + "RIGHT", + new[] { + instance, + _sqlExpressionFactory.Function("LEN", new[] { pattern }, typeof(int)) + }, + typeof(string), + stringTypeMapping), + pattern); } // See https://docs.microsoft.com/en-us/sql/t-sql/language-elements/like-transact-sql @@ -340,8 +329,10 @@ private string EscapeLikePattern(string pattern) { builder.Append(LikeEscapeChar); } + builder.Append(c); } + return builder.ToString(); } } diff --git a/src/EFCore/Query/NavigationExpansion/NavigationExpander.cs b/src/EFCore/Query/NavigationExpansion/NavigationExpander.cs index 44aa564ffa3..2832b263801 100644 --- a/src/EFCore/Query/NavigationExpansion/NavigationExpander.cs +++ b/src/EFCore/Query/NavigationExpansion/NavigationExpander.cs @@ -3,7 +3,6 @@ using System.Linq.Expressions; using JetBrains.Annotations; -using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors; using Microsoft.EntityFrameworkCore.Utilities; diff --git a/src/EFCore/Query/PipeLine/NegationOptimizingVisitor.cs b/src/EFCore/Query/PipeLine/NegationOptimizingVisitor.cs new file mode 100644 index 00000000000..5c1ccdd6b37 --- /dev/null +++ b/src/EFCore/Query/PipeLine/NegationOptimizingVisitor.cs @@ -0,0 +1,70 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Linq.Expressions; + +namespace Microsoft.EntityFrameworkCore.Query.Pipeline +{ + public class NegationOptimizingVisitor : ExpressionVisitor + { + private readonly Dictionary _expressionTypesNegationMap + = new Dictionary + { + { ExpressionType.AndAlso, ExpressionType.OrElse }, + { ExpressionType.OrElse, ExpressionType.AndAlso }, + { ExpressionType.Equal, ExpressionType.NotEqual }, + { ExpressionType.NotEqual, ExpressionType.Equal }, + { ExpressionType.GreaterThan, ExpressionType.LessThanOrEqual }, + { ExpressionType.GreaterThanOrEqual, ExpressionType.LessThan }, + { ExpressionType.LessThan, ExpressionType.GreaterThanOrEqual }, + { ExpressionType.LessThanOrEqual, ExpressionType.GreaterThan }, + }; + + protected override Expression VisitUnary(UnaryExpression unaryExpression) + { + if (unaryExpression.NodeType == ExpressionType.Not) + { + if (unaryExpression.Operand is ConstantExpression innerConstant + && innerConstant.Value is bool value) + { + // !(true) -> false + // !(false) -> true + return Expression.Constant(!value); + } + + if (unaryExpression.Operand is UnaryExpression innerUnary + && innerUnary.NodeType == ExpressionType.Not) + { + // !(!a) -> a + return Visit(innerUnary.Operand); + } + + if (unaryExpression.Operand is BinaryExpression innerBinary) + { + // De Morgan's + if (innerBinary.NodeType == ExpressionType.AndAlso + || innerBinary.NodeType == ExpressionType.OrElse) + { + return Visit( + Expression.MakeBinary( + _expressionTypesNegationMap[innerBinary.NodeType], + Expression.Not(innerBinary.Left), + Expression.Not(innerBinary.Right))); + } + + if (_expressionTypesNegationMap.ContainsKey(innerBinary.NodeType)) + { + return Visit( + Expression.MakeBinary( + _expressionTypesNegationMap[innerBinary.NodeType], + innerBinary.Left, + innerBinary.Right)); + } + } + } + + return base.VisitUnary(unaryExpression); + } + } +} diff --git a/src/EFCore/Query/Pipeline/FunctionPreprocessingVisitor.cs b/src/EFCore/Query/Pipeline/FunctionPreprocessingVisitor.cs new file mode 100644 index 00000000000..e4c674811bd --- /dev/null +++ b/src/EFCore/Query/Pipeline/FunctionPreprocessingVisitor.cs @@ -0,0 +1,87 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Linq.Expressions; +using System.Reflection; + +namespace Microsoft.EntityFrameworkCore.Query.Pipeline +{ + public class FunctionPreprocessingVisitor : ExpressionVisitor + { + private static readonly MethodInfo _startsWithMethodInfo + = typeof(string).GetRuntimeMethod(nameof(string.StartsWith), new[] { typeof(string) }); + + private static readonly MethodInfo _endsWithMethodInfo + = typeof(string).GetRuntimeMethod(nameof(string.EndsWith), new[] { typeof(string) }); + + private static Expression _constantNullString = Expression.Constant(null, typeof(string)); + + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + if (_startsWithMethodInfo.Equals(methodCallExpression.Method) + || _endsWithMethodInfo.Equals(methodCallExpression.Method)) + { + if (methodCallExpression.Arguments[0] is ConstantExpression constantArgument + && (string)constantArgument.Value == string.Empty) + { + // every string starts/ends with empty string. + return Expression.Constant(true); + } + + var newObject = Visit(methodCallExpression.Object); + var newArgument = Visit(methodCallExpression.Arguments[0]); + + var result = Expression.AndAlso( + Expression.NotEqual(newObject, _constantNullString), + Expression.AndAlso( + Expression.NotEqual(newArgument, _constantNullString), + methodCallExpression.Update(newObject, new[] { newArgument }))); + + return newArgument is ConstantExpression + ? result + : Expression.OrElse( + Expression.Equal( + newArgument, + Expression.Constant(string.Empty)), + result); + } + + return base.VisitMethodCall(methodCallExpression); + } + + protected override Expression VisitUnary(UnaryExpression unaryExpression) + { + if (unaryExpression.NodeType == ExpressionType.Not + && unaryExpression.Operand is MethodCallExpression innerMethodCall + && (_startsWithMethodInfo.Equals(innerMethodCall.Method) + || _endsWithMethodInfo.Equals(innerMethodCall.Method))) + { + if (innerMethodCall.Arguments[0] is ConstantExpression constantArgument + && (string)constantArgument.Value == string.Empty) + { + // every string starts/ends with empty string. + return Expression.Constant(false); + } + + var newObject = Visit(innerMethodCall.Object); + var newArgument = Visit(innerMethodCall.Arguments[0]); + + var result = Expression.AndAlso( + Expression.NotEqual(newObject, _constantNullString), + Expression.AndAlso( + Expression.NotEqual(newArgument, _constantNullString), + Expression.Not(innerMethodCall.Update(newObject, new[] { newArgument })))); + + return newArgument is ConstantExpression + ? result + : Expression.AndAlso( + Expression.NotEqual( + newArgument, + Expression.Constant(string.Empty)), + result); + } + + return base.VisitUnary(unaryExpression); + } + } +} diff --git a/src/EFCore/Query/Pipeline/QueryCompilationContext2.cs b/src/EFCore/Query/Pipeline/QueryCompilationContext2.cs index 824e15440da..b4251b39775 100644 --- a/src/EFCore/Query/Pipeline/QueryCompilationContext2.cs +++ b/src/EFCore/Query/Pipeline/QueryCompilationContext2.cs @@ -4,7 +4,6 @@ using System; using System.Linq.Expressions; using Microsoft.EntityFrameworkCore.Infrastructure; -using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Metadata; namespace Microsoft.EntityFrameworkCore.Query.Pipeline @@ -27,11 +26,14 @@ public QueryCompilationContext2( IShapedQueryOptimizerFactory shapedQueryOptimizerFactory, IShapedQueryCompilingExpressionVisitorFactory shapedQueryCompilingExpressionVisitorFactory, ICurrentDbContext currentDbContext, + IDbContextOptions contextOptions, bool async) { Async = async; TrackQueryResults = currentDbContext.Context.ChangeTracker.QueryTrackingBehavior == QueryTrackingBehavior.TrackAll; Model = model; + ContextOptions = contextOptions; + _queryOptimizerFactory = queryOptimizerFactory; _entityQueryableTranslatorFactory = entityQuerableTranslatorFactory; _queryableMethodTranslatingExpressionVisitorFactory = queryableMethodTranslatingExpressionVisitorFactory; @@ -41,6 +43,7 @@ public QueryCompilationContext2( public bool Async { get; } public IModel Model { get; } + public IDbContextOptions ContextOptions { get; } public bool TrackQueryResults { get; internal set; } public virtual Func CreateQueryExecutor(Expression query) diff --git a/src/EFCore/Query/Pipeline/QueryCompilationContextFactory2.cs b/src/EFCore/Query/Pipeline/QueryCompilationContextFactory2.cs index 04123f5f827..5501d919f0f 100644 --- a/src/EFCore/Query/Pipeline/QueryCompilationContextFactory2.cs +++ b/src/EFCore/Query/Pipeline/QueryCompilationContextFactory2.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using Microsoft.EntityFrameworkCore.Infrastructure; -using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Metadata; namespace Microsoft.EntityFrameworkCore.Query.Pipeline @@ -16,6 +15,7 @@ public class QueryCompilationContextFactory2 : IQueryCompilationContextFactory2 private readonly IShapedQueryOptimizerFactory _shapedQueryOptimizerFactory; private readonly IShapedQueryCompilingExpressionVisitorFactory _shapedQueryCompilingExpressionVisitorFactory; private readonly ICurrentDbContext _currentDbContext; + private readonly IDbContextOptions _contextOptions; public QueryCompilationContextFactory2( IModel model, @@ -24,7 +24,8 @@ public QueryCompilationContextFactory2( IQueryableMethodTranslatingExpressionVisitorFactory queryableMethodTranslatingExpressionVisitorFactory, IShapedQueryOptimizerFactory shapedQueryOptimizerFactory, IShapedQueryCompilingExpressionVisitorFactory shapedQueryCompilingExpressionVisitorFactory, - ICurrentDbContext currentDbContext) + ICurrentDbContext currentDbContext, + IDbContextOptions contextOptions) { _model = model; _queryOptimizerFactory = queryOptimizerFactory; @@ -33,6 +34,7 @@ public QueryCompilationContextFactory2( _shapedQueryOptimizerFactory = shapedQueryOptimizerFactory; _shapedQueryCompilingExpressionVisitorFactory = shapedQueryCompilingExpressionVisitorFactory; _currentDbContext = currentDbContext; + _contextOptions = contextOptions; } public QueryCompilationContext2 Create(bool async) @@ -45,6 +47,7 @@ public QueryCompilationContext2 Create(bool async) _shapedQueryOptimizerFactory, _shapedQueryCompilingExpressionVisitorFactory, _currentDbContext, + _contextOptions, async); return queryCompilationContext; diff --git a/src/EFCore/Query/Pipeline/QueryOptimizingExpressionVisitor.cs b/src/EFCore/Query/Pipeline/QueryOptimizingExpressionVisitor.cs index d8eed338b7e..07cb00a269e 100644 --- a/src/EFCore/Query/Pipeline/QueryOptimizingExpressionVisitor.cs +++ b/src/EFCore/Query/Pipeline/QueryOptimizingExpressionVisitor.cs @@ -26,6 +26,7 @@ public Expression Visit(Expression query) query = new QueryMetadataExtractingExpressionVisitor(_queryCompilationContext).Visit(query); query = new GroupJoinFlatteningExpressionVisitor().Visit(query); query = new NullCheckRemovingExpressionVisitor().Visit(query); + query = new FunctionPreprocessingVisitor().Visit(query); new EnumerableVerifyingExpressionVisitor().Visit(query); return query; diff --git a/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs index e00712239d3..a9d93eb3625 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs @@ -531,15 +531,15 @@ public virtual void Null_comparison_in_order_by_with_relational_nulls() } } - [Fact] + [Fact(Skip = "issue #15743")] public virtual void Null_comparison_in_join_key_with_relational_nulls() { using (var ctx = CreateContext(useRelationalNulls: true)) { var query = ctx.Entities1.Join(ctx.Entities2, e1 => e1.NullableStringA != "Foo", e2 => e2.NullableBoolB != true, (o, i) => new { o, i }); - var result = query.ToList(); - Assert.Equal(405, result.Count); + var result = query.ToList(); + Assert.Equal(162, result.Count); } } @@ -729,7 +729,7 @@ public virtual void Switching_parameter_value_to_null_produces_different_cache_e } } - [Fact] + [Fact(Skip = "issue #15704")] public virtual void From_sql_composed_with_relational_null_comparison() { using (var context = CreateContext(useRelationalNulls: true)) @@ -879,7 +879,109 @@ public virtual void Null_semantics_applied_when_comparing_two_functions_with_mul useRelationalNulls: false); } - public static TResult? MaybeScalar(object caller, Func expression) + [Fact] + public virtual void Null_semantics_coalesce() + { + AssertQuery(es => es.Where(e => e.NullableBoolA == (e.NullableBoolB ?? e.BoolC))); + AssertQuery(es => es.Where(e => e.NullableBoolA == (e.NullableBoolB ?? e.NullableBoolC))); + AssertQuery(es => es.Where(e => (e.NullableBoolB ?? e.BoolC) != e.NullableBoolA)); + AssertQuery(es => es.Where(e => (e.NullableBoolB ?? e.NullableBoolC) != e.NullableBoolA)); + } + + [Fact] + public virtual void Null_semantics_conditional() + { + AssertQuery(es => es.Where(e => e.BoolA == (e.BoolB ? e.NullableBoolB : e.NullableBoolC))); + AssertQuery(es => es.Where(e => (e.NullableBoolA != e.NullableBoolB ? e.BoolB : e.BoolC) == e.BoolA)); + AssertQuery(es => es.Where(e => (e.BoolA ? e.NullableBoolA != e.NullableBoolB : e.BoolC) != e.BoolB ? e.BoolA : e.NullableBoolB == e.NullableBoolC)); + } + + [Fact] + public virtual void Null_semantics_function() + { + AssertQuery( + es => es.Where(e => e.NullableStringA.Substring(0, e.IntA) != e.NullableStringB), + es => es.Where(e => Maybe(e.NullableIntA, () => e.NullableStringA.Substring(0, e.IntA)) != e.NullableStringB), + useRelationalNulls: false); + } + + [Fact] + public virtual void Null_semantics_join_with_composite_key() + { + using (var ctx = CreateContext()) + { + var query = from e1 in ctx.Entities1 + join e2 in ctx.Entities2 + on new + { + one = e1.NullableStringA, + two = e1.NullableStringB != e1.NullableStringC, + three = true + } + equals new + { + one = e2.NullableStringB, + two = e2.NullableBoolA ?? e2.BoolC, + three = true + } + select new { e1, e2 }; + + var result = query.ToList(); + + var expected = (from e1 in ctx.Entities1.ToList() + join e2 in ctx.Entities2.ToList() + on new + { + one = e1.NullableStringA, + two = e1.NullableStringB != e1.NullableStringC, + three = true + } + equals new + { + one = e2.NullableStringB, + two = e2.NullableBoolA ?? e2.BoolC, + three = true + } + select new { e1, e2 }).ToList(); + + Assert.Equal(result.Count, expected.Count); + } + } + + [Fact(Skip = "issue #14171")] + public virtual void Null_semantics_contains() + { + using (var ctx = CreateContext()) + { + var ids = new List { 1, 2 }; + var query1 = ctx.Entities1.Where(e => ids.Contains(e.NullableIntA)); + var result1 = query1.ToList(); + + var query2 = ctx.Entities1.Where(e => !ids.Contains(e.NullableIntA)); + var result2 = query2.ToList(); + + var ids2 = new List { 1, 2, null }; + var query3 = ctx.Entities1.Where(e => ids.Contains(e.NullableIntA)); + var result3 = query3.ToList(); + + var query4 = ctx.Entities1.Where(e => !ids.Contains(e.NullableIntA)); + var result4 = query4.ToList(); + + var query5 = ctx.Entities1.Where(e => !new List { 1, 2 }.Contains(e.NullableIntA)); + var result5 = query5.ToList(); + + var query6 = ctx.Entities1.Where(e => !new List { 1, 2, null }.Contains(e.NullableIntA)); + var result6 = query6.ToList(); + } + } + + protected static TResult Maybe(object caller, Func expression) + where TResult : class + { + return caller == null ? null : expression(); + } + + protected static TResult? MaybeScalar(object caller, Func expression) where TResult : struct { return caller == null ? null : expression(); diff --git a/test/EFCore.Specification.Tests/Query/QueryTestBase.cs b/test/EFCore.Specification.Tests/Query/QueryTestBase.cs index 1b09ba6ebb7..8b37a0da833 100644 --- a/test/EFCore.Specification.Tests/Query/QueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/QueryTestBase.cs @@ -1605,19 +1605,19 @@ public static Action CollectionAsserter( #region Helpers - Maybe - public static TResult Maybe(object caller, Func expression) + protected static TResult Maybe(object caller, Func expression) where TResult : class { return caller == null ? null : expression(); } - public static TResult? MaybeScalar(object caller, Func expression) + protected static TResult? MaybeScalar(object caller, Func expression) where TResult : struct { return caller == null ? null : expression(); } - public static IEnumerable MaybeDefaultIfEmpty(IEnumerable caller) + protected static IEnumerable MaybeDefaultIfEmpty(IEnumerable caller) where TResult : class { return caller == null diff --git a/test/EFCore.SqlServer.FunctionalTests/BuiltInDataTypesSqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/BuiltInDataTypesSqlServerTest.cs index 094b22f6d5b..6e7efca1bd1 100644 --- a/test/EFCore.SqlServer.FunctionalTests/BuiltInDataTypesSqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/BuiltInDataTypesSqlServerTest.cs @@ -49,7 +49,7 @@ var results Assert.Equal( @"SELECT [m].[Int] FROM [MappedNullableDataTypes] AS [m] -WHERE [m].[TimeSpanAsTime] = '00:01:02'", +WHERE ([m].[TimeSpanAsTime] = '00:01:02') AND [m].[TimeSpanAsTime] IS NOT NULL", Sql, ignoreLineEndingDifferences: true); } @@ -93,7 +93,7 @@ var results SELECT [m].[Int] FROM [MappedNullableDataTypes] AS [m] -WHERE [m].[TimeSpanAsTime] = @__timeSpan_0", +WHERE (([m].[TimeSpanAsTime] = @__timeSpan_0) AND ([m].[TimeSpanAsTime] IS NOT NULL AND @__timeSpan_0 IS NOT NULL)) OR ([m].[TimeSpanAsTime] IS NULL AND @__timeSpan_0 IS NULL)", Sql, ignoreLineEndingDifferences: true); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs index 69ace1943aa..187c08428ed 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs @@ -7,7 +7,7 @@ namespace Microsoft.EntityFrameworkCore.Query { - internal class NullSemanticsQuerySqlServerTest : NullSemanticsQueryTestBase + public class NullSemanticsQuerySqlServerTest : NullSemanticsQueryTestBase { // ReSharper disable once UnusedParameter.Local public NullSemanticsQuerySqlServerTest(NullSemanticsQuerySqlServerFixture fixture, ITestOutputHelper testOutputHelper) @@ -1295,6 +1295,93 @@ FROM [Entities1] AS [e] WHERE ((REPLACE([e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC]) <> [e].[NullableStringA]) OR ((([e].[NullableStringA] IS NULL OR [e].[NullableStringB] IS NULL) OR [e].[NullableStringC] IS NULL) OR [e].[NullableStringA] IS NULL)) AND ((([e].[NullableStringA] IS NOT NULL AND [e].[NullableStringB] IS NOT NULL) AND [e].[NullableStringC] IS NOT NULL) OR [e].[NullableStringA] IS NOT NULL)"); } + public override void Null_semantics_coalesce() + { + base.Null_semantics_coalesce(); + + AssertSql( + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE ([e].[NullableBoolA] = CAST(COALESCE([e].[NullableBoolB], [e].[BoolC]) AS bit)) AND [e].[NullableBoolA] IS NOT NULL", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE (([e].[NullableBoolA] = COALESCE([e].[NullableBoolB], [e].[NullableBoolC])) AND ([e].[NullableBoolA] IS NOT NULL AND COALESCE([e].[NullableBoolB], [e].[NullableBoolC]) IS NOT NULL)) OR ([e].[NullableBoolA] IS NULL AND COALESCE([e].[NullableBoolB], [e].[NullableBoolC]) IS NULL)", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE (CAST(COALESCE([e].[NullableBoolB], [e].[BoolC]) AS bit) <> [e].[NullableBoolA]) OR [e].[NullableBoolA] IS NULL", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE ((COALESCE([e].[NullableBoolB], [e].[NullableBoolC]) <> [e].[NullableBoolA]) OR (COALESCE([e].[NullableBoolB], [e].[NullableBoolC]) IS NULL OR [e].[NullableBoolA] IS NULL)) AND (COALESCE([e].[NullableBoolB], [e].[NullableBoolC]) IS NOT NULL OR [e].[NullableBoolA] IS NOT NULL)"); + } + + public override void Null_semantics_conditional() + { + base.Null_semantics_conditional(); + + AssertSql( + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE ([e].[BoolA] = CASE + WHEN [e].[BoolB] = CAST(1 AS bit) THEN [e].[NullableBoolB] + ELSE [e].[NullableBoolC] +END) AND CASE + WHEN [e].[BoolB] = CAST(1 AS bit) THEN [e].[NullableBoolB] + ELSE [e].[NullableBoolC] +END IS NOT NULL", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE CASE + WHEN (([e].[NullableBoolA] <> [e].[NullableBoolB]) OR ([e].[NullableBoolA] IS NULL OR [e].[NullableBoolB] IS NULL)) AND ([e].[NullableBoolA] IS NOT NULL OR [e].[NullableBoolB] IS NOT NULL) THEN [e].[BoolB] + ELSE [e].[BoolC] +END = [e].[BoolA]", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE CASE + WHEN CASE + WHEN [e].[BoolA] = CAST(1 AS bit) THEN CASE + WHEN (([e].[NullableBoolA] <> [e].[NullableBoolB]) OR ([e].[NullableBoolA] IS NULL OR [e].[NullableBoolB] IS NULL)) AND ([e].[NullableBoolA] IS NOT NULL OR [e].[NullableBoolB] IS NOT NULL) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) + END + ELSE [e].[BoolC] + END <> [e].[BoolB] THEN [e].[BoolA] + ELSE CASE + WHEN (([e].[NullableBoolB] = [e].[NullableBoolC]) AND ([e].[NullableBoolB] IS NOT NULL AND [e].[NullableBoolC] IS NOT NULL)) OR ([e].[NullableBoolB] IS NULL AND [e].[NullableBoolC] IS NULL) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) + END +END = CAST(1 AS bit)"); + } + + public override void Null_semantics_function() + { + base.Null_semantics_function(); + + AssertSql( + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE ((SUBSTRING([e].[NullableStringA], 0 + 1, [e].[IntA]) <> [e].[NullableStringB]) OR (SUBSTRING([e].[NullableStringA], 0 + 1, [e].[IntA]) IS NULL OR [e].[NullableStringB] IS NULL)) AND (SUBSTRING([e].[NullableStringA], 0 + 1, [e].[IntA]) IS NOT NULL OR [e].[NullableStringB] IS NOT NULL)"); + } + + public override void Null_semantics_join_with_composite_key() + { + base.Null_semantics_join_with_composite_key(); + + AssertSql( + @""); + } + + public override void Null_semantics_contains() + { + base.Null_semantics_contains(); + + AssertSql( + @""); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/NullSemanticsQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/NullSemanticsQuerySqliteTest.cs index 6c8797a074d..fa84596f92f 100644 --- a/test/EFCore.Sqlite.FunctionalTests/Query/NullSemanticsQuerySqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/Query/NullSemanticsQuerySqliteTest.cs @@ -6,7 +6,7 @@ namespace Microsoft.EntityFrameworkCore.Query { - internal class NullSemanticsQuerySqliteTest : NullSemanticsQueryTestBase + public class NullSemanticsQuerySqliteTest : NullSemanticsQueryTestBase { public NullSemanticsQuerySqliteTest(NullSemanticsQuerySqliteFixture fixture) : base(fixture)