Skip to content

Commit

Permalink
Translate PostgreSQL custom aggregates
Browse files Browse the repository at this point in the history
* string_agg (string.Join)
* array_agg
* json_agg/jsonb_agg
* json_object_agg/jsonb_object_agg
* range_agg, range_intersect_agg
* Aggregate statistics functions

Closes npgsql#532
  • Loading branch information
roji committed Jun 4, 2022
1 parent 1016470 commit 1b5472d
Show file tree
Hide file tree
Showing 11 changed files with 1,287 additions and 56 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ public NpgsqlAggregateMethodCallTranslatorProvider(RelationalAggregateMethodCall
AddTranslators(
new IAggregateMethodCallTranslator[]
{
new NpgsqlQueryableAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource)
new NpgsqlQueryableAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource),
new NpgsqlStatisticsAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource),
new NpgsqlMiscAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource)
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal;
using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping;
using static Npgsql.EntityFrameworkCore.PostgreSQL.Utilities.Statics;

namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal;

public class NpgsqlMiscAggregateMethodTranslator : IAggregateMethodCallTranslator
{
private static readonly MethodInfo StringJoin
= typeof(string).GetRuntimeMethod(nameof(string.Join), new[] { typeof(string), typeof(IEnumerable<string>) })!;

private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory;
private readonly IRelationalTypeMappingSource _typeMappingSource;

public NpgsqlMiscAggregateMethodTranslator(
NpgsqlSqlExpressionFactory sqlExpressionFactory,
IRelationalTypeMappingSource typeMappingSource)
{
_sqlExpressionFactory = sqlExpressionFactory;
_typeMappingSource = typeMappingSource;
}

public virtual SqlExpression? Translate(
MethodInfo method,
EnumerableExpression source,
IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (source.Selector is not SqlExpression sqlExpression)
{
return null;
}

if (method == StringJoin)
{
// TODO: Align this with the upstream support, including Concat translation
// TODO: Confirm result type mapping, what happens with e.g. varchar(8)
return _sqlExpressionFactory.AggregateFunction(
"string_agg",
new[] { sqlExpression, arguments[0] },
source,
nullable: true, argumentsPropagateNullability: new[] { false, true }, returnType: typeof(string));
}

if (method.DeclaringType == typeof(NpgsqlAggregateDbFunctionsExtensions))
{
switch (method.Name)
{
case nameof(NpgsqlAggregateDbFunctionsExtensions.ArrayAgg):
var arrayClrType = sqlExpression.Type.MakeArrayType();

return _sqlExpressionFactory.AggregateFunction(
"array_agg",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: arrayClrType,
typeMapping: sqlExpression.TypeMapping is null
? null
: new NpgsqlArrayArrayTypeMapping(arrayClrType, sqlExpression.TypeMapping));

case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonAgg):
arrayClrType = sqlExpression.Type.MakeArrayType();

return _sqlExpressionFactory.AggregateFunction(
"json_agg",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: arrayClrType,
_typeMappingSource.FindMapping(arrayClrType, "json"));

case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbAgg):
arrayClrType = sqlExpression.Type.MakeArrayType();

return _sqlExpressionFactory.AggregateFunction(
"jsonb_agg",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: arrayClrType,
_typeMappingSource.FindMapping(arrayClrType, "jsonb"));

case nameof(NpgsqlAggregateDbFunctionsExtensions.RangeAgg):
arrayClrType = sqlExpression.Type.MakeArrayType();

return _sqlExpressionFactory.AggregateFunction(
"range_agg",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: arrayClrType,
_typeMappingSource.FindMapping(arrayClrType));

case nameof(NpgsqlAggregateDbFunctionsExtensions.RangeIntersectAgg):
return _sqlExpressionFactory.AggregateFunction(
"range_intersect_agg",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: sqlExpression.Type,
sqlExpression.TypeMapping);

case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbObjectAgg):
case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonObjectAgg):
var isJsonb = method.Name == nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbObjectAgg);

// These methods accept two enumerable (column) arguments; this is represented in LINQ as a projection from the grouping
// to a tuple of the two columns. Since we generally translate tuples to PostgresRowValueExpression, we take it apart
// here.
if (source.Selector is not PostgresRowValueExpression rowValueExpression)
{
return null;
}

var (keys, values) = (rowValueExpression.Values[0], rowValueExpression.Values[1]);

return _sqlExpressionFactory.AggregateFunction(
isJsonb ? "jsonb_object_agg" : "json_object_agg",
new[] { keys, values },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[2],
returnType: method.ReturnType,
_typeMappingSource.FindMapping(method.ReturnType, isJsonb ? "jsonb" : "json"));
}
}

return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,16 @@ public NpgsqlQueryableAggregateMethodTranslator(
_sqlExpressionFactory.AggregateFunction(
"AVG",
new[] { averageSqlExpression },
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
typeof(double)),
nullable: true, argumentsPropagateNullability: FalseArrays[1], returnType: typeof(double)),
averageSqlExpression.Type,
averageSqlExpression.TypeMapping)
: _sqlExpressionFactory.AggregateFunction(
"AVG",
new[] { averageSqlExpression },
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
averageSqlExpression.Type,
averageSqlExpression.TypeMapping);
nullable: true,
argumentsPropagateNullability: FalseArrays[1], returnType: averageSqlExpression.Type, typeMapping: averageSqlExpression.TypeMapping);

// PostgreSQL COUNT() always returns bigint, so we need to downcast to int
case nameof(Queryable.Count)
Expand All @@ -70,10 +66,8 @@ public NpgsqlQueryableAggregateMethodTranslator(
_sqlExpressionFactory.AggregateFunction(
"COUNT",
new[] { countSqlExpression },
nullable: false,
argumentsPropagateNullability: FalseArrays[1],
source,
typeof(long))),
nullable: false, argumentsPropagateNullability: FalseArrays[1], returnType: typeof(long))),
typeof(int), _typeMappingSource.FindMapping(typeof(int)));

case nameof(Queryable.LongCount)
Expand All @@ -84,10 +78,8 @@ public NpgsqlQueryableAggregateMethodTranslator(
_sqlExpressionFactory.AggregateFunction(
"COUNT",
new[] { longCountSqlExpression },
nullable: false,
argumentsPropagateNullability: FalseArrays[1],
source,
typeof(long)));
nullable: false, argumentsPropagateNullability: FalseArrays[1], returnType: typeof(long)));

case nameof(Queryable.Max)
when (methodInfo == QueryableMethods.MaxWithoutSelector
Expand All @@ -96,11 +88,9 @@ public NpgsqlQueryableAggregateMethodTranslator(
return _sqlExpressionFactory.AggregateFunction(
"MAX",
new[] { maxSqlExpression },
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
maxSqlExpression.Type,
maxSqlExpression.TypeMapping);
nullable: true,
argumentsPropagateNullability: FalseArrays[1], returnType: maxSqlExpression.Type, typeMapping: maxSqlExpression.TypeMapping);

case nameof(Queryable.Min)
when (methodInfo == QueryableMethods.MinWithoutSelector
Expand All @@ -109,11 +99,9 @@ public NpgsqlQueryableAggregateMethodTranslator(
return _sqlExpressionFactory.AggregateFunction(
"MIN",
new[] { minSqlExpression },
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
minSqlExpression.Type,
minSqlExpression.TypeMapping);
nullable: true,
argumentsPropagateNullability: FalseArrays[1], returnType: minSqlExpression.Type, typeMapping: minSqlExpression.TypeMapping);

// In PostgreSQL SUM() doesn't return the same type as its argument for smallint, int and bigint.
// Cast to get the same type.
Expand All @@ -131,10 +119,8 @@ public NpgsqlQueryableAggregateMethodTranslator(
_sqlExpressionFactory.AggregateFunction(
"SUM",
new[] { sumSqlExpression },
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
typeof(long)),
nullable: true, argumentsPropagateNullability: FalseArrays[1], returnType: typeof(long)),
sumInputType,
sumSqlExpression.TypeMapping);
}
Expand All @@ -145,22 +131,18 @@ public NpgsqlQueryableAggregateMethodTranslator(
_sqlExpressionFactory.AggregateFunction(
"SUM",
new[] { sumSqlExpression },
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
typeof(decimal)),
nullable: true, argumentsPropagateNullability: FalseArrays[1], returnType: typeof(decimal)),
sumInputType,
sumSqlExpression.TypeMapping);
}

return _sqlExpressionFactory.AggregateFunction(
"SUM",
new[] { sumSqlExpression },
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
sumInputType,
sumSqlExpression.TypeMapping);
nullable: true,
argumentsPropagateNullability: FalseArrays[1], returnType: sumInputType, typeMapping: sumSqlExpression.TypeMapping);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal;
using static Npgsql.EntityFrameworkCore.PostgreSQL.Utilities.Statics;

namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal;

public class NpgsqlStatisticsAggregateMethodTranslator : IAggregateMethodCallTranslator
{
private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory;
private readonly RelationalTypeMapping _doubleTypeMapping, _longTypeMapping;

public NpgsqlStatisticsAggregateMethodTranslator(
NpgsqlSqlExpressionFactory sqlExpressionFactory,
IRelationalTypeMappingSource typeMappingSource)
{
_sqlExpressionFactory = sqlExpressionFactory;
_doubleTypeMapping = typeMappingSource.FindMapping(typeof(double))!;
_longTypeMapping = typeMappingSource.FindMapping(typeof(long))!;
}

public virtual SqlExpression? Translate(
MethodInfo method,
EnumerableExpression source,
IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
// Docs: https://www.postgresql.org/docs/current/functions-aggregate.html#FUNCTIONS-AGGREGATE-STATISTICS-TABLE

if (method.DeclaringType != typeof(NpgsqlAggregateDbFunctionsExtensions)
|| source.Selector is not SqlExpression sqlExpression)
{
return null;
}

// These four functions are simple and take a single enumerable argument
var functionName = method.Name switch
{
nameof(NpgsqlAggregateDbFunctionsExtensions.StandardDeviationSample) => "stddev_samp",
nameof(NpgsqlAggregateDbFunctionsExtensions.StandardDeviationPopulation) => "stddev_pop",
nameof(NpgsqlAggregateDbFunctionsExtensions.VarianceSample) => "var_samp",
nameof(NpgsqlAggregateDbFunctionsExtensions.VariancePopulation) => "var_pop",
_ => null
};

if (functionName is not null)
{
return _sqlExpressionFactory.AggregateFunction(
functionName,
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
typeof(double),
_doubleTypeMapping);
}

functionName = method.Name switch
{
nameof(NpgsqlAggregateDbFunctionsExtensions.Correlation) => "corr",
nameof(NpgsqlAggregateDbFunctionsExtensions.CovariancePopulation) => "covar_pop",
nameof(NpgsqlAggregateDbFunctionsExtensions.CovarianceSample) => "covar_samp",
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrAverageX) => "regr_avgx",
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrAverageY) => "regr_avgy",
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrCount) => "regr_count",
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrIntercept) => "regr_intercept",
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrR2) => "regr_r2",
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrSlope) => "regr_slope",
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrSXX) => "regr_sxx",
nameof(NpgsqlAggregateDbFunctionsExtensions.RegrSXY) => "regr_sxy",
_ => null
};

if (functionName is not null)
{
// These methods accept two enumerable (column) arguments; this is represented in LINQ as a projection from the grouping
// to a tuple of the two columns. Since we generally translate tuples to PostgresRowValueExpression, we take it apart here.
if (source.Selector is not PostgresRowValueExpression rowValueExpression)
{
return null;
}

var (y, x) = (rowValueExpression.Values[0], rowValueExpression.Values[1]);

return method.Name == nameof(NpgsqlAggregateDbFunctionsExtensions.RegrCount)
? _sqlExpressionFactory.AggregateFunction(
functionName,
new[] { y, x },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[2],
typeof(long),
_longTypeMapping)
: _sqlExpressionFactory.AggregateFunction(
functionName,
new[] { y, x },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[2],
typeof(double),
_doubleTypeMapping);
}

return null;
}
}
Loading

0 comments on commit 1b5472d

Please sign in to comment.