diff --git a/src/EFCore.PG.NTS/Query/ExpressionTranslators/Internal/NpgsqlNetTopologySuiteAggregateMethodCallTranslatorPlugin.cs b/src/EFCore.PG.NTS/Query/ExpressionTranslators/Internal/NpgsqlNetTopologySuiteAggregateMethodCallTranslatorPlugin.cs index 4f22e94cf..a55ead297 100644 --- a/src/EFCore.PG.NTS/Query/ExpressionTranslators/Internal/NpgsqlNetTopologySuiteAggregateMethodCallTranslatorPlugin.cs +++ b/src/EFCore.PG.NTS/Query/ExpressionTranslators/Internal/NpgsqlNetTopologySuiteAggregateMethodCallTranslatorPlugin.cs @@ -64,9 +64,9 @@ public NpgsqlNetTopologySuiteAggregateMethodTranslator( return _sqlExpressionFactory.AggregateFunction( method == GeometryCombineMethod ? "ST_Collect" : "ST_Union", new[] { sqlExpression }, + source, nullable: true, argumentsPropagateNullability: new[] { false }, - source, typeof(Geometry), resultTypeMapping); } @@ -80,9 +80,9 @@ public NpgsqlNetTopologySuiteAggregateMethodTranslator( _sqlExpressionFactory.AggregateFunction( "ST_Collect", new[] { sqlExpression }, + source, nullable: true, argumentsPropagateNullability: new[] { false }, - source, typeof(Geometry), resultTypeMapping) }, diff --git a/src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlAggregateDbFunctionsExtensions.cs b/src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlAggregateDbFunctionsExtensions.cs new file mode 100644 index 000000000..435a9feec --- /dev/null +++ b/src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlAggregateDbFunctionsExtensions.cs @@ -0,0 +1,514 @@ +// ReSharper disable once CheckNamespace +namespace Microsoft.EntityFrameworkCore; + +/// +/// Provides extension methods supporting aggregate function translation for PostgreSQL. +/// +public static class NpgsqlAggregateDbFunctionsExtensions +{ + /// + /// Collects all the input values, including nulls, into a PostgreSQL array. + /// Corresponds to the PostgreSQL array_agg aggregate function. + /// + /// The instance. + /// The input values to be aggregated into an array. + /// PostgreSQL documentation for aggregate functions. + public static T[] ArrayAgg(this DbFunctions _, IEnumerable input) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(ArrayAgg))); + + /// + /// Collects all the input values, including nulls, into a json array. Values are converted to JSON as per to_json or + /// to_jsonb. Corresponds to the PostgreSQL json_agg aggregate function. + /// + /// The instance. + /// The input values to be aggregated into a JSON array. + /// PostgreSQL documentation for aggregate functions. + public static T[] JsonAgg(this DbFunctions _, IEnumerable input) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(JsonAgg))); + + /// + /// Collects all the input values, including nulls, into a jsonb array. Values are converted to JSON as per to_json or + /// to_jsonb. Corresponds to the PostgreSQL jsonb_agg aggregate function. + /// + /// The instance. + /// The input values to be aggregated into a JSON array. + /// PostgreSQL documentation for aggregate functions. + public static T[] JsonbAgg(this DbFunctions _, IEnumerable input) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(JsonbAgg))); + + #region Range + + /// + /// Computes the union of the non-null input ranges. Corresponds to the PostgreSQL range_agg aggregate function. + /// + /// The instance. + /// The ranges to be aggregated via union into a multirange. + /// PostgreSQL documentation for aggregate functions. + public static NpgsqlRange[] RangeAgg(this DbFunctions _, IEnumerable> input) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RangeAgg))); + + /// + /// Computes the intersection of the non-null input ranges. Corresponds to the PostgreSQL range_intersect_agg aggregate function. + /// + /// The instance. + /// The ranges to be aggregated via intersection into a multirange. + /// PostgreSQL documentation for aggregate functions. + public static NpgsqlRange RangeIntersectAgg(this DbFunctions _, IEnumerable> input) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RangeAgg))); + + /// + /// Computes the intersection of the non-null input multiranges. + /// Corresponds to the PostgreSQL range_intersect_agg aggregate function. + /// + /// The instance. + /// The ranges to be aggregated via intersection into a multirange. + /// PostgreSQL documentation for aggregate functions. + public static NpgsqlRange[] RangeIntersectAgg(this DbFunctions _, IEnumerable[]> input) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RangeAgg))); + + #endregion Range + + #region JsonObjectAgg + + /// + /// Collects all the key/value pairs into a JSON object. Key arguments are coerced to text; value arguments are converted as per + /// to_json. Values can be , but not keys. + /// Corresponds to the PostgreSQL json_object_agg aggregate function. + /// + /// The instance. + /// An enumerable of key-value pairs to be aggregated into a JSON object. + /// PostgreSQL documentation for aggregate functions. + public static string JsonObjectAgg(this DbFunctions _, IEnumerable<(T1, T2)> keyValuePairs) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(JsonObjectAgg))); + + /// + /// Collects all the key/value pairs into a JSON object. Key arguments are coerced to text; value arguments are converted as per + /// to_json. Values can be , but not keys. + /// Corresponds to the PostgreSQL json_object_agg aggregate function. + /// + /// The instance. + /// An enumerable of key-value pairs to be aggregated into a JSON object. + /// PostgreSQL documentation for aggregate functions. + public static TReturn JsonObjectAgg(this DbFunctions _, IEnumerable<(T1, T2)> keyValuePairs) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(JsonObjectAgg))); + + /// + /// Collects all the key/value pairs into a JSON object. Key arguments are coerced to text; value arguments are converted as per + /// to_jsonb. Values can be , but not keys. + /// Corresponds to the PostgreSQL jsonb_object_agg aggregate function. + /// + /// The instance. + /// An enumerable of key-value pairs to be aggregated into a JSON object. + /// PostgreSQL documentation for aggregate functions. + public static string JsonbObjectAgg(this DbFunctions _, IEnumerable<(T1, T2)> keyValuePairs) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(JsonbObjectAgg))); + + /// + /// Collects all the key/value pairs into a JSON object. Key arguments are coerced to text; value arguments are converted as per + /// to_jsonb. Values can be , but not keys. + /// Corresponds to the PostgreSQL jsonb_object_agg aggregate function. + /// + /// The instance. + /// An enumerable of key-value pairs to be aggregated into a JSON object. + /// PostgreSQL documentation for aggregate functions. + public static TReturn JsonbObjectAgg(this DbFunctions _, IEnumerable<(T1, T2)> keyValuePairs) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(JsonbObjectAgg))); + + #endregion JsonObjectAgg + + #region Sample standard deviation + + /// + /// Returns the sample standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_samp function. + /// + /// The instance. + /// The values. + /// The computed sample standard deviation. + public static double? StandardDeviationSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationSample))); + + /// + /// Returns the sample standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_samp function. + /// + /// The instance. + /// The values. + /// The computed sample standard deviation. + public static double? StandardDeviationSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationSample))); + + /// + /// Returns the sample standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_samp function. + /// + /// The instance. + /// The values. + /// The computed sample standard deviation. + public static double? StandardDeviationSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationSample))); + + /// + /// Returns the sample standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_samp function. + /// + /// The instance. + /// The values. + /// The computed sample standard deviation. + public static double? StandardDeviationSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationSample))); + + /// + /// Returns the sample standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_samp function. + /// + /// The instance. + /// The values. + /// The computed sample standard deviation. + public static double? StandardDeviationSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationSample))); + + /// + /// Returns the sample standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_samp function. + /// + /// The instance. + /// The values. + /// The computed sample standard deviation. + public static double? StandardDeviationSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationSample))); + + /// + /// Returns the sample standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_samp function. + /// + /// The instance. + /// The values. + /// The computed sample standard deviation. + public static double? StandardDeviationSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationSample))); + + #endregion Sample standard deviation + + #region Population standard deviation + + /// + /// Returns the population standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_pop function. + /// + /// The instance. + /// The values. + /// The computed population standard deviation. + public static double? StandardDeviationPopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationPopulation))); + + /// + /// Returns the population standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_pop function. + /// + /// The instance. + /// The values. + /// The computed population standard deviation. + public static double? StandardDeviationPopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationPopulation))); + + /// + /// Returns the population standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_pop function. + /// + /// The instance. + /// The values. + /// The computed population standard deviation. + public static double? StandardDeviationPopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationPopulation))); + + /// + /// Returns the population standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_pop function. + /// + /// The instance. + /// The values. + /// The computed population standard deviation. + public static double? StandardDeviationPopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationPopulation))); + + /// + /// Returns the population standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_pop function. + /// + /// The instance. + /// The values. + /// The computed population standard deviation. + public static double? StandardDeviationPopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationPopulation))); + + /// + /// Returns the population standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_pop function. + /// + /// The instance. + /// The values. + /// The computed population standard deviation. + public static double? StandardDeviationPopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationPopulation))); + + /// + /// Returns the population standard deviation of all values in the specified expression. + /// Corresponds to the PostgreSQL stddev_pop function. + /// + /// The instance. + /// The values. + /// The computed population standard deviation. + public static double? StandardDeviationPopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(StandardDeviationPopulation))); + + #endregion Population standard deviation + + #region Sample variance + + /// + /// Returns the sample variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_samp function. + /// + /// The instance. + /// The values. + /// The computed sample variance. + public static double? VarianceSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VarianceSample))); + + /// + /// Returns the sample variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_samp function. + /// + /// The instance. + /// The values. + /// The computed sample variance. + public static double? VarianceSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VarianceSample))); + + /// + /// Returns the sample variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_samp function. + /// + /// The instance. + /// The values. + /// The computed sample variance. + public static double? VarianceSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VarianceSample))); + + /// + /// Returns the sample variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_samp function. + /// + /// The instance. + /// The values. + /// The computed sample variance. + public static double? VarianceSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VarianceSample))); + + /// + /// Returns the sample variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_samp function. + /// + /// The instance. + /// The values. + /// The computed sample variance. + public static double? VarianceSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VarianceSample))); + + /// + /// Returns the sample variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_samp function. + /// + /// The instance. + /// The values. + /// The computed sample variance. + public static double? VarianceSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VarianceSample))); + + /// + /// Returns the sample variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_samp function. + /// + /// The instance. + /// The values. + /// The computed sample variance. + public static double? VarianceSample(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VarianceSample))); + + #endregion Sample variance + + #region Population variance + + /// + /// Returns the population variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_pop function. + /// + /// The instance. + /// The values. + /// The computed population variance. + public static double? VariancePopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VariancePopulation))); + + /// + /// Returns the population variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_pop function. + /// + /// The instance. + /// The values. + /// The computed population variance. + public static double? VariancePopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VariancePopulation))); + + /// + /// Returns the population variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_pop function. + /// + /// The instance. + /// The values. + /// The computed population variance. + public static double? VariancePopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VariancePopulation))); + + /// + /// Returns the population variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_pop function. + /// + /// The instance. + /// The values. + /// The computed population variance. + public static double? VariancePopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VariancePopulation))); + + /// + /// Returns the population variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_pop function. + /// + /// The instance. + /// The values. + /// The computed population variance. + public static double? VariancePopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VariancePopulation))); + + /// + /// Returns the population variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_pop function. + /// + /// The instance. + /// The values. + /// The computed population variance. + public static double? VariancePopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VariancePopulation))); + + /// + /// Returns the population variance of all values in the specified expression. + /// Corresponds to the PostgreSQL var_pop function. + /// + /// The instance. + /// The values. + /// The computed population variance. + public static double? VariancePopulation(this DbFunctions _, IEnumerable values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VariancePopulation))); + + #endregion Population variance + + #region Other statistics functions + + /// + /// Computes the correlation coefficient. Corresponds to the PostgreSQL corr function. + /// + /// The instance. + /// The values. + public static double? Correlation(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Correlation))); + + /// + /// Computes the population covariance. Corresponds to the PostgreSQL covar_pop function. + /// + /// The instance. + /// The values. + public static double? CovariancePopulation(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(CovariancePopulation))); + + /// + /// Computes the sample covariance. Corresponds to the PostgreSQL covar_samp function. + /// + /// The instance. + /// The values. + public static double? CovarianceSample(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(CovarianceSample))); + + /// + /// Computes the average of the independent variable, sum(X)/N. + /// Corresponds to the PostgreSQL regr_avgx function. + /// + /// The instance. + /// The values. + public static double? RegrAverageX(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RegrAverageX))); + + /// + /// Computes the average of the dependent variable, sum(Y)/N. + /// Corresponds to the PostgreSQL regr_avgy function. + /// + /// The instance. + /// The values. + public static double? RegrAverageY(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RegrAverageY))); + + /// + /// Computes the number of rows in which both inputs are non-null. + /// Corresponds to the PostgreSQL regr_count function. + /// + /// The instance. + /// The values. + public static long? RegrCount(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RegrCount))); + + /// + /// Computes the y-intercept of the least-squares-fit linear equation determined by the (X, Y) pairs. + /// Corresponds to the PostgreSQL regr_intercept function. + /// + /// The instance. + /// The values. + public static double? RegrIntercept(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RegrIntercept))); + + /// + /// Computes the square of the correlation coefficient. + /// Corresponds to the PostgreSQL regr_r2 function. + /// + /// The instance. + /// The values. + public static double? RegrR2(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RegrR2))); + + /// + /// Computes the slope of the least-squares-fit linear equation determined by the (X, Y) pairs. + /// Corresponds to the PostgreSQL regr_slope function. + /// + /// The instance. + /// The values. + public static double? RegrSlope(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RegrSlope))); + + /// + /// Computes the “sum of squares” of the independent variable, sum(X^2) - sum(X)^2/N. + /// Corresponds to the PostgreSQL regr_sxx function. + /// + /// The instance. + /// The values. + public static double? RegrSXX(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RegrSXX))); + + /// + /// Computes the “sum of products” of independent times dependent variables, sum(X*Y) - sum(X) * sum(Y)/N. + /// Corresponds to the PostgreSQL regr_sxy function. + /// + /// The instance. + /// The values. + public static double? RegrSXY(this DbFunctions _, IEnumerable<(double, double)> values) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(RegrSXY))); + + #endregion Other statistics functions +} diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlAggregateMethodCallTranslatorProvider.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlAggregateMethodCallTranslatorProvider.cs index c0687a6d1..b0a789317 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlAggregateMethodCallTranslatorProvider.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlAggregateMethodCallTranslatorProvider.cs @@ -11,7 +11,9 @@ public NpgsqlAggregateMethodCallTranslatorProvider(RelationalAggregateMethodCall AddTranslators( new IAggregateMethodCallTranslator[] { - new NpgsqlQueryableAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource) + new NpgsqlQueryableAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource), + new NpgsqlStatisticsAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource), + new NpgsqlMiscAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource) }); } } diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMiscAggregateMethodTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMiscAggregateMethodTranslator.cs new file mode 100644 index 000000000..33ca1dabd --- /dev/null +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMiscAggregateMethodTranslator.cs @@ -0,0 +1,158 @@ +using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal; +using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping; +using static Npgsql.EntityFrameworkCore.PostgreSQL.Utilities.Statics; + +namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal; + +public class NpgsqlMiscAggregateMethodTranslator : IAggregateMethodCallTranslator +{ + private static readonly MethodInfo StringJoin + = typeof(string).GetRuntimeMethod(nameof(string.Join), new[] { typeof(string), typeof(IEnumerable) })!; + + private static readonly MethodInfo StringConcat + = typeof(string).GetRuntimeMethod(nameof(string.Concat), new[] { typeof(IEnumerable) })!; + + private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; + private readonly IRelationalTypeMappingSource _typeMappingSource; + + public NpgsqlMiscAggregateMethodTranslator( + NpgsqlSqlExpressionFactory sqlExpressionFactory, + IRelationalTypeMappingSource typeMappingSource) + { + _sqlExpressionFactory = sqlExpressionFactory; + _typeMappingSource = typeMappingSource; + } + + public virtual SqlExpression? Translate( + MethodInfo method, + EnumerableExpression source, + IReadOnlyList arguments, + IDiagnosticsLogger logger) + { + // Docs: https://www.postgresql.org/docs/current/functions-aggregate.html + + if (source.Selector is not SqlExpression sqlExpression) + { + return null; + } + + if (method == StringJoin || method == StringConcat) + { + // string_agg filters out nulls, but string.Join treats them as empty strings; coalesce unless we know we're aggregating over + // a non-nullable column. + if (sqlExpression is not ColumnExpression { IsNullable: false }) + { + sqlExpression = _sqlExpressionFactory.Coalesce( + sqlExpression, + _sqlExpressionFactory.Constant(string.Empty, typeof(string))); + } + + // string_agg returns null when there are no rows (or non-null values), but string.Join returns an empty string. + return _sqlExpressionFactory.Coalesce( + _sqlExpressionFactory.AggregateFunction( + "string_agg", + new[] + { + sqlExpression, + method == StringJoin ? arguments[0] : _sqlExpressionFactory.Constant(string.Empty, typeof(string)) + }, + source, + nullable: true, + argumentsPropagateNullability: new[] { false, true }, + typeof(string), + _typeMappingSource.FindMapping("text")), // Note that string_agg returns text even if its inputs are varchar(x) + _sqlExpressionFactory.Constant(string.Empty, typeof(string))); + } + + if (method.DeclaringType == typeof(NpgsqlAggregateDbFunctionsExtensions)) + { + switch (method.Name) + { + case nameof(NpgsqlAggregateDbFunctionsExtensions.ArrayAgg): + var arrayClrType = sqlExpression.Type.MakeArrayType(); + + return _sqlExpressionFactory.AggregateFunction( + "array_agg", + new[] { sqlExpression }, + source, + nullable: true, + argumentsPropagateNullability: FalseArrays[1], + returnType: arrayClrType, + typeMapping: sqlExpression.TypeMapping is null + ? null + : new NpgsqlArrayArrayTypeMapping(arrayClrType, sqlExpression.TypeMapping)); + + case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonAgg): + arrayClrType = sqlExpression.Type.MakeArrayType(); + + return _sqlExpressionFactory.AggregateFunction( + "json_agg", + new[] { sqlExpression }, + source, + nullable: true, + argumentsPropagateNullability: FalseArrays[1], + returnType: arrayClrType, + _typeMappingSource.FindMapping(arrayClrType, "json")); + + case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbAgg): + arrayClrType = sqlExpression.Type.MakeArrayType(); + + return _sqlExpressionFactory.AggregateFunction( + "jsonb_agg", + new[] { sqlExpression }, + source, + nullable: true, + argumentsPropagateNullability: FalseArrays[1], + returnType: arrayClrType, + _typeMappingSource.FindMapping(arrayClrType, "jsonb")); + + case nameof(NpgsqlAggregateDbFunctionsExtensions.RangeAgg): + arrayClrType = sqlExpression.Type.MakeArrayType(); + + return _sqlExpressionFactory.AggregateFunction( + "range_agg", + new[] { sqlExpression }, + source, + nullable: true, + argumentsPropagateNullability: FalseArrays[1], + returnType: arrayClrType, + _typeMappingSource.FindMapping(arrayClrType)); + + case nameof(NpgsqlAggregateDbFunctionsExtensions.RangeIntersectAgg): + return _sqlExpressionFactory.AggregateFunction( + "range_intersect_agg", + new[] { sqlExpression }, + source, + nullable: true, + argumentsPropagateNullability: FalseArrays[1], + returnType: sqlExpression.Type, + sqlExpression.TypeMapping); + + case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbObjectAgg): + case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonObjectAgg): + var isJsonb = method.Name == nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbObjectAgg); + + // These methods accept two enumerable (column) arguments; this is represented in LINQ as a projection from the grouping + // to a tuple of the two columns. Since we generally translate tuples to PostgresRowValueExpression, we take it apart + // here. + if (source.Selector is not PostgresRowValueExpression rowValueExpression) + { + return null; + } + + var (keys, values) = (rowValueExpression.Values[0], rowValueExpression.Values[1]); + + return _sqlExpressionFactory.AggregateFunction( + isJsonb ? "jsonb_object_agg" : "json_object_agg", + new[] { keys, values }, + source, + nullable: true, + argumentsPropagateNullability: FalseArrays[2], + returnType: method.ReturnType, + _typeMappingSource.FindMapping(method.ReturnType, isJsonb ? "jsonb" : "json")); + } + } + + return null; + } +} diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlQueryableAggregateMethodTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlQueryableAggregateMethodTranslator.cs index 6048a67d7..5cec5f55d 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlQueryableAggregateMethodTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlQueryableAggregateMethodTranslator.cs @@ -45,18 +45,18 @@ public NpgsqlQueryableAggregateMethodTranslator( _sqlExpressionFactory.AggregateFunction( "avg", new[] { averageSqlExpression }, - nullable: true, - argumentsPropagateNullability: FalseArrays[1], source, - typeof(double)), + nullable: true, + argumentsPropagateNullability: FalseArrays[1], + returnType: typeof(double)), averageSqlExpression.Type, averageSqlExpression.TypeMapping) : _sqlExpressionFactory.AggregateFunction( "avg", new[] { averageSqlExpression }, + source, nullable: true, argumentsPropagateNullability: FalseArrays[1], - source, averageSqlExpression.Type, averageSqlExpression.TypeMapping); @@ -69,11 +69,12 @@ public NpgsqlQueryableAggregateMethodTranslator( _sqlExpressionFactory.AggregateFunction( "count", new[] { countSqlExpression }, + source, nullable: false, argumentsPropagateNullability: FalseArrays[1], - source, typeof(long)), - typeof(int), _typeMappingSource.FindMapping(typeof(int))); + typeof(int), + _typeMappingSource.FindMapping(typeof(int))); case nameof(Queryable.LongCount) when methodInfo == QueryableMethods.LongCountWithoutPredicate @@ -82,9 +83,9 @@ public NpgsqlQueryableAggregateMethodTranslator( return _sqlExpressionFactory.AggregateFunction( "count", new[] { longCountSqlExpression }, + source, nullable: false, argumentsPropagateNullability: FalseArrays[1], - source, typeof(long)); case nameof(Queryable.Max) @@ -94,9 +95,9 @@ public NpgsqlQueryableAggregateMethodTranslator( return _sqlExpressionFactory.AggregateFunction( "max", new[] { maxSqlExpression }, + source, nullable: true, argumentsPropagateNullability: FalseArrays[1], - source, maxSqlExpression.Type, maxSqlExpression.TypeMapping); @@ -107,9 +108,9 @@ public NpgsqlQueryableAggregateMethodTranslator( return _sqlExpressionFactory.AggregateFunction( "min", new[] { minSqlExpression }, + source, nullable: true, argumentsPropagateNullability: FalseArrays[1], - source, minSqlExpression.Type, minSqlExpression.TypeMapping); @@ -129,9 +130,9 @@ public NpgsqlQueryableAggregateMethodTranslator( _sqlExpressionFactory.AggregateFunction( "sum", new[] { sumSqlExpression }, + source, nullable: true, argumentsPropagateNullability: FalseArrays[1], - source, typeof(long)), sumInputType, sumSqlExpression.TypeMapping); @@ -143,9 +144,9 @@ public NpgsqlQueryableAggregateMethodTranslator( _sqlExpressionFactory.AggregateFunction( "sum", new[] { sumSqlExpression }, + source, nullable: true, argumentsPropagateNullability: FalseArrays[1], - source, typeof(decimal)), sumInputType, sumSqlExpression.TypeMapping); @@ -154,9 +155,9 @@ public NpgsqlQueryableAggregateMethodTranslator( return _sqlExpressionFactory.AggregateFunction( "sum", new[] { sumSqlExpression }, + source, nullable: true, argumentsPropagateNullability: FalseArrays[1], - source, sumInputType, sumSqlExpression.TypeMapping); } diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStatisticsAggregateMethodTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStatisticsAggregateMethodTranslator.cs new file mode 100644 index 000000000..d438eb3eb --- /dev/null +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStatisticsAggregateMethodTranslator.cs @@ -0,0 +1,104 @@ +using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal; +using static Npgsql.EntityFrameworkCore.PostgreSQL.Utilities.Statics; + +namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal; + +public class NpgsqlStatisticsAggregateMethodTranslator : IAggregateMethodCallTranslator +{ + private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; + private readonly RelationalTypeMapping _doubleTypeMapping, _longTypeMapping; + + public NpgsqlStatisticsAggregateMethodTranslator( + NpgsqlSqlExpressionFactory sqlExpressionFactory, + IRelationalTypeMappingSource typeMappingSource) + { + _sqlExpressionFactory = sqlExpressionFactory; + _doubleTypeMapping = typeMappingSource.FindMapping(typeof(double))!; + _longTypeMapping = typeMappingSource.FindMapping(typeof(long))!; + } + + public virtual SqlExpression? Translate( + MethodInfo method, + EnumerableExpression source, + IReadOnlyList arguments, + IDiagnosticsLogger logger) + { + // Docs: https://www.postgresql.org/docs/current/functions-aggregate.html#FUNCTIONS-AGGREGATE-STATISTICS-TABLE + + if (method.DeclaringType != typeof(NpgsqlAggregateDbFunctionsExtensions) + || source.Selector is not SqlExpression sqlExpression) + { + return null; + } + + // These four functions are simple and take a single enumerable argument + var functionName = method.Name switch + { + nameof(NpgsqlAggregateDbFunctionsExtensions.StandardDeviationSample) => "stddev_samp", + nameof(NpgsqlAggregateDbFunctionsExtensions.StandardDeviationPopulation) => "stddev_pop", + nameof(NpgsqlAggregateDbFunctionsExtensions.VarianceSample) => "var_samp", + nameof(NpgsqlAggregateDbFunctionsExtensions.VariancePopulation) => "var_pop", + _ => null + }; + + if (functionName is not null) + { + return _sqlExpressionFactory.AggregateFunction( + functionName, + new[] { sqlExpression }, + source, + nullable: true, + argumentsPropagateNullability: FalseArrays[1], + typeof(double), + _doubleTypeMapping); + } + + functionName = method.Name switch + { + nameof(NpgsqlAggregateDbFunctionsExtensions.Correlation) => "corr", + nameof(NpgsqlAggregateDbFunctionsExtensions.CovariancePopulation) => "covar_pop", + nameof(NpgsqlAggregateDbFunctionsExtensions.CovarianceSample) => "covar_samp", + nameof(NpgsqlAggregateDbFunctionsExtensions.RegrAverageX) => "regr_avgx", + nameof(NpgsqlAggregateDbFunctionsExtensions.RegrAverageY) => "regr_avgy", + nameof(NpgsqlAggregateDbFunctionsExtensions.RegrCount) => "regr_count", + nameof(NpgsqlAggregateDbFunctionsExtensions.RegrIntercept) => "regr_intercept", + nameof(NpgsqlAggregateDbFunctionsExtensions.RegrR2) => "regr_r2", + nameof(NpgsqlAggregateDbFunctionsExtensions.RegrSlope) => "regr_slope", + nameof(NpgsqlAggregateDbFunctionsExtensions.RegrSXX) => "regr_sxx", + nameof(NpgsqlAggregateDbFunctionsExtensions.RegrSXY) => "regr_sxy", + _ => null + }; + + if (functionName is not null) + { + // These methods accept two enumerable (column) arguments; this is represented in LINQ as a projection from the grouping + // to a tuple of the two columns. Since we generally translate tuples to PostgresRowValueExpression, we take it apart here. + if (source.Selector is not PostgresRowValueExpression rowValueExpression) + { + return null; + } + + var (y, x) = (rowValueExpression.Values[0], rowValueExpression.Values[1]); + + return method.Name == nameof(NpgsqlAggregateDbFunctionsExtensions.RegrCount) + ? _sqlExpressionFactory.AggregateFunction( + functionName, + new[] { y, x }, + source, + nullable: true, + argumentsPropagateNullability: FalseArrays[2], + typeof(long), + _longTypeMapping) + : _sqlExpressionFactory.AggregateFunction( + functionName, + new[] { y, x }, + source, + nullable: true, + argumentsPropagateNullability: FalseArrays[2], + typeof(double), + _doubleTypeMapping); + } + + return null; + } +} diff --git a/src/EFCore.PG/Query/Expressions/Internal/PostgresFunctionExpression.cs b/src/EFCore.PG/Query/Expressions/Internal/PostgresFunctionExpression.cs index df76f797f..bbb2094d5 100644 --- a/src/EFCore.PG/Query/Expressions/Internal/PostgresFunctionExpression.cs +++ b/src/EFCore.PG/Query/Expressions/Internal/PostgresFunctionExpression.cs @@ -58,9 +58,8 @@ public static PostgresFunctionExpression CreateWithNamedArguments( return new PostgresFunctionExpression( name, arguments, argumentNames, argumentSeparators: null, - nullable, argumentsPropagateNullability, aggregateDistinct: false, aggregatePredicate: null, aggregateOrderings: Array.Empty(), - type, typeMapping); + nullable: nullable, argumentsPropagateNullability: argumentsPropagateNullability, type: type, typeMapping: typeMapping); } public static PostgresFunctionExpression CreateWithArgumentSeparators( @@ -77,10 +76,9 @@ public static PostgresFunctionExpression CreateWithArgumentSeparators( Check.NotNull(argumentSeparators, nameof(argumentSeparators)); return new PostgresFunctionExpression( - name, arguments, argumentNames: null, argumentSeparators, - nullable, argumentsPropagateNullability, + name, arguments, argumentNames: null, argumentSeparators: argumentSeparators, aggregateDistinct: false, aggregatePredicate: null, aggregateOrderings: Array.Empty(), - type, typeMapping); + nullable: nullable, argumentsPropagateNullability: argumentsPropagateNullability, type: type, typeMapping: typeMapping); } public PostgresFunctionExpression( @@ -88,11 +86,11 @@ public PostgresFunctionExpression( IEnumerable arguments, IEnumerable? argumentNames, IEnumerable? argumentSeparators, - bool nullable, - IEnumerable argumentsPropagateNullability, bool aggregateDistinct, SqlExpression? aggregatePredicate, IReadOnlyList aggregateOrderings, + bool nullable, + IEnumerable argumentsPropagateNullability, Type type, RelationalTypeMapping? typeMapping) : base(name, arguments, nullable, argumentsPropagateNullability, type, typeMapping) @@ -175,11 +173,10 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) return changed ? new PostgresFunctionExpression( Name, visitedArguments ?? Arguments, ArgumentNames, ArgumentSeparators, - IsNullable, ArgumentsPropagateNullability!, IsAggregateDistinct, visitedAggregatePredicate ?? AggregatePredicate, visitedAggregateOrderings ?? AggregateOrderings, - Type, TypeMapping) + IsNullable, ArgumentsPropagateNullability!, Type, TypeMapping) : this; } @@ -189,13 +186,11 @@ public override SqlFunctionExpression ApplyTypeMapping(RelationalTypeMapping? ty Arguments, ArgumentNames, ArgumentSeparators, - IsNullable, - ArgumentsPropagateNullability, IsAggregateDistinct, AggregatePredicate, AggregateOrderings, - Type, - typeMapping ?? TypeMapping); + IsNullable, + ArgumentsPropagateNullability, Type, typeMapping ?? TypeMapping); public override SqlFunctionExpression Update(SqlExpression? instance, IReadOnlyList? arguments) { @@ -209,11 +204,10 @@ public override SqlFunctionExpression Update(SqlExpression? instance, IReadOnlyL return !arguments.SequenceEqual(Arguments) ? new PostgresFunctionExpression( Name, arguments, ArgumentNames, ArgumentSeparators, - IsNullable, ArgumentsPropagateNullability, IsAggregateDistinct, AggregatePredicate, AggregateOrderings, - Type, TypeMapping) + IsNullable, ArgumentsPropagateNullability, Type, TypeMapping) : this; } @@ -224,11 +218,10 @@ public virtual PostgresFunctionExpression UpdateAggregateComponents( return predicate != AggregatePredicate || orderings != AggregateOrderings ? new PostgresFunctionExpression( Name, Arguments, ArgumentNames, ArgumentSeparators, - IsNullable, ArgumentsPropagateNullability, IsAggregateDistinct, predicate, orderings, - Type, TypeMapping) + IsNullable, ArgumentsPropagateNullability, Type, TypeMapping) : this; } diff --git a/src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs b/src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs index f8c328ace..67416cac7 100644 --- a/src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs +++ b/src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs @@ -283,9 +283,9 @@ public virtual PostgresBinaryExpression Overlaps(SqlExpression left, SqlExpressi public virtual PostgresFunctionExpression AggregateFunction( string name, IEnumerable arguments, + EnumerableExpression aggregateEnumerableExpression, bool nullable, IEnumerable argumentsPropagateNullability, - EnumerableExpression aggregateEnumerableExpression, Type returnType, RelationalTypeMapping? typeMapping = null) { @@ -301,13 +301,11 @@ public virtual PostgresFunctionExpression AggregateFunction( typeMappedArguments, argumentNames: null, argumentSeparators: null, - nullable, - argumentsPropagateNullability, aggregateEnumerableExpression.IsDistinct, aggregateEnumerableExpression.Predicate, aggregateEnumerableExpression.Orderings, - returnType, - typeMapping); + nullable: nullable, + argumentsPropagateNullability: argumentsPropagateNullability, type: returnType, typeMapping: typeMapping); } #endregion Expression factory methods diff --git a/test/EFCore.PG.FunctionalTests/Query/MultirangeQueryNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/Query/MultirangeQueryNpgsqlTest.cs index cfa7a217d..e427c9f43 100644 --- a/test/EFCore.PG.FunctionalTests/Query/MultirangeQueryNpgsqlTest.cs +++ b/test/EFCore.PG.FunctionalTests/Query/MultirangeQueryNpgsqlTest.cs @@ -438,6 +438,34 @@ public void Intersect_multirange() LIMIT 2"); } + [ConditionalFact] + public void Intersect_aggregate() + { + using var context = CreateContext(); + + var intersection = context.TestEntities + .Where(x => x.Id == 1 || x.Id == 2) + .GroupBy(x => true) + .Select(g => EF.Functions.RangeIntersectAgg(g.Select(x => x.IntMultirange))) + .Single(); + + Assert.Equal(new NpgsqlRange[] + { + new(4, true, 6, false), + new(7, true, 9, false) + }, intersection); + + AssertSql( + @"SELECT range_intersect_agg(t0.""IntMultirange"") +FROM ( + SELECT t.""IntMultirange"", TRUE AS ""Key"" + FROM ""TestEntities"" AS t + WHERE t.""Id"" IN (1, 2) +) AS t0 +GROUP BY t0.""Key"" +LIMIT 2"); + } + [ConditionalFact] public void Except_multirange() { diff --git a/test/EFCore.PG.FunctionalTests/Query/NorthwindFunctionsQueryNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/Query/NorthwindFunctionsQueryNpgsqlTest.cs index c17f3b346..c4de29d18 100644 --- a/test/EFCore.PG.FunctionalTests/Query/NorthwindFunctionsQueryNpgsqlTest.cs +++ b/test/EFCore.PG.FunctionalTests/Query/NorthwindFunctionsQueryNpgsqlTest.cs @@ -1,4 +1,5 @@ -using System.Text.RegularExpressions; +using System.Text.Json; +using System.Text.RegularExpressions; using Microsoft.EntityFrameworkCore.TestModels.Northwind; using Npgsql.EntityFrameworkCore.PostgreSQL.TestUtilities; @@ -12,7 +13,7 @@ public NorthwindFunctionsQueryNpgsqlTest( : base(fixture) { ClearLog(); - //Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper); + Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper); } public override async Task IsNullOrWhiteSpace_in_predicate(bool async) @@ -275,6 +276,431 @@ public Task PadRight_char_with_parameter(bool async) #endregion + #region Aggregate functions + + public override async Task String_Join_over_non_nullable_column(bool async) + { + await base.String_Join_over_non_nullable_column(async); + + AssertSql( + @"SELECT c.""City"", COALESCE(string_agg(c.""CustomerID"", '|'), '') AS ""Customers"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + public override async Task String_Join_over_nullable_column(bool async) + { + await base.String_Join_over_nullable_column(async); + + AssertSql( + @"SELECT c.""City"", COALESCE(string_agg(COALESCE(c.""Region"", ''), '|'), '') AS ""Regions"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + public override async Task String_Join_with_predicate(bool async) + { + await base.String_Join_with_predicate(async); + + AssertSql( + @"SELECT c.""City"", COALESCE(string_agg(c.""CustomerID"", '|') FILTER (WHERE length(c.""ContactName"")::int > 10), '') AS ""Customers"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + public override async Task String_Join_with_ordering(bool async) + { + await base.String_Join_with_ordering(async); + + AssertSql( + @"SELECT c.""City"", COALESCE(string_agg(c.""CustomerID"", '|' ORDER BY c.""CustomerID"" DESC NULLS LAST), '') AS ""Customers"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + public override async Task String_Concat(bool async) + { + await base.String_Concat(async); + + AssertSql( + @"SELECT c.""City"", COALESCE(string_agg(c.""CustomerID"", ''), '') AS ""Customers"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task GroupBy_ArrayAgg(bool async) + { + using var ctx = CreateContext(); + + var query = ctx.Set() + .GroupBy(c => c.City) + .Select(g => new + { + City = g.Key, + FaxNumbers = EF.Functions.ArrayAgg(g.Select(c => c.Fax).OrderBy(id => id)) + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var london = results.Single(x => x.City == "London"); + Assert.Collection( + london.FaxNumbers, + Assert.Null, + f => Assert.Equal("(171) 555-2530", f), + f => Assert.Equal("(171) 555-3373", f), + f => Assert.Equal("(171) 555-5646", f), + f => Assert.Equal("(171) 555-6750", f), + f => Assert.Equal("(171) 555-9199", f)); + + AssertSql( + @"SELECT c.""City"", array_agg(c.""Fax"" ORDER BY c.""Fax"" NULLS FIRST) AS ""FaxNumbers"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task GroupBy_JsonAgg(bool async) + { + using var ctx = CreateContext(); + + var query = ctx.Set() + .GroupBy(c => c.City) + .Select(g => new + { + City = g.Key, + FaxNumbers = EF.Functions.JsonAgg(g.Select(c => c.Fax).OrderBy(id => id)) + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var london = results.Single(x => x.City == "London"); + Assert.Collection( + london.FaxNumbers, + Assert.Null, + f => Assert.Equal("(171) 555-2530", f), + f => Assert.Equal("(171) 555-3373", f), + f => Assert.Equal("(171) 555-5646", f), + f => Assert.Equal("(171) 555-6750", f), + f => Assert.Equal("(171) 555-9199", f)); + + AssertSql( + @"SELECT c.""City"", json_agg(c.""Fax"" ORDER BY c.""Fax"" NULLS FIRST) AS ""FaxNumbers"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task GroupBy_JsonbAgg(bool async) + { + using var ctx = CreateContext(); + + var query = ctx.Set() + .GroupBy(c => c.City) + .Select(g => new + { + City = g.Key, + FaxNumbers = EF.Functions.JsonbAgg(g.Select(c => c.Fax).OrderBy(id => id)) + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var london = results.Single(x => x.City == "London"); + Assert.Collection( + london.FaxNumbers, + Assert.Null, + f => Assert.Equal("(171) 555-2530", f), + f => Assert.Equal("(171) 555-3373", f), + f => Assert.Equal("(171) 555-5646", f), + f => Assert.Equal("(171) 555-6750", f), + f => Assert.Equal("(171) 555-9199", f)); + + AssertSql( + @"SELECT c.""City"", jsonb_agg(c.""Fax"" ORDER BY c.""Fax"" NULLS FIRST) AS ""FaxNumbers"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + #endregion Aggregate functions + + #region JsonObjectAgg + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task GroupBy_JsonObjectAgg(bool async) + { + using var ctx = CreateContext(); + + var query = ctx.Set() + .GroupBy(c => c.City) + .Select( + g => new + { + City = g.Key, + Companies = EF.Functions.JsonObjectAgg( + g + .OrderBy(c => c.CompanyName) + .Select(c => ValueTuple.Create(c.CompanyName, c.ContactName))) + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var london = results.Single(r => r.City == "London"); + + Assert.Equal( + @"{ ""Around the Horn"" : ""Thomas Hardy"", ""B's Beverages"" : ""Victoria Ashworth"", ""Consolidated Holdings"" : ""Elizabeth Brown"", ""Eastern Connection"" : ""Ann Devon"", ""North/South"" : ""Simon Crowther"", ""Seven Seas Imports"" : ""Hari Kumar"" }", + london.Companies); + + AssertSql( + @"SELECT c.""City"", json_object_agg(c.""CompanyName"", c.""ContactName"" ORDER BY c.""CompanyName"" NULLS FIRST) AS ""Companies"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task GroupBy_JsonObjectAgg_as_Dictionary(bool async) + { + using var ctx = CreateContext(); + + var query = ctx.Set() + .GroupBy(c => c.City) + .Select( + g => new + { + City = g.Key, + Companies = EF.Functions.JsonObjectAgg>( + g.Select(c => ValueTuple.Create(c.CompanyName, c.ContactName))) + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var london = results.Single(r => r.City == "London"); + + Assert.Equal( + new Dictionary + { + ["Around the Horn"] = "Thomas Hardy", + ["B's Beverages"] = "Victoria Ashworth", + ["Consolidated Holdings"] = "Elizabeth Brown", + ["Eastern Connection"] = "Ann Devon", + ["North/South"] = "Simon Crowther", + ["Seven Seas Imports"] = "Hari Kumar" + }, + london.Companies); + + AssertSql( + @"SELECT c.""City"", json_object_agg(c.""CompanyName"", c.""ContactName"") AS ""Companies"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task GroupBy_JsonbObjectAgg(bool async) + { + using var ctx = CreateContext(); + + // Note that unlike with json, jsonb doesn't guarantee ordering; so we parse the JSON string client-side. + var query = ctx.Set() + .GroupBy(c => c.City) + .Select( + g => new + { + City = g.Key, + Companies = EF.Functions.JsonbObjectAgg(g.Select(c => ValueTuple.Create(c.CompanyName, c.ContactName))) + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var london = results.Single(r => r.City == "London"); + var companiesDictionary = JsonSerializer.Deserialize>(london.Companies); + + Assert.Equal( + new Dictionary + { + ["Around the Horn"] = "Thomas Hardy", + ["B's Beverages"] = "Victoria Ashworth", + ["Consolidated Holdings"] = "Elizabeth Brown", + ["Eastern Connection"] = "Ann Devon", + ["North/South"] = "Simon Crowther", + ["Seven Seas Imports"] = "Hari Kumar" + }, + companiesDictionary); + + AssertSql( + @"SELECT c.""City"", jsonb_object_agg(c.""CompanyName"", c.""ContactName"") AS ""Companies"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task GroupBy_JsonbObjectAgg_as_Dictionary(bool async) + { + using var ctx = CreateContext(); + + var query = ctx.Set() + .GroupBy(c => c.City) + .Select( + g => new + { + City = g.Key, + Companies = EF.Functions.JsonbObjectAgg>( + g.Select(c => ValueTuple.Create(c.CompanyName, c.ContactName))) + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var london = results.Single(r => r.City == "London"); + + Assert.Equal( + new Dictionary + { + ["Around the Horn"] = "Thomas Hardy", + ["B's Beverages"] = "Victoria Ashworth", + ["Consolidated Holdings"] = "Elizabeth Brown", + ["Eastern Connection"] = "Ann Devon", + ["North/South"] = "Simon Crowther", + ["Seven Seas Imports"] = "Hari Kumar" + }, + london.Companies); + + AssertSql( + @"SELECT c.""City"", jsonb_object_agg(c.""CompanyName"", c.""ContactName"") AS ""Companies"" +FROM ""Customers"" AS c +GROUP BY c.""City"""); + } + + #endregion JsonObjectAgg + + #region Statistics + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task StandardDeviation(bool async) + { + await using var ctx = CreateContext(); + + var query = ctx.Set() + .GroupBy(od => od.ProductID) + .Select(g => new + { + ProductID = g.Key, + SampleStandardDeviation = EF.Functions.StandardDeviationSample(g.Select(od => od.UnitPrice)), + PopulationStandardDeviation = EF.Functions.StandardDeviationPopulation(g.Select(od => od.UnitPrice)) + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var product9 = results.Single(r => r.ProductID == 9); + Assert.Equal(8.675943752699023, product9.SampleStandardDeviation.Value, 5); + Assert.Equal(7.759999999999856, product9.PopulationStandardDeviation.Value, 5); + + AssertSql( + @"SELECT o.""ProductID"", stddev_samp(o.""UnitPrice"") AS ""SampleStandardDeviation"", stddev_pop(o.""UnitPrice"") AS ""PopulationStandardDeviation"" +FROM ""Order Details"" AS o +GROUP BY o.""ProductID"""); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Variance(bool async) + { + await using var ctx = CreateContext(); + + var query = ctx.Set() + .GroupBy(od => od.ProductID) + .Select( + g => new + { + ProductID = g.Key, + SampleStandardDeviation = EF.Functions.VarianceSample(g.Select(od => od.UnitPrice)), + PopulationStandardDeviation = EF.Functions.VariancePopulation(g.Select(od => od.UnitPrice)) + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var product9 = results.Single(r => r.ProductID == 9); + Assert.Equal(75.2719999999972, product9.SampleStandardDeviation.Value, 5); + Assert.Equal(60.217599999997766, product9.PopulationStandardDeviation.Value, 5); + + AssertSql( + @"SELECT o.""ProductID"", var_samp(o.""UnitPrice"") AS ""SampleStandardDeviation"", var_pop(o.""UnitPrice"") AS ""PopulationStandardDeviation"" +FROM ""Order Details"" AS o +GROUP BY o.""ProductID"""); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Other_statistics_functions(bool async) + { + await using var ctx = CreateContext(); + + var query = ctx.Set() + .GroupBy(od => od.ProductID) + .Select(g => new + { + ProductID = g.Key, + Correlation = EF.Functions.Correlation(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + CovariancePopulation = EF.Functions.CovariancePopulation(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + CovarianceSample = EF.Functions.CovarianceSample(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + RegrAverageX = EF.Functions.RegrAverageX(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + RegrAverageY = EF.Functions.RegrAverageY(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + RegrCount = EF.Functions.RegrCount(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + RegrIntercept = EF.Functions.RegrIntercept(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + RegrR2 = EF.Functions.RegrR2(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + RegrSlope = EF.Functions.RegrSlope(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + RegrSXX = EF.Functions.RegrSXX(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + RegrSXY = EF.Functions.RegrSXY(g.Select(od => ValueTuple.Create((double)od.Quantity, (double)od.Discount))), + }); + + var results = async + ? await query.ToListAsync() + : query.ToList(); + + var product9 = results.Single(r => r.ProductID == 9); + Assert.Equal(0.9336470941441423, product9.Correlation.Value, 5); + Assert.Equal(1.4799999967217445, product9.CovariancePopulation.Value, 5); + Assert.Equal(1.8499999959021807, product9.CovarianceSample.Value, 5); + Assert.Equal(0.10000000149011612, product9.RegrAverageX.Value, 5); + Assert.Equal(19, product9.RegrAverageY.Value, 5); + Assert.Equal(5, product9.RegrCount.Value); + Assert.Equal(2.5555555647538144, product9.RegrIntercept.Value, 5); + Assert.Equal(0.871696896403801, product9.RegrR2.Value, 5); + Assert.Equal(164.44444190204874, product9.RegrSlope.Value, 5); + Assert.Equal(0.045000000596046474, product9.RegrSXX.Value, 5); + Assert.Equal(7.399999983608723, product9.RegrSXY.Value, 5); + + AssertSql( + @"SELECT o.""ProductID"", corr(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""Correlation"", covar_pop(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""CovariancePopulation"", covar_samp(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""CovarianceSample"", regr_avgx(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""RegrAverageX"", regr_avgy(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""RegrAverageY"", regr_count(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""RegrCount"", regr_intercept(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""RegrIntercept"", regr_r2(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""RegrR2"", regr_slope(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""RegrSlope"", regr_sxx(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""RegrSXX"", regr_sxy(o.""Quantity""::double precision, o.""Discount""::double precision) AS ""RegrSXY"" +FROM ""Order Details"" AS o +GROUP BY o.""ProductID"""); + } + + #endregion Statistics + #region Unsupported // PostgreSQL does not have strpos with starting position diff --git a/test/EFCore.PG.FunctionalTests/Query/NorthwindGroupByQueryNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/Query/NorthwindGroupByQueryNpgsqlTest.cs index d61a10025..007d37f6c 100644 --- a/test/EFCore.PG.FunctionalTests/Query/NorthwindGroupByQueryNpgsqlTest.cs +++ b/test/EFCore.PG.FunctionalTests/Query/NorthwindGroupByQueryNpgsqlTest.cs @@ -1,3 +1,5 @@ +using Microsoft.EntityFrameworkCore.TestModels.Northwind; + namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query; public class NorthwindGroupByQueryNpgsqlTest : NorthwindGroupByQueryRelationalTestBase> @@ -3144,6 +3146,23 @@ LEFT JOIN ( ORDER BY t0.""Key"" NULLS FIRST, t1.""OrderID"" NULLS FIRST"); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task GroupBy_OrderBy_Average(bool async) + { + await AssertQueryScalar( + async, + ss => from o in ss.Set() + group o by new { o.CustomerID } + into g + select g.OrderBy(e => e.OrderID).Select(e => (int?)e.OrderID).Average()); + + AssertSql( + @"SELECT avg(o.""OrderID""::double precision ORDER BY o.""OrderID"" NULLS FIRST) +FROM ""Orders"" AS o +GROUP BY o.""CustomerID"""); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); diff --git a/test/EFCore.PG.FunctionalTests/Query/RangeQueryNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/Query/RangeQueryNpgsqlTest.cs index 9fa8ad42d..a34c1b17f 100644 --- a/test/EFCore.PG.FunctionalTests/Query/RangeQueryNpgsqlTest.cs +++ b/test/EFCore.PG.FunctionalTests/Query/RangeQueryNpgsqlTest.cs @@ -223,6 +223,31 @@ public void Union() LIMIT 2"); } + [ConditionalFact] + [MinimumPostgresVersion(14, 0)] // Multiranges were introduced in PostgreSQL 14 + public void Union_aggregate() + { + using var context = CreateContext(); + + var union = context.RangeTestEntities + .Where(x => x.Id == 1 || x.Id == 2) + .GroupBy(x => true) + .Select(g => EF.Functions.RangeAgg(g.Select(x => x.IntRange))) + .Single(); + + Assert.Equal(new NpgsqlRange[] { new(1, true, 16, false) }, union); + + AssertSql( + @"SELECT range_agg(t.""IntRange"") +FROM ( + SELECT r.""IntRange"", TRUE AS ""Key"" + FROM ""RangeTestEntities"" AS r + WHERE r.""Id"" IN (1, 2) +) AS t +GROUP BY t.""Key"" +LIMIT 2"); + } + [ConditionalFact] public void Intersect() { @@ -240,6 +265,31 @@ public void Intersect() LIMIT 2"); } + [ConditionalFact] + [MinimumPostgresVersion(14, 0)] // Multiranges were introduced in PostgreSQL 14 + public void Intersect_aggregate() + { + using var context = CreateContext(); + + var intersection = context.RangeTestEntities + .Where(x => x.Id == 1 || x.Id == 2) + .GroupBy(x => true) + .Select(g => EF.Functions.RangeIntersectAgg(g.Select(x => x.IntRange))) + .Single(); + + Assert.Equal(new NpgsqlRange(5, true, 11, false), intersection); + + AssertSql( + @"SELECT range_intersect_agg(t.""IntRange"") +FROM ( + SELECT r.""IntRange"", TRUE AS ""Key"" + FROM ""RangeTestEntities"" AS r + WHERE r.""Id"" IN (1, 2) +) AS t +GROUP BY t.""Key"" +LIMIT 2"); + } + [ConditionalFact] public void Except() {