forked from npgsql/efcore.pg
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Translate PostgreSQL custom aggregates
* string_agg (string.Join) * array_agg * json_agg/jsonb_agg * json_object_agg/jsonb_object_agg * range_agg, range_intersect_agg * Aggregate statistics functions Closes npgsql#532
- Loading branch information
Showing
11 changed files
with
1,287 additions
and
56 deletions.
There are no files selected for viewing
514 changes: 514 additions & 0 deletions
514
src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlAggregateDbFunctionsExtensions.cs
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
136 changes: 136 additions & 0 deletions
136
src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMiscAggregateMethodTranslator.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal; | ||
using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping; | ||
using static Npgsql.EntityFrameworkCore.PostgreSQL.Utilities.Statics; | ||
|
||
namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal; | ||
|
||
public class NpgsqlMiscAggregateMethodTranslator : IAggregateMethodCallTranslator | ||
{ | ||
private static readonly MethodInfo StringJoin | ||
= typeof(string).GetRuntimeMethod(nameof(string.Join), new[] { typeof(string), typeof(IEnumerable<string>) })!; | ||
|
||
private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; | ||
private readonly IRelationalTypeMappingSource _typeMappingSource; | ||
|
||
public NpgsqlMiscAggregateMethodTranslator( | ||
NpgsqlSqlExpressionFactory sqlExpressionFactory, | ||
IRelationalTypeMappingSource typeMappingSource) | ||
{ | ||
_sqlExpressionFactory = sqlExpressionFactory; | ||
_typeMappingSource = typeMappingSource; | ||
} | ||
|
||
public virtual SqlExpression? Translate( | ||
MethodInfo method, | ||
EnumerableExpression source, | ||
IReadOnlyList<SqlExpression> arguments, | ||
IDiagnosticsLogger<DbLoggerCategory.Query> logger) | ||
{ | ||
if (source.Selector is not SqlExpression sqlExpression) | ||
{ | ||
return null; | ||
} | ||
|
||
if (method == StringJoin) | ||
{ | ||
// TODO: Align this with the upstream support, including Concat translation | ||
// TODO: Confirm result type mapping, what happens with e.g. varchar(8) | ||
return _sqlExpressionFactory.AggregateFunction( | ||
"string_agg", | ||
new[] { sqlExpression, arguments[0] }, | ||
source, | ||
nullable: true, argumentsPropagateNullability: new[] { false, true }, returnType: typeof(string)); | ||
} | ||
|
||
if (method.DeclaringType == typeof(NpgsqlAggregateDbFunctionsExtensions)) | ||
{ | ||
switch (method.Name) | ||
{ | ||
case nameof(NpgsqlAggregateDbFunctionsExtensions.ArrayAgg): | ||
var arrayClrType = sqlExpression.Type.MakeArrayType(); | ||
|
||
return _sqlExpressionFactory.AggregateFunction( | ||
"array_agg", | ||
new[] { sqlExpression }, | ||
source, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
returnType: arrayClrType, | ||
typeMapping: sqlExpression.TypeMapping is null | ||
? null | ||
: new NpgsqlArrayArrayTypeMapping(arrayClrType, sqlExpression.TypeMapping)); | ||
|
||
case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonAgg): | ||
arrayClrType = sqlExpression.Type.MakeArrayType(); | ||
|
||
return _sqlExpressionFactory.AggregateFunction( | ||
"json_agg", | ||
new[] { sqlExpression }, | ||
source, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
returnType: arrayClrType, | ||
_typeMappingSource.FindMapping(arrayClrType, "json")); | ||
|
||
case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbAgg): | ||
arrayClrType = sqlExpression.Type.MakeArrayType(); | ||
|
||
return _sqlExpressionFactory.AggregateFunction( | ||
"jsonb_agg", | ||
new[] { sqlExpression }, | ||
source, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
returnType: arrayClrType, | ||
_typeMappingSource.FindMapping(arrayClrType, "jsonb")); | ||
|
||
case nameof(NpgsqlAggregateDbFunctionsExtensions.RangeAgg): | ||
arrayClrType = sqlExpression.Type.MakeArrayType(); | ||
|
||
return _sqlExpressionFactory.AggregateFunction( | ||
"range_agg", | ||
new[] { sqlExpression }, | ||
source, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
returnType: arrayClrType, | ||
_typeMappingSource.FindMapping(arrayClrType)); | ||
|
||
case nameof(NpgsqlAggregateDbFunctionsExtensions.RangeIntersectAgg): | ||
return _sqlExpressionFactory.AggregateFunction( | ||
"range_intersect_agg", | ||
new[] { sqlExpression }, | ||
source, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
returnType: sqlExpression.Type, | ||
sqlExpression.TypeMapping); | ||
|
||
case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbObjectAgg): | ||
case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonObjectAgg): | ||
var isJsonb = method.Name == nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbObjectAgg); | ||
|
||
// These methods accept two enumerable (column) arguments; this is represented in LINQ as a projection from the grouping | ||
// to a tuple of the two columns. Since we generally translate tuples to PostgresRowValueExpression, we take it apart | ||
// here. | ||
if (source.Selector is not PostgresRowValueExpression rowValueExpression) | ||
{ | ||
return null; | ||
} | ||
|
||
var (keys, values) = (rowValueExpression.Values[0], rowValueExpression.Values[1]); | ||
|
||
return _sqlExpressionFactory.AggregateFunction( | ||
isJsonb ? "jsonb_object_agg" : "json_object_agg", | ||
new[] { keys, values }, | ||
source, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[2], | ||
returnType: method.ReturnType, | ||
_typeMappingSource.FindMapping(method.ReturnType, isJsonb ? "jsonb" : "json")); | ||
} | ||
} | ||
|
||
return null; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
104 changes: 104 additions & 0 deletions
104
...Core.PG/Query/ExpressionTranslators/Internal/NpgsqlStatisticsAggregateMethodTranslator.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal; | ||
using static Npgsql.EntityFrameworkCore.PostgreSQL.Utilities.Statics; | ||
|
||
namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal; | ||
|
||
public class NpgsqlStatisticsAggregateMethodTranslator : IAggregateMethodCallTranslator | ||
{ | ||
private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; | ||
private readonly RelationalTypeMapping _doubleTypeMapping, _longTypeMapping; | ||
|
||
public NpgsqlStatisticsAggregateMethodTranslator( | ||
NpgsqlSqlExpressionFactory sqlExpressionFactory, | ||
IRelationalTypeMappingSource typeMappingSource) | ||
{ | ||
_sqlExpressionFactory = sqlExpressionFactory; | ||
_doubleTypeMapping = typeMappingSource.FindMapping(typeof(double))!; | ||
_longTypeMapping = typeMappingSource.FindMapping(typeof(long))!; | ||
} | ||
|
||
public virtual SqlExpression? Translate( | ||
MethodInfo method, | ||
EnumerableExpression source, | ||
IReadOnlyList<SqlExpression> arguments, | ||
IDiagnosticsLogger<DbLoggerCategory.Query> logger) | ||
{ | ||
// Docs: https://www.postgresql.org/docs/current/functions-aggregate.html#FUNCTIONS-AGGREGATE-STATISTICS-TABLE | ||
|
||
if (method.DeclaringType != typeof(NpgsqlAggregateDbFunctionsExtensions) | ||
|| source.Selector is not SqlExpression sqlExpression) | ||
{ | ||
return null; | ||
} | ||
|
||
// These four functions are simple and take a single enumerable argument | ||
var functionName = method.Name switch | ||
{ | ||
nameof(NpgsqlAggregateDbFunctionsExtensions.StandardDeviationSample) => "stddev_samp", | ||
nameof(NpgsqlAggregateDbFunctionsExtensions.StandardDeviationPopulation) => "stddev_pop", | ||
nameof(NpgsqlAggregateDbFunctionsExtensions.VarianceSample) => "var_samp", | ||
nameof(NpgsqlAggregateDbFunctionsExtensions.VariancePopulation) => "var_pop", | ||
_ => null | ||
}; | ||
|
||
if (functionName is not null) | ||
{ | ||
return _sqlExpressionFactory.AggregateFunction( | ||
functionName, | ||
new[] { sqlExpression }, | ||
source, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
typeof(double), | ||
_doubleTypeMapping); | ||
} | ||
|
||
functionName = method.Name switch | ||
{ | ||
nameof(NpgsqlAggregateDbFunctionsExtensions.Correlation) => "corr", | ||
nameof(NpgsqlAggregateDbFunctionsExtensions.CovariancePopulation) => "covar_pop", | ||
nameof(NpgsqlAggregateDbFunctionsExtensions.CovarianceSample) => "covar_samp", | ||
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrAverageX) => "regr_avgx", | ||
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrAverageY) => "regr_avgy", | ||
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrCount) => "regr_count", | ||
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrIntercept) => "regr_intercept", | ||
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrR2) => "regr_r2", | ||
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrSlope) => "regr_slope", | ||
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrSXX) => "regr_sxx", | ||
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrSXY) => "regr_sxy", | ||
_ => null | ||
}; | ||
|
||
if (functionName is not null) | ||
{ | ||
// These methods accept two enumerable (column) arguments; this is represented in LINQ as a projection from the grouping | ||
// to a tuple of the two columns. Since we generally translate tuples to PostgresRowValueExpression, we take it apart here. | ||
if (source.Selector is not PostgresRowValueExpression rowValueExpression) | ||
{ | ||
return null; | ||
} | ||
|
||
var (y, x) = (rowValueExpression.Values[0], rowValueExpression.Values[1]); | ||
|
||
return method.Name == nameof(NpgsqlAggregateDbFunctionsExtensions.RegrCount) | ||
? _sqlExpressionFactory.AggregateFunction( | ||
functionName, | ||
new[] { y, x }, | ||
source, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[2], | ||
typeof(long), | ||
_longTypeMapping) | ||
: _sqlExpressionFactory.AggregateFunction( | ||
functionName, | ||
new[] { y, x }, | ||
source, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[2], | ||
typeof(double), | ||
_doubleTypeMapping); | ||
} | ||
|
||
return null; | ||
} | ||
} |
Oops, something went wrong.