diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala index fff0822859..ccc7ffe96c 100644 --- a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala +++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes} import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT import org.apache.spark.sql.sedona_sql.expressions.implicits.GeometryEnhancer +import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -36,7 +37,7 @@ import org.apache.spark.unsafe.types.UTF8String * string, the second parameter is the delimiter. String format should be similar to CSV/TSV */ case class ST_PointFromText(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Constructors.pointFromText) with FoldableExpression { + extends InferredExpression(Constructors.pointFromText _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } @@ -48,7 +49,7 @@ case class ST_PointFromText(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_PolygonFromText(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Constructors.polygonFromText) with FoldableExpression { + extends InferredExpression(Constructors.polygonFromText _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } @@ -60,7 +61,7 @@ case class ST_PolygonFromText(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_LineFromText(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Constructors.lineFromText) with FoldableExpression { + extends InferredExpression(Constructors.lineFromText _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } @@ -71,7 +72,7 @@ case class ST_LineFromText(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_LineStringFromText(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Constructors.lineStringFromText) with FoldableExpression { + extends InferredExpression(Constructors.lineStringFromText _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } @@ -83,7 +84,7 @@ case class ST_LineStringFromText(inputExpressions: Seq[Expression]) * @param inputExpressions This function takes a geometry string and a srid. The string format must be WKT. */ case class ST_GeomFromWKT(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Constructors.geomFromWKT) with FoldableExpression { + extends InferredExpression(Constructors.geomFromWKT _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -97,7 +98,7 @@ case class ST_GeomFromWKT(inputExpressions: Seq[Expression]) * @param inputExpressions This function takes a geometry string and a srid. The string format must be WKT. */ case class ST_GeomFromText(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Constructors.geomFromWKT) with FoldableExpression { + extends InferredExpression(Constructors.geomFromWKT _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -180,7 +181,7 @@ case class ST_GeomFromGeoJSON(inputExpressions: Seq[Expression]) * @param inputExpressions This function takes 2 parameter which are point x, y. */ case class ST_Point(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Constructors.point) with FoldableExpression { + extends InferredExpression(Constructors.point _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -193,7 +194,7 @@ case class ST_Point(inputExpressions: Seq[Expression]) * @param inputExpressions This function takes 4 parameter which are point x, y, z and srid (default 0). */ case class ST_PointZ(inputExpressions: Seq[Expression]) - extends InferredQuarternaryExpression(Constructors.pointZ) with FoldableExpression { + extends InferredExpression(Constructors.pointZ _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -207,7 +208,7 @@ case class ST_PointZ(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_PolygonFromEnvelope(inputExpressions: Seq[Expression]) - extends InferredQuarternaryExpression(Constructors.polygonFromEnvelope) with FoldableExpression { + extends InferredExpression(Constructors.polygonFromEnvelope _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -226,22 +227,21 @@ trait UserDataGeneratator { } case class ST_GeomFromGeoHash(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Constructors.geomFromGeoHash) with FoldableExpression { + extends InferredExpression(InferrableFunction.allowRightNull(Constructors.geomFromGeoHash)) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } - override def allowRightNull: Boolean = true } case class ST_GeomFromGML(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Constructors.geomFromGML) with FoldableExpression { + extends InferredExpression(Constructors.geomFromGML _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } } case class ST_GeomFromKML(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Constructors.geomFromKML) with FoldableExpression { + extends InferredExpression(Constructors.geomFromKML _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } @@ -253,7 +253,7 @@ case class ST_GeomFromKML(inputExpressions: Seq[Expression]) * @param inputExpressions This function takes a geometry string and a srid. The string format must be WKT. */ case class ST_MPolyFromText(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Constructors.mPolyFromText) with FoldableExpression { + extends InferredExpression(Constructors.mPolyFromText _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } @@ -265,7 +265,7 @@ case class ST_MPolyFromText(inputExpressions: Seq[Expression]) * @param inputExpressions This function takes a geometry string and a srid. The string format must be WKT. */ case class ST_MLineFromText(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Constructors.mLineFromText) with FoldableExpression { + extends InferredExpression(Constructors.mLineFromText _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/FoldableExpression.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/FoldableExpression.scala new file mode 100644 index 0000000000..08c0acb2ca --- /dev/null +++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/FoldableExpression.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.expressions + +import org.apache.spark.sql.catalyst.expressions.Expression + +/** + * Make expression foldable by constant folding optimizer. If all children + * expressions are foldable, then the expression itself is foldable. + */ +trait FoldableExpression extends Expression { + override def foldable: Boolean = children.forall(_.foldable) +} diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index a324da69bc..c084eec1e6 100644 --- a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.sedona_sql.expressions.implicits._ import org.apache.spark.sql.types._ import org.locationtech.jts.algorithm.MinimumBoundingCircle import org.locationtech.jts.geom._ +import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._ /** * Return the distance between two geometries. @@ -36,7 +37,7 @@ import org.locationtech.jts.geom._ * @param inputExpressions This function takes two geometries and calculates the distance between two objects. */ case class ST_Distance(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.distance) with FoldableExpression { + extends InferredExpression(Functions.distance _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -45,7 +46,7 @@ case class ST_Distance(inputExpressions: Seq[Expression]) case class ST_YMax(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.yMax) with FoldableExpression { + extends InferredExpression(Functions.yMax _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -53,7 +54,7 @@ case class ST_YMax(inputExpressions: Seq[Expression]) } case class ST_YMin(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.yMin) with FoldableExpression { + extends InferredExpression(Functions.yMin _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -66,7 +67,7 @@ case class ST_YMin(inputExpressions: Seq[Expression]) * @param inputExpressions This function takes a geometry and returns the maximum of all Z-coordinate values. */ case class ST_ZMax(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.zMax) with FoldableExpression { + extends InferredExpression(Functions.zMax _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -79,7 +80,7 @@ case class ST_ZMax(inputExpressions: Seq[Expression]) * @param inputExpressions This function takes a geometry and returns the minimum of all Z-coordinate values. */ case class ST_ZMin(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.zMin) with FoldableExpression { + extends InferredExpression(Functions.zMin _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -87,7 +88,7 @@ case class ST_ZMin(inputExpressions: Seq[Expression]) } case class ST_3DDistance(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.distance3d) with FoldableExpression { + extends InferredExpression(Functions.distance3d _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -100,7 +101,7 @@ case class ST_3DDistance(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_ConcaveHull(inputExpressions: Seq[Expression]) - extends InferredTernaryExpression(Functions.concaveHull) with FoldableExpression { + extends InferredExpression(Functions.concaveHull _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = { copy(inputExpressions = newChildren) @@ -113,7 +114,7 @@ case class ST_ConcaveHull(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_ConvexHull(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.convexHull) with FoldableExpression { + extends InferredExpression(Functions.convexHull _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -126,7 +127,7 @@ case class ST_ConvexHull(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_NPoints(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.nPoints) with FoldableExpression { + extends InferredExpression(Functions.nPoints _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -139,7 +140,7 @@ case class ST_NPoints(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_NDims(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.nDims) with FoldableExpression { + extends InferredExpression(Functions.nDims _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -152,7 +153,7 @@ case class ST_NDims(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_Buffer(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.buffer) with FoldableExpression { + extends InferredExpression(Functions.buffer _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -166,7 +167,7 @@ case class ST_Buffer(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_Envelope(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.envelope) with FoldableExpression { + extends InferredExpression(Functions.envelope _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -179,7 +180,7 @@ case class ST_Envelope(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_Length(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.length) with FoldableExpression { + extends InferredExpression(Functions.length _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -192,7 +193,7 @@ case class ST_Length(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_Area(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.area) with FoldableExpression { + extends InferredExpression(Functions.area _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -205,7 +206,7 @@ case class ST_Area(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_Centroid(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.getCentroid) with FoldableExpression { + extends InferredExpression(Functions.getCentroid _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -218,7 +219,7 @@ case class ST_Centroid(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_Transform(inputExpressions: Seq[Expression]) - extends InferredQuarternaryExpression(Functions.transform) with FoldableExpression { + extends InferredExpression(inferrableFunction4(Functions.transform)) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -231,7 +232,7 @@ case class ST_Transform(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_Intersection(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.intersection) with FoldableExpression { + extends InferredExpression(Functions.intersection _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -245,7 +246,7 @@ case class ST_Intersection(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_MakeValid(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.makeValid) with FoldableExpression { + extends InferredExpression(Functions.makeValid _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -258,7 +259,7 @@ case class ST_MakeValid(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_IsValid(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.isValid) with FoldableExpression { + extends InferredExpression(Functions.isValid _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -271,7 +272,7 @@ case class ST_IsValid(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_IsSimple(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.isSimple) with FoldableExpression { + extends InferredExpression(Functions.isSimple _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -287,7 +288,7 @@ case class ST_IsSimple(inputExpressions: Seq[Expression]) * second arg is distance tolerance for the simplification(all vertices in the simplified geometry will be within this distance of the original geometry) */ case class ST_SimplifyPreserveTopology(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.simplifyPreserveTopology) with FoldableExpression { + extends InferredExpression(Functions.simplifyPreserveTopology _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -301,7 +302,7 @@ case class ST_SimplifyPreserveTopology(inputExpressions: Seq[Expression]) * be rounded to the nearest number. */ case class ST_PrecisionReduce(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.reducePrecision) with FoldableExpression { + extends InferredExpression(Functions.reducePrecision _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -309,7 +310,7 @@ case class ST_PrecisionReduce(inputExpressions: Seq[Expression]) } case class ST_AsText(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.asWKT) with FoldableExpression { + extends InferredExpression(Functions.asWKT _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -317,7 +318,7 @@ case class ST_AsText(inputExpressions: Seq[Expression]) } case class ST_AsGeoJSON(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.asGeoJson) with FoldableExpression { + extends InferredExpression(Functions.asGeoJson _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -325,7 +326,7 @@ case class ST_AsGeoJSON(inputExpressions: Seq[Expression]) } case class ST_AsBinary(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.asWKB) with FoldableExpression { + extends InferredExpression(Functions.asWKB _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -333,7 +334,7 @@ case class ST_AsBinary(inputExpressions: Seq[Expression]) } case class ST_AsEWKB(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.asEWKB) with FoldableExpression { + extends InferredExpression(Functions.asEWKB _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -341,7 +342,7 @@ case class ST_AsEWKB(inputExpressions: Seq[Expression]) } case class ST_SRID(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.getSRID) with FoldableExpression { + extends InferredExpression(Functions.getSRID _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -349,7 +350,7 @@ case class ST_SRID(inputExpressions: Seq[Expression]) } case class ST_SetSRID(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.setSRID) with FoldableExpression { + extends InferredExpression(Functions.setSRID _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -357,7 +358,7 @@ case class ST_SetSRID(inputExpressions: Seq[Expression]) } case class ST_GeometryType(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.geometryType) with FoldableExpression { + extends InferredExpression(Functions.geometryType _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -372,7 +373,7 @@ case class ST_GeometryType(inputExpressions: Seq[Expression]) * @param inputExpressions Geometry */ case class ST_LineMerge(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.lineMerge) with FoldableExpression { + extends InferredExpression(Functions.lineMerge _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -380,7 +381,7 @@ case class ST_LineMerge(inputExpressions: Seq[Expression]) } case class ST_Azimuth(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.azimuth) with FoldableExpression { + extends InferredExpression(Functions.azimuth _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -388,7 +389,7 @@ case class ST_Azimuth(inputExpressions: Seq[Expression]) } case class ST_X(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.x) with FoldableExpression { + extends InferredExpression(Functions.x _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -397,7 +398,7 @@ case class ST_X(inputExpressions: Seq[Expression]) case class ST_Y(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.y) with FoldableExpression { + extends InferredExpression(Functions.y _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -405,7 +406,7 @@ case class ST_Y(inputExpressions: Seq[Expression]) } case class ST_Z(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.z) with FoldableExpression { + extends InferredExpression(Functions.z _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -413,7 +414,7 @@ case class ST_Z(inputExpressions: Seq[Expression]) } case class ST_StartPoint(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.startPoint) with FoldableExpression { + extends InferredExpression(Functions.startPoint _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -421,7 +422,7 @@ case class ST_StartPoint(inputExpressions: Seq[Expression]) } case class ST_Boundary(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.boundary) with FoldableExpression { + extends InferredExpression(Functions.boundary _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -471,7 +472,7 @@ case class ST_MinimumBoundingRadius(inputExpressions: Seq[Expression]) case class ST_MinimumBoundingCircle(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.minimumBoundingCircle) with FoldableExpression { + extends InferredExpression(Functions.minimumBoundingCircle _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -486,7 +487,7 @@ case class ST_MinimumBoundingCircle(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_LineSubstring(inputExpressions: Seq[Expression]) - extends InferredTernaryExpression(Functions.lineSubString) with FoldableExpression { + extends InferredExpression(Functions.lineSubString _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -501,7 +502,7 @@ case class ST_LineSubstring(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_LineInterpolatePoint(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.lineInterpolatePoint) with FoldableExpression { + extends InferredExpression(Functions.lineInterpolatePoint _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -509,7 +510,7 @@ case class ST_LineInterpolatePoint(inputExpressions: Seq[Expression]) } case class ST_EndPoint(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.endPoint) with FoldableExpression { + extends InferredExpression(Functions.endPoint _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -517,7 +518,7 @@ case class ST_EndPoint(inputExpressions: Seq[Expression]) } case class ST_ExteriorRing(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.exteriorRing) with FoldableExpression { + extends InferredExpression(Functions.exteriorRing _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -526,7 +527,7 @@ case class ST_ExteriorRing(inputExpressions: Seq[Expression]) case class ST_GeometryN(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.geometryN) with FoldableExpression { + extends InferredExpression(Functions.geometryN _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -534,7 +535,7 @@ case class ST_GeometryN(inputExpressions: Seq[Expression]) } case class ST_InteriorRingN(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.interiorRingN) with FoldableExpression { + extends InferredExpression(Functions.interiorRingN _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -542,7 +543,7 @@ case class ST_InteriorRingN(inputExpressions: Seq[Expression]) } case class ST_Dump(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.dump) with FoldableExpression { + extends InferredExpression(Functions.dump _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -550,7 +551,7 @@ case class ST_Dump(inputExpressions: Seq[Expression]) } case class ST_DumpPoints(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.dumpPoints) with FoldableExpression { + extends InferredExpression(Functions.dumpPoints _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -559,7 +560,7 @@ case class ST_DumpPoints(inputExpressions: Seq[Expression]) case class ST_IsClosed(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.isClosed) with FoldableExpression { + extends InferredExpression(Functions.isClosed _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -567,7 +568,7 @@ case class ST_IsClosed(inputExpressions: Seq[Expression]) } case class ST_NumInteriorRings(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.numInteriorRings) with FoldableExpression { + extends InferredExpression(Functions.numInteriorRings _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -575,7 +576,7 @@ case class ST_NumInteriorRings(inputExpressions: Seq[Expression]) } case class ST_AddPoint(inputExpressions: Seq[Expression]) - extends InferredTernaryExpression(Functions.addPoint) with FoldableExpression { + extends InferredExpression(inferrableFunction3(Functions.addPoint)) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -583,7 +584,7 @@ case class ST_AddPoint(inputExpressions: Seq[Expression]) } case class ST_RemovePoint(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.removePoint) with FoldableExpression { + extends InferredExpression(inferrableFunction2(Functions.removePoint)) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -591,7 +592,7 @@ case class ST_RemovePoint(inputExpressions: Seq[Expression]) } case class ST_SetPoint(inputExpressions: Seq[Expression]) - extends InferredTernaryExpression(Functions.setPoint) with FoldableExpression { + extends InferredExpression(Functions.setPoint _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -599,24 +600,22 @@ case class ST_SetPoint(inputExpressions: Seq[Expression]) } case class ST_IsRing(inputExpressions: Seq[Expression]) - extends UnaryGeometryExpression with FoldableExpression with CodegenFallback { - - override protected def nullSafeEval(geometry: Geometry): Any = { - geometry match { - case string: LineString => Functions.isRing(string) - case _ => null - } - } - - override def dataType: DataType = BooleanType - - override def children: Seq[Expression] = inputExpressions + extends InferredExpression(ST_IsRing.isRing _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } } +object ST_IsRing { + def isRing(geom: Geometry): Option[Boolean] = { + geom match { + case _: LineString => Some(Functions.isRing(geom)) + case _ => None + } + } +} + /** * Returns the number of Geometries. If geometry is a GEOMETRYCOLLECTION (or MULTI*) return the number of geometries, * for single geometries will return 1 @@ -626,7 +625,7 @@ case class ST_IsRing(inputExpressions: Seq[Expression]) * @param inputExpressions Geometry */ case class ST_NumGeometries(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.numGeometries) with FoldableExpression { + extends InferredExpression(Functions.numGeometries _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -639,7 +638,7 @@ case class ST_NumGeometries(inputExpressions: Seq[Expression]) * @param inputExpressions Geometry */ case class ST_FlipCoordinates(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.flipCoordinates) with FoldableExpression { + extends InferredExpression(Functions.flipCoordinates _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -647,7 +646,7 @@ case class ST_FlipCoordinates(inputExpressions: Seq[Expression]) } case class ST_SubDivide(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.subDivide) with FoldableExpression { + extends InferredExpression(Functions.subDivide _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -680,17 +679,15 @@ case class ST_SubDivideExplode(children: Seq[Expression]) } case class ST_MakePolygon(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.makePolygon) with FoldableExpression { + extends InferredExpression(InferrableFunction.allowRightNull(Functions.makePolygon)) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } - - override def allowRightNull: Boolean = true } case class ST_GeoHash(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.geohash) with FoldableExpression { + extends InferredExpression(Functions.geohash _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -703,7 +700,7 @@ case class ST_GeoHash(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_Difference(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.difference) with FoldableExpression { + extends InferredExpression(Functions.difference _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -716,7 +713,7 @@ case class ST_Difference(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_SymDifference(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.symDifference) with FoldableExpression { + extends InferredExpression(Functions.symDifference _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -729,7 +726,7 @@ case class ST_SymDifference(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_Union(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.union) with FoldableExpression { + extends InferredExpression(Functions.union _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -737,7 +734,7 @@ case class ST_Union(inputExpressions: Seq[Expression]) } case class ST_Multi(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.createMultiGeometryFromOneElement) with FoldableExpression { + extends InferredExpression(Functions.createMultiGeometryFromOneElement _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -750,7 +747,7 @@ case class ST_Multi(inputExpressions: Seq[Expression]) * @param inputExpressions Geometry */ case class ST_PointOnSurface(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.pointOnSurface) with FoldableExpression { + extends InferredExpression(Functions.pointOnSurface _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -763,7 +760,7 @@ case class ST_PointOnSurface(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_Reverse(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.reverse) with FoldableExpression { + extends InferredExpression(Functions.reverse _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -776,7 +773,7 @@ case class ST_Reverse(inputExpressions: Seq[Expression]) * @param inputExpressions sequence of 2 input arguments, a geometry and a value 'n' */ case class ST_PointN(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.pointN) with FoldableExpression { + extends InferredExpression(Functions.pointN _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -789,7 +786,7 @@ case class ST_PointN(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_Force_2D(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.force2D) with FoldableExpression { + extends InferredExpression(Functions.force2D _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -802,7 +799,7 @@ case class ST_Force_2D(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_AsEWKT(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.asEWKT) with FoldableExpression { + extends InferredExpression(Functions.asEWKT _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -810,7 +807,7 @@ case class ST_AsEWKT(inputExpressions: Seq[Expression]) } case class ST_AsGML(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.asGML) with FoldableExpression { + extends InferredExpression(Functions.asGML _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -818,7 +815,7 @@ case class ST_AsGML(inputExpressions: Seq[Expression]) } case class ST_AsKML(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.asKML) with FoldableExpression { + extends InferredExpression(Functions.asKML _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -831,7 +828,7 @@ case class ST_AsKML(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_IsEmpty(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.isEmpty) with FoldableExpression { + extends InferredExpression(Functions.isEmpty _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -844,7 +841,7 @@ case class ST_IsEmpty(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_XMax(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.xMax) with FoldableExpression { + extends InferredExpression(Functions.xMax _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -857,7 +854,7 @@ case class ST_XMax(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_XMin(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.xMin) with FoldableExpression { + extends InferredExpression(Functions.xMin _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -871,7 +868,7 @@ case class ST_XMin(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_BuildArea(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.buildArea) with FoldableExpression { + extends InferredExpression(Functions.buildArea _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = { copy(inputExpressions = newChildren) @@ -884,7 +881,7 @@ case class ST_BuildArea(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_Normalize(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.normalize) with FoldableExpression { + extends InferredExpression(Functions.normalize _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = { copy(inputExpressions = newChildren) @@ -897,7 +894,7 @@ case class ST_Normalize(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_LineFromMultiPoint(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.lineFromMultiPoint) with FoldableExpression { + extends InferredExpression(Functions.lineFromMultiPoint _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -910,7 +907,7 @@ case class ST_LineFromMultiPoint(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_Split(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.split) with FoldableExpression { + extends InferredExpression(Functions.split _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -918,7 +915,7 @@ case class ST_Split(inputExpressions: Seq[Expression]) } case class ST_S2CellIDs(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.s2CellIDs) with FoldableExpression { + extends InferredExpression(Functions.s2CellIDs _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -926,13 +923,11 @@ case class ST_S2CellIDs(inputExpressions: Seq[Expression]) } case class ST_CollectionExtract(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.collectionExtract) with FoldableExpression { + extends InferredExpression(InferrableFunction.allowRightNull(Functions.collectionExtract)) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } - - override def allowRightNull: Boolean = true } /** @@ -942,7 +937,7 @@ case class ST_CollectionExtract(inputExpressions: Seq[Expression]) * @param inputExpressions Geometry */ case class ST_GeometricMedian(inputExpressions: Seq[Expression]) - extends InferredQuarternaryExpression(Functions.geometricMedian) with FoldableExpression { + extends InferredExpression(inferrableFunction4(Functions.geometricMedian)) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -950,7 +945,7 @@ case class ST_GeometricMedian(inputExpressions: Seq[Expression]) } case class ST_DistanceSphere(inputExpressions: Seq[Expression]) - extends InferredTernaryExpression(Haversine.distance) with FoldableExpression { + extends InferredExpression(inferrableFunction3(Haversine.distance)) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -958,7 +953,7 @@ case class ST_DistanceSphere(inputExpressions: Seq[Expression]) } case class ST_DistanceSpheroid(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Spheroid.distance) with FoldableExpression { + extends InferredExpression(Spheroid.distance _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -966,7 +961,7 @@ case class ST_DistanceSpheroid(inputExpressions: Seq[Expression]) } case class ST_AreaSpheroid(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Spheroid.area) with FoldableExpression { + extends InferredExpression(Spheroid.area _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -974,7 +969,7 @@ case class ST_AreaSpheroid(inputExpressions: Seq[Expression]) } case class ST_LengthSpheroid(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Spheroid.length) with FoldableExpression { + extends InferredExpression(Spheroid.length _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -982,14 +977,14 @@ case class ST_LengthSpheroid(inputExpressions: Seq[Expression]) } case class ST_NumPoints(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.numPoints) with FoldableExpression { + extends InferredExpression(Functions.numPoints _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } } case class ST_Force3D(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(Functions.force3D) with FoldableExpression { + extends InferredExpression(inferrableFunction2(Functions.force3D)) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -997,14 +992,14 @@ case class ST_Force3D(inputExpressions: Seq[Expression]) } case class ST_NRings(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.nRings) with FoldableExpression { + extends InferredExpression(Functions.nRings _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } } case class ST_Translate(inputExpressions: Seq[Expression]) - extends InferredQuarternaryExpression(Functions.translate) with FoldableExpression { + extends InferredExpression(inferrableFunction4(Functions.translate)) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } @@ -1018,14 +1013,14 @@ case class ST_Dimension(inputExpressions: Seq[Expression]) } case class ST_BoundingDiagonal(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Functions.boundingDiagonal) with FoldableExpression { + extends InferredExpression(Functions.boundingDiagonal _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } } case class ST_HausdorffDistance(inputExpressions: Seq[Expression]) - extends InferredTernaryExpression(Functions.hausdorffDistance) with FoldableExpression { + extends InferredExpression(inferrableFunction3(Functions.hausdorffDistance)) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableFunctionConverter.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableFunctionConverter.scala new file mode 100644 index 0000000000..83156db94a --- /dev/null +++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableFunctionConverter.scala @@ -0,0 +1,494 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.sedona_sql.expressions + +import scala.reflect.runtime.universe.TypeTag + +/** + * Implicit conversions from Java/Scala functions to [[InferrableFunction]]. This should be used in conjunction with + * [[InferredExpression]] to make wrapping Java/Scala functions as catalyst expressions much easier. + */ +object InferrableFunctionConverter { + // scalastyle:off line.size.limit + implicit def inferrableFunction1[R: InferrableType, A1: InferrableType](f: (A1) => R)(implicit typeTag: TypeTag[(A1) => R]) + : InferrableFunction = InferrableFunction(typeTag, argExtractors => { + val func = f.asInstanceOf[(Any) => Any] + val extractor1 = argExtractors(0) + input => { + val arg1 = extractor1(input) + if (arg1 != null) { + func(arg1) + } else { + null + } + } + }) + + implicit def inferrableFunction2[R: InferrableType, A1: InferrableType, A2: InferrableType](f: (A1, A2) => R)(implicit typeTag: TypeTag[(A1, A2) => R]) + : InferrableFunction = InferrableFunction(typeTag, argExtractors => { + val func = f.asInstanceOf[(Any, Any) => Any] + val extractor1 = argExtractors(0) + val extractor2 = argExtractors(1) + input => { + val arg1 = extractor1(input) + val arg2 = extractor2(input) + if (arg1 != null && arg2 != null) { + func(arg1, arg2) + } else { + null + } + } + }) + + implicit def inferrableFunction3[R: InferrableType, A1: InferrableType, A2: InferrableType, A3: InferrableType](f: (A1, A2, A3) => R)(implicit typeTag: TypeTag[(A1, A2, A3) => R]) + : InferrableFunction = InferrableFunction(typeTag, argExtractors => { + val func = f.asInstanceOf[(Any, Any, Any) => Any] + val extractor1 = argExtractors(0) + val extractor2 = argExtractors(1) + val extractor3 = argExtractors(2) + input => { + val arg1 = extractor1(input) + val arg2 = extractor2(input) + val arg3 = extractor3(input) + if (arg1 != null && arg2 != null && arg3 != null) { + func(arg1, arg2, arg3) + } else { + null + } + } + }) + + implicit def inferrableFunction4[R: InferrableType, A1: InferrableType, A2: InferrableType, A3: InferrableType, A4: InferrableType](f: (A1, A2, A3, A4) => R)(implicit typeTag: TypeTag[(A1, A2, A3, A4) => R]) + : InferrableFunction = InferrableFunction(typeTag, argExtractors => { + val func = f.asInstanceOf[(Any, Any, Any, Any) => Any] + val extractor1 = argExtractors(0) + val extractor2 = argExtractors(1) + val extractor3 = argExtractors(2) + val extractor4 = argExtractors(3) + input => { + val arg1 = extractor1(input) + val arg2 = extractor2(input) + val arg3 = extractor3(input) + val arg4 = extractor4(input) + if (arg1 != null && arg2 != null && arg3 != null && arg4 != null) { + func(arg1, arg2, arg3, arg4) + } else { + null + } + } + }) + + implicit def inferrableFunction5[R: InferrableType, A1: InferrableType, A2: InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType](f: (A1, A2, A3, A4, A5) => R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5) => R]) + : InferrableFunction = InferrableFunction(typeTag, argExtractors => { + val func = f.asInstanceOf[(Any, Any, Any, Any, Any) => Any] + val extractor1 = argExtractors(0) + val extractor2 = argExtractors(1) + val extractor3 = argExtractors(2) + val extractor4 = argExtractors(3) + val extractor5 = argExtractors(4) + input => { + val arg1 = extractor1(input) + val arg2 = extractor2(input) + val arg3 = extractor3(input) + val arg4 = extractor4(input) + val arg5 = extractor5(input) + if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 != null) { + func(arg1, arg2, arg3, arg4, arg5) + } else { + null + } + } + }) + + implicit def inferrableFunction6[R: InferrableType, A1: InferrableType, A2: InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: InferrableType](f: (A1, A2, A3, A4, A5, A6) => R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, A6) => R]) + : InferrableFunction = InferrableFunction(typeTag, argExtractors => { + val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any] + val extractor1 = argExtractors(0) + val extractor2 = argExtractors(1) + val extractor3 = argExtractors(2) + val extractor4 = argExtractors(3) + val extractor5 = argExtractors(4) + val extractor6 = argExtractors(5) + input => { + val arg1 = extractor1(input) + val arg2 = extractor2(input) + val arg3 = extractor3(input) + val arg4 = extractor4(input) + val arg5 = extractor5(input) + val arg6 = extractor6(input) + if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 != null && arg6 != null) { + func(arg1, arg2, arg3, arg4, arg5, arg6) + } else { + null + } + } + }) + + implicit def inferrableFunction7[R: InferrableType, A1: InferrableType, A2: InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: InferrableType, A7: InferrableType](f: (A1, A2, A3, A4, A5, A6, A7) => R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, A6, A7) => R]) + : InferrableFunction = InferrableFunction(typeTag, argExtractors => { + val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any] + val extractor1 = argExtractors(0) + val extractor2 = argExtractors(1) + val extractor3 = argExtractors(2) + val extractor4 = argExtractors(3) + val extractor5 = argExtractors(4) + val extractor6 = argExtractors(5) + val extractor7 = argExtractors(6) + input => { + val arg1 = extractor1(input) + val arg2 = extractor2(input) + val arg3 = extractor3(input) + val arg4 = extractor4(input) + val arg5 = extractor5(input) + val arg6 = extractor6(input) + val arg7 = extractor7(input) + if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 != null && arg6 != null && arg7 != null) { + func(arg1, arg2, arg3, arg4, arg5, arg6, arg7) + } else { + null + } + } + }) + + implicit def inferrableFunction8[R: InferrableType, A1: InferrableType, A2: InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: InferrableType, A7: InferrableType, A8: InferrableType](f: (A1, A2, A3, A4, A5, A6, A7, A8) => R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, A6, A7, A8) => R]) + : InferrableFunction = InferrableFunction(typeTag, argExtractors => { + val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val extractor1 = argExtractors(0) + val extractor2 = argExtractors(1) + val extractor3 = argExtractors(2) + val extractor4 = argExtractors(3) + val extractor5 = argExtractors(4) + val extractor6 = argExtractors(5) + val extractor7 = argExtractors(6) + val extractor8 = argExtractors(7) + input => { + val arg1 = extractor1(input) + val arg2 = extractor2(input) + val arg3 = extractor3(input) + val arg4 = extractor4(input) + val arg5 = extractor5(input) + val arg6 = extractor6(input) + val arg7 = extractor7(input) + val arg8 = extractor8(input) + if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 != null && arg6 != null && arg7 != null && arg8 != null) { + func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) + } else { + null + } + } + }) + + implicit def inferrableFunction9[R: InferrableType, A1: InferrableType, A2: InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: InferrableType, A7: InferrableType, A8: InferrableType, A9: InferrableType](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9) => R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, A6, A7, A8, A9) => R]) + : InferrableFunction = InferrableFunction(typeTag, argExtractors => { + val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val extractor1 = argExtractors(0) + val extractor2 = argExtractors(1) + val extractor3 = argExtractors(2) + val extractor4 = argExtractors(3) + val extractor5 = argExtractors(4) + val extractor6 = argExtractors(5) + val extractor7 = argExtractors(6) + val extractor8 = argExtractors(7) + val extractor9 = argExtractors(8) + input => { + val arg1 = extractor1(input) + val arg2 = extractor2(input) + val arg3 = extractor3(input) + val arg4 = extractor4(input) + val arg5 = extractor5(input) + val arg6 = extractor6(input) + val arg7 = extractor7(input) + val arg8 = extractor8(input) + val arg9 = extractor9(input) + if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 != null && arg6 != null && arg7 != null && arg8 != null && arg9 != null) { + func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9) + } else { + null + } + } + }) + + implicit def inferrableFunction10[R: InferrableType, A1: InferrableType, A2: InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: InferrableType, A7: InferrableType, A8: InferrableType, A9: InferrableType, A10: InferrableType](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10) => R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, A6, A7, A8, A9, A10) => R]) + : InferrableFunction = InferrableFunction(typeTag, argExtractors => { + val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val extractor1 = argExtractors(0) + val extractor2 = argExtractors(1) + val extractor3 = argExtractors(2) + val extractor4 = argExtractors(3) + val extractor5 = argExtractors(4) + val extractor6 = argExtractors(5) + val extractor7 = argExtractors(6) + val extractor8 = argExtractors(7) + val extractor9 = argExtractors(8) + val extractor10 = argExtractors(9) + input => { + val arg1 = extractor1(input) + val arg2 = extractor2(input) + val arg3 = extractor3(input) + val arg4 = extractor4(input) + val arg5 = extractor5(input) + val arg6 = extractor6(input) + val arg7 = extractor7(input) + val arg8 = extractor8(input) + val arg9 = extractor9(input) + val arg10 = extractor10(input) + if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 != null && arg6 != null && arg7 != null && arg8 != null && arg9 != null && arg10 != null) { + func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10) + } else { + null + } + } + }) + + implicit def inferrableFunction11[R: InferrableType, A1: InferrableType, A2: InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: InferrableType, A7: InferrableType, A8: InferrableType, A9: InferrableType, A10: InferrableType, A11: InferrableType](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11) => R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11) => R]) + : InferrableFunction = InferrableFunction(typeTag, argExtractors => { + val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val extractor1 = argExtractors(0) + val extractor2 = argExtractors(1) + val extractor3 = argExtractors(2) + val extractor4 = argExtractors(3) + val extractor5 = argExtractors(4) + val extractor6 = argExtractors(5) + val extractor7 = argExtractors(6) + val extractor8 = argExtractors(7) + val extractor9 = argExtractors(8) + val extractor10 = argExtractors(9) + val extractor11 = argExtractors(10) + input => { + val arg1 = extractor1(input) + val arg2 = extractor2(input) + val arg3 = extractor3(input) + val arg4 = extractor4(input) + val arg5 = extractor5(input) + val arg6 = extractor6(input) + val arg7 = extractor7(input) + val arg8 = extractor8(input) + val arg9 = extractor9(input) + val arg10 = extractor10(input) + val arg11 = extractor11(input) + if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 != null && arg6 != null && arg7 != null && arg8 != null && arg9 != null && arg10 != null && arg11 != null) { + func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11) + } else { + null + } + } + }) + + implicit def inferrableFunction12[R: InferrableType, A1: InferrableType, A2: InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: InferrableType, A7: InferrableType, A8: InferrableType, A9: InferrableType, A10: InferrableType, A11: InferrableType, A12: InferrableType](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12) => R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12) => R]) + : InferrableFunction = InferrableFunction(typeTag, argExtractors => { + val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val extractor1 = argExtractors(0) + val extractor2 = argExtractors(1) + val extractor3 = argExtractors(2) + val extractor4 = argExtractors(3) + val extractor5 = argExtractors(4) + val extractor6 = argExtractors(5) + val extractor7 = argExtractors(6) + val extractor8 = argExtractors(7) + val extractor9 = argExtractors(8) + val extractor10 = argExtractors(9) + val extractor11 = argExtractors(10) + val extractor12 = argExtractors(11) + input => { + val arg1 = extractor1(input) + val arg2 = extractor2(input) + val arg3 = extractor3(input) + val arg4 = extractor4(input) + val arg5 = extractor5(input) + val arg6 = extractor6(input) + val arg7 = extractor7(input) + val arg8 = extractor8(input) + val arg9 = extractor9(input) + val arg10 = extractor10(input) + val arg11 = extractor11(input) + val arg12 = extractor12(input) + if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 != null && arg6 != null && arg7 != null && arg8 != null && arg9 != null && arg10 != null && arg11 != null && arg12 != null) { + func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12) + } else { + null + } + } + }) + + implicit def inferrableFunction13[R: InferrableType, A1: InferrableType, A2: InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: InferrableType, A7: InferrableType, A8: InferrableType, A9: InferrableType, A10: InferrableType, A11: InferrableType, A12: InferrableType, A13: InferrableType](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13) => R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13) => R]) + : InferrableFunction = InferrableFunction(typeTag, argExtractors => { + val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val extractor1 = argExtractors(0) + val extractor2 = argExtractors(1) + val extractor3 = argExtractors(2) + val extractor4 = argExtractors(3) + val extractor5 = argExtractors(4) + val extractor6 = argExtractors(5) + val extractor7 = argExtractors(6) + val extractor8 = argExtractors(7) + val extractor9 = argExtractors(8) + val extractor10 = argExtractors(9) + val extractor11 = argExtractors(10) + val extractor12 = argExtractors(11) + val extractor13 = argExtractors(12) + input => { + val arg1 = extractor1(input) + val arg2 = extractor2(input) + val arg3 = extractor3(input) + val arg4 = extractor4(input) + val arg5 = extractor5(input) + val arg6 = extractor6(input) + val arg7 = extractor7(input) + val arg8 = extractor8(input) + val arg9 = extractor9(input) + val arg10 = extractor10(input) + val arg11 = extractor11(input) + val arg12 = extractor12(input) + val arg13 = extractor13(input) + if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 != null && arg6 != null && arg7 != null && arg8 != null && arg9 != null && arg10 != null && arg11 != null && arg12 != null && arg13 != null) { + func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13) + } else { + null + } + } + }) + + implicit def inferrableFunction14[R: InferrableType, A1: InferrableType, A2: InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: InferrableType, A7: InferrableType, A8: InferrableType, A9: InferrableType, A10: InferrableType, A11: InferrableType, A12: InferrableType, A13: InferrableType, A14: InferrableType](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14) => R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14) => R]) + : InferrableFunction = InferrableFunction(typeTag, argExtractors => { + val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val extractor1 = argExtractors(0) + val extractor2 = argExtractors(1) + val extractor3 = argExtractors(2) + val extractor4 = argExtractors(3) + val extractor5 = argExtractors(4) + val extractor6 = argExtractors(5) + val extractor7 = argExtractors(6) + val extractor8 = argExtractors(7) + val extractor9 = argExtractors(8) + val extractor10 = argExtractors(9) + val extractor11 = argExtractors(10) + val extractor12 = argExtractors(11) + val extractor13 = argExtractors(12) + val extractor14 = argExtractors(13) + input => { + val arg1 = extractor1(input) + val arg2 = extractor2(input) + val arg3 = extractor3(input) + val arg4 = extractor4(input) + val arg5 = extractor5(input) + val arg6 = extractor6(input) + val arg7 = extractor7(input) + val arg8 = extractor8(input) + val arg9 = extractor9(input) + val arg10 = extractor10(input) + val arg11 = extractor11(input) + val arg12 = extractor12(input) + val arg13 = extractor13(input) + val arg14 = extractor14(input) + if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 != null && arg6 != null && arg7 != null && arg8 != null && arg9 != null && arg10 != null && arg11 != null && arg12 != null && arg13 != null && arg14 != null) { + func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14) + } else { + null + } + } + }) + + implicit def inferrableFunction15[R: InferrableType, A1: InferrableType, A2: InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: InferrableType, A7: InferrableType, A8: InferrableType, A9: InferrableType, A10: InferrableType, A11: InferrableType, A12: InferrableType, A13: InferrableType, A14: InferrableType, A15: InferrableType](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15) => R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15) => R]) + : InferrableFunction = InferrableFunction(typeTag, argExtractors => { + val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val extractor1 = argExtractors(0) + val extractor2 = argExtractors(1) + val extractor3 = argExtractors(2) + val extractor4 = argExtractors(3) + val extractor5 = argExtractors(4) + val extractor6 = argExtractors(5) + val extractor7 = argExtractors(6) + val extractor8 = argExtractors(7) + val extractor9 = argExtractors(8) + val extractor10 = argExtractors(9) + val extractor11 = argExtractors(10) + val extractor12 = argExtractors(11) + val extractor13 = argExtractors(12) + val extractor14 = argExtractors(13) + val extractor15 = argExtractors(14) + input => { + val arg1 = extractor1(input) + val arg2 = extractor2(input) + val arg3 = extractor3(input) + val arg4 = extractor4(input) + val arg5 = extractor5(input) + val arg6 = extractor6(input) + val arg7 = extractor7(input) + val arg8 = extractor8(input) + val arg9 = extractor9(input) + val arg10 = extractor10(input) + val arg11 = extractor11(input) + val arg12 = extractor12(input) + val arg13 = extractor13(input) + val arg14 = extractor14(input) + val arg15 = extractor15(input) + if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 != null && arg6 != null && arg7 != null && arg8 != null && arg9 != null && arg10 != null && arg11 != null && arg12 != null && arg13 != null && arg14 != null && arg15 != null) { + func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15) + } else { + null + } + } + }) + + implicit def inferrableFunction16[R: InferrableType, A1: InferrableType, A2: InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: InferrableType, A7: InferrableType, A8: InferrableType, A9: InferrableType, A10: InferrableType, A11: InferrableType, A12: InferrableType, A13: InferrableType, A14: InferrableType, A15: InferrableType, A16: InferrableType](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16) => R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16) => R]) + : InferrableFunction = InferrableFunction(typeTag, argExtractors => { + val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val extractor1 = argExtractors(0) + val extractor2 = argExtractors(1) + val extractor3 = argExtractors(2) + val extractor4 = argExtractors(3) + val extractor5 = argExtractors(4) + val extractor6 = argExtractors(5) + val extractor7 = argExtractors(6) + val extractor8 = argExtractors(7) + val extractor9 = argExtractors(8) + val extractor10 = argExtractors(9) + val extractor11 = argExtractors(10) + val extractor12 = argExtractors(11) + val extractor13 = argExtractors(12) + val extractor14 = argExtractors(13) + val extractor15 = argExtractors(14) + val extractor16 = argExtractors(15) + input => { + val arg1 = extractor1(input) + val arg2 = extractor2(input) + val arg3 = extractor3(input) + val arg4 = extractor4(input) + val arg5 = extractor5(input) + val arg6 = extractor6(input) + val arg7 = extractor7(input) + val arg8 = extractor8(input) + val arg9 = extractor9(input) + val arg10 = extractor10(input) + val arg11 = extractor11(input) + val arg12 = extractor12(input) + val arg13 = extractor13(input) + val arg14 = extractor14(input) + val arg15 = extractor15(input) + val arg16 = extractor16(input) + if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 != null && arg6 != null && arg7 != null && arg8 != null && arg9 != null && arg10 != null && arg11 != null && arg12 != null && arg13 != null && arg14 != null && arg15 != null && arg16 != null) { + func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16) + } else { + null + } + } + }) + // scalastyle:on +} diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala new file mode 100644 index 0000000000..a1ce24b9e8 --- /dev/null +++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, DataType, DataTypes, DoubleType, IntegerType, LongType, StringType} +import org.apache.spark.unsafe.types.UTF8String +import org.locationtech.jts.geom.Geometry +import org.apache.spark.sql.sedona_sql.expressions.implicits._ + +import scala.reflect.runtime.universe.TypeTag +import scala.reflect.runtime.universe.Type +import scala.reflect.runtime.universe.typeOf + +/** + * This is the base class for wrapping Java/Scala functions as a catalyst expression in Spark SQL. + * @param f The function to be wrapped. Subclasses can simply pass a function to this constructor, + * and the function will be converted to [[InferrableFunction]] by [[InferrableFunctionConverter]] + * automatically. + */ +abstract class InferredExpression(f: InferrableFunction) + extends Expression with ImplicitCastInputTypes with SerdeAware with CodegenFallback with FoldableExpression + with Serializable { + def inputExpressions: Seq[Expression] + override def children: Seq[Expression] = inputExpressions + override def toString: String = s" **${getClass.getName}** " + override def nullable: Boolean = true + override def inputTypes: Seq[AbstractDataType] = f.sparkInputTypes + override def dataType: DataType = f.sparkReturnType + + private val argExtractors: Array[InternalRow => Any] = f.buildExtractors(inputExpressions) + private val evaluator: InternalRow => Any = f.evaluatorBuilder(argExtractors) + + override def eval(input: InternalRow): Any = f.serializer(evaluator(input)) + override def evalWithoutSerialization(input: InternalRow): Any = evaluator(input) +} + +// This is a compile time type shield for the types we are able to infer. Anything +// other than these types will cause a compilation error. This is the Scala +// 2 way of making a union type. +sealed class InferrableType[T: TypeTag] +object InferrableType { + implicit val geometryInstance: InferrableType[Geometry] = + new InferrableType[Geometry] {} + implicit val geometryArrayInstance: InferrableType[Array[Geometry]] = + new InferrableType[Array[Geometry]] {} + implicit val javaDoubleInstance: InferrableType[java.lang.Double] = + new InferrableType[java.lang.Double] {} + implicit val javaIntegerInstance: InferrableType[java.lang.Integer] = + new InferrableType[java.lang.Integer] {} + implicit val doubleInstance: InferrableType[Double] = + new InferrableType[Double] {} + implicit val booleanInstance: InferrableType[Boolean] = + new InferrableType[Boolean] {} + implicit val booleanOptInstance: InferrableType[Option[Boolean]] = + new InferrableType[Option[Boolean]] {} + implicit val intInstance: InferrableType[Int] = + new InferrableType[Int] {} + implicit val stringInstance: InferrableType[String] = + new InferrableType[String] {} + implicit val binaryInstance: InferrableType[Array[Byte]] = + new InferrableType[Array[Byte]] {} + implicit val longArrayInstance: InferrableType[Array[java.lang.Long]] = + new InferrableType[Array[java.lang.Long]] {} +} + +object InferredTypes { + def buildArgumentExtractor(t: Type): Expression => InternalRow => Any = { + if (t =:= typeOf[Geometry]) { + expr => input => expr.toGeometry(input) + } else if (t =:= typeOf[Array[Geometry]]) { + expr => input => expr.toGeometryArray(input) + } else if (t =:= typeOf[String]) { + expr => input => expr.asString(input) + } else { + expr => input => expr.eval(input) + } + } + + def buildSerializer(t: Type): Any => Any = { + if (t =:= typeOf[Geometry]) { + output => + if (output != null) { + output.asInstanceOf[Geometry].toGenericArrayData + } else { + null + } + } else if (t =:= typeOf[String]) { + output => + if (output != null) { + UTF8String.fromString(output.asInstanceOf[String]) + } else { + null + } + } else if (t =:= typeOf[Array[java.lang.Long]]) { + output => + if (output != null) { + ArrayData.toArrayData(output) + } else { + null + } + } else if (t =:= typeOf[Array[Geometry]]) { + output => + if (output != null) { + ArrayData.toArrayData(output.asInstanceOf[Array[Geometry]].map(_.toGenericArrayData)) + } else { + null + } + } else if (t =:= typeOf[Option[Boolean]]) { + output => + if (output != null) { + output.asInstanceOf[Option[Boolean]].orNull + } else { + null + } + } else { + output => output + } + } + + def inferSparkType(t: Type): DataType = { + if (t =:= typeOf[Geometry]) { + GeometryUDT + } else if (t =:= typeOf[Array[Geometry]]) { + DataTypes.createArrayType(GeometryUDT) + } else if (t =:= typeOf[java.lang.Double]) { + DoubleType + } else if (t =:= typeOf[java.lang.Integer]) { + IntegerType + } else if (t =:= typeOf[Double]) { + DoubleType + } else if (t =:= typeOf[Int]) { + IntegerType + } else if (t =:= typeOf[String]) { + StringType + } else if (t =:= typeOf[Array[Byte]]) { + BinaryType + } else if (t =:= typeOf[Array[java.lang.Long]]) { + DataTypes.createArrayType(LongType) + } else if (t =:= typeOf[Option[Boolean]]) { + BooleanType + } else { + BooleanType + } + } +} + +case class InferrableFunction(sparkInputTypes: Seq[AbstractDataType], + sparkReturnType: DataType, + serializer: Any => Any, + argExtractorBuilders: Seq[Expression => InternalRow => Any], + evaluatorBuilder: Array[InternalRow => Any] => InternalRow => Any) { + def buildExtractors(expressions: Seq[Expression]): Array[InternalRow => Any] = { + argExtractorBuilders.zipAll(expressions, null, null).flatMap { + case (null, _) => None + case (builder, expr) => Some(builder(expr)) + }.toArray + } +} + +object InferrableFunction { + /** + * Infer input types and return type from a type tag, and construct builder for argument extractors. + * @param typeTag Type tag of the function. + * @param evaluatorBuilder Builder for the evaluator. + * @return InferrableFunction. + */ + def apply(typeTag: TypeTag[_], evaluatorBuilder: Array[InternalRow => Any] => InternalRow => Any): InferrableFunction = { + val argTypes = typeTag.tpe.typeArgs.init + val returnType = typeTag.tpe.typeArgs.last + val sparkInputTypes: Seq[AbstractDataType] = argTypes.map(InferredTypes.inferSparkType) + val sparkReturnType: DataType = InferredTypes.inferSparkType(returnType) + val serializer = InferredTypes.buildSerializer(returnType) + val argExtractorBuilders = argTypes.map(InferredTypes.buildArgumentExtractor) + InferrableFunction(sparkInputTypes, sparkReturnType, serializer, argExtractorBuilders, evaluatorBuilder) + } + + /** + * A variant of binary inferred expression which allows the second argument to be null. + * @param f Function to be wrapped as a catalyst expression. + * @param typeTag Type tag of the function. + * @tparam R Return type of the function. + * @tparam A1 Type of the first argument. + * @tparam A2 Type of the second argument. + * @return InferrableFunction. + */ + def allowRightNull[R, A1, A2](f: (A1, A2) => R)(implicit typeTag: TypeTag[(A1, A2) => R]): InferrableFunction = { + apply(typeTag, extractors => { + val func = f.asInstanceOf[(Any, Any) => Any] + val extractor1 = extractors(0) + val extractor2 = extractors(1) + input => { + val arg1 = extractor1(input) + val arg2 = extractor2(input) + if (arg1 != null) { + func(arg1, arg2) + } else { + null + } + } + }) + } +} diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala deleted file mode 100644 index f526baf0cf..0000000000 --- a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala +++ /dev/null @@ -1,366 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.sedona_sql.expressions - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.sedona_sql.expressions.implicits._ -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String -import org.locationtech.jts.geom.Geometry - -import scala.reflect.runtime.universe._ - -/** - * Make expression foldable by constant folding optimizer. If all children - * expressions are foldable, then the expression itself is foldable. - */ -trait FoldableExpression extends Expression { - override def foldable: Boolean = children.forall(_.foldable) -} - -abstract class UnaryGeometryExpression extends Expression with SerdeAware with ExpectsInputTypes { - def inputExpressions: Seq[Expression] - - override def nullable: Boolean = true - - override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT) - - override def eval(input: InternalRow): Any = { - val result = evalWithoutSerialization(input) - serializeResult(result) - } - - override def evalWithoutSerialization(input: InternalRow): Any ={ - val inputExpression = inputExpressions.head - val geometry = inputExpression match { - case expr: SerdeAware => expr.evalWithoutSerialization(input) - case expr: Any => expr.toGeometry(input) - } - - (geometry) match { - case (geometry: Geometry) => nullSafeEval(geometry) - case _ => null - } - } - - protected def serializeResult(result: Any): Any = { - result match { - case geometry: Geometry => geometry.toGenericArrayData - case _ => result - } - } - - protected def nullSafeEval(geometry: Geometry): Any - - -} - -abstract class BinaryGeometryExpression extends Expression with SerdeAware with ExpectsInputTypes { - def inputExpressions: Seq[Expression] - - override def nullable: Boolean = true - - override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT, GeometryUDT) - - override def eval(input: InternalRow): Any = { - val result = evalWithoutSerialization(input) - serializeResult(result) - } - - override def evalWithoutSerialization(input: InternalRow): Any = { - val leftExpression = inputExpressions(0) - val leftGeometry = leftExpression match { - case expr: SerdeAware => expr.evalWithoutSerialization(input) - case _ => leftExpression.toGeometry(input) - } - - val rightExpression = inputExpressions(1) - val rightGeometry = rightExpression match { - case expr: SerdeAware => expr.evalWithoutSerialization(input) - case _ => rightExpression.toGeometry(input) - } - - (leftGeometry, rightGeometry) match { - case (leftGeometry: Geometry, rightGeometry: Geometry) => nullSafeEval(leftGeometry, rightGeometry) - case _ => null - } - } - - protected def serializeResult(result: Any): Any = { - result match { - case geometry: Geometry => geometry.toGenericArrayData - case _ => result - } - } - - protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any -} - -// This is a compile time type shield for the types we are able to infer. Anything -// other than these types will cause a compilation error. This is the Scala -// 2 way of making a union type. -sealed class InferrableType[T: TypeTag] -object InferrableType { - implicit val geometryInstance: InferrableType[Geometry] = - new InferrableType[Geometry] {} - implicit val geometryArrayInstance: InferrableType[Array[Geometry]] = - new InferrableType[Array[Geometry]] {} - implicit val javaDoubleInstance: InferrableType[java.lang.Double] = - new InferrableType[java.lang.Double] {} - implicit val javaIntegerInstance: InferrableType[java.lang.Integer] = - new InferrableType[java.lang.Integer] {} - implicit val doubleInstance: InferrableType[Double] = - new InferrableType[Double] {} - implicit val booleanInstance: InferrableType[Boolean] = - new InferrableType[Boolean] {} - implicit val intInstance: InferrableType[Int] = - new InferrableType[Int] {} - implicit val stringInstance: InferrableType[String] = - new InferrableType[String] {} - implicit val binaryInstance: InferrableType[Array[Byte]] = - new InferrableType[Array[Byte]] {} - implicit val longArrayInstance: InferrableType[Array[java.lang.Long]] = - new InferrableType[Array[java.lang.Long]] {} -} - -object InferredTypes { - def buildExtractor[T: TypeTag](expr: Expression): InternalRow => T = { - if (typeOf[T] =:= typeOf[Geometry]) { - input: InternalRow => expr.toGeometry(input).asInstanceOf[T] - } else if (typeOf[T] =:= typeOf[Array[Geometry]]) { - input: InternalRow => expr.toGeometryArray(input).asInstanceOf[T] - } else if (typeOf[T] =:= typeOf[String]) { - input: InternalRow => expr.asString(input).asInstanceOf[T] - } else { - input: InternalRow => expr.eval(input).asInstanceOf[T] - } - } - - def buildSerializer[T: TypeTag]: T => Any = { - if (typeOf[T] =:= typeOf[Geometry]) { - output: T => if (output != null) { - output.asInstanceOf[Geometry].toGenericArrayData - } else { - null - } - } else if (typeOf[T] =:= typeOf[String]) { - output: T => if (output != null) { - UTF8String.fromString(output.asInstanceOf[String]) - } else { - null - } - } else if (typeOf[T] =:= typeOf[Array[java.lang.Long]]) { - output: T => - if (output != null) { - ArrayData.toArrayData(output) - } else { - null - } - } else if (typeOf[T] =:= typeOf[Array[Geometry]]) { - output: T => - if (output != null) { - ArrayData.toArrayData(output.asInstanceOf[Array[Geometry]].map(_.toGenericArrayData)) - } else { - null - } - } else { - output: T => output - } - } - - def inferSparkType[T: TypeTag]: DataType = { - if (typeOf[T] =:= typeOf[Geometry]) { - GeometryUDT - } else if (typeOf[T] =:= typeOf[Array[Geometry]]) { - DataTypes.createArrayType(GeometryUDT) - } else if (typeOf[T] =:= typeOf[java.lang.Double]) { - DoubleType - } else if (typeOf[T] =:= typeOf[java.lang.Integer]) { - IntegerType - } else if (typeOf[T] =:= typeOf[Double]) { - DoubleType - } else if (typeOf[T] =:= typeOf[Int]) { - IntegerType - } else if (typeOf[T] =:= typeOf[String]) { - StringType - } else if (typeOf[T] =:= typeOf[Array[Byte]]) { - BinaryType - } else if (typeOf[T] =:= typeOf[Array[java.lang.Long]]) { - DataTypes.createArrayType(LongType) - } else { - BooleanType - } - } -} - -/** - * The implicit TypeTag's tell Scala to maintain generic type info at runtime. Normally type - * erasure would remove any knowledge of what the passed in generic type is. - */ -abstract class InferredUnaryExpression[A1: InferrableType, R: InferrableType] - (f: (A1) => R) - (implicit val a1Tag: TypeTag[A1], implicit val rTag: TypeTag[R]) - extends Expression with ImplicitCastInputTypes with SerdeAware with CodegenFallback with Serializable { - import InferredTypes._ - - def inputExpressions: Seq[Expression] - - override def children: Seq[Expression] = inputExpressions - - override def toString: String = s" **${getClass.getName}** " - - override def inputTypes: Seq[AbstractDataType] = Seq(inferSparkType[A1]) - - override def nullable: Boolean = true - - override def dataType = inferSparkType[R] - - lazy val extract = buildExtractor[A1](inputExpressions(0)) - - lazy val serialize = buildSerializer[R] - - override def eval(input: InternalRow): Any = serialize(evalWithoutSerialization(input).asInstanceOf[R]) - - override def evalWithoutSerialization(input: InternalRow): Any = { - val value = extract(input) - if (value != null) { - f(value) - } else { - null - } - } -} - -abstract class InferredBinaryExpression[A1: InferrableType, A2: InferrableType, R: InferrableType] - (f: (A1, A2) => R) - (implicit val a1Tag: TypeTag[A1], implicit val a2Tag: TypeTag[A2], implicit val rTag: TypeTag[R]) - extends Expression with ImplicitCastInputTypes with SerdeAware with CodegenFallback with Serializable { - import InferredTypes._ - - def inputExpressions: Seq[Expression] - - override def children: Seq[Expression] = inputExpressions - - override def toString: String = s" **${getClass.getName}** " - - override def inputTypes: Seq[AbstractDataType] = Seq(inferSparkType[A1], inferSparkType[A2]) - - override def nullable: Boolean = true - - def allowRightNull: Boolean = false - - override def dataType = inferSparkType[R] - - lazy val extractLeft = buildExtractor[A1](inputExpressions(0)) - lazy val extractRight = buildExtractor[A2](inputExpressions(1)) - - lazy val serialize = buildSerializer[R] - - override def eval(input: InternalRow): Any = serialize(evalWithoutSerialization(input).asInstanceOf[R]) - - override def evalWithoutSerialization(input: InternalRow): Any = { - val left = extractLeft(input) - val right = extractRight(input) - if (left != null && (right != null || allowRightNull)) { - f(left, right) - } else { - null - } - } -} - -abstract class InferredTernaryExpression[A1: InferrableType, A2: InferrableType, A3: InferrableType, R: InferrableType] -(f: (A1, A2, A3) => R) -(implicit val a1Tag: TypeTag[A1], implicit val a2Tag: TypeTag[A2], implicit val a3Tag: TypeTag[A3], implicit val rTag: TypeTag[R]) - extends Expression with ImplicitCastInputTypes with SerdeAware with CodegenFallback with Serializable { - import InferredTypes._ - - def inputExpressions: Seq[Expression] - - override def children: Seq[Expression] = inputExpressions - - override def toString: String = s" **${getClass.getName}** " - - override def inputTypes: Seq[AbstractDataType] = Seq(inferSparkType[A1], inferSparkType[A2], inferSparkType[A3]) - - override def nullable: Boolean = true - - override def dataType = inferSparkType[R] - - lazy val extractFirst = buildExtractor[A1](inputExpressions(0)) - lazy val extractSecond = buildExtractor[A2](inputExpressions(1)) - lazy val extractThird = buildExtractor[A3](inputExpressions(2)) - - lazy val serialize = buildSerializer[R] - - override def eval(input: InternalRow): Any = serialize(evalWithoutSerialization(input).asInstanceOf[R]) - - override def evalWithoutSerialization(input: InternalRow): Any = { - val first = extractFirst(input) - val second = extractSecond(input) - val third = extractThird(input) - if (first != null && second != null && third != null) { - f(first, second, third) - } else { - null - } - } -} - -abstract class InferredQuarternaryExpression[A1: InferrableType, A2: InferrableType, A3: InferrableType, A4: InferrableType, R: InferrableType] -(f: (A1, A2, A3, A4) => R) -(implicit val a1Tag: TypeTag[A1], implicit val a2Tag: TypeTag[A2], implicit val a3Tag: TypeTag[A3], implicit val a4Tag: TypeTag[A4], implicit val rTag: TypeTag[R]) - extends Expression with ImplicitCastInputTypes with CodegenFallback with Serializable { - import InferredTypes._ - - def inputExpressions: Seq[Expression] - - override def children: Seq[Expression] = inputExpressions - - override def toString: String = s" **${getClass.getName}** " - - override def inputTypes: Seq[AbstractDataType] = Seq(inferSparkType[A1], inferSparkType[A2], inferSparkType[A3], inferSparkType[A4]) - - override def nullable: Boolean = true - - override def dataType = inferSparkType[R] - - lazy val extractFirst = buildExtractor[A1](inputExpressions(0)) - lazy val extractSecond = buildExtractor[A2](inputExpressions(1)) - lazy val extractThird = buildExtractor[A3](inputExpressions(2)) - lazy val extractForth = buildExtractor[A4](inputExpressions(3)) - - lazy val serialize = buildSerializer[R] - - override def eval(input: InternalRow): Any = { - val first = extractFirst(input) - val second = extractSecond(input) - val third = extractThird(input) - val forth = extractForth(input) - if (first != null && second != null && third != null && forth != null) { - serialize(f(first, second, third, forth)) - } else { - null - } - } -}