Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement sum and average over TimeSpan #2423

Merged
merged 1 commit into from
Jul 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// ReSharper disable once CheckNamespace
namespace Microsoft.EntityFrameworkCore;

/// <summary>
/// Provides extension methods supporting NodaTime function translation for PostgreSQL.
/// </summary>
public static class NpgsqlNodaTimeDbFunctionsExtensions
{
/// <summary>
/// Computes the sum of the non-null input intervals. Corresponds to the PostgreSQL <c>sum</c> aggregate function.
/// </summary>
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
/// <param name="input">The input values to be summed.</param>
/// <seealso href="https://www.postgresql.org/docs/current/functions-aggregate.html">PostgreSQL documentation for aggregate functions.</seealso>
public static Period? Sum(this DbFunctions _, IEnumerable<Period> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Sum)));

/// <summary>
/// Computes the sum of the non-null input intervals. Corresponds to the PostgreSQL <c>sum</c> aggregate function.
/// </summary>
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
/// <param name="input">The input values to be summed.</param>
/// <seealso href="https://www.postgresql.org/docs/current/functions-aggregate.html">PostgreSQL documentation for aggregate functions.</seealso>
public static Duration? Sum(this DbFunctions _, IEnumerable<Duration> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Sum)));

/// <summary>
/// Computes the average (arithmetic mean) of the non-null input intervals. Corresponds to the PostgreSQL <c>avg</c> aggregate function.
/// </summary>
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
/// <param name="input">The input values to be computed into an average.</param>
/// <seealso href="https://www.postgresql.org/docs/current/functions-aggregate.html">PostgreSQL documentation for aggregate functions.</seealso>
public static Period? Average(this DbFunctions _, IEnumerable<Period> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Average)));

/// <summary>
/// Computes the average (arithmetic mean) of the non-null input intervals. Corresponds to the PostgreSQL <c>avg</c> aggregate function.
/// </summary>
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
/// <param name="input">The input values to be computed into an average.</param>
/// <seealso href="https://www.postgresql.org/docs/current/functions-aggregate.html">PostgreSQL documentation for aggregate functions.</seealso>
public static Duration? Average(this DbFunctions _, IEnumerable<Duration> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Average)));
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public static IServiceCollection AddEntityFrameworkNpgsqlNodaTime(
new EntityFrameworkRelationalServicesBuilder(serviceCollection)
.TryAdd<IRelationalTypeMappingSourcePlugin, NpgsqlNodaTimeTypeMappingSourcePlugin>()
.TryAdd<IMethodCallTranslatorPlugin, NpgsqlNodaTimeMethodCallTranslatorPlugin>()
.TryAdd<IAggregateMethodCallTranslatorPlugin, NpgsqlNodaTimeAggregateMethodCallTranslatorPlugin>()
.TryAdd<IMemberTranslatorPlugin, NpgsqlNodaTimeMemberTranslatorPlugin>()
.TryAdd<IEvaluatableExpressionFilterPlugin, NpgsqlNodaTimeEvaluatableExpressionFilterPlugin>();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
using Npgsql.EntityFrameworkCore.PostgreSQL.Query;

namespace Npgsql.EntityFrameworkCore.PostgreSQL.NodaTime.Query.Internal;

public class NpgsqlNodaTimeAggregateMethodCallTranslatorPlugin : IAggregateMethodCallTranslatorPlugin
{
public NpgsqlNodaTimeAggregateMethodCallTranslatorPlugin(ISqlExpressionFactory sqlExpressionFactory)
{
if (sqlExpressionFactory is not NpgsqlSqlExpressionFactory npgsqlSqlExpressionFactory)
{
throw new ArgumentException($"Must be an {nameof(NpgsqlSqlExpressionFactory)}", nameof(sqlExpressionFactory));
}

Translators = new IAggregateMethodCallTranslator[]
{
new NpgsqlNodaTimeAggregateMethodTranslator(npgsqlSqlExpressionFactory)
};
}

public virtual IEnumerable<IAggregateMethodCallTranslator> Translators { get; }
}

public class NpgsqlNodaTimeAggregateMethodTranslator : IAggregateMethodCallTranslator
{
private static readonly bool[][] FalseArrays = { Array.Empty<bool>(), new[] { false } };

private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory;

public NpgsqlNodaTimeAggregateMethodTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory)
=> _sqlExpressionFactory = sqlExpressionFactory;

public virtual SqlExpression? Translate(
MethodInfo method,
EnumerableExpression source,
IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (source.Selector is not SqlExpression sqlExpression || method.DeclaringType != typeof(NpgsqlNodaTimeDbFunctionsExtensions))
{
return null;
}

switch (method.Name)
{
case nameof(NpgsqlNodaTimeDbFunctionsExtensions.Sum):
return _sqlExpressionFactory.AggregateFunction(
"sum",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: sqlExpression.Type,
sqlExpression.TypeMapping);

case nameof(NpgsqlNodaTimeDbFunctionsExtensions.Average):
return _sqlExpressionFactory.AggregateFunction(
"avg",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: sqlExpression.Type,
sqlExpression.TypeMapping);

default:
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,24 @@ public static T[] JsonAgg<T>(this DbFunctions _, IEnumerable<T> input)
public static T[] JsonbAgg<T>(this DbFunctions _, IEnumerable<T> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(JsonbAgg)));

/// <summary>
/// Computes the sum of the non-null input intervals. Corresponds to the PostgreSQL <c>sum</c> aggregate function.
/// </summary>
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
/// <param name="input">The input values to be summed.</param>
/// <seealso href="https://www.postgresql.org/docs/current/functions-aggregate.html">PostgreSQL documentation for aggregate functions.</seealso>
public static TimeSpan? Sum(this DbFunctions _, IEnumerable<TimeSpan> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Sum)));

/// <summary>
/// Computes the average (arithmetic mean) of the non-null input intervals. Corresponds to the PostgreSQL <c>avg</c> aggregate function.
/// </summary>
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
/// <param name="input">The input values to be computed into an average.</param>
/// <seealso href="https://www.postgresql.org/docs/current/functions-aggregate.html">PostgreSQL documentation for aggregate functions.</seealso>
public static TimeSpan? Average(this DbFunctions _, IEnumerable<TimeSpan> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Average)));

#region Range

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,26 @@ public NpgsqlMiscAggregateMethodTranslator(
returnType: sqlExpression.Type,
sqlExpression.TypeMapping);

case nameof(NpgsqlAggregateDbFunctionsExtensions.Sum):
return _sqlExpressionFactory.AggregateFunction(
"sum",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: sqlExpression.Type,
sqlExpression.TypeMapping);

case nameof(NpgsqlAggregateDbFunctionsExtensions.Average):
return _sqlExpressionFactory.AggregateFunction(
"avg",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: sqlExpression.Type,
sqlExpression.TypeMapping);

case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbObjectAgg):
case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonObjectAgg):
var isJsonb = method.Name == nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbObjectAgg);
Expand Down
9 changes: 9 additions & 0 deletions src/EFCore.PG/Query/Internal/NpgsqlSqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,15 @@ protected override SqlExpression VisitSqlFunction(
&& visitedBase is SqlFunctionExpression { Name: "COALESCE", Arguments: { } } coalesceExpression
&& coalesceExpression.Arguments[0] is PostgresFunctionExpression wrappedFunctionExpression)
{
// The base logic assumes sum is operating over numbers, which breaks sum over PG interval.
// Detect that case and remove the coalesce entirely (note that we don't need coalescing since sum function is in
// EF.Functions.Sum, and returns nullable. This is a temporary hack until #38158 is fixed.
if (sqlFunctionExpression.Type == typeof(TimeSpan)
|| sqlFunctionExpression.Type.FullName is "NodaTime.Period" or "NodaTime.Duration")
{
return coalesceExpression.Arguments[0];
}

var visitedArguments = coalesceExpression.Arguments!.ToArray();
visitedArguments[0] = VisitPostgresFunctionComponents(wrappedFunctionExpression);

Expand Down
38 changes: 38 additions & 0 deletions test/EFCore.PG.FunctionalTests/Query/GearsOfWarQueryNpgsqlTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,44 @@ await AssertQuery(
WHERE (date_part('epoch', m.""Duration"") / 0.001) < 3700000.0");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task GroupBy_Property_Select_Sum_over_TimeSpan(bool async)
{
await AssertQueryScalar(
async,
ss => ss.Set<Mission>()
.GroupBy(o => o.Id)
.Select(g => EF.Functions.Sum(g.Select(o => o.Duration))),
ss => ss.Set<Mission>()
.GroupBy(o => o.Id)
.Select(g => (TimeSpan?)new TimeSpan(g.Sum(o => o.Duration.Ticks))));

AssertSql(
@"SELECT sum(m.""Duration"")
FROM ""Missions"" AS m
GROUP BY m.""Id""");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task GroupBy_Property_Select_Average_over_TimeSpan(bool async)
{
await AssertQueryScalar(
async,
ss => ss.Set<Mission>()
.GroupBy(o => o.Id)
.Select(g => EF.Functions.Average(g.Select(o => o.Duration))),
ss => ss.Set<Mission>()
.GroupBy(o => o.Id)
.Select(g => (TimeSpan?)new TimeSpan((long)g.Average(o => o.Duration.Ticks))));

AssertSql(
@"SELECT avg(m.""Duration"")
FROM ""Missions"" AS m
GROUP BY m.""Id""");
}

#endregion TimeSpan

#region DateOnly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public NorthwindFunctionsQueryNpgsqlTest(
: base(fixture)
{
ClearLog();
Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
//Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
}

public override async Task IsNullOrWhiteSpace_in_predicate(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3163,6 +3163,8 @@ into g
GROUP BY o.""CustomerID""");
}

// See aggregate tests over TimeSpan in GearsOfWarQueryNpsgqlTest

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
80 changes: 80 additions & 0 deletions test/EFCore.PG.NodaTime.FunctionalTests/NodaTimeQueryNpgsqlTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,48 @@ public Task Period_FromTicks_is_not_translated()
() => ctx.Set<NodaTimeTypes>().Where(t => Period.FromNanoseconds(t.Id).Seconds == 1).ToListAsync());
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public async Task GroupBy_Property_Select_Sum_over_Period(bool async)
{
using var ctx = CreateContext();

// Note: Unlike Duration, Period can't be converted to total ticks (because its absolute time varies).
var query = ctx.Set<NodaTimeTypes>()
.GroupBy(o => o.Id)
.Select(g => EF.Functions.Sum(g.Select(o => o.Period)));

_ = async
? await query.ToListAsync()
: query.ToList();

AssertSql(
@"SELECT sum(n.""Period"")
FROM ""NodaTimeTypes"" AS n
GROUP BY n.""Id""");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public async Task GroupBy_Property_Select_Average_over_Period(bool async)
{
using var ctx = CreateContext();

// Note: Unlike Duration, Period can't be converted to total ticks (because its absolute time varies).
var query = ctx.Set<NodaTimeTypes>()
.GroupBy(o => o.Id)
.Select(g => EF.Functions.Average(g.Select(o => o.Period)));

_ = async
? await query.ToListAsync()
: query.ToList();

AssertSql(
@"SELECT avg(n.""Period"")
FROM ""NodaTimeTypes"" AS n
GROUP BY n.""Id""");
}

#endregion Period

#region Duration
Expand Down Expand Up @@ -894,6 +936,44 @@ await AssertQuery(
WHERE floor(date_part('second', n.""Duration""))::int = 8");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public async Task GroupBy_Property_Select_Sum_over_Duration(bool async)
{
await AssertQueryScalar(
async,
ss => ss.Set<NodaTimeTypes>()
.GroupBy(o => o.Id)
.Select(g => EF.Functions.Sum(g.Select(o => o.Duration))),
expectedQuery: ss => ss.Set<NodaTimeTypes>()
.GroupBy(o => o.Id)
.Select(g => (Duration?)Duration.FromTicks(g.Sum(o => o.Duration.TotalTicks))));

AssertSql(
@"SELECT sum(n.""Duration"")
FROM ""NodaTimeTypes"" AS n
GROUP BY n.""Id""");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public async Task GroupBy_Property_Select_Average_over_Duration(bool async)
{
await AssertQueryScalar(
async,
ss => ss.Set<NodaTimeTypes>()
.GroupBy(o => o.Id)
.Select(g => EF.Functions.Average(g.Select(o => o.Duration))),
expectedQuery: ss => ss.Set<NodaTimeTypes>()
.GroupBy(o => o.Id)
.Select(g => (Duration?)Duration.FromTicks((long)g.Average(o => o.Duration.TotalTicks))));

AssertSql(
@"SELECT avg(n.""Duration"")
FROM ""NodaTimeTypes"" AS n
GROUP BY n.""Id""");
}

#endregion

#region Interval
Expand Down