Skip to content

Commit

Permalink
Translate PostgreSQL custom aggregates
Browse files Browse the repository at this point in the history
* string_agg (string.Join)
* array_agg
* json_agg/jsonb_agg
* json_object_agg/jsonb_object_agg
* range_agg, range_intersect_agg
* Aggregate statistics functions

Closes npgsql#2395
Closes npgsql#532
  • Loading branch information
roji committed Jun 24, 2022
1 parent ba7d2a5 commit a22027b
Show file tree
Hide file tree
Showing 12 changed files with 1,274 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ public NpgsqlNetTopologySuiteAggregateMethodTranslator(
return _sqlExpressionFactory.AggregateFunction(
method == GeometryCombineMethod ? "ST_Collect" : "ST_Union",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: new[] { false },
source,
typeof(Geometry),
resultTypeMapping);
}
Expand All @@ -80,9 +80,9 @@ public NpgsqlNetTopologySuiteAggregateMethodTranslator(
_sqlExpressionFactory.AggregateFunction(
"ST_Collect",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: new[] { false },
source,
typeof(Geometry),
resultTypeMapping)
},
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ public NpgsqlAggregateMethodCallTranslatorProvider(RelationalAggregateMethodCall
AddTranslators(
new IAggregateMethodCallTranslator[]
{
new NpgsqlQueryableAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource)
new NpgsqlQueryableAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource),
new NpgsqlStatisticsAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource),
new NpgsqlMiscAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource)
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal;
using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping;
using static Npgsql.EntityFrameworkCore.PostgreSQL.Utilities.Statics;

namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal;

public class NpgsqlMiscAggregateMethodTranslator : IAggregateMethodCallTranslator
{
private static readonly MethodInfo StringJoin
= typeof(string).GetRuntimeMethod(nameof(string.Join), new[] { typeof(string), typeof(IEnumerable<string>) })!;

private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory;
private readonly IRelationalTypeMappingSource _typeMappingSource;

public NpgsqlMiscAggregateMethodTranslator(
NpgsqlSqlExpressionFactory sqlExpressionFactory,
IRelationalTypeMappingSource typeMappingSource)
{
_sqlExpressionFactory = sqlExpressionFactory;
_typeMappingSource = typeMappingSource;
}

public virtual SqlExpression? Translate(
MethodInfo method,
EnumerableExpression source,
IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (source.Selector is not SqlExpression sqlExpression)
{
return null;
}

if (method == StringJoin)
{
// TODO: Align this with the upstream support, including Concat translation
// TODO: Confirm result type mapping, what happens with e.g. varchar(8)
return _sqlExpressionFactory.AggregateFunction(
"string_agg",
new[] { sqlExpression, arguments[0] },
source,
nullable: true, argumentsPropagateNullability: new[] { false, true }, returnType: typeof(string));
}

if (method.DeclaringType == typeof(NpgsqlAggregateDbFunctionsExtensions))
{
switch (method.Name)
{
case nameof(NpgsqlAggregateDbFunctionsExtensions.ArrayAgg):
var arrayClrType = sqlExpression.Type.MakeArrayType();

return _sqlExpressionFactory.AggregateFunction(
"array_agg",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: arrayClrType,
typeMapping: sqlExpression.TypeMapping is null
? null
: new NpgsqlArrayArrayTypeMapping(arrayClrType, sqlExpression.TypeMapping));

case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonAgg):
arrayClrType = sqlExpression.Type.MakeArrayType();

return _sqlExpressionFactory.AggregateFunction(
"json_agg",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: arrayClrType,
_typeMappingSource.FindMapping(arrayClrType, "json"));

case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbAgg):
arrayClrType = sqlExpression.Type.MakeArrayType();

return _sqlExpressionFactory.AggregateFunction(
"jsonb_agg",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: arrayClrType,
_typeMappingSource.FindMapping(arrayClrType, "jsonb"));

case nameof(NpgsqlAggregateDbFunctionsExtensions.RangeAgg):
arrayClrType = sqlExpression.Type.MakeArrayType();

return _sqlExpressionFactory.AggregateFunction(
"range_agg",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: arrayClrType,
_typeMappingSource.FindMapping(arrayClrType));

case nameof(NpgsqlAggregateDbFunctionsExtensions.RangeIntersectAgg):
return _sqlExpressionFactory.AggregateFunction(
"range_intersect_agg",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: sqlExpression.Type,
sqlExpression.TypeMapping);

case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbObjectAgg):
case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonObjectAgg):
var isJsonb = method.Name == nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbObjectAgg);

// These methods accept two enumerable (column) arguments; this is represented in LINQ as a projection from the grouping
// to a tuple of the two columns. Since we generally translate tuples to PostgresRowValueExpression, we take it apart
// here.
if (source.Selector is not PostgresRowValueExpression rowValueExpression)
{
return null;
}

var (keys, values) = (rowValueExpression.Values[0], rowValueExpression.Values[1]);

return _sqlExpressionFactory.AggregateFunction(
isJsonb ? "jsonb_object_agg" : "json_object_agg",
new[] { keys, values },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[2],
returnType: method.ReturnType,
_typeMappingSource.FindMapping(method.ReturnType, isJsonb ? "jsonb" : "json"));
}
}

return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,16 @@ public NpgsqlQueryableAggregateMethodTranslator(
_sqlExpressionFactory.AggregateFunction(
"avg",
new[] { averageSqlExpression },
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
typeof(double)),
nullable: true, argumentsPropagateNullability: FalseArrays[1], returnType: typeof(double)),
averageSqlExpression.Type,
averageSqlExpression.TypeMapping)
: _sqlExpressionFactory.AggregateFunction(
"avg",
new[] { averageSqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
averageSqlExpression.Type,
averageSqlExpression.TypeMapping);

Expand All @@ -69,11 +67,12 @@ public NpgsqlQueryableAggregateMethodTranslator(
_sqlExpressionFactory.AggregateFunction(
"count",
new[] { countSqlExpression },
source,
nullable: false,
argumentsPropagateNullability: FalseArrays[1],
source,
typeof(long)),
typeof(int), _typeMappingSource.FindMapping(typeof(int)));
typeof(int),
_typeMappingSource.FindMapping(typeof(int)));

case nameof(Queryable.LongCount)
when methodInfo == QueryableMethods.LongCountWithoutPredicate
Expand All @@ -82,9 +81,9 @@ public NpgsqlQueryableAggregateMethodTranslator(
return _sqlExpressionFactory.AggregateFunction(
"count",
new[] { longCountSqlExpression },
source,
nullable: false,
argumentsPropagateNullability: FalseArrays[1],
source,
typeof(long));

case nameof(Queryable.Max)
Expand All @@ -94,9 +93,9 @@ public NpgsqlQueryableAggregateMethodTranslator(
return _sqlExpressionFactory.AggregateFunction(
"max",
new[] { maxSqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
maxSqlExpression.Type,
maxSqlExpression.TypeMapping);

Expand All @@ -107,9 +106,9 @@ public NpgsqlQueryableAggregateMethodTranslator(
return _sqlExpressionFactory.AggregateFunction(
"min",
new[] { minSqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
minSqlExpression.Type,
minSqlExpression.TypeMapping);

Expand All @@ -129,9 +128,9 @@ public NpgsqlQueryableAggregateMethodTranslator(
_sqlExpressionFactory.AggregateFunction(
"sum",
new[] { sumSqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
typeof(long)),
sumInputType,
sumSqlExpression.TypeMapping);
Expand All @@ -143,9 +142,9 @@ public NpgsqlQueryableAggregateMethodTranslator(
_sqlExpressionFactory.AggregateFunction(
"sum",
new[] { sumSqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
typeof(decimal)),
sumInputType,
sumSqlExpression.TypeMapping);
Expand All @@ -154,9 +153,9 @@ public NpgsqlQueryableAggregateMethodTranslator(
return _sqlExpressionFactory.AggregateFunction(
"sum",
new[] { sumSqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
sumInputType,
sumSqlExpression.TypeMapping);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal;
using static Npgsql.EntityFrameworkCore.PostgreSQL.Utilities.Statics;

namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal;

public class NpgsqlStatisticsAggregateMethodTranslator : IAggregateMethodCallTranslator
{
private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory;
private readonly RelationalTypeMapping _doubleTypeMapping, _longTypeMapping;

public NpgsqlStatisticsAggregateMethodTranslator(
NpgsqlSqlExpressionFactory sqlExpressionFactory,
IRelationalTypeMappingSource typeMappingSource)
{
_sqlExpressionFactory = sqlExpressionFactory;
_doubleTypeMapping = typeMappingSource.FindMapping(typeof(double))!;
_longTypeMapping = typeMappingSource.FindMapping(typeof(long))!;
}

public virtual SqlExpression? Translate(
MethodInfo method,
EnumerableExpression source,
IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
// Docs: https://www.postgresql.org/docs/current/functions-aggregate.html#FUNCTIONS-AGGREGATE-STATISTICS-TABLE

if (method.DeclaringType != typeof(NpgsqlAggregateDbFunctionsExtensions)
|| source.Selector is not SqlExpression sqlExpression)
{
return null;
}

// These four functions are simple and take a single enumerable argument
var functionName = method.Name switch
{
nameof(NpgsqlAggregateDbFunctionsExtensions.StandardDeviationSample) => "stddev_samp",
nameof(NpgsqlAggregateDbFunctionsExtensions.StandardDeviationPopulation) => "stddev_pop",
nameof(NpgsqlAggregateDbFunctionsExtensions.VarianceSample) => "var_samp",
nameof(NpgsqlAggregateDbFunctionsExtensions.VariancePopulation) => "var_pop",
_ => null
};

if (functionName is not null)
{
return _sqlExpressionFactory.AggregateFunction(
functionName,
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
typeof(double),
_doubleTypeMapping);
}

functionName = method.Name switch
{
nameof(NpgsqlAggregateDbFunctionsExtensions.Correlation) => "corr",
nameof(NpgsqlAggregateDbFunctionsExtensions.CovariancePopulation) => "covar_pop",
nameof(NpgsqlAggregateDbFunctionsExtensions.CovarianceSample) => "covar_samp",
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrAverageX) => "regr_avgx",
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrAverageY) => "regr_avgy",
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrCount) => "regr_count",
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrIntercept) => "regr_intercept",
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrR2) => "regr_r2",
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrSlope) => "regr_slope",
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrSXX) => "regr_sxx",
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrSXY) => "regr_sxy",
_ => null
};

if (functionName is not null)
{
// These methods accept two enumerable (column) arguments; this is represented in LINQ as a projection from the grouping
// to a tuple of the two columns. Since we generally translate tuples to PostgresRowValueExpression, we take it apart here.
if (source.Selector is not PostgresRowValueExpression rowValueExpression)
{
return null;
}

var (y, x) = (rowValueExpression.Values[0], rowValueExpression.Values[1]);

return method.Name == nameof(NpgsqlAggregateDbFunctionsExtensions.RegrCount)
? _sqlExpressionFactory.AggregateFunction(
functionName,
new[] { y, x },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[2],
typeof(long),
_longTypeMapping)
: _sqlExpressionFactory.AggregateFunction(
functionName,
new[] { y, x },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[2],
typeof(double),
_doubleTypeMapping);
}

return null;
}
}
Loading

0 comments on commit a22027b

Please sign in to comment.