From edab5c889aaf7dbeae906e8aa01ebc6fb68460cd Mon Sep 17 00:00:00 2001 From: Furqaanahmed Khan Date: Wed, 10 Jul 2024 18:18:17 -0400 Subject: [PATCH 01/14] feat: add ST_GeneratePoints --- .../org/apache/sedona/common/Functions.java | 8 ++++++ .../apache/sedona/common/FunctionsTest.java | 14 ++++++++++ docs/api/flink/Function.md | 26 +++++++++++++++++++ docs/api/snowflake/vector-data/Function.md | 23 ++++++++++++++++ docs/api/sql/Function.md | 26 +++++++++++++++++++ .../java/org/apache/sedona/flink/Catalog.java | 1 + .../sedona/flink/expressions/Functions.java | 10 +++++++ .../org/apache/sedona/flink/FunctionTest.java | 25 ++++++++++++++++++ python/sedona/sql/st_functions.py | 13 ++++++++++ python/tests/sql/test_dataframe_api.py | 3 +++ python/tests/sql/test_function.py | 9 +++++++ .../snowflake/snowsql/TestFunctions.java | 9 +++++++ .../snowflake/snowsql/TestFunctionsV2.java | 9 +++++++ .../apache/sedona/snowflake/snowsql/UDFs.java | 6 +++++ .../sedona/snowflake/snowsql/UDFsV2.java | 8 ++++++ .../org/apache/sedona/sql/UDF/Catalog.scala | 1 + .../sedona_sql/expressions/Functions.scala | 7 +++++ .../sedona_sql/expressions/st_functions.scala | 7 +++++ .../sedona/sql/dataFrameAPITestScala.scala | 22 ++++++++++++++++ .../apache/sedona/sql/functionTestScala.scala | 14 ++++++++++ 20 files changed, 241 insertions(+) diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java b/common/src/main/java/org/apache/sedona/common/Functions.java index a43519db19..054694acdb 100644 --- a/common/src/main/java/org/apache/sedona/common/Functions.java +++ b/common/src/main/java/org/apache/sedona/common/Functions.java @@ -56,6 +56,7 @@ import org.locationtech.jts.operation.valid.TopologyValidationError; import org.locationtech.jts.precision.GeometryPrecisionReducer; import org.locationtech.jts.precision.MinimumClearance; +import org.locationtech.jts.shape.random.RandomPointsBuilder; import org.locationtech.jts.simplify.PolygonHullSimplifier; import org.locationtech.jts.simplify.TopologyPreservingSimplifier; import org.locationtech.jts.simplify.VWSimplifier; @@ -1812,6 +1813,13 @@ private static Geometry[] convertGeometryToArray(Geometry geom) { return array; } + public static Geometry generatePoints(Geometry geom, int numPoints) { + RandomPointsBuilder pointsBuilder = new RandomPointsBuilder(geom.getFactory()); + pointsBuilder.setExtent(geom); + pointsBuilder.setNumPoints(numPoints); + return pointsBuilder.getGeometry(); + } + public static Integer nRings(Geometry geometry) throws Exception { String geometryType = geometry.getGeometryType(); if (!(geometry instanceof Polygon || geometry instanceof MultiPolygon)) { diff --git a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java index 3868b45395..dae3df01fc 100644 --- a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java +++ b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java @@ -2112,6 +2112,20 @@ public void minimumClearanceLine() throws ParseException { assertEquals(expected, actual); } + @Test + public void generatePoints() throws ParseException { + Geometry geom = Constructors.geomFromEWKT("LINESTRING(50 50,150 150,150 50)"); + Geometry actual = + Functions.generatePoints(Functions.buffer(geom, 10, false, "endcap=round join=round"), 12); + assertEquals(actual.getNumGeometries(), 12); + + geom = + Constructors.geomFromEWKT( + "MULTIPOLYGON (((10 0, 10 10, 20 10, 20 0, 10 0)), ((50 0, 50 10, 70 10, 70 0, 50 0)))"); + actual = Functions.generatePoints(geom, 30); + assertEquals(actual.getNumGeometries(), 30); + } + @Test public void nRingsPolygonOnlyExternal() throws Exception { Polygon polygon = GEOMETRY_FACTORY.createPolygon(coordArray(1, 0, 1, 1, 2, 1, 2, 0, 1, 0)); diff --git a/docs/api/flink/Function.md b/docs/api/flink/Function.md index 2a77036db2..f251b83e38 100644 --- a/docs/api/flink/Function.md +++ b/docs/api/flink/Function.md @@ -1552,6 +1552,32 @@ Output: 5.0990195135927845 ``` +## ST_GeneratePoints + +Introduction: Generates a specified quantity of pseudo-random points within the boundaries of the provided polygonal geometry. + +Format: `ST_GeneratePoints(geom: Geometry, numPoints: Integer)` + +Since: `v1.6.1` + +SQL Example: + +```sql +SELECT ST_GeneratePoints( + ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'), 4 +) +``` + +Output: + +!!!Note + Due to the pseudo-random nature of point generation, the output of this function will vary between executions and may not match any provided examples. + + +``` +MULTIPOINT ((0.2393028905520183 0.9721563442837837), (0.3805848547053376 0.7546556656982678), (0.0950295778200995 0.2494334895495989), (0.4133520939987385 0.3447046312451945)) +``` + ## ST_GeoHash Introduction: Returns GeoHash of the geometry with given precision diff --git a/docs/api/snowflake/vector-data/Function.md b/docs/api/snowflake/vector-data/Function.md index c067f6a1fa..8e0b590a4a 100644 --- a/docs/api/snowflake/vector-data/Function.md +++ b/docs/api/snowflake/vector-data/Function.md @@ -1163,6 +1163,29 @@ Output: 5.0990195135927845 ``` +## ST_GeneratePoints + +Introduction: Generates a specified quantity of pseudo-random points within the boundaries of the provided polygonal geometry. + +Format: `ST_GeneratePoints(geom: Geometry, numPoints: Integer)` + +SQL Example: + +```sql +SELECT ST_GeneratePoints( + ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'), 4 +) +``` + +Output: + +!!!Note + Due to the pseudo-random nature of point generation, the output of this function will vary between executions and may not match any provided examples. + +``` +MULTIPOINT ((0.2393028905520183 0.9721563442837837), (0.3805848547053376 0.7546556656982678), (0.0950295778200995 0.2494334895495989), (0.4133520939987385 0.3447046312451945)) +``` + ## ST_GeoHash Introduction: Returns GeoHash of the geometry with given precision diff --git a/docs/api/sql/Function.md b/docs/api/sql/Function.md index f2e1e9afb1..f6cca656c9 100644 --- a/docs/api/sql/Function.md +++ b/docs/api/sql/Function.md @@ -1557,6 +1557,32 @@ Output: 5.0990195135927845 ``` +## ST_GeneratePoints + +Introduction: Generates a specified quantity of pseudo-random points within the boundaries of the provided polygonal geometry. + +Format: `ST_GeneratePoints(geom: Geometry, numPoints: Integer)` + +Since: `v1.6.1` + +SQL Example: + +```sql +SELECT ST_GeneratePoints( + ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'), 4 +) +``` + +Output: + +!!!Note + Due to the pseudo-random nature of point generation, the output of this function will vary between executions and may not match any provided examples. + + +``` +MULTIPOINT ((0.2393028905520183 0.9721563442837837), (0.3805848547053376 0.7546556656982678), (0.0950295778200995 0.2494334895495989), (0.4133520939987385 0.3447046312451945)) +``` + ## ST_GeoHash Introduction: Returns GeoHash of the geometry with given precision diff --git a/flink/src/main/java/org/apache/sedona/flink/Catalog.java b/flink/src/main/java/org/apache/sedona/flink/Catalog.java index c043a85d05..b19e99166b 100644 --- a/flink/src/main/java/org/apache/sedona/flink/Catalog.java +++ b/flink/src/main/java/org/apache/sedona/flink/Catalog.java @@ -184,6 +184,7 @@ public static UserDefinedFunction[] getFuncs() { new Functions.ST_ForceCollection(), new Functions.ST_ForcePolygonCW(), new Functions.ST_ForceRHR(), + new Functions.ST_GeneratePoints(), new Functions.ST_NRings(), new Functions.ST_IsPolygonCCW(), new Functions.ST_ForcePolygonCCW(), diff --git a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java index 87df74658d..782f63510a 100644 --- a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java +++ b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java @@ -1549,6 +1549,16 @@ public Geometry eval( } } + public static class ST_GeneratePoints extends ScalarFunction { + @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) + public Geometry eval( + @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) Object o, + @DataTypeHint(value = "Integer") Integer numPoints) { + Geometry geom = (Geometry) o; + return org.apache.sedona.common.Functions.generatePoints(geom, numPoints); + } + } + public static class ST_NRings extends ScalarFunction { @DataTypeHint(value = "Integer") public int eval( diff --git a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java index 08c8632b1c..136360ee35 100644 --- a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java +++ b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java @@ -2142,6 +2142,31 @@ public void testIsPolygonCW() { assertTrue(actual); } + @Test + public void testGeneratePoints() { + Table polyTable = + tableEnv.sqlQuery( + "SELECT ST_Buffer(ST_GeomFromWKT('LINESTRING(50 50,150 150,150 50)'), 10, false, 'endcap=round join=round') AS geom"); + Geometry actual = + (Geometry) + first( + polyTable.select( + call(Functions.ST_GeneratePoints.class.getSimpleName(), $("geom"), 15))) + .getField(0); + assertEquals(actual.getNumGeometries(), 15); + + polyTable = + tableEnv.sqlQuery( + "SELECT ST_GeomFromWKT('MULTIPOLYGON (((10 0, 10 10, 20 10, 20 0, 10 0)), ((50 0, 50 10, 70 10, 70 0, 50 0)))') AS geom"); + actual = + (Geometry) + first( + polyTable.select( + call(Functions.ST_GeneratePoints.class.getSimpleName(), $("geom"), 30))) + .getField(0); + assertEquals(actual.getNumGeometries(), 30); + } + @Test public void testNRings() { Integer expected = 1; diff --git a/python/sedona/sql/st_functions.py b/python/sedona/sql/st_functions.py index 8b5d1ef752..9bec9813ce 100644 --- a/python/sedona/sql/st_functions.py +++ b/python/sedona/sql/st_functions.py @@ -548,6 +548,19 @@ def ST_Force_2D(geometry: ColumnOrName) -> Column: return _call_st_function("ST_Force_2D", geometry) +@validate_argument_types +def ST_GeneratePoints(geometry: ColumnOrName, numPoints: Union[ColumnOrName, int]) -> Column: + """Generate random points in given geometry. + + :param geometry: Geometry column to hash. + :type geometry: ColumnOrName + :param numPoints: Precision level to hash geometry at, given as an integer or an integer column. + :type numPoints: Union[ColumnOrName, int] + :return: Generate random points in given geometry + :rtype: Column + """ + return _call_st_function("ST_GeneratePoints", (geometry, numPoints)) + @validate_argument_types def ST_GeoHash(geometry: ColumnOrName, precision: Union[ColumnOrName, int]) -> Column: """Return the geohash of a geometry column at a given precision level. diff --git a/python/tests/sql/test_dataframe_api.py b/python/tests/sql/test_dataframe_api.py index 1a97bb6ac1..758824d3bc 100644 --- a/python/tests/sql/test_dataframe_api.py +++ b/python/tests/sql/test_dataframe_api.py @@ -137,6 +137,7 @@ (stf.ST_ForceRHR, ("geom",), "geom_with_hole", "", "POLYGON ((0 0, 3 3, 3 0, 0 0), (1 1, 2 1, 2 2, 1 1))"), (stf.ST_FrechetDistance, ("point", "line",), "point_and_line", "", 5.0990195135927845), (stf.ST_GeometricMedian, ("multipoint",), "multipoint_geom", "", "POINT (22.500002656424286 21.250001168173426)"), + (stf.ST_GeneratePoints, ("geom", 15), "square_geom", "ST_NumGeometries(geom)", 15), (stf.ST_GeometryN, ("geom", 0), "multipoint", "", "POINT (0 0)"), (stf.ST_GeometryType, ("point",), "point_geom", "", "ST_Point"), (stf.ST_HausdorffDistance, ("point", "line",), "point_and_line", "", 5.0990195135927845), @@ -351,6 +352,8 @@ (stf.ST_GeometryN, ("", None)), (stf.ST_GeometryN, ("", 0.0)), (stf.ST_GeometryType, (None,)), + (stf.ST_GeneratePoints, (None, 0.0)), + (stf.ST_GeneratePoints, ("", None)), (stf.ST_InteriorRingN, (None, 0)), (stf.ST_InteriorRingN, ("", None)), (stf.ST_InteriorRingN, ("", 0.0)), diff --git a/python/tests/sql/test_function.py b/python/tests/sql/test_function.py index 5fd8dd2b5b..e361a1a832 100644 --- a/python/tests/sql/test_function.py +++ b/python/tests/sql/test_function.py @@ -1606,6 +1606,15 @@ def test_forceRHR(self): expected = "POLYGON ((20 35, 45 20, 30 5, 10 10, 10 30, 20 35), (30 20, 20 25, 20 15, 30 20))" assert expected == actual + def test_generate_points(self): + actual = self.spark.sql("SELECT ST_NumGeometries(ST_GeneratePoints(ST_Buffer(ST_GeomFromWKT('LINESTRING(50 50,150 150,150 50)'), 10, false, 'endcap=round join=round'), 15))")\ + .first()[0] + assert actual == 15 + + actual = self.spark.sql("SELECT ST_NumGeometries(ST_GeneratePoints(ST_GeomFromWKT('MULTIPOLYGON (((10 0, 10 10, 20 10, 20 0, 10 0)), ((50 0, 50 10, 70 10, 70 0, 50 0)))'), 30))")\ + .first()[0] + assert actual == 30 + def test_nRings(self): expected = 1 actualDf = self.spark.sql("SELECT ST_GeomFromText('POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))') AS geom") diff --git a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java index c95ab35a37..9efaae5cb3 100644 --- a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java +++ b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java @@ -394,6 +394,15 @@ public void test_ST_Force2D() { "select sedona.ST_AsText(sedona.ST_Force2D(sedona.ST_POINTZ(1, 2, 3)))", "POINT (1 2)"); } + @Test + public void test_ST_GeneratePoints() { + registerUDF("ST_GeneratePoints", byte[].class, int.class); + registerUDF("ST_NumGeometries", byte[].class); + verifySqlSingleRes( + "select sedona.ST_NumGeometries(sedona.ST_GeneratePoints(sedona.ST_GeomFromWKT('POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))'), 15))", + 15); + } + @Test public void test_ST_GeoHash() { registerUDF("ST_GeoHash", byte[].class, int.class); diff --git a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java index 4060e9447b..96ce0b3e91 100644 --- a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java +++ b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java @@ -376,6 +376,15 @@ public void test_ST_Force2D() { verifySqlSingleRes("select ST_AsText(sedona.ST_Force2D(ST_GEOMPOINT(1, 2)))", "POINT(1 2)"); } + @Test + public void test_ST_GeneratePoints() { + registerUDFV2("ST_GeneratePoints", String.class, int.class); + registerUDFV2("ST_NumGeometries", String.class); + verifySqlSingleRes( + "select sedona.ST_NumGeometries(sedona.ST_GeneratePoints(ST_GeomFromWKT('POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))'), 15))", + 15); + } + @Test public void test_ST_GeoHash() { registerUDFV2("ST_GeoHash", String.class, int.class); diff --git a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFs.java b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFs.java index 7985017cb3..0dcb82ba9f 100644 --- a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFs.java +++ b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFs.java @@ -366,6 +366,12 @@ public static byte[] ST_Force2D(byte[] geometry) { return GeometrySerde.serialize(Functions.force2D(GeometrySerde.deserialize(geometry))); } + @UDFAnnotations.ParamMeta(argNames = {"geometry", "numPoints"}) + public static byte[] ST_GeneratePoints(byte[] geometry, int numPoints) { + return GeometrySerde.serialize( + Functions.generatePoints(GeometrySerde.deserialize(geometry), numPoints)); + } + @UDFAnnotations.ParamMeta(argNames = {"geometry", "precision"}) public static String ST_GeoHash(byte[] geometry, int precision) { return Functions.geohash(GeometrySerde.deserialize(geometry), precision); diff --git a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java index fa75a51b02..9d334ca136 100644 --- a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java +++ b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java @@ -509,6 +509,14 @@ public static String ST_Force2D(String geometry) { return GeometrySerde.serGeoJson(Functions.force2D(GeometrySerde.deserGeoJson(geometry))); } + @UDFAnnotations.ParamMeta( + argNames = {"geometry", "numPoints"}, + argTypes = {"Geometry", "int"}) + public static String ST_GeneratePoints(String geometry, int numPoints) { + return GeometrySerde.serGeoJson( + Functions.generatePoints(GeometrySerde.deserGeoJson(geometry), numPoints)); + } + @UDFAnnotations.ParamMeta( argNames = {"geometry", "precision"}, argTypes = {"Geometry", "int"}) diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala index 2b9aefd845..4fb0aa044b 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala @@ -213,6 +213,7 @@ object Catalog { function[ST_Force3DZ](0.0), function[ST_Force4D](), function[ST_ForceCollection](), + function[ST_GeneratePoints](), function[ST_NRings](), function[ST_Translate](0.0), function[ST_TriangulatePolygon](), diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index 47711716a3..a69797d8e7 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -1435,6 +1435,13 @@ case class ST_ForceRHR(inputExpressions: Seq[Expression]) } } +case class ST_GeneratePoints(inputExpressions: Seq[Expression]) + extends InferredExpression(Functions.generatePoints _) { + protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { + copy(inputExpressions = newChildren) + } +} + case class ST_NRings(inputExpressions: Seq[Expression]) extends InferredExpression(Functions.nRings _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala index cc7a756ee9..5427b555f3 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala @@ -676,6 +676,13 @@ object st_functions extends DataFrameAPI { def ST_ForceRHR(geometry: Column): Column = wrapExpression[ST_ForceRHR](geometry) def ST_ForceRHR(geometry: String): Column = wrapExpression[ST_ForceRHR](geometry) + def ST_GeneratePoints(geometry: Column, numPoints: Column): Column = + wrapExpression[ST_GeneratePoints](geometry, numPoints) + def ST_GeneratePoints(geometry: String, numPoints: String): Column = + wrapExpression[ST_GeneratePoints](geometry, numPoints) + def ST_GeneratePoints(geometry: String, numPoints: Integer): Column = + wrapExpression[ST_GeneratePoints](geometry, numPoints) + def ST_NRings(geometry: Column): Column = wrapExpression[ST_NRings](geometry) def ST_NRings(geometry: String): Column = wrapExpression[ST_NRings](geometry) diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala index 81b849592f..d01cb67990 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala @@ -1833,6 +1833,28 @@ class dataFrameAPITestScala extends TestBaseScala { assertTrue(actual) } + it("Should pass ST_GeneratePoints") { + var poly = sparkSession + .sql( + "SELECT ST_Buffer(ST_GeomFromWKT('LINESTRING(50 50,150 150,150 50)'), 10, false, 'endcap=round join=round') AS geom") + var actual = poly + .select(ST_NumGeometries(ST_GeneratePoints("geom", 15))) + .first() + .get(0) + .asInstanceOf[Int] + assert(actual == 15) + + poly = sparkSession + .sql( + "SELECT ST_GeomFromWKT('MULTIPOLYGON (((10 0, 10 10, 20 10, 20 0, 10 0)), ((50 0, 50 10, 70 10, 70 0, 50 0)))') AS geom") + actual = poly + .select(ST_NumGeometries(ST_GeneratePoints("geom", 30))) + .first() + .get(0) + .asInstanceOf[Int] + assert(actual == 30) + } + it("Passed ST_NRings") { val polyDf = sparkSession.sql("SELECT ST_GeomFromWKT('POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))') AS geom") diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala index ea399f4f15..e1503ab6eb 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala @@ -2916,6 +2916,20 @@ class functionTestScala assert(actual == 1) } + it("Should pass ST_GeneratePoints") { + var actual = sparkSession + .sql("SELECT ST_NumGeometries(ST_GeneratePoints(ST_Buffer(ST_GeomFromWKT('LINESTRING(50 50,150 150,150 50)'), 10, false, 'endcap=round join=round'), 15))") + .first() + .get(0) + assert(actual == 15) + + actual = sparkSession + .sql("SELECT ST_NumGeometries(ST_GeneratePoints(ST_GeomFromWKT('MULTIPOLYGON (((10 0, 10 10, 20 10, 20 0, 10 0)), ((50 0, 50 10, 70 10, 70 0, 50 0)))'), 30))") + .first() + .get(0) + assert(actual == 30) + } + it("should pass ST_NRings") { val geomTestCases = Map( ("'POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))'") -> 1, From ac3bffbef54de3542409f5106785e645d6e82201 Mon Sep 17 00:00:00 2001 From: Furqaanahmed Khan Date: Wed, 10 Jul 2024 18:24:23 -0400 Subject: [PATCH 02/14] fix: lint errors --- docs/api/flink/Function.md | 1 - docs/api/sql/Function.md | 1 - 2 files changed, 2 deletions(-) diff --git a/docs/api/flink/Function.md b/docs/api/flink/Function.md index f251b83e38..10ff7ddfee 100644 --- a/docs/api/flink/Function.md +++ b/docs/api/flink/Function.md @@ -1573,7 +1573,6 @@ Output: !!!Note Due to the pseudo-random nature of point generation, the output of this function will vary between executions and may not match any provided examples. - ``` MULTIPOINT ((0.2393028905520183 0.9721563442837837), (0.3805848547053376 0.7546556656982678), (0.0950295778200995 0.2494334895495989), (0.4133520939987385 0.3447046312451945)) ``` diff --git a/docs/api/sql/Function.md b/docs/api/sql/Function.md index f6cca656c9..224b2ba2c0 100644 --- a/docs/api/sql/Function.md +++ b/docs/api/sql/Function.md @@ -1578,7 +1578,6 @@ Output: !!!Note Due to the pseudo-random nature of point generation, the output of this function will vary between executions and may not match any provided examples. - ``` MULTIPOINT ((0.2393028905520183 0.9721563442837837), (0.3805848547053376 0.7546556656982678), (0.0950295778200995 0.2494334895495989), (0.4133520939987385 0.3447046312451945)) ``` From c6de12e708f152ba3ba62c4939e617f9cbe910cf Mon Sep 17 00:00:00 2001 From: Furqaanahmed Khan Date: Wed, 10 Jul 2024 21:43:17 -0400 Subject: [PATCH 03/14] fix: snowflake v2 registeration --- .../main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java index 9d334ca136..de48b186eb 100644 --- a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java +++ b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java @@ -511,7 +511,8 @@ public static String ST_Force2D(String geometry) { @UDFAnnotations.ParamMeta( argNames = {"geometry", "numPoints"}, - argTypes = {"Geometry", "int"}) + argTypes = {"Geometry", "int"}, + returnTypes = "Geometry") public static String ST_GeneratePoints(String geometry, int numPoints) { return GeometrySerde.serGeoJson( Functions.generatePoints(GeometrySerde.deserGeoJson(geometry), numPoints)); From ba33b3c6cefabd470c8b32f2ed1f90ccb1deae4b Mon Sep 17 00:00:00 2001 From: Furqaanahmed Khan Date: Fri, 12 Jul 2024 13:58:41 -0400 Subject: [PATCH 04/14] feat: add Nondeterministic trait to spark --- .../org/apache/sedona/common/Functions.java | 9 ++-- .../common/utils/RandomPointsBuilderSeed.java | 47 ++++++++++++++++ .../apache/sedona/common/FunctionsTest.java | 12 ++++- .../sedona_sql/expressions/Functions.scala | 53 +++++++++++++++++-- .../apache/sedona/sql/functionTestScala.scala | 8 +++ 5 files changed, 122 insertions(+), 7 deletions(-) create mode 100644 common/src/main/java/org/apache/sedona/common/utils/RandomPointsBuilderSeed.java diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java b/common/src/main/java/org/apache/sedona/common/Functions.java index 054694acdb..3b213a94cc 100644 --- a/common/src/main/java/org/apache/sedona/common/Functions.java +++ b/common/src/main/java/org/apache/sedona/common/Functions.java @@ -56,7 +56,6 @@ import org.locationtech.jts.operation.valid.TopologyValidationError; import org.locationtech.jts.precision.GeometryPrecisionReducer; import org.locationtech.jts.precision.MinimumClearance; -import org.locationtech.jts.shape.random.RandomPointsBuilder; import org.locationtech.jts.simplify.PolygonHullSimplifier; import org.locationtech.jts.simplify.TopologyPreservingSimplifier; import org.locationtech.jts.simplify.VWSimplifier; @@ -1813,13 +1812,17 @@ private static Geometry[] convertGeometryToArray(Geometry geom) { return array; } - public static Geometry generatePoints(Geometry geom, int numPoints) { - RandomPointsBuilder pointsBuilder = new RandomPointsBuilder(geom.getFactory()); + public static Geometry generatePoints(Geometry geom, int numPoints, int seed) { + RandomPointsBuilderSeed pointsBuilder = new RandomPointsBuilderSeed(geom.getFactory(), seed); pointsBuilder.setExtent(geom); pointsBuilder.setNumPoints(numPoints); return pointsBuilder.getGeometry(); } + public static Geometry generatePoints(Geometry geom, int numPoints) { + return generatePoints(geom, numPoints, 0); + } + public static Integer nRings(Geometry geometry) throws Exception { String geometryType = geometry.getGeometryType(); if (!(geometry instanceof Polygon || geometry instanceof MultiPolygon)) { diff --git a/common/src/main/java/org/apache/sedona/common/utils/RandomPointsBuilderSeed.java b/common/src/main/java/org/apache/sedona/common/utils/RandomPointsBuilderSeed.java new file mode 100644 index 0000000000..24269cbfb4 --- /dev/null +++ b/common/src/main/java/org/apache/sedona/common/utils/RandomPointsBuilderSeed.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.common.utils; + +import java.util.Random; +import org.locationtech.jts.geom.*; +import org.locationtech.jts.shape.random.RandomPointsBuilder; + +public class RandomPointsBuilderSeed extends RandomPointsBuilder { + double seed; + int counter = 0; + Random rand; + + public RandomPointsBuilderSeed(GeometryFactory geometryFactory, double seed) { + super(geometryFactory); + this.seed = seed; + if (seed > 0) { + this.rand = new Random((long) seed); + return; + } + this.rand = new Random(); + } + + @Override + protected Coordinate createRandomCoord(Envelope env) { + counter++; + double x = env.getMinX() + env.getWidth() * rand.nextDouble(); + double y = env.getMinY() + env.getHeight() * rand.nextDouble(); + return createCoord(x, y); + } +} diff --git a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java index dae3df01fc..96211c86d8 100644 --- a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java +++ b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java @@ -2114,11 +2114,21 @@ public void minimumClearanceLine() throws ParseException { @Test public void generatePoints() throws ParseException { - Geometry geom = Constructors.geomFromEWKT("LINESTRING(50 50,150 150,150 50)"); + Geometry geom = Constructors.geomFromEWKT("LINESTRING(50 50,10 10,10 50)"); + Geometry actual = Functions.generatePoints(Functions.buffer(geom, 10, false, "endcap=round join=round"), 12); assertEquals(actual.getNumGeometries(), 12); + actual = + Functions.reducePrecision( + Functions.generatePoints( + Functions.buffer(geom, 10, false, "endcap=round join=round"), 5, 100), + 5); + String expected = + "MULTIPOINT ((40.02957 46.70645), (37.11646 37.38582), (14.2051 29.23363), (40.82533 31.47273), (28.16839 34.16338))"; + assertEquals(expected, actual.toString()); + geom = Constructors.geomFromEWKT( "MULTIPOLYGON (((10 0, 10 10, 20 10, 20 0, 10 0)), ((50 0, 50 10, 70 10, 70 0, 50 0)))"); diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index a69797d8e7..41b71f6752 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -23,8 +23,9 @@ import org.apache.sedona.common.sphere.{Haversine, Spheroid} import org.apache.sedona.common.utils.{InscribedCircle, ValidDetail} import org.apache.sedona.sql.utils.GeometrySerializer import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Generator} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionWithRandomSeed, Generator, Literal, Nondeterministic} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT import org.apache.spark.sql.sedona_sql.expressions.implicits._ @@ -34,6 +35,8 @@ import org.locationtech.jts.geom._ import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._ import org.apache.spark.unsafe.types.UTF8String +import scala.util.Random + /** * Return the distance between two geometries. * @@ -1435,11 +1438,55 @@ case class ST_ForceRHR(inputExpressions: Seq[Expression]) } } -case class ST_GeneratePoints(inputExpressions: Seq[Expression]) - extends InferredExpression(Functions.generatePoints _) { +case class ST_GeneratePoints(inputExpressions: Seq[Expression], randomSeed: Option[Long] = None) + extends Expression + with Nondeterministic + with ExpectsInputTypes + with CodegenFallback + with ExpressionWithRandomSeed { + + def this(inputExpressions: Seq[Expression]) = this(inputExpressions, None) + + def seedExpression: Expression = randomSeed.map(Literal.apply).getOrElse(Literal(0L)) + + @transient private[this] var random: Random = _ + protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } + + private val nArgs = children.length + + override protected def initializeInternal(partitionIndex: Int): Unit = random = new Random( + randomSeed.getOrElse(0L) + partitionIndex) + + override protected def evalInternal(input: InternalRow): Any = { + val geom = children.head.toGeometry(input) + val numPoints = children(1).eval(input).asInstanceOf[Int] + if (nArgs == 3) { + val seed = children(2).eval(input).asInstanceOf[Int] + return GeometrySerializer.serialize(Functions.generatePoints(geom, numPoints, seed)) + } + GeometrySerializer.serialize(Functions.generatePoints(geom, numPoints)) + } + + override def nullable: Boolean = true + + override def dataType: DataType = GeometryUDT + + override def inputTypes: Seq[AbstractDataType] = { + if (nArgs == 3) { + Seq(GeometryUDT, IntegerType, IntegerType) + } else if (nArgs == 2) { + Seq(GeometryUDT, IntegerType) + } else { + throw new IllegalArgumentException(s"Invalid number of arguments: $nArgs") + } + } + + override def children: Seq[Expression] = inputExpressions + + override def withNewSeed(seed: Long): ST_GeneratePoints = copy(randomSeed = Some(seed)) } case class ST_NRings(inputExpressions: Seq[Expression]) diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala index e1503ab6eb..176ab296ad 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala @@ -2928,6 +2928,14 @@ class functionTestScala .first() .get(0) assert(actual == 30) + + actual = sparkSession + .sql("SELECT ST_AsText(ST_ReducePrecision(ST_GeneratePoints(ST_GeomFromWKT('MULTIPOLYGON (((10 0, 10 10, 20 10, 20 0, 10 0)), ((50 0, 50 10, 70 10, 70 0, 50 0)))'), 5, 10), 5))") + .first() + .get(0) + val expected = + "MULTIPOINT ((53.82582 2.57803), (13.55212 2.44117), (59.12854 3.70611), (61.37698 7.14985), (10.49657 4.40622))" + assertEquals(expected, actual) } it("should pass ST_NRings") { From c996fe8457e4a199e8a00d286d6cc0ff5a570474 Mon Sep 17 00:00:00 2001 From: Furqaanahmed Khan Date: Fri, 12 Jul 2024 14:03:16 -0400 Subject: [PATCH 05/14] fix: optimize imports --- .../org/apache/spark/sql/sedona_sql/expressions/Functions.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index 41b71f6752..7b38a22859 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -23,7 +23,6 @@ import org.apache.sedona.common.sphere.{Haversine, Spheroid} import org.apache.sedona.common.utils.{InscribedCircle, ValidDetail} import org.apache.sedona.sql.utils.GeometrySerializer import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionWithRandomSeed, Generator, Literal, Nondeterministic} import org.apache.spark.sql.catalyst.util.ArrayData From 83c98d76f46bcacba0f861fd0aaf2d58e281ff9f Mon Sep 17 00:00:00 2001 From: Furqaanahmed Khan Date: Fri, 12 Jul 2024 15:08:26 -0400 Subject: [PATCH 06/14] fix: optimize imports and spark registration --- python/sedona/sql/st_functions.py | 9 +++++++-- python/tests/sql/test_dataframe_api.py | 1 + python/tests/sql/test_function.py | 5 +++++ .../spark/sql/sedona_sql/expressions/Functions.scala | 7 +++++-- .../spark/sql/sedona_sql/expressions/st_functions.scala | 6 ++++++ 5 files changed, 24 insertions(+), 4 deletions(-) diff --git a/python/sedona/sql/st_functions.py b/python/sedona/sql/st_functions.py index 9bec9813ce..8100da893f 100644 --- a/python/sedona/sql/st_functions.py +++ b/python/sedona/sql/st_functions.py @@ -549,7 +549,7 @@ def ST_Force_2D(geometry: ColumnOrName) -> Column: @validate_argument_types -def ST_GeneratePoints(geometry: ColumnOrName, numPoints: Union[ColumnOrName, int]) -> Column: +def ST_GeneratePoints(geometry: ColumnOrName, numPoints: Union[ColumnOrName, int], seed:Optional[Union[ColumnOrName, int]] = None) -> Column: """Generate random points in given geometry. :param geometry: Geometry column to hash. @@ -559,7 +559,12 @@ def ST_GeneratePoints(geometry: ColumnOrName, numPoints: Union[ColumnOrName, int :return: Generate random points in given geometry :rtype: Column """ - return _call_st_function("ST_GeneratePoints", (geometry, numPoints)) + if seed is None: + args = (geometry, numPoints) + else: + args = (geometry, numPoints, seed) + + return _call_st_function("ST_GeneratePoints", args) @validate_argument_types def ST_GeoHash(geometry: ColumnOrName, precision: Union[ColumnOrName, int]) -> Column: diff --git a/python/tests/sql/test_dataframe_api.py b/python/tests/sql/test_dataframe_api.py index 758824d3bc..49a2e68f8d 100644 --- a/python/tests/sql/test_dataframe_api.py +++ b/python/tests/sql/test_dataframe_api.py @@ -138,6 +138,7 @@ (stf.ST_FrechetDistance, ("point", "line",), "point_and_line", "", 5.0990195135927845), (stf.ST_GeometricMedian, ("multipoint",), "multipoint_geom", "", "POINT (22.500002656424286 21.250001168173426)"), (stf.ST_GeneratePoints, ("geom", 15), "square_geom", "ST_NumGeometries(geom)", 15), + (stf.ST_GeneratePoints, ("geom", 15, 100), "square_geom", "ST_NumGeometries(geom)", 15), (stf.ST_GeometryN, ("geom", 0), "multipoint", "", "POINT (0 0)"), (stf.ST_GeometryType, ("point",), "point_geom", "", "ST_Point"), (stf.ST_HausdorffDistance, ("point", "line",), "point_and_line", "", 5.0990195135927845), diff --git a/python/tests/sql/test_function.py b/python/tests/sql/test_function.py index e361a1a832..71eae37526 100644 --- a/python/tests/sql/test_function.py +++ b/python/tests/sql/test_function.py @@ -1611,6 +1611,11 @@ def test_generate_points(self): .first()[0] assert actual == 15 + actual = self.spark.sql( + "SELECT ST_NumGeometries(ST_GeneratePoints(ST_Buffer(ST_GeomFromWKT('LINESTRING(50 50,150 150,150 50)'), 10, false, 'endcap=round join=round'), 15, 100))") \ + .first()[0] + assert actual == 15 + actual = self.spark.sql("SELECT ST_NumGeometries(ST_GeneratePoints(ST_GeomFromWKT('MULTIPOLYGON (((10 0, 10 10, 20 10, 20 0, 10 0)), ((50 0, 50 10, 70 10, 70 0, 50 0)))'), 30))")\ .first()[0] assert actual == 30 diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index 7b38a22859..6fe3263078 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -1444,9 +1444,9 @@ case class ST_GeneratePoints(inputExpressions: Seq[Expression], randomSeed: Opti with CodegenFallback with ExpressionWithRandomSeed { - def this(inputExpressions: Seq[Expression]) = this(inputExpressions, None) + def this(inputExpressions: Seq[Expression]) = this(inputExpressions, Some(0L)) - def seedExpression: Expression = randomSeed.map(Literal.apply).getOrElse(Literal(0L)) + override def seedExpression: Expression = randomSeed.map(Literal.apply).getOrElse(Literal(0L)) @transient private[this] var random: Random = _ @@ -1483,6 +1483,9 @@ case class ST_GeneratePoints(inputExpressions: Seq[Expression], randomSeed: Opti } } + override lazy val resolved: Boolean = + childrenResolved && checkInputDataTypes().isSuccess && randomSeed.isDefined + override def children: Seq[Expression] = inputExpressions override def withNewSeed(seed: Long): ST_GeneratePoints = copy(randomSeed = Some(seed)) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala index 5427b555f3..e28cd120a2 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala @@ -682,6 +682,12 @@ object st_functions extends DataFrameAPI { wrapExpression[ST_GeneratePoints](geometry, numPoints) def ST_GeneratePoints(geometry: String, numPoints: Integer): Column = wrapExpression[ST_GeneratePoints](geometry, numPoints) + def ST_GeneratePoints(geometry: String, numPoints: Integer, seed: Integer): Column = + wrapExpression[ST_GeneratePoints](geometry, numPoints, seed) + def ST_GeneratePoints(geometry: String, numPoints: String, seed: String): Column = + wrapExpression[ST_GeneratePoints](geometry, numPoints, seed) + def ST_GeneratePoints(geometry: Column, numPoints: Column, seed: Column): Column = + wrapExpression[ST_GeneratePoints](geometry, numPoints, seed) def ST_NRings(geometry: Column): Column = wrapExpression[ST_NRings](geometry) From 23b1843251a60b7a8842b75a3986b872fe2e9746 Mon Sep 17 00:00:00 2001 From: Furqaanahmed Khan Date: Fri, 12 Jul 2024 15:13:05 -0400 Subject: [PATCH 07/14] fix: spark registration 1 --- .../org/apache/spark/sql/sedona_sql/expressions/Functions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index 6fe3263078..9a662ec669 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -1446,7 +1446,7 @@ case class ST_GeneratePoints(inputExpressions: Seq[Expression], randomSeed: Opti def this(inputExpressions: Seq[Expression]) = this(inputExpressions, Some(0L)) - override def seedExpression: Expression = randomSeed.map(Literal.apply).getOrElse(Literal(0L)) + def seedExpression: Expression = randomSeed.map(Literal.apply).getOrElse(Literal(0L)) @transient private[this] var random: Random = _ From 1e578f8efa2f560f69673ac08942b098afed3589 Mon Sep 17 00:00:00 2001 From: Furqaanahmed Khan Date: Mon, 15 Jul 2024 12:29:22 -0400 Subject: [PATCH 08/14] fix: spark registration 2 and remove unnecessary code --- .../common/utils/RandomPointsBuilderSeed.java | 8 ++--- .../apache/sedona/common/FunctionsTest.java | 3 +- .../sedona_sql/expressions/Functions.scala | 36 ++++++++++--------- 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/common/src/main/java/org/apache/sedona/common/utils/RandomPointsBuilderSeed.java b/common/src/main/java/org/apache/sedona/common/utils/RandomPointsBuilderSeed.java index 24269cbfb4..c42e450587 100644 --- a/common/src/main/java/org/apache/sedona/common/utils/RandomPointsBuilderSeed.java +++ b/common/src/main/java/org/apache/sedona/common/utils/RandomPointsBuilderSeed.java @@ -23,15 +23,14 @@ import org.locationtech.jts.shape.random.RandomPointsBuilder; public class RandomPointsBuilderSeed extends RandomPointsBuilder { - double seed; - int counter = 0; + long seed; Random rand; - public RandomPointsBuilderSeed(GeometryFactory geometryFactory, double seed) { + public RandomPointsBuilderSeed(GeometryFactory geometryFactory, long seed) { super(geometryFactory); this.seed = seed; if (seed > 0) { - this.rand = new Random((long) seed); + this.rand = new Random(seed); return; } this.rand = new Random(); @@ -39,7 +38,6 @@ public RandomPointsBuilderSeed(GeometryFactory geometryFactory, double seed) { @Override protected Coordinate createRandomCoord(Envelope env) { - counter++; double x = env.getMinX() + env.getWidth() * rand.nextDouble(); double y = env.getMinY() + env.getHeight() * rand.nextDouble(); return createCoord(x, y); diff --git a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java index 96211c86d8..8601044567 100644 --- a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java +++ b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java @@ -2114,7 +2114,7 @@ public void minimumClearanceLine() throws ParseException { @Test public void generatePoints() throws ParseException { - Geometry geom = Constructors.geomFromEWKT("LINESTRING(50 50,10 10,10 50)"); + Geometry geom = Constructors.geomFromWKT("LINESTRING(50 50,10 10,10 50)", 4326); Geometry actual = Functions.generatePoints(Functions.buffer(geom, 10, false, "endcap=round join=round"), 12); @@ -2128,6 +2128,7 @@ public void generatePoints() throws ParseException { String expected = "MULTIPOINT ((40.02957 46.70645), (37.11646 37.38582), (14.2051 29.23363), (40.82533 31.47273), (28.16839 34.16338))"; assertEquals(expected, actual.toString()); + assertEquals(4326, actual.getSRID()); geom = Constructors.geomFromEWKT( diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index 9a662ec669..9125817b50 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.sedona_sql.expressions import org.apache.sedona.common.{Functions, FunctionsGeoTools} import org.apache.sedona.common.sphere.{Haversine, Spheroid} -import org.apache.sedona.common.utils.{InscribedCircle, ValidDetail} +import org.apache.sedona.common.utils.{InscribedCircle, RandomPointsBuilderSeed, ValidDetail} import org.apache.sedona.sql.utils.GeometrySerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback @@ -33,6 +33,7 @@ import org.locationtech.jts.algorithm.MinimumBoundingCircle import org.locationtech.jts.geom._ import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils import scala.util.Random @@ -1437,36 +1438,36 @@ case class ST_ForceRHR(inputExpressions: Seq[Expression]) } } -case class ST_GeneratePoints(inputExpressions: Seq[Expression], randomSeed: Option[Long] = None) +case class ST_GeneratePoints(inputExpressions: Seq[Expression], randomSeed: Long) extends Expression - with Nondeterministic - with ExpectsInputTypes with CodegenFallback + with ExpectsInputTypes + with Nondeterministic with ExpressionWithRandomSeed { - def this(inputExpressions: Seq[Expression]) = this(inputExpressions, Some(0L)) + def this(inputExpressions: Seq[Expression]) = this(inputExpressions, Utils.random.nextLong()) - def seedExpression: Expression = randomSeed.map(Literal.apply).getOrElse(Literal(0L)) + override def seedExpression: Expression = Literal(randomSeed) @transient private[this] var random: Random = _ - protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { - copy(inputExpressions = newChildren) - } - private val nArgs = children.length override protected def initializeInternal(partitionIndex: Int): Unit = random = new Random( - randomSeed.getOrElse(0L) + partitionIndex) + randomSeed + partitionIndex) override protected def evalInternal(input: InternalRow): Any = { val geom = children.head.toGeometry(input) val numPoints = children(1).eval(input).asInstanceOf[Int] - if (nArgs == 3) { + val randomPointsBuilder = if (nArgs == 3) { val seed = children(2).eval(input).asInstanceOf[Int] - return GeometrySerializer.serialize(Functions.generatePoints(geom, numPoints, seed)) + new RandomPointsBuilderSeed(geom.getFactory, seed) + } else { + new RandomPointsBuilderSeed(geom.getFactory, 0) } - GeometrySerializer.serialize(Functions.generatePoints(geom, numPoints)) + randomPointsBuilder.setExtent(geom) + randomPointsBuilder.setNumPoints(numPoints) + GeometrySerializer.serialize(randomPointsBuilder.getGeometry) } override def nullable: Boolean = true @@ -1483,12 +1484,13 @@ case class ST_GeneratePoints(inputExpressions: Seq[Expression], randomSeed: Opti } } - override lazy val resolved: Boolean = - childrenResolved && checkInputDataTypes().isSuccess && randomSeed.isDefined + override def withNewSeed(seed: Long): ST_GeneratePoints = copy(randomSeed = seed) override def children: Seq[Expression] = inputExpressions - override def withNewSeed(seed: Long): ST_GeneratePoints = copy(randomSeed = Some(seed)) + protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { + copy(inputExpressions = newChildren) + } } case class ST_NRings(inputExpressions: Seq[Expression]) From 9e22c5b4e30ff1b1fc13e6e5fc3c2e2e2eb0afd5 Mon Sep 17 00:00:00 2001 From: Furqaanahmed Khan Date: Mon, 15 Jul 2024 12:39:14 -0400 Subject: [PATCH 09/14] fix: spark registration 3 --- .../org/apache/spark/sql/sedona_sql/expressions/Functions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index 9125817b50..5c27f86872 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -1447,7 +1447,7 @@ case class ST_GeneratePoints(inputExpressions: Seq[Expression], randomSeed: Long def this(inputExpressions: Seq[Expression]) = this(inputExpressions, Utils.random.nextLong()) - override def seedExpression: Expression = Literal(randomSeed) + def seedExpression: Expression = Literal(randomSeed) @transient private[this] var random: Random = _ From 39087dfadd0a4a54b21face3119c23b2c5f83e83 Mon Sep 17 00:00:00 2001 From: Furqaanahmed Khan Date: Mon, 15 Jul 2024 22:50:10 -0400 Subject: [PATCH 10/14] fix: remove ExpressionWithRandomSeed --- .../spark/sql/sedona_sql/expressions/Functions.scala | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index 5c27f86872..f1ffae43d2 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -1442,13 +1442,10 @@ case class ST_GeneratePoints(inputExpressions: Seq[Expression], randomSeed: Long extends Expression with CodegenFallback with ExpectsInputTypes - with Nondeterministic - with ExpressionWithRandomSeed { + with Nondeterministic { def this(inputExpressions: Seq[Expression]) = this(inputExpressions, Utils.random.nextLong()) - def seedExpression: Expression = Literal(randomSeed) - @transient private[this] var random: Random = _ private val nArgs = children.length @@ -1484,8 +1481,6 @@ case class ST_GeneratePoints(inputExpressions: Seq[Expression], randomSeed: Long } } - override def withNewSeed(seed: Long): ST_GeneratePoints = copy(randomSeed = seed) - override def children: Seq[Expression] = inputExpressions protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { From 0925e744797bf6feb7c1f35f05212ad5d43fa13c Mon Sep 17 00:00:00 2001 From: Furqaanahmed Khan Date: Mon, 15 Jul 2024 23:12:22 -0400 Subject: [PATCH 11/14] docs: add new seed parameter to docs --- docs/api/flink/Function.md | 8 ++++++-- docs/api/snowflake/vector-data/Function.md | 8 ++++++-- docs/api/sql/Function.md | 8 ++++++-- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/docs/api/flink/Function.md b/docs/api/flink/Function.md index 10ff7ddfee..1b2163d193 100644 --- a/docs/api/flink/Function.md +++ b/docs/api/flink/Function.md @@ -1554,9 +1554,13 @@ Output: ## ST_GeneratePoints -Introduction: Generates a specified quantity of pseudo-random points within the boundaries of the provided polygonal geometry. +Introduction: Generates a specified quantity of pseudo-random points within the boundaries of the provided polygonal geometry. When `seed` is either zero or not defined then output will be random. -Format: `ST_GeneratePoints(geom: Geometry, numPoints: Integer)` +Format: + +`ST_GeneratePoints(geom: Geometry, numPoints: Integer, seed: Long = 0)` + +`ST_GeneratePoints(geom: Geometry, numPoints: Integer)` Since: `v1.6.1` diff --git a/docs/api/snowflake/vector-data/Function.md b/docs/api/snowflake/vector-data/Function.md index 8e0b590a4a..196afa55ff 100644 --- a/docs/api/snowflake/vector-data/Function.md +++ b/docs/api/snowflake/vector-data/Function.md @@ -1165,9 +1165,13 @@ Output: ## ST_GeneratePoints -Introduction: Generates a specified quantity of pseudo-random points within the boundaries of the provided polygonal geometry. +Introduction: Generates a specified quantity of pseudo-random points within the boundaries of the provided polygonal geometry. When `seed` is either zero or not defined then output will be random. -Format: `ST_GeneratePoints(geom: Geometry, numPoints: Integer)` +Format: + +`ST_GeneratePoints(geom: Geometry, numPoints: Integer, seed: Long = 0)` + +`ST_GeneratePoints(geom: Geometry, numPoints: Integer)` SQL Example: diff --git a/docs/api/sql/Function.md b/docs/api/sql/Function.md index 224b2ba2c0..9d70cee0c5 100644 --- a/docs/api/sql/Function.md +++ b/docs/api/sql/Function.md @@ -1559,9 +1559,13 @@ Output: ## ST_GeneratePoints -Introduction: Generates a specified quantity of pseudo-random points within the boundaries of the provided polygonal geometry. +Introduction: Generates a specified quantity of pseudo-random points within the boundaries of the provided polygonal geometry. When `seed` is either zero or not defined then output will be random. -Format: `ST_GeneratePoints(geom: Geometry, numPoints: Integer)` +Format: + +`ST_GeneratePoints(geom: Geometry, numPoints: Integer, seed: Long = 0)` + +`ST_GeneratePoints(geom: Geometry, numPoints: Integer)` Since: `v1.6.1` From 7f63a99be569eec79adaa9324ef3e479cc4cebee Mon Sep 17 00:00:00 2001 From: Furqaanahmed Khan Date: Tue, 16 Jul 2024 10:39:57 -0400 Subject: [PATCH 12/14] chore: add the new parameter to all engines and add tests --- .../org/apache/sedona/common/Functions.java | 2 +- .../sedona/flink/expressions/Functions.java | 9 +++++++++ .../org/apache/sedona/flink/FunctionTest.java | 20 ++++++++++++++++++- python/tests/sql/test_function.py | 6 ++++++ .../snowflake/snowsql/TestFunctions.java | 3 +++ .../snowflake/snowsql/TestFunctionsV2.java | 3 +++ .../apache/sedona/snowflake/snowsql/UDFs.java | 6 ++++++ .../sedona/snowflake/snowsql/UDFsV2.java | 9 +++++++++ 8 files changed, 56 insertions(+), 2 deletions(-) diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java b/common/src/main/java/org/apache/sedona/common/Functions.java index 3b213a94cc..87a6ebd0db 100644 --- a/common/src/main/java/org/apache/sedona/common/Functions.java +++ b/common/src/main/java/org/apache/sedona/common/Functions.java @@ -1812,7 +1812,7 @@ private static Geometry[] convertGeometryToArray(Geometry geom) { return array; } - public static Geometry generatePoints(Geometry geom, int numPoints, int seed) { + public static Geometry generatePoints(Geometry geom, int numPoints, long seed) { RandomPointsBuilderSeed pointsBuilder = new RandomPointsBuilderSeed(geom.getFactory(), seed); pointsBuilder.setExtent(geom); pointsBuilder.setNumPoints(numPoints); diff --git a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java index 782f63510a..9b378e0e7d 100644 --- a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java +++ b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java @@ -1557,6 +1557,15 @@ public Geometry eval( Geometry geom = (Geometry) o; return org.apache.sedona.common.Functions.generatePoints(geom, numPoints); } + + @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) + public Geometry eval( + @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) Object o, + @DataTypeHint(value = "Integer") Integer numPoints, + @DataTypeHint(value = "BIGINT") Long seed) { + Geometry geom = (Geometry) o; + return org.apache.sedona.common.Functions.generatePoints(geom, numPoints, seed); + } } public static class ST_NRings extends ScalarFunction { diff --git a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java index 136360ee35..ea2cb07549 100644 --- a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java +++ b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java @@ -2146,7 +2146,7 @@ public void testIsPolygonCW() { public void testGeneratePoints() { Table polyTable = tableEnv.sqlQuery( - "SELECT ST_Buffer(ST_GeomFromWKT('LINESTRING(50 50,150 150,150 50)'), 10, false, 'endcap=round join=round') AS geom"); + "SELECT ST_Buffer(ST_GeomFromWKT('LINESTRING(50 50,10 10,10 50)'), 10, false, 'endcap=round join=round') AS geom"); Geometry actual = (Geometry) first( @@ -2155,6 +2155,24 @@ public void testGeneratePoints() { .getField(0); assertEquals(actual.getNumGeometries(), 15); + actual = + (Geometry) + first( + polyTable + .select( + call( + Functions.ST_GeneratePoints.class.getSimpleName(), + $("geom"), + 5, + 100L)) + .as("geom") + .select( + call(Functions.ST_ReducePrecision.class.getSimpleName(), $("geom"), 5))) + .getField(0); + String expected = + "MULTIPOINT ((40.02957 46.70645), (37.11646 37.38582), (14.2051 29.23363), (40.82533 31.47273), (28.16839 34.16338))"; + assertEquals(expected, actual.toString()); + polyTable = tableEnv.sqlQuery( "SELECT ST_GeomFromWKT('MULTIPOLYGON (((10 0, 10 10, 20 10, 20 0, 10 0)), ((50 0, 50 10, 70 10, 70 0, 50 0)))') AS geom"); diff --git a/python/tests/sql/test_function.py b/python/tests/sql/test_function.py index 71eae37526..5f923b5cbf 100644 --- a/python/tests/sql/test_function.py +++ b/python/tests/sql/test_function.py @@ -1611,6 +1611,12 @@ def test_generate_points(self): .first()[0] assert actual == 15 + actual = self.spark.sql( + "SELECT ST_AsText(ST_ReducePrecision(ST_GeneratePoints(ST_GeomFromWKT('MULTIPOLYGON (((10 0, 10 10, 20 10, 20 0, 10 0)), ((50 0, 50 10, 70 10, 70 0, 50 0)))'), 5, 10), 5))") \ + .first()[0] + expected = "MULTIPOINT ((53.82582 2.57803), (13.55212 2.44117), (59.12854 3.70611), (61.37698 7.14985), (10.49657 4.40622))" + assert expected == actual + actual = self.spark.sql( "SELECT ST_NumGeometries(ST_GeneratePoints(ST_Buffer(ST_GeomFromWKT('LINESTRING(50 50,150 150,150 50)'), 10, false, 'endcap=round join=round'), 15, 100))") \ .first()[0] diff --git a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java index 9efaae5cb3..7197c8d61f 100644 --- a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java +++ b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java @@ -401,6 +401,9 @@ public void test_ST_GeneratePoints() { verifySqlSingleRes( "select sedona.ST_NumGeometries(sedona.ST_GeneratePoints(sedona.ST_GeomFromWKT('POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))'), 15))", 15); + verifySqlSingleRes( + "select sedona.ST_NumGeometries(sedona.ST_GeneratePoints(sedona.ST_GeomFromWKT('POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))'), 15, 100))", + 15); } @Test diff --git a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java index 96ce0b3e91..83af3ec800 100644 --- a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java +++ b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java @@ -383,6 +383,9 @@ public void test_ST_GeneratePoints() { verifySqlSingleRes( "select sedona.ST_NumGeometries(sedona.ST_GeneratePoints(ST_GeomFromWKT('POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))'), 15))", 15); + verifySqlSingleRes( + "select sedona.ST_NumGeometries(sedona.ST_GeneratePoints(ST_GeomFromWKT('POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))'), 15, 100))", + 15); } @Test diff --git a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFs.java b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFs.java index 0dcb82ba9f..2a6429dd65 100644 --- a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFs.java +++ b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFs.java @@ -372,6 +372,12 @@ public static byte[] ST_GeneratePoints(byte[] geometry, int numPoints) { Functions.generatePoints(GeometrySerde.deserialize(geometry), numPoints)); } + @UDFAnnotations.ParamMeta(argNames = {"geometry", "numPoints", "seed"}) + public static byte[] ST_GeneratePoints(byte[] geometry, int numPoints, long seed) { + return GeometrySerde.serialize( + Functions.generatePoints(GeometrySerde.deserialize(geometry), numPoints, seed)); + } + @UDFAnnotations.ParamMeta(argNames = {"geometry", "precision"}) public static String ST_GeoHash(byte[] geometry, int precision) { return Functions.geohash(GeometrySerde.deserialize(geometry), precision); diff --git a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java index de48b186eb..1b78807855 100644 --- a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java +++ b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java @@ -518,6 +518,15 @@ public static String ST_GeneratePoints(String geometry, int numPoints) { Functions.generatePoints(GeometrySerde.deserGeoJson(geometry), numPoints)); } + @UDFAnnotations.ParamMeta( + argNames = {"geometry", "numPoints", "seed"}, + argTypes = {"Geometry", "int", "long"}, + returnTypes = "Geometry") + public static String ST_GeneratePoints(String geometry, int numPoints, long seed) { + return GeometrySerde.serGeoJson( + Functions.generatePoints(GeometrySerde.deserGeoJson(geometry), numPoints, seed)); + } + @UDFAnnotations.ParamMeta( argNames = {"geometry", "precision"}, argTypes = {"Geometry", "int"}) From 0b40b5cd27da7ed337c2cdcc5ab9f8865d59be5c Mon Sep 17 00:00:00 2001 From: Furqaanahmed Khan Date: Tue, 16 Jul 2024 13:44:19 -0400 Subject: [PATCH 13/14] fix: snowflake test --- .../java/org/apache/sedona/snowflake/snowsql/TestFunctions.java | 2 ++ .../org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java | 2 ++ 2 files changed, 4 insertions(+) diff --git a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java index 7197c8d61f..f59f639050 100644 --- a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java +++ b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java @@ -401,6 +401,8 @@ public void test_ST_GeneratePoints() { verifySqlSingleRes( "select sedona.ST_NumGeometries(sedona.ST_GeneratePoints(sedona.ST_GeomFromWKT('POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))'), 15))", 15); + + registerUDF("ST_GeneratePoints", byte[].class, int.class, long.class); verifySqlSingleRes( "select sedona.ST_NumGeometries(sedona.ST_GeneratePoints(sedona.ST_GeomFromWKT('POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))'), 15, 100))", 15); diff --git a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java index 83af3ec800..896cd9cb74 100644 --- a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java +++ b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java @@ -383,6 +383,8 @@ public void test_ST_GeneratePoints() { verifySqlSingleRes( "select sedona.ST_NumGeometries(sedona.ST_GeneratePoints(ST_GeomFromWKT('POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))'), 15))", 15); + + registerUDFV2("ST_GeneratePoints", String.class, int.class, long.class); verifySqlSingleRes( "select sedona.ST_NumGeometries(sedona.ST_GeneratePoints(ST_GeomFromWKT('POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))'), 15, 100))", 15); From 9d47badc13caf4e412ed595af7558ff4be16deca Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Thu, 18 Jul 2024 03:19:10 +0800 Subject: [PATCH 14/14] Use per-partition random number generator when seed is not specified. --- .../org/apache/sedona/common/Functions.java | 7 +++++ .../common/utils/RandomPointsBuilderSeed.java | 11 +++++--- .../apache/sedona/common/FunctionsTest.java | 17 ++++++++++++ .../sedona_sql/expressions/Functions.scala | 26 +++++++++---------- .../apache/sedona/sql/PreserveSRIDSuite.scala | 3 ++- 5 files changed, 46 insertions(+), 18 deletions(-) diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java b/common/src/main/java/org/apache/sedona/common/Functions.java index 87a6ebd0db..f0ffd0bc57 100644 --- a/common/src/main/java/org/apache/sedona/common/Functions.java +++ b/common/src/main/java/org/apache/sedona/common/Functions.java @@ -1823,6 +1823,13 @@ public static Geometry generatePoints(Geometry geom, int numPoints) { return generatePoints(geom, numPoints, 0); } + public static Geometry generatePoints(Geometry geom, int numPoints, Random random) { + RandomPointsBuilderSeed pointsBuilder = new RandomPointsBuilderSeed(geom.getFactory(), random); + pointsBuilder.setExtent(geom); + pointsBuilder.setNumPoints(numPoints); + return pointsBuilder.getGeometry(); + } + public static Integer nRings(Geometry geometry) throws Exception { String geometryType = geometry.getGeometryType(); if (!(geometry instanceof Polygon || geometry instanceof MultiPolygon)) { diff --git a/common/src/main/java/org/apache/sedona/common/utils/RandomPointsBuilderSeed.java b/common/src/main/java/org/apache/sedona/common/utils/RandomPointsBuilderSeed.java index c42e450587..0d396f3923 100644 --- a/common/src/main/java/org/apache/sedona/common/utils/RandomPointsBuilderSeed.java +++ b/common/src/main/java/org/apache/sedona/common/utils/RandomPointsBuilderSeed.java @@ -23,17 +23,20 @@ import org.locationtech.jts.shape.random.RandomPointsBuilder; public class RandomPointsBuilderSeed extends RandomPointsBuilder { - long seed; Random rand; public RandomPointsBuilderSeed(GeometryFactory geometryFactory, long seed) { super(geometryFactory); - this.seed = seed; if (seed > 0) { this.rand = new Random(seed); - return; + } else { + this.rand = new Random(); } - this.rand = new Random(); + } + + public RandomPointsBuilderSeed(GeometryFactory geometryFactory, Random random) { + super(geometryFactory); + this.rand = random; } @Override diff --git a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java index 8601044567..a4a3a8c9a9 100644 --- a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java +++ b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java @@ -2135,6 +2135,23 @@ public void generatePoints() throws ParseException { "MULTIPOLYGON (((10 0, 10 10, 20 10, 20 0, 10 0)), ((50 0, 50 10, 70 10, 70 0, 50 0)))"); actual = Functions.generatePoints(geom, 30); assertEquals(actual.getNumGeometries(), 30); + + // Deterministic when using the same seed + Geometry first = Functions.generatePoints(geom, 10, 100); + Geometry second = Functions.generatePoints(geom, 10, 100); + assertEquals(first, second); + + // Deterministic when using the same random number generator + geom = geom.buffer(10, 48); + Random rand = new Random(100); + Random rand2 = new Random(100); + first = Functions.generatePoints(geom, 100, rand); + second = Functions.generatePoints(geom, 100, rand); + Geometry first2 = Functions.generatePoints(geom, 100, rand2); + Geometry second2 = Functions.generatePoints(geom, 100, rand2); + assertNotEquals(first, second); + assertEquals(first, first2); + assertEquals(second, second2); } @Test diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index f1ffae43d2..70e3582f20 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.sedona_sql.expressions import org.apache.sedona.common.{Functions, FunctionsGeoTools} import org.apache.sedona.common.sphere.{Haversine, Spheroid} -import org.apache.sedona.common.utils.{InscribedCircle, RandomPointsBuilderSeed, ValidDetail} +import org.apache.sedona.common.utils.{InscribedCircle, ValidDetail} import org.apache.sedona.sql.utils.GeometrySerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionWithRandomSeed, Generator, Literal, Nondeterministic} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Generator, Nondeterministic} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT import org.apache.spark.sql.sedona_sql.expressions.implicits._ @@ -35,8 +35,6 @@ import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -import scala.util.Random - /** * Return the distance between two geometries. * @@ -1446,25 +1444,27 @@ case class ST_GeneratePoints(inputExpressions: Seq[Expression], randomSeed: Long def this(inputExpressions: Seq[Expression]) = this(inputExpressions, Utils.random.nextLong()) - @transient private[this] var random: Random = _ + @transient private[this] var random: java.util.Random = _ private val nArgs = children.length - override protected def initializeInternal(partitionIndex: Int): Unit = random = new Random( - randomSeed + partitionIndex) + override protected def initializeInternal(partitionIndex: Int): Unit = random = + new java.util.Random(randomSeed + partitionIndex) override protected def evalInternal(input: InternalRow): Any = { val geom = children.head.toGeometry(input) val numPoints = children(1).eval(input).asInstanceOf[Int] - val randomPointsBuilder = if (nArgs == 3) { + val generatedPoints = if (nArgs == 3) { val seed = children(2).eval(input).asInstanceOf[Int] - new RandomPointsBuilderSeed(geom.getFactory, seed) + if (seed > 0) { + Functions.generatePoints(geom, numPoints, seed) + } else { + Functions.generatePoints(geom, numPoints, random) + } } else { - new RandomPointsBuilderSeed(geom.getFactory, 0) + Functions.generatePoints(geom, numPoints, random) } - randomPointsBuilder.setExtent(geom) - randomPointsBuilder.setNumPoints(numPoints) - GeometrySerializer.serialize(randomPointsBuilder.getGeometry) + GeometrySerializer.serialize(generatedPoints) } override def nullable: Boolean = true diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/PreserveSRIDSuite.scala b/spark/common/src/test/scala/org/apache/sedona/sql/PreserveSRIDSuite.scala index 1d4cc9e8ba..414a9aeedc 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/PreserveSRIDSuite.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/PreserveSRIDSuite.scala @@ -108,7 +108,8 @@ class PreserveSRIDSuite extends TestBaseScala with TableDrivenPropertyChecks { ("ST_BoundingDiagonal(geom1)", 1000), ("ST_DelaunayTriangles(geom4)", 1000), ("ST_Rotate(geom1, 10)", 1000), - ("ST_Collect(geom1, geom2, geom3)", 1000)) + ("ST_Collect(geom1, geom2, geom3)", 1000), + ("ST_GeneratePoints(geom1, 3)", 1000)) forAll(testCases) { case (expression: String, srid: Int) => it(s"$expression") {