Skip to content

Commit

Permalink
Expressions constructors toSeq conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
pomadchin committed Jan 3, 2023
1 parent 43e8d3d commit a2d5a7a
Showing 1 changed file with 115 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ package org.locationtech.rasterframes
import geotrellis.raster.{DoubleConstantNoDataCellType, Tile}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ExpressionInfo, ScalaUDF}
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, ScalaReflection}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.SQLContext
Expand All @@ -36,19 +36,23 @@ import org.locationtech.rasterframes.expressions.localops._
import org.locationtech.rasterframes.expressions.focalops._
import org.locationtech.rasterframes.expressions.tilestats._
import org.locationtech.rasterframes.expressions.transformers._
import shapeless.HList
import shapeless.ops.function.FnToProduct
import shapeless.ops.traversable.FromTraversable

import scala.reflect.ClassTag
import scala.reflect.runtime.universe._
import scala.language.implicitConversions

/**
* Module of Catalyst expressions for efficiently working with tiles.
*
* @since 10/10/17
*/
package object expressions {
type HasTernaryExpressionCopy = {def copy(first: Expression, second: Expression, third: Expression): Expression}
type HasBinaryExpressionCopy = {def copy(left: Expression, right: Expression): Expression}
type HasUnaryExpressionCopy = {def copy(child: Expression): Expression}
type HasTernaryExpressionCopy = { def copy(first: Expression, second: Expression, third: Expression): Expression }
type HasBinaryExpressionCopy = { def copy(left: Expression, right: Expression): Expression }
type HasUnaryExpressionCopy = { def copy(child: Expression): Expression }

private[expressions] def row(input: Any) = input.asInstanceOf[InternalRow]
/** Convert the tile to a floating point type as needed for scalar operations. */
Expand All @@ -67,33 +71,6 @@ package object expressions {

}

private def expressionInfo[T : ClassTag](name: String, since: Option[String], database: Option[String]): ExpressionInfo = {
val clazz = scala.reflect.classTag[T].runtimeClass
val df = clazz.getAnnotation(classOf[ExpressionDescription])
if (df != null) {
if (df.extended().isEmpty) {
new ExpressionInfo(
clazz.getCanonicalName,
database.orNull,
name,
df.usage(),
df.arguments(),
df.examples(),
df.note(),
df.group(),
since.getOrElse(df.since()),
df.deprecated(),
df.source())
} else {
// This exists for the backward compatibility with old `ExpressionDescription`s defining
// the extended description in `extended()`.
new ExpressionInfo(clazz.getCanonicalName, database.orNull, name, df.usage(), df.extended())
}
} else {
new ExpressionInfo(clazz.getCanonicalName, name)
}
}

def register(sqlContext: SQLContext, database: Option[String] = None): Unit = {
val registry = sqlContext.sparkSession.sessionState.functionRegistry

Expand All @@ -103,127 +80,114 @@ package object expressions {
registry.registerFunction(id, info, builder)
}

def register1[T <: Expression : ClassTag](
name: String,
builder: Expression => T
): Unit = registerFunction[T](name, None){ args => builder(args(0))
/** Converts (expr1: Expression, ..., exprn: Expression) => R into a Seq[Expression] => R */
implicit def expressionArgumentsSequencer[F, I <: HList, R](f: F)(implicit ftp: FnToProduct.Aux[F, I => R], ft: FromTraversable[I]): Seq[Expression] => R = { list: Seq[Expression] =>
ft(list) match {
case Some(l) => ftp(f)(l)
case None => throw new IllegalArgumentException(s"registerFunction application failed: arity mismatch: $list.")
}
}

def register2[T <: Expression : ClassTag](
name: String,
builder: (Expression, Expression) => T
): Unit = registerFunction[T](name, None){ args => builder(args(0), args(1)) }

def register3[T <: Expression : ClassTag](
name: String,
builder: (Expression, Expression, Expression) => T
): Unit = registerFunction[T](name, None){ args => builder(args(0), args(1), args(2)) }

def register5[T <: Expression : ClassTag](
name: String,
builder: (Expression, Expression, Expression, Expression, Expression) => T
): Unit = registerFunction[T](name, None){ args => builder(args(0), args(1), args(2), args(3), args(4)) }

register2("rf_local_add", Add(_, _))
register2("rf_local_subtract", Subtract(_, _))
registerFunction("rf_explode_tiles"){ExplodeTiles(1.0, None, _)}
register5("rf_assemble_tile", TileAssembler(_, _, _, _, _))
register1("rf_cell_type", GetCellType(_))
register2("rf_convert_cell_type", SetCellType(_, _))
register2("rf_interpret_cell_type_as", InterpretAs(_, _))
register2("rf_with_no_data", SetNoDataValue(_,_))
register1("rf_dimensions", GetDimensions(_))
register1("st_geometry", ExtentToGeometry(_))
register1("rf_geometry", GetGeometry(_))
register1("st_extent", GeometryToExtent(_))
register1("rf_extent", GetExtent(_))
register1("rf_crs", GetCRS(_))
register1("rf_tile", RealizeTile(_))
register3("rf_proj_raster", CreateProjectedRaster(_, _, _))
register2("rf_local_multiply", Multiply(_, _))
register2("rf_local_divide", Divide(_, _))
register2("rf_normalized_difference", NormalizedDifference(_,_))
register2("rf_local_less", Less(_, _))
register2("rf_local_greater", Greater(_, _))
register2("rf_local_less_equal", LessEqual(_, _))
register2("rf_local_greater_equal", GreaterEqual(_, _))
register2("rf_local_equal", Equal(_, _))
register2("rf_local_unequal", Unequal(_, _))
register2("rf_local_is_in", IsIn(_, _))
register1("rf_local_no_data", Undefined(_))
register1("rf_local_data", Defined(_))
register2("rf_local_min", Min(_, _))
register2("rf_local_max", Max(_, _))
register3("rf_local_clamp", Clamp(_, _, _))
register3("rf_where", Where(_, _, _))
register3("rf_standardize", Standardize(_, _, _))
register3("rf_rescale", Rescale(_, _ , _))
register1("rf_tile_sum", Sum(_))
register1("rf_round", Round(_))
register1("rf_abs", Abs(_))
register1("rf_log", Log(_))
register1("rf_log10", Log10(_))
register1("rf_log2", Log2(_))
register1("rf_log1p", Log1p(_))
register1("rf_exp", Exp(_))
register1("rf_exp10", Exp10(_))
register1("rf_exp2", Exp2(_))
register1("rf_expm1", ExpM1(_))
register1("rf_sqrt", Sqrt(_))
register3("rf_resample", Resample(_, _, _))
register2("rf_resample_nearest", ResampleNearest(_, _))
register1("rf_tile_to_array_double", TileToArrayDouble(_))
register1("rf_tile_to_array_int", TileToArrayInt(_))
register1("rf_data_cells", DataCells(_))
register1("rf_no_data_cells", NoDataCells(_))
register1("rf_is_no_data_tile", IsNoDataTile(_))
register1("rf_exists", Exists(_))
register1("rf_for_all", ForAll(_))
register1("rf_tile_min", TileMin(_))
register1("rf_tile_max", TileMax(_))
register1("rf_tile_mean", TileMean(_))
register1("rf_tile_stats", TileStats(_))
register1("rf_tile_histogram", TileHistogram(_))
register1("rf_agg_data_cells", DataCells(_))
register1("rf_agg_no_data_cells", CellCountAggregate.NoDataCells(_))
register1("rf_agg_stats", CellStatsAggregate.CellStatsAggregateUDAF(_))
register1("rf_agg_approx_histogram", HistogramAggregate.HistogramAggregateUDAF(_))
register1("rf_agg_local_stats", LocalStatsAggregate.LocalStatsAggregateUDAF(_))
register1("rf_agg_local_min",LocalTileOpAggregate.LocalMinUDAF(_))
register1("rf_agg_local_max", LocalTileOpAggregate.LocalMaxUDAF(_))
register1("rf_agg_local_data_cells", LocalCountAggregate.LocalDataCellsUDAF(_))
register1("rf_agg_local_no_data_cells", LocalCountAggregate.LocalNoDataCellsUDAF(_))
register1("rf_agg_local_mean", LocalMeanAggregate(_))
register3(FocalMax.name, FocalMax(_, _, _))
register3(FocalMin.name, FocalMin(_, _, _))
register3(FocalMean.name, FocalMean(_, _, _))
register3(FocalMode.name, FocalMode(_, _, _))
register3(FocalMedian.name, FocalMedian(_, _, _))
register3(FocalMoransI.name, FocalMoransI(_, _, _))
register3(FocalStdDev.name, FocalStdDev(_, _, _))
register3(Convolve.name, Convolve(_, _, _))

register3(Slope.name, Slope(_, _, _))
register2(Aspect.name, Aspect(_, _))
register5(Hillshade.name, Hillshade(_, _, _, _, _))

register2("rf_mask", MaskByDefined(_, _))
register2("rf_inverse_mask", InverseMaskByDefined(_, _))
register3("rf_mask_by_value", MaskByValue(_, _, _))
register3("rf_inverse_mask_by_value", InverseMaskByValue(_, _, _))
register3("rf_mask_by_values", MaskByValues(_, _, _))

register1("rf_render_ascii", DebugRender.RenderAscii(_))
register1("rf_render_matrix", DebugRender.RenderMatrix(_))
register1("rf_render_png", RenderPNG.RenderCompositePNG(_))
register3("rf_rgb_composite", RGBComposite(_, _, _))

register2("rf_xz2_index", XZ2Indexer(_, _, 18.toShort))
register2("rf_z2_index", Z2Indexer(_, _, 31.toShort))

register3("st_reproject", ReprojectGeometry(_, _, _))

register3[ExtractBits]("rf_local_extract_bits", ExtractBits(_: Expression, _: Expression, _: Expression))
register3[ExtractBits]("rf_local_extract_bit", ExtractBits(_: Expression, _: Expression, _: Expression))
registerFunction[Add](name = "rf_local_add")(Add.apply)
registerFunction[Subtract](name = "rf_local_subtract")(Subtract.apply)
registerFunction[ExplodeTiles](name = "rf_explode_tiles")(ExplodeTiles(1.0, None, _))
registerFunction[TileAssembler](name = "rf_assemble_tile")(TileAssembler.apply)
registerFunction[GetCellType](name = "rf_cell_type")(GetCellType.apply)
registerFunction[SetCellType](name = "rf_convert_cell_type")(SetCellType.apply)
registerFunction[InterpretAs](name = "rf_interpret_cell_type_as")(InterpretAs.apply)
registerFunction[SetNoDataValue](name = "rf_with_no_data")(SetNoDataValue.apply)
registerFunction[GetDimensions](name = "rf_dimensions")(GetDimensions.apply)
registerFunction[ExtentToGeometry](name = "st_geometry")(ExtentToGeometry.apply)
registerFunction[GetGeometry](name = "rf_geometry")(GetGeometry.apply)
registerFunction[GeometryToExtent](name = "st_extent")(GeometryToExtent.apply)
registerFunction[GetExtent](name = "rf_extent")(GetExtent.apply)
registerFunction[GetCRS](name = "rf_crs")(GetCRS.apply)
registerFunction[RealizeTile](name = "rf_tile")(RealizeTile.apply)
registerFunction[CreateProjectedRaster](name = "rf_proj_raster")(CreateProjectedRaster.apply)
registerFunction[Multiply](name = "rf_local_multiply")(Multiply.apply)
registerFunction[Divide](name = "rf_local_divide")(Divide.apply)
registerFunction[NormalizedDifference](name = "rf_normalized_difference")(NormalizedDifference.apply)
registerFunction[Less](name = "rf_local_less")(Less.apply)
registerFunction[Greater](name = "rf_local_greater")(Greater.apply)
registerFunction[LessEqual](name = "rf_local_less_equal")(LessEqual.apply)
registerFunction[GreaterEqual](name = "rf_local_greater_equal")(GreaterEqual.apply)
registerFunction[Equal](name = "rf_local_equal")(Equal.apply)
registerFunction[Unequal](name = "rf_local_unequal")(Unequal.apply)
registerFunction[IsIn](name = "rf_local_is_in")(IsIn.apply)
registerFunction[Undefined](name = "rf_local_no_data")(Undefined.apply)
registerFunction[Defined](name = "rf_local_data")(Defined.apply)
registerFunction[Min](name = "rf_local_min")(Min.apply)
registerFunction[Max](name = "rf_local_max")(Max.apply)
registerFunction[Clamp](name = "rf_local_clamp")(Clamp.apply)
registerFunction[Where](name = "rf_where")(Where.apply)
registerFunction[Standardize](name = "rf_standardize")(Standardize.apply)
registerFunction[Rescale](name = "rf_rescale")(Rescale.apply)
registerFunction[Sum](name = "rf_tile_sum")(Sum.apply)
registerFunction[Round](name = "rf_round")(Round.apply)
registerFunction[Abs](name = "rf_abs")(Abs.apply)
registerFunction[Log](name = "rf_log")(Log.apply)
registerFunction[Log10](name = "rf_log10")(Log10.apply)
registerFunction[Log2](name = "rf_log2")(Log2.apply)
registerFunction[Log1p](name = "rf_log1p")(Log1p.apply)
registerFunction[Exp](name = "rf_exp")(Exp.apply)
registerFunction[Exp10](name = "rf_exp10")(Exp10.apply)
registerFunction[Exp2](name = "rf_exp2")(Exp2.apply)
registerFunction[ExpM1](name = "rf_expm1")(ExpM1.apply)
registerFunction[Sqrt](name = "rf_sqrt")(Sqrt.apply)
registerFunction[Resample](name = "rf_resample")(Resample.apply)
registerFunction[ResampleNearest](name = "rf_resample_nearest")(ResampleNearest.apply)
registerFunction[TileToArrayDouble](name = "rf_tile_to_array_double")(TileToArrayDouble.apply)
registerFunction[TileToArrayInt](name = "rf_tile_to_array_int")(TileToArrayInt.apply)
registerFunction[DataCells](name = "rf_data_cells")(DataCells.apply)
registerFunction[NoDataCells](name = "rf_no_data_cells")(NoDataCells.apply)
registerFunction[IsNoDataTile](name = "rf_is_no_data_tile")(IsNoDataTile.apply)
registerFunction[Exists](name = "rf_exists")(Exists.apply)
registerFunction[ForAll](name = "rf_for_all")(ForAll.apply)
registerFunction[TileMin](name = "rf_tile_min")(TileMin.apply)
registerFunction[TileMax](name = "rf_tile_max")(TileMax.apply)
registerFunction[TileMean](name = "rf_tile_mean")(TileMean.apply)
registerFunction[TileStats](name = "rf_tile_stats")(TileStats.apply)
registerFunction[TileHistogram](name = "rf_tile_histogram")(TileHistogram.apply)
registerFunction[DataCells](name = "rf_agg_data_cells")(DataCells.apply)
registerFunction[CellCountAggregate.NoDataCells](name = "rf_agg_no_data_cells")(CellCountAggregate.NoDataCells.apply)
registerFunction[CellStatsAggregate.CellStatsAggregateUDAF](name = "rf_agg_stats")(CellStatsAggregate.CellStatsAggregateUDAF.apply)
registerFunction[HistogramAggregate.HistogramAggregateUDAF](name = "rf_agg_approx_histogram")(HistogramAggregate.HistogramAggregateUDAF.apply)
registerFunction[LocalStatsAggregate.LocalStatsAggregateUDAF](name = "rf_agg_local_stats")(LocalStatsAggregate.LocalStatsAggregateUDAF.apply)
registerFunction[LocalTileOpAggregate.LocalMinUDAF](name = "rf_agg_local_min")(LocalTileOpAggregate.LocalMinUDAF.apply)
registerFunction[LocalTileOpAggregate.LocalMaxUDAF](name = "rf_agg_local_max")(LocalTileOpAggregate.LocalMaxUDAF.apply)
registerFunction[LocalCountAggregate.LocalDataCellsUDAF](name = "rf_agg_local_data_cells")(LocalCountAggregate.LocalDataCellsUDAF.apply)
registerFunction[LocalCountAggregate.LocalNoDataCellsUDAF](name = "rf_agg_local_no_data_cells")(LocalCountAggregate.LocalNoDataCellsUDAF.apply)
registerFunction[LocalMeanAggregate](name = "rf_agg_local_mean")(LocalMeanAggregate.apply)
registerFunction[FocalMax](FocalMax.name)(FocalMax.apply)
registerFunction[FocalMin](FocalMin.name)(FocalMin.apply)
registerFunction[FocalMean](FocalMean.name)(FocalMean.apply)
registerFunction[FocalMode](FocalMode.name)(FocalMode.apply)
registerFunction[FocalMedian](FocalMedian.name)(FocalMedian.apply)
registerFunction[FocalMoransI](FocalMoransI.name)(FocalMoransI.apply)
registerFunction[FocalStdDev](FocalStdDev.name)(FocalStdDev.apply)
registerFunction[Convolve](Convolve.name)(Convolve.apply)

registerFunction[Slope](Slope.name)(Slope.apply)
registerFunction[Aspect](Aspect.name)(Aspect.apply)
registerFunction[Hillshade](Hillshade.name)(Hillshade.apply)

registerFunction[MaskByDefined](name = "rf_mask")(MaskByDefined.apply)
registerFunction[InverseMaskByDefined](name = "rf_inverse_mask")(InverseMaskByDefined.apply)
registerFunction[MaskByValue](name = "rf_mask_by_value")(MaskByValue.apply)
registerFunction[InverseMaskByValue](name = "rf_inverse_mask_by_value")(InverseMaskByValue.apply)
registerFunction[MaskByValues](name = "rf_mask_by_values")(MaskByValues.apply)

registerFunction[DebugRender.RenderAscii](name = "rf_render_ascii")(DebugRender.RenderAscii.apply)
registerFunction[DebugRender.RenderMatrix](name = "rf_render_matrix")(DebugRender.RenderMatrix.apply)
registerFunction[RenderPNG.RenderCompositePNG](name = "rf_render_png")(RenderPNG.RenderCompositePNG.apply)
registerFunction[RGBComposite](name = "rf_rgb_composite")(RGBComposite.apply)

registerFunction[XZ2Indexer](name = "rf_xz2_index")(XZ2Indexer(_: Expression, _: Expression, 18.toShort))
registerFunction[Z2Indexer](name = "rf_z2_index")(Z2Indexer(_: Expression, _: Expression, 31.toShort))

registerFunction[ReprojectGeometry](name = "st_reproject")(ReprojectGeometry.apply)

registerFunction[ExtractBits]("rf_local_extract_bits")(ExtractBits(_: Expression, _: Expression, _: Expression))
registerFunction[ExtractBits]("rf_local_extract_bit")(ExtractBits(_: Expression, _: Expression, _: Expression))
}
}

0 comments on commit a2d5a7a

Please sign in to comment.