Skip to content

Commit

Permalink
[SEDONA-630] Improve ST_Union_Aggr performance (#1526)
Browse files Browse the repository at this point in the history
* [SEDONA-630] Improve ST_Union_Aggr performance

Switch to JTS `OverlayNGRobust.union` function to perform geometry union and add
geometry cache capability.

* fix pythion test

* add unit test to measure the ST_Union_aggr time

* address review comments by refactoring unit tests

* rename test table
  • Loading branch information
zhangfengcdt committed Jul 22, 2024
1 parent 8118894 commit bab1f77
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 16 deletions.
2 changes: 1 addition & 1 deletion python/tests/sql/test_dataframe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@
# aggregates
(sta.ST_Envelope_Aggr, ("geom",), "exploded_points", "", "POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))"),
(sta.ST_Intersection_Aggr, ("geom",), "exploded_polys", "", "LINESTRING (1 0, 1 1)"),
(sta.ST_Union_Aggr, ("geom",), "exploded_polys", "", "POLYGON ((1 0, 0 0, 0 1, 1 1, 2 1, 2 0, 1 0))"),
(sta.ST_Union_Aggr, ("geom",), "exploded_polys", "", "POLYGON ((0 0, 0 1, 1 1, 2 1, 2 0, 1 0, 0 0))"),
]

wrong_type_configurations = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.sedona_sql.expressions.raster._
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.operation.buffer.BufferParameters

import scala.collection.mutable.ListBuffer
import scala.reflect.ClassTag

object Catalog {
Expand Down Expand Up @@ -327,8 +328,13 @@ object Catalog {
function[RS_FromNetCDF](),
function[RS_NetCDFInfo]())

// Aggregate functions with Geometry as buffer
val aggregateExpressions: Seq[Aggregator[Geometry, Geometry, Geometry]] =
Seq(new ST_Union_Aggr, new ST_Envelope_Aggr, new ST_Intersection_Aggr)
Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr)

// Aggregate functions with List as buffer
val aggregateExpressions2: Seq[Aggregator[Geometry, ListBuffer[Geometry], Geometry]] =
Seq(new ST_Union_Aggr())

private def function[T <: Expression: ClassTag](defaultArgs: Any*): FunctionDescription = {
val classTag = implicitly[ClassTag[T]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ object UdfRegistrator {
}
Catalog.aggregateExpressions.foreach(f =>
sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f))) // SPARK3 anchor
//Catalog.aggregateExpressions_UDAF.foreach(f => sparkSession.udf.register(f.getClass.getSimpleName, f)) // SPARK2 anchor

Catalog.aggregateExpressions2.foreach(f =>
sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f))) // SPARK3 anchor
}

def dropAll(sparkSession: SparkSession): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ package org.apache.spark.sql.sedona_sql.expressions
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory}
import org.locationtech.jts.operation.overlayng.OverlayNGRobust

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer

/**
* traits for creating Aggregate Function
Expand Down Expand Up @@ -50,22 +54,44 @@ trait TraitSTAggregateExec {
def finish(out: Geometry): Geometry = out
}

/**
* Return the polygon union of all Polygon in the given column
*/
class ST_Union_Aggr extends Aggregator[Geometry, Geometry, Geometry] with TraitSTAggregateExec {
class ST_Union_Aggr(bufferSize: Int = 1000)
extends Aggregator[Geometry, ListBuffer[Geometry], Geometry]
with Serializable {

override def reduce(buffer: ListBuffer[Geometry], input: Geometry): ListBuffer[Geometry] = {
buffer += input
if (buffer.size >= bufferSize) {
// Perform the union when buffer size is reached
val unionGeometry = OverlayNGRobust.union(buffer.asJava)
buffer.clear()
buffer += unionGeometry
}
buffer
}

def reduce(buffer: Geometry, input: Geometry): Geometry = {
if (buffer.equalsExact(initialGeometry)) input
else buffer.union(input)
override def merge(
buffer1: ListBuffer[Geometry],
buffer2: ListBuffer[Geometry]): ListBuffer[Geometry] = {
buffer1 ++= buffer2
if (buffer1.size >= bufferSize) {
// Perform the union when buffer size is reached
val unionGeometry = OverlayNGRobust.union(buffer1.asJava)
buffer1.clear()
buffer1 += unionGeometry
}
buffer1
}

def merge(buffer1: Geometry, buffer2: Geometry): Geometry = {
if (buffer1.equals(initialGeometry)) buffer2
else if (buffer2.equals(initialGeometry)) buffer1
else buffer1.union(buffer2)
override def finish(reduction: ListBuffer[Geometry]): Geometry = {
OverlayNGRobust.union(reduction.asJava)
}

def bufferEncoder: ExpressionEncoder[ListBuffer[Geometry]] =
ExpressionEncoder[ListBuffer[Geometry]]()

def outputEncoder: ExpressionEncoder[Geometry] = ExpressionEncoder[Geometry]()

override def zero: ListBuffer[Geometry] = ListBuffer.empty
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
*/
package org.apache.sedona.sql

import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.expressions.javalang.typed
import org.apache.spark.sql.sedona_sql.expressions.ST_Union_Aggr
import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory, Polygon}
import org.locationtech.jts.io.WKTReader

import scala.util.Random

class aggregateFunctionTestScala extends TestBaseScala {

Expand Down Expand Up @@ -62,6 +68,35 @@ class aggregateFunctionTestScala extends TestBaseScala {
assert(union.take(1)(0).get(0).asInstanceOf[Geometry].getArea == 10100)
}

it("Measured ST_Union_aggr wall time") {
// number of random polygons to generate
val numPolygons = 1000
val df = createPolygonDataFrame(numPolygons)

df.createOrReplaceTempView("geometry_table_for_measuring_union_aggr")

// cache the table to eliminate the time of table scan
df.cache()
sparkSession
.sql("select count(*) from geometry_table_for_measuring_union_aggr")
.take(1)(0)
.get(0)

// measure time for optimized ST_Union_Aggr
val startTimeOptimized = System.currentTimeMillis()
val unionOptimized =
sparkSession.sql(
"SELECT ST_Union_Aggr(geom) AS union_geom FROM geometry_table_for_measuring_union_aggr")
assert(unionOptimized.take(1)(0).get(0).asInstanceOf[Geometry].getArea > 0)
val endTimeOptimized = System.currentTimeMillis()
val durationOptimized = endTimeOptimized - startTimeOptimized

assert(durationOptimized > 0, "Duration of optimized ST_Union_Aggr should be positive")

// clear cache
df.unpersist()
}

it("Passed ST_Intersection_aggr") {

val twoPolygonsAsWktDf =
Expand Down Expand Up @@ -97,4 +132,24 @@ class aggregateFunctionTestScala extends TestBaseScala {
assertResult(0.0)(intersectionDF.take(1)(0).get(0).asInstanceOf[Geometry].getArea)
}
}

def generateRandomPolygon(index: Int): String = {
val random = new Random()
val x = random.nextDouble() * index
val y = random.nextDouble() * index
s"POLYGON (($x $y, ${x + 1} $y, ${x + 1} ${y + 1}, $x ${y + 1}, $x $y))"
}

def createPolygonDataFrame(numPolygons: Int): DataFrame = {
val polygons = (1 to numPolygons).map(generateRandomPolygon).toArray
val polygonArray = polygons.map(polygon => s"ST_GeomFromWKT('$polygon')")
val polygonArrayStr = polygonArray.mkString(", ")

val sqlQuery =
s"""
|SELECT explode(array($polygonArrayStr)) AS geom
""".stripMargin

sparkSession.sql(sqlQuery)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1585,7 +1585,7 @@ class dataFrameAPITestScala extends TestBaseScala {
"SELECT explode(array(ST_GeomFromWKT('POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))'), ST_GeomFromWKT('POLYGON ((1 0, 2 0, 2 1, 1 1, 1 0))'))) AS geom")
val df = baseDf.select(ST_Union_Aggr("geom"))
val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry].toText()
val expectedResult = "POLYGON ((1 0, 0 0, 0 1, 1 1, 2 1, 2 0, 1 0))"
val expectedResult = "POLYGON ((0 0, 0 1, 1 1, 2 1, 2 0, 1 0, 0 0))"
assert(actualResult == expectedResult)
}

Expand Down

0 comments on commit bab1f77

Please sign in to comment.