Skip to content

Commit

Permalink
Expose TargetCells in focal ops
Browse files Browse the repository at this point in the history
  • Loading branch information
pomadchin committed Oct 1, 2021
1 parent 2750171 commit b97daee
Show file tree
Hide file tree
Showing 22 changed files with 660 additions and 223 deletions.
111 changes: 111 additions & 0 deletions core/src/main/scala/org/apache/spark/sql/rf/QuinaryExpression.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package org.apache.spark.sql.rf

import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode, FalseLiteral}

/**
* An expression with five inputs and one output. The output is by default evaluated to null if any input is evaluated to null
*/
abstract class QuinaryExpression extends Expression {

override def foldable: Boolean = children.forall(_.foldable)

override def nullable: Boolean = children.exists(_.nullable)

/**
* Default behavior of evaluation according to the default nullability of QuaternaryExpression.
* If subclass of QuaternaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
val exprs = children
val value1 = exprs(0).eval(input)
if (value1 != null) {
val value2 = exprs(1).eval(input)
if (value2 != null) {
val value3 = exprs(2).eval(input)
if (value3 != null) {
val value4 = exprs(3).eval(input)
if (value4 != null) {
val value5 = exprs(4).eval(input)
if (value5 != null) {
return nullSafeEval(value1, value2, value3, value4, value5)
}
}
}
}
}
null
}

/**
* Called by default [[eval]] implementation. If subclass of QuinaryExpression keep the
* default nullability, they can override this method to save null-check code. If we need
* full control of evaluation process, we should override [[eval]].
*/
protected def nullSafeEval(input1: Any, input2: Any, input3: Any, input4: Any, input5: Any): Any =
sys.error(s"QuinaryExpressions must override either eval or nullSafeEval")

/**
* Short hand for generating quinary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f accepts five variable names and returns Java code to compute the output.
*/
protected def defineCodeGen(ctx: CodegenContext, ev: ExprCode, f: (String, String, String, String, String) => String): ExprCode = {
nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3, eval4, eval5) => {
s"${ev.value} = ${f(eval1, eval2, eval3, eval4, eval5)};"
})
}

/**
* Short hand for generating quinary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f function that accepts the 5 non-null evaluation result names of children
* and returns Java code to compute the output.
*/
protected def nullSafeCodeGen(ctx: CodegenContext, ev: ExprCode, f: (String, String, String, String, String) => String): ExprCode = {
val firstGen = children(0).genCode(ctx)
val secondGen = children(1).genCode(ctx)
val thridGen = children(2).genCode(ctx)
val fourthGen = children(3).genCode(ctx)
val fifthGen = children(4).genCode(ctx)
val resultCode = f(firstGen.value, secondGen.value, thridGen.value, fourthGen.value, fifthGen.value)

if (nullable) {
val nullSafeEval =
firstGen.code + ctx.nullSafeExec(children(0).nullable, firstGen.isNull) {
secondGen.code + ctx.nullSafeExec(children(1).nullable, secondGen.isNull) {
thridGen.code + ctx.nullSafeExec(children(2).nullable, thridGen.isNull) {
fourthGen.code + ctx.nullSafeExec(children(3).nullable, fourthGen.isNull) {
fifthGen.code + ctx.nullSafeExec(children(4).nullable, fifthGen.isNull) {
s"""
${ev.isNull} = false; // resultCode could change nullability.
$resultCode
"""
}
}
}
}
}

ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval""")
} else {
ev.copy(code = code"""
${firstGen.code}
${secondGen.code}
${thridGen.code}
${fourthGen.code}
${fifthGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = FalseLiteral)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util.QuantileSummaries
import org.locationtech.geomesa.spark.jts.encoders.SpatialEncoders
import org.locationtech.rasterframes.model.{CellContext, LongExtent, TileContext, TileDataContext}
import frameless.TypedEncoder
import geotrellis.raster.mapalgebra.focal.{Kernel, Neighborhood}
import geotrellis.raster.mapalgebra.focal.{Kernel, Neighborhood, TargetCell}

import java.net.URI
import java.sql.Timestamp
Expand All @@ -55,6 +55,7 @@ trait StandardEncoders extends SpatialEncoders with TypedEncoders {

implicit lazy val uriEncoder: ExpressionEncoder[URI] = typedExpressionEncoder[URI]
implicit lazy val neighborhoodEncoder: ExpressionEncoder[Neighborhood] = typedExpressionEncoder[Neighborhood]
implicit lazy val targetCellEncoder: ExpressionEncoder[TargetCell] = typedExpressionEncoder[TargetCell]
implicit lazy val kernelEncoder: ExpressionEncoder[Kernel] = typedExpressionEncoder[Kernel]
implicit lazy val quantileSummariesEncoder: ExpressionEncoder[QuantileSummaries] = typedExpressionEncoder[QuantileSummaries]
implicit lazy val envelopeEncoder: ExpressionEncoder[Envelope] = typedExpressionEncoder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ package org.locationtech.rasterframes.encoders
import frameless._
import geotrellis.layer.{KeyBounds, LayoutDefinition, TileLayerMetadata}
import geotrellis.proj4.CRS
import geotrellis.raster.mapalgebra.focal.{Kernel, Neighborhood}
import geotrellis.raster.mapalgebra.focal.{Kernel, Neighborhood, TargetCell}
import geotrellis.raster.{CellGrid, CellType, Dimensions, GridBounds, Raster, Tile}
import geotrellis.vector.Extent
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.util.QuantileSummaries
import org.apache.spark.sql.rf.{CrsUDT, RasterSourceUDT, TileUDT}
import org.locationtech.jts.geom.Envelope
import org.locationtech.rasterframes.util.{FocalNeighborhood, KryoSupport}
import org.locationtech.rasterframes.util.{FocalNeighborhood, FocalTargetCell, KryoSupport}

import java.net.URI
import java.nio.ByteBuffer
Expand All @@ -37,6 +37,9 @@ trait TypedEncoders {
implicit val neighborhoodInjection: Injection[Neighborhood, String] = Injection(FocalNeighborhood(_), FocalNeighborhood.fromString(_).get)
implicit val neighborhoodTypedEncoder: TypedEncoder[Neighborhood] = TypedEncoder.usingInjection

implicit val targetCellInjection: Injection[TargetCell, String] = Injection(FocalTargetCell(_), FocalTargetCell.fromString)
implicit val targetCellTypedEncoder: TypedEncoder[TargetCell] = TypedEncoder.usingInjection

implicit val envelopeTypedEncoder: TypedEncoder[Envelope] =
ManualTypedEncoder.newInstance[Envelope](
fields = List(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
package org.locationtech.rasterframes.expressions

import geotrellis.proj4.CRS
import geotrellis.raster.{CellGrid, Neighborhood, Raster, Tile}
import geotrellis.raster.{CellGrid, Neighborhood, Raster, TargetCell, Tile}
import geotrellis.vector.Extent
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -38,7 +38,7 @@ import org.locationtech.rasterframes.model.{LazyCRS, LongExtent, TileContext}
import org.locationtech.rasterframes.ref.{ProjectedRasterLike, RasterRef}
import org.locationtech.rasterframes.tiles.ProjectedRasterTile
import org.apache.spark.sql.rf.CrsUDT
import org.locationtech.rasterframes.util.FocalNeighborhood
import org.locationtech.rasterframes.util.{FocalNeighborhood, FocalTargetCell}

private[rasterframes]
object DynamicExtractors {
Expand Down Expand Up @@ -230,4 +230,9 @@ object DynamicExtractors {
case _: StringType => (v: Any) => FocalNeighborhood.fromString(v.asInstanceOf[UTF8String].toString).get
case n if n.conformsToSchema(neighborhoodEncoder.schema) => { case ir: InternalRow => ir.as[Neighborhood] }
}

lazy val targetCellExtractor: PartialFunction[DataType, Any => TargetCell] = {
case _: StringType => (v: Any) => FocalTargetCell.fromString(v.asInstanceOf[UTF8String].toString)
case n if n.conformsToSchema(targetCellEncoder.schema) => { case ir: InternalRow => ir.as[TargetCell] }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,55 +21,61 @@

package org.locationtech.rasterframes.expressions.focalops

import geotrellis.raster.{BufferTile, CellSize}
import geotrellis.raster.{BufferTile, CellSize, TargetCell, Tile}
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
import org.locationtech.rasterframes.expressions.{NullToValue, RasterResult, UnaryRasterFunction, row}
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription}
import org.locationtech.rasterframes.expressions.{RasterResult, row}
import org.locationtech.rasterframes.encoders.syntax._
import org.locationtech.rasterframes.expressions.DynamicExtractors._
import org.locationtech.rasterframes.model.TileContext
import geotrellis.raster.Tile
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types.DataType
import org.slf4j.LoggerFactory
import com.typesafe.scalalogging.Logger
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}

@ExpressionDescription(
usage = "_FUNC_(tile) - Performs aspect on tile.",
usage = "_FUNC_(tile, target) - Performs aspect on tile.",
arguments = """
Arguments:
* tile - a tile to apply operation""",
* tile - a tile to apply operation
* target - the target cells to apply focal operation: data, nodata, all""",
examples = """
Examples:
> SELECT _FUNC_(tile);
> SELECT _FUNC_(tile, 'all');
..."""
)
case class Aspect(child: Expression) extends UnaryRasterFunction with RasterResult with NullToValue with CodegenFallback {
case class Aspect(left: Expression, right: Expression) extends BinaryExpression with RasterResult with CodegenFallback {
@transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName))

def na: Any = null
def dataType: DataType = left.dataType

def dataType: DataType = child.dataType
override def checkInputDataTypes(): TypeCheckResult =
if (!tileExtractor.isDefinedAt(left.dataType)) TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a raster type.")
else if(!targetCellExtractor.isDefinedAt(right.dataType)) TypeCheckFailure(s"Input type '${right.dataType}' does not conform to a string TargetCell type.")
else TypeCheckSuccess

override protected def nullSafeEval(input: Any): Any = {
val (tile, ctx) = tileExtractor(child.dataType)(row(input))
eval(extractBufferTile(tile), ctx)
override protected def nullSafeEval(tileInput: Any, targetCellInput: Any): Any = {
val (tile, ctx) = tileExtractor(left.dataType)(row(tileInput))
val target = targetCellExtractor(right.dataType)(targetCellInput)
eval(extractBufferTile(tile), ctx, target)
}

protected def eval(tile: Tile, ctx: Option[TileContext]): Any = ctx match {
case Some(ctx) => ctx.toProjectRasterTile(op(tile, ctx)).toInternalRow
protected def eval(tile: Tile, ctx: Option[TileContext], target: TargetCell): Any = ctx match {
case Some(ctx) => ctx.toProjectRasterTile(op(tile, ctx, target)).toInternalRow
case None => new NotImplementedError("Surface operation requires ProjectedRasterTile")
}

override def nodeName: String = Aspect.name

def op(t: Tile, ctx: TileContext): Tile = t match {
case bt: BufferTile => bt.aspect(CellSize(ctx.extent, cols = t.cols, rows = t.rows))
case _ => t.aspect(CellSize(ctx.extent, cols = t.cols, rows = t.rows))
def op(t: Tile, ctx: TileContext, target: TargetCell): Tile = t match {
case bt: BufferTile => bt.aspect(CellSize(ctx.extent, cols = t.cols, rows = t.rows), target = target)
case _ => t.aspect(CellSize(ctx.extent, cols = t.cols, rows = t.rows), target = target)
}
}

object Aspect {
def name: String = "rf_aspect"
def apply(tile: Column): Column = new Column(Aspect(tile.expr))
def apply(tile: Column, target: Column): Column = new Column(Aspect(tile.expr, target.expr))
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,59 +22,62 @@
package org.locationtech.rasterframes.expressions.focalops

import com.typesafe.scalalogging.Logger
import geotrellis.raster.{BufferTile, Tile}
import geotrellis.raster.{BufferTile, TargetCell, Tile}
import geotrellis.raster.mapalgebra.focal.Kernel
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, TernaryExpression}
import org.apache.spark.sql.types.DataType
import org.locationtech.rasterframes._
import org.locationtech.rasterframes.encoders._
import org.locationtech.rasterframes.encoders.syntax._
import org.locationtech.rasterframes.expressions.DynamicExtractors.tileExtractor
import org.locationtech.rasterframes.expressions.DynamicExtractors.{targetCellExtractor, tileExtractor}
import org.locationtech.rasterframes.expressions.{RasterResult, row}
import org.slf4j.LoggerFactory

@ExpressionDescription(
usage = "_FUNC_(tile, kernel) - Performs convolve on tile in the neighborhood.",
usage = "_FUNC_(tile, kernel, target) - Performs convolve on tile in the neighborhood.",
arguments = """
Arguments:
* tile - a tile to apply operation
* kernel - a focal operation kernel""",
* kernel - a focal operation kernel
* target - the target cells to apply focal operation: data, nodata, all""",
examples = """
Examples:
> SELECT _FUNC_(tile, kernel);
> SELECT _FUNC_(tile, kernel, 'all');
..."""
)
case class Convolve(left: Expression, right: Expression) extends BinaryExpression with RasterResult with CodegenFallback {
case class Convolve(left: Expression, middle: Expression, right: Expression) extends TernaryExpression with RasterResult with CodegenFallback {
@transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName))

override def nodeName: String = Convolve.name

def dataType: DataType = left.dataType
val children: Seq[Expression] = Seq(left, middle, right)

override def checkInputDataTypes(): TypeCheckResult =
if (!tileExtractor.isDefinedAt(left.dataType)) TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a raster type.")
else if (!right.dataType.conformsToSchema(kernelEncoder.schema)) {
TypeCheckFailure(s"Input type '${right.dataType}' does not conform to a kernel type.")
} else TypeCheckSuccess
else if (!middle.dataType.conformsToSchema(kernelEncoder.schema)) TypeCheckFailure(s"Input type '${middle.dataType}' does not conform to a Kernel type.")
else if (!targetCellExtractor.isDefinedAt(right.dataType)) TypeCheckFailure(s"Input type '${right.dataType}' does not conform to a TargetCell type.")
else TypeCheckSuccess

override protected def nullSafeEval(tileInput: Any, kernelInput: Any): Any = {
override protected def nullSafeEval(tileInput: Any, kernelInput: Any, targetCellInput: Any): Any = {
val (tile, ctx) = tileExtractor(left.dataType)(row(tileInput))
val kernel = row(kernelInput).as[Kernel]
val result = op(extractBufferTile(tile), kernel)
val target = targetCellExtractor(right.dataType)(targetCellInput)
val result = op(extractBufferTile(tile), kernel, target)
toInternalRow(result, ctx)
}

protected def op(t: Tile, kernel: Kernel): Tile = t match {
case bt: BufferTile => bt.convolve(kernel)
case _ => t.convolve(kernel)
protected def op(t: Tile, kernel: Kernel, target: TargetCell): Tile = t match {
case bt: BufferTile => bt.convolve(kernel, target = target)
case _ => t.convolve(kernel, target = target)
}
}

object Convolve {
def name: String = "rf_convolve"
def apply(tile: Column, kernel: Column): Column = new Column(Convolve(tile.expr, kernel.expr))
def apply(tile: Column, kernel: Column, target: Column): Column = new Column(Convolve(tile.expr, kernel.expr, target.expr))
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,31 +21,32 @@

package org.locationtech.rasterframes.expressions.focalops

import geotrellis.raster.{BufferTile, Tile}
import geotrellis.raster.{BufferTile, TargetCell, Tile}
import geotrellis.raster.mapalgebra.focal.Neighborhood
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}

@ExpressionDescription(
usage = "_FUNC_(tile, neighborhood) - Performs focalMax on tile in the neighborhood.",
usage = "_FUNC_(tile, neighborhood, target) - Performs focalMax on tile in the neighborhood.",
arguments = """
Arguments:
* tile - a tile to apply operation
* neighborhood - a focal operation neighborhood""",
* neighborhood - a focal operation neighborhood
* target - the target cells to apply focal operation: data, nodata, all""",
examples = """
Examples:
> SELECT _FUNC_(tile, 'square-1');
> SELECT _FUNC_(tile, 'square-1', 'all');
..."""
)
case class FocalMax(left: Expression, right: Expression) extends FocalNeighborhoodOp {
case class FocalMax(left: Expression, middle: Expression, right: Expression) extends FocalNeighborhoodOp {
override def nodeName: String = FocalMax.name
protected def op(t: Tile, neighborhood: Neighborhood): Tile = t match {
case bt: BufferTile => bt.focalMax(neighborhood)
case _ => t.focalMax(neighborhood)
protected def op(t: Tile, neighborhood: Neighborhood, target: TargetCell): Tile = t match {
case bt: BufferTile => bt.focalMax(neighborhood, target = target)
case _ => t.focalMax(neighborhood, target = target)
}
}

object FocalMax {
def name: String = "rf_focal_max"
def apply(tile: Column, neighborhood: Column): Column = new Column(FocalMax(tile.expr, neighborhood.expr))
def apply(tile: Column, neighborhood: Column, target: Column): Column = new Column(FocalMax(tile.expr, neighborhood.expr, target.expr))
}
Loading

0 comments on commit b97daee

Please sign in to comment.