Skip to content

Commit

Permalink
Implement sum and average over TimeSpan (npgsql#2423)
Browse files Browse the repository at this point in the history
  • Loading branch information
roji authored Jul 9, 2022
1 parent 82a758c commit 1eed21e
Show file tree
Hide file tree
Showing 10 changed files with 282 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// ReSharper disable once CheckNamespace
namespace Microsoft.EntityFrameworkCore;

/// <summary>
/// Provides extension methods supporting NodaTime function translation for PostgreSQL.
/// </summary>
public static class NpgsqlNodaTimeDbFunctionsExtensions
{
/// <summary>
/// Computes the sum of the non-null input intervals. Corresponds to the PostgreSQL <c>sum</c> aggregate function.
/// </summary>
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
/// <param name="input">The input values to be summed.</param>
/// <seealso href="https://www.postgresql.org/docs/current/functions-aggregate.html">PostgreSQL documentation for aggregate functions.</seealso>
public static Period? Sum(this DbFunctions _, IEnumerable<Period> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Sum)));

/// <summary>
/// Computes the sum of the non-null input intervals. Corresponds to the PostgreSQL <c>sum</c> aggregate function.
/// </summary>
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
/// <param name="input">The input values to be summed.</param>
/// <seealso href="https://www.postgresql.org/docs/current/functions-aggregate.html">PostgreSQL documentation for aggregate functions.</seealso>
public static Duration? Sum(this DbFunctions _, IEnumerable<Duration> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Sum)));

/// <summary>
/// Computes the average (arithmetic mean) of the non-null input intervals. Corresponds to the PostgreSQL <c>avg</c> aggregate function.
/// </summary>
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
/// <param name="input">The input values to be computed into an average.</param>
/// <seealso href="https://www.postgresql.org/docs/current/functions-aggregate.html">PostgreSQL documentation for aggregate functions.</seealso>
public static Period? Average(this DbFunctions _, IEnumerable<Period> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Average)));

/// <summary>
/// Computes the average (arithmetic mean) of the non-null input intervals. Corresponds to the PostgreSQL <c>avg</c> aggregate function.
/// </summary>
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
/// <param name="input">The input values to be computed into an average.</param>
/// <seealso href="https://www.postgresql.org/docs/current/functions-aggregate.html">PostgreSQL documentation for aggregate functions.</seealso>
public static Duration? Average(this DbFunctions _, IEnumerable<Duration> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Average)));
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public static IServiceCollection AddEntityFrameworkNpgsqlNodaTime(
new EntityFrameworkRelationalServicesBuilder(serviceCollection)
.TryAdd<IRelationalTypeMappingSourcePlugin, NpgsqlNodaTimeTypeMappingSourcePlugin>()
.TryAdd<IMethodCallTranslatorPlugin, NpgsqlNodaTimeMethodCallTranslatorPlugin>()
.TryAdd<IAggregateMethodCallTranslatorPlugin, NpgsqlNodaTimeAggregateMethodCallTranslatorPlugin>()
.TryAdd<IMemberTranslatorPlugin, NpgsqlNodaTimeMemberTranslatorPlugin>()
.TryAdd<IEvaluatableExpressionFilterPlugin, NpgsqlNodaTimeEvaluatableExpressionFilterPlugin>();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
using Npgsql.EntityFrameworkCore.PostgreSQL.Query;

namespace Npgsql.EntityFrameworkCore.PostgreSQL.NodaTime.Query.Internal;

public class NpgsqlNodaTimeAggregateMethodCallTranslatorPlugin : IAggregateMethodCallTranslatorPlugin
{
public NpgsqlNodaTimeAggregateMethodCallTranslatorPlugin(ISqlExpressionFactory sqlExpressionFactory)
{
if (sqlExpressionFactory is not NpgsqlSqlExpressionFactory npgsqlSqlExpressionFactory)
{
throw new ArgumentException($"Must be an {nameof(NpgsqlSqlExpressionFactory)}", nameof(sqlExpressionFactory));
}

Translators = new IAggregateMethodCallTranslator[]
{
new NpgsqlNodaTimeAggregateMethodTranslator(npgsqlSqlExpressionFactory)
};
}

public virtual IEnumerable<IAggregateMethodCallTranslator> Translators { get; }
}

public class NpgsqlNodaTimeAggregateMethodTranslator : IAggregateMethodCallTranslator
{
private static readonly bool[][] FalseArrays = { Array.Empty<bool>(), new[] { false } };

private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory;

public NpgsqlNodaTimeAggregateMethodTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory)
=> _sqlExpressionFactory = sqlExpressionFactory;

public virtual SqlExpression? Translate(
MethodInfo method,
EnumerableExpression source,
IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (source.Selector is not SqlExpression sqlExpression || method.DeclaringType != typeof(NpgsqlNodaTimeDbFunctionsExtensions))
{
return null;
}

switch (method.Name)
{
case nameof(NpgsqlNodaTimeDbFunctionsExtensions.Sum):
return _sqlExpressionFactory.AggregateFunction(
"sum",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: sqlExpression.Type,
sqlExpression.TypeMapping);

case nameof(NpgsqlNodaTimeDbFunctionsExtensions.Average):
return _sqlExpressionFactory.AggregateFunction(
"avg",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: sqlExpression.Type,
sqlExpression.TypeMapping);

default:
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,24 @@ public static T[] JsonAgg<T>(this DbFunctions _, IEnumerable<T> input)
public static T[] JsonbAgg<T>(this DbFunctions _, IEnumerable<T> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(JsonbAgg)));

/// <summary>
/// Computes the sum of the non-null input intervals. Corresponds to the PostgreSQL <c>sum</c> aggregate function.
/// </summary>
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
/// <param name="input">The input values to be summed.</param>
/// <seealso href="https://www.postgresql.org/docs/current/functions-aggregate.html">PostgreSQL documentation for aggregate functions.</seealso>
public static TimeSpan? Sum(this DbFunctions _, IEnumerable<TimeSpan> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Sum)));

/// <summary>
/// Computes the average (arithmetic mean) of the non-null input intervals. Corresponds to the PostgreSQL <c>avg</c> aggregate function.
/// </summary>
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
/// <param name="input">The input values to be computed into an average.</param>
/// <seealso href="https://www.postgresql.org/docs/current/functions-aggregate.html">PostgreSQL documentation for aggregate functions.</seealso>
public static TimeSpan? Average(this DbFunctions _, IEnumerable<TimeSpan> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Average)));

#region Range

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,26 @@ public NpgsqlMiscAggregateMethodTranslator(
returnType: sqlExpression.Type,
sqlExpression.TypeMapping);

case nameof(NpgsqlAggregateDbFunctionsExtensions.Sum):
return _sqlExpressionFactory.AggregateFunction(
"sum",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: sqlExpression.Type,
sqlExpression.TypeMapping);

case nameof(NpgsqlAggregateDbFunctionsExtensions.Average):
return _sqlExpressionFactory.AggregateFunction(
"avg",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: sqlExpression.Type,
sqlExpression.TypeMapping);

case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbObjectAgg):
case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonObjectAgg):
var isJsonb = method.Name == nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbObjectAgg);
Expand Down
9 changes: 9 additions & 0 deletions src/EFCore.PG/Query/Internal/NpgsqlSqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,15 @@ protected override SqlExpression VisitSqlFunction(
&& visitedBase is SqlFunctionExpression { Name: "COALESCE", Arguments: { } } coalesceExpression
&& coalesceExpression.Arguments[0] is PostgresFunctionExpression wrappedFunctionExpression)
{
// The base logic assumes sum is operating over numbers, which breaks sum over PG interval.
// Detect that case and remove the coalesce entirely (note that we don't need coalescing since sum function is in
// EF.Functions.Sum, and returns nullable. This is a temporary hack until #38158 is fixed.
if (sqlFunctionExpression.Type == typeof(TimeSpan)
|| sqlFunctionExpression.Type.FullName is "NodaTime.Period" or "NodaTime.Duration")
{
return coalesceExpression.Arguments[0];
}

var visitedArguments = coalesceExpression.Arguments!.ToArray();
visitedArguments[0] = VisitPostgresFunctionComponents(wrappedFunctionExpression);

Expand Down
38 changes: 38 additions & 0 deletions test/EFCore.PG.FunctionalTests/Query/GearsOfWarQueryNpgsqlTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,44 @@ await AssertQuery(
WHERE (date_part('epoch', m.""Duration"") / 0.001) < 3700000.0");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task GroupBy_Property_Select_Sum_over_TimeSpan(bool async)
{
await AssertQueryScalar(
async,
ss => ss.Set<Mission>()
.GroupBy(o => o.Id)
.Select(g => EF.Functions.Sum(g.Select(o => o.Duration))),
ss => ss.Set<Mission>()
.GroupBy(o => o.Id)
.Select(g => (TimeSpan?)new TimeSpan(g.Sum(o => o.Duration.Ticks))));

AssertSql(
@"SELECT sum(m.""Duration"")
FROM ""Missions"" AS m
GROUP BY m.""Id""");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task GroupBy_Property_Select_Average_over_TimeSpan(bool async)
{
await AssertQueryScalar(
async,
ss => ss.Set<Mission>()
.GroupBy(o => o.Id)
.Select(g => EF.Functions.Average(g.Select(o => o.Duration))),
ss => ss.Set<Mission>()
.GroupBy(o => o.Id)
.Select(g => (TimeSpan?)new TimeSpan((long)g.Average(o => o.Duration.Ticks))));

AssertSql(
@"SELECT avg(m.""Duration"")
FROM ""Missions"" AS m
GROUP BY m.""Id""");
}

#endregion TimeSpan

#region DateOnly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public NorthwindFunctionsQueryNpgsqlTest(
: base(fixture)
{
ClearLog();
Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
//Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
}

public override async Task IsNullOrWhiteSpace_in_predicate(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3163,6 +3163,8 @@ into g
GROUP BY o.""CustomerID""");
}

// See aggregate tests over TimeSpan in GearsOfWarQueryNpsgqlTest

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
80 changes: 80 additions & 0 deletions test/EFCore.PG.NodaTime.FunctionalTests/NodaTimeQueryNpgsqlTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,48 @@ public Task Period_FromTicks_is_not_translated()
() => ctx.Set<NodaTimeTypes>().Where(t => Period.FromNanoseconds(t.Id).Seconds == 1).ToListAsync());
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public async Task GroupBy_Property_Select_Sum_over_Period(bool async)
{
using var ctx = CreateContext();

// Note: Unlike Duration, Period can't be converted to total ticks (because its absolute time varies).
var query = ctx.Set<NodaTimeTypes>()
.GroupBy(o => o.Id)
.Select(g => EF.Functions.Sum(g.Select(o => o.Period)));

_ = async
? await query.ToListAsync()
: query.ToList();

AssertSql(
@"SELECT sum(n.""Period"")
FROM ""NodaTimeTypes"" AS n
GROUP BY n.""Id""");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public async Task GroupBy_Property_Select_Average_over_Period(bool async)
{
using var ctx = CreateContext();

// Note: Unlike Duration, Period can't be converted to total ticks (because its absolute time varies).
var query = ctx.Set<NodaTimeTypes>()
.GroupBy(o => o.Id)
.Select(g => EF.Functions.Average(g.Select(o => o.Period)));

_ = async
? await query.ToListAsync()
: query.ToList();

AssertSql(
@"SELECT avg(n.""Period"")
FROM ""NodaTimeTypes"" AS n
GROUP BY n.""Id""");
}

#endregion Period

#region Duration
Expand Down Expand Up @@ -894,6 +936,44 @@ await AssertQuery(
WHERE floor(date_part('second', n.""Duration""))::int = 8");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public async Task GroupBy_Property_Select_Sum_over_Duration(bool async)
{
await AssertQueryScalar(
async,
ss => ss.Set<NodaTimeTypes>()
.GroupBy(o => o.Id)
.Select(g => EF.Functions.Sum(g.Select(o => o.Duration))),
expectedQuery: ss => ss.Set<NodaTimeTypes>()
.GroupBy(o => o.Id)
.Select(g => (Duration?)Duration.FromTicks(g.Sum(o => o.Duration.TotalTicks))));

AssertSql(
@"SELECT sum(n.""Duration"")
FROM ""NodaTimeTypes"" AS n
GROUP BY n.""Id""");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public async Task GroupBy_Property_Select_Average_over_Duration(bool async)
{
await AssertQueryScalar(
async,
ss => ss.Set<NodaTimeTypes>()
.GroupBy(o => o.Id)
.Select(g => EF.Functions.Average(g.Select(o => o.Duration))),
expectedQuery: ss => ss.Set<NodaTimeTypes>()
.GroupBy(o => o.Id)
.Select(g => (Duration?)Duration.FromTicks((long)g.Average(o => o.Duration.TotalTicks))));

AssertSql(
@"SELECT avg(n.""Duration"")
FROM ""NodaTimeTypes"" AS n
GROUP BY n.""Id""");
}

#endregion

#region Interval
Expand Down

0 comments on commit 1eed21e

Please sign in to comment.