Skip to content

Commit

Permalink
[SEDONA-648] Throw unsupported operation exception when ST_KNN is use…
Browse files Browse the repository at this point in the history
…d as UDF
  • Loading branch information
zhangfengcdt committed Sep 13, 2024
1 parent 20488c0 commit 7b4ea42
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 14 deletions.
5 changes: 2 additions & 3 deletions common/src/main/java/org/apache/sedona/common/Predicates.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,11 @@ public static boolean relateMatch(String matrix1, String matrix2) {
}

public static boolean knn(Geometry leftGeometry, Geometry rightGeometry, int k) {
return knn(leftGeometry, rightGeometry, k, false);
throw new UnsupportedOperationException("KNN predicate is not supported");
}

public static boolean knn(
Geometry leftGeometry, Geometry rightGeometry, int k, boolean useSpheroid) {
// This should only be used as a test predicate used with extra join condition
return true;
throw new UnsupportedOperationException("KNN predicate is not supported");
}
}
36 changes: 34 additions & 2 deletions docs/api/sql/NearestNeighbourSearching.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ In case there are ties in the distance, the result will include all the tied geo
spark.sedona.join.knn.includeTieBreakers=true
```

Filter Pushdown Considerations:
### Filter Pushdown Considerations:

When using ST_KNN with filters applied to the resulting DataFrame, some of these filters may be pushed down to the object side of the kNN join. This means the filters will be applied to the object side reader before the kNN join is executed. If you want the filters to be applied after the kNN join, ensure that you first materialize the kNN join results and then apply the filters.

Expand All @@ -43,7 +43,39 @@ CACHE TABLE knnResult;
SELECT * FROM knnResult WHERE condition;
```

SQL Example
### Handling SQL-Defined Tables in ST_KNN Joins:

When creating DataFrames from hard-coded SQL select statements in Sedona, and later using them in `ST_KNN` joins, Sedona may attempt to optimize the query in a way that bypasses the intended kNN join logic. Specifically, if you create DataFrames with hard-coded SQL, such as:

```scala
val df1 = sedona.sql("SELECT ST_Point(0.0, 0.0) as geom1")
val df2 = sedona.sql("SELECT ST_Point(0.0, 0.0) as geom2")

val df = df1.join(df2, expr("ST_KNN(geom1, geom2, 1)"))
```

Sedona may optimize the join to a form like this:

```sql
SELECT ST_KNN(ST_Point(0.0, 0.0), ST_Point(0.0, 0.0), 1)
```

As a result, the ST_KNN function is handled as a User-Defined Function (UDF) instead of a proper join operation, preventing Sedona from initiating the kNN join execution path. Unlike typical UDFs, the ST_KNN function operates on multiple rows across DataFrames, not just individual rows. When this occurs, the query fails with an UnsupportedOperationException, indicating that the KNN predicate is not supported.

Workaround:

To prevent Spark's optimization from bypassing the kNN join logic, the DataFrames created with hard-coded SQL select statements must be materialized before performing the join. By caching the DataFrames, you can instruct Spark to avoid this undesired optimization:

```scala
val df1 = sedona.sql("SELECT ST_Point(0.0, 0.0) as geom1").cache()
val df2 = sedona.sql("SELECT ST_Point(0.0, 0.0) as geom2").cache()

val df = df1.join(df2, expr("ST_KNN(geom1, geom2, 1)"))
```

Materializing the DataFrames with .cache() ensures that the correct kNN join path is followed in the Spark logical plan and prevents the optimization that would treat ST_KNN as a simple UDF.

### SQL Example

Suppose we have two tables `QUERIES` and `OBJECTS` with the following data:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import org.apache.spark.sql.sedona_sql.expressions._
import org.apache.spark.sql.sedona_sql.expressions.{ST_KNN, _}
import org.apache.spark.sql.sedona_sql.expressions.raster._
import org.apache.spark.sql.sedona_sql.optimization.ExpressionUtils.splitConjunctivePredicates
import org.apache.spark.sql.{SparkSession, Strategy}
Expand Down Expand Up @@ -602,7 +602,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
spatialPredicate = null,
isGeography,
condition,
extraCondition) :: Nil
extractExtraKNNJoinCondition(condition)) :: Nil
}

private def planDistanceJoin(
Expand Down Expand Up @@ -664,6 +664,24 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
}
}

private def extractExtraKNNJoinCondition(condition: Expression): Option[Expression] = {
condition match {
case and: And =>
// Check both left and right sides for ST_KNN or ST_AKNN
if (and.left.isInstanceOf[ST_KNN]) {
Some(and.right)
} else if (and.right.isInstanceOf[ST_KNN]) {
Some(and.left)
} else {
None
}
case _: ST_KNN =>
None
case _ =>
Some(condition)
}
}

private def planBroadcastJoin(
left: LogicalPlan,
right: LogicalPlan,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
df,
numNeighbors = 3,
useApproximate = false,
expressionSize = 5,
expressionSize = 4,
isGeography = true,
mustInclude = "")
}
Expand All @@ -83,7 +83,7 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
df,
numNeighbors = 3,
useApproximate = true,
expressionSize = 5,
expressionSize = 4,
isGeography = false,
mustInclude = "")
}
Expand All @@ -98,7 +98,7 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
df,
numNeighbors = 3,
useApproximate = true,
expressionSize = 5,
expressionSize = 4,
isGeography = false,
mustInclude = "")
}
Expand All @@ -112,7 +112,7 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
df,
numNeighbors = 3,
useApproximate = true,
expressionSize = 5,
expressionSize = 4,
isGeography = false,
mustInclude = "as int) <= 88))")
}
Expand All @@ -124,7 +124,7 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
df,
numNeighbors = 3,
useApproximate = true,
expressionSize = 5,
expressionSize = 4,
isGeography = false,
mustInclude = "= point))")
}
Expand All @@ -136,7 +136,7 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
df,
numNeighbors = 3,
useApproximate = true,
expressionSize = 5,
expressionSize = 4,
isGeography = false,
mustInclude = "= point))")
}
Expand All @@ -148,7 +148,7 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
df,
numNeighbors = 3,
useApproximate = true,
expressionSize = 5,
expressionSize = 4,
isGeography = false,
mustInclude = "")
}
Expand Down Expand Up @@ -216,6 +216,12 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
resultAll.length should be(8) // 2 queries (filtered out 1) and 4 neighbors each
resultAll.mkString should be("[2,1][2,5][2,11][2,15][3,3][3,9][3,13][3,19]")
}

it("Should throw KNN predicate is not supported exception") {
intercept[Exception] {
sparkSession.sql("SELECT ST_KNN(ST_Point(0.0, 0.0), ST_Point(0.0, 0.0), 1)").show()
}
}
}

describe("KNN spatial join SQLs should be executed correctly with complex join conditions") {
Expand Down

0 comments on commit 7b4ea42

Please sign in to comment.