Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Translate spatial aggregates #2398

Merged
merged 1 commit into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public static IServiceCollection AddEntityFrameworkNpgsqlNetTopologySuite(
.TryAdd<ISingletonOptions, INpgsqlNetTopologySuiteOptions>(p => p.GetRequiredService<INpgsqlNetTopologySuiteOptions>())
.TryAdd<IRelationalTypeMappingSourcePlugin, NpgsqlNetTopologySuiteTypeMappingSourcePlugin>()
.TryAdd<IMethodCallTranslatorPlugin, NpgsqlNetTopologySuiteMethodCallTranslatorPlugin>()
.TryAdd<IAggregateMethodCallTranslatorPlugin, NpgsqlNetTopologySuiteAggregateMethodCallTranslatorPlugin>()
.TryAdd<IMemberTranslatorPlugin, NpgsqlNetTopologySuiteMemberTranslatorPlugin>()
.TryAdd<IConventionSetPlugin, NpgsqlNetTopologySuiteConventionSetPlugin>()
.TryAddProviderSpecificServices(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
using NetTopologySuite.Algorithm;
using NetTopologySuite.Geometries.Utilities;
using NetTopologySuite.Operation.Union;

// ReSharper disable once CheckNamespace
namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal;

public class NpgsqlNetTopologySuiteAggregateMethodCallTranslatorPlugin : IAggregateMethodCallTranslatorPlugin
{
public NpgsqlNetTopologySuiteAggregateMethodCallTranslatorPlugin(
IRelationalTypeMappingSource typeMappingSource,
ISqlExpressionFactory sqlExpressionFactory)
{
if (sqlExpressionFactory is not NpgsqlSqlExpressionFactory npgsqlSqlExpressionFactory)
{
throw new ArgumentException($"Must be an {nameof(NpgsqlSqlExpressionFactory)}", nameof(sqlExpressionFactory));
}

Translators = new IAggregateMethodCallTranslator[]
{
new NpgsqlNetTopologySuiteAggregateMethodTranslator(npgsqlSqlExpressionFactory, typeMappingSource)
};
}

public virtual IEnumerable<IAggregateMethodCallTranslator> Translators { get; }
}

public class NpgsqlNetTopologySuiteAggregateMethodTranslator : IAggregateMethodCallTranslator
{
private static readonly MethodInfo GeometryCombineMethod
= typeof(GeometryCombiner).GetRuntimeMethod(nameof(GeometryCombiner.Combine), new[] { typeof(IEnumerable<Geometry>) })!;

private static readonly MethodInfo ConvexHullMethod
= typeof(ConvexHull).GetRuntimeMethod(nameof(ConvexHull.Create), new[] { typeof(IEnumerable<Geometry>) })!;

private static readonly MethodInfo UnionMethod
= typeof(UnaryUnionOp).GetRuntimeMethod(nameof(UnaryUnionOp.Union), new[] { typeof(IEnumerable<Geometry>) })!;

private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory;
private readonly IRelationalTypeMappingSource _typeMappingSource;

public NpgsqlNetTopologySuiteAggregateMethodTranslator(
NpgsqlSqlExpressionFactory sqlExpressionFactory,
IRelationalTypeMappingSource typeMappingSource)
{
_sqlExpressionFactory = sqlExpressionFactory;
_typeMappingSource = typeMappingSource;
}

public virtual SqlExpression? Translate(
MethodInfo method, EnumerableExpression source, IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (source.Selector is not SqlExpression sqlExpression
|| (method != GeometryCombineMethod && method != UnionMethod && method != ConvexHullMethod))
{
return null;
}

var resultTypeMapping = _typeMappingSource.FindMapping(typeof(Geometry), sqlExpression.TypeMapping?.StoreType ?? "geometry");

if (method == GeometryCombineMethod || method == UnionMethod)
{
return _sqlExpressionFactory.AggregateFunction(
method == GeometryCombineMethod ? "ST_Collect" : "ST_Union",
new[] { sqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
source,
typeof(Geometry),
resultTypeMapping);
}

// PostGIS has no built-in aggregate convex hull, but we can simply apply ST_Collect beforehand as recommended in the docs
// https://postgis.net/docs/ST_ConvexHull.html
return _sqlExpressionFactory.Function(
"ST_ConvexHull",
new[]
{
_sqlExpressionFactory.AggregateFunction(
"ST_Collect",
new[] { sqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
source,
typeof(Geometry),
resultTypeMapping)
},
nullable: true,
argumentsPropagateNullability: new[] { true },
typeof(Geometry),
resultTypeMapping);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ public NpgsqlNetTopologySuiteMethodCallTranslatorPlugin(
IRelationalTypeMappingSource typeMappingSource,
ISqlExpressionFactory sqlExpressionFactory)
{
if (!(sqlExpressionFactory is NpgsqlSqlExpressionFactory npgsqlSqlExpressionFactory))
if (sqlExpressionFactory is not NpgsqlSqlExpressionFactory npgsqlSqlExpressionFactory)
{
throw new ArgumentException($"Must be an {nameof(NpgsqlSqlExpressionFactory)}", nameof(sqlExpressionFactory));
}

Translators = new IMethodCallTranslator[] { new NpgsqlGeometryMethodTranslator(npgsqlSqlExpressionFactory, typeMappingSource), };
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,10 @@ public override async Task ToText(bool async)
#region Not supported on geography

public override Task Boundary(bool async) => Task.CompletedTask;
public override Task Combine_aggregate(bool async) => Task.CompletedTask;
public override Task Contains(bool async) => Task.CompletedTask;
public override Task ConvexHull(bool async) => Task.CompletedTask;
public override Task ConvexHull_aggregate(bool async) => Task.CompletedTask;
public override Task Crosses(bool async) => Task.CompletedTask;
public override Task Difference(bool async) => Task.CompletedTask;
public override Task Dimension(bool async) => Task.CompletedTask;
Expand Down Expand Up @@ -300,6 +302,7 @@ public override async Task ToText(bool async)
public override Task SymmetricDifference(bool async) => Task.CompletedTask;
public override Task Touches(bool async) => Task.CompletedTask;
public override Task Union(bool async) => Task.CompletedTask;
public override Task Union_aggregate(bool async) => Task.CompletedTask;
public override Task Union_void(bool async) => Task.CompletedTask;
public override Task Within(bool async) => Task.CompletedTask;
public override Task X(bool async) => Task.CompletedTask;
Expand Down Expand Up @@ -339,4 +342,4 @@ public override DbContextOptionsBuilder AddOptions(DbContextOptionsBuilder build
return optionsBuilder;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ public override async Task Centroid(bool async)
FROM ""PolygonEntity"" AS p");
}

public override async Task Combine_aggregate(bool async)
{
await base.Combine_aggregate(async);

AssertSql(
@"SELECT p.""Group"" AS ""Id"", ST_Collect(p.""Point"") AS ""Combined""
FROM ""PointEntity"" AS p
WHERE (p.""Point"" IS NOT NULL)
GROUP BY p.""Group""");
}

public override async Task Contains(bool async)
{
await base.Contains(async);
Expand All @@ -100,6 +111,17 @@ public override async Task ConvexHull(bool async)
FROM ""PolygonEntity"" AS p");
}

public override async Task ConvexHull_aggregate(bool async)
{
await base.ConvexHull_aggregate(async);

AssertSql(
@"SELECT p.""Group"" AS ""Id"", ST_ConvexHull(ST_Collect(p.""Point"")) AS ""ConvexHull""
FROM ""PointEntity"" AS p
WHERE (p.""Point"" IS NOT NULL)
GROUP BY p.""Group""");
}

public override async Task IGeometryCollection_Count(bool async)
{
await base.IGeometryCollection_Count(async);
Expand Down Expand Up @@ -526,6 +548,17 @@ public override async Task Union(bool async)
FROM ""PolygonEntity"" AS p");
}

public override async Task Union_aggregate(bool async)
{
await base.Union_aggregate(async);

AssertSql(
@"SELECT p.""Group"" AS ""Id"", ST_Union(p.""Point"") AS ""Union""
FROM ""PointEntity"" AS p
WHERE (p.""Point"" IS NOT NULL)
GROUP BY p.""Group""");
}

public override async Task Union_void(bool async)
{
await base.Union_void(async);
Expand Down Expand Up @@ -585,4 +618,4 @@ public override DbContextOptionsBuilder AddOptions(DbContextOptionsBuilder build
return optionsBuilder;
}
}
}
}