Skip to content

Commit ecdeed4

Browse files
committed
Fix to #18555 - Query: when rewriting null semantics for comparisons with functions use function specific metadata to get better SQL
When we need to compute whether a function is null, we often can just evaluate nullability of it's constituents (instance & arguments), e.g. SUBSTRING(stringProperty, 0, 5) == null -> stringProperty == null Adding metadata to SqlFunctionExpression: nullResultAllowed - indicates whether function can ever be null, instancePropagatesNullability - indicates whether function instance can be used to calculate nullability of the entire function argumentsPropagateNullability - array indicating which (if any) function arguments can be used to calculate nullability of the entire function If "canBeNull" is set to false we can instantly compute IsNull/IsNotNull of that function. Otherwise, we look at values of instancePropagatesNullability and argumentsPropagateNullability - if any of them are set to true, we use corresponding argument(s) to compute function nullability. If all of them are set to false we must fallback to the old method and evaluate nullability of the entire function.
1 parent d42926c commit ecdeed4

File tree

63 files changed

+843
-127
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+843
-127
lines changed

src/EFCore.Relational/Query/ISqlExpressionFactory.cs

+48
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,54 @@ SqlFunctionExpression Function(
131131
[NotNull] Type returnType,
132132
[CanBeNull] RelationalTypeMapping typeMapping = null);
133133

134+
SqlFunctionExpression Function(
135+
[NotNull] string name,
136+
[NotNull] IEnumerable<SqlExpression> arguments,
137+
bool nullResultAllowed,
138+
[NotNull] IEnumerable<bool> argumentsPropagateNullability,
139+
[NotNull] Type returnType,
140+
[CanBeNull] RelationalTypeMapping typeMapping = null);
141+
142+
SqlFunctionExpression Function(
143+
[CanBeNull] string schema,
144+
[NotNull] string name,
145+
[NotNull] IEnumerable<SqlExpression> arguments,
146+
bool nullResultAllowed,
147+
[NotNull] IEnumerable<bool> argumentsPropagateNullability,
148+
[NotNull] Type returnType,
149+
[CanBeNull] RelationalTypeMapping typeMapping = null);
150+
151+
SqlFunctionExpression Function(
152+
[CanBeNull] SqlExpression instance,
153+
[NotNull] string name,
154+
[NotNull] IEnumerable<SqlExpression> arguments,
155+
bool nullResultAllowed,
156+
bool instancePropagatesNullability,
157+
[NotNull] IEnumerable<bool> argumentsPropagateNullability,
158+
[NotNull] Type returnType,
159+
[CanBeNull] RelationalTypeMapping typeMapping = null);
160+
161+
SqlFunctionExpression Function(
162+
[NotNull] string name,
163+
bool nullResultAllowed,
164+
[NotNull] Type returnType,
165+
[CanBeNull] RelationalTypeMapping typeMapping = null);
166+
167+
SqlFunctionExpression Function(
168+
[NotNull] string schema,
169+
[NotNull] string name,
170+
bool nullResultAllowed,
171+
[NotNull] Type returnType,
172+
[CanBeNull] RelationalTypeMapping typeMapping = null);
173+
174+
SqlFunctionExpression Function(
175+
[CanBeNull] SqlExpression instance,
176+
[NotNull] string name,
177+
bool nullResultAllowed,
178+
bool instancePropagatesNullability,
179+
[NotNull] Type returnType,
180+
[CanBeNull] RelationalTypeMapping typeMapping = null);
181+
134182
ExistsExpression Exists([NotNull] SelectExpression subquery, bool negated);
135183
InExpression In([NotNull] SqlExpression item, [NotNull] SqlExpression values, bool negated);
136184
InExpression In([NotNull] SqlExpression item, [NotNull] SelectExpression subquery, bool negated);

src/EFCore.Relational/Query/NullabilityBasedSqlProcessingExpressionVisitor.cs

+87-29
Original file line numberDiff line numberDiff line change
@@ -898,8 +898,7 @@ private SqlExpression RewriteNullSemantics(
898898
return sqlBinaryExpression.Update(left, right);
899899
}
900900

901-
private SqlExpression SimplifyLogicalSqlBinaryExpression(
902-
SqlBinaryExpression sqlBinaryExpression)
901+
private SqlExpression SimplifyLogicalSqlBinaryExpression(SqlBinaryExpression sqlBinaryExpression)
903902
{
904903
var leftUnary = sqlBinaryExpression.Left as SqlUnaryExpression;
905904
var rightUnary = sqlBinaryExpression.Right as SqlUnaryExpression;
@@ -1253,37 +1252,96 @@ protected virtual SqlExpression ProcessNullNotNull(
12531252
sqlUnaryExpression.TypeMapping));
12541253
}
12551254

1256-
case SqlFunctionExpression sqlFunctionExpression
1257-
when sqlFunctionExpression.IsBuiltIn && string.Equals("COALESCE", sqlFunctionExpression.Name, StringComparison.OrdinalIgnoreCase):
1255+
case SqlFunctionExpression sqlFunctionExpression:
12581256
{
1259-
// for coalesce:
1260-
// (a ?? b) == null -> a == null && b == null
1261-
// (a ?? b) != null -> a != null || b != null
1262-
var left = ProcessNullNotNull(
1263-
SqlExpressionFactory.MakeUnary(
1264-
sqlUnaryExpression.OperatorType,
1265-
sqlFunctionExpression.Arguments[0],
1266-
typeof(bool),
1267-
sqlUnaryExpression.TypeMapping),
1268-
operandNullable: null);
1257+
if (sqlFunctionExpression.IsBuiltIn && string.Equals("COALESCE", sqlFunctionExpression.Name, StringComparison.OrdinalIgnoreCase))
1258+
{
1259+
// for coalesce:
1260+
// (a ?? b) == null -> a == null && b == null
1261+
// (a ?? b) != null -> a != null || b != null
1262+
var left = ProcessNullNotNull(
1263+
SqlExpressionFactory.MakeUnary(
1264+
sqlUnaryExpression.OperatorType,
1265+
sqlFunctionExpression.Arguments[0],
1266+
typeof(bool),
1267+
sqlUnaryExpression.TypeMapping),
1268+
operandNullable: null);
1269+
1270+
var right = ProcessNullNotNull(
1271+
SqlExpressionFactory.MakeUnary(
1272+
sqlUnaryExpression.OperatorType,
1273+
sqlFunctionExpression.Arguments[1],
1274+
typeof(bool),
1275+
sqlUnaryExpression.TypeMapping),
1276+
operandNullable: null);
12691277

1270-
var right = ProcessNullNotNull(
1271-
SqlExpressionFactory.MakeUnary(
1272-
sqlUnaryExpression.OperatorType,
1273-
sqlFunctionExpression.Arguments[1],
1274-
typeof(bool),
1275-
sqlUnaryExpression.TypeMapping),
1276-
operandNullable: null);
1278+
return SimplifyLogicalSqlBinaryExpression(
1279+
SqlExpressionFactory.MakeBinary(
1280+
sqlUnaryExpression.OperatorType == ExpressionType.Equal
1281+
? ExpressionType.AndAlso
1282+
: ExpressionType.OrElse,
1283+
left,
1284+
right,
1285+
sqlUnaryExpression.TypeMapping));
1286+
}
12771287

1278-
return SimplifyLogicalSqlBinaryExpression(
1279-
SqlExpressionFactory.MakeBinary(
1280-
sqlUnaryExpression.OperatorType == ExpressionType.Equal
1281-
? ExpressionType.AndAlso
1282-
: ExpressionType.OrElse,
1283-
left,
1284-
right,
1285-
sqlUnaryExpression.TypeMapping));
1288+
if (!sqlFunctionExpression.NullResultAllowed)
1289+
{
1290+
// when we know that function can't be nullable:
1291+
// non_nullable_function() is null-> false
1292+
// non_nullable_function() is not null -> true
1293+
return SqlExpressionFactory.Constant(
1294+
sqlUnaryExpression.OperatorType == ExpressionType.NotEqual,
1295+
sqlUnaryExpression.TypeMapping);
1296+
}
1297+
1298+
// see if we can derive function nullability from it's instance and/or arguments
1299+
// rather than evaluating nullability of the entire function
1300+
var nullabilityPropagationElements = new List<SqlExpression>();
1301+
if (sqlFunctionExpression.Instance != null
1302+
&& sqlFunctionExpression.InstancPropagatesNullability == true)
1303+
{
1304+
nullabilityPropagationElements.Add(sqlFunctionExpression.Instance);
1305+
}
1306+
1307+
for (var i = 0; i < sqlFunctionExpression.Arguments.Count; i++)
1308+
{
1309+
if (sqlFunctionExpression.ArgumentsPropagateNullability[i])
1310+
{
1311+
nullabilityPropagationElements.Add(sqlFunctionExpression.Arguments[i]);
1312+
}
1313+
}
1314+
1315+
if (nullabilityPropagationElements.Count > 0)
1316+
{
1317+
var result = ProcessNullNotNull(
1318+
SqlExpressionFactory.MakeUnary(
1319+
sqlUnaryExpression.OperatorType,
1320+
nullabilityPropagationElements[0],
1321+
sqlUnaryExpression.Type,
1322+
sqlUnaryExpression.TypeMapping),
1323+
operandNullable: null);
1324+
1325+
foreach (var element in nullabilityPropagationElements.Skip(1))
1326+
{
1327+
result = SimplifyLogicalSqlBinaryExpression(
1328+
sqlUnaryExpression.OperatorType == ExpressionType.Equal
1329+
? SqlExpressionFactory.OrElse(
1330+
result,
1331+
ProcessNullNotNull(
1332+
SqlExpressionFactory.IsNull(element),
1333+
operandNullable: null))
1334+
: SqlExpressionFactory.AndAlso(
1335+
result,
1336+
ProcessNullNotNull(
1337+
SqlExpressionFactory.IsNotNull(element),
1338+
operandNullable: null)));
1339+
}
1340+
1341+
return result;
1342+
}
12861343
}
1344+
break;
12871345
}
12881346

12891347
return sqlUnaryExpression;

src/EFCore.Relational/Query/RelationalMethodCallTranslatorProvider.cs

+2
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ public virtual SqlExpression Translate(
5656
dbFunction.Schema,
5757
dbFunction.Name,
5858
arguments,
59+
nullResultAllowed: true,
60+
argumentsPropagateNullability: arguments.Select(a => true).ToList(),
5961
method.ReturnType);
6062
}
6163

src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs

+49-8
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,20 @@ public virtual SqlExpression TranslateAverage([NotNull] Expression expression)
9999
return inputType == typeof(float)
100100
? SqlExpressionFactory.Convert(
101101
SqlExpressionFactory.Function(
102-
"AVG", new[] { sqlExpression }, typeof(double)),
102+
"AVG",
103+
new[] { sqlExpression },
104+
nullResultAllowed: true,
105+
argumentsPropagateNullability: new[] { false },
106+
typeof(double)),
103107
sqlExpression.Type,
104108
sqlExpression.TypeMapping)
105109
: (SqlExpression)SqlExpressionFactory.Function(
106-
"AVG", new[] { sqlExpression }, sqlExpression.Type, sqlExpression.TypeMapping);
110+
"AVG",
111+
new[] { sqlExpression },
112+
nullResultAllowed: true,
113+
argumentsPropagateNullability: new[] { false },
114+
sqlExpression.Type,
115+
sqlExpression.TypeMapping);
107116
}
108117

109118
public virtual SqlExpression TranslateCount([CanBeNull] Expression expression = null)
@@ -115,7 +124,12 @@ public virtual SqlExpression TranslateCount([CanBeNull] Expression expression =
115124
}
116125

117126
return SqlExpressionFactory.ApplyDefaultTypeMapping(
118-
SqlExpressionFactory.Function("COUNT", new[] { SqlExpressionFactory.Fragment("*") }, typeof(int)));
127+
SqlExpressionFactory.Function(
128+
"COUNT",
129+
new[] { SqlExpressionFactory.Fragment("*") },
130+
nullResultAllowed: false,
131+
argumentsPropagateNullability: new[] { false },
132+
typeof(int)));
119133
}
120134

121135
public virtual SqlExpression TranslateLongCount([CanBeNull] Expression expression = null)
@@ -127,7 +141,12 @@ public virtual SqlExpression TranslateLongCount([CanBeNull] Expression expressio
127141
}
128142

129143
return SqlExpressionFactory.ApplyDefaultTypeMapping(
130-
SqlExpressionFactory.Function("COUNT", new[] { SqlExpressionFactory.Fragment("*") }, typeof(long)));
144+
SqlExpressionFactory.Function(
145+
"COUNT",
146+
new[] { SqlExpressionFactory.Fragment("*") },
147+
nullResultAllowed: false,
148+
argumentsPropagateNullability: new[] { false },
149+
typeof(long)));
131150
}
132151

133152
public virtual SqlExpression TranslateMax([NotNull] Expression expression)
@@ -140,7 +159,13 @@ public virtual SqlExpression TranslateMax([NotNull] Expression expression)
140159
}
141160

142161
return sqlExpression != null
143-
? SqlExpressionFactory.Function("MAX", new[] { sqlExpression }, sqlExpression.Type, sqlExpression.TypeMapping)
162+
? SqlExpressionFactory.Function(
163+
"MAX",
164+
new[] { sqlExpression },
165+
nullResultAllowed: true,
166+
argumentsPropagateNullability: new[] { false },
167+
sqlExpression.Type,
168+
sqlExpression.TypeMapping)
144169
: null;
145170
}
146171

@@ -154,7 +179,13 @@ public virtual SqlExpression TranslateMin([NotNull] Expression expression)
154179
}
155180

156181
return sqlExpression != null
157-
? SqlExpressionFactory.Function("MIN", new[] { sqlExpression }, sqlExpression.Type, sqlExpression.TypeMapping)
182+
? SqlExpressionFactory.Function(
183+
"MIN",
184+
new[] { sqlExpression },
185+
nullResultAllowed: true,
186+
argumentsPropagateNullability: new[] { false },
187+
sqlExpression.Type,
188+
sqlExpression.TypeMapping)
158189
: null;
159190
}
160191

@@ -176,11 +207,21 @@ public virtual SqlExpression TranslateSum([NotNull] Expression expression)
176207

177208
return inputType == typeof(float)
178209
? SqlExpressionFactory.Convert(
179-
SqlExpressionFactory.Function("SUM", new[] { sqlExpression }, typeof(double)),
210+
SqlExpressionFactory.Function(
211+
"SUM",
212+
new[] { sqlExpression },
213+
nullResultAllowed: true,
214+
argumentsPropagateNullability: new[] { false },
215+
typeof(double)),
180216
inputType,
181217
sqlExpression.TypeMapping)
182218
: (SqlExpression)SqlExpressionFactory.Function(
183-
"SUM", new[] { sqlExpression }, inputType, sqlExpression.TypeMapping);
219+
"SUM",
220+
new[] { sqlExpression },
221+
nullResultAllowed: true,
222+
argumentsPropagateNullability: new[] { false },
223+
inputType,
224+
sqlExpression.TypeMapping);
184225
}
185226

186227
private sealed class SqlTypeMappingVerifyingExpressionVisitor : ExpressionVisitor

0 commit comments

Comments
 (0)