Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle IN optimizer transformation of ordering to constant (take 2) #16422

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq.Expressions;
Expand All @@ -15,12 +16,15 @@ private class InExpressionValuesExpandingExpressionVisitor : ExpressionVisitor
{
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private IReadOnlyDictionary<string, object> _parametersValues;
private ConstantExpressionIdentifier _constantExpressionIdentifier;
private bool _innerExpressionChangedToConstant;

public InExpressionValuesExpandingExpressionVisitor(
ISqlExpressionFactory sqlExpressionFactory, IReadOnlyDictionary<string, object> parametersValues)
{
_sqlExpressionFactory = sqlExpressionFactory;
_parametersValues = parametersValues;
_constantExpressionIdentifier = new ConstantExpressionIdentifier();
}

public override Expression Visit(Expression expression)
Expand All @@ -34,6 +38,7 @@ public override Expression Visit(Expression expression)

switch (inExpression.Values)
{
// TODO: This shouldn't be here - should be part of compilation instead (#16375)
case SqlConstantExpression sqlConstant:
{
typeMapping = sqlConstant.TypeMapping;
Expand Down Expand Up @@ -91,6 +96,7 @@ public override Expression Visit(Expression expression)

if (updatedInExpression == null && nullCheckExpression == null)
{
_innerExpressionChangedToConstant = true;
return _sqlExpressionFactory.Equal(_sqlExpressionFactory.Constant(true), _sqlExpressionFactory.Constant(inExpression.Negated));
}

Expand All @@ -99,6 +105,55 @@ public override Expression Visit(Expression expression)

return base.Visit(expression);
}

protected override Expression VisitExtension(Expression extensionExpression)
{
var visited = base.VisitExtension(extensionExpression);
if (_innerExpressionChangedToConstant && visited is OrderingExpression orderingExpression)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few design things (no need to address right now)

  • Instead of storing state you can just check if visited.Equals(extensionExpression). If it is changed, it will be different and reconstructed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea here was to limit the check to only the case where when we rewrite the expression into a constant - not all changes inside the ordering mean that we need to perform this (relatively expensive) check. This might be a somewhat unneeded optimization, let me know what you think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly un-needed. We can also do reference equality due to immutability and reconstruction. It should give same result.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be misunderstanding, but once again there are rewritings that this visitor performs which do not require our check. For example, if there's a null in the values list, the visitor adds null semantics - if we do a reference equality check it will show as modified, although there's no need to check whether it was converted to constant.

It may not be important to skip these cases perf-wise, but then again the stateful flag approach seems easy and safe enough, and allows us to check only when we know we've performed a "dangerous" rewrite.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

States are dangerous. No need of premature optimization.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have state in a lot of visitors around the pipeline, this is outside of the compilation phase (so perf-sensitive), and the statefulness is extremely limited.

But if you're deadset against it I can remove it.

{
// Our rewriting of InExpression above may have left an ordering which contains only constants -
// invalid in SqlServer. If so, rewrite the entire ordering to a single constant expression which
// QuerySqlGenerator will recognize and render as (SELECT 1).
_innerExpressionChangedToConstant = false;
return _constantExpressionIdentifier.IsExpressionConstant(orderingExpression)
? new OrderingExpression(
new SqlConstantExpression(Expression.Constant(1), _sqlExpressionFactory.FindMapping(typeof(int))),
true)
: orderingExpression;
}

return visited;
}
}

/// <summary>
/// Visits an expression tree returns identifies whether it contains any non-constant nodes, e.g. columns.
/// </summary>
private class ConstantExpressionIdentifier : ExpressionVisitor
{
private bool _wasConstant;

public bool IsExpressionConstant(Expression node)
{
_wasConstant = true;
Visit(node);
return _wasConstant;
}

public override Expression Visit(Expression node)
{
// TODO: can be optimized by immediately returning null when a matching node is found (but must be handled everywhere...)

switch (node)
{
case ColumnExpression _:
case SqlFragmentExpression _: // May or may not be constant, but better to error than to give wrong results
_wasConstant = false;
break;
}

return base.Visit(node);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1181,9 +1181,9 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
var orderings = new List<OrderingExpression>();
foreach (var ordering in _orderings)
{
var orderingExpression = (SqlExpression)visitor.Visit(ordering.Expression);
changed |= orderingExpression != ordering.Expression;
orderings.Add(ordering.Update(orderingExpression));
var newOrdering = (OrderingExpression)visitor.Visit(ordering);
changed |= newOrdering != ordering;
orderings.Add(newOrdering);
}

var offset = (SqlExpression)visitor.Visit(Offset);
Expand Down
13 changes: 13 additions & 0 deletions test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3642,6 +3642,7 @@ FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND ((c[""OrderID""] > 690) AND (c[""OrderID""] < 710)))");
}

[ConditionalTheory(Skip = "Should work, requires testing etc.")]
public override async Task OrderBy_empty_list_contains(bool isAsync)
{
await base.OrderBy_empty_list_contains(isAsync);
Expand All @@ -3652,6 +3653,18 @@ FROM root c
WHERE (c[""Discriminator""] = ""Customer"")");
}

[ConditionalTheory(Skip = "Should work, requires testing etc.")]
public override async Task OrderBy_contains_and_count(bool isAsync)
{
await base.OrderBy_contains_and_count(isAsync);

AssertSql(
@"SELECT c
FROM root c
WHERE (c[""Discriminator""] = ""Customer"")");
}

[ConditionalTheory(Skip = "Should work, requires testing etc.")]
public override async Task OrderBy_empty_list_does_not_contains(bool isAsync)
{
await base.OrderBy_empty_list_does_not_contains(isAsync);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6794,7 +6794,7 @@ public virtual Task Cast_ordered_subquery_to_base_type_using_typed_ToArray(bool
elementAsserter: CollectionAsserter<Gear>(e => e.Nickname, (e, a) => Assert.Equal(e.Nickname, a.Nickname)));
}

[ConditionalTheory(Skip = "Issue#15713")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Correlated_collection_with_complex_order_by_funcletized_to_constant_bool(bool isAsync)
{
Expand Down
4 changes: 2 additions & 2 deletions test/EFCore.Specification.Tests/Query/IncludeTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3933,7 +3933,7 @@ var orders
}
}

[ConditionalTheory(Skip = "Issue#15713")]
[ConditionalTheory]
[InlineData(false)]
[InlineData(true)]
public virtual void Include_collection_OrderBy_empty_list_contains(bool useString)
Expand Down Expand Up @@ -3968,7 +3968,7 @@ var customers
}
}

[ConditionalTheory(Skip = "Issue#15713")]
[ConditionalTheory]
[InlineData(false)]
[InlineData(true)]
public virtual void Include_collection_OrderBy_empty_list_does_not_contains(bool useString)
Expand Down
16 changes: 14 additions & 2 deletions test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5735,7 +5735,7 @@ on o.CustomerID equals c.CustomerID
.Take(5));
}

[ConditionalTheory(Skip = "Issue#15713")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task OrderBy_empty_list_contains(bool isAsync)
{
Expand All @@ -5747,7 +5747,19 @@ public virtual Task OrderBy_empty_list_contains(bool isAsync)
entryCount: 91);
}

[ConditionalTheory(Skip = "Issue#15713")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task OrderBy_contains_and_count(bool isAsync)
{
var list = new List<string>();

return AssertQuery<Customer>(
isAsync,
cs => cs.OrderBy(c => (list.Contains(c.CustomerID) ? 1 : 0) + cs.Count()).Select(c => c),
entryCount: 91);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task OrderBy_empty_list_does_not_contains(bool isAsync)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7343,19 +7343,11 @@ public override async Task Correlated_collection_with_complex_order_by_funcletiz
await base.Correlated_collection_with_complex_order_by_funcletized_to_constant_bool(isAsync);

AssertSql(
@"SELECT [g].[Nickname], [g].[FullName]
@"SELECT [g].[Nickname], [g].[SquadId], [w].[Name], [w].[Id], [w].[OwnerFullName]
FROM [Gears] AS [g]
WHERE [g].[Discriminator] IN (N'Officer', N'Gear')
ORDER BY [g].[Nickname], [g].[SquadId], [g].[FullName]",
//
@"SELECT [t].[c], [t].[Nickname], [t].[SquadId], [t].[FullName], [g.Weapons].[Name], [g.Weapons].[OwnerFullName]
FROM [Weapons] AS [g.Weapons]
INNER JOIN (
SELECT CAST(0 AS bit) AS [c], [g0].[Nickname], [g0].[SquadId], [g0].[FullName]
FROM [Gears] AS [g0]
WHERE [g0].[Discriminator] IN (N'Officer', N'Gear')
) AS [t] ON [g.Weapons].[OwnerFullName] = [t].[FullName]
ORDER BY [t].[c] DESC, [t].[Nickname], [t].[SquadId], [t].[FullName]");
LEFT JOIN [Weapons] AS [w] ON [g].[FullName] = [w].[OwnerFullName]
WHERE [g].[Discriminator] IN (N'Gear', N'Officer')
ORDER BY [g].[Nickname], [g].[SquadId], [w].[Id]");
}

public override async Task Double_order_by_on_nullable_bool_coming_from_optional_navigation(bool isAsync)
Expand Down
58 changes: 24 additions & 34 deletions test/EFCore.SqlServer.FunctionalTests/Query/IncludeSqlServerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1120,24 +1120,19 @@ public override void Include_collection_OrderBy_empty_list_contains(bool useStri
AssertSql(
@"@__p_1='1'

SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE [c].[CustomerID] LIKE N'A%'
ORDER BY (SELECT 1), [c].[CustomerID]
OFFSET @__p_1 ROWS",
//
@"@__p_1='1'

SELECT [c.Orders].[OrderID], [c.Orders].[CustomerID], [c.Orders].[EmployeeID], [c.Orders].[OrderDate]
FROM [Orders] AS [c.Orders]
INNER JOIN (
SELECT [c0].[CustomerID], CAST(0 AS bit) AS [c]
FROM [Customers] AS [c0]
WHERE [c0].[CustomerID] LIKE N'A%'
ORDER BY [c], [c0].[CustomerID]
SELECT [t].[CustomerID], [t].[Address], [t].[City], [t].[CompanyName], [t].[ContactName], [t].[ContactTitle], [t].[Country], [t].[Fax], [t].[Phone], [t].[PostalCode], [t].[Region], [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM (
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region], CASE
WHEN CAST(1 AS bit) = CAST(0 AS bit) THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [c]
FROM [Customers] AS [c]
WHERE [c].[CustomerID] LIKE N'A%'
ORDER BY (SELECT 1)
OFFSET @__p_1 ROWS
) AS [t] ON [c.Orders].[CustomerID] = [t].[CustomerID]
ORDER BY [t].[c], [t].[CustomerID]");
) AS [t]
LEFT JOIN [Orders] AS [o] ON [t].[CustomerID] = [o].[CustomerID]
ORDER BY [t].[c], [t].[CustomerID], [o].[OrderID]");
}

public override void Include_collection_OrderBy_empty_list_does_not_contains(bool useString)
Expand All @@ -1147,24 +1142,19 @@ public override void Include_collection_OrderBy_empty_list_does_not_contains(boo
AssertSql(
@"@__p_1='1'

SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE [c].[CustomerID] LIKE N'A%'
ORDER BY (SELECT 1), [c].[CustomerID]
OFFSET @__p_1 ROWS",
//
@"@__p_1='1'

SELECT [c.Orders].[OrderID], [c.Orders].[CustomerID], [c.Orders].[EmployeeID], [c.Orders].[OrderDate]
FROM [Orders] AS [c.Orders]
INNER JOIN (
SELECT [c0].[CustomerID], CAST(1 AS bit) AS [c]
FROM [Customers] AS [c0]
WHERE [c0].[CustomerID] LIKE N'A%'
ORDER BY [c], [c0].[CustomerID]
SELECT [t].[CustomerID], [t].[Address], [t].[City], [t].[CompanyName], [t].[ContactName], [t].[ContactTitle], [t].[Country], [t].[Fax], [t].[Phone], [t].[PostalCode], [t].[Region], [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM (
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region], CASE
WHEN CAST(1 AS bit) = CAST(1 AS bit) THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [c]
FROM [Customers] AS [c]
WHERE [c].[CustomerID] LIKE N'A%'
ORDER BY (SELECT 1)
OFFSET @__p_1 ROWS
) AS [t] ON [c.Orders].[CustomerID] = [t].[CustomerID]
ORDER BY [t].[c], [t].[CustomerID]");
) AS [t]
LEFT JOIN [Orders] AS [o] ON [t].[CustomerID] = [o].[CustomerID]
ORDER BY [t].[c], [t].[CustomerID], [o].[OrderID]");
}

public override void Include_collection_OrderBy_list_contains(bool useString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4681,6 +4681,18 @@ public override async Task OrderBy_empty_list_contains(bool isAsync)
FROM [Customers] AS [c]");
}

public override async Task OrderBy_contains_and_count(bool isAsync)
{
await base.OrderBy_contains_and_count(isAsync);

AssertSql(
@"SELECT COUNT(*)
FROM [Customers] AS [c]",
//
@"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]");
}

public override async Task OrderBy_empty_list_does_not_contains(bool isAsync)
{
await base.OrderBy_empty_list_does_not_contains(isAsync);
Expand Down