Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exploratory work on custom aggregate methods #1531

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/EFCore.PG/Extensions/NpgsqlDbFunctionsExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;

Expand Down Expand Up @@ -51,5 +53,11 @@ public static bool ILike(
/// <returns>The reversed string.</returns>
public static string Reverse([CanBeNull] this DbFunctions _, [CanBeNull] string value)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Reverse)));

public static string StringAggregate<TSource, TResult>(
[CanBeNull] this DbFunctions _,
[NotNull] IEnumerable<TSource> source,
[NotNull] Expression<Func<TSource, TResult>> selector)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Reverse)));
}
}
257 changes: 257 additions & 0 deletions src/EFCore.PG/Extensions/NpgsqlQueryableExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
using System;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Query;
using Npgsql.EntityFrameworkCore.PostgreSQL.Utilities;

// ReSharper disable once CheckNamespace
namespace Microsoft.EntityFrameworkCore
{
public static class NpgsqlQueryableExtensions
{
#region StringAggregate

internal static readonly MethodInfo StringAggregateWithoutSelectorMethod
= typeof(NpgsqlQueryableExtensions).GetTypeInfo().GetDeclaredMethods(nameof(StringAggregate))
.Single(m => m.GetParameters().Length == 2 && m.GetParameters()[1].ParameterType == typeof(string));

internal static readonly MethodInfo StringAggregateWithSelectorMethod
= typeof(NpgsqlQueryableExtensions).GetTypeInfo().GetDeclaredMethods(nameof(StringAggregate))
.Single(m => m.GetParameters().Length == 3);

/// <summary>
/// Concatenates the non-null input values into a string.
/// Each value after the first is preceded by the corresponding delimiter (if it's not null).
/// </summary>
/// <param name="source">An <see cref="IQueryable{T}" /> that contains the elements to concatenate.</param>
/// <param name="delimiter">The delimiter for the concatenation. Defaults to empty string.</param>
/// <typeparam name="TSource">An <see cref="IQueryable{T}" /> that contains the elements to concatenate.</typeparam>
/// <returns>A task that represents the asynchronous operation. The task result contains sequence concatenation.</returns>
/// <remarks>
/// Calls PostgreSQL <c>string_agg</c>, see https://www.postgresql.org/docs/current/functions-aggregate.html.
/// </remarks>
public static string StringAggregate<TSource>([NotNull] this IQueryable<TSource> source, [CanBeNull] string delimiter = null)
{
Check.NotNull(source, nameof(source));

return source.Provider.Execute<string>(
Expression.Call(
instance: null,
StringAggregateWithoutSelectorMethod.MakeGenericMethod(typeof(TSource)),
source.Expression,
Expression.Constant(delimiter ?? string.Empty)));
}

/// <summary>
/// Concatenates the non-null input values into a string.
/// Each value after the first is preceded by the corresponding delimiter (if it's not null).
/// </summary>
/// <param name="source">An <see cref="IQueryable{T}" /> that contains the elements to concatenate.</param>
/// <param name="delimiter">The delimiter for the concatenation. Defaults to empty string.</param>
/// <param name="selector"> A projection function to apply to each element. </param>
/// <typeparam name="TSource">An <see cref="IQueryable{T}" /> that contains the elements to concatenate.</typeparam>
/// <typeparam name="TResult">
/// The type of the value returned by the function represented by <paramref name="selector" />.
/// </typeparam>
/// <returns>A task that represents the asynchronous operation. The task result contains sequence concatenation.</returns>
/// <remarks>
/// Calls PostgreSQL <c>string_agg</c>, see https://www.postgresql.org/docs/current/functions-aggregate.html.
/// </remarks>
public static string StringAggregate<TSource, TResult>(
[NotNull] this IQueryable<TSource> source,
[CanBeNull] string delimiter,
[NotNull] Expression<Func<TSource, TResult>> selector)
{
Check.NotNull(source, nameof(source));
Check.NotNull(selector, nameof(selector));

return source.Provider.Execute<string>(
Expression.Call(
instance: null,
StringAggregateWithSelectorMethod.MakeGenericMethod(typeof(TSource), typeof(TResult)),
source.Expression,
Expression.Constant(delimiter ?? string.Empty),
Expression.Quote(selector)));
}

/// <summary>
/// Concatenates the non-null input values into a string.
/// </summary>
/// <param name="source">An <see cref="IQueryable{T}" /> that contains the elements to concatenate.</param>
/// <param name="selector"> A projection function to apply to each element. </param>
/// <typeparam name="TSource">An <see cref="IQueryable{T}" /> that contains the elements to concatenate.</typeparam>
/// <typeparam name="TResult">
/// The type of the value returned by the function represented by <paramref name="selector" />.
/// </typeparam>
/// <returns>A task that represents the asynchronous operation. The task result contains sequence concatenation.</returns>
/// <remarks>
/// Calls PostgreSQL <c>string_agg</c>, see https://www.postgresql.org/docs/current/functions-aggregate.html.
/// </remarks>
public static string StringAggregate<TSource, TResult>(
[NotNull] this IQueryable<TSource> source,
[NotNull] Expression<Func<TSource, TResult>> selector)
{
Check.NotNull(source, nameof(source));
Check.NotNull(selector, nameof(selector));

return source.Provider.Execute<string>(
Expression.Call(
instance: null,
StringAggregateWithSelectorMethod.MakeGenericMethod(typeof(TSource), typeof(TResult)),
source.Expression,
Expression.Constant(string.Empty),
Expression.Quote(selector)));
}

/// <summary>
/// Concatenates the non-null input values into a string.
/// Each value after the first is preceded by the corresponding delimiter (if it's not null).
/// </summary>
/// <param name="source">An <see cref="IQueryable{T}" /> that contains the elements to concatenate.</param>
/// <param name="delimiter">The delimiter for the concatenation. Defaults to empty string.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken" /> to observe while waiting for the task to complete.</param>
/// <typeparam name="TSource">An <see cref="IQueryable{T}" /> that contains the elements to concatenate.</typeparam>
/// <returns>A task that represents the asynchronous operation. The task result contains sequence concatenation.</returns>
/// <remarks>
/// Calls PostgreSQL <c>string_agg</c>, see https://www.postgresql.org/docs/current/functions-aggregate.html.
/// </remarks>
public static Task<string> StringAggregateAsync<TSource>(
[NotNull] this IQueryable<TSource> source,
[CanBeNull] string delimiter = null,
CancellationToken cancellationToken = default)
{
Check.NotNull(source, nameof(source));

return ExecuteAsync<TSource, Task<string>>(
StringAggregateWithoutSelectorMethod,
source,
Expression.Constant(delimiter ?? string.Empty),
cancellationToken);
}

/// <summary>
/// Concatenates the non-null input values into a string.
/// Each value after the first is preceded by the corresponding delimiter (if it's not null).
/// </summary>
/// <param name="source">An <see cref="IQueryable{T}" /> that contains the elements to concatenate.</param>
/// <param name="delimiter">The delimiter for the concatenation. Defaults to empty string.</param>
/// <param name="selector"> A projection function to apply to each element. </param>
/// <param name="cancellationToken">A <see cref="CancellationToken" /> to observe while waiting for the task to complete.</param>
/// <typeparam name="TSource">An <see cref="IQueryable{T}" /> that contains the elements to concatenate.</typeparam>
/// <typeparam name="TResult">
/// The type of the value returned by the function represented by <paramref name="selector" />.
/// </typeparam>
/// <returns>A task that represents the asynchronous operation. The task result contains sequence concatenation.</returns>
/// <remarks>
/// Calls PostgreSQL <c>string_agg</c>, see https://www.postgresql.org/docs/current/functions-aggregate.html.
/// </remarks>
public static Task<string> StringAggregateAsync<TSource, TResult>(
[NotNull] this IQueryable<TSource> source,
[CanBeNull] string delimiter,
[NotNull] Expression<Func<TSource, TResult>> selector,
CancellationToken cancellationToken = default)
{
Check.NotNull(source, nameof(source));
Check.NotNull(selector, nameof(selector));

return ExecuteAsync<TSource, Task<string>>(
StringAggregateWithSelectorMethod,
source,
Expression.Constant(delimiter ?? string.Empty),
Expression.Quote(selector),
cancellationToken);
}

/// <summary>
/// Concatenates the non-null input values into a string.
/// </summary>
/// <param name="source">An <see cref="IQueryable{T}" /> that contains the elements to concatenate.</param>
/// <param name="selector"> A projection function to apply to each element. </param>
/// <param name="cancellationToken">A <see cref="CancellationToken" /> to observe while waiting for the task to complete.</param>
/// <typeparam name="TSource">An <see cref="IQueryable{T}" /> that contains the elements to concatenate.</typeparam>
/// <typeparam name="TResult">
/// The type of the value returned by the function represented by <paramref name="selector" />.
/// </typeparam>
/// <returns>A task that represents the asynchronous operation. The task result contains sequence concatenation.</returns>
/// <remarks>
/// Calls PostgreSQL <c>string_agg</c>, see https://www.postgresql.org/docs/current/functions-aggregate.html.
/// </remarks>
public static Task<string> StringAggregateAsync<TSource, TResult>(
[NotNull] this IQueryable<TSource> source,
[NotNull] Expression<Func<TSource, TResult>> selector,
CancellationToken cancellationToken = default)
{
Check.NotNull(source, nameof(source));
Check.NotNull(selector, nameof(selector));

return ExecuteAsync<TSource, Task<string>>(
StringAggregateWithSelectorMethod,
source,
Expression.Constant(string.Empty),
Expression.Quote(selector),
cancellationToken);
}

#endregion StringAggregate

#region Impl.

// Copied from EntityFrameworkQueryableExtensions

static TResult ExecuteAsync<TSource, TResult>(
MethodInfo operatorMethodInfo,
IQueryable<TSource> source,
Expression arg1,
CancellationToken cancellationToken = default)
{
if (source.Provider is IAsyncQueryProvider provider)
{
if (operatorMethodInfo.IsGenericMethod)
{
operatorMethodInfo
= operatorMethodInfo.GetGenericArguments().Length == 2
? operatorMethodInfo.MakeGenericMethod(typeof(TSource), typeof(TResult).GetGenericArguments().Single())
: operatorMethodInfo.MakeGenericMethod(typeof(TSource));
}

return provider.ExecuteAsync<TResult>(
Expression.Call(instance: null, operatorMethodInfo, source.Expression, arg1),
cancellationToken);
}

throw new InvalidOperationException(CoreStrings.IQueryableProviderNotAsync);
}

static TResult ExecuteAsync<TSource, TResult>(
MethodInfo operatorMethodInfo,
IQueryable<TSource> source,
Expression arg1,
Expression arg2,
CancellationToken cancellationToken = default)
{
if (source.Provider is IAsyncQueryProvider provider)
{
if (operatorMethodInfo.IsGenericMethod)
{
operatorMethodInfo
= operatorMethodInfo.GetGenericArguments().Length == 2
? operatorMethodInfo.MakeGenericMethod(typeof(TSource), typeof(TResult).GetGenericArguments().Single())
: operatorMethodInfo.MakeGenericMethod(typeof(TSource));
}

return provider.ExecuteAsync<TResult>(
Expression.Call(instance: null, operatorMethodInfo, source.Expression, arg1, arg2),
cancellationToken);
}

throw new InvalidOperationException(CoreStrings.IQueryableProviderNotAsync);
}

#endregion
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ public static IServiceCollection AddEntityFrameworkNpgsql([NotNull] this IServic
.TryAdd<IMemberTranslatorProvider, NpgsqlMemberTranslatorProvider>()
.TryAdd<IEvaluatableExpressionFilter, NpgsqlEvaluatableExpressionFilter>()
.TryAdd<IQuerySqlGeneratorFactory, NpgsqlQuerySqlGeneratorFactory>()
.TryAdd<IQueryableMethodTranslatingExpressionVisitorFactory, NpgsqlQueryableMethodTranslatingExpressionVisitorFactory>()
.TryAdd<IRelationalSqlTranslatingExpressionVisitorFactory, NpgsqlSqlTranslatingExpressionVisitorFactory>()
.TryAdd<IRelationalParameterBasedSqlProcessorFactory, NpgsqlParameterBasedSqlProcessorFactory>()
.TryAdd<ISqlExpressionFactory, NpgsqlSqlExpressionFactory>()
Expand Down
Loading