From 9def04fb3fb0dfc1b118c4246119d33b9e27cae5 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Sat, 4 Jun 2022 18:32:43 +0200 Subject: [PATCH] Translate PostgreSQL custom aggregates * string_agg (string.Join) * array_agg * json_agg/jsonb_agg * json_object_agg/jsonb_object_agg * range_agg, range_intersect_agg * Aggregate statistics functions Closes #2395 Closes #532 --- .../NpgsqlAggregateDbFunctionsExtensions.cs | 514 ++++++++++++++++++ ...qlAggregateMethodCallTranslatorProvider.cs | 4 +- .../NpgsqlMiscAggregateMethodTranslator.cs | 136 +++++ ...pgsqlQueryableAggregateMethodTranslator.cs | 44 +- ...gsqlStatisticsAggregateMethodTranslator.cs | 104 ++++ .../Internal/PostgresFunctionExpression.cs | 27 +- .../Query/NpgsqlSqlExpressionFactory.cs | 8 +- .../Query/MultirangeQueryNpgsqlTest.cs | 31 ++ .../NorthwindFunctionsQueryNpgsqlTest.cs | 400 +++++++++++++- .../Query/NorthwindGroupByQueryNpgsqlTest.cs | 19 + .../Query/RangeQueryNpgsqlTest.cs | 56 ++ 11 files changed, 1287 insertions(+), 56 deletions(-) create mode 100644 src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlAggregateDbFunctionsExtensions.cs create mode 100644 src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMiscAggregateMethodTranslator.cs create mode 100644 src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStatisticsAggregateMethodTranslator.cs diff --git a/src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlAggregateDbFunctionsExtensions.cs b/src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlAggregateDbFunctionsExtensions.cs new file mode 100644 index 0000000000..435a9feecd --- /dev/null +++ b/src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlAggregateDbFunctionsExtensions.cs @@ -0,0 +1,514 @@ +// ReSharper disable once CheckNamespace +namespace Microsoft.EntityFrameworkCore; + +/// +/// Provides extension methods supporting aggregate function translation for PostgreSQL. +/// +public static class NpgsqlAggregateDbFunctionsExtensions +{ + /// + /// Collects all the input values, including nulls, into a PostgreSQL array. + /// Corresponds to the PostgreSQL array_agg aggregate function. + /// + /// The instance. + /// The input values to be aggregated into an array. + /// PostgreSQL documentation for aggregate functions. + public static T[] ArrayAgg(this DbFunctions _, IEnumerable input) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(ArrayAgg))); + + /// + /// Collects all the input values, including nulls, into a json array. Values are converted to JSON as per to_json or + /// to_jsonb. Corresponds to the PostgreSQL json_agg aggregate function. + /// + /// The instance. + /// The input values to be aggregated into a JSON array. + /// PostgreSQL documentation for aggregate functions. + public static T[] JsonAgg(this DbFunctions _, IEnumerable input) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(JsonAgg))); + + /// + /// Collects all the input values, including nulls, into a jsonb array. Values are converted to JSON as per to_json or + /// to_jsonb. Corresponds to the PostgreSQL jsonb_agg aggregate function. + /// + /// The instance. + /// The input values to be aggregated into a JSON array. + /// PostgreSQL documentation for aggregate functions. + public static T[] JsonbAgg(this DbFunctions _, IEnumerable input) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(JsonbAgg))); + + #region Range + + /// + /// Computes the union of the non-null input ranges. Corresponds to the PostgreSQL range_agg aggregate function. + /// + /// The instance. + /// The ranges to be aggregated via union into a multirange. + /// PostgreSQL documentation for aggregate functions. + public static NpgsqlRange[] RangeAgg(this DbFunctions _, IEnumerable> input) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RangeAgg))); + + /// + /// Computes the intersection of the non-null input ranges. Corresponds to the PostgreSQL range_intersect_agg aggregate function. + /// + /// The instance. + /// The ranges to be aggregated via intersection into a multirange. + /// PostgreSQL documentation for aggregate functions. + public static NpgsqlRange RangeIntersectAgg(this DbFunctions _, IEnumerable> input) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RangeAgg))); + + /// + /// Computes the intersection of the non-null input multiranges. + /// Corresponds to the PostgreSQL range_intersect_agg aggregate function. + /// + /// The instance. + /// The ranges to be aggregated via intersection into a multirange. + /// PostgreSQL documentation for aggregate functions. + public static NpgsqlRange[] RangeIntersectAgg(this DbFunctions _, IEnumerable[]> input) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RangeAgg))); + + #endregion Range + + #region JsonObjectAgg + + /// + /// Collects all the key/value pairs into a JSON object. Key arguments are coerced to text; value arguments are converted as per + /// to_json. Values can be , but not keys. + /// Corresponds to the PostgreSQL json_object_agg aggregate function. + /// + /// The instance. + /// An enumerable of key-value pairs to be aggregated into a JSON object. + /// PostgreSQL documentation for aggregate functions. + public static string JsonObjectAgg(this DbFunctions _, IEnumerable<(T1, T2)> keyValuePairs) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(JsonObjectAgg))); + + /// + /// Collects all the key/value pairs into a JSON object. Key arguments are coerced to text; value arguments are converted as per + /// to_json. Values can be , but not keys. + /// Corresponds to the PostgreSQL json_object_agg aggregate function. + /// + /// The instance. + /// An enumerable of key-value pairs to be aggregated into a JSON object. + /// PostgreSQL documentation for aggregate functions. + public static TReturn JsonObjectAgg(this DbFunctions _, IEnumerable<(T1, T2)> keyValuePairs) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(JsonObjectAgg))); + + /// + /// Collects all the key/value pairs into a JSON object. Key arguments are coerced to text; value arguments are converted as per + /// to_jsonb. Values can be , but not keys. + /// Corresponds to the PostgreSQL jsonb_object_agg aggregate function. + /// + /// The instance. + /// An enumerable of key-value pairs to be aggregated into a JSON object. + /// PostgreSQL documentation for aggregate functions. + public static string JsonbObjectAgg(this DbFunctions _, IEnumerable<(T1, T2)> keyValuePairs) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(JsonbObjectAgg))); + + /// + /// Collects all the key/value pairs into a JSON object. Key arguments are coerced to text; value arguments are converted as per + /// to_jsonb. Values can be , but not keys. + /// Corresponds to the PostgreSQL jsonb_object_agg aggregate function. + /// + /// The instance. + /// An enumerable of key-value pairs to be aggregated into a JSON object. + /// PostgreSQL documentation for aggregate functions. + public static TReturn JsonbObjectAgg(this DbFunctions _, IEnumerable<(T1, T2)> keyValuePairs) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(JsonbObjectAgg))); + + #endregion JsonObjectAgg + + #region Sample standard deviation + + /// + /// Returns the sample standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_samp function. + /// + /// The instance. + /// The values. + /// The computed sample standard deviation. + public static double? StandardDeviationSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationSample))); + + /// + /// Returns the sample standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_samp function. + /// + /// The instance. + /// The values. + /// The computed sample standard deviation. + public static double? StandardDeviationSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationSample))); + + /// + /// Returns the sample standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_samp function. + /// + /// The instance. + /// The values. + /// The computed sample standard deviation. + public static double? StandardDeviationSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationSample))); + + /// + /// Returns the sample standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_samp function. + /// + /// The instance. + /// The values. + /// The computed sample standard deviation. + public static double? StandardDeviationSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationSample))); + + /// + /// Returns the sample standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_samp function. + /// + /// The instance. + /// The values. + /// The computed sample standard deviation. + public static double? StandardDeviationSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationSample))); + + /// + /// Returns the sample standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_samp function. + /// + /// The instance. + /// The values. + /// The computed sample standard deviation. + public static double? StandardDeviationSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationSample))); + + /// + /// Returns the sample standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_samp function. + /// + /// The instance. + /// The values. + /// The computed sample standard deviation. + public static double? StandardDeviationSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationSample))); + + #endregion Sample standard deviation + + #region Population standard deviation + + /// + /// Returns the population standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_pop function. + /// + /// The instance. + /// The values. + /// The computed population standard deviation. + public static double? StandardDeviationPopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationPopulation))); + + /// + /// Returns the population standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_pop function. + /// + /// The instance. + /// The values. + /// The computed population standard deviation. + public static double? StandardDeviationPopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationPopulation))); + + /// + /// Returns the population standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_pop function. + /// + /// The instance. + /// The values. + /// The computed population standard deviation. + public static double? StandardDeviationPopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationPopulation))); + + /// + /// Returns the population standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_pop function. + /// + /// The instance. + /// The values. + /// The computed population standard deviation. + public static double? StandardDeviationPopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationPopulation))); + + /// + /// Returns the population standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_pop function. + /// + /// The instance. + /// The values. + /// The computed population standard deviation. + public static double? StandardDeviationPopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationPopulation))); + + /// + /// Returns the population standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_pop function. + /// + /// The instance. + /// The values. + /// The computed population standard deviation. + public static double? StandardDeviationPopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationPopulation))); + + /// + /// Returns the population standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_pop function. + /// + /// The instance. + /// The values. + /// The computed population standard deviation. + public static double? StandardDeviationPopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationPopulation))); + + #endregion Population standard deviation + + #region Sample variance + + /// + /// Returns the sample variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_samp function. + /// + /// The instance. + /// The values. + /// The computed sample variance. + public static double? VarianceSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VarianceSample))); + + /// + /// Returns the sample variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_samp function. + /// + /// The instance. + /// The values. + /// The computed sample variance. + public static double? VarianceSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VarianceSample))); + + /// + /// Returns the sample variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_samp function. + /// + /// The instance. + /// The values. + /// The computed sample variance. + public static double? VarianceSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VarianceSample))); + + /// + /// Returns the sample variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_samp function. + /// + /// The instance. + /// The values. + /// The computed sample variance. + public static double? VarianceSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VarianceSample))); + + /// + /// Returns the sample variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_samp function. + /// + /// The instance. + /// The values. + /// The computed sample variance. + public static double? VarianceSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VarianceSample))); + + /// + /// Returns the sample variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_samp function. + /// + /// The instance. + /// The values. + /// The computed sample variance. + public static double? VarianceSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VarianceSample))); + + /// + /// Returns the sample variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_samp function. + /// + /// The instance. + /// The values. + /// The computed sample variance. + public static double? VarianceSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VarianceSample))); + + #endregion Sample variance + + #region Population variance + + /// + /// Returns the population variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_pop function. + /// + /// The instance. + /// The values. + /// The computed population variance. + public static double? VariancePopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VariancePopulation))); + + /// + /// Returns the population variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_pop function. + /// + /// The instance. + /// The values. + /// The computed population variance. + public static double? VariancePopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VariancePopulation))); + + /// + /// Returns the population variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_pop function. + /// + /// The instance. + /// The values. + /// The computed population variance. + public static double? VariancePopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VariancePopulation))); + + /// + /// Returns the population variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_pop function. + /// + /// The instance. + /// The values. + /// The computed population variance. + public static double? VariancePopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VariancePopulation))); + + /// + /// Returns the population variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_pop function. + /// + /// The instance. + /// The values. + /// The computed population variance. + public static double? VariancePopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VariancePopulation))); + + /// + /// Returns the population variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_pop function. + /// + /// The instance. + /// The values. + /// The computed population variance. + public static double? VariancePopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VariancePopulation))); + + /// + /// Returns the population variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_pop function. + /// + /// The instance. + /// The values. + /// The computed population variance. + public static double? VariancePopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VariancePopulation))); + + #endregion Population variance + + #region Other statistics functions + + /// + /// Computes the correlation coefficient. Corresponds to the PostgreSQL corr function. + /// + /// The instance. + /// The values. + public static double? Correlation(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Correlation))); + + /// + /// Computes the population covariance. Corresponds to the PostgreSQL covar_pop function. + /// + /// The instance. + /// The values. + public static double? CovariancePopulation(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(CovariancePopulation))); + + /// + /// Computes the sample covariance. Corresponds to the PostgreSQL covar_samp function. + /// + /// The instance. + /// The values. + public static double? CovarianceSample(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(CovarianceSample))); + + /// + /// Computes the average of the independent variable, sum(X)/N. + /// Corresponds to the PostgreSQL regr_avgx function. + /// + /// The instance. + /// The values. + public static double? RegrAverageX(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RegrAverageX))); + + /// + /// Computes the average of the dependent variable, sum(Y)/N. + /// Corresponds to the PostgreSQL regr_avgy function. + /// + /// The instance. + /// The values. + public static double? RegrAverageY(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RegrAverageY))); + + /// + /// Computes the number of rows in which both inputs are non-null. + /// Corresponds to the PostgreSQL regr_count function. + /// + /// The instance. + /// The values. + public static long? RegrCount(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RegrCount))); + + /// + /// Computes the y-intercept of the least-squares-fit linear equation determined by the (X, Y) pairs. + /// Corresponds to the PostgreSQL regr_intercept function. + /// + /// The instance. + /// The values. + public static double? RegrIntercept(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RegrIntercept))); + + /// + /// Computes the square of the correlation coefficient. + /// Corresponds to the PostgreSQL regr_r2 function. + /// + /// The instance. + /// The values. + public static double? RegrR2(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RegrR2))); + + /// + /// Computes the slope of the least-squares-fit linear equation determined by the (X, Y) pairs. + /// Corresponds to the PostgreSQL regr_slope function. + /// + /// The instance. + /// The values. + public static double? RegrSlope(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RegrSlope))); + + /// + /// Computes the “sum of squares” of the independent variable, sum(X^2) - sum(X)^2/N. + /// Corresponds to the PostgreSQL regr_sxx function. + /// + /// The instance. + /// The values. + public static double? RegrSXX(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RegrSXX))); + + /// + /// Computes the “sum of products” of independent times dependent variables, sum(X*Y) - sum(X) * sum(Y)/N. + /// Corresponds to the PostgreSQL regr_sxy function. + /// + /// The instance. + /// The values. + public static double? RegrSXY(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RegrSXY))); + + #endregion Other statistics functions +} diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlAggregateMethodCallTranslatorProvider.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlAggregateMethodCallTranslatorProvider.cs index c0687a6d11..b0a789317c 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlAggregateMethodCallTranslatorProvider.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlAggregateMethodCallTranslatorProvider.cs @@ -11,7 +11,9 @@ public NpgsqlAggregateMethodCallTranslatorProvider(RelationalAggregateMethodCall AddTranslators( new IAggregateMethodCallTranslator[] { - new NpgsqlQueryableAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource) + new NpgsqlQueryableAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource), + new NpgsqlStatisticsAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource), + new NpgsqlMiscAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource) }); } } diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMiscAggregateMethodTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMiscAggregateMethodTranslator.cs new file mode 100644 index 0000000000..4787093355 --- /dev/null +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMiscAggregateMethodTranslator.cs @@ -0,0 +1,136 @@ +using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal; +using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping; +using static Npgsql.EntityFrameworkCore.PostgreSQL.Utilities.Statics; + +namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal; + +public class NpgsqlMiscAggregateMethodTranslator : IAggregateMethodCallTranslator +{ + private static readonly MethodInfo StringJoin + = typeof(string).GetRuntimeMethod(nameof(string.Join), new[] { typeof(string), typeof(IEnumerable) })!; + + private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; + private readonly IRelationalTypeMappingSource _typeMappingSource; + + public NpgsqlMiscAggregateMethodTranslator( + NpgsqlSqlExpressionFactory sqlExpressionFactory, + IRelationalTypeMappingSource typeMappingSource) + { + _sqlExpressionFactory = sqlExpressionFactory; + _typeMappingSource = typeMappingSource; + } + + public virtual SqlExpression? Translate( + MethodInfo method, + EnumerableExpression source, + IReadOnlyList arguments, + IDiagnosticsLogger logger) + { + if (source.Selector is not SqlExpression sqlExpression) + { + return null; + } + + if (method == StringJoin) + { + // TODO: Align this with the upstream support, including Concat translation + // TODO: Confirm result type mapping, what happens with e.g. varchar(8) + return _sqlExpressionFactory.AggregateFunction( + "string_agg", + new[] { sqlExpression, arguments[0] }, + source, + nullable: true, argumentsPropagateNullability: new[] { false, true }, returnType: typeof(string)); + } + + if (method.DeclaringType == typeof(NpgsqlAggregateDbFunctionsExtensions)) + { + switch (method.Name) + { + case nameof(NpgsqlAggregateDbFunctionsExtensions.ArrayAgg): + var arrayClrType = sqlExpression.Type.MakeArrayType(); + + return _sqlExpressionFactory.AggregateFunction( + "array_agg", + new[] { sqlExpression }, + source, + nullable: true, + argumentsPropagateNullability: FalseArrays[1], + returnType: arrayClrType, + typeMapping: sqlExpression.TypeMapping is null + ? null + : new NpgsqlArrayArrayTypeMapping(arrayClrType, sqlExpression.TypeMapping)); + + case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonAgg): + arrayClrType = sqlExpression.Type.MakeArrayType(); + + return _sqlExpressionFactory.AggregateFunction( + "json_agg", + new[] { sqlExpression }, + source, + nullable: true, + argumentsPropagateNullability: FalseArrays[1], + returnType: arrayClrType, + _typeMappingSource.FindMapping(arrayClrType, "json")); + + case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbAgg): + arrayClrType = sqlExpression.Type.MakeArrayType(); + + return _sqlExpressionFactory.AggregateFunction( + "jsonb_agg", + new[] { sqlExpression }, + source, + nullable: true, + argumentsPropagateNullability: FalseArrays[1], + returnType: arrayClrType, + _typeMappingSource.FindMapping(arrayClrType, "jsonb")); + + case nameof(NpgsqlAggregateDbFunctionsExtensions.RangeAgg): + arrayClrType = sqlExpression.Type.MakeArrayType(); + + return _sqlExpressionFactory.AggregateFunction( + "range_agg", + new[] { sqlExpression }, + source, + nullable: true, + argumentsPropagateNullability: FalseArrays[1], + returnType: arrayClrType, + _typeMappingSource.FindMapping(arrayClrType)); + + case nameof(NpgsqlAggregateDbFunctionsExtensions.RangeIntersectAgg): + return _sqlExpressionFactory.AggregateFunction( + "range_intersect_agg", + new[] { sqlExpression }, + source, + nullable: true, + argumentsPropagateNullability: FalseArrays[1], + returnType: sqlExpression.Type, + sqlExpression.TypeMapping); + + case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbObjectAgg): + case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonObjectAgg): + var isJsonb = method.Name == nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbObjectAgg); + + // These methods accept two enumerable (column) arguments; this is represented in LINQ as a projection from the grouping + // to a tuple of the two columns. Since we generally translate tuples to PostgresRowValueExpression, we take it apart + // here. + if (source.Selector is not PostgresRowValueExpression rowValueExpression) + { + return null; + } + + var (keys, values) = (rowValueExpression.Values[0], rowValueExpression.Values[1]); + + return _sqlExpressionFactory.AggregateFunction( + isJsonb ? "jsonb_object_agg" : "json_object_agg", + new[] { keys, values }, + source, + nullable: true, + argumentsPropagateNullability: FalseArrays[2], + returnType: method.ReturnType, + _typeMappingSource.FindMapping(method.ReturnType, isJsonb ? "jsonb" : "json")); + } + } + + return null; + } +} diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlQueryableAggregateMethodTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlQueryableAggregateMethodTranslator.cs index 5bb1c494bf..58afd6b02a 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlQueryableAggregateMethodTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlQueryableAggregateMethodTranslator.cs @@ -45,20 +45,16 @@ public NpgsqlQueryableAggregateMethodTranslator( _sqlExpressionFactory.AggregateFunction( "AVG", new[] { averageSqlExpression }, - nullable: true, - argumentsPropagateNullability: FalseArrays[1], source, - typeof(double)), + nullable: true, argumentsPropagateNullability: FalseArrays[1], returnType: typeof(double)), averageSqlExpression.Type, averageSqlExpression.TypeMapping) : _sqlExpressionFactory.AggregateFunction( "AVG", new[] { averageSqlExpression }, - nullable: true, - argumentsPropagateNullability: FalseArrays[1], source, - averageSqlExpression.Type, - averageSqlExpression.TypeMapping); + nullable: true, + argumentsPropagateNullability: FalseArrays[1], returnType: averageSqlExpression.Type, typeMapping: averageSqlExpression.TypeMapping); // PostgreSQL COUNT() always returns bigint, so we need to downcast to int case nameof(Queryable.Count) @@ -70,10 +66,8 @@ public NpgsqlQueryableAggregateMethodTranslator( _sqlExpressionFactory.AggregateFunction( "COUNT", new[] { countSqlExpression }, - nullable: false, - argumentsPropagateNullability: FalseArrays[1], source, - typeof(long))), + nullable: false, argumentsPropagateNullability: FalseArrays[1], returnType: typeof(long))), typeof(int), _typeMappingSource.FindMapping(typeof(int))); case nameof(Queryable.LongCount) @@ -84,10 +78,8 @@ public NpgsqlQueryableAggregateMethodTranslator( _sqlExpressionFactory.AggregateFunction( "COUNT", new[] { longCountSqlExpression }, - nullable: false, - argumentsPropagateNullability: FalseArrays[1], source, - typeof(long))); + nullable: false, argumentsPropagateNullability: FalseArrays[1], returnType: typeof(long))); case nameof(Queryable.Max) when (methodInfo == QueryableMethods.MaxWithoutSelector @@ -96,11 +88,9 @@ public NpgsqlQueryableAggregateMethodTranslator( return _sqlExpressionFactory.AggregateFunction( "MAX", new[] { maxSqlExpression }, - nullable: true, - argumentsPropagateNullability: FalseArrays[1], source, - maxSqlExpression.Type, - maxSqlExpression.TypeMapping); + nullable: true, + argumentsPropagateNullability: FalseArrays[1], returnType: maxSqlExpression.Type, typeMapping: maxSqlExpression.TypeMapping); case nameof(Queryable.Min) when (methodInfo == QueryableMethods.MinWithoutSelector @@ -109,11 +99,9 @@ public NpgsqlQueryableAggregateMethodTranslator( return _sqlExpressionFactory.AggregateFunction( "MIN", new[] { minSqlExpression }, - nullable: true, - argumentsPropagateNullability: FalseArrays[1], source, - minSqlExpression.Type, - minSqlExpression.TypeMapping); + nullable: true, + argumentsPropagateNullability: FalseArrays[1], returnType: minSqlExpression.Type, typeMapping: minSqlExpression.TypeMapping); // In PostgreSQL SUM() doesn't return the same type as its argument for smallint, int and bigint. // Cast to get the same type. @@ -131,10 +119,8 @@ public NpgsqlQueryableAggregateMethodTranslator( _sqlExpressionFactory.AggregateFunction( "SUM", new[] { sumSqlExpression }, - nullable: true, - argumentsPropagateNullability: FalseArrays[1], source, - typeof(long)), + nullable: true, argumentsPropagateNullability: FalseArrays[1], returnType: typeof(long)), sumInputType, sumSqlExpression.TypeMapping); } @@ -145,10 +131,8 @@ public NpgsqlQueryableAggregateMethodTranslator( _sqlExpressionFactory.AggregateFunction( "SUM", new[] { sumSqlExpression }, - nullable: true, - argumentsPropagateNullability: FalseArrays[1], source, - typeof(decimal)), + nullable: true, argumentsPropagateNullability: FalseArrays[1], returnType: typeof(decimal)), sumInputType, sumSqlExpression.TypeMapping); } @@ -156,11 +140,9 @@ public NpgsqlQueryableAggregateMethodTranslator( return _sqlExpressionFactory.AggregateFunction( "SUM", new[] { sumSqlExpression }, - nullable: true, - argumentsPropagateNullability: FalseArrays[1], source, - sumInputType, - sumSqlExpression.TypeMapping); + nullable: true, + argumentsPropagateNullability: FalseArrays[1], returnType: sumInputType, typeMapping: sumSqlExpression.TypeMapping); } } diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStatisticsAggregateMethodTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStatisticsAggregateMethodTranslator.cs new file mode 100644 index 0000000000..d438eb3eb8 --- /dev/null +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStatisticsAggregateMethodTranslator.cs @@ -0,0 +1,104 @@ +using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal; +using static Npgsql.EntityFrameworkCore.PostgreSQL.Utilities.Statics; + +namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal; + +public class NpgsqlStatisticsAggregateMethodTranslator : IAggregateMethodCallTranslator +{ + private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; + private readonly RelationalTypeMapping _doubleTypeMapping, _longTypeMapping; + + public NpgsqlStatisticsAggregateMethodTranslator( + NpgsqlSqlExpressionFactory sqlExpressionFactory, + IRelationalTypeMappingSource typeMappingSource) + { + _sqlExpressionFactory = sqlExpressionFactory; + _doubleTypeMapping = typeMappingSource.FindMapping(typeof(double))!; + _longTypeMapping = typeMappingSource.FindMapping(typeof(long))!; + } + + public virtual SqlExpression? Translate( + MethodInfo method, + EnumerableExpression source, + IReadOnlyList arguments, + IDiagnosticsLogger logger) + { + // Docs: https://www.postgresql.org/docs/current/functions-aggregate.html#FUNCTIONS-AGGREGATE-STATISTICS-TABLE + + if (method.DeclaringType != typeof(NpgsqlAggregateDbFunctionsExtensions) + || source.Selector is not SqlExpression sqlExpression) + { + return null; + } + + // These four functions are simple and take a single enumerable argument + var functionName = method.Name switch + { + nameof(NpgsqlAggregateDbFunctionsExtensions.StandardDeviationSample) => "stddev_samp", + nameof(NpgsqlAggregateDbFunctionsExtensions.StandardDeviationPopulation) => "stddev_pop", + nameof(NpgsqlAggregateDbFunctionsExtensions.VarianceSample) => "var_samp", + nameof(NpgsqlAggregateDbFunctionsExtensions.VariancePopulation) => "var_pop", + _ => null + }; + + if (functionName is not null) + { + return _sqlExpressionFactory.AggregateFunction( + functionName, + new[] { sqlExpression }, + source, + nullable: true, + argumentsPropagateNullability: FalseArrays[1], + typeof(double), + _doubleTypeMapping); + } + + functionName = method.Name switch + { + nameof(NpgsqlAggregateDbFunctionsExtensions.Correlation) => "corr", + nameof(NpgsqlAggregateDbFunctionsExtensions.CovariancePopulation) => "covar_pop", + nameof(NpgsqlAggregateDbFunctionsExtensions.CovarianceSample) => "covar_samp", + nameof(NpgsqlAggregateDbFunctionsExtensions.RegrAverageX) => "regr_avgx", + nameof(NpgsqlAggregateDbFunctionsExtensions.RegrAverageY) => "regr_avgy", + nameof(NpgsqlAggregateDbFunctionsExtensions.RegrCount) => "regr_count", + nameof(NpgsqlAggregateDbFunctionsExtensions.RegrIntercept) => "regr_intercept", + nameof(NpgsqlAggregateDbFunctionsExtensions.RegrR2) => "regr_r2", + nameof(NpgsqlAggregateDbFunctionsExtensions.RegrSlope) => "regr_slope", + nameof(NpgsqlAggregateDbFunctionsExtensions.RegrSXX) => "regr_sxx", + nameof(NpgsqlAggregateDbFunctionsExtensions.RegrSXY) => "regr_sxy", + _ => null + }; + + if (functionName is not null) + { + // These methods accept two enumerable (column) arguments; this is represented in LINQ as a projection from the grouping + // to a tuple of the two columns. Since we generally translate tuples to PostgresRowValueExpression, we take it apart here. + if (source.Selector is not PostgresRowValueExpression rowValueExpression) + { + return null; + } + + var (y, x) = (rowValueExpression.Values[0], rowValueExpression.Values[1]); + + return method.Name == nameof(NpgsqlAggregateDbFunctionsExtensions.RegrCount) + ? _sqlExpressionFactory.AggregateFunction( + functionName, + new[] { y, x }, + source, + nullable: true, + argumentsPropagateNullability: FalseArrays[2], + typeof(long), + _longTypeMapping) + : _sqlExpressionFactory.AggregateFunction( + functionName, + new[] { y, x }, + source, + nullable: true, + argumentsPropagateNullability: FalseArrays[2], + typeof(double), + _doubleTypeMapping); + } + + return null; + } +} diff --git a/src/EFCore.PG/Query/Expressions/Internal/PostgresFunctionExpression.cs b/src/EFCore.PG/Query/Expressions/Internal/PostgresFunctionExpression.cs index b1534d84e2..1c8739aa81 100644 --- a/src/EFCore.PG/Query/Expressions/Internal/PostgresFunctionExpression.cs +++ b/src/EFCore.PG/Query/Expressions/Internal/PostgresFunctionExpression.cs @@ -58,9 +58,8 @@ public static PostgresFunctionExpression CreateWithNamedArguments( return new PostgresFunctionExpression( name, arguments, argumentNames, argumentSeparators: null, - nullable, argumentsPropagateNullability, aggregateDistinct: false, aggregatePredicate: null, aggregateOrderings: Array.Empty(), - type, typeMapping); + nullable: nullable, argumentsPropagateNullability: argumentsPropagateNullability, type: type, typeMapping: typeMapping); } public static PostgresFunctionExpression CreateWithArgumentSeparators( @@ -77,10 +76,9 @@ public static PostgresFunctionExpression CreateWithArgumentSeparators( Check.NotNull(argumentSeparators, nameof(argumentSeparators)); return new PostgresFunctionExpression( - name, arguments, argumentNames: null, argumentSeparators, - nullable, argumentsPropagateNullability, + name, arguments, argumentNames: null, argumentSeparators: argumentSeparators, aggregateDistinct: false, aggregatePredicate: null, aggregateOrderings: Array.Empty(), - type, typeMapping); + nullable: nullable, argumentsPropagateNullability: argumentsPropagateNullability, type: type, typeMapping: typeMapping); } public PostgresFunctionExpression( @@ -88,11 +86,11 @@ public PostgresFunctionExpression( IEnumerable arguments, IEnumerable? argumentNames, IEnumerable? argumentSeparators, - bool nullable, - IEnumerable argumentsPropagateNullability, bool aggregateDistinct, SqlExpression? aggregatePredicate, IReadOnlyList aggregateOrderings, + bool nullable, + IEnumerable argumentsPropagateNullability, Type type, RelationalTypeMapping? typeMapping) : base(name, arguments, nullable, argumentsPropagateNullability, type, typeMapping) @@ -175,11 +173,10 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) return changed ? new PostgresFunctionExpression( Name, visitedArguments ?? Arguments, ArgumentNames, ArgumentSeparators, - IsNullable, ArgumentsPropagateNullability!, IsAggregateDistinct, visitedAggregatePredicate ?? AggregatePredicate, visitedAggregateOrderings ?? AggregateOrderings, - Type, TypeMapping) + IsNullable, ArgumentsPropagateNullability!, Type, TypeMapping) : this; } @@ -189,13 +186,11 @@ public override SqlFunctionExpression ApplyTypeMapping(RelationalTypeMapping? ty Arguments, ArgumentNames, ArgumentSeparators, - IsNullable, - ArgumentsPropagateNullability, IsAggregateDistinct, AggregatePredicate, AggregateOrderings, - Type, - typeMapping ?? TypeMapping); + IsNullable, + ArgumentsPropagateNullability, Type, typeMapping ?? TypeMapping); public override SqlFunctionExpression Update(SqlExpression? instance, IReadOnlyList? arguments) { @@ -209,11 +204,10 @@ public override SqlFunctionExpression Update(SqlExpression? instance, IReadOnlyL return !arguments.SequenceEqual(Arguments) ? new PostgresFunctionExpression( Name, arguments, ArgumentNames, ArgumentSeparators, - IsNullable, ArgumentsPropagateNullability, IsAggregateDistinct, AggregatePredicate, AggregateOrderings, - Type, TypeMapping) + IsNullable, ArgumentsPropagateNullability, Type, TypeMapping) : this; } @@ -222,11 +216,10 @@ public virtual SqlFunctionExpression UpdateAggregateComponents(SqlExpression? pr return predicate != AggregatePredicate || orderings != AggregateOrderings ? new PostgresFunctionExpression( Name, Arguments, ArgumentNames, ArgumentSeparators, - IsNullable, ArgumentsPropagateNullability, IsAggregateDistinct, predicate, orderings, - Type, TypeMapping) + IsNullable, ArgumentsPropagateNullability, Type, TypeMapping) : this; } diff --git a/src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs b/src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs index f14de1eed2..5771e007e9 100644 --- a/src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs +++ b/src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs @@ -289,9 +289,9 @@ public virtual PostgresBinaryExpression Overlaps(SqlExpression left, SqlExpressi public virtual PostgresFunctionExpression AggregateFunction( string name, IEnumerable arguments, + EnumerableExpression aggregateEnumerableExpression, bool nullable, IEnumerable argumentsPropagateNullability, - EnumerableExpression aggregateEnumerableExpression, Type returnType, RelationalTypeMapping? typeMapping = null) { @@ -307,13 +307,11 @@ public virtual PostgresFunctionExpression AggregateFunction( typeMappedArguments, argumentNames: null, argumentSeparators: null, - nullable, - argumentsPropagateNullability, aggregateEnumerableExpression.IsDistinct, aggregateEnumerableExpression.Predicate, aggregateEnumerableExpression.Orderings, - returnType, - typeMapping); + nullable: nullable, + argumentsPropagateNullability: argumentsPropagateNullability, type: returnType, typeMapping: typeMapping); } #endregion Expression factory methods diff --git a/test/EFCore.PG.FunctionalTests/Query/MultirangeQueryNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/Query/MultirangeQueryNpgsqlTest.cs index cfa7a217d6..f371ba583b 100644 --- a/test/EFCore.PG.FunctionalTests/Query/MultirangeQueryNpgsqlTest.cs +++ b/test/EFCore.PG.FunctionalTests/Query/MultirangeQueryNpgsqlTest.cs @@ -438,6 +438,37 @@ public void Intersect_multirange() LIMIT 2"); } + [ConditionalFact] + public void Intersect_aggregate() + { + using var context = CreateContext(); + + var group = context.TestEntities + .Where(x => x.Id == 1 || x.Id == 2) + .GroupBy(x => true) + .Select(g => new + { + Union = EF.Functions.RangeIntersectAgg(g.Select(x => x.IntMultirange)) + }) + .Single(); + + Assert.Equal(new NpgsqlRange[] + { + new(4, true, 6, false), + new(7, true, 9, false) + }, group.Union); + + AssertSql( + @"SELECT range_intersect_agg(t0.""IntMultirange"") AS ""Union"" +FROM ( + SELECT t.""IntMultirange"", TRUE AS ""Key"" + FROM ""TestEntities"" AS t + WHERE t.""Id"" IN (1, 2) +) AS t0 +GROUP BY t0.""Key"" +LIMIT 2"); + } + [ConditionalFact] public void Except_multirange() { diff --git a/test/EFCore.PG.FunctionalTests/Query/NorthwindFunctionsQueryNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/Query/NorthwindFunctionsQueryNpgsqlTest.cs index c17f3b346c..b6f7f6ae4b 100644 --- a/test/EFCore.PG.FunctionalTests/Query/NorthwindFunctionsQueryNpgsqlTest.cs +++ b/test/EFCore.PG.FunctionalTests/Query/NorthwindFunctionsQueryNpgsqlTest.cs @@ -1,4 +1,5 @@ -using System.Text.RegularExpressions; +using System.Text.Json; +using System.Text.RegularExpressions; using Microsoft.EntityFrameworkCore.TestModels.Northwind; using Npgsql.EntityFrameworkCore.PostgreSQL.TestUtilities; @@ -12,7 +13,7 @@ public NorthwindFunctionsQueryNpgsqlTest( : base(fixture) { ClearLog(); - //Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper); + // Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper); } public override async Task IsNullOrWhiteSpace_in_predicate(bool async) @@ -275,6 +276,401 @@ public Task PadRight_char_with_parameter(bool async) #endregion + #region Aggregate functions + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task GroupBy_String_Join(bool async) + { + await AssertQuery( + async, + ss => ss.Set() + .GroupBy(o => o.EmployeeID) + .Select(g => new + { + EmployeeID = g.Key, + Customers = string.Join(", ", g.OrderBy(e => e.OrderID).Select(e => e.CustomerID)) + })); + + AssertSql( + @"SELECT o.""EmployeeID"", string_agg(o.""CustomerID"", ', ' ORDER BY o.""OrderID"" NULLS FIRST) AS ""Customers"" +FROM ""Orders"" AS o +GROUP BY o.""EmployeeID"""); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task GroupBy_ArrayAgg(bool async) + { + using var ctx = CreateContext(); + + var query = ctx.Set() + .GroupBy(c => c.City) + .Select(g => new + { + City = g.Key, + FaxNumbers = EF.Functions.ArrayAgg(g.Select(c => c.Fax).OrderBy(id => id)) + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var london = results.Single(x => x.City == "London"); + Assert.Collection( + london.FaxNumbers, + Assert.Null, + f => Assert.Equal("(171) 555-2530", f), + f => Assert.Equal("(171) 555-3373", f), + f => Assert.Equal("(171) 555-5646", f), + f => Assert.Equal("(171) 555-6750", f), + f => Assert.Equal("(171) 555-9199", f)); + + AssertSql( + @"SELECT c.""City"", array_agg(c.""Fax"" ORDER BY c.""Fax"" NULLS FIRST) AS ""FaxNumbers"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task GroupBy_JsonAgg(bool async) + { + using var ctx = CreateContext(); + + var query = ctx.Set() + .GroupBy(c => c.City) + .Select(g => new + { + City = g.Key, + FaxNumbers = EF.Functions.JsonAgg(g.Select(c => c.Fax).OrderBy(id => id)) + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var london = results.Single(x => x.City == "London"); + Assert.Collection( + london.FaxNumbers, + Assert.Null, + f => Assert.Equal("(171) 555-2530", f), + f => Assert.Equal("(171) 555-3373", f), + f => Assert.Equal("(171) 555-5646", f), + f => Assert.Equal("(171) 555-6750", f), + f => Assert.Equal("(171) 555-9199", f)); + + AssertSql( + @"SELECT c.""City"", json_agg(c.""Fax"" ORDER BY c.""Fax"" NULLS FIRST) AS ""FaxNumbers"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task GroupBy_JsonbAgg(bool async) + { + using var ctx = CreateContext(); + + var query = ctx.Set() + .GroupBy(c => c.City) + .Select(g => new + { + City = g.Key, + FaxNumbers = EF.Functions.JsonbAgg(g.Select(c => c.Fax).OrderBy(id => id)) + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var london = results.Single(x => x.City == "London"); + Assert.Collection( + london.FaxNumbers, + Assert.Null, + f => Assert.Equal("(171) 555-2530", f), + f => Assert.Equal("(171) 555-3373", f), + f => Assert.Equal("(171) 555-5646", f), + f => Assert.Equal("(171) 555-6750", f), + f => Assert.Equal("(171) 555-9199", f)); + + AssertSql( + @"SELECT c.""City"", jsonb_agg(c.""Fax"" ORDER BY c.""Fax"" NULLS FIRST) AS ""FaxNumbers"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + #endregion Aggregate functions + + #region JsonObjectAgg + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task GroupBy_JsonObjectAgg(bool async) + { + using var ctx = CreateContext(); + + var query = ctx.Set() + .GroupBy(c => c.City) + .Select( + g => new + { + City = g.Key, + Companies = EF.Functions.JsonObjectAgg( + g + .OrderBy(c => c.CompanyName) + .Select(c => ValueTuple.Create(c.CompanyName, c.ContactName))) + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var london = results.Single(r => r.City == "London"); + + Assert.Equal( + @"{ ""Around the Horn"" : ""Thomas Hardy"", ""B's Beverages"" : ""Victoria Ashworth"", ""Consolidated Holdings"" : ""Elizabeth Brown"", ""Eastern Connection"" : ""Ann Devon"", ""North/South"" : ""Simon Crowther"", ""Seven Seas Imports"" : ""Hari Kumar"" }", + london.Companies); + + AssertSql( + @"SELECT c.""City"", json_object_agg(c.""CompanyName"", c.""ContactName"" ORDER BY c.""CompanyName"" NULLS FIRST) AS ""Companies"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task GroupBy_JsonObjectAgg_as_Dictionary(bool async) + { + using var ctx = CreateContext(); + + var query = ctx.Set() + .GroupBy(c => c.City) + .Select( + g => new + { + City = g.Key, + Companies = EF.Functions.JsonObjectAgg>( + g.Select(c => ValueTuple.Create(c.CompanyName, c.ContactName))) + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var london = results.Single(r => r.City == "London"); + + Assert.Equal( + new Dictionary + { + ["Around the Horn"] = "Thomas Hardy", + ["B's Beverages"] = "Victoria Ashworth", + ["Consolidated Holdings"] = "Elizabeth Brown", + ["Eastern Connection"] = "Ann Devon", + ["North/South"] = "Simon Crowther", + ["Seven Seas Imports"] = "Hari Kumar" + }, + london.Companies); + + AssertSql( + @"SELECT c.""City"", json_object_agg(c.""CompanyName"", c.""ContactName"") AS ""Companies"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task GroupBy_JsonbObjectAgg(bool async) + { + using var ctx = CreateContext(); + + // Note that unlike with json, jsonb doesn't guarantee ordering; so we parse the JSON string client-side. + var query = ctx.Set() + .GroupBy(c => c.City) + .Select( + g => new + { + City = g.Key, + Companies = EF.Functions.JsonbObjectAgg(g.Select(c => ValueTuple.Create(c.CompanyName, c.ContactName))) + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var london = results.Single(r => r.City == "London"); + var companiesDictionary = JsonSerializer.Deserialize>(london.Companies); + + Assert.Equal( + new Dictionary + { + ["Around the Horn"] = "Thomas Hardy", + ["B's Beverages"] = "Victoria Ashworth", + ["Consolidated Holdings"] = "Elizabeth Brown", + ["Eastern Connection"] = "Ann Devon", + ["North/South"] = "Simon Crowther", + ["Seven Seas Imports"] = "Hari Kumar" + }, + companiesDictionary); + + AssertSql( + @"SELECT c.""City"", jsonb_object_agg(c.""CompanyName"", c.""ContactName"") AS ""Companies"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task GroupBy_JsonbObjectAgg_as_Dictionary(bool async) + { + using var ctx = CreateContext(); + + var query = ctx.Set() + .GroupBy(c => c.City) + .Select( + g => new + { + City = g.Key, + Companies = EF.Functions.JsonbObjectAgg>( + g.Select(c => ValueTuple.Create(c.CompanyName, c.ContactName))) + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var london = results.Single(r => r.City == "London"); + + Assert.Equal( + new Dictionary + { + ["Around the Horn"] = "Thomas Hardy", + ["B's Beverages"] = "Victoria Ashworth", + ["Consolidated Holdings"] = "Elizabeth Brown", + ["Eastern Connection"] = "Ann Devon", + ["North/South"] = "Simon Crowther", + ["Seven Seas Imports"] = "Hari Kumar" + }, + london.Companies); + + AssertSql( + @"SELECT c.""City"", jsonb_object_agg(c.""CompanyName"", c.""ContactName"") AS ""Companies"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + #endregion JsonObjectAgg + + #region Statistics + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task StandardDeviation(bool async) + { + await using var ctx = CreateContext(); + + var query = ctx.Set() + .GroupBy(od => od.ProductID) + .Select(g => new + { + ProductID = g.Key, + SampleStandardDeviation = EF.Functions.StandardDeviationSample(g.Select(od => od.UnitPrice)), + PopulationStandardDeviation = EF.Functions.StandardDeviationPopulation(g.Select(od => od.UnitPrice)) + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var product9 = results.Single(r => r.ProductID == 9); + Assert.Equal(8.675943752699023, product9.SampleStandardDeviation.Value, 5); + Assert.Equal(7.759999999999856, product9.PopulationStandardDeviation.Value, 5); + + AssertSql( + @"SELECT o.""ProductID"", stddev_samp(o.""UnitPrice"") AS ""SampleStandardDeviation"", stddev_pop(o.""UnitPrice"") AS ""PopulationStandardDeviation"" +FROM ""Order Details"" AS o +GROUP BY o.""ProductID"""); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Variance(bool async) + { + await using var ctx = CreateContext(); + + var query = ctx.Set() + .GroupBy(od => od.ProductID) + .Select( + g => new + { + ProductID = g.Key, + SampleStandardDeviation = EF.Functions.VarianceSample(g.Select(od => od.UnitPrice)), + PopulationStandardDeviation = EF.Functions.VariancePopulation(g.Select(od => od.UnitPrice)) + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var product9 = results.Single(r => r.ProductID == 9); + Assert.Equal(75.2719999999972, product9.SampleStandardDeviation.Value, 5); + Assert.Equal(60.217599999997766, product9.PopulationStandardDeviation.Value, 5); + + AssertSql( + @"SELECT o.""ProductID"", var_samp(o.""UnitPrice"") AS ""SampleStandardDeviation"", var_pop(o.""UnitPrice"") AS ""PopulationStandardDeviation"" +FROM ""Order Details"" AS o +GROUP BY o.""ProductID"""); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Other_statistics_functions(bool async) + { + await using var ctx = CreateContext(); + + var query = ctx.Set() + .GroupBy(od => od.ProductID) + .Select(g => new + { + ProductID = g.Key, + Correlation = EF.Functions.Correlation(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + CovariancePopulation = EF.Functions.CovariancePopulation(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + CovarianceSample = EF.Functions.CovarianceSample(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + RegrAverageX = EF.Functions.RegrAverageX(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + RegrAverageY = EF.Functions.RegrAverageY(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + RegrCount = EF.Functions.RegrCount(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + RegrIntercept = EF.Functions.RegrIntercept(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + RegrR2 = EF.Functions.RegrR2(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + RegrSlope = EF.Functions.RegrSlope(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + RegrSXX = EF.Functions.RegrSXX(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + RegrSXY = EF.Functions.RegrSXY(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var product9 = results.Single(r => r.ProductID == 9); + Assert.Equal(0.9336470941441423, product9.Correlation.Value, 5); + Assert.Equal(1.4799999967217445, product9.CovariancePopulation.Value, 5); + Assert.Equal(1.8499999959021807, product9.CovarianceSample.Value, 5); + Assert.Equal(0.10000000149011612, product9.RegrAverageX.Value, 5); + Assert.Equal(19, product9.RegrAverageY.Value, 5); + Assert.Equal(5, product9.RegrCount.Value); + Assert.Equal(2.5555555647538144, product9.RegrIntercept.Value, 5); + Assert.Equal(0.871696896403801, product9.RegrR2.Value, 5); + Assert.Equal(164.44444190204874, product9.RegrSlope.Value, 5); + Assert.Equal(0.045000000596046474, product9.RegrSXX.Value, 5); + Assert.Equal(7.399999983608723, product9.RegrSXY.Value, 5); + + AssertSql( + @"SELECT o.""ProductID"", corr(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""Correlation"", covar_pop(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""CovariancePopulation"", covar_samp(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""CovarianceSample"", regr_avgx(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""RegrAverageX"", regr_avgy(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""RegrAverageY"", regr_count(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""RegrCount"", regr_intercept(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""RegrIntercept"", regr_r2(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""RegrR2"", regr_slope(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""RegrSlope"", regr_sxx(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""RegrSXX"", regr_sxy(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""RegrSXY"" +FROM ""Order Details"" AS o +GROUP BY o.""ProductID"""); + } + + #endregion Statistics + #region Unsupported // PostgreSQL does not have strpos with starting position diff --git a/test/EFCore.PG.FunctionalTests/Query/NorthwindGroupByQueryNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/Query/NorthwindGroupByQueryNpgsqlTest.cs index ad485b6114..06c605152b 100644 --- a/test/EFCore.PG.FunctionalTests/Query/NorthwindGroupByQueryNpgsqlTest.cs +++ b/test/EFCore.PG.FunctionalTests/Query/NorthwindGroupByQueryNpgsqlTest.cs @@ -1,3 +1,5 @@ +using Microsoft.EntityFrameworkCore.TestModels.Northwind; + namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query; public class NorthwindGroupByQueryNpgsqlTest : NorthwindGroupByQueryRelationalTestBase> @@ -3144,6 +3146,23 @@ LEFT JOIN ( ORDER BY t0.""Key"" NULLS FIRST, t1.""OrderID"" NULLS FIRST"); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task GroupBy_OrderBy_Average(bool async) + { + await AssertQueryScalar( + async, + ss => from o in ss.Set() + group o by new { o.CustomerID } + into g + select g.OrderBy(e => e.OrderID).Select(e => (int?)e.OrderID).Average()); + + AssertSql( + @"SELECT AVG(o.""OrderID""::double precision ORDER BY o.""OrderID"" NULLS FIRST) +FROM ""Orders"" AS o +GROUP BY o.""CustomerID"""); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); diff --git a/test/EFCore.PG.FunctionalTests/Query/RangeQueryNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/Query/RangeQueryNpgsqlTest.cs index d57cd618aa..c84c9f6405 100644 --- a/test/EFCore.PG.FunctionalTests/Query/RangeQueryNpgsqlTest.cs +++ b/test/EFCore.PG.FunctionalTests/Query/RangeQueryNpgsqlTest.cs @@ -223,6 +223,34 @@ public void Union() LIMIT 2"); } + [ConditionalFact] + [MinimumPostgresVersion(14, 0)] // Multiranges were introduced in PostgreSQL 14 + public void Union_aggregate() + { + using var context = CreateContext(); + + var group = context.RangeTestEntities + .Where(x => x.Id == 1 || x.Id == 2) + .GroupBy(x => true) + .Select(g => new + { + Union = EF.Functions.RangeAgg(g.Select(x => x.IntRange)) + }) + .Single(); + + Assert.Equal(new NpgsqlRange[] { new(1, true, 16, false) }, group.Union); + + AssertSql( + @"SELECT range_agg(t.""IntRange"") AS ""Union"" +FROM ( + SELECT r.""IntRange"", TRUE AS ""Key"" + FROM ""RangeTestEntities"" AS r + WHERE r.""Id"" IN (1, 2) +) AS t +GROUP BY t.""Key"" +LIMIT 2"); + } + [ConditionalFact] public void Intersect() { @@ -240,6 +268,34 @@ public void Intersect() LIMIT 2"); } + [ConditionalFact] + [MinimumPostgresVersion(14, 0)] // Multiranges were introduced in PostgreSQL 14 + public void Intersect_aggregate() + { + using var context = CreateContext(); + + var group = context.RangeTestEntities + .Where(x => x.Id == 1 || x.Id == 2) + .GroupBy(x => true) + .Select(g => new + { + Union = EF.Functions.RangeIntersectAgg(g.Select(x => x.IntRange)) + }) + .Single(); + + Assert.Equal(new NpgsqlRange(5, true, 11, false), group.Union); + + AssertSql( + @"SELECT range_intersect_agg(t.""IntRange"") AS ""Union"" +FROM ( + SELECT r.""IntRange"", TRUE AS ""Key"" + FROM ""RangeTestEntities"" AS r + WHERE r.""Id"" IN (1, 2) +) AS t +GROUP BY t.""Key"" +LIMIT 2"); + } + [ConditionalFact] public void Except() {