Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-630] Improve ST_Union_Aggr performance #1526

Merged
merged 5 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tests/sql/test_dataframe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@
# aggregates
(sta.ST_Envelope_Aggr, ("geom",), "exploded_points", "", "POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))"),
(sta.ST_Intersection_Aggr, ("geom",), "exploded_polys", "", "LINESTRING (1 0, 1 1)"),
(sta.ST_Union_Aggr, ("geom",), "exploded_polys", "", "POLYGON ((1 0, 0 0, 0 1, 1 1, 2 1, 2 0, 1 0))"),
(sta.ST_Union_Aggr, ("geom",), "exploded_polys", "", "POLYGON ((0 0, 0 1, 1 1, 2 1, 2 0, 1 0, 0 0))"),
]

wrong_type_configurations = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.sedona_sql.expressions.raster._
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.operation.buffer.BufferParameters

import scala.collection.mutable.ListBuffer
import scala.reflect.ClassTag

object Catalog {
Expand Down Expand Up @@ -327,8 +328,13 @@ object Catalog {
function[RS_FromNetCDF](),
function[RS_NetCDFInfo]())

// Aggregate functions with Geometry as buffer
val aggregateExpressions: Seq[Aggregator[Geometry, Geometry, Geometry]] =
Seq(new ST_Union_Aggr, new ST_Envelope_Aggr, new ST_Intersection_Aggr)
Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr)

// Aggregate functions with List as buffer
val aggregateExpressions2: Seq[Aggregator[Geometry, ListBuffer[Geometry], Geometry]] =
Seq(new ST_Union_Aggr())

private def function[T <: Expression: ClassTag](defaultArgs: Any*): FunctionDescription = {
val classTag = implicitly[ClassTag[T]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ object UdfRegistrator {
}
Catalog.aggregateExpressions.foreach(f =>
sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f))) // SPARK3 anchor
//Catalog.aggregateExpressions_UDAF.foreach(f => sparkSession.udf.register(f.getClass.getSimpleName, f)) // SPARK2 anchor

Catalog.aggregateExpressions2.foreach(f =>
sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f))) // SPARK3 anchor
}

def dropAll(sparkSession: SparkSession): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ package org.apache.spark.sql.sedona_sql.expressions
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory}
import org.locationtech.jts.operation.overlayng.OverlayNGRobust

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer

/**
* traits for creating Aggregate Function
Expand Down Expand Up @@ -50,22 +54,44 @@ trait TraitSTAggregateExec {
def finish(out: Geometry): Geometry = out
}

/**
* Return the polygon union of all Polygon in the given column
*/
class ST_Union_Aggr extends Aggregator[Geometry, Geometry, Geometry] with TraitSTAggregateExec {
class ST_Union_Aggr(bufferSize: Int = 1000)
extends Aggregator[Geometry, ListBuffer[Geometry], Geometry]
with Serializable {

override def reduce(buffer: ListBuffer[Geometry], input: Geometry): ListBuffer[Geometry] = {
buffer += input
if (buffer.size >= bufferSize) {
// Perform the union when buffer size is reached
val unionGeometry = OverlayNGRobust.union(buffer.asJava)
buffer.clear()
buffer += unionGeometry
}
buffer
}

def reduce(buffer: Geometry, input: Geometry): Geometry = {
if (buffer.equalsExact(initialGeometry)) input
else buffer.union(input)
override def merge(
buffer1: ListBuffer[Geometry],
buffer2: ListBuffer[Geometry]): ListBuffer[Geometry] = {
buffer1 ++= buffer2
if (buffer1.size >= bufferSize) {
// Perform the union when buffer size is reached
val unionGeometry = OverlayNGRobust.union(buffer1.asJava)
buffer1.clear()
buffer1 += unionGeometry
}
buffer1
}

def merge(buffer1: Geometry, buffer2: Geometry): Geometry = {
if (buffer1.equals(initialGeometry)) buffer2
else if (buffer2.equals(initialGeometry)) buffer1
else buffer1.union(buffer2)
override def finish(reduction: ListBuffer[Geometry]): Geometry = {
OverlayNGRobust.union(reduction.asJava)
}

def bufferEncoder: ExpressionEncoder[ListBuffer[Geometry]] =
ExpressionEncoder[ListBuffer[Geometry]]()

def outputEncoder: ExpressionEncoder[Geometry] = ExpressionEncoder[Geometry]()

override def zero: ListBuffer[Geometry] = ListBuffer.empty
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
*/
package org.apache.sedona.sql

import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.expressions.javalang.typed
import org.apache.spark.sql.sedona_sql.expressions.ST_Union_Aggr
import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory, Polygon}
import org.locationtech.jts.io.WKTReader

import scala.util.Random

class aggregateFunctionTestScala extends TestBaseScala {

Expand Down Expand Up @@ -62,6 +68,35 @@ class aggregateFunctionTestScala extends TestBaseScala {
assert(union.take(1)(0).get(0).asInstanceOf[Geometry].getArea == 10100)
}

it("Measured ST_Union_aggr wall time") {
// number of random polygons to generate
val numPolygons = 1000
val df = createPolygonDataFrame(numPolygons)

df.createOrReplaceTempView("geometry_table_for_measuring_union_aggr")

// cache the table to eliminate the time of table scan
df.cache()
sparkSession
.sql("select count(*) from geometry_table_for_measuring_union_aggr")
.take(1)(0)
.get(0)

// measure time for optimized ST_Union_Aggr
val startTimeOptimized = System.currentTimeMillis()
val unionOptimized =
sparkSession.sql(
"SELECT ST_Union_Aggr(geom) AS union_geom FROM geometry_table_for_measuring_union_aggr")
assert(unionOptimized.take(1)(0).get(0).asInstanceOf[Geometry].getArea > 0)
val endTimeOptimized = System.currentTimeMillis()
val durationOptimized = endTimeOptimized - startTimeOptimized

assert(durationOptimized > 0, "Duration of optimized ST_Union_Aggr should be positive")

// clear cache
df.unpersist()
}

it("Passed ST_Intersection_aggr") {

val twoPolygonsAsWktDf =
Expand Down Expand Up @@ -97,4 +132,24 @@ class aggregateFunctionTestScala extends TestBaseScala {
assertResult(0.0)(intersectionDF.take(1)(0).get(0).asInstanceOf[Geometry].getArea)
}
}

def generateRandomPolygon(index: Int): String = {
val random = new Random()
val x = random.nextDouble() * index
val y = random.nextDouble() * index
s"POLYGON (($x $y, ${x + 1} $y, ${x + 1} ${y + 1}, $x ${y + 1}, $x $y))"
}

def createPolygonDataFrame(numPolygons: Int): DataFrame = {
val polygons = (1 to numPolygons).map(generateRandomPolygon).toArray
val polygonArray = polygons.map(polygon => s"ST_GeomFromWKT('$polygon')")
val polygonArrayStr = polygonArray.mkString(", ")

val sqlQuery =
s"""
|SELECT explode(array($polygonArrayStr)) AS geom
""".stripMargin

sparkSession.sql(sqlQuery)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1585,7 +1585,7 @@ class dataFrameAPITestScala extends TestBaseScala {
"SELECT explode(array(ST_GeomFromWKT('POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))'), ST_GeomFromWKT('POLYGON ((1 0, 2 0, 2 1, 1 1, 1 0))'))) AS geom")
val df = baseDf.select(ST_Union_Aggr("geom"))
val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry].toText()
val expectedResult = "POLYGON ((1 0, 0 0, 0 1, 1 1, 2 1, 2 0, 1 0))"
val expectedResult = "POLYGON ((0 0, 0 1, 1 1, 2 1, 2 0, 1 0, 0 0))"
assert(actualResult == expectedResult)
}

Expand Down
Loading