Skip to content

Commit b26c794

Browse files
committed
Additional refactoring of Null Semantics:
- moving NullSemantics visitor after 2nd level cache - we need to know the parameter values to properly handle IN expressions wrt null semantics, - NullSemantics visitor needs to go before SqlExpressionOptimizer and SearchCondition, so those two are also moved after 2nd level cache, - moving optimizations that depend on knowing the nullability to NullSemantics visitor - optimizer now only contains optimizations that also work in 3-value logic, or when we know nulls can't happen, - merging InExpressionValuesExpandingExpressionVisitor int NullSemantics visitor, so that we don't apply the rewrite for UseRelationalNulls. Resolves #11464 Resolves #15722 Resolved #18338 Resolves #18597 Resolves #18689
1 parent 29de3d6 commit b26c794

20 files changed

+951
-688
lines changed

src/EFCore.Relational/Query/Internal/NullSemanticsRewritingExpressionVisitor.cs

+666-169
Large diffs are not rendered by default.

src/EFCore.Relational/Query/Internal/SqlExpressionOptimizingExpressionVisitor.cs

+19-252
Large diffs are not rendered by default.

src/EFCore.Relational/Query/RelationalParameterBasedQueryTranslationPostprocessor.cs

+8-164
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

44
using System;
5-
using System.Collections;
65
using System.Collections.Generic;
76
using System.Data.Common;
87
using System.Linq.Expressions;
@@ -38,17 +37,19 @@ public virtual (SelectExpression selectExpression, bool canCache) Optimize(
3837
SelectExpression selectExpression, IReadOnlyDictionary<string, object> parametersValues)
3938
{
4039
var canCache = true;
40+
var nullSemanticsRewritingExpressionVisitor = new NullSemanticsRewritingExpressionVisitor(
41+
UseRelationalNulls,
42+
Dependencies.SqlExpressionFactory,
43+
parametersValues);
4144

42-
var inExpressionOptimized = new InExpressionValuesExpandingExpressionVisitor(
43-
Dependencies.SqlExpressionFactory, parametersValues).Visit(selectExpression);
44-
45-
if (!ReferenceEquals(selectExpression, inExpressionOptimized))
45+
var nullSemanticsOptimized = nullSemanticsRewritingExpressionVisitor.Visit(selectExpression);
46+
if (!nullSemanticsRewritingExpressionVisitor.CanCache)
4647
{
4748
canCache = false;
4849
}
4950

50-
var nullParametersOptimized = new ParameterNullabilityBasedSqlExpressionOptimizingExpressionVisitor(
51-
Dependencies.SqlExpressionFactory, UseRelationalNulls, parametersValues).Visit(inExpressionOptimized);
51+
var nullParametersOptimized = new SqlExpressionOptimizingExpressionVisitor(
52+
Dependencies.SqlExpressionFactory, UseRelationalNulls, parametersValues).Visit(nullSemanticsOptimized);
5253

5354
var fromSqlParameterOptimized = new FromSqlParameterApplyingExpressionVisitor(
5455
Dependencies.SqlExpressionFactory,
@@ -63,163 +64,6 @@ public virtual (SelectExpression selectExpression, bool canCache) Optimize(
6364
return (selectExpression: (SelectExpression)fromSqlParameterOptimized, canCache);
6465
}
6566

66-
private sealed class ParameterNullabilityBasedSqlExpressionOptimizingExpressionVisitor : SqlExpressionOptimizingExpressionVisitor
67-
{
68-
private readonly IReadOnlyDictionary<string, object> _parametersValues;
69-
70-
public ParameterNullabilityBasedSqlExpressionOptimizingExpressionVisitor(
71-
ISqlExpressionFactory sqlExpressionFactory,
72-
bool useRelationalNulls,
73-
IReadOnlyDictionary<string, object> parametersValues)
74-
: base(sqlExpressionFactory, useRelationalNulls)
75-
{
76-
_parametersValues = parametersValues;
77-
}
78-
79-
protected override Expression VisitSqlUnaryExpression(SqlUnaryExpression sqlUnaryExpression)
80-
{
81-
var result = base.VisitSqlUnaryExpression(sqlUnaryExpression);
82-
if (result is SqlUnaryExpression newUnaryExpression
83-
&& newUnaryExpression.Operand is SqlParameterExpression parameterOperand)
84-
{
85-
var parameterValue = _parametersValues[parameterOperand.Name];
86-
if (sqlUnaryExpression.OperatorType == ExpressionType.Equal)
87-
{
88-
return SqlExpressionFactory.Constant(parameterValue == null, sqlUnaryExpression.TypeMapping);
89-
}
90-
91-
if (sqlUnaryExpression.OperatorType == ExpressionType.NotEqual)
92-
{
93-
return SqlExpressionFactory.Constant(parameterValue != null, sqlUnaryExpression.TypeMapping);
94-
}
95-
}
96-
97-
return result;
98-
}
99-
100-
protected override Expression VisitSqlBinaryExpression(SqlBinaryExpression sqlBinaryExpression)
101-
{
102-
var result = base.VisitSqlBinaryExpression(sqlBinaryExpression);
103-
if (result is SqlBinaryExpression sqlBinaryResult)
104-
{
105-
var leftNullParameter = sqlBinaryResult.Left is SqlParameterExpression leftParameter
106-
&& _parametersValues[leftParameter.Name] == null;
107-
108-
var rightNullParameter = sqlBinaryResult.Right is SqlParameterExpression rightParameter
109-
&& _parametersValues[rightParameter.Name] == null;
110-
111-
if ((sqlBinaryResult.OperatorType == ExpressionType.Equal || sqlBinaryResult.OperatorType == ExpressionType.NotEqual)
112-
&& (leftNullParameter || rightNullParameter))
113-
{
114-
return SimplifyNullComparisonExpression(
115-
sqlBinaryResult.OperatorType,
116-
sqlBinaryResult.Left,
117-
sqlBinaryResult.Right,
118-
leftNullParameter,
119-
rightNullParameter,
120-
sqlBinaryResult.TypeMapping);
121-
}
122-
}
123-
124-
return result;
125-
}
126-
}
127-
128-
private sealed class InExpressionValuesExpandingExpressionVisitor : ExpressionVisitor
129-
{
130-
private readonly ISqlExpressionFactory _sqlExpressionFactory;
131-
private readonly IReadOnlyDictionary<string, object> _parametersValues;
132-
133-
public InExpressionValuesExpandingExpressionVisitor(
134-
ISqlExpressionFactory sqlExpressionFactory, IReadOnlyDictionary<string, object> parametersValues)
135-
{
136-
_sqlExpressionFactory = sqlExpressionFactory;
137-
_parametersValues = parametersValues;
138-
}
139-
140-
public override Expression Visit(Expression expression)
141-
{
142-
if (expression is InExpression inExpression
143-
&& inExpression.Values != null)
144-
{
145-
var inValues = new List<object>();
146-
var hasNullValue = false;
147-
RelationalTypeMapping typeMapping = null;
148-
149-
switch (inExpression.Values)
150-
{
151-
case SqlConstantExpression sqlConstant:
152-
{
153-
typeMapping = sqlConstant.TypeMapping;
154-
var values = (IEnumerable)sqlConstant.Value;
155-
foreach (var value in values)
156-
{
157-
if (value == null)
158-
{
159-
hasNullValue = true;
160-
continue;
161-
}
162-
163-
inValues.Add(value);
164-
}
165-
166-
break;
167-
}
168-
169-
case SqlParameterExpression sqlParameter:
170-
{
171-
typeMapping = sqlParameter.TypeMapping;
172-
var values = (IEnumerable)_parametersValues[sqlParameter.Name];
173-
foreach (var value in values)
174-
{
175-
if (value == null)
176-
{
177-
hasNullValue = true;
178-
continue;
179-
}
180-
181-
inValues.Add(value);
182-
}
183-
184-
break;
185-
}
186-
}
187-
188-
var updatedInExpression = inValues.Count > 0
189-
? _sqlExpressionFactory.In(
190-
(SqlExpression)Visit(inExpression.Item),
191-
_sqlExpressionFactory.Constant(inValues, typeMapping),
192-
inExpression.IsNegated)
193-
: null;
194-
195-
var nullCheckExpression = hasNullValue
196-
? inExpression.IsNegated
197-
? _sqlExpressionFactory.IsNotNull(inExpression.Item)
198-
: _sqlExpressionFactory.IsNull(inExpression.Item)
199-
: null;
200-
201-
if (updatedInExpression != null
202-
&& nullCheckExpression != null)
203-
{
204-
return inExpression.IsNegated
205-
? _sqlExpressionFactory.AndAlso(updatedInExpression, nullCheckExpression)
206-
: _sqlExpressionFactory.OrElse(updatedInExpression, nullCheckExpression);
207-
}
208-
209-
if (updatedInExpression == null
210-
&& nullCheckExpression == null)
211-
{
212-
return _sqlExpressionFactory.Equal(
213-
_sqlExpressionFactory.Constant(true), _sqlExpressionFactory.Constant(inExpression.IsNegated));
214-
}
215-
216-
return (SqlExpression)updatedInExpression ?? nullCheckExpression;
217-
}
218-
219-
return base.Visit(expression);
220-
}
221-
}
222-
22367
private sealed class FromSqlParameterApplyingExpressionVisitor : ExpressionVisitor
22468
{
22569
private readonly IDictionary<FromSqlExpression, Expression> _visitedFromSqlExpressions

src/EFCore.Relational/Query/RelationalQueryTranslationPostprocessor.cs

+2-11
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ namespace Microsoft.EntityFrameworkCore.Query
99
{
1010
public class RelationalQueryTranslationPostprocessor : QueryTranslationPostprocessor
1111
{
12-
private readonly SqlExpressionOptimizingExpressionVisitor _sqlExpressionOptimizingExpressionVisitor;
13-
1412
public RelationalQueryTranslationPostprocessor(
1513
QueryTranslationPostprocessorDependencies dependencies,
1614
RelationalQueryTranslationPostprocessorDependencies relationalDependencies,
@@ -20,8 +18,6 @@ public RelationalQueryTranslationPostprocessor(
2018
RelationalDependencies = relationalDependencies;
2119
UseRelationalNulls = RelationalOptionsExtension.Extract(queryCompilationContext.ContextOptions).UseRelationalNulls;
2220
SqlExpressionFactory = relationalDependencies.SqlExpressionFactory;
23-
_sqlExpressionOptimizingExpressionVisitor
24-
= new SqlExpressionOptimizingExpressionVisitor(SqlExpressionFactory, UseRelationalNulls);
2521
}
2622

2723
protected virtual RelationalQueryTranslationPostprocessorDependencies RelationalDependencies { get; }
@@ -37,17 +33,12 @@ public override Expression Process(Expression query)
3733
query = new CollectionJoinApplyingExpressionVisitor().Visit(query);
3834
query = new TableAliasUniquifyingExpressionVisitor().Visit(query);
3935
query = new CaseWhenFlatteningExpressionVisitor(SqlExpressionFactory).Visit(query);
40-
41-
if (!UseRelationalNulls)
42-
{
43-
query = new NullSemanticsRewritingExpressionVisitor(SqlExpressionFactory).Visit(query);
44-
}
45-
4636
query = OptimizeSqlExpression(query);
4737

4838
return query;
4939
}
5040

51-
protected virtual Expression OptimizeSqlExpression(Expression query) => _sqlExpressionOptimizingExpressionVisitor.Visit(query);
41+
protected virtual Expression OptimizeSqlExpression(Expression query)
42+
=> query;
5243
}
5344
}

src/EFCore.SqlServer/Query/Internal/SqlServerParameterBasedQueryTranslationPostprocessor.cs

+5-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using System.Collections.Generic;
55
using Microsoft.EntityFrameworkCore.Query;
6+
using Microsoft.EntityFrameworkCore.Query.Internal;
67
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
78

89
namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal
@@ -25,7 +26,10 @@ public override (SelectExpression selectExpression, bool canCache) Optimize(
2526
var searchConditionOptimized = (SelectExpression)new SearchConditionConvertingExpressionVisitor(Dependencies.SqlExpressionFactory)
2627
.Visit(optimizedSelectExpression);
2728

28-
return (searchConditionOptimized, canCache);
29+
var optimized = (SelectExpression)new SqlExpressionOptimizingExpressionVisitor(
30+
Dependencies.SqlExpressionFactory, UseRelationalNulls, parametersValues).Visit(searchConditionOptimized);
31+
32+
return (optimized, canCache);
2933
}
3034
}
3135
}
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// Copyright (c) .NET Foundation. All rights reserved.
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

4-
using System.Linq.Expressions;
54
using Microsoft.EntityFrameworkCore.Query;
65

76
namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal
@@ -15,13 +14,5 @@ public SqlServerQueryTranslationPostprocessor(
1514
: base(dependencies, relationalDependencies, queryCompilationContext)
1615
{
1716
}
18-
19-
public override Expression Process(Expression query)
20-
{
21-
query = base.Process(query);
22-
query = new SearchConditionConvertingExpressionVisitor(SqlExpressionFactory).Visit(query);
23-
24-
return query;
25-
}
2617
}
2718
}

test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs

+65-27
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ public virtual void Contains_with_local_array_closure_false_with_null()
316316
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !ids.Contains(e.NullableStringA)));
317317
}
318318

319-
[ConditionalFact(Skip = "issue #14171")]
319+
[ConditionalFact]
320320
public virtual void Contains_with_local_nullable_array_closure_negated()
321321
{
322322
string[] ids = { "Foo" };
@@ -946,40 +946,58 @@ join e2 in _clientData._entities2
946946
}
947947
}
948948

949-
[ConditionalFact(Skip = "issue #14171")]
949+
[ConditionalFact]
950950
public virtual void Null_semantics_contains()
951951
{
952-
using var ctx = CreateContext();
953952
var ids = new List<int?> { 1, 2 };
954-
var query1 = ctx.Entities1.Where(e => ids.Contains(e.NullableIntA));
955-
var result1 = query1.ToList();
953+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => ids.Contains(e.NullableIntA)));
954+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !ids.Contains(e.NullableIntA)));
956955

957-
var query2 = ctx.Entities1.Where(e => !ids.Contains(e.NullableIntA));
958-
var result2 = query2.ToList();
956+
var ids2 = new List<int?> { 1, 2, null };
957+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => ids2.Contains(e.NullableIntA)));
958+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !ids2.Contains(e.NullableIntA)));
959959

960-
var ids2 = new List<int?>
961-
{
962-
1,
963-
2,
964-
null
965-
};
966-
var query3 = ctx.Entities1.Where(e => ids.Contains(e.NullableIntA));
967-
var result3 = query3.ToList();
960+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => new List<int?> { 1, 2 }.Contains(e.NullableIntA)));
961+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !new List<int?> { 1, 2 }.Contains(e.NullableIntA)));
962+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => new List<int?> { 1, 2, null }.Contains(e.NullableIntA)));
963+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !new List<int?> { 1, 2, null }.Contains(e.NullableIntA)));
964+
}
968965

969-
var query4 = ctx.Entities1.Where(e => !ids.Contains(e.NullableIntA));
970-
var result4 = query4.ToList();
966+
[ConditionalFact]
967+
public virtual void Null_semantics_contains_array_with_no_values()
968+
{
969+
var ids = new List<int?>();
970+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => ids.Contains(e.NullableIntA)));
971+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !ids.Contains(e.NullableIntA)));
971972

972-
var query5 = ctx.Entities1.Where(e => !new List<int?> { 1, 2 }.Contains(e.NullableIntA));
973-
var result5 = query5.ToList();
973+
var ids2 = new List<int?> { null };
974+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => ids2.Contains(e.NullableIntA)));
975+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !ids2.Contains(e.NullableIntA)));
974976

975-
var query6 = ctx.Entities1.Where(
976-
e => !new List<int?>
977-
{
978-
1,
979-
2,
980-
null
981-
}.Contains(e.NullableIntA));
982-
var result6 = query6.ToList();
977+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => new List<int?>().Contains(e.NullableIntA)));
978+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !new List<int?>().Contains(e.NullableIntA)));
979+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => new List<int?> { null }.Contains(e.NullableIntA)));
980+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !new List<int?> { null }.Contains(e.NullableIntA)));
981+
}
982+
983+
[ConditionalFact]
984+
public virtual void Null_semantics_contains_non_nullable_argument()
985+
{
986+
var ids = new List<int?> { 1, 2, null };
987+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => ids.Contains(e.IntA)));
988+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !ids.Contains(e.IntA)));
989+
990+
var ids2 = new List<int?> { 1, 2, };
991+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => ids2.Contains(e.IntA)));
992+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !ids2.Contains(e.IntA)));
993+
994+
var ids3 = new List<int?>();
995+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => ids3.Contains(e.IntA)));
996+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !ids3.Contains(e.IntA)));
997+
998+
var ids4 = new List<int?> { null };
999+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => ids4.Contains(e.IntA)));
1000+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !ids4.Contains(e.IntA)));
9831001
}
9841002

9851003
[ConditionalFact]
@@ -1044,6 +1062,26 @@ public virtual void Coalesce_not_equal()
10441062
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => (e.NullableIntA ?? 0) != 0));
10451063
}
10461064

1065+
[ConditionalFact]
1066+
public virtual void Negated_order_comparison_on_non_nullable_arguments_gets_optimized()
1067+
{
1068+
var i = 1;
1069+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !(e.IntA > i)));
1070+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !(e.IntA >= i)));
1071+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !(e.IntA < i)));
1072+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !(e.IntA <= i)));
1073+
}
1074+
1075+
[ConditionalFact(Skip = "issue #9544")]
1076+
public virtual void Negated_order_comparison_on_nullable_arguments_doesnt_get_optimized()
1077+
{
1078+
var i = 1;
1079+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !(e.NullableIntA > i)));
1080+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !(e.NullableIntA >= i)));
1081+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !(e.NullableIntA < i)));
1082+
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !(e.NullableIntA <= i)));
1083+
}
1084+
10471085
protected static TResult Maybe<TResult>(object caller, Func<TResult> expression)
10481086
where TResult : class
10491087
{

0 commit comments

Comments
 (0)