Skip to content

Commit

Permalink
Fix to #29638 - GroupBy generates invalid SQL when using custom datab…
Browse files Browse the repository at this point in the history
…ase function

Problem is that CloningExpressionVisitor doesn't have proper handling for TableValuedFunctionExpression, and therefore goes through default expression visitor pattern (visit all children, check if there are any changes, if there are return new, if not return the same). Since there are no changes, the same instance is returned from cloning, and causes the problem.
Fix is to add proper handling of TVFExpression in the CloningExpressionVisitor so that it produces a proper copy.

Fixes #29638
  • Loading branch information
maumar committed Apr 4, 2023
1 parent bdd9846 commit f410b99
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics.CodeAnalysis;
using Microsoft.EntityFrameworkCore.Metadata.Internal;

namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions;

Expand Down Expand Up @@ -1008,6 +1009,27 @@ private sealed class CloningExpressionVisitor : ExpressionVisitor
return newTpcTable;
}

if (expression is TableValuedFunctionExpression tableValuedFunctionExpression)
{
var newArguments = new SqlExpression[tableValuedFunctionExpression.Arguments.Count];
for (var i = 0; i < newArguments.Length; i++)
{
newArguments[i] = (SqlExpression)Visit(tableValuedFunctionExpression.Arguments[i]);
}

var newTableValuedFunctionExpression = new TableValuedFunctionExpression(
tableValuedFunctionExpression.StoreFunction,
newArguments);

newTableValuedFunctionExpression.Alias = tableValuedFunctionExpression.Alias;
foreach (var annotation in tableValuedFunctionExpression.GetAnnotations())
{
newTableValuedFunctionExpression.AddAnnotation(annotation.Name, annotation.Value);
}

return newTableValuedFunctionExpression;
}

return expression is IClonableTableExpressionBase cloneable ? cloneable.Clone() : base.Visit(expression);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2133,6 +2133,52 @@ orderby t.ProductId
}
}

[ConditionalFact]
public virtual void TVF_with_navigation_in_projection_groupby_aggregate()
{
using (var context = CreateContext())
{
var query = context.Orders
.Where(c => !context.Set<TopSellingProduct>().Select(x => x.ProductId).Contains(25))
.Select(x => new { x.Customer.FirstName, x.Customer.LastName })
.GroupBy(x => new { x.LastName })
.Select(x => new { x.Key.LastName, SumOfLengths = x.Sum(xx => xx.FirstName.Length) })
.ToList();

Assert.Equal(3, query.Count);
var orderedResult = query.OrderBy(x => x.LastName).ToList();
Assert.Equal("One", orderedResult[0].LastName);
Assert.Equal(24, orderedResult[0].SumOfLengths);
Assert.Equal("Three", orderedResult[1].LastName);
Assert.Equal(8, orderedResult[1].SumOfLengths);
Assert.Equal("Two", orderedResult[2].LastName);
Assert.Equal(16, orderedResult[2].SumOfLengths);
}
}

[ConditionalFact]
public virtual void TVF_with_argument_being_a_subquery_with_navigation_in_projection_groupby_aggregate()
{
using (var context = CreateContext())
{
var query = context.Orders
.Where(c => !context.GetOrdersWithMultipleProducts(context.Customers.OrderBy(x => x.Id).FirstOrDefault().Id).Select(x => x.CustomerId).Contains(25))
.Select(x => new { x.Customer.FirstName, x.Customer.LastName })
.GroupBy(x => new { x.LastName })
.Select(x => new { x.Key.LastName, SumOfLengths = x.Sum(xx => xx.FirstName.Length) })
.ToList();

Assert.Equal(3, query.Count);
var orderedResult = query.OrderBy(x => x.LastName).ToList();
Assert.Equal("One", orderedResult[0].LastName);
Assert.Equal(24, orderedResult[0].SumOfLengths);
Assert.Equal("Three", orderedResult[1].LastName);
Assert.Equal(8, orderedResult[1].SumOfLengths);
Assert.Equal("Two", orderedResult[2].LastName);
Assert.Equal(16, orderedResult[2].SumOfLengths);
}
}

[ConditionalFact]
public virtual void TVF_backing_entity_type_mapped_to_view()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,62 @@ ORDER BY [g].[ProductId]
""");
}

public override void TVF_with_navigation_in_projection_groupby_aggregate()
{
base.TVF_with_navigation_in_projection_groupby_aggregate();

AssertSql(
"""
SELECT [c].[LastName], (
SELECT COALESCE(SUM(CAST(LEN([c1].[FirstName]) AS int)), 0)
FROM [Orders] AS [o0]
INNER JOIN [Customers] AS [c0] ON [o0].[CustomerId] = [c0].[Id]
INNER JOIN [Customers] AS [c1] ON [o0].[CustomerId] = [c1].[Id]
WHERE NOT (EXISTS (
SELECT 1
FROM [dbo].[GetTopTwoSellingProducts]() AS [g0]
WHERE [g0].[ProductId] = 25)) AND ([c].[LastName] = [c0].[LastName] OR ([c].[LastName] IS NULL AND [c0].[LastName] IS NULL))) AS [SumOfLengths]
FROM [Orders] AS [o]
INNER JOIN [Customers] AS [c] ON [o].[CustomerId] = [c].[Id]
WHERE NOT (EXISTS (
SELECT 1
FROM [dbo].[GetTopTwoSellingProducts]() AS [g]
WHERE [g].[ProductId] = 25))
GROUP BY [c].[LastName]
""");
}

public override void TVF_with_argument_being_a_subquery_with_navigation_in_projection_groupby_aggregate()
{
base.TVF_with_argument_being_a_subquery_with_navigation_in_projection_groupby_aggregate();

AssertSql(
"""
SELECT [c0].[LastName], (
SELECT COALESCE(SUM(CAST(LEN([c2].[FirstName]) AS int)), 0)
FROM [Orders] AS [o0]
INNER JOIN [Customers] AS [c1] ON [o0].[CustomerId] = [c1].[Id]
INNER JOIN [Customers] AS [c2] ON [o0].[CustomerId] = [c2].[Id]
WHERE NOT (EXISTS (
SELECT 1
FROM [dbo].[GetOrdersWithMultipleProducts]((
SELECT TOP(1) [c3].[Id]
FROM [Customers] AS [c3]
ORDER BY [c3].[Id])) AS [g0]
WHERE [g0].[CustomerId] = 25)) AND ([c0].[LastName] = [c1].[LastName] OR ([c0].[LastName] IS NULL AND [c1].[LastName] IS NULL))) AS [SumOfLengths]
FROM [Orders] AS [o]
INNER JOIN [Customers] AS [c0] ON [o].[CustomerId] = [c0].[Id]
WHERE NOT (EXISTS (
SELECT 1
FROM [dbo].[GetOrdersWithMultipleProducts]((
SELECT TOP(1) [c].[Id]
FROM [Customers] AS [c]
ORDER BY [c].[Id])) AS [g]
WHERE [g].[CustomerId] = 25))
GROUP BY [c0].[LastName]
""");
}

public override void TVF_backing_entity_type_mapped_to_view()
{
base.TVF_backing_entity_type_mapped_to_view();
Expand Down

0 comments on commit f410b99

Please sign in to comment.