Skip to content

Commit

Permalink
Translate spatial aggregates
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed May 27, 2022
1 parent bbac00d commit f77330d
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public static IServiceCollection AddEntityFrameworkNpgsqlNetTopologySuite(
.TryAdd<ISingletonOptions, INpgsqlNetTopologySuiteOptions>(p => p.GetRequiredService<INpgsqlNetTopologySuiteOptions>())
.TryAdd<IRelationalTypeMappingSourcePlugin, NpgsqlNetTopologySuiteTypeMappingSourcePlugin>()
.TryAdd<IMethodCallTranslatorPlugin, NpgsqlNetTopologySuiteMethodCallTranslatorPlugin>()
.TryAdd<IAggregateMethodCallTranslatorPlugin, NpgsqlNetTopologySuiteAggregateMethodCallTranslatorPlugin>()
.TryAdd<IMemberTranslatorPlugin, NpgsqlNetTopologySuiteMemberTranslatorPlugin>()
.TryAdd<IConventionSetPlugin, NpgsqlNetTopologySuiteConventionSetPlugin>()
.TryAddProviderSpecificServices(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
using NetTopologySuite.Algorithm;
using NetTopologySuite.Geometries.Utilities;
using NetTopologySuite.Operation.Union;

// ReSharper disable once CheckNamespace
namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal;

public class NpgsqlNetTopologySuiteAggregateMethodCallTranslatorPlugin : IAggregateMethodCallTranslatorPlugin
{
public NpgsqlNetTopologySuiteAggregateMethodCallTranslatorPlugin(
IRelationalTypeMappingSource typeMappingSource,
ISqlExpressionFactory sqlExpressionFactory)
{
if (sqlExpressionFactory is not NpgsqlSqlExpressionFactory npgsqlSqlExpressionFactory)
{
throw new ArgumentException($"Must be an {nameof(NpgsqlSqlExpressionFactory)}", nameof(sqlExpressionFactory));
}

Translators = new IAggregateMethodCallTranslator[]
{
new NpgsqlNetTopologySuiteAggregateMethodTranslator(npgsqlSqlExpressionFactory, typeMappingSource)
};
}

public virtual IEnumerable<IAggregateMethodCallTranslator> Translators { get; }
}

public class NpgsqlNetTopologySuiteAggregateMethodTranslator : IAggregateMethodCallTranslator
{
private static readonly MethodInfo GeometryCombineMethod
= typeof(GeometryCombiner).GetRuntimeMethod(nameof(GeometryCombiner.Combine), new[] { typeof(IEnumerable<Geometry>) })!;

private static readonly MethodInfo ConvexHullMethod
= typeof(ConvexHull).GetRuntimeMethod(nameof(ConvexHull.Create), new[] { typeof(IEnumerable<Geometry>) })!;

private static readonly MethodInfo UnionMethod
= typeof(UnaryUnionOp).GetRuntimeMethod(nameof(UnaryUnionOp.Union), new[] { typeof(IEnumerable<Geometry>) })!;

private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory;
private readonly IRelationalTypeMappingSource _typeMappingSource;

public NpgsqlNetTopologySuiteAggregateMethodTranslator(
NpgsqlSqlExpressionFactory sqlExpressionFactory,
IRelationalTypeMappingSource typeMappingSource)
{
_sqlExpressionFactory = sqlExpressionFactory;
_typeMappingSource = typeMappingSource;
}

public virtual SqlExpression? Translate(
MethodInfo method, EnumerableExpression source, IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (source.Selector is not SqlExpression sqlExpression
|| (method != GeometryCombineMethod && method != UnionMethod && method != ConvexHullMethod))
{
return null;
}

var resultTypeMapping = _typeMappingSource.FindMapping(typeof(Geometry), sqlExpression.TypeMapping?.StoreType ?? "geometry");

if (method == GeometryCombineMethod || method == UnionMethod)
{
return _sqlExpressionFactory.AggregateFunction(
method == GeometryCombineMethod ? "ST_Collect" : "ST_Union",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: new[] { false },
typeof(Geometry),
resultTypeMapping);
}

// PostGIS has no built-in aggregate convex hull, but we can simply apply ST_Collect beforehand as recommended in the docs
// https://postgis.net/docs/ST_ConvexHull.html
return _sqlExpressionFactory.Function(
"ST_ConvexHull",
new[]
{
_sqlExpressionFactory.Function(
"ST_Collect",
new[] { sqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
typeof(Geometry),
resultTypeMapping)
},
nullable: true,
argumentsPropagateNullability: new[] { true },
typeof(Geometry),
resultTypeMapping);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ public NpgsqlNetTopologySuiteMethodCallTranslatorPlugin(
IRelationalTypeMappingSource typeMappingSource,
ISqlExpressionFactory sqlExpressionFactory)
{
if (!(sqlExpressionFactory is NpgsqlSqlExpressionFactory npgsqlSqlExpressionFactory))
if (sqlExpressionFactory is not NpgsqlSqlExpressionFactory npgsqlSqlExpressionFactory)
{
throw new ArgumentException($"Must be an {nameof(NpgsqlSqlExpressionFactory)}", nameof(sqlExpressionFactory));
}

Translators = new IMethodCallTranslator[] { new NpgsqlGeometryMethodTranslator(npgsqlSqlExpressionFactory, typeMappingSource), };
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@ public NpgsqlMiscAggregateMethodTranslator(NpgsqlSqlExpressionFactory sqlExpress
return _sqlExpressionFactory.AggregateFunction(
"string_agg",
new[] { sqlExpression, arguments[0] },
nullable: true,
argumentsPropagateNullability: new[] { false, true },
source,
typeof(string));
nullable: true, argumentsPropagateNullability: new[] { false, true }, returnType: typeof(string));
}

if (method.DeclaringType == typeof(NpgsqlAggregateDbFunctionsExtensions))
Expand All @@ -45,11 +43,9 @@ public NpgsqlMiscAggregateMethodTranslator(NpgsqlSqlExpressionFactory sqlExpress
return _sqlExpressionFactory.AggregateFunction(
"array_agg",
new[] { sqlExpression },
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
arrayClrType,
sqlExpression.TypeMapping is null
nullable: true,
argumentsPropagateNullability: FalseArrays[1], returnType: arrayClrType, typeMapping: sqlExpression.TypeMapping is null
? null
: new NpgsqlArrayArrayTypeMapping(arrayClrType, sqlExpression.TypeMapping));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,16 @@ public NpgsqlQueryableAggregateMethodTranslator(
_sqlExpressionFactory.AggregateFunction(
"AVG",
new[] { averageSqlExpression },
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
typeof(double)),
nullable: true, argumentsPropagateNullability: FalseArrays[1], returnType: typeof(double)),
averageSqlExpression.Type,
averageSqlExpression.TypeMapping)
: _sqlExpressionFactory.AggregateFunction(
"AVG",
new[] { averageSqlExpression },
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
averageSqlExpression.Type,
averageSqlExpression.TypeMapping);
nullable: true,
argumentsPropagateNullability: FalseArrays[1], returnType: averageSqlExpression.Type, typeMapping: averageSqlExpression.TypeMapping);

// PostgreSQL COUNT() always returns bigint, so we need to downcast to int
case nameof(Queryable.Count)
Expand All @@ -70,10 +66,8 @@ public NpgsqlQueryableAggregateMethodTranslator(
_sqlExpressionFactory.AggregateFunction(
"COUNT",
new[] { countSqlExpression },
nullable: false,
argumentsPropagateNullability: FalseArrays[1],
source,
typeof(long))),
nullable: false, argumentsPropagateNullability: FalseArrays[1], returnType: typeof(long))),
typeof(int), _typeMappingSource.FindMapping(typeof(int)));

case nameof(Queryable.LongCount)
Expand All @@ -84,10 +78,8 @@ public NpgsqlQueryableAggregateMethodTranslator(
_sqlExpressionFactory.AggregateFunction(
"COUNT",
new[] { longCountSqlExpression },
nullable: false,
argumentsPropagateNullability: FalseArrays[1],
source,
typeof(long)));
nullable: false, argumentsPropagateNullability: FalseArrays[1], returnType: typeof(long)));

case nameof(Queryable.Max)
when (methodInfo == QueryableMethods.MaxWithoutSelector
Expand All @@ -96,11 +88,9 @@ public NpgsqlQueryableAggregateMethodTranslator(
return _sqlExpressionFactory.AggregateFunction(
"MAX",
new[] { maxSqlExpression },
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
maxSqlExpression.Type,
maxSqlExpression.TypeMapping);
nullable: true,
argumentsPropagateNullability: FalseArrays[1], returnType: maxSqlExpression.Type, typeMapping: maxSqlExpression.TypeMapping);

case nameof(Queryable.Min)
when (methodInfo == QueryableMethods.MinWithoutSelector
Expand All @@ -109,11 +99,9 @@ public NpgsqlQueryableAggregateMethodTranslator(
return _sqlExpressionFactory.AggregateFunction(
"MIN",
new[] { minSqlExpression },
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
minSqlExpression.Type,
minSqlExpression.TypeMapping);
nullable: true,
argumentsPropagateNullability: FalseArrays[1], returnType: minSqlExpression.Type, typeMapping: minSqlExpression.TypeMapping);

// In PostgreSQL SUM() doesn't return the same type as its argument for smallint, int and bigint.
// Cast to get the same type.
Expand All @@ -131,10 +119,8 @@ public NpgsqlQueryableAggregateMethodTranslator(
_sqlExpressionFactory.AggregateFunction(
"SUM",
new[] { sumSqlExpression },
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
typeof(long)),
nullable: true, argumentsPropagateNullability: FalseArrays[1], returnType: typeof(long)),
sumInputType,
sumSqlExpression.TypeMapping);
}
Expand All @@ -145,22 +131,18 @@ public NpgsqlQueryableAggregateMethodTranslator(
_sqlExpressionFactory.AggregateFunction(
"SUM",
new[] { sumSqlExpression },
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
typeof(decimal)),
nullable: true, argumentsPropagateNullability: FalseArrays[1], returnType: typeof(decimal)),
sumInputType,
sumSqlExpression.TypeMapping);
}

return _sqlExpressionFactory.AggregateFunction(
"SUM",
new[] { sumSqlExpression },
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
source,
sumInputType,
sumSqlExpression.TypeMapping);
nullable: true,
argumentsPropagateNullability: FalseArrays[1], returnType: sumInputType, typeMapping: sumSqlExpression.TypeMapping);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,8 @@ public static PostgresFunctionExpression CreateWithNamedArguments(

return new PostgresFunctionExpression(
name, arguments, argumentNames, argumentSeparators: null,
nullable, argumentsPropagateNullability,
aggregateDistinct: false, aggregatePredicate: null, aggregateOrderings: Array.Empty<OrderingExpression>(),
type, typeMapping);
nullable: nullable, argumentsPropagateNullability: argumentsPropagateNullability, type: type, typeMapping: typeMapping);
}

public static PostgresFunctionExpression CreateWithArgumentSeparators(
Expand All @@ -77,22 +76,21 @@ public static PostgresFunctionExpression CreateWithArgumentSeparators(
Check.NotNull(argumentSeparators, nameof(argumentSeparators));

return new PostgresFunctionExpression(
name, arguments, argumentNames: null, argumentSeparators,
nullable, argumentsPropagateNullability,
name, arguments, argumentNames: null, argumentSeparators: argumentSeparators,
aggregateDistinct: false, aggregatePredicate: null, aggregateOrderings: Array.Empty<OrderingExpression>(),
type, typeMapping);
nullable: nullable, argumentsPropagateNullability: argumentsPropagateNullability, type: type, typeMapping: typeMapping);
}

public PostgresFunctionExpression(
string name,
IEnumerable<SqlExpression> arguments,
IEnumerable<string?>? argumentNames,
IEnumerable<string?>? argumentSeparators,
bool nullable,
IEnumerable<bool> argumentsPropagateNullability,
bool aggregateDistinct,
SqlExpression? aggregatePredicate,
IReadOnlyList<OrderingExpression> aggregateOrderings,
bool nullable,
IEnumerable<bool> argumentsPropagateNullability,
Type type,
RelationalTypeMapping? typeMapping)
: base(name, arguments, nullable, argumentsPropagateNullability, type, typeMapping)
Expand Down Expand Up @@ -175,11 +173,10 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
return changed
? new PostgresFunctionExpression(
Name, visitedArguments ?? Arguments, ArgumentNames, ArgumentSeparators,
IsNullable, ArgumentsPropagateNullability!,
IsAggregateDistinct,
visitedAggregatePredicate ?? AggregatePredicate,
visitedAggregateOrderings ?? AggregateOrderings,
Type, TypeMapping)
IsNullable, ArgumentsPropagateNullability!, Type, TypeMapping)
: this;
}

Expand All @@ -189,13 +186,11 @@ public override SqlFunctionExpression ApplyTypeMapping(RelationalTypeMapping? ty
Arguments,
ArgumentNames,
ArgumentSeparators,
IsNullable,
ArgumentsPropagateNullability,
IsAggregateDistinct,
AggregatePredicate,
AggregateOrderings,
Type,
typeMapping ?? TypeMapping);
IsNullable,
ArgumentsPropagateNullability, Type, typeMapping ?? TypeMapping);

public override SqlFunctionExpression Update(SqlExpression? instance, IReadOnlyList<SqlExpression>? arguments)
{
Expand All @@ -209,11 +204,10 @@ public override SqlFunctionExpression Update(SqlExpression? instance, IReadOnlyL
return !arguments.SequenceEqual(Arguments)
? new PostgresFunctionExpression(
Name, arguments, ArgumentNames, ArgumentSeparators,
IsNullable, ArgumentsPropagateNullability,
IsAggregateDistinct,
AggregatePredicate,
AggregateOrderings,
Type, TypeMapping)
IsNullable, ArgumentsPropagateNullability, Type, TypeMapping)
: this;
}

Expand All @@ -222,11 +216,10 @@ public virtual SqlFunctionExpression UpdateAggregateComponents(SqlExpression? pr
return predicate != AggregatePredicate || orderings != AggregateOrderings
? new PostgresFunctionExpression(
Name, Arguments, ArgumentNames, ArgumentSeparators,
IsNullable, ArgumentsPropagateNullability,
IsAggregateDistinct,
predicate,
orderings,
Type, TypeMapping)
IsNullable, ArgumentsPropagateNullability, Type, TypeMapping)
: this;
}

Expand Down
8 changes: 3 additions & 5 deletions src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ public virtual PostgresBinaryExpression Overlaps(SqlExpression left, SqlExpressi
public virtual PostgresFunctionExpression AggregateFunction(
string name,
IEnumerable<SqlExpression> arguments,
EnumerableExpression aggregateEnumerableExpression,
bool nullable,
IEnumerable<bool> argumentsPropagateNullability,
EnumerableExpression aggregateEnumerableExpression,
Type returnType,
RelationalTypeMapping? typeMapping = null)
{
Expand All @@ -307,13 +307,11 @@ public virtual PostgresFunctionExpression AggregateFunction(
typeMappedArguments,
argumentNames: null,
argumentSeparators: null,
nullable,
argumentsPropagateNullability,
aggregateEnumerableExpression.IsDistinct,
aggregateEnumerableExpression.Predicate,
aggregateEnumerableExpression.Orderings,
returnType,
typeMapping);
nullable: nullable,
argumentsPropagateNullability: argumentsPropagateNullability, type: returnType, typeMapping: typeMapping);
}

#endregion Expression factory methods
Expand Down
Loading

0 comments on commit f77330d

Please sign in to comment.