Skip to content

Commit

Permalink
[SEDONA-311] Refactor Inferred*Expression base class for Sedona SQL (
Browse files Browse the repository at this point in the history
  • Loading branch information
Kontinuation committed Jun 25, 2023
1 parent 20e0de4 commit debac33
Show file tree
Hide file tree
Showing 6 changed files with 858 additions and 483 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.sedona_sql.expressions.implicits.GeometryEnhancer
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -36,7 +37,7 @@ import org.apache.spark.unsafe.types.UTF8String
* string, the second parameter is the delimiter. String format should be similar to CSV/TSV
*/
case class ST_PointFromText(inputExpressions: Seq[Expression])
extends InferredBinaryExpression(Constructors.pointFromText) with FoldableExpression {
extends InferredExpression(Constructors.pointFromText _) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
Expand All @@ -48,7 +49,7 @@ case class ST_PointFromText(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_PolygonFromText(inputExpressions: Seq[Expression])
extends InferredBinaryExpression(Constructors.polygonFromText) with FoldableExpression {
extends InferredExpression(Constructors.polygonFromText _) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
Expand All @@ -60,7 +61,7 @@ case class ST_PolygonFromText(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_LineFromText(inputExpressions: Seq[Expression])
extends InferredUnaryExpression(Constructors.lineFromText) with FoldableExpression {
extends InferredExpression(Constructors.lineFromText _) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
Expand All @@ -71,7 +72,7 @@ case class ST_LineFromText(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_LineStringFromText(inputExpressions: Seq[Expression])
extends InferredBinaryExpression(Constructors.lineStringFromText) with FoldableExpression {
extends InferredExpression(Constructors.lineStringFromText _) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
Expand All @@ -83,7 +84,7 @@ case class ST_LineStringFromText(inputExpressions: Seq[Expression])
* @param inputExpressions This function takes a geometry string and a srid. The string format must be WKT.
*/
case class ST_GeomFromWKT(inputExpressions: Seq[Expression])
extends InferredBinaryExpression(Constructors.geomFromWKT) with FoldableExpression {
extends InferredExpression(Constructors.geomFromWKT _) {

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
Expand All @@ -97,7 +98,7 @@ case class ST_GeomFromWKT(inputExpressions: Seq[Expression])
* @param inputExpressions This function takes a geometry string and a srid. The string format must be WKT.
*/
case class ST_GeomFromText(inputExpressions: Seq[Expression])
extends InferredBinaryExpression(Constructors.geomFromWKT) with FoldableExpression {
extends InferredExpression(Constructors.geomFromWKT _) {

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
Expand Down Expand Up @@ -180,7 +181,7 @@ case class ST_GeomFromGeoJSON(inputExpressions: Seq[Expression])
* @param inputExpressions This function takes 2 parameter which are point x, y.
*/
case class ST_Point(inputExpressions: Seq[Expression])
extends InferredBinaryExpression(Constructors.point) with FoldableExpression {
extends InferredExpression(Constructors.point _) {

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
Expand All @@ -193,7 +194,7 @@ case class ST_Point(inputExpressions: Seq[Expression])
* @param inputExpressions This function takes 4 parameter which are point x, y, z and srid (default 0).
*/
case class ST_PointZ(inputExpressions: Seq[Expression])
extends InferredQuarternaryExpression(Constructors.pointZ) with FoldableExpression {
extends InferredExpression(Constructors.pointZ _) {

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
Expand All @@ -207,7 +208,7 @@ case class ST_PointZ(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_PolygonFromEnvelope(inputExpressions: Seq[Expression])
extends InferredQuarternaryExpression(Constructors.polygonFromEnvelope) with FoldableExpression {
extends InferredExpression(Constructors.polygonFromEnvelope _) {

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
Expand All @@ -226,22 +227,21 @@ trait UserDataGeneratator {
}

case class ST_GeomFromGeoHash(inputExpressions: Seq[Expression])
extends InferredBinaryExpression(Constructors.geomFromGeoHash) with FoldableExpression {
extends InferredExpression(InferrableFunction.allowRightNull(Constructors.geomFromGeoHash)) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
override def allowRightNull: Boolean = true
}

case class ST_GeomFromGML(inputExpressions: Seq[Expression])
extends InferredUnaryExpression(Constructors.geomFromGML) with FoldableExpression {
extends InferredExpression(Constructors.geomFromGML _) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
}

case class ST_GeomFromKML(inputExpressions: Seq[Expression])
extends InferredUnaryExpression(Constructors.geomFromKML) with FoldableExpression {
extends InferredExpression(Constructors.geomFromKML _) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
Expand All @@ -253,7 +253,7 @@ case class ST_GeomFromKML(inputExpressions: Seq[Expression])
* @param inputExpressions This function takes a geometry string and a srid. The string format must be WKT.
*/
case class ST_MPolyFromText(inputExpressions: Seq[Expression])
extends InferredBinaryExpression(Constructors.mPolyFromText) with FoldableExpression {
extends InferredExpression(Constructors.mPolyFromText _) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
Expand All @@ -265,7 +265,7 @@ case class ST_MPolyFromText(inputExpressions: Seq[Expression])
* @param inputExpressions This function takes a geometry string and a srid. The string format must be WKT.
*/
case class ST_MLineFromText(inputExpressions: Seq[Expression])
extends InferredBinaryExpression(Constructors.mLineFromText) with FoldableExpression {
extends InferredExpression(Constructors.mLineFromText _) {

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.spark.sql.sedona_sql.expressions

import org.apache.spark.sql.catalyst.expressions.Expression

/**
* Make expression foldable by constant folding optimizer. If all children
* expressions are foldable, then the expression itself is foldable.
*/
trait FoldableExpression extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
}
Loading

0 comments on commit debac33

Please sign in to comment.