Skip to content

Commit

Permalink
Implement withNewChildrenInternal directly
Browse files Browse the repository at this point in the history
avoid reflection which is done at runtime by structural types
  • Loading branch information
echeipesh committed Jan 13, 2023
1 parent b14adaa commit df552b8
Show file tree
Hide file tree
Showing 79 changed files with 136 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@ import com.typesafe.scalalogging.Logger
import geotrellis.raster.Tile
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression}
import org.apache.spark.sql.catalyst.expressions.BinaryExpression
import org.apache.spark.sql.types.DataType
import org.locationtech.rasterframes.expressions.DynamicExtractors._
import org.slf4j.LoggerFactory

/** Operation combining two tiles or a tile and a scalar into a new tile. */
trait BinaryRasterFunction extends BinaryExpression with RasterResult { self: HasBinaryExpressionCopy =>
override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight)
trait BinaryRasterFunction extends BinaryExpression with RasterResult {

@transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,15 @@ import geotrellis.raster.CellGrid
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.UnaryExpression

/**
* Implements boilerplate for subtype expressions processing TileUDT, RasterSourceUDT, and RasterRefs
* as Grid types.
*
* @since 11/4/18
*/
trait OnCellGridExpression extends UnaryExpression { self: HasUnaryExpressionCopy =>
override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild)

trait OnCellGridExpression extends UnaryExpression {
private lazy val fromRow: InternalRow => CellGrid[Int] = {
if (child.resolved) gridExtractor(child.dataType)
else throw new IllegalStateException(s"Child expression unbound: ${child}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.locationtech.rasterframes.expressions.DynamicExtractors._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.UnaryExpression
import org.locationtech.rasterframes.model.TileContext

/**
Expand All @@ -34,9 +34,7 @@ import org.locationtech.rasterframes.model.TileContext
*
* @since 11/3/18
*/
trait OnTileContextExpression extends UnaryExpression { self: HasUnaryExpressionCopy =>
override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild)

trait OnTileContextExpression extends UnaryExpression {
override def checkInputDataTypes(): TypeCheckResult = {
if (!projectedRasterLikeExtractor.isDefinedAt(child.dataType)) {
TypeCheckFailure(s"Input type '${child.dataType}' does not conform to `ProjectedRasterLike`.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ import org.locationtech.geomesa.spark.jts.udf.SpatialRelationFunctions._
*
* @since 12/28/17
*/
abstract class SpatialRelation extends BinaryExpression with CodegenFallback { this: HasBinaryExpressionCopy =>

override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression =
copy(left = newLeft, right = newRight)
abstract class SpatialRelation extends BinaryExpression with CodegenFallback {

def extractGeometry(expr: Expression, input: Any): Geometry = {
input match {
Expand Down Expand Up @@ -78,36 +75,42 @@ object SpatialRelation {
override def nodeName: String = "intersects"
val relation = ST_Intersects

override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression =
copy(left = newLeft, right = newRight)
override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight)
}
case class Contains(left: Expression, right: Expression) extends SpatialRelation {
override def nodeName = "contains"
val relation = ST_Contains
override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight)
}
case class Covers(left: Expression, right: Expression) extends SpatialRelation {
override def nodeName = "covers"
val relation = ST_Covers
override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight)
}
case class Crosses(left: Expression, right: Expression) extends SpatialRelation {
override def nodeName = "crosses"
val relation = ST_Crosses
override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight)
}
case class Disjoint(left: Expression, right: Expression) extends SpatialRelation {
override def nodeName = "disjoint"
val relation = ST_Disjoint
override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight)
}
case class Overlaps(left: Expression, right: Expression) extends SpatialRelation {
override def nodeName = "overlaps"
val relation = ST_Overlaps
override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight)
}
case class Touches(left: Expression, right: Expression) extends SpatialRelation {
override def nodeName = "touches"
val relation = ST_Touches
override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight)
}
case class Within(left: Expression, right: Expression) extends SpatialRelation {
override def nodeName = "within"
val relation = ST_Within
override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight)
}

private val predicateMap = Map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.locationtech.rasterframes.encoders.syntax._
import scala.reflect.runtime.universe._

/** Mixin providing boilerplate for DeclarativeAggrates over tile-conforming columns. */
trait UnaryRasterAggregate extends DeclarativeAggregate { self: HasUnaryExpressionCopy =>
trait UnaryRasterAggregate extends DeclarativeAggregate {
def child: Expression

def nullable: Boolean = child.nullable
Expand All @@ -42,8 +42,6 @@ trait UnaryRasterAggregate extends DeclarativeAggregate { self: HasUnaryExpressi

protected def tileOpAsExpression[R: TypeTag](name: String, op: Tile => R): Expression => ScalaUDF =
udfiexpr[R, Any](name, (dataType: DataType) => (a: Any) => if(a == null) null.asInstanceOf[R] else op(UnaryRasterAggregate.extractTileFromAny(dataType, a)))

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(newChildren(0))
}

object UnaryRasterAggregate {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,11 @@ import org.locationtech.rasterframes.expressions.DynamicExtractors._
import geotrellis.raster.Tile
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.UnaryExpression
import org.locationtech.rasterframes.model.TileContext

/** Boilerplate for expressions operating on a single Tile-like . */
trait UnaryRasterFunction extends UnaryExpression { self: HasUnaryExpressionCopy =>
override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild)

trait UnaryRasterFunction extends UnaryExpression {
override def checkInputDataTypes(): TypeCheckResult = {
if (!tileExtractor.isDefinedAt(child.dataType)) {
TypeCheckFailure(s"Input type '${child.dataType}' does not conform to a raster type.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@ package org.locationtech.rasterframes.expressions

import com.typesafe.scalalogging.Logger
import geotrellis.raster.Tile
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.types.DataType
import org.locationtech.rasterframes.model.TileContext
import org.slf4j.LoggerFactory

/** Operation on a tile returning a tile. */
trait UnaryRasterOp extends UnaryRasterFunction with RasterResult { this: HasUnaryExpressionCopy =>
trait UnaryRasterOp extends UnaryRasterFunction with RasterResult {
@transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName))

def dataType: DataType = child.dataType
Expand All @@ -38,7 +37,5 @@ trait UnaryRasterOp extends UnaryRasterFunction with RasterResult { this: HasUna
toInternalRow(op(tile), ctx)

protected def op(child: Tile): Tile

override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
}

Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ case class ExtractTile(child: Expression) extends UnaryRasterFunction with Codeg
case prt: ProjectedRasterTile => tileUDT.serialize(prt.tile)
case tile: Tile => tileSer(tile)
}

def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
}

object ExtractTile {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ case class GetCRS(child: Expression) extends UnaryExpression with CodegenFallbac
}
}

override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
}

object GetCRS {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ case class GetCellType(child: Expression) extends OnCellGridExpression with Code
/** Implemented by subtypes to process incoming ProjectedRasterLike entity. */
def eval(cg: CellGrid[Int]): Any = resultConverter(cg.cellType)

override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
}

object GetCellType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ case class GetDimensions(child: Expression) extends OnCellGridExpression with Co
def dataType = dimensionsEncoder[Int].schema

def eval(grid: CellGrid[Int]): Any = Dimensions[Int](grid.cols, grid.rows).toInternalRow

def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
}

object GetDimensions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ case class GetEnvelope(child: Expression) extends UnaryExpression with CodegenFa

def dataType: DataType = envelopeEncoder.schema

override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
}

object GetEnvelope {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ case class GetExtent(child: Expression) extends OnTileContextExpression with Cod

override def nodeName: String = "rf_extent"
def eval(ctx: TileContext): InternalRow = ctx.extent.toInternalRow

def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
}

object GetExtent {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ case class GetGeometry(child: Expression) extends OnTileContextExpression with C
def eval(ctx: TileContext): InternalRow =
JTSTypes.GeometryTypeInstance.serialize(ctx.extent.toPolygon())

def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
}

object GetGeometry {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ case class GetTileContext(child: Expression) extends UnaryRasterFunction with Co

protected def eval(tile: Tile, ctx: Option[TileContext]): Any =
ctx.map(SerializersCache.serializer[TileContext].apply).orNull

def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
}

object GetTileContext {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ case class RealizeTile(child: Expression) extends UnaryExpression with CodegenFa
tileSer(tile.toArrayTile())
}

override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
}

object RealizeTile {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
package org.locationtech.rasterframes.expressions.aggregates

import org.locationtech.rasterframes.encoders.SparkBasicEncoders._
import org.locationtech.rasterframes.expressions.{HasUnaryExpressionCopy, UnaryRasterAggregate}
import org.locationtech.rasterframes.expressions.UnaryRasterAggregate
import org.locationtech.rasterframes.expressions.tilestats.{DataCells, NoDataCells}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -35,7 +35,7 @@ import org.apache.spark.sql.{Column, TypedColumn}
* @since 10/5/17
* @param isData true if count should be of non-NoData cells, false if count should be of NoData cells.
*/
abstract class CellCountAggregate(isData: Boolean) extends UnaryRasterAggregate { self: HasUnaryExpressionCopy =>
abstract class CellCountAggregate(isData: Boolean) extends UnaryRasterAggregate {
private lazy val count = AttributeReference("count", LongType, false, Metadata.empty)()

override lazy val aggBufferAttributes = Seq(count)
Expand Down Expand Up @@ -68,6 +68,8 @@ object CellCountAggregate {
)
case class DataCells(child: Expression) extends CellCountAggregate(true) {
override def nodeName: String = "rf_agg_data_cells"

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(newChildren.head)
}

object DataCells {
Expand All @@ -86,6 +88,8 @@ object CellCountAggregate {
)
case class NoDataCells(child: Expression) extends CellCountAggregate(false) {
override def nodeName: String = "rf_agg_no_data_cells"

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(newChildren.head)
}
object NoDataCells {
def apply(tile: Column): TypedColumn[Any, Long] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ case class CellMeanAggregate(child: Expression) extends UnaryRasterAggregate {
val evaluateExpression = sum / new Cast(count, DoubleType)

def dataType: DataType = DoubleType

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(newChildren.head)
}

object CellMeanAggregate {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ case class LocalMeanAggregate(child: Expression) extends UnaryRasterAggregate {
BiasedAdd(sum.left, sum.right)
)
lazy val evaluateExpression: Expression = DivideTiles(sum, count)

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(newChildren.head)
}
object LocalMeanAggregate {
def apply(tile: Column): TypedColumn[Any, Tile] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ case class FocalMax(first: Expression, second: Expression, third: Expression) ex
case bt: BufferTile => bt.focalMax(neighborhood, target = target)
case _ => t.focalMax(neighborhood, target = target)
}

def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird)
}

object FocalMax {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ case class FocalMean(first: Expression, second: Expression, third: Expression) e
case _ => t.focalMean(neighborhood, target = target)
}

def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird)
}

object FocalMean {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ case class FocalMedian(first: Expression, second: Expression, third: Expression)
case bt: BufferTile => bt.focalMedian(neighborhood, target = target)
case _ => t.focalMedian(neighborhood, target = target)
}
def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird)
}

object FocalMedian {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ case class FocalMin(first: Expression, second: Expression, third: Expression) ex
case bt: BufferTile => bt.focalMin(neighborhood, target = target)
case _ => t.focalMin(neighborhood, target = target)
}

def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird)
}

object FocalMin {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ case class FocalMode(first: Expression, second: Expression, third: Expression) e
case bt: BufferTile => bt.focalMode(neighborhood, target = target)
case _ => t.focalMode(neighborhood, target = target)
}
def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird)
}

object FocalMode {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ case class FocalMoransI(first: Expression, second: Expression, third: Expression
case bt: BufferTile => bt.tileMoransI(neighborhood, target = target)
case _ => t.tileMoransI(neighborhood, target = target)
}
def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird)
}

object FocalMoransI {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,10 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, TernaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types.DataType
import org.locationtech.rasterframes.expressions.DynamicExtractors.{neighborhoodExtractor, targetCellExtractor, tileExtractor}
import org.locationtech.rasterframes.expressions.{HasTernaryExpressionCopy, RasterResult, row}
import org.locationtech.rasterframes.expressions.{RasterResult, row}
import org.slf4j.LoggerFactory

trait FocalNeighborhoodOp extends TernaryExpression with RasterResult with CodegenFallback {self: HasTernaryExpressionCopy =>
override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
copy(newFirst, newSecond, newThird)

trait FocalNeighborhoodOp extends TernaryExpression with RasterResult with CodegenFallback {
@transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName))

// Tile
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ case class FocalStdDev(first: Expression, second: Expression, third: Expression)
case bt: BufferTile => bt.focalStandardDeviation(neighborhood, target = target)
case _ => t.focalStandardDeviation(neighborhood, target = target)
}
def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird)
}

object FocalStdDev {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ case class Abs(child: Expression) extends UnaryRasterOp with NullToValue with Co
def na: Any = null
protected def op(t: Tile): Tile = t.localAbs()

def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
}

object Abs {
Expand Down
Loading

0 comments on commit df552b8

Please sign in to comment.