-
Notifications
You must be signed in to change notification settings - Fork 228
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
518 additions
and
143 deletions.
There are no files selected for viewing
14 changes: 14 additions & 0 deletions
14
src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlAggregateDbFunctionsExtensions.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
// ReSharper disable once CheckNamespace | ||
namespace Microsoft.EntityFrameworkCore; | ||
|
||
/// <summary> | ||
/// Provides extension methods supporting aggregate function translation for PostgreSQL. | ||
/// </summary> | ||
public static class NpgsqlAggregateDbFunctionsExtensions | ||
{ | ||
// TODO: List | ||
// TODO: Docs | ||
// TODO: Distinguish between the two array_agg... Probably by name - though the multidim version accepts arrays of arrays/enumerables | ||
public static T[] ArrayAgg<T>(this DbFunctions _, IEnumerable<T> input) | ||
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(ArrayAgg))); | ||
} |
5 changes: 1 addition & 4 deletions
5
src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlDbFunctionsExtensions.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 0 additions & 4 deletions
4
src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlMultirangeDbFunctionsExtensions.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 1 addition & 3 deletions
4
src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlRangeDbFunctionsExtensions.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 1 addition & 3 deletions
4
src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlTrigramsDbFunctionsExtensions.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
18 changes: 18 additions & 0 deletions
18
...re.PG/Query/ExpressionTranslators/Internal/NpgsqlAggregateMethodCallTranslatorProvider.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal; | ||
|
||
public class NpgsqlAggregateMethodCallTranslatorProvider : RelationalAggregateMethodCallTranslatorProvider | ||
{ | ||
public NpgsqlAggregateMethodCallTranslatorProvider(RelationalAggregateMethodCallTranslatorProviderDependencies dependencies) | ||
: base(dependencies) | ||
{ | ||
var sqlExpressionFactory = (NpgsqlSqlExpressionFactory)dependencies.SqlExpressionFactory; | ||
var typeMappingSource = dependencies.RelationalTypeMappingSource; | ||
|
||
AddTranslators( | ||
new IAggregateMethodCallTranslator[] | ||
{ | ||
new NpgsqlQueryableAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource), | ||
new NpgsqlMiscAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource), | ||
}); | ||
} | ||
} |
60 changes: 60 additions & 0 deletions
60
src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMiscAggregateMethodTranslator.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
using static Npgsql.EntityFrameworkCore.PostgreSQL.Utilities.Statics; | ||
|
||
namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal; | ||
|
||
public class NpgsqlMiscAggregateMethodTranslator : IAggregateMethodCallTranslator | ||
{ | ||
private static readonly MethodInfo StringJoin | ||
= typeof(string).GetRuntimeMethod(nameof(string.Join), new[] { typeof(string), typeof(IEnumerable<string>) })!; | ||
|
||
private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; | ||
private readonly IRelationalTypeMappingSource _typeMappingSource; | ||
|
||
public NpgsqlMiscAggregateMethodTranslator( | ||
NpgsqlSqlExpressionFactory sqlExpressionFactory, | ||
IRelationalTypeMappingSource typeMappingSource) | ||
{ | ||
_sqlExpressionFactory = sqlExpressionFactory; | ||
_typeMappingSource = typeMappingSource; | ||
} | ||
|
||
public virtual SqlExpression? Translate( | ||
MethodInfo method, | ||
EnumerableExpression source, | ||
IReadOnlyList<SqlExpression> arguments, | ||
IDiagnosticsLogger<DbLoggerCategory.Query> logger) | ||
{ | ||
if (source.Selector is not SqlExpression sqlExpression) | ||
{ | ||
return null; | ||
} | ||
|
||
if (method == StringJoin) | ||
{ | ||
return _sqlExpressionFactory.AggregateFunction( | ||
"string_agg", | ||
new[] { sqlExpression, arguments[0] }, | ||
nullable: true, | ||
argumentsPropagateNullability: new[] { false, true }, | ||
source, | ||
typeof(string)); | ||
} | ||
|
||
if (method.DeclaringType == typeof(NpgsqlAggregateDbFunctionsExtensions)) | ||
{ | ||
switch (method.Name) | ||
{ | ||
case nameof(NpgsqlAggregateDbFunctionsExtensions.ArrayAgg): | ||
return _sqlExpressionFactory.AggregateFunction( | ||
"array_agg", | ||
new[] { sqlExpression }, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
sqlExpression.Type.MakeArrayType()); // TODO: No, infer type mapping | ||
} | ||
} | ||
|
||
return null; | ||
} | ||
} |
169 changes: 169 additions & 0 deletions
169
...FCore.PG/Query/ExpressionTranslators/Internal/NpgsqlQueryableAggregateMethodTranslator.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
using static Npgsql.EntityFrameworkCore.PostgreSQL.Utilities.Statics; | ||
|
||
namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal; | ||
|
||
public class NpgsqlQueryableAggregateMethodTranslator : IAggregateMethodCallTranslator | ||
{ | ||
private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; | ||
private readonly IRelationalTypeMappingSource _typeMappingSource; | ||
|
||
public NpgsqlQueryableAggregateMethodTranslator( | ||
NpgsqlSqlExpressionFactory sqlExpressionFactory, | ||
IRelationalTypeMappingSource typeMappingSource) | ||
{ | ||
_sqlExpressionFactory = sqlExpressionFactory; | ||
_typeMappingSource = typeMappingSource; | ||
} | ||
|
||
public virtual SqlExpression? Translate( | ||
MethodInfo method, | ||
EnumerableExpression source, | ||
IReadOnlyList<SqlExpression> arguments, | ||
IDiagnosticsLogger<DbLoggerCategory.Query> logger) | ||
{ | ||
if (method.DeclaringType == typeof(Queryable)) | ||
{ | ||
var methodInfo = method.IsGenericMethod | ||
? method.GetGenericMethodDefinition() | ||
: method; | ||
switch (methodInfo.Name) | ||
{ | ||
case nameof(Queryable.Average) | ||
when (QueryableMethods.IsAverageWithoutSelector(methodInfo) | ||
|| QueryableMethods.IsAverageWithSelector(methodInfo)) | ||
&& source.Selector is SqlExpression averageSqlExpression: | ||
var averageInputType = averageSqlExpression.Type; | ||
if (averageInputType == typeof(int) | ||
|| averageInputType == typeof(long)) | ||
{ | ||
averageSqlExpression = _sqlExpressionFactory.ApplyDefaultTypeMapping( | ||
_sqlExpressionFactory.Convert(averageSqlExpression, typeof(double))); | ||
} | ||
|
||
return averageInputType == typeof(float) | ||
? _sqlExpressionFactory.Convert( | ||
_sqlExpressionFactory.AggregateFunction( | ||
"AVG", | ||
new[] { averageSqlExpression }, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
typeof(double)), | ||
averageSqlExpression.Type, | ||
averageSqlExpression.TypeMapping) | ||
: _sqlExpressionFactory.AggregateFunction( | ||
"AVG", | ||
new[] { averageSqlExpression }, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
averageSqlExpression.Type, | ||
averageSqlExpression.TypeMapping); | ||
|
||
// PostgreSQL COUNT() always returns bigint, so we need to downcast to int | ||
case nameof(Queryable.Count) | ||
when methodInfo == QueryableMethods.CountWithoutPredicate | ||
|| methodInfo == QueryableMethods.CountWithPredicate: | ||
var countSqlExpression = (source.Selector as SqlExpression) ?? _sqlExpressionFactory.Fragment("*"); | ||
return _sqlExpressionFactory.Convert( | ||
_sqlExpressionFactory.ApplyDefaultTypeMapping( | ||
_sqlExpressionFactory.AggregateFunction( | ||
"COUNT", | ||
new[] { countSqlExpression }, | ||
nullable: false, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
typeof(long))), | ||
typeof(int), _typeMappingSource.FindMapping(typeof(int))); | ||
|
||
case nameof(Queryable.LongCount) | ||
when methodInfo == QueryableMethods.LongCountWithoutPredicate | ||
|| methodInfo == QueryableMethods.LongCountWithPredicate: | ||
var longCountSqlExpression = (source.Selector as SqlExpression) ?? _sqlExpressionFactory.Fragment("*"); | ||
return _sqlExpressionFactory.ApplyDefaultTypeMapping( | ||
_sqlExpressionFactory.AggregateFunction( | ||
"COUNT", | ||
new[] { longCountSqlExpression }, | ||
nullable: false, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
typeof(long))); | ||
|
||
case nameof(Queryable.Max) | ||
when (methodInfo == QueryableMethods.MaxWithoutSelector | ||
|| methodInfo == QueryableMethods.MaxWithSelector) | ||
&& source.Selector is SqlExpression maxSqlExpression: | ||
return _sqlExpressionFactory.AggregateFunction( | ||
"MAX", | ||
new[] { maxSqlExpression }, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
maxSqlExpression.Type, | ||
maxSqlExpression.TypeMapping); | ||
|
||
case nameof(Queryable.Min) | ||
when (methodInfo == QueryableMethods.MinWithoutSelector | ||
|| methodInfo == QueryableMethods.MinWithSelector) | ||
&& source.Selector is SqlExpression minSqlExpression: | ||
return _sqlExpressionFactory.AggregateFunction( | ||
"MIN", | ||
new[] { minSqlExpression }, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
minSqlExpression.Type, | ||
minSqlExpression.TypeMapping); | ||
|
||
// In PostgreSQL SUM() doesn't return the same type as its argument for smallint, int and bigint. | ||
// Cast to get the same type. | ||
// http://www.postgresql.org/docs/current/static/functions-aggregate.html | ||
case nameof(Queryable.Sum) | ||
when (QueryableMethods.IsSumWithoutSelector(methodInfo) | ||
|| QueryableMethods.IsSumWithSelector(methodInfo)) | ||
&& source.Selector is SqlExpression sumSqlExpression: | ||
var sumInputType = sumSqlExpression.Type; | ||
|
||
// Note that there is no Sum over short in LINQ | ||
if (sumInputType == typeof(int)) | ||
{ | ||
return _sqlExpressionFactory.Convert( | ||
_sqlExpressionFactory.AggregateFunction( | ||
"SUM", | ||
new[] { sumSqlExpression }, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
typeof(long)), | ||
sumInputType, | ||
sumSqlExpression.TypeMapping); | ||
} | ||
|
||
if (sumInputType == typeof(long)) | ||
{ | ||
return _sqlExpressionFactory.Convert( | ||
_sqlExpressionFactory.AggregateFunction( | ||
"SUM", | ||
new[] { sumSqlExpression }, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
typeof(decimal)), | ||
sumInputType, | ||
sumSqlExpression.TypeMapping); | ||
} | ||
|
||
return _sqlExpressionFactory.AggregateFunction( | ||
"SUM", | ||
new[] { sumSqlExpression }, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
sumInputType, | ||
sumSqlExpression.TypeMapping); | ||
} | ||
} | ||
|
||
return null; | ||
} | ||
} |
Oops, something went wrong.