From 30e700400661cd80fcb04748444ec5326616c30a Mon Sep 17 00:00:00 2001 From: "Simeon H.K. Fitch" Date: Mon, 14 Mar 2022 17:33:28 -0400 Subject: [PATCH 01/34] bumped dev version --- version.sbt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.sbt b/version.sbt index 1f9ae9ed8..e1b9bbdca 100644 --- a/version.sbt +++ b/version.sbt @@ -1 +1 @@ -ThisBuild / version := "0.10.1" +ThisBuild / version := "0.10.2-SNAPSHOT" From 0b4309e744541e0fac51045524435653b3b619ce Mon Sep 17 00:00:00 2001 From: "Simeon H.K. Fitch" Date: Thu, 7 Apr 2022 09:22:32 -0400 Subject: [PATCH 02/34] CI fix. --- build.sbt | 2 +- project/RFDependenciesPlugin.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/build.sbt b/build.sbt index d54d0cac8..ba9f00115 100644 --- a/build.sbt +++ b/build.sbt @@ -19,7 +19,7 @@ * */ -// Leave me an my custom keys alone! +// Leave me and my custom keys alone! Global / lintUnusedKeysOnLoad := false addCommandAlias("makeSite", "docs/makeSite") diff --git a/project/RFDependenciesPlugin.scala b/project/RFDependenciesPlugin.scala index 8929ab69a..95c19b1c7 100644 --- a/project/RFDependenciesPlugin.scala +++ b/project/RFDependenciesPlugin.scala @@ -56,7 +56,7 @@ object RFDependenciesPlugin extends AutoPlugin { val sttpCatsCe2 = "com.softwaremill.sttp.client3" %% "async-http-client-backend-cats-ce2" % "3.3.15" val frameless = "org.typelevel" %% "frameless-dataset-spark31" % "0.11.1" val framelessRefined = "org.typelevel" %% "frameless-refined-spark31" % "0.11.1" - val `better-files` = "com.github.pathikrit" %% "better-files" % "3.9.1" % Test + val `better-files` = "com.github.pathikrit" %% "better-files" % "3.9.1" } import autoImport._ From 80e69928411a889bec4aea0efba577e17817411a Mon Sep 17 00:00:00 2001 From: "Simeon H.K. Fitch" Date: Thu, 7 Apr 2022 11:20:17 -0400 Subject: [PATCH 03/34] Dependency updates. --- project/RFDependenciesPlugin.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/project/RFDependenciesPlugin.scala b/project/RFDependenciesPlugin.scala index 95c19b1c7..c0e9b5b33 100644 --- a/project/RFDependenciesPlugin.scala +++ b/project/RFDependenciesPlugin.scala @@ -46,14 +46,14 @@ object RFDependenciesPlugin extends AutoPlugin { } } val scalatest = "org.scalatest" %% "scalatest" % "3.2.5" % Test - val shapeless = "com.chuusai" %% "shapeless" % "2.3.7" - val `jts-core` = "org.locationtech.jts" % "jts-core" % "1.17.0" - val `slf4j-api` = "org.slf4j" % "slf4j-api" % "1.7.28" - val scaffeine = "com.github.blemale" %% "scaffeine" % "4.0.2" - val `spray-json` = "io.spray" %% "spray-json" % "1.3.4" - val `scala-logging` = "com.typesafe.scala-logging" %% "scala-logging" % "3.8.0" + val shapeless = "com.chuusai" %% "shapeless" % "2.3.9" + val `jts-core` = "org.locationtech.jts" % "jts-core" % "1.18.2" + val `slf4j-api` = "org.slf4j" % "slf4j-api" % "1.7.36" + val scaffeine = "com.github.blemale" %% "scaffeine" % "5.1.2" + val `spray-json` = "io.spray" %% "spray-json" % "1.3.6" + val `scala-logging` = "com.typesafe.scala-logging" %% "scala-logging" % "3.9.4" val stac4s = "com.azavea.stac4s" %% "client" % "0.7.2" - val sttpCatsCe2 = "com.softwaremill.sttp.client3" %% "async-http-client-backend-cats-ce2" % "3.3.15" + val sttpCatsCe2 = "com.softwaremill.sttp.client3" %% "async-http-client-backend-cats-ce2" % "3.5.1" val frameless = "org.typelevel" %% "frameless-dataset-spark31" % "0.11.1" val framelessRefined = "org.typelevel" %% "frameless-refined-spark31" % "0.11.1" val `better-files` = "com.github.pathikrit" %% "better-files" % "3.9.1" From d0e5bd588f4b25a03288e975e63d320dbdca252a Mon Sep 17 00:00:00 2001 From: "Simeon H.K. Fitch" Date: Wed, 4 May 2022 11:20:21 -0400 Subject: [PATCH 04/34] Spark 3.1.3 --- project/RFDependenciesPlugin.scala | 2 +- pyrasterframes/src/main/python/setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/project/RFDependenciesPlugin.scala b/project/RFDependenciesPlugin.scala index c0e9b5b33..bfad75f74 100644 --- a/project/RFDependenciesPlugin.scala +++ b/project/RFDependenciesPlugin.scala @@ -70,7 +70,7 @@ object RFDependenciesPlugin extends AutoPlugin { "jitpack" at "https://jitpack.io" ), // NB: Make sure to update the Spark version in pyrasterframes/python/setup.py - rfSparkVersion := "3.1.2", + rfSparkVersion := "3.1.3", rfGeoTrellisVersion := "3.6.1", rfGeoMesaVersion := "3.2.0", excludeDependencies += "log4j" % "log4j" diff --git a/pyrasterframes/src/main/python/setup.py b/pyrasterframes/src/main/python/setup.py index d7a665cdf..4032d23eb 100644 --- a/pyrasterframes/src/main/python/setup.py +++ b/pyrasterframes/src/main/python/setup.py @@ -140,7 +140,7 @@ def dest_file(self, src_file): # to throw a `NotImplementedError: Can't perform this operation for unregistered loader type` pytest = 'pytest>=4.0.0,<5.0.0' -pyspark = 'pyspark==3.1.2' +pyspark = 'pyspark==3.1.3' boto3 = 'boto3' deprecation = 'deprecation' descartes = 'descartes' From ff00b4137976193a309efc97ed9f9da5f28ea8e9 Mon Sep 17 00:00:00 2001 From: "Simeon H.K. Fitch" Date: Mon, 14 Mar 2022 17:33:28 -0400 Subject: [PATCH 05/34] bumped dev version Spark 3.2 Lets get it compiled, spark 2 support is well out the window anyway. --- .../apache/spark/sql/rf/VersionShims.scala | 119 +++--------------- project/RFDependenciesPlugin.scala | 10 +- project/build.properties | 2 +- 3 files changed, 20 insertions(+), 111 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/sql/rf/VersionShims.scala b/core/src/main/scala/org/apache/spark/sql/rf/VersionShims.scala index bb05573d1..511a4dc49 100644 --- a/core/src/main/scala/org/apache/spark/sql/rf/VersionShims.scala +++ b/core/src/main/scala/org/apache/spark/sql/rf/VersionShims.scala @@ -1,21 +1,17 @@ package org.apache.spark.sql.rf import java.lang.reflect.Constructor - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, FunctionRegistryBase} +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder} import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, InvokeLike} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionDescription, ExpressionInfo} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.DataType import scala.reflect._ -import scala.util.{Failure, Success, Try} /** * Collection of Spark version compatibility adapters. @@ -27,18 +23,6 @@ object VersionShims { val lrClazz = classOf[LogicalRelation] val ctor = lrClazz.getConstructors.head.asInstanceOf[Constructor[LogicalRelation]] ctor.getParameterTypes.length match { - // In Spark 2.1.0 the signature looks like this: - // - // case class LogicalRelation( - // relation: BaseRelation, - // expectedOutputAttributes: Option[Seq[Attribute]] = None, - // catalogTable: Option[CatalogTable] = None) - // extends LeafNode with MultiInstanceRelation - // In Spark 2.2.0 it's like this: - // case class LogicalRelation( - // relation: BaseRelation, - // output: Seq[AttributeReference], - // catalogTable: Option[CatalogTable]) case 3 => val arg2: Seq[AttributeReference] = lr.output val arg3: Option[CatalogTable] = lr.catalogTable @@ -49,14 +33,6 @@ object VersionShims { ctor.newInstance(base, arg2, arg3) } - // In Spark 2.3.0 this signature is this: - // - // case class LogicalRelation( - // relation: BaseRelation, - // output: Seq[AttributeReference], - // catalogTable: Option[CatalogTable], - // override val isStreaming: Boolean) - // extends LeafNode with MultiInstanceRelation { case 4 => val arg2: Seq[AttributeReference] = lr.output val arg3: Option[CatalogTable] = lr.catalogTable @@ -75,25 +51,8 @@ object VersionShims { val ctor = classOf[Invoke].getConstructors.head val TRUE = Boolean.box(true) ctor.getParameterTypes.length match { - // In Spark 2.1.0 the signature looks like this: - // - // case class Invoke( - // targetObject: Expression, - // functionName: String, - // dataType: DataType, - // arguments: Seq[Expression] = Nil, - // propagateNull: Boolean = true) extends InvokeLike case 5 => ctor.newInstance(targetObject, functionName, dataType, Nil, TRUE).asInstanceOf[InvokeLike] - // In spark 2.2.0 the signature looks like this: - // - // case class Invoke( - // targetObject: Expression, - // functionName: String, - // dataType: DataType, - // arguments: Seq[Expression] = Nil, - // propagateNull: Boolean = true, - // returnNullable : Boolean = true) extends InvokeLike case 6 => ctor.newInstance(targetObject, functionName, dataType, Nil, TRUE, TRUE).asInstanceOf[InvokeLike] @@ -125,68 +84,18 @@ object VersionShims { } } - // Much of the code herein is copied from org.apache.spark.sql.catalyst.analysis.FunctionRegistry - def registerExpression[T <: Expression: ClassTag](name: String): Unit = { - val clazz = classTag[T].runtimeClass - - def expressionInfo: ExpressionInfo = { - val df = clazz.getAnnotation(classOf[ExpressionDescription]) - if (df != null) { - if (df.extended().isEmpty) { - new ExpressionInfo(clazz.getCanonicalName, null, name, df.usage(), df.arguments(), df.examples(), df.note(), df.group(), df.since(), df.deprecated()) - } else { - // This exists for the backward compatibility with old `ExpressionDescription`s defining - // the extended description in `extended()`. - new ExpressionInfo(clazz.getCanonicalName, null, name, df.usage(), df.extended()) - } - } else { - new ExpressionInfo(clazz.getCanonicalName, name) - } + def registerExpression[T <: Expression : ClassTag]( + name: String, + setAlias: Boolean = false, + since: Option[String] = None + ): (String, (ExpressionInfo, FunctionBuilder)) = { + val (expressionInfo, builder) = FunctionRegistryBase.build[T](name, since) + val newBuilder = (expressions: Seq[Expression]) => { + val expr = builder(expressions) + if (setAlias) expr.setTagValue(FUNC_ALIAS, name) + expr } - def findBuilder: FunctionBuilder = { - val constructors = clazz.getConstructors - // See if we can find a constructor that accepts Seq[Expression] - val varargCtor = constructors.find(_.getParameterTypes.toSeq == Seq(classOf[Seq[_]])) - val builder = (expressions: Seq[Expression]) => { - if (varargCtor.isDefined) { - // If there is an apply method that accepts Seq[Expression], use that one. - Try(varargCtor.get.newInstance(expressions).asInstanceOf[Expression]) match { - case Success(e) => e - case Failure(e) => - // the exception is an invocation exception. To get a meaningful message, we need the - // cause. - throw new AnalysisException(e.getCause.getMessage) - } - } else { - // Otherwise, find a constructor method that matches the number of arguments, and use that. - val params = Seq.fill(expressions.size)(classOf[Expression]) - val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { - val validParametersCount = constructors - .filter(_.getParameterTypes.forall(_ == classOf[Expression])) - .map(_.getParameterCount).distinct.sorted - val expectedNumberOfParameters = if (validParametersCount.length == 1) { - validParametersCount.head.toString - } else { - validParametersCount.init.mkString("one of ", ", ", " and ") + - validParametersCount.last - } - throw new AnalysisException(s"Invalid number of arguments for function ${clazz.getSimpleName}. " + - s"Expected: $expectedNumberOfParameters; Found: ${params.length}") - } - Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { - case Success(e) => e - case Failure(e) => - // the exception is an invocation exception. To get a meaningful message, we need the - // cause. - throw new AnalysisException(e.getCause.getMessage) - } - } - } - - builder - } - - registry.registerFunction(FunctionIdentifier(name), expressionInfo, findBuilder) + (name, (expressionInfo, newBuilder)) } } } diff --git a/project/RFDependenciesPlugin.scala b/project/RFDependenciesPlugin.scala index bfad75f74..ed6ab4dde 100644 --- a/project/RFDependenciesPlugin.scala +++ b/project/RFDependenciesPlugin.scala @@ -53,10 +53,10 @@ object RFDependenciesPlugin extends AutoPlugin { val `spray-json` = "io.spray" %% "spray-json" % "1.3.6" val `scala-logging` = "com.typesafe.scala-logging" %% "scala-logging" % "3.9.4" val stac4s = "com.azavea.stac4s" %% "client" % "0.7.2" - val sttpCatsCe2 = "com.softwaremill.sttp.client3" %% "async-http-client-backend-cats-ce2" % "3.5.1" - val frameless = "org.typelevel" %% "frameless-dataset-spark31" % "0.11.1" - val framelessRefined = "org.typelevel" %% "frameless-refined-spark31" % "0.11.1" - val `better-files` = "com.github.pathikrit" %% "better-files" % "3.9.1" + val sttpCatsCe2 = "com.softwaremill.sttp.client3" %% "async-http-client-backend-cats-ce2" % "3.3.15" + val frameless = "org.typelevel" %% "frameless-dataset-spark31" % "0.12.0" + val framelessRefined = "org.typelevel" %% "frameless-refined-spark31" % "0.12.0" + val `better-files` = "com.github.pathikrit" %% "better-files" % "3.9.1" % Test } import autoImport._ @@ -70,7 +70,7 @@ object RFDependenciesPlugin extends AutoPlugin { "jitpack" at "https://jitpack.io" ), // NB: Make sure to update the Spark version in pyrasterframes/python/setup.py - rfSparkVersion := "3.1.3", + rfSparkVersion := "3.2.1", rfGeoTrellisVersion := "3.6.1", rfGeoMesaVersion := "3.2.0", excludeDependencies += "log4j" % "log4j" diff --git a/project/build.properties b/project/build.properties index 10fd9eee0..c8fcab543 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.5.5 +sbt.version=1.6.2 From 7f5e078c6c3115c912570fa0160bbc8f1c6d2a28 Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Thu, 30 Jun 2022 00:58:01 -0400 Subject: [PATCH 06/34] withNewChildrenInternal --- .../expressions/BinaryRasterFunction.scala | 5 ++- .../expressions/OnCellGridExpression.scala | 5 ++- .../expressions/OnTileContextExpression.scala | 5 ++- .../expressions/SpatialRelation.scala | 10 ++++- .../expressions/TileAssembler.scala | 8 ++++ .../expressions/UnaryRasterFunction.scala | 6 ++- .../expressions/UnaryRasterOp.scala | 5 ++- .../expressions/accessors/GetCRS.scala | 1 + .../expressions/accessors/GetCellType.scala | 2 + .../expressions/accessors/GetEnvelope.scala | 2 + .../expressions/accessors/GetGeometry.scala | 1 + .../expressions/accessors/RealizeTile.scala | 2 + .../expressions/focalops/Aspect.scala | 2 + .../expressions/focalops/Convolve.scala | 18 +++++---- .../expressions/focalops/FocalMax.scala | 2 +- .../expressions/focalops/FocalMean.scala | 3 +- .../expressions/focalops/FocalMedian.scala | 2 +- .../expressions/focalops/FocalMin.scala | 2 +- .../expressions/focalops/FocalMode.scala | 2 +- .../expressions/focalops/FocalMoransI.scala | 2 +- .../focalops/FocalNeighborhoodOp.scala | 28 ++++++------- .../expressions/focalops/FocalStdDev.scala | 2 +- .../expressions/focalops/Hillshade.scala | 3 ++ .../expressions/focalops/Slope.scala | 20 +++++----- .../expressions/generators/ExplodeTiles.scala | 3 ++ .../generators/RasterSourceToRasterRefs.scala | 2 + .../generators/RasterSourceToTiles.scala | 2 + .../expressions/localops/Abs.scala | 1 + .../expressions/localops/Clamp.scala | 26 ++++++------- .../expressions/localops/Equal.scala | 1 + .../expressions/localops/IsIn.scala | 1 + .../localops/NormalizedDifference.scala | 2 + .../expressions/localops/Resample.scala | 11 +++++- .../expressions/localops/Where.scala | 31 +++++++-------- .../rasterframes/expressions/package.scala | 4 ++ .../transformers/CreateProjectedRaster.scala | 7 +++- .../transformers/DebugRender.scala | 4 +- .../transformers/ExtentToGeometry.scala | 2 + .../transformers/ExtractBits.scala | 26 ++++++------- .../transformers/GeometryToExtent.scala | 2 + .../transformers/InterpretAs.scala | 2 + .../expressions/transformers/Mask.scala | 39 ++++++++++++------- .../transformers/RGBComposite.scala | 7 +++- .../transformers/RasterRefToTile.scala | 2 + .../expressions/transformers/RenderPNG.scala | 5 ++- .../transformers/ReprojectGeometry.scala | 12 +++--- .../expressions/transformers/Rescale.scala | 26 ++++++------- .../transformers/SetCellType.scala | 2 + .../transformers/SetNoDataValue.scala | 2 + .../transformers/Standardize.scala | 26 ++++++------- .../transformers/URIToRasterSource.scala | 2 + .../expressions/transformers/XZ2Indexer.scala | 3 ++ .../expressions/transformers/Z2Indexer.scala | 3 ++ 53 files changed, 248 insertions(+), 146 deletions(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/BinaryRasterFunction.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/BinaryRasterFunction.scala index 425e6c4e7..edf61ea2b 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/BinaryRasterFunction.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/BinaryRasterFunction.scala @@ -25,13 +25,14 @@ import com.typesafe.scalalogging.Logger import geotrellis.raster.Tile import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.BinaryExpression +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression} import org.apache.spark.sql.types.DataType import org.locationtech.rasterframes.expressions.DynamicExtractors._ import org.slf4j.LoggerFactory /** Operation combining two tiles or a tile and a scalar into a new tile. */ -trait BinaryRasterFunction extends BinaryExpression with RasterResult { +trait BinaryRasterFunction extends BinaryExpression with RasterResult { self: HasBinaryExpressionCopy => + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) @transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/OnCellGridExpression.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/OnCellGridExpression.scala index 741a85a8e..c10df97c1 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/OnCellGridExpression.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/OnCellGridExpression.scala @@ -26,7 +26,7 @@ import geotrellis.raster.CellGrid import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.UnaryExpression +import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} /** * Implements boilerplate for subtype expressions processing TileUDT, RasterSourceUDT, and RasterRefs @@ -34,7 +34,8 @@ import org.apache.spark.sql.catalyst.expressions.UnaryExpression * * @since 11/4/18 */ -trait OnCellGridExpression extends UnaryExpression { +trait OnCellGridExpression extends UnaryExpression { self: HasUnaryExpressionCopy => + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) private lazy val fromRow: InternalRow => CellGrid[Int] = { if (child.resolved) gridExtractor(child.dataType) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/OnTileContextExpression.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/OnTileContextExpression.scala index 3767b4d0f..1c02b1a95 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/OnTileContextExpression.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/OnTileContextExpression.scala @@ -25,7 +25,7 @@ import org.locationtech.rasterframes.expressions.DynamicExtractors._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.UnaryExpression +import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} import org.locationtech.rasterframes.model.TileContext /** @@ -34,7 +34,8 @@ import org.locationtech.rasterframes.model.TileContext * * @since 11/3/18 */ -trait OnTileContextExpression extends UnaryExpression { +trait OnTileContextExpression extends UnaryExpression { self: HasUnaryExpressionCopy => + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) override def checkInputDataTypes(): TypeCheckResult = { if (!projectedRasterLikeExtractor.isDefinedAt(child.dataType)) { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/SpatialRelation.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/SpatialRelation.scala index bc6249d1d..a2589fd5b 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/SpatialRelation.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/SpatialRelation.scala @@ -39,7 +39,10 @@ import org.locationtech.geomesa.spark.jts.udf.SpatialRelationFunctions._ * * @since 12/28/17 */ -abstract class SpatialRelation extends BinaryExpression with CodegenFallback { +abstract class SpatialRelation extends BinaryExpression with CodegenFallback { this: HasBinaryExpressionCopy => + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = + copy(left = newLeft, right = newRight) def extractGeometry(expr: Expression, input: Any): Geometry = { input match { @@ -72,8 +75,11 @@ object SpatialRelation { type RelationPredicate = (Geometry, Geometry) => java.lang.Boolean case class Intersects(left: Expression, right: Expression) extends SpatialRelation { - override def nodeName = "intersects" + override def nodeName: String = "intersects" val relation = ST_Intersects + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = + copy(left = newLeft, right = newRight) } case class Contains(left: Expression, right: Expression) extends SpatialRelation { override def nodeName = "contains" diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/TileAssembler.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/TileAssembler.scala index ea187e662..9015513c8 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/TileAssembler.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/TileAssembler.scala @@ -140,6 +140,14 @@ case class TileAssembler( def serialize(buffer: TileBuffer): Array[Byte] = buffer.serialize() def deserialize(storageFormat: Array[Byte]): TileBuffer = new TileBuffer(storageFormat) + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy( + colIndex = newChildren(0), + rowIndex = newChildren(1), + cellValue = newChildren(2), + tileCols = newChildren(3), + tileRows = newChildren(4) + ) } object TileAssembler { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterFunction.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterFunction.scala index 6eb4e7a69..70a8180c8 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterFunction.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterFunction.scala @@ -25,11 +25,13 @@ import org.locationtech.rasterframes.expressions.DynamicExtractors._ import geotrellis.raster.Tile import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.UnaryExpression +import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} import org.locationtech.rasterframes.model.TileContext /** Boilerplate for expressions operating on a single Tile-like . */ -trait UnaryRasterFunction extends UnaryExpression { +trait UnaryRasterFunction extends UnaryExpression { self: HasUnaryExpressionCopy => + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) + override def checkInputDataTypes(): TypeCheckResult = { if (!tileExtractor.isDefinedAt(child.dataType)) { TypeCheckFailure(s"Input type '${child.dataType}' does not conform to a raster type.") diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterOp.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterOp.scala index dcb4871c8..da9232600 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterOp.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterOp.scala @@ -23,12 +23,13 @@ package org.locationtech.rasterframes.expressions import com.typesafe.scalalogging.Logger import geotrellis.raster.Tile +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.DataType import org.locationtech.rasterframes.model.TileContext import org.slf4j.LoggerFactory /** Operation on a tile returning a tile. */ -trait UnaryRasterOp extends UnaryRasterFunction with RasterResult { +trait UnaryRasterOp extends UnaryRasterFunction with RasterResult { this: HasUnaryExpressionCopy => @transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName)) def dataType: DataType = child.dataType @@ -37,5 +38,7 @@ trait UnaryRasterOp extends UnaryRasterFunction with RasterResult { toInternalRow(op(tile), ctx) protected def op(child: Tile): Tile + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetCRS.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetCRS.scala index 0ffc0d78e..1f5484b73 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetCRS.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetCRS.scala @@ -97,6 +97,7 @@ case class GetCRS(child: Expression) extends UnaryExpression with CodegenFallbac } } + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object GetCRS { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetCellType.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetCellType.scala index b5966733c..89180d757 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetCellType.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetCellType.scala @@ -55,6 +55,8 @@ case class GetCellType(child: Expression) extends OnCellGridExpression with Code /** Implemented by subtypes to process incoming ProjectedRasterLike entity. */ def eval(cg: CellGrid[Int]): Any = resultConverter(cg.cellType) + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object GetCellType { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetEnvelope.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetEnvelope.scala index 00ba62e83..67b32ce49 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetEnvelope.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetEnvelope.scala @@ -57,6 +57,8 @@ case class GetEnvelope(child: Expression) extends UnaryExpression with CodegenFa } def dataType: DataType = envelopeEncoder.schema + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object GetEnvelope { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetGeometry.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetGeometry.scala index 760263292..de8470180 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetGeometry.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetGeometry.scala @@ -49,6 +49,7 @@ case class GetGeometry(child: Expression) extends OnTileContextExpression with C override def nodeName: String = "rf_geometry" def eval(ctx: TileContext): InternalRow = JTSTypes.GeometryTypeInstance.serialize(ctx.extent.toPolygon()) + } object GetGeometry { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/RealizeTile.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/RealizeTile.scala index e5d9f9f45..9e37c62d6 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/RealizeTile.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/RealizeTile.scala @@ -56,6 +56,8 @@ case class RealizeTile(child: Expression) extends UnaryExpression with CodegenFa val tile = tileableExtractor(child.dataType)(in) tileSer(tile.toArrayTile()) } + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object RealizeTile { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/Aspect.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/Aspect.scala index 68083293b..385051443 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/Aspect.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/Aspect.scala @@ -73,6 +73,8 @@ case class Aspect(left: Expression, right: Expression) extends BinaryExpression case bt: BufferTile => bt.aspect(CellSize(ctx.extent, cols = t.cols, rows = t.rows), target = target) case _ => t.aspect(CellSize(ctx.extent, cols = t.cols, rows = t.rows), target = target) } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object Aspect { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/Convolve.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/Convolve.scala index 2d6cc1638..91dd13b95 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/Convolve.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/Convolve.scala @@ -49,24 +49,23 @@ import org.slf4j.LoggerFactory > SELECT _FUNC_(tile, kernel, 'all'); ...""" ) -case class Convolve(left: Expression, middle: Expression, right: Expression) extends TernaryExpression with RasterResult with CodegenFallback { +case class Convolve(first: Expression, second: Expression, third: Expression) extends TernaryExpression with RasterResult with CodegenFallback { @transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName)) override def nodeName: String = Convolve.name - def dataType: DataType = left.dataType - val children: Seq[Expression] = Seq(left, middle, right) + def dataType: DataType = first.dataType override def checkInputDataTypes(): TypeCheckResult = - if (!tileExtractor.isDefinedAt(left.dataType)) TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a raster type.") - else if (!middle.dataType.conformsToSchema(kernelEncoder.schema)) TypeCheckFailure(s"Input type '${middle.dataType}' does not conform to a Kernel type.") - else if (!targetCellExtractor.isDefinedAt(right.dataType)) TypeCheckFailure(s"Input type '${right.dataType}' does not conform to a TargetCell type.") + if (!tileExtractor.isDefinedAt(first.dataType)) TypeCheckFailure(s"Input type '${first.dataType}' does not conform to a raster type.") + else if (!second.dataType.conformsToSchema(kernelEncoder.schema)) TypeCheckFailure(s"Input type '${second.dataType}' does not conform to a Kernel type.") + else if (!targetCellExtractor.isDefinedAt(third.dataType)) TypeCheckFailure(s"Input type '${third.dataType}' does not conform to a TargetCell type.") else TypeCheckSuccess override protected def nullSafeEval(tileInput: Any, kernelInput: Any, targetCellInput: Any): Any = { - val (tile, ctx) = tileExtractor(left.dataType)(row(tileInput)) + val (tile, ctx) = tileExtractor(first.dataType)(row(tileInput)) val kernel = row(kernelInput).as[Kernel] - val target = targetCellExtractor(right.dataType)(targetCellInput) + val target = targetCellExtractor(third.dataType)(targetCellInput) val result = op(extractBufferTile(tile), kernel, target) toInternalRow(result, ctx) } @@ -75,6 +74,9 @@ case class Convolve(left: Expression, middle: Expression, right: Expression) ext case bt: BufferTile => bt.convolve(kernel, target = target) case _ => t.convolve(kernel, target = target) } + + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + copy(newFirst, newSecond, newThird) } object Convolve { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMax.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMax.scala index b8ad6d908..5ca4f386f 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMax.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMax.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescript > SELECT _FUNC_(tile, 'square-1', 'all'); ...""" ) -case class FocalMax(left: Expression, middle: Expression, right: Expression) extends FocalNeighborhoodOp { +case class FocalMax(first: Expression, second: Expression, third: Expression) extends FocalNeighborhoodOp { override def nodeName: String = FocalMax.name protected def op(t: Tile, neighborhood: Neighborhood, target: TargetCell): Tile = t match { case bt: BufferTile => bt.focalMax(neighborhood, target = target) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMean.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMean.scala index b6fb8ba0d..f612d118a 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMean.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMean.scala @@ -38,12 +38,13 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescript > SELECT _FUNC_(tile, 'square-1', 'all'); ...""" ) -case class FocalMean(left: Expression, middle: Expression, right: Expression) extends FocalNeighborhoodOp { +case class FocalMean(first: Expression, second: Expression, third: Expression) extends FocalNeighborhoodOp { override def nodeName: String = FocalMean.name protected def op(t: Tile, neighborhood: Neighborhood, target: TargetCell): Tile = t match { case bt: BufferTile => bt.focalMean(neighborhood, target = target) case _ => t.focalMean(neighborhood, target = target) } + } object FocalMean { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMedian.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMedian.scala index b72a4ed8d..7830bae41 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMedian.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMedian.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescript > SELECT _FUNC_(tile, 'square-1', 'all'); ...""" ) -case class FocalMedian(left: Expression, middle: Expression, right: Expression) extends FocalNeighborhoodOp { +case class FocalMedian(first: Expression, second: Expression, third: Expression) extends FocalNeighborhoodOp { override def nodeName: String = FocalMedian.name protected def op(t: Tile, neighborhood: Neighborhood, target: TargetCell): Tile = t match { case bt: BufferTile => bt.focalMedian(neighborhood, target = target) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMin.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMin.scala index 439a8ae9f..0baead593 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMin.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMin.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescript > SELECT _FUNC_(tile, 'square-1', 'all'); ...""" ) -case class FocalMin(left: Expression, middle: Expression, right: Expression) extends FocalNeighborhoodOp { +case class FocalMin(first: Expression, second: Expression, third: Expression) extends FocalNeighborhoodOp { override def nodeName: String = FocalMin.name protected def op(t: Tile, neighborhood: Neighborhood, target: TargetCell): Tile = t match { case bt: BufferTile => bt.focalMin(neighborhood, target = target) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMode.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMode.scala index 6ea049cc6..4e4d08c67 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMode.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMode.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescript > SELECT _FUNC_(tile, 'square-1', 'all'); ...""" ) -case class FocalMode(left: Expression, middle: Expression, right: Expression) extends FocalNeighborhoodOp { +case class FocalMode(first: Expression, second: Expression, third: Expression) extends FocalNeighborhoodOp { override def nodeName: String = FocalMode.name protected def op(t: Tile, neighborhood: Neighborhood, target: TargetCell): Tile = t match { case bt: BufferTile => bt.focalMode(neighborhood, target = target) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMoransI.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMoransI.scala index d4db3192f..7ab8f1d97 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMoransI.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMoransI.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescript > SELECT _FUNC_(tile, 'square-1', 'all'); ...""" ) -case class FocalMoransI(left: Expression, middle: Expression, right: Expression) extends FocalNeighborhoodOp { +case class FocalMoransI(first: Expression, second: Expression, third: Expression) extends FocalNeighborhoodOp { override def nodeName: String = FocalMoransI.name protected def op(t: Tile, neighborhood: Neighborhood, target: TargetCell): Tile = t match { case bt: BufferTile => bt.tileMoransI(neighborhood, target = target) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalNeighborhoodOp.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalNeighborhoodOp.scala index 64bbd313e..2303c7b7c 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalNeighborhoodOp.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalNeighborhoodOp.scala @@ -29,32 +29,34 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, TernaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types.DataType import org.locationtech.rasterframes.expressions.DynamicExtractors.{neighborhoodExtractor, targetCellExtractor, tileExtractor} -import org.locationtech.rasterframes.expressions.{RasterResult, row} +import org.locationtech.rasterframes.expressions.{HasTernaryExpressionCopy, RasterResult, row} import org.slf4j.LoggerFactory -trait FocalNeighborhoodOp extends TernaryExpression with RasterResult with CodegenFallback { +trait FocalNeighborhoodOp extends TernaryExpression with RasterResult with CodegenFallback {self: HasTernaryExpressionCopy => + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + copy(newFirst, newSecond, newThird) + @transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName)) // Tile - def left: Expression + def first: Expression // Neighborhood - def middle: Expression + def second: Expression // TargetCell - def right: Expression + def third: Expression - def dataType: DataType = left.dataType - def children: Seq[Expression] = Seq(left, middle, right) + def dataType: DataType = first.dataType override def checkInputDataTypes(): TypeCheckResult = - if (!tileExtractor.isDefinedAt(left.dataType)) TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a raster type.") - else if(!neighborhoodExtractor.isDefinedAt(middle.dataType)) TypeCheckFailure(s"Input type '${middle.dataType}' does not conform to a string Neighborhood type.") - else if(!targetCellExtractor.isDefinedAt(right.dataType)) TypeCheckFailure(s"Input type '${right.dataType}' does not conform to a string TargetCell type.") + if (!tileExtractor.isDefinedAt(first.dataType)) TypeCheckFailure(s"Input type '${first.dataType}' does not conform to a raster type.") + else if(!neighborhoodExtractor.isDefinedAt(second.dataType)) TypeCheckFailure(s"Input type '${second.dataType}' does not conform to a string Neighborhood type.") + else if(!targetCellExtractor.isDefinedAt(third.dataType)) TypeCheckFailure(s"Input type '${third.dataType}' does not conform to a string TargetCell type.") else TypeCheckSuccess override protected def nullSafeEval(tileInput: Any, neighborhoodInput: Any, targetCellInput: Any): Any = { - val (tile, ctx) = tileExtractor(left.dataType)(row(tileInput)) - val neighborhood = neighborhoodExtractor(middle.dataType)(neighborhoodInput) - val target = targetCellExtractor(right.dataType)(targetCellInput) + val (tile, ctx) = tileExtractor(first.dataType)(row(tileInput)) + val neighborhood = neighborhoodExtractor(second.dataType)(neighborhoodInput) + val target = targetCellExtractor(third.dataType)(targetCellInput) val result = op(extractBufferTile(tile), neighborhood, target) toInternalRow(result, ctx) } diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalStdDev.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalStdDev.scala index ed05e077f..3887d079c 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalStdDev.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalStdDev.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescript > SELECT _FUNC_(tile, 'square-1', 'all'); ...""" ) -case class FocalStdDev(left: Expression, middle: Expression, right: Expression) extends FocalNeighborhoodOp { +case class FocalStdDev(first: Expression, second: Expression, third: Expression) extends FocalNeighborhoodOp { override def nodeName: String = FocalStdDev.name protected def op(t: Tile, neighborhood: Neighborhood, target: TargetCell): Tile = t match { case bt: BufferTile => bt.focalStandardDeviation(neighborhood, target = target) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/Hillshade.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/Hillshade.scala index 3a917337b..ca5bc3bec 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/Hillshade.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/Hillshade.scala @@ -91,6 +91,9 @@ case class Hillshade(first: Expression, second: Expression, third: Expression, f case bt: BufferTile => bt.mapTile(_.hillshade(CellSize(ctx.extent, cols = t.cols, rows = t.rows), azimuth, altitude, zFactor, target = target)) case _ => t.hillshade(CellSize(ctx.extent, cols = t.cols, rows = t.rows), azimuth, altitude, zFactor, target = target) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(newChildren(0), newChildren(1), newChildren(2), newChildren(3), newChildren(4)) } object Hillshade { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/Slope.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/Slope.scala index 2bd256ce2..79d2257f8 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/Slope.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/Slope.scala @@ -48,28 +48,26 @@ import org.slf4j.LoggerFactory > SELECT _FUNC_(tile, 0.2, 'all'); ...""" ) -case class Slope(left: Expression, middle: Expression, right: Expression) extends TernaryExpression with RasterResult with CodegenFallback { +case class Slope(first: Expression, second: Expression, third: Expression) extends TernaryExpression with RasterResult with CodegenFallback { @transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName)) override def nodeName: String = Slope.name - def dataType: DataType = left.dataType - - val children: Seq[Expression] = Seq(left, middle, right) + def dataType: DataType = first.dataType override def checkInputDataTypes(): TypeCheckResult = - if (!tileExtractor.isDefinedAt(left.dataType)) TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a raster type.") - else if (!numberArgExtractor.isDefinedAt(middle.dataType)) TypeCheckFailure(s"Input type '${middle.dataType}' does not conform to a numeric type.") - else if (!targetCellExtractor.isDefinedAt(right.dataType)) TypeCheckFailure(s"Input type '${right.dataType}' does not conform to a TargetCell type.") + if (!tileExtractor.isDefinedAt(first.dataType)) TypeCheckFailure(s"Input type '${first.dataType}' does not conform to a raster type.") + else if (!numberArgExtractor.isDefinedAt(second.dataType)) TypeCheckFailure(s"Input type '${second.dataType}' does not conform to a numeric type.") + else if (!targetCellExtractor.isDefinedAt(third.dataType)) TypeCheckFailure(s"Input type '${third.dataType}' does not conform to a TargetCell type.") else TypeCheckSuccess override protected def nullSafeEval(tileInput: Any, zFactorInput: Any, targetCellInput: Any): Any = { - val (tile, ctx) = tileExtractor(left.dataType)(row(tileInput)) - val zFactor = numberArgExtractor(middle.dataType)(zFactorInput) match { + val (tile, ctx) = tileExtractor(first.dataType)(row(tileInput)) + val zFactor = numberArgExtractor(second.dataType)(zFactorInput) match { case DoubleArg(value) => value case IntegerArg(value) => value.toDouble } - val target = targetCellExtractor(right.dataType)(targetCellInput) + val target = targetCellExtractor(third.dataType)(targetCellInput) eval(extractBufferTile(tile), ctx, zFactor, target) } protected def eval(tile: Tile, ctx: Option[TileContext], zFactor: Double, target: TargetCell): Any = ctx match { @@ -81,6 +79,8 @@ case class Slope(left: Expression, middle: Expression, right: Expression) extend case bt: BufferTile => bt.slope(CellSize(ctx.extent, cols = t.cols, rows = t.rows), zFactor, target = target) case _ => t.slope(CellSize(ctx.extent, cols = t.cols, rows = t.rows), zFactor, target = target) } + + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object Slope { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/ExplodeTiles.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/ExplodeTiles.scala index 7ebbad7cc..8dd46c2fc 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/ExplodeTiles.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/ExplodeTiles.scala @@ -78,6 +78,7 @@ case class ExplodeTiles(sampleFraction: Double , seed: Option[Long], override va val Dimensions(cols, rows) = dims.head val retval = Array.ofDim[InternalRow](cols * rows) + cfor(0)(_ < rows, _ + 1) { row => cfor(0)(_ < cols, _ + 1) { col => val rowIndex = row * cols + col @@ -95,6 +96,8 @@ case class ExplodeTiles(sampleFraction: Double , seed: Option[Long], override va else retval } } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children=newChildren) } object ExplodeTiles { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourceToRasterRefs.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourceToRasterRefs.scala index 1d9b82abc..8fd4c951d 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourceToRasterRefs.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourceToRasterRefs.scala @@ -82,6 +82,8 @@ case class RasterSourceToRasterRefs(children: Seq[Expression], bandIndexes: Seq[ .toOption.toSeq.flatten.mkString(", ") throw new java.lang.IllegalArgumentException(description, ex) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children=newChildren) } object RasterSourceToRasterRefs { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourceToTiles.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourceToTiles.scala index 8f28eb916..713811ca6 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourceToTiles.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourceToTiles.scala @@ -84,6 +84,8 @@ case class RasterSourceToTiles(children: Seq[Expression], bandIndexes: Seq[Int], Traversable.empty } } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children=newChildren) } object RasterSourceToTiles { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Abs.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Abs.scala index 19cbe3090..ed6cdd950 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Abs.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Abs.scala @@ -41,6 +41,7 @@ case class Abs(child: Expression) extends UnaryRasterOp with NullToValue with Co override def nodeName: String = "rf_abs" def na: Any = null protected def op(t: Tile): Tile = t.localAbs() + } object Abs { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Clamp.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Clamp.scala index 0b974e230..464e5f730 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Clamp.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Clamp.scala @@ -19,28 +19,26 @@ import org.locationtech.rasterframes.expressions.{RasterResult, row} * min - scalar or tile setting the minimum value for each cell * max - scalar or tile setting the maximum value for each cell""" ) -case class Clamp(left: Expression, middle: Expression, right: Expression) extends TernaryExpression with CodegenFallback with RasterResult with Serializable { - def dataType: DataType = left.dataType - - def children: Seq[Expression] = Seq(left, middle, right) +case class Clamp(first: Expression, second: Expression, third: Expression) extends TernaryExpression with CodegenFallback with RasterResult with Serializable { + def dataType: DataType = first.dataType override val nodeName = "rf_local_clamp" override def checkInputDataTypes(): TypeCheckResult = { - if (!tileExtractor.isDefinedAt(left.dataType)) { - TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a Tile type") - } else if (!tileExtractor.isDefinedAt(middle.dataType) && !numberArgExtractor.isDefinedAt(middle.dataType)) { - TypeCheckFailure(s"Input type '${middle.dataType}' does not conform to a Tile or numeric type") - } else if (!tileExtractor.isDefinedAt(right.dataType) && !numberArgExtractor.isDefinedAt(right.dataType)) { - TypeCheckFailure(s"Input type '${right.dataType}' does not conform to a Tile or numeric type") + if (!tileExtractor.isDefinedAt(first.dataType)) { + TypeCheckFailure(s"Input type '${first.dataType}' does not conform to a Tile type") + } else if (!tileExtractor.isDefinedAt(second.dataType) && !numberArgExtractor.isDefinedAt(second.dataType)) { + TypeCheckFailure(s"Input type '${second.dataType}' does not conform to a Tile or numeric type") + } else if (!tileExtractor.isDefinedAt(third.dataType) && !numberArgExtractor.isDefinedAt(third.dataType)) { + TypeCheckFailure(s"Input type '${third.dataType}' does not conform to a Tile or numeric type") } else TypeCheckSuccess } override protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = { - val (targetTile, targetCtx) = tileExtractor(left.dataType)(row(input1)) - val minVal = tileOrNumberExtractor(middle.dataType)(input2) - val maxVal = tileOrNumberExtractor(right.dataType)(input3) + val (targetTile, targetCtx) = tileExtractor(first.dataType)(row(input1)) + val minVal = tileOrNumberExtractor(second.dataType)(input2) + val maxVal = tileOrNumberExtractor(third.dataType)(input3) val result = (minVal, maxVal) match { case (mn: TileArg, mx: TileArg) => targetTile.localMin(mx.tile).localMax(mn.tile) @@ -57,6 +55,8 @@ case class Clamp(left: Expression, middle: Expression, right: Expression) extend toInternalRow(result, targetCtx) } + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + copy(newFirst, newSecond, newThird) } object Clamp { def apply(tile: Column, min: Column, max: Column): Column = new Column(Clamp(tile.expr, min.expr, max.expr)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Equal.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Equal.scala index b83fcee7e..29f622c78 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Equal.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Equal.scala @@ -44,6 +44,7 @@ case class Equal(left: Expression, right: Expression) extends BinaryRasterFuncti protected def op(left: Tile, right: Tile): Tile = left.localEqual(right) protected def op(left: Tile, right: Double): Tile = left.localEqual(right) protected def op(left: Tile, right: Int): Tile = left.localEqual(right) + } object Equal { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/IsIn.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/IsIn.scala index bf1d9d7aa..e5472be01 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/IsIn.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/IsIn.scala @@ -72,6 +72,7 @@ case class IsIn(left: Expression, right: Expression) extends BinaryExpression wi IfCell(left, fn(_: Int), 1, 0) } + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object IsIn { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/NormalizedDifference.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/NormalizedDifference.scala index f5a312296..0a7c94eff 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/NormalizedDifference.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/NormalizedDifference.scala @@ -50,6 +50,8 @@ case class NormalizedDifference(left: Expression, right: Expression) extends Bin val sum = fpTile(left.localAdd(right)) diff.localDivide(sum) } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object NormalizedDifference { def apply(left: Column, right: Column): TypedColumn[Any, Tile] = new Column(NormalizedDifference(left.expr, right.expr)).as[Tile] diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Resample.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Resample.scala index 9bc0d829e..9f2aec49c 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Resample.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Resample.scala @@ -41,8 +41,10 @@ import org.locationtech.rasterframes.expressions.DynamicExtractors._ abstract class ResampleBase(left: Expression, right: Expression, method: Expression) extends TernaryExpression with RasterResult with CodegenFallback with Serializable { override val nodeName: String = "rf_resample" + def first: Expression = left + def second: Expression = right + def third: Expression = method def dataType: DataType = left.dataType - def children: Seq[Expression] = Seq(left, right, method) def targetFloatIfNeeded(t: Tile, method: GTResampleMethod): Tile = method match { @@ -127,7 +129,9 @@ Examples: > SELECT _FUNC_(tile1, tile2, lit("cubic_spline")); ...""" ) -case class Resample(left: Expression, factor: Expression, method: Expression) extends ResampleBase(left, factor, method) +case class Resample(left: Expression, factor: Expression, method: Expression) extends ResampleBase(left, factor, method) { + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) +} object Resample { def apply(left: Column, right: Column, methodName: String): Column = @@ -156,6 +160,9 @@ object Resample { ...""") case class ResampleNearest(tile: Expression, target: Expression) extends ResampleBase(tile, target, Literal("nearest")) { override val nodeName: String = "rf_resample_nearest" + + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + ResampleNearest(tile, target) } object ResampleNearest { def apply(tile: Column, target: Column): Column = new Column(ResampleNearest(tile.expr, target.expr)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Where.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Where.scala index 9b0a605d9..13121b63c 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Where.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Where.scala @@ -21,39 +21,37 @@ import org.slf4j.LoggerFactory * x - tile with cell values to return if condition is true * y - tile with cell values to return if condition is false""" ) -case class Where(left: Expression, middle: Expression, right: Expression) extends TernaryExpression with RasterResult with CodegenFallback with Serializable { +case class Where(first: Expression, second: Expression, third: Expression) extends TernaryExpression with RasterResult with CodegenFallback with Serializable { @transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName)) - def dataType: DataType = middle.dataType - - def children: Seq[Expression] = Seq(left, middle, right) + def dataType: DataType = second.dataType override val nodeName = "rf_where" override def checkInputDataTypes(): TypeCheckResult = { - if (!tileExtractor.isDefinedAt(left.dataType)) { - TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a Tile type") - } else if (!tileExtractor.isDefinedAt(middle.dataType)) { - TypeCheckFailure(s"Input type '${middle.dataType}' does not conform to a Tile type") - } else if (!tileExtractor.isDefinedAt(right.dataType)) { - TypeCheckFailure(s"Input type '${right.dataType}' does not conform to a Tile type") + if (!tileExtractor.isDefinedAt(first.dataType)) { + TypeCheckFailure(s"Input type '${first.dataType}' does not conform to a Tile type") + } else if (!tileExtractor.isDefinedAt(second.dataType)) { + TypeCheckFailure(s"Input type '${second.dataType}' does not conform to a Tile type") + } else if (!tileExtractor.isDefinedAt(third.dataType)) { + TypeCheckFailure(s"Input type '${third.dataType}' does not conform to a Tile type") } else TypeCheckSuccess } override protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = { - val (conditionTile, conditionCtx) = tileExtractor(left.dataType)(row(input1)) - val (xTile, xCtx) = tileExtractor(middle.dataType)(row(input2)) - val (yTile, yCtx) = tileExtractor(right.dataType)(row(input3)) + val (conditionTile, conditionCtx) = tileExtractor(first.dataType)(row(input1)) + val (xTile, xCtx) = tileExtractor(second.dataType)(row(input2)) + val (yTile, yCtx) = tileExtractor(third.dataType)(row(input3)) if (xCtx.isEmpty && yCtx.isDefined) logger.warn( - s"Middle parameter '${middle}' provided an extent and CRS, but the right parameter " + - s"'${right}' didn't have any. Because the middle defines output type, the right-hand context will be lost.") + s"Middle parameter '${second}' provided an extent and CRS, but the right parameter " + + s"'${third}' didn't have any. Because the middle defines output type, the right-hand context will be lost.") if(xCtx.isDefined && yCtx.isDefined && xCtx != yCtx) - logger.warn(s"Both '${middle}' and '${right}' provided an extent and CRS, but they are different. The former will be used.") + logger.warn(s"Both '${second}' and '${third}' provided an extent and CRS, but they are different. The former will be used.") val result = op(conditionTile, xTile, yTile) toInternalRow(result, xCtx) @@ -84,6 +82,7 @@ case class Where(left: Expression, middle: Expression, right: Expression) extend returnTile } + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object Where { def apply(condition: Column, x: Column, y: Column): Column = new Column(Where(condition.expr, x.expr, y.expr)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala index 9fa191ae4..40d22f96d 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala @@ -46,6 +46,10 @@ import scala.reflect.runtime.universe._ * @since 10/10/17 */ package object expressions { + type HasTernaryExpressionCopy = {def copy(first: Expression, second: Expression, third: Expression): Expression} + type HasBinaryExpressionCopy = {def copy(left: Expression, right: Expression): Expression} + type HasUnaryExpressionCopy = {def copy(child: Expression): Expression} + private[expressions] def row(input: Any) = input.asInstanceOf[InternalRow] /** Convert the tile to a floating point type as needed for scalar operations. */ @inline diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/CreateProjectedRaster.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/CreateProjectedRaster.scala index 759c14ebf..99c7124e5 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/CreateProjectedRaster.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/CreateProjectedRaster.scala @@ -43,8 +43,9 @@ import org.locationtech.rasterframes.encoders._ ) case class CreateProjectedRaster(tile: Expression, extent: Expression, crs: Expression) extends TernaryExpression with RasterResult with CodegenFallback { override def nodeName: String = "rf_proj_raster" - - def children: Seq[Expression] = Seq(tile, extent, crs) + def first: Expression = tile + def second: Expression = extent + def third: Expression = crs def dataType: DataType = ProjectedRasterTile.projectedRasterTileEncoder.schema @@ -70,6 +71,8 @@ case class CreateProjectedRaster(tile: Expression, extent: Expression, crs: Expr val prt = ProjectedRasterTile(t, e, c) toInternalRow(prt) } + + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object CreateProjectedRaster { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/DebugRender.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/DebugRender.scala index 76be3ba16..c310dc80c 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/DebugRender.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/DebugRender.scala @@ -29,11 +29,11 @@ import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.sql.{Column, TypedColumn} import org.apache.spark.unsafe.types.UTF8String import org.locationtech.rasterframes.encoders.SparkBasicEncoders._ -import org.locationtech.rasterframes.expressions.UnaryRasterFunction +import org.locationtech.rasterframes.expressions.{HasUnaryExpressionCopy, UnaryRasterFunction} import org.locationtech.rasterframes.model.TileContext import spire.syntax.cfor.cfor -abstract class DebugRender(asciiArt: Boolean) extends UnaryRasterFunction with CodegenFallback with Serializable { +abstract class DebugRender(asciiArt: Boolean) extends UnaryRasterFunction with CodegenFallback with Serializable { self: HasUnaryExpressionCopy => import org.locationtech.rasterframes.expressions.transformers.DebugRender.TileAsMatrix def dataType: DataType = StringType diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtentToGeometry.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtentToGeometry.scala index e90c7046d..8b922de4d 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtentToGeometry.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtentToGeometry.scala @@ -59,6 +59,8 @@ case class ExtentToGeometry(child: Expression) extends UnaryExpression with Code val geom = extent.toPolygon() JTSTypes.GeometryTypeInstance.serialize(geom) } + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object ExtentToGeometry extends SpatialEncoders { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala index 661e3a087..4412c2a9f 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala @@ -44,32 +44,32 @@ import org.locationtech.rasterframes.expressions._ > SELECT _FUNC_(tile, lit(4), lit(2)) ...""" ) -case class ExtractBits(child1: Expression, child2: Expression, child3: Expression) extends TernaryExpression with CodegenFallback with RasterResult with Serializable { +case class ExtractBits(first: Expression, second: Expression, third: Expression) extends TernaryExpression with CodegenFallback with RasterResult with Serializable { override val nodeName: String = "rf_local_extract_bits" - def children: Seq[Expression] = Seq(child1, child2, child3) - - def dataType: DataType = child1.dataType + def dataType: DataType = first.dataType override def checkInputDataTypes(): TypeCheckResult = - if(!tileExtractor.isDefinedAt(child1.dataType)) { - TypeCheckFailure(s"Input type '${child1.dataType}' does not conform to a raster type.") - } else if (!intArgExtractor.isDefinedAt(child2.dataType)) { - TypeCheckFailure(s"Input type '${child2.dataType}' isn't an integral type.") - } else if (!intArgExtractor.isDefinedAt(child3.dataType)) { - TypeCheckFailure(s"Input type '${child3.dataType}' isn't an integral type.") + if(!tileExtractor.isDefinedAt(first.dataType)) { + TypeCheckFailure(s"Input type '${first.dataType}' does not conform to a raster type.") + } else if (!intArgExtractor.isDefinedAt(second.dataType)) { + TypeCheckFailure(s"Input type '${second.dataType}' isn't an integral type.") + } else if (!intArgExtractor.isDefinedAt(third.dataType)) { + TypeCheckFailure(s"Input type '${third.dataType}' isn't an integral type.") } else TypeCheckSuccess override protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = { - val (childTile, childCtx) = tileExtractor(child1.dataType)(row(input1)) - val startBits = intArgExtractor(child2.dataType)(input2).value - val numBits = intArgExtractor(child2.dataType)(input3).value + val (childTile, childCtx) = tileExtractor(first.dataType)(row(input1)) + val startBits = intArgExtractor(second.dataType)(input2).value + val numBits = intArgExtractor(second.dataType)(input3).value val result = op(childTile, startBits, numBits) toInternalRow(result,childCtx) } protected def op(tile: Tile, startBit: Int, numBits: Int): Tile = ExtractBits(tile, startBit, numBits) + + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object ExtractBits{ diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/GeometryToExtent.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/GeometryToExtent.scala index 410f9168c..43e96311c 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/GeometryToExtent.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/GeometryToExtent.scala @@ -55,6 +55,8 @@ case class GeometryToExtent(child: Expression) extends UnaryExpression with Code val geom = JTSTypes.GeometryTypeInstance.deserialize(input) Extent(geom.getEnvelopeInternal).toInternalRow } + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object GeometryToExtent { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InterpretAs.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InterpretAs.scala index 678df26ab..91fb9ab81 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InterpretAs.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InterpretAs.scala @@ -81,6 +81,8 @@ case class InterpretAs(tile: Expression, cellType: Expression) extends BinaryExp val result = tile.interpretAs(ct) toInternalRow(result, ctx) } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object InterpretAs{ diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala index 9f528cb92..f225b369f 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala @@ -39,23 +39,21 @@ import org.slf4j.LoggerFactory /** Convert cells in the `left` to NoData based on another tile's contents * - * @param left a tile of data values, with valid nodata cell type - * @param middle a tile indicating locations to set to nodata - * @param right optional, cell values in the `middle` tile indicating locations to set NoData + * @param first a tile of data values, with valid nodata cell type + * @param second a tile indicating locations to set to nodata + * @param third optional, cell values in the `middle` tile indicating locations to set NoData * @param undefined if true, consider NoData in the `middle` as the locations to mask; else use `right` valued cells * @param inverse if true, and defined is true, set `left` to NoData where `middle` is NOT nodata */ -abstract class Mask(val left: Expression, val middle: Expression, val right: Expression, undefined: Boolean, inverse: Boolean) +abstract class Mask(val first: Expression, val second: Expression, val third: Expression, undefined: Boolean, inverse: Boolean) extends TernaryExpression with RasterResult with CodegenFallback with Serializable { // aliases. - def targetExp: Expression = left - def maskExp: Expression = middle - def maskValueExp: Expression = right + def targetExp: Expression = first + def maskExp: Expression = second + def maskValueExp: Expression = third @transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName)) - def children: Seq[Expression] = Seq(left, middle, right) - override def checkInputDataTypes(): TypeCheckResult = if (!tileExtractor.isDefinedAt(targetExp.dataType)) { TypeCheckFailure(s"Input type '${targetExp.dataType}' does not conform to a raster type.") @@ -65,7 +63,7 @@ abstract class Mask(val left: Expression, val middle: Expression, val right: Exp TypeCheckFailure(s"Input type '${maskValueExp.dataType}' isn't an integral type.") } else TypeCheckSuccess - def dataType: DataType = left.dataType + def dataType: DataType = first.dataType override def makeCopy(newArgs: Array[AnyRef]): Expression = super.makeCopy(newArgs) @@ -73,17 +71,17 @@ abstract class Mask(val left: Expression, val middle: Expression, val right: Exp val (targetTile, targetCtx) = tileExtractor(targetExp.dataType)(row(targetInput)) require(! targetTile.cellType.isInstanceOf[NoNoData], - s"Input data expression ${left.prettyName} must have a CellType with NoData defined in order to perform a masking operation. Found CellType ${targetTile.cellType.toString()}.") + s"Input data expression ${first.prettyName} must have a CellType with NoData defined in order to perform a masking operation. Found CellType ${targetTile.cellType.toString()}.") val (maskTile, maskCtx) = tileExtractor(maskExp.dataType)(row(maskInput)) if (targetCtx.isEmpty && maskCtx.isDefined) logger.warn( - s"Right-hand parameter '${middle}' provided an extent and CRS, but the left-hand parameter " + - s"'${left}' didn't have any. Because the left-hand side defines output type, the right-hand context will be lost.") + s"Right-hand parameter '${second}' provided an extent and CRS, but the left-hand parameter " + + s"'${first}' didn't have any. Because the left-hand side defines output type, the right-hand context will be lost.") if (targetCtx.isDefined && maskCtx.isDefined && targetCtx != maskCtx) - logger.warn(s"Both '${left}' and '${middle}' provided an extent and CRS, but they are different. Left-hand side will be used.") + logger.warn(s"Both '${first}' and '${second}' provided an extent and CRS, but they are different. Left-hand side will be used.") val maskValue = intArgExtractor(maskValueExp.dataType)(maskValueInput) @@ -112,6 +110,8 @@ object Mask { ) case class MaskByDefined(target: Expression, mask: Expression) extends Mask(target, mask, Literal(0), true, false) { override def nodeName: String = "rf_mask" + + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = ??? } object MaskByDefined { def apply(targetTile: Column, maskTile: Column): TypedColumn[Any, Tile] = @@ -131,6 +131,9 @@ object Mask { ) case class InverseMaskByDefined(leftTile: Expression, rightTile: Expression) extends Mask(leftTile, rightTile, Literal(0), true, true) { override def nodeName: String = "rf_inverse_mask" + + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + copy(leftTile = newFirst, rightTile = newSecond) } object InverseMaskByDefined { def apply(srcTile: Column, maskingTile: Column): TypedColumn[Any, Tile] = @@ -150,6 +153,9 @@ object Mask { ) case class MaskByValue(leftTile: Expression, rightTile: Expression, maskValue: Expression) extends Mask(leftTile, rightTile, maskValue, false, false) { override def nodeName: String = "rf_mask_by_value" + + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + copy(leftTile = newFirst, rightTile = newSecond, maskValue = newThird) } object MaskByValue { def apply(srcTile: Column, maskingTile: Column, maskValue: Column): TypedColumn[Any, Tile] = @@ -171,6 +177,9 @@ object Mask { ) case class InverseMaskByValue(leftTile: Expression, rightTile: Expression, maskValue: Expression) extends Mask(leftTile, rightTile, maskValue, false, true) { override def nodeName: String = "rf_inverse_mask_by_value" + + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + copy(leftTile = newFirst, rightTile = newSecond) } object InverseMaskByValue { def apply(srcTile: Column, maskingTile: Column, maskValue: Column): TypedColumn[Any, Tile] = @@ -194,6 +203,8 @@ object Mask { def this(dataTile: Expression, maskTile: Expression, maskValues: Expression) = this(dataTile, IsIn(maskTile, maskValues)) override def nodeName: String = "rf_mask_by_values" + + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = ??? } object MaskByValues { def apply(dataTile: Column, maskTile: Column, maskValues: Column): TypedColumn[Any, Tile] = diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RGBComposite.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RGBComposite.scala index 71a580b6f..f33cc8ca0 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RGBComposite.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RGBComposite.scala @@ -50,6 +50,9 @@ import org.locationtech.rasterframes.expressions.{RasterResult, row} case class RGBComposite(red: Expression, green: Expression, blue: Expression) extends TernaryExpression with RasterResult with CodegenFallback { override def nodeName: String = "rf_rgb_composite" + def first: Expression = red + def second: Expression = green + def third: Expression = blue def dataType: DataType = if( tileExtractor.isDefinedAt(red.dataType) || @@ -57,8 +60,6 @@ case class RGBComposite(red: Expression, green: Expression, blue: Expression) ex tileExtractor.isDefinedAt(blue.dataType) ) red.dataType else tileUDT - def children: Seq[Expression] = Seq(red, green, blue) - override def checkInputDataTypes(): TypeCheckResult = { if (!tileExtractor.isDefinedAt(red.dataType)) { TypeCheckFailure(s"Red channel input type '${red.dataType}' does not conform to a raster type.") @@ -86,6 +87,8 @@ case class RGBComposite(red: Expression, green: Expression, blue: Expression) ex ).color() toInternalRow(composite, ctx) } + + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object RGBComposite { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RasterRefToTile.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RasterRefToTile.scala index 7c0fb4ba2..261a3a6c5 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RasterRefToTile.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RasterRefToTile.scala @@ -53,6 +53,8 @@ case class RasterRefToTile(child: Expression) extends UnaryExpression val ref = input.asInstanceOf[InternalRow].as[RasterRef] ProjectedRasterTile(ref.tile, ref.extent, ref.crs).toInternalRow } + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object RasterRefToTile { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RenderPNG.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RenderPNG.scala index 9d3639910..be539a4dd 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RenderPNG.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RenderPNG.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescript import org.apache.spark.sql.types.{BinaryType, DataType} import org.apache.spark.sql.{Column, TypedColumn} import org.locationtech.rasterframes.encoders.SparkBasicEncoders._ -import org.locationtech.rasterframes.expressions.UnaryRasterFunction +import org.locationtech.rasterframes.expressions.{HasUnaryExpressionCopy, UnaryRasterFunction} import org.locationtech.rasterframes.model.TileContext /** @@ -36,7 +36,7 @@ import org.locationtech.rasterframes.model.TileContext * @param child tile column * @param ramp color ramp to use for non-composite tiles. */ -abstract class RenderPNG(child: Expression, ramp: Option[ColorRamp]) extends UnaryRasterFunction with CodegenFallback with Serializable { +abstract class RenderPNG(child: Expression, ramp: Option[ColorRamp]) extends UnaryRasterFunction with CodegenFallback with Serializable { self: HasUnaryExpressionCopy => def dataType: DataType = BinaryType protected def eval(tile: Tile, ctx: Option[TileContext]): Any = { val png = ramp.map(tile.renderPng).getOrElse(tile.renderPng()) @@ -69,6 +69,7 @@ object RenderPNG { ) case class RenderColorRampPNG(child: Expression, colors: ColorRamp) extends RenderPNG(child, Some(colors)) { override def nodeName: String = "rf_render_png" + def copy(child: Expression): Expression = RenderColorRampPNG(child, colors: ColorRamp) } object RenderColorRampPNG { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ReprojectGeometry.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ReprojectGeometry.scala index 71c7800a4..036d9192d 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ReprojectGeometry.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ReprojectGeometry.scala @@ -49,12 +49,12 @@ import org.locationtech.rasterframes.model.LazyCRS > SELECT _FUNC_(geom, srcCRS, dstCRS); ...""" ) -case class ReprojectGeometry(geometry: Expression, srcCRS: Expression, dstCRS: Expression) extends Expression with CodegenFallback { - +case class ReprojectGeometry(geometry: Expression, srcCRS: Expression, dstCRS: Expression) extends TernaryExpression with CodegenFallback { override def nodeName: String = "st_reproject" + def first: Expression = geometry + def second: Expression = srcCRS + def third: Expression = dstCRS def dataType: DataType = JTSTypes.GeometryTypeInstance - def nullable: Boolean = geometry.nullable || srcCRS.nullable || dstCRS.nullable - def children: Seq[Expression] = Seq(geometry, srcCRS, dstCRS) override def checkInputDataTypes(): TypeCheckResult = { if (!geometry.dataType.isInstanceOf[AbstractGeometryUDT[_]]) @@ -73,7 +73,7 @@ case class ReprojectGeometry(geometry: Expression, srcCRS: Expression, dstCRS: E trans.transform(sourceGeom) } - def eval(input: InternalRow): Any = { + override def eval(input: InternalRow): Any = { val src = DynamicExtractors.crsExtractor(srcCRS.dataType)(srcCRS.eval(input)) val dst = DynamicExtractors.crsExtractor(dstCRS.dataType)(dstCRS.eval(input)) (src, dst) match { @@ -89,6 +89,8 @@ case class ReprojectGeometry(geometry: Expression, srcCRS: Expression, dstCRS: E } } } + + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object ReprojectGeometry { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Rescale.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Rescale.scala index 4261c7a36..7dabef32d 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Rescale.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Rescale.scala @@ -46,27 +46,25 @@ import org.locationtech.rasterframes.expressions.tilestats.TileStats > SELECT _FUNC_(tile, lit(-2.2), lit(2.2)) ...""" ) -case class Rescale(child1: Expression, child2: Expression, child3: Expression) extends TernaryExpression with RasterResult with CodegenFallback with Serializable { +case class Rescale(first: Expression, second: Expression, third: Expression) extends TernaryExpression with RasterResult with CodegenFallback with Serializable { override val nodeName: String = "rf_rescale" - def children: Seq[Expression] = Seq(child1, child2, child3) - - def dataType: DataType = child1.dataType + def dataType: DataType = first.dataType override def checkInputDataTypes(): TypeCheckResult = - if(!tileExtractor.isDefinedAt(child1.dataType)) { - TypeCheckFailure(s"Input type '${child1.dataType}' does not conform to a raster type.") - } else if (!doubleArgExtractor.isDefinedAt(child2.dataType)) { - TypeCheckFailure(s"Input type '${child2.dataType}' isn't floating point type.") - } else if (!doubleArgExtractor.isDefinedAt(child3.dataType)) { - TypeCheckFailure(s"Input type '${child3.dataType}' isn't floating point type." ) + if(!tileExtractor.isDefinedAt(first.dataType)) { + TypeCheckFailure(s"Input type '${first.dataType}' does not conform to a raster type.") + } else if (!doubleArgExtractor.isDefinedAt(second.dataType)) { + TypeCheckFailure(s"Input type '${second.dataType}' isn't floating point type.") + } else if (!doubleArgExtractor.isDefinedAt(third.dataType)) { + TypeCheckFailure(s"Input type '${third.dataType}' isn't floating point type." ) } else TypeCheckSuccess override protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = { - val (childTile, childCtx) = tileExtractor(child1.dataType)(row(input1)) - val min = doubleArgExtractor(child2.dataType)(input2).value - val max = doubleArgExtractor(child3.dataType)(input3).value + val (childTile, childCtx) = tileExtractor(first.dataType)(row(input1)) + val min = doubleArgExtractor(second.dataType)(input2).value + val max = doubleArgExtractor(third.dataType)(input3).value val result = op(childTile, min, max) toInternalRow(result, childCtx) } @@ -81,6 +79,8 @@ case class Rescale(child1: Expression, child2: Expression, child3: Expression) e .normalize(min, max, 0.0, 1.0) } + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + copy(newFirst, newSecond, newThird) } object Rescale { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/SetCellType.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/SetCellType.scala index 32a329691..ee311a593 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/SetCellType.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/SetCellType.scala @@ -85,6 +85,8 @@ case class SetCellType(tile: Expression, cellType: Expression) extends BinaryExp val result = tile.convert(ct) toInternalRow(result, ctx) } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object SetCellType { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/SetNoDataValue.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/SetNoDataValue.scala index 2825d5334..52fdfc6cb 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/SetNoDataValue.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/SetNoDataValue.scala @@ -70,6 +70,8 @@ case class SetNoDataValue(left: Expression, right: Expression) extends BinaryExp toInternalRow(result, leftCtx) } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object SetNoDataValue { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Standardize.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Standardize.scala index 02a04e54c..3d69682f4 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Standardize.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Standardize.scala @@ -46,28 +46,26 @@ import org.locationtech.rasterframes.expressions.tilestats.TileStats > SELECT _FUNC_(tile, lit(4.0), lit(2.2)) ...""" ) -case class Standardize(child1: Expression, child2: Expression, child3: Expression) extends TernaryExpression with RasterResult with CodegenFallback with Serializable { +case class Standardize(first: Expression, second: Expression, third: Expression) extends TernaryExpression with RasterResult with CodegenFallback with Serializable { override val nodeName: String = "rf_standardize" - def children: Seq[Expression] = Seq(child1, child2, child3) - - def dataType: DataType = child1.dataType + def dataType: DataType = first.dataType override def checkInputDataTypes(): TypeCheckResult = - if(!tileExtractor.isDefinedAt(child1.dataType)) { - TypeCheckFailure(s"Input type '${child1.dataType}' does not conform to a raster type.") - } else if (!doubleArgExtractor.isDefinedAt(child2.dataType)) { - TypeCheckFailure(s"Input type '${child2.dataType}' isn't floating point type.") - } else if (!doubleArgExtractor.isDefinedAt(child3.dataType)) { - TypeCheckFailure(s"Input type '${child3.dataType}' isn't floating point type." ) + if(!tileExtractor.isDefinedAt(first.dataType)) { + TypeCheckFailure(s"Input type '${first.dataType}' does not conform to a raster type.") + } else if (!doubleArgExtractor.isDefinedAt(second.dataType)) { + TypeCheckFailure(s"Input type '${second.dataType}' isn't floating point type.") + } else if (!doubleArgExtractor.isDefinedAt(third.dataType)) { + TypeCheckFailure(s"Input type '${third.dataType}' isn't floating point type." ) } else TypeCheckSuccess override protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = { - val (childTile, childCtx) = tileExtractor(child1.dataType)(row(input1)) + val (childTile, childCtx) = tileExtractor(first.dataType)(row(input1)) - val mean = doubleArgExtractor(child2.dataType)(input2).value - val stdDev = doubleArgExtractor(child3.dataType)(input3).value + val mean = doubleArgExtractor(second.dataType)(input2).value + val stdDev = doubleArgExtractor(third.dataType)(input3).value val result = op(childTile, mean, stdDev) toInternalRow(result, childCtx) @@ -79,6 +77,8 @@ case class Standardize(child1: Expression, child2: Expression, child3: Expressio .localSubtract(mean) .localDivide(stdDev) + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + copy(newFirst, newSecond, newThird) } object Standardize { def apply(tile: Column, mean: Column, stdDev: Column): Column = diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/URIToRasterSource.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/URIToRasterSource.scala index 5356d7864..fcab58900 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/URIToRasterSource.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/URIToRasterSource.scala @@ -53,6 +53,8 @@ case class URIToRasterSource(override val child: Expression) extends UnaryExpres val ref = RFRasterSource(uri) rasterSourceUDT.serialize(ref) } + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object URIToRasterSource { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/XZ2Indexer.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/XZ2Indexer.scala index dfca3d49d..28a9a099c 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/XZ2Indexer.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/XZ2Indexer.scala @@ -87,6 +87,9 @@ case class XZ2Indexer(left: Expression, right: Expression, indexResolution: Shor ) index } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = + copy(newLeft, newRight) } object XZ2Indexer { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Z2Indexer.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Z2Indexer.scala index d8f8a8ade..2b8844e44 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Z2Indexer.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Z2Indexer.scala @@ -82,6 +82,9 @@ case class Z2Indexer(left: Expression, right: Expression, indexResolution: Short indexer.index(pt.getX, pt.getY, lenient = true) } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = + copy(newLeft, newRight) } object Z2Indexer { From 160f351cc649ab0eb51aada422bfc484ee4dc54e Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Thu, 30 Jun 2022 00:59:03 -0400 Subject: [PATCH 07/34] Try Aggregator implemtnation --- .../aggregates/TileRasterizerAggregate.scala | 62 +++++++------------ 1 file changed, 21 insertions(+), 41 deletions(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala index 446e7aeb2..58a54a8d1 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala @@ -25,14 +25,15 @@ import geotrellis.layer._ import geotrellis.proj4.CRS import geotrellis.raster.reproject.Reproject import geotrellis.raster.resample.{Bilinear, ResampleMethod} -import geotrellis.raster.{ArrayTile, CellType, Dimensions, MultibandTile, ProjectedRaster, Tile} +import geotrellis.raster.{ArrayTile, CellType, Dimensions, MultibandTile, MutableArrayTile, ProjectedRaster, Tile} import geotrellis.vector.Extent -import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} -import org.apache.spark.sql.types.{DataType, StructField, StructType} -import org.apache.spark.sql.{Column, DataFrame, Row, TypedColumn} +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.functions.udaf +import org.apache.spark.sql.{Column, DataFrame, Encoder, TypedColumn} import org.locationtech.rasterframes._ -import org.locationtech.rasterframes.encoders.syntax._ +import org.locationtech.rasterframes.encoders.StandardEncoders import org.locationtech.rasterframes.expressions.aggregates.TileRasterizerAggregate.ProjectedRasterDefinition +import org.locationtech.rasterframes.tiles.ProjectedRasterTile import org.locationtech.rasterframes.util._ import org.slf4j.LoggerFactory @@ -41,48 +42,26 @@ import org.slf4j.LoggerFactory * `Tile`, `CRS` and `Extent` columns. * @param prd aggregation settings */ -class TileRasterizerAggregate(prd: ProjectedRasterDefinition) extends UserDefinedAggregateFunction { - +class TileRasterizerAggregate(prd: ProjectedRasterDefinition) extends Aggregator[ProjectedRasterTile, Tile, Tile] { val projOpts = Reproject.Options.DEFAULT.copy(method = prd.sampler) - def deterministic: Boolean = true - - def inputSchema: StructType = StructType(Seq( - StructField("crs", crsUDT, false), - StructField("extent", extentEncoder.schema, false), - StructField("tile", tileUDT) - )) - - def bufferSchema: StructType = StructType(Seq( - StructField("tile_buffer", tileUDT) - )) - - def dataType: DataType = tileUDT - - def initialize(buffer: MutableAggregationBuffer): Unit = - buffer(0) = ArrayTile.empty(prd.destinationCellType, prd.totalCols, prd.totalRows) - - def update(buffer: MutableAggregationBuffer, input: Row): Unit = { - val crs: CRS = input.getAs[CRS](0) - val extent: Extent = input.getAs[Row](1).as[Extent] - - val localExtent = extent.reproject(crs, prd.destinationCRS) + override def zero: MutableArrayTile = ArrayTile.empty(prd.destinationCellType, prd.totalCols, prd.totalRows) + override def reduce(b: Tile, a: ProjectedRasterTile): Tile = { + val localExtent = a.extent.reproject(a.crs, prd.destinationCRS) if (prd.destinationExtent.intersects(localExtent)) { - val localTile = input.getAs[Tile](2).reproject(extent, crs, prd.destinationCRS, projOpts) - val bt = buffer.getAs[Tile](0) - val merged = bt.merge(prd.destinationExtent, localExtent, localTile.tile, prd.sampler) - buffer(0) = merged - } + val localTile = a.tile.reproject(a.extent, a.crs, prd.destinationCRS, projOpts) + b.merge(prd.destinationExtent, localExtent, localTile.tile, prd.sampler) + } else b } - def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { - val leftTile = buffer1.getAs[Tile](0) - val rightTile = buffer2.getAs[Tile](0) - buffer1(0) = leftTile.merge(rightTile) - } + override def merge(b1: Tile, b2: Tile): Tile = b1.merge(b2) + + override def finish(reduction: Tile): Tile = reduction + + override def bufferEncoder: Encoder[Tile] = StandardEncoders.tileEncoder - def evaluate(buffer: Row): Tile = buffer.getAs[Tile](0) + override def outputEncoder: Encoder[Tile] = StandardEncoders.tileEncoder } object TileRasterizerAggregate { @@ -107,7 +86,8 @@ object TileRasterizerAggregate { logger.warn( s"You've asked for the construction of a very large image (${prd.totalCols} x ${prd.totalRows}). Out of memory error likely.") - new TileRasterizerAggregate(prd)(crsCol, extentCol, tileCol) + udaf(new TileRasterizerAggregate(prd)) + .apply(crsCol, extentCol, tileCol) .as("rf_agg_overview_raster") .as[Tile] } From eb8ccb92473c8658b8d57aa3ad879783ca8471dd Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Thu, 30 Jun 2022 00:59:36 -0400 Subject: [PATCH 08/34] more explicit --- .../expressions/aggregates/ApproxCellQuantilesAggregate.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/ApproxCellQuantilesAggregate.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/ApproxCellQuantilesAggregate.scala index 00d3bd2c9..ac99ef6e2 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/ApproxCellQuantilesAggregate.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/ApproxCellQuantilesAggregate.scala @@ -71,7 +71,7 @@ case class ApproxCellQuantilesAggregate(probabilities: Seq[Double], relativeErro def evaluate(buffer: Row): Seq[Double] = { val summaries = buffer.getStruct(0).as[QuantileSummaries] - probabilities.flatMap(summaries.query) + probabilities.flatMap(quantile => summaries.query(quantile)) } } From a92ee4e19e77b5e01270dc832cf5c1e1c97ba9f1 Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Tue, 5 Jul 2022 21:00:41 -0400 Subject: [PATCH 09/34] Fix UDF style Aggregates --- .../rasterframes/expressions/UnaryRasterAggregate.scala | 6 ++++-- .../expressions/aggregates/CellCountAggregate.scala | 5 +++-- project/RFDependenciesPlugin.scala | 6 +++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterAggregate.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterAggregate.scala index 42f886b65..253b1cb0f 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterAggregate.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterAggregate.scala @@ -33,15 +33,17 @@ import org.locationtech.rasterframes.encoders.syntax._ import scala.reflect.runtime.universe._ /** Mixin providing boilerplate for DeclarativeAggrates over tile-conforming columns. */ -trait UnaryRasterAggregate extends DeclarativeAggregate { +trait UnaryRasterAggregate extends DeclarativeAggregate { self: HasUnaryExpressionCopy => def child: Expression def nullable: Boolean = child.nullable - def children = Seq(child) + def children: Seq[Expression] = Seq(child) protected def tileOpAsExpression[R: TypeTag](name: String, op: Tile => R): Expression => ScalaUDF = udfiexpr[R, Any](name, (dataType: DataType) => (a: Any) => if(a == null) null.asInstanceOf[R] else op(UnaryRasterAggregate.extractTileFromAny(dataType, a))) + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(newChildren(0)) } object UnaryRasterAggregate { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/CellCountAggregate.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/CellCountAggregate.scala index 7e845f409..1571a29ac 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/CellCountAggregate.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/CellCountAggregate.scala @@ -22,7 +22,7 @@ package org.locationtech.rasterframes.expressions.aggregates import org.locationtech.rasterframes.encoders.SparkBasicEncoders._ -import org.locationtech.rasterframes.expressions.UnaryRasterAggregate +import org.locationtech.rasterframes.expressions.{HasUnaryExpressionCopy, UnaryRasterAggregate} import org.locationtech.rasterframes.expressions.tilestats.{DataCells, NoDataCells} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -35,7 +35,7 @@ import org.apache.spark.sql.{Column, TypedColumn} * @since 10/5/17 * @param isData true if count should be of non-NoData cells, false if count should be of NoData cells. */ -abstract class CellCountAggregate(isData: Boolean) extends UnaryRasterAggregate { +abstract class CellCountAggregate(isData: Boolean) extends UnaryRasterAggregate { self: HasUnaryExpressionCopy => private lazy val count = AttributeReference("count", LongType, false, Metadata.empty)() override lazy val aggBufferAttributes = Seq(count) @@ -69,6 +69,7 @@ object CellCountAggregate { case class DataCells(child: Expression) extends CellCountAggregate(true) { override def nodeName: String = "rf_agg_data_cells" } + object DataCells { def apply(tile: Column): TypedColumn[Any, Long] = new Column(DataCells(tile.expr).toAggregateExpression()).as[Long] diff --git a/project/RFDependenciesPlugin.scala b/project/RFDependenciesPlugin.scala index ed6ab4dde..ed55afe9d 100644 --- a/project/RFDependenciesPlugin.scala +++ b/project/RFDependenciesPlugin.scala @@ -54,8 +54,8 @@ object RFDependenciesPlugin extends AutoPlugin { val `scala-logging` = "com.typesafe.scala-logging" %% "scala-logging" % "3.9.4" val stac4s = "com.azavea.stac4s" %% "client" % "0.7.2" val sttpCatsCe2 = "com.softwaremill.sttp.client3" %% "async-http-client-backend-cats-ce2" % "3.3.15" - val frameless = "org.typelevel" %% "frameless-dataset-spark31" % "0.12.0" - val framelessRefined = "org.typelevel" %% "frameless-refined-spark31" % "0.12.0" + val frameless = "org.typelevel" %% "frameless-dataset" % "0.11.1" + val framelessRefined = "org.typelevel" %% "frameless-refined" % "0.11.1" val `better-files` = "com.github.pathikrit" %% "better-files" % "3.9.1" % Test } import autoImport._ @@ -70,7 +70,7 @@ object RFDependenciesPlugin extends AutoPlugin { "jitpack" at "https://jitpack.io" ), // NB: Make sure to update the Spark version in pyrasterframes/python/setup.py - rfSparkVersion := "3.2.1", + rfSparkVersion := "3.2.0", rfGeoTrellisVersion := "3.6.1", rfGeoMesaVersion := "3.2.0", excludeDependencies += "log4j" % "log4j" From 8da8bd7ae056c5e574f31488f59f101a4a8d24a9 Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Mon, 5 Dec 2022 03:00:39 -0500 Subject: [PATCH 10/34] Bring in the Kryo setup GT settings bring in jackson classes and with shading it was getting weird --- .../rasterframes/util/RFKryoRegistrator.scala | 210 +++++++++++++++++- 1 file changed, 207 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/util/RFKryoRegistrator.scala b/core/src/main/scala/org/locationtech/rasterframes/util/RFKryoRegistrator.scala index e5aae5162..c9dc8c400 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/util/RFKryoRegistrator.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/util/RFKryoRegistrator.scala @@ -21,11 +21,12 @@ package org.locationtech.rasterframes.util -import org.locationtech.rasterframes.ref.{DelegatingRasterSource, RasterRef, RFRasterSource} +import org.locationtech.rasterframes.ref.{DelegatingRasterSource, RFRasterSource, RasterRef} import org.locationtech.rasterframes.ref._ import com.esotericsoftware.kryo.Kryo import geotrellis.raster.io.geotiff.reader.GeoTiffInfo -import geotrellis.spark.store.kryo.KryoRegistrator +import geotrellis.spark.store.kryo.{GeometrySerializer} +import org.apache.spark.serializer.KryoRegistrator /** * @@ -35,7 +36,210 @@ import geotrellis.spark.store.kryo.KryoRegistrator */ class RFKryoRegistrator extends KryoRegistrator { override def registerClasses(kryo: Kryo): Unit = { - super.registerClasses(kryo) + + // TreeMap serializaiton has a bug; we fix it here as we're stuck on low + // Kryo versions due to Spark. Hack-tastic. + //kryo.register(classOf[util.TreeMap[_, _]], (new XTreeMapSerializer).asInstanceOf[com.esotericsoftware.kryo.Serializer[TreeMap[_, _]]]) + + kryo.register(classOf[(_,_)]) + kryo.register(classOf[::[_]]) + kryo.register(classOf[geotrellis.raster.ByteArrayFiller]) + + // CellTypes + kryo.register(geotrellis.raster.BitCellType.getClass) // Bit + kryo.register(geotrellis.raster.ByteCellType.getClass) // Byte + kryo.register(geotrellis.raster.ByteConstantNoDataCellType.getClass) + kryo.register(classOf[geotrellis.raster.ByteUserDefinedNoDataCellType]) + kryo.register(geotrellis.raster.UByteCellType.getClass) // UByte + kryo.register(geotrellis.raster.UByteConstantNoDataCellType.getClass) + kryo.register(classOf[geotrellis.raster.UByteUserDefinedNoDataCellType]) + kryo.register(geotrellis.raster.ShortCellType.getClass) // Short + kryo.register(geotrellis.raster.ShortConstantNoDataCellType.getClass) + kryo.register(classOf[geotrellis.raster.ShortUserDefinedNoDataCellType]) + kryo.register(geotrellis.raster.UShortCellType.getClass) // UShort + kryo.register(geotrellis.raster.UShortConstantNoDataCellType.getClass) + kryo.register(classOf[geotrellis.raster.UShortUserDefinedNoDataCellType]) + kryo.register(geotrellis.raster.IntCellType.getClass) // Int + kryo.register(geotrellis.raster.IntConstantNoDataCellType.getClass) + kryo.register(classOf[geotrellis.raster.IntUserDefinedNoDataCellType]) + kryo.register(geotrellis.raster.FloatCellType.getClass) // Float + kryo.register(geotrellis.raster.FloatConstantNoDataCellType.getClass) + kryo.register(classOf[geotrellis.raster.FloatUserDefinedNoDataCellType]) + kryo.register(geotrellis.raster.DoubleCellType.getClass) // Double + kryo.register(geotrellis.raster.DoubleConstantNoDataCellType.getClass) + kryo.register(classOf[geotrellis.raster.DoubleUserDefinedNoDataCellType]) + + // ArrayTiles + kryo.register(classOf[geotrellis.raster.BitArrayTile]) // Bit + kryo.register(classOf[geotrellis.raster.ByteArrayTile]) // Byte + kryo.register(classOf[geotrellis.raster.ByteRawArrayTile]) + kryo.register(classOf[geotrellis.raster.ByteConstantNoDataArrayTile]) + kryo.register(classOf[geotrellis.raster.ByteUserDefinedNoDataArrayTile]) + kryo.register(classOf[geotrellis.raster.UByteArrayTile]) // UByte + kryo.register(classOf[geotrellis.raster.UByteRawArrayTile]) + kryo.register(classOf[geotrellis.raster.UByteConstantNoDataArrayTile]) + kryo.register(classOf[geotrellis.raster.UByteUserDefinedNoDataArrayTile]) + kryo.register(classOf[geotrellis.raster.ShortArrayTile]) // Short + kryo.register(classOf[geotrellis.raster.ShortRawArrayTile]) + kryo.register(classOf[geotrellis.raster.ShortConstantNoDataArrayTile]) + kryo.register(classOf[geotrellis.raster.ShortUserDefinedNoDataArrayTile]) + kryo.register(classOf[geotrellis.raster.UShortArrayTile]) // UShort + kryo.register(classOf[geotrellis.raster.UShortRawArrayTile]) + kryo.register(classOf[geotrellis.raster.UShortConstantNoDataArrayTile]) + kryo.register(classOf[geotrellis.raster.UShortUserDefinedNoDataArrayTile]) + kryo.register(classOf[geotrellis.raster.IntArrayTile]) // Int + kryo.register(classOf[geotrellis.raster.IntRawArrayTile]) + kryo.register(classOf[geotrellis.raster.IntConstantNoDataArrayTile]) + kryo.register(classOf[geotrellis.raster.IntUserDefinedNoDataArrayTile]) + kryo.register(classOf[geotrellis.raster.FloatArrayTile]) // Float + kryo.register(classOf[geotrellis.raster.FloatRawArrayTile]) + kryo.register(classOf[geotrellis.raster.FloatConstantNoDataArrayTile]) + kryo.register(classOf[geotrellis.raster.FloatUserDefinedNoDataArrayTile]) + kryo.register(classOf[geotrellis.raster.DoubleArrayTile]) // Double + kryo.register(classOf[geotrellis.raster.DoubleRawArrayTile]) + kryo.register(classOf[geotrellis.raster.DoubleConstantNoDataArrayTile]) + kryo.register(classOf[geotrellis.raster.DoubleUserDefinedNoDataArrayTile]) + + kryo.register(classOf[Array[geotrellis.raster.Tile]]) + kryo.register(classOf[Array[geotrellis.raster.TileFeature[_,_]]]) + kryo.register(classOf[geotrellis.raster.Tile]) + kryo.register(classOf[geotrellis.raster.TileFeature[_,_]]) + + kryo.register(classOf[geotrellis.raster.ArrayMultibandTile]) + kryo.register(classOf[geotrellis.raster.CompositeTile]) + kryo.register(classOf[geotrellis.raster.ConstantTile]) + kryo.register(classOf[geotrellis.raster.CroppedTile]) + kryo.register(classOf[geotrellis.raster.Raster[_]]) + kryo.register(classOf[geotrellis.raster.RasterExtent]) + kryo.register(classOf[geotrellis.raster.CellGrid[_]]) + kryo.register(classOf[geotrellis.raster.CellSize]) + kryo.register(classOf[geotrellis.raster.GridBounds[_]]) + kryo.register(classOf[geotrellis.raster.GridExtent[_]]) + kryo.register(classOf[geotrellis.raster.mapalgebra.focal.TargetCell]) + kryo.register(classOf[geotrellis.raster.summary.GridVisitor[_, _]]) + kryo.register(geotrellis.raster.mapalgebra.focal.TargetCell.All.getClass) + kryo.register(geotrellis.raster.mapalgebra.focal.TargetCell.Data.getClass) + kryo.register(geotrellis.raster.mapalgebra.focal.TargetCell.NoData.getClass) + + kryo.register(classOf[geotrellis.layer.SpatialKey]) + kryo.register(classOf[geotrellis.layer.SpaceTimeKey]) + kryo.register(classOf[geotrellis.store.index.rowmajor.RowMajorSpatialKeyIndex]) + kryo.register(classOf[geotrellis.store.index.zcurve.ZSpatialKeyIndex]) + kryo.register(classOf[geotrellis.store.index.zcurve.ZSpaceTimeKeyIndex]) + kryo.register(classOf[geotrellis.store.index.hilbert.HilbertSpatialKeyIndex]) + kryo.register(classOf[geotrellis.store.index.hilbert.HilbertSpaceTimeKeyIndex]) + kryo.register(classOf[geotrellis.vector.ProjectedExtent]) + kryo.register(classOf[geotrellis.vector.Extent]) + kryo.register(classOf[geotrellis.proj4.CRS]) + + // UnmodifiableCollectionsSerializer.registerSerializers(kryo) + kryo.register(geotrellis.raster.buffer.Direction.Center.getClass) + kryo.register(geotrellis.raster.buffer.Direction.Top.getClass) + kryo.register(geotrellis.raster.buffer.Direction.Bottom.getClass) + kryo.register(geotrellis.raster.buffer.Direction.Left.getClass) + kryo.register(geotrellis.raster.buffer.Direction.Right.getClass) + kryo.register(geotrellis.raster.buffer.Direction.TopLeft.getClass) + kryo.register(geotrellis.raster.buffer.Direction.TopRight.getClass) + kryo.register(geotrellis.raster.buffer.Direction.BottomLeft.getClass) + kryo.register(geotrellis.raster.buffer.Direction.BottomRight.getClass) + + /* Exhaustive Registration */ + kryo.register(classOf[Array[Double]]) + kryo.register(classOf[Array[Float]]) + kryo.register(classOf[Array[Int]]) + kryo.register(classOf[Array[String]]) + kryo.register(classOf[Array[org.locationtech.jts.geom.Coordinate]]) + kryo.register(classOf[Array[org.locationtech.jts.geom.LinearRing]]) + kryo.register(classOf[Array[org.locationtech.jts.geom.Polygon]]) + kryo.register(classOf[Array[geotrellis.store.avro.AvroRecordCodec[Any]]]) + kryo.register(classOf[Array[geotrellis.layer.SpaceTimeKey]]) + kryo.register(classOf[Array[geotrellis.layer.SpatialKey]]) + kryo.register(classOf[Array[geotrellis.vector.Feature[_, Any]]]) + kryo.register(classOf[Array[geotrellis.vector.MultiPolygon]]) + kryo.register(classOf[Array[geotrellis.vector.Point]]) + kryo.register(classOf[Array[geotrellis.vector.Polygon]]) + kryo.register(classOf[Array[scala.collection.Seq[Any]]]) + kryo.register(classOf[Array[(Any, Any)]]) + kryo.register(classOf[Array[(Any, Any, Any)]]) + kryo.register(classOf[org.locationtech.jts.geom.Coordinate]) + kryo.register(classOf[org.locationtech.jts.geom.Envelope]) + kryo.register(classOf[org.locationtech.jts.geom.GeometryFactory]) + kryo.register(classOf[org.locationtech.jts.geom.impl.CoordinateArraySequence]) + kryo.register(classOf[org.locationtech.jts.geom.impl.CoordinateArraySequenceFactory]) + kryo.register(classOf[org.locationtech.jts.geom.LinearRing]) + kryo.register(classOf[org.locationtech.jts.geom.MultiPolygon]) + kryo.register(classOf[org.locationtech.jts.geom.Point]) + kryo.register(classOf[org.locationtech.jts.geom.Polygon]) + kryo.register(classOf[org.locationtech.jts.geom.PrecisionModel]) + kryo.register(classOf[org.locationtech.jts.geom.PrecisionModel.Type]) + kryo.register(classOf[geotrellis.raster.histogram.FastMapHistogram]) + kryo.register(classOf[geotrellis.raster.histogram.Histogram[AnyVal]]) + kryo.register(classOf[geotrellis.raster.histogram.MutableHistogram[AnyVal]]) + kryo.register(classOf[geotrellis.raster.histogram.StreamingHistogram]) + kryo.register(classOf[geotrellis.raster.histogram.StreamingHistogram.DeltaCompare]) + kryo.register(classOf[geotrellis.raster.histogram.StreamingHistogram.Delta]) + kryo.register(classOf[geotrellis.raster.histogram.StreamingHistogram.Bucket]) + kryo.register(classOf[geotrellis.raster.density.KernelStamper]) + kryo.register(classOf[geotrellis.raster.ProjectedRaster[_]]) + kryo.register(classOf[geotrellis.raster.TileLayout]) + kryo.register(classOf[geotrellis.layer.TemporalProjectedExtent]) + kryo.register(classOf[geotrellis.raster.buffer.BufferSizes]) + kryo.register(classOf[geotrellis.store.avro.AvroRecordCodec[Any]]) + kryo.register(classOf[geotrellis.store.avro.AvroUnionCodec[Any]]) + kryo.register(classOf[geotrellis.store.avro.codecs.KeyValueRecordCodec[Any, Any]]) + kryo.register(classOf[geotrellis.store.avro.codecs.TupleCodec[Any, Any]]) + kryo.register(classOf[geotrellis.layer.KeyBounds[Any]]) + kryo.register(classOf[geotrellis.spark.knn.KNearestRDD.Ord[Any]]) + kryo.register(classOf[geotrellis.vector.Feature[_, Any]]) + kryo.register(classOf[geotrellis.vector.Geometry], new GeometrySerializer[geotrellis.vector.Geometry]) + kryo.register(classOf[geotrellis.vector.GeometryCollection]) + kryo.register(classOf[geotrellis.vector.LineString], new GeometrySerializer[geotrellis.vector.LineString]) + kryo.register(classOf[geotrellis.vector.MultiLineString], new GeometrySerializer[geotrellis.vector.MultiLineString]) + kryo.register(classOf[geotrellis.vector.MultiPoint], new GeometrySerializer[geotrellis.vector.MultiPoint]) + kryo.register(classOf[geotrellis.vector.MultiPolygon], new GeometrySerializer[geotrellis.vector.MultiPolygon]) + kryo.register(classOf[geotrellis.vector.Point]) + kryo.register(classOf[geotrellis.vector.Polygon], new GeometrySerializer[geotrellis.vector.Polygon]) + kryo.register(classOf[geotrellis.vector.SpatialIndex[Any]]) + kryo.register(classOf[java.lang.Class[Any]]) + kryo.register(classOf[java.util.TreeMap[Any, Any]]) + kryo.register(classOf[java.util.HashMap[Any, Any]]) + kryo.register(classOf[java.util.HashSet[Any]]) + kryo.register(classOf[java.util.LinkedHashMap[Any, Any]]) + kryo.register(classOf[java.util.LinkedHashSet[Any]]) + kryo.register(classOf[org.apache.hadoop.io.BytesWritable]) + kryo.register(classOf[org.apache.hadoop.io.BigIntWritable]) + kryo.register(classOf[Array[org.apache.hadoop.io.BigIntWritable]]) + kryo.register(classOf[Array[org.apache.hadoop.io.BytesWritable]]) + kryo.register(classOf[org.locationtech.proj4j.CoordinateReferenceSystem]) + kryo.register(classOf[org.locationtech.proj4j.datum.AxisOrder]) + kryo.register(classOf[org.locationtech.proj4j.datum.AxisOrder.Axis]) + kryo.register(classOf[org.locationtech.proj4j.datum.Datum]) + kryo.register(classOf[org.locationtech.proj4j.datum.Ellipsoid]) + kryo.register(classOf[org.locationtech.proj4j.datum.Grid]) + kryo.register(classOf[org.locationtech.proj4j.datum.Grid.ConversionTable]) + kryo.register(classOf[org.locationtech.proj4j.util.PolarCoordinate]) + kryo.register(classOf[org.locationtech.proj4j.util.FloatPolarCoordinate]) + kryo.register(classOf[org.locationtech.proj4j.util.IntPolarCoordinate]) + kryo.register(classOf[Array[org.locationtech.proj4j.util.FloatPolarCoordinate]]) + kryo.register(classOf[org.locationtech.proj4j.datum.PrimeMeridian]) + kryo.register(classOf[org.locationtech.proj4j.proj.LambertConformalConicProjection]) + kryo.register(classOf[org.locationtech.proj4j.proj.LongLatProjection]) + kryo.register(classOf[org.locationtech.proj4j.proj.TransverseMercatorProjection]) + kryo.register(classOf[org.locationtech.proj4j.proj.MercatorProjection]) + kryo.register(classOf[org.locationtech.proj4j.units.DegreeUnit]) + kryo.register(classOf[org.locationtech.proj4j.units.Unit]) + kryo.register(classOf[scala.collection.mutable.WrappedArray.ofInt]) + kryo.register(classOf[scala.collection.mutable.WrappedArray.ofRef[AnyRef]]) + kryo.register(classOf[scala.collection.Seq[Any]]) + kryo.register(classOf[(Any, Any, Any)]) + kryo.register(geotrellis.proj4.LatLng.getClass) + kryo.register(geotrellis.layer.EmptyBounds.getClass) + kryo.register(scala.collection.immutable.Nil.getClass) + kryo.register(scala.math.Ordering.Double.getClass) + kryo.register(scala.math.Ordering.Float.getClass) + kryo.register(scala.math.Ordering.Int.getClass) + kryo.register(scala.math.Ordering.Long.getClass) + kryo.register(scala.None.getClass) kryo.register(classOf[RFRasterSource]) kryo.register(classOf[RasterRef]) kryo.register(classOf[DelegatingRasterSource]) From ef2f4ee5d1cdc5599d115fc35ac338b800422528 Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Mon, 5 Dec 2022 03:01:25 -0500 Subject: [PATCH 11/34] Register functions directly this is a starting point --- .../rasterframes/expressions/package.scala | 276 +++++++++++------- 1 file changed, 164 insertions(+), 112 deletions(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala index 40d22f96d..9f5686d10 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala @@ -22,13 +22,12 @@ package org.locationtech.rasterframes import geotrellis.raster.{DoubleConstantNoDataCellType, Tile} -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistryBase} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} -import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} -import org.apache.spark.sql.rf.VersionShims._ +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ExpressionInfo, ScalaUDF} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, ScalaReflection} import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.{SQLContext, rf} +import org.apache.spark.sql.{SQLContext} import org.locationtech.rasterframes.expressions.accessors._ import org.locationtech.rasterframes.expressions.aggregates.CellCountAggregate.DataCells import org.locationtech.rasterframes.expressions.aggregates._ @@ -38,6 +37,7 @@ import org.locationtech.rasterframes.expressions.focalops._ import org.locationtech.rasterframes.expressions.tilestats._ import org.locationtech.rasterframes.expressions.transformers._ +import scala.reflect.ClassTag import scala.reflect.runtime.universe._ /** @@ -64,114 +64,166 @@ package object expressions { def udfiexpr[RT: TypeTag, A1: TypeTag](name: String, f: DataType => A1 => RT): Expression => ScalaUDF = (child: Expression) => { val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[RT] ScalaUDF((row: A1) => f(child.dataType)(row), dataType, Seq(child), Seq(Option(ExpressionEncoder[RT]().resolveAndBind())), udfName = Some(name)) + + } + + private def expressionInfo[T : ClassTag](name: String, since: Option[String], database: Option[String]): ExpressionInfo = { + val clazz = scala.reflect.classTag[T].runtimeClass + val df = clazz.getAnnotation(classOf[ExpressionDescription]) + if (df != null) { + if (df.extended().isEmpty) { + new ExpressionInfo( + clazz.getCanonicalName, + database.orNull, + name, + df.usage(), + df.arguments(), + df.examples(), + df.note(), + df.group(), + since.getOrElse(df.since()), + df.deprecated(), + df.source()) + } else { + // This exists for the backward compatibility with old `ExpressionDescription`s defining + // the extended description in `extended()`. + new ExpressionInfo(clazz.getCanonicalName, database.orNull, name, df.usage(), df.extended()) + } + } else { + new ExpressionInfo(clazz.getCanonicalName, name) + } } - def register(sqlContext: SQLContext): Unit = { - // Expression-oriented functions have a different registration scheme - // Currently have to register with the `builtin` registry due to Spark data hiding. - val registry: FunctionRegistry = rf.registry(sqlContext) - - registry.registerExpression[Add]("rf_local_add") - registry.registerExpression[Subtract]("rf_local_subtract") - registry.registerExpression[TileAssembler]("rf_assemble_tile") - registry.registerExpression[ExplodeTiles]("rf_explode_tiles") - registry.registerExpression[GetCellType]("rf_cell_type") - registry.registerExpression[SetCellType]("rf_convert_cell_type") - registry.registerExpression[InterpretAs]("rf_interpret_cell_type_as") - registry.registerExpression[SetNoDataValue]("rf_with_no_data") - registry.registerExpression[GetDimensions]("rf_dimensions") - registry.registerExpression[ExtentToGeometry]("st_geometry") - registry.registerExpression[GetGeometry]("rf_geometry") - registry.registerExpression[GeometryToExtent]("st_extent") - registry.registerExpression[GetExtent]("rf_extent") - registry.registerExpression[GetCRS]("rf_crs") - registry.registerExpression[RealizeTile]("rf_tile") - registry.registerExpression[CreateProjectedRaster]("rf_proj_raster") - registry.registerExpression[Multiply]("rf_local_multiply") - registry.registerExpression[Divide]("rf_local_divide") - registry.registerExpression[NormalizedDifference]("rf_normalized_difference") - registry.registerExpression[Less]("rf_local_less") - registry.registerExpression[Greater]("rf_local_greater") - registry.registerExpression[LessEqual]("rf_local_less_equal") - registry.registerExpression[GreaterEqual]("rf_local_greater_equal") - registry.registerExpression[Equal]("rf_local_equal") - registry.registerExpression[Unequal]("rf_local_unequal") - registry.registerExpression[IsIn]("rf_local_is_in") - registry.registerExpression[Undefined]("rf_local_no_data") - registry.registerExpression[Defined]("rf_local_data") - registry.registerExpression[Min]("rf_local_min") - registry.registerExpression[Max]("rf_local_max") - registry.registerExpression[Clamp]("rf_local_clamp") - registry.registerExpression[Where]("rf_where") - registry.registerExpression[Standardize]("rf_standardize") - registry.registerExpression[Rescale]("rf_rescale") - registry.registerExpression[Sum]("rf_tile_sum") - registry.registerExpression[Round]("rf_round") - registry.registerExpression[Abs]("rf_abs") - registry.registerExpression[Log]("rf_log") - registry.registerExpression[Log10]("rf_log10") - registry.registerExpression[Log2]("rf_log2") - registry.registerExpression[Log1p]("rf_log1p") - registry.registerExpression[Exp]("rf_exp") - registry.registerExpression[Exp10]("rf_exp10") - registry.registerExpression[Exp2]("rf_exp2") - registry.registerExpression[ExpM1]("rf_expm1") - registry.registerExpression[Sqrt]("rf_sqrt") - registry.registerExpression[Resample]("rf_resample") - registry.registerExpression[ResampleNearest]("rf_resample_nearest") - registry.registerExpression[TileToArrayDouble]("rf_tile_to_array_double") - registry.registerExpression[TileToArrayInt]("rf_tile_to_array_int") - registry.registerExpression[DataCells]("rf_data_cells") - registry.registerExpression[NoDataCells]("rf_no_data_cells") - registry.registerExpression[IsNoDataTile]("rf_is_no_data_tile") - registry.registerExpression[Exists]("rf_exists") - registry.registerExpression[ForAll]("rf_for_all") - registry.registerExpression[TileMin]("rf_tile_min") - registry.registerExpression[TileMax]("rf_tile_max") - registry.registerExpression[TileMean]("rf_tile_mean") - registry.registerExpression[TileStats]("rf_tile_stats") - registry.registerExpression[TileHistogram]("rf_tile_histogram") - registry.registerExpression[DataCells]("rf_agg_data_cells") - registry.registerExpression[CellCountAggregate.NoDataCells]("rf_agg_no_data_cells") - registry.registerExpression[CellStatsAggregate.CellStatsAggregateUDAF]("rf_agg_stats") - registry.registerExpression[HistogramAggregate.HistogramAggregateUDAF]("rf_agg_approx_histogram") - registry.registerExpression[LocalStatsAggregate.LocalStatsAggregateUDAF]("rf_agg_local_stats") - registry.registerExpression[LocalTileOpAggregate.LocalMinUDAF]("rf_agg_local_min") - registry.registerExpression[LocalTileOpAggregate.LocalMaxUDAF]("rf_agg_local_max") - registry.registerExpression[LocalCountAggregate.LocalDataCellsUDAF]("rf_agg_local_data_cells") - registry.registerExpression[LocalCountAggregate.LocalNoDataCellsUDAF]("rf_agg_local_no_data_cells") - registry.registerExpression[LocalMeanAggregate]("rf_agg_local_mean") - - registry.registerExpression[FocalMax](FocalMax.name) - registry.registerExpression[FocalMin](FocalMin.name) - registry.registerExpression[FocalMean](FocalMean.name) - registry.registerExpression[FocalMode](FocalMode.name) - registry.registerExpression[FocalMedian](FocalMedian.name) - registry.registerExpression[FocalMoransI](FocalMoransI.name) - registry.registerExpression[FocalStdDev](FocalStdDev.name) - registry.registerExpression[Convolve](Convolve.name) - - registry.registerExpression[Slope](Slope.name) - registry.registerExpression[Aspect](Aspect.name) - registry.registerExpression[Hillshade](Hillshade.name) - - registry.registerExpression[Mask.MaskByDefined]("rf_mask") - registry.registerExpression[Mask.InverseMaskByDefined]("rf_inverse_mask") - registry.registerExpression[Mask.MaskByValue]("rf_mask_by_value") - registry.registerExpression[Mask.InverseMaskByValue]("rf_inverse_mask_by_value") - registry.registerExpression[Mask.MaskByValues]("rf_mask_by_values") - - registry.registerExpression[DebugRender.RenderAscii]("rf_render_ascii") - registry.registerExpression[DebugRender.RenderMatrix]("rf_render_matrix") - registry.registerExpression[RenderPNG.RenderCompositePNG]("rf_render_png") - registry.registerExpression[RGBComposite]("rf_rgb_composite") - - registry.registerExpression[XZ2Indexer]("rf_xz2_index") - registry.registerExpression[Z2Indexer]("rf_z2_index") - - registry.registerExpression[transformers.ReprojectGeometry]("st_reproject") - - registry.registerExpression[ExtractBits]("rf_local_extract_bits") - registry.registerExpression[ExtractBits]("rf_local_extract_bit") + def register(sqlContext: SQLContext, database: Option[String] = None): Unit = { + val registry = sqlContext.sparkSession.sessionState.functionRegistry + + def registerFunction[T <: Expression : ClassTag](name: String, since: Option[String] = None)(builder: Seq[Expression] => T): Unit = { + val id = FunctionIdentifier(name, database) + val info = FunctionRegistryBase.expressionInfo[T](name, since) + registry.registerFunction(id, info, builder) + } + + def register1[T <: Expression : ClassTag]( + name: String, + builder: Expression => T + ): Unit = registerFunction[T](name, None){ case Seq(a) => builder(a) + } + + def register2[T <: Expression : ClassTag]( + name: String, + builder: (Expression, Expression) => T + ): Unit = registerFunction[T](name, None){ case Seq(a, b) => builder(a, b) } + + def register3[T <: Expression : ClassTag]( + name: String, + builder: (Expression, Expression, Expression) => T + ): Unit = registerFunction[T](name, None){ case Seq(a, b, c) => builder(a, b, c) } + + def register5[T <: Expression : ClassTag]( + name: String, + builder: (Expression, Expression, Expression, Expression, Expression) => T + ): Unit = registerFunction[T](name, None){ case Seq(a, b, c, d, e) => builder(a, b, c, d, e) } + + register2("rf_local_add", Add(_, _)) + register2("rf_local_subtract", Subtract(_, _)) + registerFunction("rf_explode_tiles"){ExplodeTiles(1.0, None, _)} + register5("rf_assemble_tile", TileAssembler(_, _, _, _, _)) + register1("rf_cell_type", GetCellType(_)) + register2("rf_convert_cell_type", SetCellType(_, _)) + register2("rf_interpret_cell_type_as", InterpretAs(_, _)) + register2("rf_with_no_data", SetNoDataValue(_,_)) + register1("rf_dimensions", GetDimensions(_)) + register1("st_geometry", ExtentToGeometry(_)) + register1("rf_geometry", GetGeometry(_)) + register1("st_extent", GeometryToExtent(_)) + register1("rf_extent", GetExtent(_)) + register1("rf_crs", GetCRS(_)) + register1("rf_tile", RealizeTile(_)) + register3("rf_proj_raster", CreateProjectedRaster(_, _, _)) + register2("rf_local_multiply", Multiply(_, _)) + register2("rf_local_divide", Divide(_, _)) + register2("rf_normalized_difference", NormalizedDifference(_,_)) + register2("rf_local_less", Less(_, _)) + register2("rf_local_greater", Greater(_, _)) + register2("rf_local_less_equal", LessEqual(_, _)) + register2("rf_local_greater_equal", GreaterEqual(_, _)) + register2("rf_local_equal", Equal(_, _)) + register2("rf_local_unequal", Unequal(_, _)) + register2("rf_local_is_in", IsIn(_, _)) + register1("rf_local_no_data", Undefined(_)) + register1("rf_local_data", Defined(_)) + register2("rf_local_min", Min(_, _)) + register2("rf_local_max", Max(_, _)) + register3("rf_local_clamp", Clamp(_, _, _)) + register3("rf_where", Where(_, _, _)) + register3("rf_standardize", Standardize(_, _, _)) + register3("rf_rescale", Rescale(_, _ , _)) + register1("rf_tile_sum", Sum(_)) + register1("rf_round", Round(_)) + register1("rf_abs", Abs(_)) + register1("rf_log", Log(_)) + register1("rf_log10", Log10(_)) + register1("rf_log2", Log2(_)) + register1("rf_log1p", Log1p(_)) + register1("rf_exp", Exp(_)) + register1("rf_exp10", Exp10(_)) + register1("rf_exp2", Exp2(_)) + register1("rf_expm1", ExpM1(_)) + register1("rf_sqrt", Sqrt(_)) + register3("rf_resample", Resample(_, _, _)) + register2("rf_resample_nearest", ResampleNearest(_, _)) + register1("rf_tile_to_array_double", TileToArrayDouble(_)) + register1("rf_tile_to_array_int", TileToArrayInt(_)) + register1("rf_data_cells", DataCells(_)) + register1("rf_no_data_cells", NoDataCells(_)) + register1("rf_is_no_data_tile", IsNoDataTile(_)) + register1("rf_exists", Exists(_)) + register1("rf_for_all", ForAll(_)) + register1("rf_tile_min", TileMin(_)) + register1("rf_tile_max", TileMax(_)) + register1("rf_tile_mean", TileMean(_)) + register1("rf_tile_stats", TileStats(_)) + register1("rf_tile_histogram", TileHistogram(_)) + register1("rf_agg_data_cells", DataCells(_)) + register1("rf_agg_no_data_cells", CellCountAggregate.NoDataCells(_)) + register1("rf_agg_stats", CellStatsAggregate.CellStatsAggregateUDAF(_)) + register1("rf_agg_approx_histogram", HistogramAggregate.HistogramAggregateUDAF(_)) + register1("rf_agg_local_stats", LocalStatsAggregate.LocalStatsAggregateUDAF(_)) + register1("rf_agg_local_min",LocalTileOpAggregate.LocalMinUDAF(_)) + register1("rf_agg_local_max", LocalTileOpAggregate.LocalMaxUDAF(_)) + register1("rf_agg_local_data_cells", LocalCountAggregate.LocalDataCellsUDAF(_)) + register1("rf_agg_local_no_data_cells", LocalCountAggregate.LocalNoDataCellsUDAF(_)) + register1("rf_agg_local_mean", LocalMeanAggregate(_)) + register3(FocalMax.name, FocalMax(_, _, _)) + register3(FocalMin.name, FocalMin(_, _, _)) + register3(FocalMean.name, FocalMean(_, _, _)) + register3(FocalMode.name, FocalMode(_, _, _)) + register3(FocalMedian.name, FocalMedian(_, _, _)) + register3(FocalMoransI.name, FocalMoransI(_, _, _)) + register3(FocalStdDev.name, FocalStdDev(_, _, _)) + register3(Convolve.name, Convolve(_, _, _)) + + register3(Slope.name, Slope(_, _, _)) + register2(Aspect.name, Aspect(_, _)) + register5(Hillshade.name, Hillshade(_, _, _, _, _)) + + register2("rf_mask", Mask.MaskByDefined(_, _)) + register2("rf_inverse_mask", Mask.InverseMaskByDefined(_, _)) + register3("rf_mask_by_value", Mask.MaskByValue(_, _, _)) + register3("rf_inverse_mask_by_value", Mask.InverseMaskByValue(_, _, _)) + register2("rf_mask_by_values", Mask.MaskByValues(_, _)) + + register1("rf_render_ascii", DebugRender.RenderAscii(_)) + register1("rf_render_matrix", DebugRender.RenderMatrix(_)) + register1("rf_render_png", RenderPNG.RenderCompositePNG(_)) + register3("rf_rgb_composite", RGBComposite(_, _, _)) + + register2("rf_xz2_index", XZ2Indexer(_, _, 18.toShort)) + register2("rf_z2_index", Z2Indexer(_, _, 31.toShort)) + + register3("st_reproject", ReprojectGeometry(_, _, _)) + + register3[ExtractBits]("rf_local_extract_bits", ExtractBits(_: Expression, _: Expression, _: Expression)) + register3[ExtractBits]("rf_local_extract_bit", ExtractBits(_: Expression, _: Expression, _: Expression)) } } From d12ace5a8e3f420c55347388fe53575a34c49ec3 Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Mon, 5 Dec 2022 03:02:01 -0500 Subject: [PATCH 12/34] Bump versions --- project/RFDependenciesPlugin.scala | 12 ++++++------ project/build.properties | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/project/RFDependenciesPlugin.scala b/project/RFDependenciesPlugin.scala index ed55afe9d..938e0d1a2 100644 --- a/project/RFDependenciesPlugin.scala +++ b/project/RFDependenciesPlugin.scala @@ -54,8 +54,8 @@ object RFDependenciesPlugin extends AutoPlugin { val `scala-logging` = "com.typesafe.scala-logging" %% "scala-logging" % "3.9.4" val stac4s = "com.azavea.stac4s" %% "client" % "0.7.2" val sttpCatsCe2 = "com.softwaremill.sttp.client3" %% "async-http-client-backend-cats-ce2" % "3.3.15" - val frameless = "org.typelevel" %% "frameless-dataset" % "0.11.1" - val framelessRefined = "org.typelevel" %% "frameless-refined" % "0.11.1" + val frameless = "org.typelevel" %% "frameless-dataset" % "0.12.0" + val framelessRefined = "org.typelevel" %% "frameless-refined" % "0.12.0" val `better-files` = "com.github.pathikrit" %% "better-files" % "3.9.1" % Test } import autoImport._ @@ -70,9 +70,9 @@ object RFDependenciesPlugin extends AutoPlugin { "jitpack" at "https://jitpack.io" ), // NB: Make sure to update the Spark version in pyrasterframes/python/setup.py - rfSparkVersion := "3.2.0", - rfGeoTrellisVersion := "3.6.1", - rfGeoMesaVersion := "3.2.0", - excludeDependencies += "log4j" % "log4j" + rfSparkVersion := "3.2.1", + rfGeoTrellisVersion := "3.6.3", + rfGeoMesaVersion := "3.4.1" + //excludeDependencies += "log4j" % "log4j" ) } diff --git a/project/build.properties b/project/build.properties index c8fcab543..8b9a0b0ab 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.6.2 +sbt.version=1.8.0 From aab54860c7b0e72c17be9e3cc878e0ed8f362847 Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Thu, 8 Dec 2022 21:20:34 -0500 Subject: [PATCH 13/34] Landsat PDS is gone :( --- .../locationtech/rasterframes/ref/RasterRefIT.scala | 3 +-- .../scala/org/locationtech/rasterframes/TestData.scala | 10 +++++----- .../datasource/raster/RaterSourceDataSourceIT.scala | 4 ++-- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/core/src/it/scala/org/locationtech/rasterframes/ref/RasterRefIT.scala b/core/src/it/scala/org/locationtech/rasterframes/ref/RasterRefIT.scala index 3ea7a65d0..2e098c008 100644 --- a/core/src/it/scala/org/locationtech/rasterframes/ref/RasterRefIT.scala +++ b/core/src/it/scala/org/locationtech/rasterframes/ref/RasterRefIT.scala @@ -32,8 +32,7 @@ class RasterRefIT extends TestEnvironment { describe("practical subregion reads") { it("should construct a natural color composite") { import spark.implicits._ - def scene(idx: Int) = URI.create(s"https://landsat-pds.s3.us-west-2.amazonaws.com" + - s"/c1/L8/176/039/LC08_L1TP_176039_20190703_20190718_01_T1/LC08_L1TP_176039_20190703_20190718_01_T1_B$idx.TIF") + def scene(idx: Int) = TestData.remoteCOGSingleBand(idx) val redScene = RFRasterSource(scene(4)) // [west, south, east, north] diff --git a/core/src/test/scala/org/locationtech/rasterframes/TestData.scala b/core/src/test/scala/org/locationtech/rasterframes/TestData.scala index 09a8139c5..98948fc5b 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/TestData.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/TestData.scala @@ -143,15 +143,15 @@ trait TestData { } // Check the URL exists as of 2020-09-30; strictly these are not COGs because they do not have internal overviews - private def remoteCOGSingleBand(b: Int) = URI.create(s"https://landsat-pds.s3.us-west-2.amazonaws.com/c1/L8/017/029/LC08_L1TP_017029_20200422_20200509_01_T1/LC08_L1TP_017029_20200422_20200509_01_T1_B${b}.TIF") - lazy val remoteCOGSingleband1: URI = remoteCOGSingleBand(1) - lazy val remoteCOGSingleband2: URI = remoteCOGSingleBand(2) + def remoteCOGSingleBand(b: Int) = URI.create(s"https://geotrellis-test.s3.us-east-1.amazonaws.com/landsat/LC80030172015001LGN00_B${b}.tiff") + lazy val remoteCOGSingleband1: URI = remoteCOGSingleBand(2) + lazy val remoteCOGSingleband2: URI = remoteCOGSingleBand(3) // a public 4 band COG TIF - lazy val remoteCOGMultiband: URI = URI.create("https://s22s-rasterframes-integration-tests.s3.amazonaws.com/m_4411708_ne_11_1_20141005.cog.tif") + lazy val remoteCOGMultiband: URI = URI.create("https://geotrellis-test.s3.us-east-1.amazonaws.com/landsat-multiband-band-cropped.tif") lazy val remoteMODIS: URI = URI.create("https://modis-pds.s3.amazonaws.com/MCD43A4.006/31/11/2017158/MCD43A4.A2017158.h31v11.006.2017171203421_B01.TIF") - lazy val remoteL8: URI = URI.create("https://landsat-pds.s3.amazonaws.com/c1/L8/017/033/LC08_L1TP_017033_20181010_20181030_01_T1/LC08_L1TP_017033_20181010_20181030_01_T1_B4.TIF") + lazy val remoteL8: URI = URI.create("https://geotrellis-test.s3.us-east-1.amazonaws.com/landsat/LC80030172015001LGN00_B4.tiff") lazy val remoteHttpMrfPath: URI = URI.create("https://s3.amazonaws.com/s22s-rasterframes-integration-tests/m_3607526_sw_18_1_20160708.mrf") lazy val remoteS3MrfPath: URI = URI.create("s3://naip-analytic/va/2016/100cm/rgbir/37077/m_3707764_sw_18_1_20160708.mrf") diff --git a/datasource/src/it/scala/org/locationtech/rasterframes/datasource/raster/RaterSourceDataSourceIT.scala b/datasource/src/it/scala/org/locationtech/rasterframes/datasource/raster/RaterSourceDataSourceIT.scala index 4fa5d08c2..2e00d8bf2 100644 --- a/datasource/src/it/scala/org/locationtech/rasterframes/datasource/raster/RaterSourceDataSourceIT.scala +++ b/datasource/src/it/scala/org/locationtech/rasterframes/datasource/raster/RaterSourceDataSourceIT.scala @@ -37,9 +37,9 @@ class RaterSourceDataSourceIT extends TestEnvironment with TestData { rf.select(rf_extent($"proj_raster").alias("extent"), rf_crs($"proj_raster").alias("crs"), rf_tile($"proj_raster").alias("target")) val cat = - """ + s""" B3,B5 - https://landsat-pds.s3.us-west-2.amazonaws.com/c1/L8/021/028/LC08_L1TP_021028_20180515_20180604_01_T1/LC08_L1TP_021028_20180515_20180604_01_T1_B3.TIF,https://landsat-pds.s3.us-west-2.amazonaws.com/c1/L8/021/028/LC08_L1TP_021028_20180515_20180604_01_T1/LC08_L1TP_021028_20180515_20180604_01_T1_B5.TIF + ${remoteCOGSingleband1},${remoteCOGSingleband2} """ val features_rf = spark.read.raster From a3ac4cf614e3bb7c335983337d2632743036ded1 Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Sat, 10 Dec 2022 09:46:26 -0500 Subject: [PATCH 14/34] Fix Resample and ResampleNearest Also untangle the super weird inheritance relationship between the two --- .../expressions/localops/Resample.scala | 171 +++++++----------- .../localops/ResampleNearest.scala | 84 +++++++++ 2 files changed, 146 insertions(+), 109 deletions(-) create mode 100644 core/src/main/scala/org/locationtech/rasterframes/expressions/localops/ResampleNearest.scala diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Resample.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Resample.scala index 9f2aec49c..55fd7afc3 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Resample.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Resample.scala @@ -22,96 +22,20 @@ package org.locationtech.rasterframes.expressions.localops import geotrellis.raster.Tile -import geotrellis.raster.resample._ -import geotrellis.raster.resample.{Max => RMax, Min => RMin, ResampleMethod => GTResampleMethod} +import geotrellis.raster.resample.{Mode, NearestNeighbor, Sum, Max => RMax, Min => RMin, ResampleMethod => GTResampleMethod} import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, Literal, TernaryExpression} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, TernaryExpression} import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String -import org.locationtech.rasterframes.util.ResampleMethod -import org.locationtech.rasterframes.expressions.{RasterResult, fpTile, row} import org.locationtech.rasterframes.expressions.DynamicExtractors._ +import org.locationtech.rasterframes.expressions.{RasterResult, fpTile, row} +import org.locationtech.rasterframes.util.ResampleMethod -abstract class ResampleBase(left: Expression, right: Expression, method: Expression) extends TernaryExpression with RasterResult with CodegenFallback with Serializable { - - override val nodeName: String = "rf_resample" - def first: Expression = left - def second: Expression = right - def third: Expression = method - def dataType: DataType = left.dataType - - def targetFloatIfNeeded(t: Tile, method: GTResampleMethod): Tile = - method match { - case NearestNeighbor | Mode | RMax | RMin | Sum => t - case _ => fpTile(t) - } - - // These methods define the core algorithms to be used. - def op(left: Tile, right: Tile, method: GTResampleMethod): Tile = - op(left, right.cols, right.rows, method) - - def op(left: Tile, right: Double, method: GTResampleMethod): Tile = - op(left, (left.cols * right).toInt, (left.rows * right).toInt, method) - - def op(tile: Tile, newCols: Int, newRows: Int, method: GTResampleMethod): Tile = - targetFloatIfNeeded(tile, method).resample(newCols, newRows, method) - - override def checkInputDataTypes(): TypeCheckResult = { - // copypasta from BinaryLocalRasterOp - if (!tileExtractor.isDefinedAt(left.dataType)) { - TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a raster type.") - } - else if (!tileOrNumberExtractor.isDefinedAt(right.dataType)) { - TypeCheckFailure(s"Input type '${right.dataType}' does not conform to a compatible type.") - } else method.dataType match { - case StringType => TypeCheckSuccess - case _ => TypeCheckFailure(s"Cannot interpret value of type `${method.dataType.simpleString}` for resampling method; please provide a String method name.") - } - } - - override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = { - // more copypasta from BinaryLocalRasterOp - - val (leftTile, leftCtx) = tileExtractor(left.dataType)(row(input1)) - val methodString = input3.asInstanceOf[UTF8String].toString - - val resamplingMethod = methodString match { - case ResampleMethod(mm) => mm - case _ => throw new IllegalArgumentException("Unrecognized resampling method specified") - } - - val result: Tile = tileOrNumberExtractor(right.dataType)(input2) match { - // in this case we expect the left and right contexts to vary. no warnings raised. - case TileArg(rightTile, _) => op(leftTile, rightTile, resamplingMethod) - case DoubleArg(d) => op(leftTile, d, resamplingMethod) - case IntegerArg(i) => op(leftTile, i.toDouble, resamplingMethod) - } - - // reassemble the leftTile with its context. Note that this operation does not change Extent and CRS - toInternalRow(result, leftCtx) - } - - override def eval(input: InternalRow): Any = { - if(input == null) null - else { - val l = left.eval(input) - val r = right.eval(input) - val m = method.eval(input) - if (m == null) null // no method, return null - else if (l == null) null // no l tile, return null - else if (r == null) l // no target tile or factor, return l without changin it - else nullSafeEval(l, r, m) - } - } - -} - @ExpressionDescription( usage = "_FUNC_(tile, factor, method_name) - Resample tile to different dimension based on scalar `factor` or a tile whose dimension to match. Scalar less than one will downsample tile; greater than one will upsample. Uses resampling method named in the `method_name`." + "Methods average, mode, median, max, min, and sum aggregate over cells when downsampling", @@ -129,45 +53,74 @@ Examples: > SELECT _FUNC_(tile1, tile2, lit("cubic_spline")); ...""" ) -case class Resample(left: Expression, factor: Expression, method: Expression) extends ResampleBase(left, factor, method) { +case class Resample(tile: Expression, factor: Expression, method: Expression) extends TernaryExpression with RasterResult with CodegenFallback { + override val nodeName: String = "rf_resample" + def dataType: DataType = tile.dataType + def first: Expression = tile + def second: Expression = factor + def third: Expression = method + + override def checkInputDataTypes(): TypeCheckResult = { + if (!tileExtractor.isDefinedAt(tile.dataType)) { + TypeCheckFailure(s"Input type '${tile.dataType}' does not conform to a raster type.") + } else if (!tileOrNumberExtractor.isDefinedAt(factor.dataType)) { + TypeCheckFailure(s"Input type '${factor.dataType}' does not conform to a compatible type.") + } else + method.dataType match { + case StringType => TypeCheckSuccess + case _ => + TypeCheckFailure( + s"Cannot interpret value of type `${method.dataType.simpleString}` for resampling method; please provide a String method name." + ) + } + } + override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = { + val (leftTile, leftCtx) = tileExtractor(tile.dataType)(row(input1)) + val ton = tileOrNumberExtractor(factor.dataType)(input2) + val methodString = input3.asInstanceOf[UTF8String].toString + val resamplingMethod = methodString match { + case ResampleMethod(mm) => mm + case _ => throw new IllegalArgumentException("Unrecognized resampling method specified") + } + + val result: Tile = Resample.op(leftTile, ton, resamplingMethod) + toInternalRow(result, leftCtx) + } + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object Resample { - def apply(left: Column, right: Column, methodName: String): Column = - new Column(Resample(left.expr, right.expr, lit(methodName).expr)) + def op(tile: Tile, target: TileOrNumberArg, method: GTResampleMethod): Tile = { + val sourceTile = method match { + case NearestNeighbor | Mode | RMax | RMin | Sum => tile + case _ => fpTile(tile) + } + target match { + case TileArg(targetTile, _) => + sourceTile.resample(targetTile.cols, targetTile.rows, method) + case DoubleArg(d) => + sourceTile.resample((tile.cols * d).toInt, (tile.rows * d).toInt, method) + case IntegerArg(i) => + sourceTile.resample(tile.cols * i,tile.rows * i, method) + } + } - def apply(left: Column, right: Column, method: Column): Column = - new Column(Resample(left.expr, right.expr, method.expr)) + def apply(tile: Column, factor: Column, methodName: String): Column = + new Column(Resample(tile.expr, factor.expr, lit(methodName).expr)) - def apply[N: Numeric](left: Column, right: N, method: String): Column = new Column(Resample(left.expr, lit(right).expr, lit(method).expr)) + def apply(tile: Column, factor: Column, method: Column): Column = + new Column(Resample(tile.expr, factor.expr, method.expr)) - def apply[N: Numeric](left: Column, right: N, method: Column): Column = new Column(Resample(left.expr, lit(right).expr, method.expr)) + def apply[N: Numeric](tile: Column, factor: N, method: String): Column = + new Column(Resample(tile.expr, lit(factor).expr, lit(method).expr)) + def apply[N: Numeric](tile: Column, factor: N, method: Column): Column = + new Column(Resample(tile.expr, lit(factor).expr, method.expr)) } -@ExpressionDescription( - usage = "_FUNC_(tile, factor) - Resample tile to different size based on scalar factor or tile whose dimension to match. Scalar less than one will downsample tile; greater than one will upsample. Uses nearest-neighbor value.", - arguments = """ - Arguments: - * tile - tile - * rhs - scalar or tile to match dimension""", - examples = """ - Examples: - > SELECT _FUNC_(tile, 2.0); - ... - > SELECT _FUNC_(tile1, tile2); - ...""") -case class ResampleNearest(tile: Expression, target: Expression) extends ResampleBase(tile, target, Literal("nearest")) { - override val nodeName: String = "rf_resample_nearest" - - override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = - ResampleNearest(tile, target) -} -object ResampleNearest { - def apply(tile: Column, target: Column): Column = new Column(ResampleNearest(tile.expr, target.expr)) - def apply[N: Numeric](tile: Column, value: N): Column = new Column(ResampleNearest(tile.expr, lit(value).expr)) -} + + diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/ResampleNearest.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/ResampleNearest.scala new file mode 100644 index 000000000..d902cca6c --- /dev/null +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/ResampleNearest.scala @@ -0,0 +1,84 @@ +/* + * This software is licensed under the Apache 2 license, quoted below. + * + * Copyright 2019 Astraea, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * [http://www.apache.org/licenses/LICENSE-2.0] + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.locationtech.rasterframes.expressions.localops + +import geotrellis.raster.Tile +import geotrellis.raster.resample._ +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types.DataType +import org.locationtech.rasterframes.expressions.{RasterResult, row} +import org.locationtech.rasterframes.expressions.DynamicExtractors._ + + +@ExpressionDescription( + usage = + "_FUNC_(tile, factor) - Resample tile to different size based on scalar factor or tile whose dimension to match. Scalar less than one will downsample tile; greater than one will upsample. Uses nearest-neighbor value.", + arguments = """ + Arguments: + * tile - tile + * rhs - scalar or tile to match dimension""", + examples = """ + Examples: + > SELECT _FUNC_(tile, 2.0); + ... + > SELECT _FUNC_(tile1, tile2); + ...""" +) +case class ResampleNearest(tile: Expression, factor: Expression) extends BinaryExpression with RasterResult with CodegenFallback { + override val nodeName: String = "rf_resample_nearest" + def dataType: DataType = tile.dataType + def left: Expression = tile + def right: Expression = factor + + override def checkInputDataTypes(): TypeCheckResult = { + if (!tileExtractor.isDefinedAt(tile.dataType)) + TypeCheckFailure(s"Input type '${tile.dataType}' does not conform to a raster type.") + else if (!tileOrNumberExtractor.isDefinedAt(factor.dataType)) + TypeCheckFailure(s"Input type '${factor.dataType}' does not conform to a compatible type.") + else + TypeCheckSuccess + } + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val (leftTile, leftCtx) = tileExtractor(tile.dataType)(row(input1)) + val ton = tileOrNumberExtractor(factor.dataType)(input2) + + val result: Tile = Resample.op(leftTile, ton, NearestNeighbor) + toInternalRow(result, leftCtx) + } + + override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = + ResampleNearest(newLeft, newRight) +} + +object ResampleNearest { + def apply(tile: Column, target: Column): Column = + new Column(ResampleNearest(tile.expr, target.expr)) + + def apply[N: Numeric](tile: Column, value: N): Column = + new Column(ResampleNearest(tile.expr, lit(value).expr)) +} From 0214fa22577716da776d417595fb986b459e3f41 Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Sat, 10 Dec 2022 20:15:54 -0500 Subject: [PATCH 15/34] fix masking functions Made them more direct. Good for fixing things and better for performance because these versions don't need to create intermediate mask tiles. --- .../expressions/DynamicExtractors.scala | 19 ++ .../rasterframes/expressions/package.scala | 22 +- .../transformers/InverseMaskByDefined.scala | 85 +++++++ .../transformers/InverseMaskByValue.scala | 92 ++++++++ .../expressions/transformers/Mask.scala | 213 ------------------ .../transformers/MaskByDefined.scala | 84 +++++++ .../transformers/MaskByValue.scala | 92 ++++++++ .../transformers/MaskByValues.scala | 93 ++++++++ .../functions/LocalFunctions.scala | 22 +- .../functions/MaskingFunctionsSpec.scala | 6 +- 10 files changed, 491 insertions(+), 237 deletions(-) create mode 100644 core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InverseMaskByDefined.scala create mode 100644 core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InverseMaskByValue.scala delete mode 100644 core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala create mode 100644 core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByDefined.scala create mode 100644 core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByValue.scala create mode 100644 core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByValues.scala diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/DynamicExtractors.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/DynamicExtractors.scala index 9f337d226..efc71a01c 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/DynamicExtractors.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/DynamicExtractors.scala @@ -26,6 +26,7 @@ import geotrellis.raster.{CellGrid, Neighborhood, Raster, TargetCell, Tile} import geotrellis.vector.Extent import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.jts.JTSTypes import org.apache.spark.sql.rf.{RasterSourceUDT, TileUDT} import org.apache.spark.sql.types._ @@ -106,6 +107,24 @@ object DynamicExtractors { (row: InternalRow) => row.as[ProjectedRasterTile] } + lazy val intArrayExtractor: PartialFunction[DataType, ArrayData => Array[Int]] = { + case ArrayType(t, true) => + throw new IllegalArgumentException(s"Can't turn array of $t to array") + case ArrayType(DoubleType, false) => + unsafe => unsafe.toDoubleArray.map(_.toInt) + case ArrayType(FloatType, false) => + unsafe => unsafe.toFloatArray.map(_.toInt) + case ArrayType(IntegerType, false) => + unsafe => unsafe.toIntArray + case ArrayType(ShortType, false) => + unsafe => unsafe.toShortArray.map(_.toInt) + case ArrayType(ByteType, false) => + unsafe => unsafe.toByteArray.map(_.toInt) + case ArrayType(BooleanType, false) => + unsafe => unsafe.toBooleanArray().map(x => if (x) 1 else 0) + + } + lazy val crsExtractor: PartialFunction[DataType, Any => CRS] = { val base: PartialFunction[DataType, Any => CRS] = { case _: StringType => (v: Any) => LazyCRS(v.asInstanceOf[UTF8String].toString) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala index 9f5686d10..7f23b197c 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala @@ -22,12 +22,12 @@ package org.locationtech.rasterframes import geotrellis.raster.{DoubleConstantNoDataCellType, Tile} -import org.apache.spark.sql.catalyst.analysis.{FunctionRegistryBase} +import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ExpressionInfo, ScalaUDF} import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, ScalaReflection} import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.{SQLContext} +import org.apache.spark.sql.SQLContext import org.locationtech.rasterframes.expressions.accessors._ import org.locationtech.rasterframes.expressions.aggregates.CellCountAggregate.DataCells import org.locationtech.rasterframes.expressions.aggregates._ @@ -106,23 +106,23 @@ package object expressions { def register1[T <: Expression : ClassTag]( name: String, builder: Expression => T - ): Unit = registerFunction[T](name, None){ case Seq(a) => builder(a) + ): Unit = registerFunction[T](name, None){ args => builder(args(0)) } def register2[T <: Expression : ClassTag]( name: String, builder: (Expression, Expression) => T - ): Unit = registerFunction[T](name, None){ case Seq(a, b) => builder(a, b) } + ): Unit = registerFunction[T](name, None){ args => builder(args(0), args(1)) } def register3[T <: Expression : ClassTag]( name: String, builder: (Expression, Expression, Expression) => T - ): Unit = registerFunction[T](name, None){ case Seq(a, b, c) => builder(a, b, c) } + ): Unit = registerFunction[T](name, None){ args => builder(args(0), args(1), args(2)) } def register5[T <: Expression : ClassTag]( name: String, builder: (Expression, Expression, Expression, Expression, Expression) => T - ): Unit = registerFunction[T](name, None){ case Seq(a, b, c, d, e) => builder(a, b, c, d, e) } + ): Unit = registerFunction[T](name, None){ args => builder(args(0), args(1), args(2), args(3), args(4)) } register2("rf_local_add", Add(_, _)) register2("rf_local_subtract", Subtract(_, _)) @@ -207,11 +207,11 @@ package object expressions { register2(Aspect.name, Aspect(_, _)) register5(Hillshade.name, Hillshade(_, _, _, _, _)) - register2("rf_mask", Mask.MaskByDefined(_, _)) - register2("rf_inverse_mask", Mask.InverseMaskByDefined(_, _)) - register3("rf_mask_by_value", Mask.MaskByValue(_, _, _)) - register3("rf_inverse_mask_by_value", Mask.InverseMaskByValue(_, _, _)) - register2("rf_mask_by_values", Mask.MaskByValues(_, _)) + register2("rf_mask", MaskByDefined(_, _)) + register2("rf_inverse_mask", InverseMaskByDefined(_, _)) + register3("rf_mask_by_value", MaskByValue(_, _, _)) + register3("rf_inverse_mask_by_value", InverseMaskByValue(_, _, _)) + register3("rf_mask_by_values", MaskByValues(_, _, _)) register1("rf_render_ascii", DebugRender.RenderAscii(_)) register1("rf_render_matrix", DebugRender.RenderMatrix(_)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InverseMaskByDefined.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InverseMaskByDefined.scala new file mode 100644 index 000000000..b340c5583 --- /dev/null +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InverseMaskByDefined.scala @@ -0,0 +1,85 @@ +/* + * This software is licensed under the Apache 2 license, quoted below. + * + * Copyright 2019 Astraea, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * [http://www.apache.org/licenses/LICENSE-2.0] + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.locationtech.rasterframes.expressions.transformers + +import geotrellis.raster.{NODATA, Tile, isNoData} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.{Column, TypedColumn} +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription} +import org.apache.spark.sql.types.DataType +import org.locationtech.rasterframes.expressions.DynamicExtractors.tileExtractor +import org.locationtech.rasterframes.expressions.{RasterResult, row} +import org.locationtech.rasterframes.tileEncoder + + +@ExpressionDescription( + usage = "_FUNC_(target, mask) - Generate a tile with the values from the data tile, but where cells in the masking tile DO NOT contain NODATA, replace the data value with NODATA", + arguments = """ + Arguments: + * target - tile to mask + * mask - masking definition""", + examples = """ + Examples: + > SELECT _FUNC_(target, mask); + ...""" +) +case class InverseMaskByDefined(targetTile: Expression, maskTile: Expression) + extends BinaryExpression + with CodegenFallback + with RasterResult { + override def nodeName: String = "rf_inverse_mask" + + def dataType: DataType = targetTile.dataType + def left: Expression = targetTile + def right: Expression = maskTile + + protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = + InverseMaskByDefined(newLeft, newRight) + + override def checkInputDataTypes(): TypeCheckResult = { + if (!tileExtractor.isDefinedAt(targetTile.dataType)) { + TypeCheckFailure(s"Input type '${targetTile.dataType}' does not conform to a raster type.") + } else if (!tileExtractor.isDefinedAt(maskTile.dataType)) { + TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.") + } else TypeCheckSuccess + } + + private lazy val targetTileExtractor = tileExtractor(targetTile.dataType) + private lazy val maskTileExtractor = tileExtractor(maskTile.dataType) + + override protected def nullSafeEval(targetInput: Any, maskInput: Any): Any = { + val (targetTile, targetCtx) = targetTileExtractor(row(targetInput)) + val (mask, maskCtx) = maskTileExtractor(row(maskInput)) + + val result = targetTile.dualCombine(mask) + { (v, m) => if (isNoData(m)) v else NODATA } + { (v, m) => if (isNoData(m)) v else NODATA } + toInternalRow(result, targetCtx) + } +} + +object InverseMaskByDefined { + def apply(srcTile: Column, maskingTile: Column): TypedColumn[Any, Tile] = + new Column(InverseMaskByDefined(srcTile.expr, maskingTile.expr)).as[Tile] +} diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InverseMaskByValue.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InverseMaskByValue.scala new file mode 100644 index 000000000..1e87a160b --- /dev/null +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InverseMaskByValue.scala @@ -0,0 +1,92 @@ +/* + * This software is licensed under the Apache 2 license, quoted below. + * + * Copyright 2019 Astraea, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * [http://www.apache.org/licenses/LICENSE-2.0] + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.locationtech.rasterframes.expressions.transformers + +import geotrellis.raster.{NODATA, Tile, d2i} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.{Column, TypedColumn} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, TernaryExpression} +import org.apache.spark.sql.types.DataType +import org.locationtech.rasterframes.expressions.DynamicExtractors.{intArgExtractor, tileExtractor} +import org.locationtech.rasterframes.expressions.{RasterResult, row} +import org.locationtech.rasterframes.tileEncoder + + +@ExpressionDescription( + usage = "_FUNC_(target, mask, maskValue) - Generate a tile with the values from the data tile, but where cells in the masking tile DO NOT contain the masking value, replace the data value with NODATA.", + arguments = """ + Arguments: + * target - tile to mask + * mask - masking definition + * maskValue - value in the `mask` for which to mark `target` as data cells + """, + examples = """ + Examples: + > SELECT _FUNC_(target, mask, maskValue); + ...""" +) +case class InverseMaskByValue(targetTile: Expression, maskTile: Expression, maskValue: Expression) + extends TernaryExpression + with CodegenFallback + with RasterResult { + override def nodeName: String = "rf_inverse_mask_by_value" + + def dataType: DataType = targetTile.dataType + def first: Expression = targetTile + def second: Expression = maskTile + def third: Expression = maskValue + + protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + InverseMaskByValue(newFirst, newSecond, newThird) + + override def checkInputDataTypes(): TypeCheckResult = { + if (!tileExtractor.isDefinedAt(targetTile.dataType)) { + TypeCheckFailure(s"Input type '${targetTile.dataType}' does not conform to a raster type.") + } else if (!tileExtractor.isDefinedAt(maskTile.dataType)) { + TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.") + } else if (!intArgExtractor.isDefinedAt(maskValue.dataType)) { + TypeCheckFailure(s"Input type '${maskValue.dataType}' isn't an integral type.") + } else TypeCheckSuccess + } + + private lazy val targetTileExtractor = tileExtractor(targetTile.dataType) + private lazy val maskTileExtractor = tileExtractor(maskTile.dataType) + private lazy val maskValueExtractor = intArgExtractor(maskValue.dataType) + + override protected def nullSafeEval(targetInput: Any, maskInput: Any, maskValueInput: Any): Any = { + val (targetTile, targetCtx) = targetTileExtractor(row(targetInput)) + val (mask, maskCtx) = maskTileExtractor(row(maskInput)) + val maskValue = maskValueExtractor(maskValueInput).value + + val result = targetTile.dualCombine(mask) + { (v, m) => if (m != maskValue) NODATA else v } + { (v, m) => if (d2i(m) != maskValue) NODATA else v } + toInternalRow(result, targetCtx) + } +} + +object InverseMaskByValue { + def apply(srcTile: Column, maskingTile: Column, maskValue: Column): TypedColumn[Any, Tile] = + new Column(InverseMaskByValue(srcTile.expr, maskingTile.expr, maskValue.expr)).as[Tile] +} diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala deleted file mode 100644 index f225b369f..000000000 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala +++ /dev/null @@ -1,213 +0,0 @@ -/* - * This software is licensed under the Apache 2 license, quoted below. - * - * Copyright 2019 Astraea, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not - * use this file except in compliance with the License. You may obtain a copy of - * the License at - * - * [http://www.apache.org/licenses/LICENSE-2.0] - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations under - * the License. - * - * SPDX-License-Identifier: Apache-2.0 - * - */ - -package org.locationtech.rasterframes.expressions.transformers - -import com.typesafe.scalalogging.Logger -import geotrellis.raster -import geotrellis.raster.{NoNoData, Tile} -import geotrellis.raster.mapalgebra.local.{Undefined, InverseMask => gtInverseMask, Mask => gtMask} -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, Literal, TernaryExpression} -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.{Column, TypedColumn} -import org.locationtech.rasterframes._ -import org.locationtech.rasterframes.expressions.DynamicExtractors._ -import org.locationtech.rasterframes.expressions.localops.IsIn -import org.locationtech.rasterframes.expressions.{RasterResult, row} -import org.slf4j.LoggerFactory - -/** Convert cells in the `left` to NoData based on another tile's contents - * - * @param first a tile of data values, with valid nodata cell type - * @param second a tile indicating locations to set to nodata - * @param third optional, cell values in the `middle` tile indicating locations to set NoData - * @param undefined if true, consider NoData in the `middle` as the locations to mask; else use `right` valued cells - * @param inverse if true, and defined is true, set `left` to NoData where `middle` is NOT nodata - */ -abstract class Mask(val first: Expression, val second: Expression, val third: Expression, undefined: Boolean, inverse: Boolean) - extends TernaryExpression with RasterResult with CodegenFallback with Serializable { - // aliases. - def targetExp: Expression = first - def maskExp: Expression = second - def maskValueExp: Expression = third - - @transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName)) - - override def checkInputDataTypes(): TypeCheckResult = - if (!tileExtractor.isDefinedAt(targetExp.dataType)) { - TypeCheckFailure(s"Input type '${targetExp.dataType}' does not conform to a raster type.") - } else if (!tileExtractor.isDefinedAt(maskExp.dataType)) { - TypeCheckFailure(s"Input type '${maskExp.dataType}' does not conform to a raster type.") - } else if (!intArgExtractor.isDefinedAt(maskValueExp.dataType)) { - TypeCheckFailure(s"Input type '${maskValueExp.dataType}' isn't an integral type.") - } else TypeCheckSuccess - - def dataType: DataType = first.dataType - - override def makeCopy(newArgs: Array[AnyRef]): Expression = super.makeCopy(newArgs) - - override protected def nullSafeEval(targetInput: Any, maskInput: Any, maskValueInput: Any): Any = { - val (targetTile, targetCtx) = tileExtractor(targetExp.dataType)(row(targetInput)) - - require(! targetTile.cellType.isInstanceOf[NoNoData], - s"Input data expression ${first.prettyName} must have a CellType with NoData defined in order to perform a masking operation. Found CellType ${targetTile.cellType.toString()}.") - - val (maskTile, maskCtx) = tileExtractor(maskExp.dataType)(row(maskInput)) - - if (targetCtx.isEmpty && maskCtx.isDefined) - logger.warn( - s"Right-hand parameter '${second}' provided an extent and CRS, but the left-hand parameter " + - s"'${first}' didn't have any. Because the left-hand side defines output type, the right-hand context will be lost.") - - if (targetCtx.isDefined && maskCtx.isDefined && targetCtx != maskCtx) - logger.warn(s"Both '${first}' and '${second}' provided an extent and CRS, but they are different. Left-hand side will be used.") - - val maskValue = intArgExtractor(maskValueExp.dataType)(maskValueInput) - - // Get a tile where values of 1 indicate locations to set to ND in the target tile - // When `undefined` is true, setting targetTile locations to ND for ND locations of the `maskTile` - val masking = if (undefined) Undefined(maskTile) - else maskTile.localEqual(maskValue.value) // Otherwise if `maskTile` locations equal `maskValue`, set location to ND - - // apply the `masking` where values are 1 set to ND (possibly inverted!) - val result = if (inverse) gtInverseMask(targetTile, masking, 1, raster.NODATA) else gtMask(targetTile, masking, 1, raster.NODATA) - - toInternalRow(result, targetCtx) - } -} -object Mask { - @ExpressionDescription( - usage = "_FUNC_(target, mask) - Generate a tile with the values from the data tile, but where cells in the masking tile contain NODATA, replace the data value with NODATA.", - arguments = """ - Arguments: - * target - tile to mask - * mask - masking definition""", - examples = """ - Examples: - > SELECT _FUNC_(target, mask); - ...""" - ) - case class MaskByDefined(target: Expression, mask: Expression) extends Mask(target, mask, Literal(0), true, false) { - override def nodeName: String = "rf_mask" - - override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = ??? - } - object MaskByDefined { - def apply(targetTile: Column, maskTile: Column): TypedColumn[Any, Tile] = - new Column(MaskByDefined(targetTile.expr, maskTile.expr)).as[Tile] - } - - @ExpressionDescription( - usage = "_FUNC_(target, mask) - Generate a tile with the values from the data tile, but where cells in the masking tile DO NOT contain NODATA, replace the data value with NODATA", - arguments = """ - Arguments: - * target - tile to mask - * mask - masking definition""", - examples = """ - Examples: - > SELECT _FUNC_(target, mask); - ...""" - ) - case class InverseMaskByDefined(leftTile: Expression, rightTile: Expression) extends Mask(leftTile, rightTile, Literal(0), true, true) { - override def nodeName: String = "rf_inverse_mask" - - override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = - copy(leftTile = newFirst, rightTile = newSecond) - } - object InverseMaskByDefined { - def apply(srcTile: Column, maskingTile: Column): TypedColumn[Any, Tile] = - new Column(InverseMaskByDefined(srcTile.expr, maskingTile.expr)).as[Tile] - } - - @ExpressionDescription( - usage = "_FUNC_(target, mask, maskValue) - Generate a tile with the values from the data tile, but where cells in the masking tile contain the masking value, replace the data value with NODATA.", - arguments = """ - Arguments: - * target - tile to mask - * mask - masking definition""", - examples = """ - Examples: - > SELECT _FUNC_(target, mask, maskValue); - ...""" - ) - case class MaskByValue(leftTile: Expression, rightTile: Expression, maskValue: Expression) extends Mask(leftTile, rightTile, maskValue, false, false) { - override def nodeName: String = "rf_mask_by_value" - - override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = - copy(leftTile = newFirst, rightTile = newSecond, maskValue = newThird) - } - object MaskByValue { - def apply(srcTile: Column, maskingTile: Column, maskValue: Column): TypedColumn[Any, Tile] = - new Column(MaskByValue(srcTile.expr, maskingTile.expr, maskValue.expr)).as[Tile] - } - - @ExpressionDescription( - usage = "_FUNC_(target, mask, maskValue) - Generate a tile with the values from the data tile, but where cells in the masking tile DO NOT contain the masking value, replace the data value with NODATA.", - arguments = """ - Arguments: - * target - tile to mask - * mask - masking definition - * maskValue - value in the `mask` for which to mark `target` as data cells - """, - examples = """ - Examples: - > SELECT _FUNC_(target, mask, maskValue); - ...""" - ) - case class InverseMaskByValue(leftTile: Expression, rightTile: Expression, maskValue: Expression) extends Mask(leftTile, rightTile, maskValue, false, true) { - override def nodeName: String = "rf_inverse_mask_by_value" - - override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = - copy(leftTile = newFirst, rightTile = newSecond) - } - object InverseMaskByValue { - def apply(srcTile: Column, maskingTile: Column, maskValue: Column): TypedColumn[Any, Tile] = - new Column(InverseMaskByValue(srcTile.expr, maskingTile.expr, maskValue.expr)).as[Tile] - } - - @ExpressionDescription( - usage = "_FUNC_(data, mask, maskValues) - Generate a tile with the values from `data` tile but where cells in the `mask` tile are in the `maskValues` list, replace the value with NODATA.", - arguments = """ - Arguments: - * target - tile to mask - * mask - masking definition - * maskValues - sequence of values to consider as masks candidates - """, - examples = """ - Examples: - > SELECT _FUNC_(data, mask, array(1, 2, 3)) - ...""" - ) - case class MaskByValues(dataTile: Expression, maskTile: Expression) extends Mask(dataTile, maskTile, Literal(1), false, false) { - def this(dataTile: Expression, maskTile: Expression, maskValues: Expression) = - this(dataTile, IsIn(maskTile, maskValues)) - override def nodeName: String = "rf_mask_by_values" - - override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = ??? - } - object MaskByValues { - def apply(dataTile: Column, maskTile: Column, maskValues: Column): TypedColumn[Any, Tile] = - new Column(MaskByValues(dataTile.expr, IsIn(maskTile, maskValues).expr)).as[Tile] - } -} diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByDefined.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByDefined.scala new file mode 100644 index 000000000..7420be708 --- /dev/null +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByDefined.scala @@ -0,0 +1,84 @@ +/* + * This software is licensed under the Apache 2 license, quoted below. + * + * Copyright 2019 Astraea, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * [http://www.apache.org/licenses/LICENSE-2.0] + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.locationtech.rasterframes.expressions.transformers +import geotrellis.raster.{NODATA, Tile, isNoData} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.{Column, TypedColumn} +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription} +import org.apache.spark.sql.types.DataType +import org.locationtech.rasterframes.expressions.DynamicExtractors.{tileExtractor} +import org.locationtech.rasterframes.expressions.{RasterResult, row} +import org.locationtech.rasterframes.tileEncoder + + +@ExpressionDescription( + usage = "_FUNC_(target, mask) - Generate a tile with the values from the data tile, but where cells in the masking tile contain NODATA, replace the data value with NODATA.", + arguments = """ + Arguments: + * target - tile to mask + * mask - masking definition""", + examples = """ + Examples: + > SELECT _FUNC_(target, mask); + ...""" +) +case class MaskByDefined(targetTile: Expression, maskTile: Expression) + extends BinaryExpression + with CodegenFallback + with RasterResult { + override def nodeName: String = "rf_mask" + + def dataType: DataType = targetTile.dataType + def left: Expression = targetTile + def right: Expression = maskTile + + protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = + MaskByDefined(newLeft, newRight) + + override def checkInputDataTypes(): TypeCheckResult = { + if (!tileExtractor.isDefinedAt(targetTile.dataType)) { + TypeCheckFailure(s"Input type '${targetTile.dataType}' does not conform to a raster type.") + } else if (!tileExtractor.isDefinedAt(maskTile.dataType)) { + TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.") + } else TypeCheckSuccess + } + + private lazy val targetTileExtractor = tileExtractor(targetTile.dataType) + private lazy val maskTileExtractor = tileExtractor(maskTile.dataType) + + override protected def nullSafeEval(targetInput: Any, maskInput: Any): Any = { + val (targetTile, targetCtx) = targetTileExtractor(row(targetInput)) + val (mask, maskCtx) = maskTileExtractor(row(maskInput)) + + val result = targetTile.dualCombine(mask) + { (v, m) => if (isNoData(m)) NODATA else v } + { (v, m) => if (isNoData(m)) NODATA else v } + toInternalRow(result, targetCtx) + } +} + +object MaskByDefined { + def apply(targetTile: Column, maskTile: Column): TypedColumn[Any, Tile] = + new Column(MaskByDefined(targetTile.expr, maskTile.expr)).as[Tile] +} diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByValue.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByValue.scala new file mode 100644 index 000000000..eda992bdc --- /dev/null +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByValue.scala @@ -0,0 +1,92 @@ +/* + * This software is licensed under the Apache 2 license, quoted below. + * + * Copyright 2019 Astraea, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * [http://www.apache.org/licenses/LICENSE-2.0] + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.locationtech.rasterframes.expressions.transformers + +import geotrellis.raster.{NODATA, Tile, d2i} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.{Column, TypedColumn} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, TernaryExpression} +import org.apache.spark.sql.types.{DataType} +import org.locationtech.rasterframes.expressions.DynamicExtractors.{intArgExtractor, tileExtractor} +import org.locationtech.rasterframes.expressions.{RasterResult, row} +import org.locationtech.rasterframes.tileEncoder + + +@ExpressionDescription( + usage = "_FUNC_(target, mask, maskValue) - Generate a tile with the values from the data tile, but where cells in the masking tile contain the masking value, replace the data value with NODATA.", + arguments = """ + Arguments: + * target - tile to mask + * mask - masking definition + * maskValue - pixel value to consider as mask location when found in mask tile + """, + examples = """ + Examples: + > SELECT _FUNC_(target, mask, maskValue); + ...""" +) +case class MaskByValue(dataTile: Expression, maskTile: Expression, maskValue: Expression) + extends TernaryExpression + with CodegenFallback + with RasterResult { + override def nodeName: String = "rf_mask_by_value" + + def dataType: DataType = dataTile.dataType + def first: Expression = dataTile + def second: Expression = maskTile + def third: Expression = maskValue + + protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + MaskByValue(newFirst, newSecond, newThird) + + override def checkInputDataTypes(): TypeCheckResult = { + if (!tileExtractor.isDefinedAt(dataTile.dataType)) { + TypeCheckFailure(s"Input type '${dataTile.dataType}' does not conform to a raster type.") + } else if (!tileExtractor.isDefinedAt(maskTile.dataType)) { + TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.") + } else if (!intArgExtractor.isDefinedAt(maskValue.dataType)) { + TypeCheckFailure(s"Input type '${maskValue.dataType}' isn't an integral type.") + } else TypeCheckSuccess + } + + private lazy val dataTileExtractor = tileExtractor(dataTile.dataType) + private lazy val maskTileExtractor = tileExtractor(maskTile.dataType) + private lazy val maskValueExtractor = intArgExtractor(maskValue.dataType) + + override protected def nullSafeEval(targetInput: Any, maskInput: Any, maskValueInput: Any): Any = { + val (targetTile, targetCtx) = dataTileExtractor(row(targetInput)) + val (mask, maskCtx) = maskTileExtractor(row(maskInput)) + val maskValue = maskValueExtractor(maskValueInput).value + + val result = targetTile.dualCombine(mask) + { (v, m) => if (m == maskValue) NODATA else v } + { (v, m) => if (d2i(m) == maskValue) NODATA else v } + toInternalRow(result, targetCtx) + } +} + +object MaskByValue { + def apply(srcTile: Column, maskingTile: Column, maskValue: Column): TypedColumn[Any, Tile] = + new Column(MaskByValue(srcTile.expr, maskingTile.expr, maskValue.expr)).as[Tile] +} diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByValues.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByValues.scala new file mode 100644 index 000000000..39d9d9dd3 --- /dev/null +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByValues.scala @@ -0,0 +1,93 @@ +/* + * This software is licensed under the Apache 2 license, quoted below. + * + * Copyright 2019 Astraea, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * [http://www.apache.org/licenses/LICENSE-2.0] + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.locationtech.rasterframes.expressions.transformers + +import geotrellis.raster.{NODATA, Tile, d2i} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, TernaryExpression} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Column, TypedColumn} +import org.locationtech.rasterframes.expressions.DynamicExtractors.{intArrayExtractor, tileExtractor} +import org.locationtech.rasterframes.expressions.{RasterResult, row} +import org.locationtech.rasterframes.tileEncoder + +@ExpressionDescription( + usage = + "_FUNC_(data, mask, maskValues) - Generate a tile with the values from `data` tile but where cells in the `mask` tile are in the `maskValues` list, replace the value with NODATA.", + arguments = """ + Arguments: + * target - tile to mask + * mask - masking definition + * maskValues - sequence of values to consider as masks candidates + """, + examples = """ + Examples: + > SELECT _FUNC_(data, mask, array(1, 2, 3)) + ...""" +) +case class MaskByValues(targetTile: Expression, maskTile: Expression, maskValues: Expression) + extends TernaryExpression + with CodegenFallback + with RasterResult { + override def nodeName: String = "rf_mask_by_values" + + def dataType: DataType = targetTile.dataType + def first: Expression = targetTile + def second: Expression = maskTile + def third: Expression = maskValues + + protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + MaskByValues(newFirst, newSecond, newThird) + + override def checkInputDataTypes(): TypeCheckResult = + if (!tileExtractor.isDefinedAt(targetTile.dataType)) { + TypeCheckFailure(s"Input type '${targetTile.dataType}' does not conform to a raster type.") + } else if (!tileExtractor.isDefinedAt(maskTile.dataType)) { + TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.") + } else if (!intArrayExtractor.isDefinedAt(maskValues.dataType)) { + TypeCheckFailure(s"Input type '${maskValues.dataType}' does not translate to an array.") + } else TypeCheckSuccess + + private lazy val targetTileExtractor = tileExtractor(targetTile.dataType) + private lazy val maskTileExtractor = tileExtractor(maskTile.dataType) + private lazy val maskValuesExtractor = intArrayExtractor(maskValues.dataType) + + override protected def nullSafeEval(targetInput: Any, maskInput: Any, maskValuesInput: Any): Any = { + val (targetTile, targetCtx) = targetTileExtractor(row(targetInput)) + val (mask, maskCtx) = maskTileExtractor(row(maskInput)) + val maskValues: Array[Int] = maskValuesExtractor(maskValuesInput.asInstanceOf[ArrayData]) + + val result = targetTile.dualCombine(mask) + { (v, m) => if (maskValues.contains(m)) NODATA else v } + { (v, m) => if (maskValues.contains(d2i(m))) NODATA else v } + + toInternalRow(result, targetCtx) + } +} + +object MaskByValues { + def apply(dataTile: Column, maskTile: Column, maskValues: Column): TypedColumn[Any, Tile] = + new Column(MaskByValues(dataTile.expr, maskTile.expr, maskValues.expr)).as[Tile] +} diff --git a/core/src/main/scala/org/locationtech/rasterframes/functions/LocalFunctions.scala b/core/src/main/scala/org/locationtech/rasterframes/functions/LocalFunctions.scala index 1388a82fb..1b066418f 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/functions/LocalFunctions.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/functions/LocalFunctions.scala @@ -130,13 +130,13 @@ trait LocalFunctions { /** Where the rf_mask tile contains NODATA, replace values in the source tile with NODATA */ def rf_mask(sourceTile: Column, maskTile: Column, inverse: Boolean = false): TypedColumn[Any, Tile] = - if (!inverse) Mask.MaskByDefined(sourceTile, maskTile) - else Mask.InverseMaskByDefined(sourceTile, maskTile) + if (!inverse) MaskByDefined(sourceTile, maskTile) + else InverseMaskByDefined(sourceTile, maskTile) /** Where the `maskTile` equals `maskValue`, replace values in the source tile with `NoData` */ def rf_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Column, inverse: Boolean = false): TypedColumn[Any, Tile] = - if (!inverse) Mask.MaskByValue(sourceTile, maskTile, maskValue) - else Mask.InverseMaskByValue(sourceTile, maskTile, maskValue) + if (!inverse) MaskByValue(sourceTile, maskTile, maskValue) + else InverseMaskByValue(sourceTile, maskTile, maskValue) /** Where the `maskTile` equals `maskValue`, replace values in the source tile with `NoData` */ def rf_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Int, inverse: Boolean): TypedColumn[Any, Tile] = @@ -149,7 +149,7 @@ trait LocalFunctions { /** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values` list, replace the value with NODATA. */ def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Column): TypedColumn[Any, Tile] = - Mask.MaskByValues(sourceTile, maskTile, maskValues) + MaskByValues(sourceTile, maskTile, maskValues) /** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values` list, replace the value with NODATA. */ @@ -161,15 +161,15 @@ trait LocalFunctions { /** Where the `maskTile` does **not** contain `NoData`, replace values in the source tile with `NoData` */ def rf_inverse_mask(sourceTile: Column, maskTile: Column): TypedColumn[Any, Tile] = - Mask.InverseMaskByDefined(sourceTile, maskTile) + InverseMaskByDefined(sourceTile, maskTile) /** Where the `maskTile` does **not** equal `maskValue`, replace values in the source tile with `NoData` */ def rf_inverse_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Column): TypedColumn[Any, Tile] = - Mask.InverseMaskByValue(sourceTile, maskTile, maskValue) + InverseMaskByValue(sourceTile, maskTile, maskValue) /** Where the `maskTile` does **not** equal `maskValue`, replace values in the source tile with `NoData` */ def rf_inverse_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Int): TypedColumn[Any, Tile] = - Mask.InverseMaskByValue(sourceTile, maskTile, lit(maskValue)) + InverseMaskByValue(sourceTile, maskTile, lit(maskValue)) /** Applies a mask using bit values in the `mask_tile`. Working from the right, extract the bit at `bitPosition` from the `maskTile`. In all locations where these are equal to the `valueToMask`, the returned tile is set to NoData, else the original `dataTile` cell value. */ def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Int, valueToMask: Boolean): TypedColumn[Any, Tile] = @@ -192,7 +192,11 @@ trait LocalFunctions { rf_mask_by_values(dataTile, bitMask, valuesToMask) } - /** Applies a mask from blacklisted bit values in the `mask_tile`. Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. In all locations where these are in the `mask_values`, the returned tile is set to NoData; otherwise the original `tile` cell value is returned. */ + /** Applies a mask from blacklisted bit values in the `mask_tile`. + * Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. + * In all locations where these are in the `mask_values`, the returned tile is set to NoData; + * otherwise the original `tile` cell value is returned. + **/ def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Int, numBits: Int, valuesToMask: Int*): TypedColumn[Any, Tile] = { import org.apache.spark.sql.functions.array val values = array(valuesToMask.map(lit): _*) diff --git a/core/src/test/scala/org/locationtech/rasterframes/functions/MaskingFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/functions/MaskingFunctionsSpec.scala index b29ba29fa..a408cc057 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/functions/MaskingFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/functions/MaskingFunctionsSpec.scala @@ -97,15 +97,13 @@ class MaskingFunctionsSpec extends TestEnvironment with RasterMatchers { checkDocs("rf_inverse_mask") } - it("should throw if no nodata"){ + it("should mask over no nodata"){ val noNoDataCellType = UByteCellType val df = Seq(Option(TestData.projectedRasterTile(5, 5, 42, TestData.extent, TestData.crs, noNoDataCellType))).toDF("tile") - an [IllegalArgumentException] should be thrownBy { - df.select(rf_mask($"tile", $"tile")).collect() - } + df.select(rf_mask($"tile", $"tile")) } } From 285e03d4f02dd3d447f0456052630f8a5e183cea Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Sat, 10 Dec 2022 21:25:14 -0500 Subject: [PATCH 16/34] Fix test: 6900 bit at position 4 is 1 -- expect NODATA after mask --- .../functions/LocalFunctions.scala | 10 ++++++++-- .../functions/MaskingFunctionsSpec.scala | 20 +++++++++---------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/functions/LocalFunctions.scala b/core/src/main/scala/org/locationtech/rasterframes/functions/LocalFunctions.scala index 1b066418f..c4c8a21e0 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/functions/LocalFunctions.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/functions/LocalFunctions.scala @@ -175,13 +175,19 @@ trait LocalFunctions { def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Int, valueToMask: Boolean): TypedColumn[Any, Tile] = rf_mask_by_bit(dataTile, maskTile, lit(bitPosition), lit(if (valueToMask) 1 else 0)) - /** Applies a mask using bit values in the `mask_tile`. Working from the right, extract the bit at `bitPosition` from the `maskTile`. In all locations where these are equal to the `valueToMask`, the returned tile is set to NoData, else the original `dataTile` cell value. */ + /** Applies a mask using bit values in the `mask_tile`. Working from the right, extract the bit at `bitPosition` from the `maskTile`. + * In all locations where these are equal to the `valueToMask`, the returned tile is set to NoData, else the original `dataTile` cell value. + **/ def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Column, valueToMask: Column): TypedColumn[Any, Tile] = { import org.apache.spark.sql.functions.array rf_mask_by_bits(dataTile, maskTile, bitPosition, lit(1), array(valueToMask)) } - /** Applies a mask from blacklisted bit values in the `mask_tile`. Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. In all locations where these are in the `mask_values`, the returned tile is set to NoData; otherwise the original `tile` cell value is returned. */ + /** Applies a mask from blacklisted bit values in the `mask_tile`. + * Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. + * In all locations where these are in the `mask_values`, the returned tile is set to NoData; + * otherwise the original `tile` cell value is returned. + **/ def rf_mask_by_bits( dataTile: Column, maskTile: Column, diff --git a/core/src/test/scala/org/locationtech/rasterframes/functions/MaskingFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/functions/MaskingFunctionsSpec.scala index a408cc057..f930c28e7 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/functions/MaskingFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/functions/MaskingFunctionsSpec.scala @@ -349,8 +349,7 @@ class MaskingFunctionsSpec extends TestEnvironment with RasterMatchers { .withColumn("fill_no", rf_mask_by_bit($"data", $"mask", 0, true)) .withColumn("sat_0", rf_mask_by_bits($"data", $"mask", 2, 2, 1, 2, 3)) // strict no bands .withColumn("sat_2", rf_mask_by_bits($"data", $"mask", 2, 2, 2, 3)) // up to 2 bands contain sat - .withColumn("sat_4", - rf_mask_by_bits($"data", $"mask", lit(2), lit(2), array(lit(3)))) // up to 4 bands contain sat + .withColumn("sat_4", rf_mask_by_bits($"data", $"mask", lit(2), lit(2), array(lit(3)))) // up to 4 bands contain sat .withColumn("cloud_no", rf_mask_by_bit($"data", $"mask", lit(4), lit(true))) .withColumn("cloud_only", rf_mask_by_bit($"data", $"mask", 4, false)) // mask if *not* cloud .withColumn("cloud_conf_low", rf_mask_by_bits($"data", $"mask", lit(5), lit(2), array(lit(0), lit(1)))) @@ -360,14 +359,14 @@ class MaskingFunctionsSpec extends TestEnvironment with RasterMatchers { result.select(rf_cell_type($"fill_no")).first() should be (dataColumnCellType) def checker(columnName: String, maskValueFilter: Int, resultIsNoData: Boolean = true): Unit = { - /** in this unit test setup, the `val` column is an integer that the entire row's mask is full of - * filter for the maskValueFilter - * then check the columnName and look at the masked data tile given by `columnName` - * assert that the `columnName` tile is / is not all nodata based on `resultIsNoData` + /** in this unit test setup, the `val` column is an integer that the entire row's mask is full of + * - filter for the maskValueFilter + * - then check the columnName + * - look at the masked data tile given by `columnName` + * - assert that the `columnName` tile is / is not all nodata based on `resultIsNoData` * */ - val printOutcome = if (resultIsNoData) "all NoData cells" - else "all data cells" + val printOutcome = if (resultIsNoData) "all NoData cells" else "all data cells" logger.debug(s"${columnName} should contain ${printOutcome} for qa val ${maskValueFilter}") val resultDf = result @@ -380,13 +379,12 @@ class MaskingFunctionsSpec extends TestEnvironment with RasterMatchers { val dataTile = resultDf.select(col(columnName)).as[Option[ProjectedRasterTile]].first().get logger.debug(s"\tData tile values for col ${columnName}: ${dataTile.toArray().mkString(",")}") - resultToCheck should be (resultIsNoData) + resultToCheck should be(resultIsNoData) } - checker("fill_no", fill, true) checker("cloud_only", clear, true) checker("cloud_only", hi_cirrus, false) - checker("cloud_no", hi_cirrus, false) + checker("cloud_no", hi_cirrus, true) checker("sat_0", clear, false) checker("cloud_no", clear, false) checker("cloud_no", med_cloud, false) From 725c9d52b53044d1ee5c4745738eea93778cd0b3 Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Sun, 11 Dec 2022 08:10:10 -0500 Subject: [PATCH 17/34] TileRasterizerAggregate expects column in rf_raster_proj order This is a change but it's towards less surprising --- .../expressions/aggregates/TileRasterizerAggregate.scala | 6 +++--- .../rasterframes/functions/AggregateFunctions.scala | 2 +- .../org/locationtech/rasterframes/RasterJoinSpec.scala | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala index 58a54a8d1..b916ee301 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala @@ -20,7 +20,6 @@ */ package org.locationtech.rasterframes.expressions.aggregates - import geotrellis.layer._ import geotrellis.proj4.CRS import geotrellis.raster.reproject.Reproject @@ -48,6 +47,7 @@ class TileRasterizerAggregate(prd: ProjectedRasterDefinition) extends Aggregator override def zero: MutableArrayTile = ArrayTile.empty(prd.destinationCellType, prd.totalCols, prd.totalRows) override def reduce(b: Tile, a: ProjectedRasterTile): Tile = { + // TODO: this is not right, got to use dynamic reprojection for this extent val localExtent = a.extent.reproject(a.crs, prd.destinationCRS) if (prd.destinationExtent.intersects(localExtent)) { val localTile = a.tile.reproject(a.extent, a.crs, prd.destinationCRS, projOpts) @@ -81,13 +81,13 @@ object TileRasterizerAggregate { } } - def apply(prd: ProjectedRasterDefinition, crsCol: Column, extentCol: Column, tileCol: Column): TypedColumn[Any, Tile] = { + def apply(prd: ProjectedRasterDefinition, tileCol: Column, extentCol: Column, crsCol: Column): TypedColumn[Any, Tile] = { if (prd.totalCols.toDouble * prd.totalRows * 64.0 > Runtime.getRuntime.totalMemory() * 0.5) logger.warn( s"You've asked for the construction of a very large image (${prd.totalCols} x ${prd.totalRows}). Out of memory error likely.") udaf(new TileRasterizerAggregate(prd)) - .apply(crsCol, extentCol, tileCol) + .apply(tileCol, extentCol, crsCol) .as("rf_agg_overview_raster") .as[Tile] } diff --git a/core/src/main/scala/org/locationtech/rasterframes/functions/AggregateFunctions.scala b/core/src/main/scala/org/locationtech/rasterframes/functions/AggregateFunctions.scala index 13d8e13b6..a1f9af1f9 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/functions/AggregateFunctions.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/functions/AggregateFunctions.scala @@ -108,7 +108,7 @@ trait AggregateFunctions { */ def rf_agg_overview_raster(tile: Column, tileExtent: Column, tileCRS: Column, cols: Int, rows: Int, areaOfInterest: Extent, sampler: ResampleMethod): TypedColumn[Any, Tile] = { val params = ProjectedRasterDefinition(cols, rows, IntConstantNoDataCellType, WebMercator, areaOfInterest, sampler) - TileRasterizerAggregate(params, tileCRS, tileExtent, tile) + TileRasterizerAggregate(params, tile, tileExtent, tileCRS) } import org.apache.spark.sql.functions._ diff --git a/core/src/test/scala/org/locationtech/rasterframes/RasterJoinSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/RasterJoinSpec.scala index e57255ea0..b8a810fb5 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/RasterJoinSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/RasterJoinSpec.scala @@ -81,7 +81,7 @@ class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers { // create a Raster from tile2 which should be almost equal to b4nativeTif val agg = joined.agg(TileRasterizerAggregate( ProjectedRasterDefinition(b4nativeTif.cols, b4nativeTif.rows, b4nativeTif.cellType, b4nativeTif.crs, b4nativeTif.extent, Bilinear), - $"crs", $"extent", $"tile2") as "raster" + $"tile2", $"extent", $"crs") as "raster" ).select(col("raster").as[Tile]) val raster = Raster(agg.first(), srcExtent) From 91468b4639319c48fccc28007c87390e3d777abc Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Sun, 11 Dec 2022 09:29:27 -0500 Subject: [PATCH 18/34] Use spark-testing-base - core tests green - fixed weird init order in tests - all tests share same context now thanks to base - exclude scala-xml from tests --- build.sbt | 2 +- .../aggregates/TileRasterizerAggregate.scala | 2 +- .../rasterframes/BaseUdtSpec.scala | 3 - .../locationtech/rasterframes/CrsSpec.scala | 1 - .../rasterframes/GeometryFunctionsSpec.scala | 5 +- .../rasterframes/RasterFunctionsSpec.scala | 3 +- .../rasterframes/RasterJoinSpec.scala | 19 +++-- .../rasterframes/SpatialKeySpec.scala | 8 +-- .../rasterframes/StandardEncodersSpec.scala | 4 -- .../rasterframes/TestEnvironment.scala | 45 +++++++----- .../rasterframes/TileUDTSpec.scala | 2 - .../rasterframes/encoders/EncodingSpec.scala | 4 +- .../functions/AggregateFunctionsSpec.scala | 3 +- .../functions/FocalFunctionsSpec.scala | 3 +- .../functions/LocalFunctionsSpec.scala | 3 +- .../functions/MaskingFunctionsSpec.scala | 23 ++++-- .../functions/StatFunctionsSpec.scala | 70 ++++++++++++++++--- .../functions/TileFunctionsSpec.scala | 10 +-- .../geotrellis/GeoTrellisDataSourceSpec.scala | 6 +- .../raster/RasterSourceDataSourceSpec.scala | 38 +++++++++- .../slippy/SlippyDataSourceSpec.scala | 22 +++--- .../tiles/TilesDataSourceSpec.scala | 54 +++++--------- project/RFDependenciesPlugin.scala | 1 + 23 files changed, 204 insertions(+), 127 deletions(-) diff --git a/build.sbt b/build.sbt index ba9f00115..d7cd1b102 100644 --- a/build.sbt +++ b/build.sbt @@ -78,7 +78,7 @@ lazy val core = project ExclusionRule(organization = "com.github.mpilquist") ), scaffeine, - scalatest, + sparktestingbase excludeAll ExclusionRule("org.scala-lang.modules", "scala-xml_2.12"), `scala-logging` ), libraryDependencies ++= { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala index b916ee301..5bd914d41 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala @@ -137,7 +137,7 @@ object TileRasterizerAggregate { destExtent.map { ext => c.copy(destinationExtent = ext) } - val aggs = tileCols.map(t => TileRasterizerAggregate(config, rf_crs(crsCol), extCol, rf_tile(t)).as(t.columnName)) + val aggs = tileCols.map(t => TileRasterizerAggregate(config, rf_tile(t), extCol, rf_crs(crsCol)).as(t.columnName)) val agg = df.select(aggs: _*) diff --git a/core/src/test/scala/org/locationtech/rasterframes/BaseUdtSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/BaseUdtSpec.scala index ad5897ff4..ad61c972e 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/BaseUdtSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/BaseUdtSpec.scala @@ -27,8 +27,6 @@ import org.scalatest.Inspectors class BaseUdtSpec extends TestEnvironment with TestData with Inspectors { - spark.version - it("should (de)serialize CRS") { val udt = new CrsUDT() val in = geotrellis.proj4.LatLng @@ -37,6 +35,5 @@ class BaseUdtSpec extends TestEnvironment with TestData with Inspectors { out shouldBe in assert(out.isInstanceOf[LazyCRS]) info(out.toString()) - } } diff --git a/core/src/test/scala/org/locationtech/rasterframes/CrsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/CrsSpec.scala index 888a87004..0b3d8c8c7 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/CrsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/CrsSpec.scala @@ -28,7 +28,6 @@ import org.locationtech.rasterframes.ref.RFRasterSource import org.locationtech.rasterframes.ref.RasterRef class CrsSpec extends TestEnvironment with TestData with Inspectors { - spark.version import spark.implicits._ describe("CrsUDT") { diff --git a/core/src/test/scala/org/locationtech/rasterframes/GeometryFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/GeometryFunctionsSpec.scala index 04573c9e5..8df91b6db 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/GeometryFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/GeometryFunctionsSpec.scala @@ -32,10 +32,8 @@ import org.locationtech.jts.geom.{Coordinate, GeometryFactory} * @since 12/16/17 */ class GeometryFunctionsSpec extends TestEnvironment with TestData with StandardColumns { - import spark.implicits._ - describe("Vector geometry operations") { - val rf = l8Sample(1).projectedRaster.toLayer(10, 10).withGeometry() + lazy val rf = l8Sample(1).projectedRaster.toLayer(10, 10).withGeometry() it("should allow joining and filtering of tiles based on points") { import spark.implicits._ @@ -136,6 +134,7 @@ class GeometryFunctionsSpec extends TestEnvironment with TestData with StandardC } it("should rasterize geometry") { + import spark.implicits._ val rf = l8Sample(1).projectedRaster.toLayer.withGeometry() val df = GeomData.features.map(f => ( f.geom.reproject(LatLng, rf.crs), diff --git a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala index 2e9987b99..0a2cfeb00 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala @@ -22,11 +22,10 @@ package org.locationtech.rasterframes import geotrellis.raster._ -import geotrellis.raster.testkit.RasterMatchers import org.apache.spark.sql.functions._ import org.locationtech.rasterframes.tiles.ProjectedRasterTile -class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { +class RasterFunctionsSpec extends TestEnvironment { import TestData._ import spark.implicits._ diff --git a/core/src/test/scala/org/locationtech/rasterframes/RasterJoinSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/RasterJoinSpec.scala index b8a810fb5..beae2909c 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/RasterJoinSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/RasterJoinSpec.scala @@ -23,7 +23,6 @@ package org.locationtech.rasterframes import geotrellis.proj4.CRS import geotrellis.raster.resample._ -import geotrellis.raster.testkit.RasterMatchers import geotrellis.raster.{Dimensions, IntConstantNoDataCellType, Raster, Tile} import geotrellis.vector.Extent import org.apache.spark.SparkConf @@ -32,18 +31,18 @@ import org.locationtech.rasterframes.expressions.aggregates.TileRasterizerAggreg import org.locationtech.rasterframes.expressions.aggregates.TileRasterizerAggregate.ProjectedRasterDefinition -class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers { - import spark.implicits._ +class RasterJoinSpec extends TestEnvironment with TestData { describe("Raster join between two DataFrames") { val b4nativeTif = readSingleband("L8-B4-Elkton-VA.tiff") // Same data, reprojected to EPSG:4326 val b4warpedTif = readSingleband("L8-B4-Elkton-VA-4326.tiff") - val b4nativeRf = b4nativeTif.toDF(Dimensions(10, 10)) - val b4warpedRf = b4warpedTif.toDF(Dimensions(10, 10)) + lazy val b4nativeRf = b4nativeTif.toDF(Dimensions(10, 10)) + lazy val b4warpedRf = b4warpedTif.toDF(Dimensions(10, 10)) .withColumnRenamed("tile", "tile2") it("should join the same scene correctly") { + import spark.implicits._ val b4nativeRfPrime = b4nativeTif.toDF(Dimensions(10, 10)) .withColumnRenamed("tile", "tile2") val joined = b4nativeRf.rasterJoin(b4nativeRfPrime.hint("broadcast")) @@ -59,6 +58,7 @@ class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers { } it("should join same scene in different tile sizes"){ + import spark.implicits._ val r1prime = b4nativeTif.toDF(Dimensions(25, 25)).withColumnRenamed("tile", "tile2") r1prime.select(rf_dimensions($"tile2").getField("rows")).as[Int].first() should be (25) val joined = b4nativeRf.rasterJoin(r1prime) @@ -75,6 +75,7 @@ class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers { } it("should join same scene in two projections, same tile size") { + import spark.implicits._ val srcExtent = b4nativeTif.extent // b4warpedRf source data is gdal warped b4nativeRf data; join them together. val joined = b4nativeRf.rasterJoin(b4warpedRf) @@ -112,6 +113,7 @@ class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers { } it("should join multiple RHS tile columns"){ + import spark.implicits._ // join multiple native CRS bands to the EPSG 4326 RF val multibandRf = b4nativeRf @@ -126,6 +128,7 @@ class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers { } it("should join with heterogeneous LHS CRS and coverages"){ + import spark.implicits._ val df17 = readSingleband("m_3607824_se_17_1_20160620_subset.tif") .toDF(Dimensions(50, 50)) @@ -165,6 +168,7 @@ class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers { } it("should handle proj_raster types") { + import spark.implicits._ val df1 = Seq(Option(one)).toDF("one") val df2 = Seq(Option(two)).toDF("two") noException shouldBe thrownBy { @@ -174,6 +178,7 @@ class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers { } it("should raster join multiple times on projected raster"){ + import spark.implicits._ val df0 = Seq(Option(one)).toDF("proj_raster") val result = df0.select($"proj_raster" as "t1") .rasterJoin(df0.select($"proj_raster" as "t2")) @@ -184,6 +189,7 @@ class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers { } it("should honor resampling options") { + import spark.implicits._ // test case. replicate existing test condition and check that resampling option results in different output val filterExpr = st_intersects(rf_geometry($"tile"), st_point(704940.0, 4251130.0)) val result = b4nativeRf.rasterJoin(b4warpedRf.withColumnRenamed("tile2", "nearest"), NearestNeighbor) @@ -200,6 +206,7 @@ class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers { // Failed to execute user defined function(package$$$Lambda$4417/0x00000008019e2840: (struct, string, array,bandIndex:int,subextent:struct,subgrid:struct>>>, array>, array, struct, string) => struct,bandIndex:int,subextent:struct,subgrid:struct>>) it("should raster join with null left head") { + import spark.implicits._ // https://github.com/locationtech/rasterframes/issues/462 val prt = TestData.projectedRasterTile( 10, 10, 1, @@ -264,5 +271,5 @@ class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers { } - override def additionalConf: SparkConf = super.additionalConf.set("spark.sql.codegen.comments", "true") + override def additionalConf(conf: SparkConf) = conf.set("spark.sql.codegen.comments", "true") } diff --git a/core/src/test/scala/org/locationtech/rasterframes/SpatialKeySpec.scala b/core/src/test/scala/org/locationtech/rasterframes/SpatialKeySpec.scala index cd38d7791..ca76992e4 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/SpatialKeySpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/SpatialKeySpec.scala @@ -31,14 +31,10 @@ import org.locationtech.geomesa.curve.Z2SFC * @since 12/15/17 */ class SpatialKeySpec extends TestEnvironment with TestData { - assert(!spark.sparkContext.isStopped) - - import spark.implicits._ - describe("Spatial key conversions") { val raster = sampleGeoTiff.projectedRaster // Create a raster frame with a single row - val rf = raster.toLayer(raster.tile.cols, raster.tile.rows) + lazy val rf = raster.toLayer(raster.tile.cols, raster.tile.rows) it("should add an extent column") { val expected = raster.extent.toPolygon() @@ -53,12 +49,14 @@ class SpatialKeySpec extends TestEnvironment with TestData { } it("should add a center lat/lng value") { + import spark.implicits._ val expected = raster.extent.center.reproject(raster.crs, LatLng) val result = rf.withCenterLatLng().select($"center".as[(Double, Double)]).first assert( Point(result._1, result._2) === expected) } it("should add a z-index value") { + import spark.implicits._ val center = raster.extent.center.reproject(raster.crs, LatLng) val expected = Z2SFC.index(center.x, center.y) val result = rf.withSpatialIndex().select($"spatial_index".as[Long]).first diff --git a/core/src/test/scala/org/locationtech/rasterframes/StandardEncodersSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/StandardEncodersSpec.scala index a2cbad0b7..a2fe6f057 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/StandardEncodersSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/StandardEncodersSpec.scala @@ -37,7 +37,6 @@ import org.scalatest.Inspectors class StandardEncodersSpec extends TestEnvironment with TestData with Inspectors { it("Dimensions encoder") { - spark.version import spark.implicits._ val data = Dimensions[Int](256, 256) val df = List(data).toDF() @@ -47,7 +46,6 @@ class StandardEncodersSpec extends TestEnvironment with TestData with Inspectors } it("TileDataContext encoder") { - spark.version import spark.implicits._ val data = TileDataContext(IntCellType, Dimensions[Int](256, 256)) val df = List(data).toDF() @@ -57,7 +55,6 @@ class StandardEncodersSpec extends TestEnvironment with TestData with Inspectors } it("ProjectedExtent encoder") { - spark.version import spark.implicits._ val data = ProjectedExtent(Extent(0, 0, 1, 1), LatLng) val df = List(data).toDF() @@ -68,7 +65,6 @@ class StandardEncodersSpec extends TestEnvironment with TestData with Inspectors } it("TileLayerMetadata encoder"){ - spark.version import spark.implicits._ val data = TileLayerMetadata( IntCellType, diff --git a/core/src/test/scala/org/locationtech/rasterframes/TestEnvironment.scala b/core/src/test/scala/org/locationtech/rasterframes/TestEnvironment.scala index 953881cbd..ad21b18bc 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/TestEnvironment.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/TestEnvironment.scala @@ -20,8 +20,9 @@ */ package org.locationtech.rasterframes -import java.nio.file.{Files, Path} +import com.holdenkarau.spark.testing.DataFrameSuiteBase +import java.nio.file.{Files, Path} import com.typesafe.scalalogging.Logger import geotrellis.raster.Tile import geotrellis.raster.render.{ColorMap, ColorRamps} @@ -39,11 +40,10 @@ import org.scalactic.Tolerance import org.scalatest._ import org.scalatest.funspec.AnyFunSpec import org.scalatest.matchers.should.Matchers - import org.scalatest.matchers.{MatchResult, Matcher} import org.slf4j.LoggerFactory -trait TestEnvironment extends AnyFunSpec with Matchers with Inspectors with Tolerance with RasterMatchers { +trait TestEnvironment extends AnyFunSpec with DataFrameSuiteBase with Matchers with RasterMatchers with Inspectors with Tolerance { @transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName)) @@ -56,22 +56,29 @@ trait TestEnvironment extends AnyFunSpec with Matchers with Inspectors with Tole // allow 2 retries, should stabilize CI builds. https://spark.apache.org/docs/2.4.7/submitting-applications.html#master-urls def sparkMaster: String = "local[*, 2]" - def additionalConf: SparkConf = - new SparkConf(false) - .set("spark.driver.port", "0") - .set("spark.hostPort", "0") - .set("spark.ui.enabled", "false") - - implicit val spark: SparkSession = - SparkSession - .builder - .master(sparkMaster) - .withKryoSerialization - .config(additionalConf) - .getOrCreate() - .withRasterFrames - - implicit def sc: SparkContext = spark.sparkContext + protected def additionalConf(conf: SparkConf): SparkConf = conf + + override def conf: SparkConf = { + val base = new SparkConf(). + setAppName("RasterFrames Test"). + setMaster(sparkMaster). + set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"). + set("spark.kryo.registrator", "org.locationtech.rasterframes.util.RFKryoRegistrator"). + set("spark.ui.enabled", "false"). + set("spark.driver.port", "0"). + set("spark.hostPort", "0"). + set("spark.ui.enabled", "true") + additionalConf(base) + } + + override def setup(sc: SparkContext): Unit = { + sc.setCheckpointDir(com.holdenkarau.spark.testing.Utils.createTempDir().toPath().toString) + sc.setLogLevel("ERROR") + org.locationtech.rasterframes.initRF(sqlContext) + } + + implicit def sparkSession: SparkSession = spark + implicit def sparkContext: SparkContext = spark.sparkContext lazy val sql: String => DataFrame = spark.sql diff --git a/core/src/test/scala/org/locationtech/rasterframes/TileUDTSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/TileUDTSpec.scala index d2ae04559..0d1f2d6d5 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/TileUDTSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/TileUDTSpec.scala @@ -35,8 +35,6 @@ import org.scalatest.Inspectors class TileUDTSpec extends TestEnvironment with TestData with Inspectors { import TestData.randomTile - spark.version - describe("TileUDT") { val tileSizes = Seq(2, 7, 64, 128, 511) val ct = functions.cellTypes().filter(_ != "bool") diff --git a/core/src/test/scala/org/locationtech/rasterframes/encoders/EncodingSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/encoders/EncodingSpec.scala index 95fc4fb41..cf638a6ca 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/encoders/EncodingSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/encoders/EncodingSpec.scala @@ -196,7 +196,7 @@ class EncodingSpec extends TestEnvironment with TestData { } } - override def additionalConf: SparkConf = { - super.additionalConf.set("spark.sql.codegen.logging.maxLines", Int.MaxValue.toString) + override def additionalConf(conf: SparkConf) = { + conf.set("spark.sql.codegen.logging.maxLines", Int.MaxValue.toString) } } diff --git a/core/src/test/scala/org/locationtech/rasterframes/functions/AggregateFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/functions/AggregateFunctionsSpec.scala index 7e5049da2..64500c018 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/functions/AggregateFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/functions/AggregateFunctionsSpec.scala @@ -25,7 +25,6 @@ import geotrellis.proj4.{CRS, WebMercator} import geotrellis.raster._ import geotrellis.raster.render.Png import geotrellis.raster.resample.Bilinear -import geotrellis.raster.testkit.RasterMatchers import geotrellis.vector.Extent import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -36,7 +35,7 @@ import org.locationtech.rasterframes.encoders.StandardEncoders import org.locationtech.rasterframes.stats._ import org.locationtech.rasterframes.tiles.ProjectedRasterTile -class AggregateFunctionsSpec extends TestEnvironment with RasterMatchers { +class AggregateFunctionsSpec extends TestEnvironment { import spark.implicits._ describe("aggregate statistics") { diff --git a/core/src/test/scala/org/locationtech/rasterframes/functions/FocalFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/functions/FocalFunctionsSpec.scala index 9ec4e46dc..6e5ac9ee5 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/functions/FocalFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/functions/FocalFunctionsSpec.scala @@ -23,7 +23,6 @@ package org.locationtech.rasterframes.functions import geotrellis.raster.mapalgebra.focal.{Circle, Kernel, Square} import geotrellis.raster.{BufferTile, CellSize} -import geotrellis.raster.testkit.RasterMatchers import org.locationtech.rasterframes.ref.{RFRasterSource, RasterRef, Subgrid} import org.locationtech.rasterframes.tiles.ProjectedRasterTile import org.locationtech.rasterframes._ @@ -33,7 +32,7 @@ import org.locationtech.rasterframes.encoders.serialized_literal import java.nio.file.Paths -class FocalFunctionsSpec extends TestEnvironment with RasterMatchers { +class FocalFunctionsSpec extends TestEnvironment { import spark.implicits._ diff --git a/core/src/test/scala/org/locationtech/rasterframes/functions/LocalFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/functions/LocalFunctionsSpec.scala index ee8940b61..c9ae3eeee 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/functions/LocalFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/functions/LocalFunctionsSpec.scala @@ -23,13 +23,12 @@ package org.locationtech.rasterframes.functions import org.locationtech.rasterframes.TestEnvironment import geotrellis.raster._ -import geotrellis.raster.testkit.RasterMatchers import org.apache.spark.sql.functions._ import org.locationtech.rasterframes.expressions.accessors.ExtractTile import org.locationtech.rasterframes.tiles.ProjectedRasterTile import org.locationtech.rasterframes._ -class LocalFunctionsSpec extends TestEnvironment with RasterMatchers { +class LocalFunctionsSpec extends TestEnvironment { import TestData._ import spark.implicits._ diff --git a/core/src/test/scala/org/locationtech/rasterframes/functions/MaskingFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/functions/MaskingFunctionsSpec.scala index f930c28e7..8d6f94314 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/functions/MaskingFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/functions/MaskingFunctionsSpec.scala @@ -22,19 +22,17 @@ package org.locationtech.rasterframes.functions import geotrellis.raster._ -import geotrellis.raster.testkit.RasterMatchers import org.apache.spark.sql.functions._ import org.locationtech.rasterframes._ import org.locationtech.rasterframes.tiles.ProjectedRasterTile -class MaskingFunctionsSpec extends TestEnvironment with RasterMatchers { +class MaskingFunctionsSpec extends TestEnvironment { import TestData._ - import spark.implicits._ describe("masking by defined") { - spark.version it("should mask one tile against another") { + import spark.implicits._ val df = Seq[Tile](randPRT).toDF("tile") val withMask = df.withColumn("mask", @@ -54,6 +52,7 @@ class MaskingFunctionsSpec extends TestEnvironment with RasterMatchers { } it("should mask with expected results") { + import spark.implicits._ val df = Seq((byteArrayTile, maskingTile)).toDF("tile", "mask") val withMasked = df.withColumn("masked", @@ -65,6 +64,7 @@ class MaskingFunctionsSpec extends TestEnvironment with RasterMatchers { } it("should mask without mutating cell type") { + import spark.implicits._ val result = Seq((byteArrayTile, maskingTile)) .toDF("tile", "mask") .select(rf_mask($"tile", $"mask").as("masked_tile")) @@ -75,6 +75,7 @@ class MaskingFunctionsSpec extends TestEnvironment with RasterMatchers { } it("should inverse mask one tile against another") { + import spark.implicits._ val df = Seq[Tile](randPRT).toDF("tile") val baseND = df.select(rf_agg_no_data_cells($"tile")).first() @@ -98,6 +99,7 @@ class MaskingFunctionsSpec extends TestEnvironment with RasterMatchers { } it("should mask over no nodata"){ + import spark.implicits._ val noNoDataCellType = UByteCellType val df = @@ -111,6 +113,7 @@ class MaskingFunctionsSpec extends TestEnvironment with RasterMatchers { describe("mask by value") { it("should mask tile by another identified by specified value") { + import spark.implicits._ val df = Seq[Tile](randPRT).toDF("tile") val mask_value = 4 @@ -132,6 +135,7 @@ class MaskingFunctionsSpec extends TestEnvironment with RasterMatchers { } it("should mask by value for value 0.") { + import spark.implicits._ // maskingTile has -4, ND, and -15 values. Expect mask by value with 0 to not change the val df = Seq((byteArrayTile, maskingTile)).toDF("data", "mask") @@ -151,6 +155,7 @@ class MaskingFunctionsSpec extends TestEnvironment with RasterMatchers { } it("should inverse mask tile by another identified by specified value") { + import spark.implicits._ val df = Seq[Tile](randPRT).toDF("tile") val mask_value = 4 @@ -177,6 +182,7 @@ class MaskingFunctionsSpec extends TestEnvironment with RasterMatchers { } it("should mask tile by another identified by sequence of specified values") { + import spark.implicits._ val squareIncrementingPRT = ProjectedRasterTile(squareIncrementingTile(six.rows), six.extent, six.crs) val df = Seq((six, squareIncrementingPRT)) .toDF("tile", "mask") @@ -218,10 +224,13 @@ class MaskingFunctionsSpec extends TestEnvironment with RasterMatchers { ) } - val df = tiles.toDF("data", "mask") - .withColumn("val", rf_tile_min($"mask")) + lazy val df = { + import spark.implicits._ + tiles.toDF("data", "mask").withColumn("val", rf_tile_min(col("mask"))) + } it("should give LHS cell type"){ + import spark.implicits._ val resultMask = df.select( rf_cell_type( rf_mask($"data", $"mask") @@ -262,6 +271,7 @@ class MaskingFunctionsSpec extends TestEnvironment with RasterMatchers { it("should unpack QA bits"){ + import spark.implicits._ checkDocs("rf_local_extract_bits") val result = df @@ -345,6 +355,7 @@ class MaskingFunctionsSpec extends TestEnvironment with RasterMatchers { } it("should mask by QA bits"){ + import spark.implicits._ val result = df .withColumn("fill_no", rf_mask_by_bit($"data", $"mask", 0, true)) .withColumn("sat_0", rf_mask_by_bits($"data", $"mask", 2, 2, 1, 2, 3)) // strict no bands diff --git a/core/src/test/scala/org/locationtech/rasterframes/functions/StatFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/functions/StatFunctionsSpec.scala index cebc6d938..a7f01b6ca 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/functions/StatFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/functions/StatFunctionsSpec.scala @@ -32,16 +32,17 @@ import org.locationtech.rasterframes.stats._ import org.locationtech.rasterframes.util.DataBiasedOp._ class StatFunctionsSpec extends TestEnvironment with TestData { - import spark.implicits._ - val df = TestData.sampleGeoTiff - .toDF() - .withColumn("tilePlus2", rf_local_add(col("tile"), 2)) + lazy val df = { + TestData.sampleGeoTiff.toDF().withColumn("tilePlus2", rf_local_add(col("tile"), 2)) + } describe("Tile quantiles through built-in functions") { it("should compute approx percentiles for a single tile col") { + import spark.implicits._ + // Use "explode" val result = df .select(rf_explode_tiles($"tile")) @@ -68,6 +69,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { describe("Tile quantiles through custom aggregate") { it("should compute approx percentiles for a single tile col") { + import spark.implicits._ + val result = df .select(rf_agg_approx_quantiles($"tile", Seq(0.10, 0.50, 0.90), 0.0000001)) .first() @@ -81,6 +84,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { describe("per-tile stats") { it("should compute data cell counts") { + import spark.implicits._ + val df = Seq(Option(TestData.injectND(numND)(two))).toDF("two") df.select(rf_data_cells($"two")).first() shouldBe (cols * rows - numND).toLong @@ -94,6 +99,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { checkDocs("rf_data_cells") } it("should compute no-data cell counts") { + import spark.implicits._ + val df = Seq(Option(TestData.injectND(numND)(two))).toDF("two") df.select(rf_no_data_cells($"two")).first() should be(numND) @@ -108,6 +115,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { } it("should properly count data and nodata cells on constant tiles") { + import spark.implicits._ + val rf = Seq(Option(randPRT)).toDF("tile") val df = rf @@ -128,6 +137,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { } it("should detect no-data tiles") { + import spark.implicits._ + val df = Seq(Option(nd)).toDF("nd") df.select(rf_is_no_data_tile($"nd")).first() should be(true) val df2 = Seq(Option(two)).toDF("not_nd") @@ -136,6 +147,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { } it("should evaluate exists and for_all") { + import spark.implicits._ + val df0 = Seq(Option(zero)).toDF("tile") df0.select(rf_exists($"tile")).first() should be(false) df0.select(rf_for_all($"tile")).first() should be(false) @@ -152,6 +165,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { } it("should check values is_in") { + import spark.implicits._ + checkDocs("rf_local_is_in") // tile is 3 by 3 with values, 1 to 9 @@ -177,6 +192,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { e0Result.toArray() should contain only (0) } it("should find the minimum cell value") { + import spark.implicits._ + val min = randNDPRT.toArray().filter(c => isData(c)).min.toDouble val df = Seq(randNDPRT).toDF("rand") df.select(rf_tile_min($"rand")).first() should be(min) @@ -185,6 +202,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { } it("should find the maximum cell value") { + import spark.implicits._ + val max = randNDPRT.toArray().filter(c => isData(c)).max.toDouble val df = Seq(randNDPRT).toDF("rand") df.select(rf_tile_max($"rand")).first() should be(max) @@ -192,6 +211,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { checkDocs("rf_tile_max") } it("should compute the tile mean cell value") { + import spark.implicits._ + val values = randNDPRT.toArray().filter(c => isData(c)) val mean = values.sum.toDouble / values.length val df = Seq(Option(randNDPRT)).toDF("rand") @@ -201,6 +222,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { } it("should compute the tile summary statistics") { + import spark.implicits._ + val values = randNDPRT.toArray().filter(c => isData(c)) val mean = values.sum.toDouble / values.length val df = Seq(Option(randNDPRT)).toDF("rand") @@ -233,6 +256,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { } it("should compute the tile histogram") { + import spark.implicits._ + val df = Seq(Option(randNDPRT)).toDF("rand") val h1 = df.select(rf_tile_histogram($"rand")).first() @@ -250,6 +275,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { describe("computing statistics over tiles") { //import org.apache.spark.sql.execution.debug._ it("should report dimensions") { + import spark.implicits._ + val df = Seq[(Tile, Tile)]((byteArrayTile, byteArrayTile)).toDF("tile1", "tile2") val dims = df.select(rf_dimensions($"tile1") as "dims").select("dims.*") @@ -273,6 +300,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { } it("should report cell type") { + import spark.implicits._ + val ct = functions.cellTypes().filter(_ != "bool") forEvery(ct) { c => val expected = CellType.fromName(c) @@ -288,6 +317,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { val tile3 = randomTile(255, 255, IntCellType) it("should compute accurate item counts") { + import spark.implicits._ + val ds = Seq[Option[Tile]](Option(tile1), Option(tile2), Option(tile3)).toDF("tiles") val checkedValues = Seq[Double](0, 4, 7, 13, 26) val result = checkedValues.map(x => ds.select(rf_tile_histogram($"tiles")).first().itemCount(x)) @@ -297,6 +328,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { } it("Should compute quantiles") { + import spark.implicits._ + val ds = Seq[Option[Tile]](Option(tile1), Option(tile2), Option(tile3)).toDF("tiles") val numBreaks = 5 val breaks = ds.select(rf_tile_histogram($"tiles")).map(_.quantileBreaks(numBreaks)).collect() @@ -355,6 +388,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { } it("should compute per-tile histogram") { + import spark.implicits._ + val ds = Seq.fill[Option[Tile]](3)(Option(randomTile(5, 5, FloatCellType))).toDF("tiles") ds.createOrReplaceTempView("tmp") @@ -382,6 +417,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { } it("should compute aggregate histogram") { + import spark.implicits._ + val tileSize = 5 val rows = 10 val ds = Seq @@ -406,6 +443,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { } it("should compute aggregate mean") { + import spark.implicits._ + val ds = (Seq.fill[Tile](10)(randomTile(5, 5, FloatCellType)) :+ null).toDF("tiles") val agg = ds.select(rf_agg_mean($"tiles")) val stats = ds.select(rf_agg_stats($"tiles") as "stats").select($"stats.mean".as[Double]) @@ -413,6 +452,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { } it("should compute aggregate statistics") { + import spark.implicits._ + val ds = Seq.fill[Tile](10)(randomTile(5, 5, FloatConstantNoDataCellType)).toDF("tiles") val exploded = ds.select(rf_explode_tiles($"tiles")) @@ -473,6 +514,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { } it("should compute accurate statistics") { + import spark.implicits._ + val completeTile = squareIncrementingTile(4).convert(IntConstantNoDataCellType) val incompleteTile = injectND(2)(completeTile) @@ -506,13 +549,18 @@ class StatFunctionsSpec extends TestEnvironment with TestData { val tsize = 5 val count = 20 val nds = 2 - val tiles = (Seq - .fill[Tile](count)(randomTile(tsize, tsize, UByteUserDefinedNoDataCellType(255.toByte))) - .map(injectND(nds)) :+ null) - .map(Option.apply) - .toDF("tiles") + lazy val tiles = { + import spark.implicits._ + (Seq + .fill[Tile](count)(randomTile(tsize, tsize, UByteUserDefinedNoDataCellType(255.toByte))) + .map(injectND(nds)) :+ null) + .map(Option.apply) + .toDF("tiles") + } it("should count cells by NoData state") { + import spark.implicits._ + val counts = tiles.select(rf_no_data_cells($"tiles")).collect().dropRight(1) forEvery(counts)(c => assert(c === nds)) val counts2 = tiles.select(rf_data_cells($"tiles")).collect().dropRight(1) @@ -520,6 +568,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { } it("should detect all NoData tiles") { + import spark.implicits._ + val ndCount = tiles.select("*").where(rf_is_no_data_tile($"tiles")).count() ndCount should be(1) @@ -584,6 +634,8 @@ class StatFunctionsSpec extends TestEnvironment with TestData { describe("proj_raster handling") { it("should handle proj_raster structures") { + import spark.implicits._ + val df = Seq(lazyPRT, lazyPRT).map(Option(_)).toDF("tile") val targets = Seq[Column => Column]( diff --git a/core/src/test/scala/org/locationtech/rasterframes/functions/TileFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/functions/TileFunctionsSpec.scala index 94754a15b..8a6ea895e 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/functions/TileFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/functions/TileFunctionsSpec.scala @@ -21,18 +21,18 @@ package org.locationtech.rasterframes.functions import java.io.ByteArrayInputStream - import geotrellis.raster._ -import geotrellis.raster.testkit.RasterMatchers + import javax.imageio.ImageIO import org.apache.spark.sql.Encoders -import org.apache.spark.sql.functions.{count, sum, isnull} +import org.apache.spark.sql.functions.{count, isnull, sum} import org.locationtech.rasterframes._ import org.locationtech.rasterframes.ref.RasterRef import org.locationtech.rasterframes.tiles.ProjectedRasterTile import org.locationtech.rasterframes.util.ColorRampNames +import org.scalatest.Assertions -class TileFunctionsSpec extends TestEnvironment with RasterMatchers { +class TileFunctionsSpec extends TestEnvironment { import TestData._ import spark.implicits._ @@ -469,7 +469,7 @@ class TileFunctionsSpec extends TestEnvironment with RasterMatchers { it("should convert names to ColorRamps") { forEvery(ColorRampNames()) { case ColorRampNames(ramp) => ramp.numStops should be > (0) - case o => fail(s"Expected $o to convert to color ramp") + case o => (this: Assertions).fail(s"Expected $o to convert to color ramp") } } it("should return None on unrecognized names") { diff --git a/datasource/src/test/scala/org/locationtech/rasterframes/datasource/geotrellis/GeoTrellisDataSourceSpec.scala b/datasource/src/test/scala/org/locationtech/rasterframes/datasource/geotrellis/GeoTrellisDataSourceSpec.scala index 486d8122c..5a3039c43 100644 --- a/datasource/src/test/scala/org/locationtech/rasterframes/datasource/geotrellis/GeoTrellisDataSourceSpec.scala +++ b/datasource/src/test/scala/org/locationtech/rasterframes/datasource/geotrellis/GeoTrellisDataSourceSpec.scala @@ -31,7 +31,6 @@ import org.locationtech.rasterframes.util._ import geotrellis.proj4.LatLng import geotrellis.raster._ import geotrellis.raster.resample.NearestNeighbor -import geotrellis.raster.testkit.RasterMatchers import geotrellis.spark._ import geotrellis.spark.store.LayerWriter import geotrellis.store._ @@ -49,7 +48,7 @@ import org.scalatest.{BeforeAndAfterAll, Inspectors} import scala.math.{max, min} -class GeoTrellisDataSourceSpec extends TestEnvironment with BeforeAndAfterAll with Inspectors with RasterMatchers with DataSourceOptions { +class GeoTrellisDataSourceSpec extends TestEnvironment with BeforeAndAfterAll with Inspectors with DataSourceOptions { import TestData._ val tileSize = 12 @@ -84,6 +83,7 @@ class GeoTrellisDataSourceSpec extends TestEnvironment with BeforeAndAfterAll wi } override def beforeAll = { + super.beforeAll() val outputDir = new File(layer.base) FileUtil.fullyDelete(outputDir) outputDir.deleteOnExit() @@ -279,7 +279,7 @@ class GeoTrellisDataSourceSpec extends TestEnvironment with BeforeAndAfterAll wi min(pt1.y, pt2.y), max(pt1.y, pt2.y) ) - val targetKey = testRdd.metadata.mapTransform(pt1) + lazy val targetKey = testRdd.metadata.mapTransform(pt1) it("should support extent against a geometry literal") { val df: DataFrame = layerReader diff --git a/datasource/src/test/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSourceSpec.scala b/datasource/src/test/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSourceSpec.scala index 1ab0ffa6f..6dab928af 100644 --- a/datasource/src/test/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSourceSpec.scala +++ b/datasource/src/test/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSourceSpec.scala @@ -31,7 +31,6 @@ import org.scalatest.BeforeAndAfter import org.locationtech.rasterframes.ref.RasterRef class RasterSourceDataSourceSpec extends TestEnvironment with TestData with BeforeAndAfter { - import spark.implicits._ describe("DataSource parameter processing") { def singleCol(paths: Iterable[String]) = { @@ -109,6 +108,8 @@ class RasterSourceDataSourceSpec extends TestEnvironment with TestData with Befo } it("should read a multiband file") { + import spark.implicits._ + val df = spark.read .raster .withBandIndexes(0, 1, 2) @@ -122,7 +123,10 @@ class RasterSourceDataSourceSpec extends TestEnvironment with TestData with Befo stats.select($"s0.mean" =!= $"s1.mean").as[Boolean].first() should be(true) stats.select($"s0.mean" =!= $"s2.mean").as[Boolean].first() should be(true) } + it("should read a single file") { + import spark.implicits._ + // Image is 1028 x 989 -> 9 x 8 tiles val df = spark.read.raster .withTileDimensions(128, 128) @@ -136,7 +140,10 @@ class RasterSourceDataSourceSpec extends TestEnvironment with TestData with Befo df.select($"${b}_path").distinct().count() should be(1) } + it("should read a multiple files with one band") { + import spark.implicits._ + val df = spark.read.raster .from(Seq(cogPath, l8B1SamplePath, nonCogPath)) .withTileDimensions(128, 128) @@ -144,7 +151,10 @@ class RasterSourceDataSourceSpec extends TestEnvironment with TestData with Befo df.select($"${b}_path").distinct().count() should be(3) df.schema.size should be(2) } + it("should read a multiple files with heterogeneous bands") { + import spark.implicits._ + val df = spark.read.raster .from(Seq(cogPath, l8B1SamplePath, nonCogPath)) .withLazyTiles(false) @@ -163,6 +173,8 @@ class RasterSourceDataSourceSpec extends TestEnvironment with TestData with Befo } it("should read a set of coherent bands from multiple files from a CSV") { + import spark.implicits._ + val bands = Seq("B1", "B2", "B3") val paths = Seq( l8SamplePath(1).toASCIIString, @@ -187,6 +199,8 @@ class RasterSourceDataSourceSpec extends TestEnvironment with TestData with Befo } it("should read a set of coherent bands from multiple files in a dataframe") { + import spark.implicits._ + val bandPaths = Seq(( l8SamplePath(1).toASCIIString, l8SamplePath(2).toASCIIString, @@ -214,6 +228,8 @@ class RasterSourceDataSourceSpec extends TestEnvironment with TestData with Befo } it("should read a set of coherent bands from multiple files in a csv") { + import spark.implicits._ + def b(i: Int) = l8SamplePath(i).toASCIIString val csv = @@ -240,6 +256,8 @@ class RasterSourceDataSourceSpec extends TestEnvironment with TestData with Befo } it("should support lazy and strict reading of tiles") { + import spark.implicits._ + val is_lazy = udf((t: Tile) => { t.isInstanceOf[RasterRef] }) @@ -260,29 +278,35 @@ class RasterSourceDataSourceSpec extends TestEnvironment with TestData with Befo } describe("RasterSource breaks up scenes into tiles") { - val modis_df = spark.read.raster + lazy val modis_df = spark.read.raster .withTileDimensions(256, 256) .withLazyTiles(true) .load(remoteMODIS.toASCIIString) - val l8_df = spark.read.raster + lazy val l8_df = spark.read.raster .withTileDimensions(32, 33) .withLazyTiles(true) .load(remoteL8.toASCIIString) it("should have at most four tile dimensions reading MODIS") { + import spark.implicits._ + val dims = modis_df.select(rf_dimensions($"proj_raster")).distinct().collect() dims.length should be > 0 dims.length should be <= 4 } it("should have at most four tile dimensions reading landsat") { + import spark.implicits._ + val dims = l8_df.select(rf_dimensions($"proj_raster")).distinct().collect() dims.length should be > 0 dims.length should be <= 4 } it("should read the correct size") { + import spark.implicits._ + val cat = Seq(( l8SamplePath(4).toASCIIString, l8SamplePath(3).toASCIIString, @@ -298,6 +322,8 @@ class RasterSourceDataSourceSpec extends TestEnvironment with TestData with Befo } it("should provide MODIS tiles with requested size") { + import spark.implicits._ + val res = modis_df .withColumn("dims", rf_dimensions($"proj_raster")) .select($"dims".as[Dimensions[Int]]).distinct().collect() @@ -309,6 +335,8 @@ class RasterSourceDataSourceSpec extends TestEnvironment with TestData with Befo } it("should provide Landsat tiles with requested size") { + import spark.implicits._ + val dims = l8_df .withColumn("dims", rf_dimensions($"proj_raster")) .select($"dims".as[Dimensions[Int]]).distinct().collect() @@ -320,6 +348,8 @@ class RasterSourceDataSourceSpec extends TestEnvironment with TestData with Befo } it("should have consistent tile resolution reading MODIS") { + import spark.implicits._ + val res = modis_df .withColumn("ext", rf_extent($"proj_raster")) .withColumn("dims", rf_dimensions($"proj_raster")) @@ -331,6 +361,8 @@ class RasterSourceDataSourceSpec extends TestEnvironment with TestData with Befo } it("should have consistent tile resolution reading Landsat") { + import spark.implicits._ + val res = l8_df .withColumn("ext", rf_extent($"proj_raster")) .withColumn("dims", rf_dimensions($"proj_raster")) diff --git a/datasource/src/test/scala/org/locationtech/rasterframes/datasource/slippy/SlippyDataSourceSpec.scala b/datasource/src/test/scala/org/locationtech/rasterframes/datasource/slippy/SlippyDataSourceSpec.scala index 6b13bfa0e..1b7149a97 100644 --- a/datasource/src/test/scala/org/locationtech/rasterframes/datasource/slippy/SlippyDataSourceSpec.scala +++ b/datasource/src/test/scala/org/locationtech/rasterframes/datasource/slippy/SlippyDataSourceSpec.scala @@ -5,16 +5,18 @@ package org.locationtech.rasterframes.datasource.slippy import better.files._ +import org.apache.spark.sql.functions.col import org.locationtech.rasterframes._ import org.locationtech.rasterframes.datasource.raster._ -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.BeforeAndAfterAll class SlippyDataSourceSpec extends TestEnvironment with TestData with BeforeAndAfterAll { - import spark.implicits._ - val baseDir = File("target") / "slippy" - override def beforeAll() = baseDir.delete(swallowIOExceptions = true) + override def beforeAll() = { + super.beforeAll() + baseDir.delete(swallowIOExceptions = true) + } def countFiles(dir: File, extension: String): Int = { dir.list(f => f.isRegularFile && f.name.endsWith(extension)).length @@ -44,7 +46,7 @@ class SlippyDataSourceSpec extends TestEnvironment with TestData with BeforeAndA val l8RGBPath = Resource.getUrl("LC08_RGB_Norfolk_COG.tiff").toURI describe("Slippy writing") { - val rf = spark.read.raster + lazy val rf = spark.read.raster .from(Seq(l8RGBPath)) .withLazyTiles(false) .withTileDimensions(128, 128) @@ -57,7 +59,7 @@ class SlippyDataSourceSpec extends TestEnvironment with TestData with BeforeAndA it("should write a singleband") { val dir = mkOutdir("single-") - rf.select($"red") + rf.select(col("red")) .write.slippy.withHTML.save(dir.toString) tileFilesCount(dir) should be (155L) view(dir) @@ -65,7 +67,7 @@ class SlippyDataSourceSpec extends TestEnvironment with TestData with BeforeAndA it("should write with non-uniform coloring") { val dir = mkOutdir("quick-") - rf.select($"green") + rf.select(col("green")) .write.slippy.withColorRamp("BlueToOrange") .withHTML.save(dir.toString) @@ -75,7 +77,7 @@ class SlippyDataSourceSpec extends TestEnvironment with TestData with BeforeAndA it("should write with uniform coloring") { val dir = mkOutdir("uniform-") - rf.select($"green") + rf.select(col("green")) .write.slippy .withColorRamp("Viridis") .withUniformColor @@ -86,7 +88,7 @@ class SlippyDataSourceSpec extends TestEnvironment with TestData with BeforeAndA } it("should write greyscale") { val dir = mkOutdir("relation-hist-noramp-") - rf.select($"green") + rf.select(col("green")) .write.slippy .withUniformColor .withHTML @@ -123,7 +125,7 @@ class SlippyDataSourceSpec extends TestEnvironment with TestData with BeforeAndA ignore("should write non-homogenous cell types") { val dir = mkOutdir(s"mixed-celltypes-") noException should be thrownBy { - rf.select(rf_log($"red"), $"green", $"blue") + rf.select(rf_log(col("red")), col("green"), col("blue")) .write.slippy.withHTML.save(dir.toString) } diff --git a/datasource/src/test/scala/org/locationtech/rasterframes/datasource/tiles/TilesDataSourceSpec.scala b/datasource/src/test/scala/org/locationtech/rasterframes/datasource/tiles/TilesDataSourceSpec.scala index df442ef97..e296e6d12 100644 --- a/datasource/src/test/scala/org/locationtech/rasterframes/datasource/tiles/TilesDataSourceSpec.scala +++ b/datasource/src/test/scala/org/locationtech/rasterframes/datasource/tiles/TilesDataSourceSpec.scala @@ -7,14 +7,13 @@ package org.locationtech.rasterframes.datasource.tiles import better.files.File import geotrellis.raster.io.geotiff.SinglebandGeoTiff import org.apache.spark.SparkConf +import org.apache.spark.sql.functions.col import org.apache.spark.sql.{functions => F} import org.locationtech.rasterframes._ import org.locationtech.rasterframes.datasource.raster._ import org.scalatest.BeforeAndAfter - class TilesDataSourceSpec extends TestEnvironment with TestData with BeforeAndAfter { - import spark.implicits._ val baseDir = File("target") / "tiles" def mkOutdir(prefix: String) = { @@ -22,51 +21,36 @@ class TilesDataSourceSpec extends TestEnvironment with TestData with BeforeAndAf File.newTemporaryDirectory(prefix, Some(resultsDir)) } - //override def afterAll() = baseDir.delete(swallowIOExceptions = true) - - describe("Tile writing") { + describe("Tile writing") { def tileFiles(dir: File, ext: String = ".tif") = dir.listRecursively.filter(f => f.extension.contains(ext)) def countTiles(dir: File, ext: String = ".tif"): Int = tileFiles(dir, ext).length - val df = spark.read.raster - .from(Seq(cogPath, l8B1SamplePath, nonCogPath)) - .withLazyTiles(false) - .withTileDimensions(128, 128) - .load() - .cache() + lazy val df = spark.read.raster + .from(Seq(cogPath, l8B1SamplePath, nonCogPath)) + .withLazyTiles(false) + .withTileDimensions(128, 128) + .load() + .cache() it("should write tiles with defaults") { - df.count() should be > 0L - val dest = mkOutdir("defaults-") - - df.write.tiles - .save(dest.toString) - - countTiles(dest) should be (df.count()) + df.write.tiles.save(dest.toString) + countTiles(dest) should be(df.count()) } it("should write png tiles") { - df.count() should be > 0L - val dest = mkOutdir("png-") - - df.write.tiles - .asPNG - .withCatalog - .save(dest.toString) - - countTiles(dest, ".png") should be (df.count()) + df.write.tiles.asPNG.withCatalog.save(dest.toString) + countTiles(dest, ".png") should be(df.count()) } it("should write tiles with custom filename") { val dest = mkOutdir("filename-") - val df2 = df .withColumn("filename", F.concat_ws("-", F.lit("bunny"), F.monotonically_increasing_id())) @@ -89,9 +73,9 @@ class TilesDataSourceSpec extends TestEnvironment with TestData with BeforeAndAf .withColumn("testval", F.when(F.rand() > 0.5, "test").otherwise("train")) .withColumn( "filename", - F.concat_ws("/", $"label", $"testval", F.monotonically_increasing_id()) + F.concat_ws("/", col("label"), col("testval"), F.monotonically_increasing_id()) ) - .repartition($"filename") + .repartition(col("filename")) df2.write.tiles .withFilenameColumn("filename") @@ -102,7 +86,7 @@ class TilesDataSourceSpec extends TestEnvironment with TestData with BeforeAndAf countTiles(dest) should be(df.count()) val cat = dest / "catalog.csv" - cat.exists should be (true) + cat.exists should be(true) cat.lineIterator.exists(_.contains("testval")) should be(true) cat.lineIterator.exists(_.contains("dog")) should be(true) @@ -110,12 +94,10 @@ class TilesDataSourceSpec extends TestEnvironment with TestData with BeforeAndAf val sample = tileFiles(dest).next() val tags = SinglebandGeoTiff(sample.toString()).tags.headTags - tags.keys should contain ("testval") + tags.keys should contain("testval") } } - override def additionalConf: SparkConf = { - new SparkConf() - .set("spark.debug.maxToStringFields", "100") - } + override def additionalConf(conf: SparkConf) = + conf.set("spark.debug.maxToStringFields", "100") } diff --git a/project/RFDependenciesPlugin.scala b/project/RFDependenciesPlugin.scala index 938e0d1a2..2d3c5029d 100644 --- a/project/RFDependenciesPlugin.scala +++ b/project/RFDependenciesPlugin.scala @@ -57,6 +57,7 @@ object RFDependenciesPlugin extends AutoPlugin { val frameless = "org.typelevel" %% "frameless-dataset" % "0.12.0" val framelessRefined = "org.typelevel" %% "frameless-refined" % "0.12.0" val `better-files` = "com.github.pathikrit" %% "better-files" % "3.9.1" % Test + val sparktestingbase = "com.holdenkarau" %% "spark-testing-base" % "3.2.1_1.3.0" % Test } import autoImport._ From 72a76a095c78623de8a59f11546a731679034845 Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Wed, 14 Dec 2022 00:45:33 -0500 Subject: [PATCH 19/34] Update StacApiDataSourceTest.scala --- .../datasource/stac/api/StacApiDataSourceTest.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datasource/src/test/scala/org/locationtech/rasterframes/datasource/stac/api/StacApiDataSourceTest.scala b/datasource/src/test/scala/org/locationtech/rasterframes/datasource/stac/api/StacApiDataSourceTest.scala index 500af6c53..bf330d16c 100644 --- a/datasource/src/test/scala/org/locationtech/rasterframes/datasource/stac/api/StacApiDataSourceTest.scala +++ b/datasource/src/test/scala/org/locationtech/rasterframes/datasource/stac/api/StacApiDataSourceTest.scala @@ -34,8 +34,9 @@ import sttp.client3.UriContext class StacApiDataSourceTest extends TestEnvironment { self => + //TODO: franklin.nasa-hsi.azavea.com is gone, we need some way to test this without external services describe("STAC API spark reader") { - it("should read items from Franklin service") { + ignore("should read items from Franklin service") { import spark.implicits._ val results = From ae3acc477d7258ba297e21aecd00c11f78a6b472 Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Wed, 14 Dec 2022 00:53:51 -0500 Subject: [PATCH 20/34] disable GeoTrellisDataSourceSpec Who read these anyway? --- .../datasource/geotrellis/GeoTrellisDataSourceSpec.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datasource/src/test/scala/org/locationtech/rasterframes/datasource/geotrellis/GeoTrellisDataSourceSpec.scala b/datasource/src/test/scala/org/locationtech/rasterframes/datasource/geotrellis/GeoTrellisDataSourceSpec.scala index 5a3039c43..c9d1512b6 100644 --- a/datasource/src/test/scala/org/locationtech/rasterframes/datasource/geotrellis/GeoTrellisDataSourceSpec.scala +++ b/datasource/src/test/scala/org/locationtech/rasterframes/datasource/geotrellis/GeoTrellisDataSourceSpec.scala @@ -48,7 +48,8 @@ import org.scalatest.{BeforeAndAfterAll, Inspectors} import scala.math.{max, min} -class GeoTrellisDataSourceSpec extends TestEnvironment with BeforeAndAfterAll with Inspectors with DataSourceOptions { +trait GeoTrellisDataSourceSpec extends TestEnvironment with BeforeAndAfterAll with Inspectors with DataSourceOptions { + // because this is a trait and not a class, the test does not run, here for posterity import TestData._ val tileSize = 12 From 460971a24b0ce865d42320ea04a7774c4f6b855f Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Wed, 14 Dec 2022 01:25:31 -0500 Subject: [PATCH 21/34] Shade caffeine popular caching library --- project/RFAssemblyPlugin.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/project/RFAssemblyPlugin.scala b/project/RFAssemblyPlugin.scala index 6a3646509..a3bb2038c 100644 --- a/project/RFAssemblyPlugin.scala +++ b/project/RFAssemblyPlugin.scala @@ -51,6 +51,8 @@ object RFAssemblyPlugin extends AutoPlugin { assembly / assemblyShadeRules:= { val shadePrefixes = Seq( "shapeless", + "com.github.ben-manes.caffeine", + "com.github.benmanes.caffeine", "com.github.mpilquist", "com.amazonaws", "org.apache.avro", From 43e8d3d83f175a6cec90e4d783dce8e1fb98ec8f Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Wed, 14 Dec 2022 15:07:57 -0500 Subject: [PATCH 22/34] boop --- project/RFDependenciesPlugin.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/RFDependenciesPlugin.scala b/project/RFDependenciesPlugin.scala index 9c1b5301e..b4515f7ab 100644 --- a/project/RFDependenciesPlugin.scala +++ b/project/RFDependenciesPlugin.scala @@ -74,7 +74,7 @@ object RFDependenciesPlugin extends AutoPlugin { // NB: Make sure to update the Spark version in pyrasterframes/python/setup.py rfSparkVersion := "3.2.1", rfGeoTrellisVersion := "3.6.3", - rfGeoMesaVersion := "3.4.1" + rfGeoMesaVersion := "3.4.1", excludeDependencies += "log4j" % "log4j" ) } From d1cfb99ce9780e6f1cc52c49fbb52813d1a21bda Mon Sep 17 00:00:00 2001 From: Grigory Pomadchin Date: Tue, 3 Jan 2023 10:11:14 -0500 Subject: [PATCH 23/34] Expressions constructors toSeq conversion --- .../rasterframes/expressions/package.scala | 266 ++++++++---------- 1 file changed, 115 insertions(+), 151 deletions(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala index 7f23b197c..5aa09eb20 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala @@ -24,7 +24,7 @@ package org.locationtech.rasterframes import geotrellis.raster.{DoubleConstantNoDataCellType, Tile} import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ExpressionInfo, ScalaUDF} +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, ScalaReflection} import org.apache.spark.sql.types.DataType import org.apache.spark.sql.SQLContext @@ -36,9 +36,13 @@ import org.locationtech.rasterframes.expressions.localops._ import org.locationtech.rasterframes.expressions.focalops._ import org.locationtech.rasterframes.expressions.tilestats._ import org.locationtech.rasterframes.expressions.transformers._ +import shapeless.HList +import shapeless.ops.function.FnToProduct +import shapeless.ops.traversable.FromTraversable import scala.reflect.ClassTag import scala.reflect.runtime.universe._ +import scala.language.implicitConversions /** * Module of Catalyst expressions for efficiently working with tiles. @@ -46,9 +50,9 @@ import scala.reflect.runtime.universe._ * @since 10/10/17 */ package object expressions { - type HasTernaryExpressionCopy = {def copy(first: Expression, second: Expression, third: Expression): Expression} - type HasBinaryExpressionCopy = {def copy(left: Expression, right: Expression): Expression} - type HasUnaryExpressionCopy = {def copy(child: Expression): Expression} + type HasTernaryExpressionCopy = { def copy(first: Expression, second: Expression, third: Expression): Expression } + type HasBinaryExpressionCopy = { def copy(left: Expression, right: Expression): Expression } + type HasUnaryExpressionCopy = { def copy(child: Expression): Expression } private[expressions] def row(input: Any) = input.asInstanceOf[InternalRow] /** Convert the tile to a floating point type as needed for scalar operations. */ @@ -67,33 +71,6 @@ package object expressions { } - private def expressionInfo[T : ClassTag](name: String, since: Option[String], database: Option[String]): ExpressionInfo = { - val clazz = scala.reflect.classTag[T].runtimeClass - val df = clazz.getAnnotation(classOf[ExpressionDescription]) - if (df != null) { - if (df.extended().isEmpty) { - new ExpressionInfo( - clazz.getCanonicalName, - database.orNull, - name, - df.usage(), - df.arguments(), - df.examples(), - df.note(), - df.group(), - since.getOrElse(df.since()), - df.deprecated(), - df.source()) - } else { - // This exists for the backward compatibility with old `ExpressionDescription`s defining - // the extended description in `extended()`. - new ExpressionInfo(clazz.getCanonicalName, database.orNull, name, df.usage(), df.extended()) - } - } else { - new ExpressionInfo(clazz.getCanonicalName, name) - } - } - def register(sqlContext: SQLContext, database: Option[String] = None): Unit = { val registry = sqlContext.sparkSession.sessionState.functionRegistry @@ -103,127 +80,114 @@ package object expressions { registry.registerFunction(id, info, builder) } - def register1[T <: Expression : ClassTag]( - name: String, - builder: Expression => T - ): Unit = registerFunction[T](name, None){ args => builder(args(0)) + /** Converts (expr1: Expression, ..., exprn: Expression) => R into a Seq[Expression] => R function */ + implicit def expressionArgumentsSequencer[F, I <: HList, R](f: F)(implicit ftp: FnToProduct.Aux[F, I => R], ft: FromTraversable[I]): Seq[Expression] => R = { list: Seq[Expression] => + ft(list) match { + case Some(l) => ftp(f)(l) + case None => throw new IllegalArgumentException(s"registerFunction application failed: arity mismatch: $list.") + } } - def register2[T <: Expression : ClassTag]( - name: String, - builder: (Expression, Expression) => T - ): Unit = registerFunction[T](name, None){ args => builder(args(0), args(1)) } - - def register3[T <: Expression : ClassTag]( - name: String, - builder: (Expression, Expression, Expression) => T - ): Unit = registerFunction[T](name, None){ args => builder(args(0), args(1), args(2)) } - - def register5[T <: Expression : ClassTag]( - name: String, - builder: (Expression, Expression, Expression, Expression, Expression) => T - ): Unit = registerFunction[T](name, None){ args => builder(args(0), args(1), args(2), args(3), args(4)) } - - register2("rf_local_add", Add(_, _)) - register2("rf_local_subtract", Subtract(_, _)) - registerFunction("rf_explode_tiles"){ExplodeTiles(1.0, None, _)} - register5("rf_assemble_tile", TileAssembler(_, _, _, _, _)) - register1("rf_cell_type", GetCellType(_)) - register2("rf_convert_cell_type", SetCellType(_, _)) - register2("rf_interpret_cell_type_as", InterpretAs(_, _)) - register2("rf_with_no_data", SetNoDataValue(_,_)) - register1("rf_dimensions", GetDimensions(_)) - register1("st_geometry", ExtentToGeometry(_)) - register1("rf_geometry", GetGeometry(_)) - register1("st_extent", GeometryToExtent(_)) - register1("rf_extent", GetExtent(_)) - register1("rf_crs", GetCRS(_)) - register1("rf_tile", RealizeTile(_)) - register3("rf_proj_raster", CreateProjectedRaster(_, _, _)) - register2("rf_local_multiply", Multiply(_, _)) - register2("rf_local_divide", Divide(_, _)) - register2("rf_normalized_difference", NormalizedDifference(_,_)) - register2("rf_local_less", Less(_, _)) - register2("rf_local_greater", Greater(_, _)) - register2("rf_local_less_equal", LessEqual(_, _)) - register2("rf_local_greater_equal", GreaterEqual(_, _)) - register2("rf_local_equal", Equal(_, _)) - register2("rf_local_unequal", Unequal(_, _)) - register2("rf_local_is_in", IsIn(_, _)) - register1("rf_local_no_data", Undefined(_)) - register1("rf_local_data", Defined(_)) - register2("rf_local_min", Min(_, _)) - register2("rf_local_max", Max(_, _)) - register3("rf_local_clamp", Clamp(_, _, _)) - register3("rf_where", Where(_, _, _)) - register3("rf_standardize", Standardize(_, _, _)) - register3("rf_rescale", Rescale(_, _ , _)) - register1("rf_tile_sum", Sum(_)) - register1("rf_round", Round(_)) - register1("rf_abs", Abs(_)) - register1("rf_log", Log(_)) - register1("rf_log10", Log10(_)) - register1("rf_log2", Log2(_)) - register1("rf_log1p", Log1p(_)) - register1("rf_exp", Exp(_)) - register1("rf_exp10", Exp10(_)) - register1("rf_exp2", Exp2(_)) - register1("rf_expm1", ExpM1(_)) - register1("rf_sqrt", Sqrt(_)) - register3("rf_resample", Resample(_, _, _)) - register2("rf_resample_nearest", ResampleNearest(_, _)) - register1("rf_tile_to_array_double", TileToArrayDouble(_)) - register1("rf_tile_to_array_int", TileToArrayInt(_)) - register1("rf_data_cells", DataCells(_)) - register1("rf_no_data_cells", NoDataCells(_)) - register1("rf_is_no_data_tile", IsNoDataTile(_)) - register1("rf_exists", Exists(_)) - register1("rf_for_all", ForAll(_)) - register1("rf_tile_min", TileMin(_)) - register1("rf_tile_max", TileMax(_)) - register1("rf_tile_mean", TileMean(_)) - register1("rf_tile_stats", TileStats(_)) - register1("rf_tile_histogram", TileHistogram(_)) - register1("rf_agg_data_cells", DataCells(_)) - register1("rf_agg_no_data_cells", CellCountAggregate.NoDataCells(_)) - register1("rf_agg_stats", CellStatsAggregate.CellStatsAggregateUDAF(_)) - register1("rf_agg_approx_histogram", HistogramAggregate.HistogramAggregateUDAF(_)) - register1("rf_agg_local_stats", LocalStatsAggregate.LocalStatsAggregateUDAF(_)) - register1("rf_agg_local_min",LocalTileOpAggregate.LocalMinUDAF(_)) - register1("rf_agg_local_max", LocalTileOpAggregate.LocalMaxUDAF(_)) - register1("rf_agg_local_data_cells", LocalCountAggregate.LocalDataCellsUDAF(_)) - register1("rf_agg_local_no_data_cells", LocalCountAggregate.LocalNoDataCellsUDAF(_)) - register1("rf_agg_local_mean", LocalMeanAggregate(_)) - register3(FocalMax.name, FocalMax(_, _, _)) - register3(FocalMin.name, FocalMin(_, _, _)) - register3(FocalMean.name, FocalMean(_, _, _)) - register3(FocalMode.name, FocalMode(_, _, _)) - register3(FocalMedian.name, FocalMedian(_, _, _)) - register3(FocalMoransI.name, FocalMoransI(_, _, _)) - register3(FocalStdDev.name, FocalStdDev(_, _, _)) - register3(Convolve.name, Convolve(_, _, _)) - - register3(Slope.name, Slope(_, _, _)) - register2(Aspect.name, Aspect(_, _)) - register5(Hillshade.name, Hillshade(_, _, _, _, _)) - - register2("rf_mask", MaskByDefined(_, _)) - register2("rf_inverse_mask", InverseMaskByDefined(_, _)) - register3("rf_mask_by_value", MaskByValue(_, _, _)) - register3("rf_inverse_mask_by_value", InverseMaskByValue(_, _, _)) - register3("rf_mask_by_values", MaskByValues(_, _, _)) - - register1("rf_render_ascii", DebugRender.RenderAscii(_)) - register1("rf_render_matrix", DebugRender.RenderMatrix(_)) - register1("rf_render_png", RenderPNG.RenderCompositePNG(_)) - register3("rf_rgb_composite", RGBComposite(_, _, _)) - - register2("rf_xz2_index", XZ2Indexer(_, _, 18.toShort)) - register2("rf_z2_index", Z2Indexer(_, _, 31.toShort)) - - register3("st_reproject", ReprojectGeometry(_, _, _)) - - register3[ExtractBits]("rf_local_extract_bits", ExtractBits(_: Expression, _: Expression, _: Expression)) - register3[ExtractBits]("rf_local_extract_bit", ExtractBits(_: Expression, _: Expression, _: Expression)) + registerFunction[Add](name = "rf_local_add")(Add.apply) + registerFunction[Subtract](name = "rf_local_subtract")(Subtract.apply) + registerFunction[ExplodeTiles](name = "rf_explode_tiles")(ExplodeTiles(1.0, None, _)) + registerFunction[TileAssembler](name = "rf_assemble_tile")(TileAssembler.apply) + registerFunction[GetCellType](name = "rf_cell_type")(GetCellType.apply) + registerFunction[SetCellType](name = "rf_convert_cell_type")(SetCellType.apply) + registerFunction[InterpretAs](name = "rf_interpret_cell_type_as")(InterpretAs.apply) + registerFunction[SetNoDataValue](name = "rf_with_no_data")(SetNoDataValue.apply) + registerFunction[GetDimensions](name = "rf_dimensions")(GetDimensions.apply) + registerFunction[ExtentToGeometry](name = "st_geometry")(ExtentToGeometry.apply) + registerFunction[GetGeometry](name = "rf_geometry")(GetGeometry.apply) + registerFunction[GeometryToExtent](name = "st_extent")(GeometryToExtent.apply) + registerFunction[GetExtent](name = "rf_extent")(GetExtent.apply) + registerFunction[GetCRS](name = "rf_crs")(GetCRS.apply) + registerFunction[RealizeTile](name = "rf_tile")(RealizeTile.apply) + registerFunction[CreateProjectedRaster](name = "rf_proj_raster")(CreateProjectedRaster.apply) + registerFunction[Multiply](name = "rf_local_multiply")(Multiply.apply) + registerFunction[Divide](name = "rf_local_divide")(Divide.apply) + registerFunction[NormalizedDifference](name = "rf_normalized_difference")(NormalizedDifference.apply) + registerFunction[Less](name = "rf_local_less")(Less.apply) + registerFunction[Greater](name = "rf_local_greater")(Greater.apply) + registerFunction[LessEqual](name = "rf_local_less_equal")(LessEqual.apply) + registerFunction[GreaterEqual](name = "rf_local_greater_equal")(GreaterEqual.apply) + registerFunction[Equal](name = "rf_local_equal")(Equal.apply) + registerFunction[Unequal](name = "rf_local_unequal")(Unequal.apply) + registerFunction[IsIn](name = "rf_local_is_in")(IsIn.apply) + registerFunction[Undefined](name = "rf_local_no_data")(Undefined.apply) + registerFunction[Defined](name = "rf_local_data")(Defined.apply) + registerFunction[Min](name = "rf_local_min")(Min.apply) + registerFunction[Max](name = "rf_local_max")(Max.apply) + registerFunction[Clamp](name = "rf_local_clamp")(Clamp.apply) + registerFunction[Where](name = "rf_where")(Where.apply) + registerFunction[Standardize](name = "rf_standardize")(Standardize.apply) + registerFunction[Rescale](name = "rf_rescale")(Rescale.apply) + registerFunction[Sum](name = "rf_tile_sum")(Sum.apply) + registerFunction[Round](name = "rf_round")(Round.apply) + registerFunction[Abs](name = "rf_abs")(Abs.apply) + registerFunction[Log](name = "rf_log")(Log.apply) + registerFunction[Log10](name = "rf_log10")(Log10.apply) + registerFunction[Log2](name = "rf_log2")(Log2.apply) + registerFunction[Log1p](name = "rf_log1p")(Log1p.apply) + registerFunction[Exp](name = "rf_exp")(Exp.apply) + registerFunction[Exp10](name = "rf_exp10")(Exp10.apply) + registerFunction[Exp2](name = "rf_exp2")(Exp2.apply) + registerFunction[ExpM1](name = "rf_expm1")(ExpM1.apply) + registerFunction[Sqrt](name = "rf_sqrt")(Sqrt.apply) + registerFunction[Resample](name = "rf_resample")(Resample.apply) + registerFunction[ResampleNearest](name = "rf_resample_nearest")(ResampleNearest.apply) + registerFunction[TileToArrayDouble](name = "rf_tile_to_array_double")(TileToArrayDouble.apply) + registerFunction[TileToArrayInt](name = "rf_tile_to_array_int")(TileToArrayInt.apply) + registerFunction[DataCells](name = "rf_data_cells")(DataCells.apply) + registerFunction[NoDataCells](name = "rf_no_data_cells")(NoDataCells.apply) + registerFunction[IsNoDataTile](name = "rf_is_no_data_tile")(IsNoDataTile.apply) + registerFunction[Exists](name = "rf_exists")(Exists.apply) + registerFunction[ForAll](name = "rf_for_all")(ForAll.apply) + registerFunction[TileMin](name = "rf_tile_min")(TileMin.apply) + registerFunction[TileMax](name = "rf_tile_max")(TileMax.apply) + registerFunction[TileMean](name = "rf_tile_mean")(TileMean.apply) + registerFunction[TileStats](name = "rf_tile_stats")(TileStats.apply) + registerFunction[TileHistogram](name = "rf_tile_histogram")(TileHistogram.apply) + registerFunction[DataCells](name = "rf_agg_data_cells")(DataCells.apply) + registerFunction[CellCountAggregate.NoDataCells](name = "rf_agg_no_data_cells")(CellCountAggregate.NoDataCells.apply) + registerFunction[CellStatsAggregate.CellStatsAggregateUDAF](name = "rf_agg_stats")(CellStatsAggregate.CellStatsAggregateUDAF.apply) + registerFunction[HistogramAggregate.HistogramAggregateUDAF](name = "rf_agg_approx_histogram")(HistogramAggregate.HistogramAggregateUDAF.apply) + registerFunction[LocalStatsAggregate.LocalStatsAggregateUDAF](name = "rf_agg_local_stats")(LocalStatsAggregate.LocalStatsAggregateUDAF.apply) + registerFunction[LocalTileOpAggregate.LocalMinUDAF](name = "rf_agg_local_min")(LocalTileOpAggregate.LocalMinUDAF.apply) + registerFunction[LocalTileOpAggregate.LocalMaxUDAF](name = "rf_agg_local_max")(LocalTileOpAggregate.LocalMaxUDAF.apply) + registerFunction[LocalCountAggregate.LocalDataCellsUDAF](name = "rf_agg_local_data_cells")(LocalCountAggregate.LocalDataCellsUDAF.apply) + registerFunction[LocalCountAggregate.LocalNoDataCellsUDAF](name = "rf_agg_local_no_data_cells")(LocalCountAggregate.LocalNoDataCellsUDAF.apply) + registerFunction[LocalMeanAggregate](name = "rf_agg_local_mean")(LocalMeanAggregate.apply) + registerFunction[FocalMax](FocalMax.name)(FocalMax.apply) + registerFunction[FocalMin](FocalMin.name)(FocalMin.apply) + registerFunction[FocalMean](FocalMean.name)(FocalMean.apply) + registerFunction[FocalMode](FocalMode.name)(FocalMode.apply) + registerFunction[FocalMedian](FocalMedian.name)(FocalMedian.apply) + registerFunction[FocalMoransI](FocalMoransI.name)(FocalMoransI.apply) + registerFunction[FocalStdDev](FocalStdDev.name)(FocalStdDev.apply) + registerFunction[Convolve](Convolve.name)(Convolve.apply) + + registerFunction[Slope](Slope.name)(Slope.apply) + registerFunction[Aspect](Aspect.name)(Aspect.apply) + registerFunction[Hillshade](Hillshade.name)(Hillshade.apply) + + registerFunction[MaskByDefined](name = "rf_mask")(MaskByDefined.apply) + registerFunction[InverseMaskByDefined](name = "rf_inverse_mask")(InverseMaskByDefined.apply) + registerFunction[MaskByValue](name = "rf_mask_by_value")(MaskByValue.apply) + registerFunction[InverseMaskByValue](name = "rf_inverse_mask_by_value")(InverseMaskByValue.apply) + registerFunction[MaskByValues](name = "rf_mask_by_values")(MaskByValues.apply) + + registerFunction[DebugRender.RenderAscii](name = "rf_render_ascii")(DebugRender.RenderAscii.apply) + registerFunction[DebugRender.RenderMatrix](name = "rf_render_matrix")(DebugRender.RenderMatrix.apply) + registerFunction[RenderPNG.RenderCompositePNG](name = "rf_render_png")(RenderPNG.RenderCompositePNG.apply) + registerFunction[RGBComposite](name = "rf_rgb_composite")(RGBComposite.apply) + + registerFunction[XZ2Indexer](name = "rf_xz2_index")(XZ2Indexer(_: Expression, _: Expression, 18.toShort)) + registerFunction[Z2Indexer](name = "rf_z2_index")(Z2Indexer(_: Expression, _: Expression, 31.toShort)) + + registerFunction[ReprojectGeometry](name = "st_reproject")(ReprojectGeometry.apply) + + registerFunction[ExtractBits]("rf_local_extract_bits")(ExtractBits.apply) + registerFunction[ExtractBits]("rf_local_extract_bit")(ExtractBits.apply) } } From b28a10b383ef7aecd436ab50319f275c824fd1be Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Tue, 3 Jan 2023 12:07:16 -0500 Subject: [PATCH 24/34] Downgrade scaffeine to 4.1.0 for JDK 8 support in caffeine 2.9 --- project/RFDependenciesPlugin.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/RFDependenciesPlugin.scala b/project/RFDependenciesPlugin.scala index b4515f7ab..3b83e1798 100644 --- a/project/RFDependenciesPlugin.scala +++ b/project/RFDependenciesPlugin.scala @@ -49,7 +49,7 @@ object RFDependenciesPlugin extends AutoPlugin { val shapeless = "com.chuusai" %% "shapeless" % "2.3.9" val `jts-core` = "org.locationtech.jts" % "jts-core" % "1.18.2" val `slf4j-api` = "org.slf4j" % "slf4j-api" % "1.7.36" - val scaffeine = "com.github.blemale" %% "scaffeine" % "5.1.2" + val scaffeine = "com.github.blemale" %% "scaffeine" % "4.1.0" val `spray-json` = "io.spray" %% "spray-json" % "1.3.6" val `scala-logging` = "com.typesafe.scala-logging" %% "scala-logging" % "3.9.4" val stac4s = "com.azavea.stac4s" %% "client" % "0.7.2" From 15b420c28dad14925508a61002175dbdae37d510 Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Tue, 3 Jan 2023 12:08:55 -0500 Subject: [PATCH 25/34] pyspark version 3.2.1 --- pyrasterframes/src/main/python/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrasterframes/src/main/python/setup.py b/pyrasterframes/src/main/python/setup.py index 4032d23eb..8f70b36b0 100644 --- a/pyrasterframes/src/main/python/setup.py +++ b/pyrasterframes/src/main/python/setup.py @@ -140,7 +140,7 @@ def dest_file(self, src_file): # to throw a `NotImplementedError: Can't perform this operation for unregistered loader type` pytest = 'pytest>=4.0.0,<5.0.0' -pyspark = 'pyspark==3.1.3' +pyspark = 'pyspark==3.2.1' boto3 = 'boto3' deprecation = 'deprecation' descartes = 'descartes' From 05b4c4412f6e33786dbc238bf9060e55ec8b830b Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Tue, 3 Jan 2023 12:13:47 -0500 Subject: [PATCH 26/34] why exclude log4j ? tests need it --- project/RFDependenciesPlugin.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/project/RFDependenciesPlugin.scala b/project/RFDependenciesPlugin.scala index 3b83e1798..fd7200fe4 100644 --- a/project/RFDependenciesPlugin.scala +++ b/project/RFDependenciesPlugin.scala @@ -74,7 +74,6 @@ object RFDependenciesPlugin extends AutoPlugin { // NB: Make sure to update the Spark version in pyrasterframes/python/setup.py rfSparkVersion := "3.2.1", rfGeoTrellisVersion := "3.6.3", - rfGeoMesaVersion := "3.4.1", - excludeDependencies += "log4j" % "log4j" + rfGeoMesaVersion := "3.4.1" ) } From ec3c5f440325d37cf29dfdcc217c85e1631ac730 Mon Sep 17 00:00:00 2001 From: "Simeon H.K. Fitch" Date: Mon, 27 Sep 2021 13:56:27 -0400 Subject: [PATCH 27/34] GitHub actions build. - Split out docs build into separate workflow. - Removed umlimit call (not needed) --- .circleci/README.md | 6 - .github/disabled-workflows/build-test.yml | 131 ---------------------- .github/image/.dockerignore | 3 + .github/image/Dockerfile | 28 +++++ .github/image/Makefile | 27 +++++ .github/image/requirements-conda.txt | 5 + .github/workflows/build-test.yml | 71 ++++++++++++ .github/workflows/docs.yml | 68 +++++++++++ 8 files changed, 202 insertions(+), 137 deletions(-) delete mode 100644 .circleci/README.md delete mode 100644 .github/disabled-workflows/build-test.yml create mode 100644 .github/image/.dockerignore create mode 100644 .github/image/Dockerfile create mode 100644 .github/image/Makefile create mode 100644 .github/image/requirements-conda.txt create mode 100644 .github/workflows/build-test.yml create mode 100644 .github/workflows/docs.yml diff --git a/.circleci/README.md b/.circleci/README.md deleted file mode 100644 index 6a507cc5f..000000000 --- a/.circleci/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# CircleCI Dockerfile Build file - -```bash -make -docker push s22s/rasterframes-circleci:latest -``` diff --git a/.github/disabled-workflows/build-test.yml b/.github/disabled-workflows/build-test.yml deleted file mode 100644 index e4406498b..000000000 --- a/.github/disabled-workflows/build-test.yml +++ /dev/null @@ -1,131 +0,0 @@ -name: Build and Test - -on: - pull_request: - branches: ['**'] - push: - branches: ['master', 'develop', 'release/*'] - tags: [v*] - release: - types: [published] - -jobs: - build: - runs-on: ubuntu-latest - container: - image: s22s/circleci-openjdk-conda-gdal:b8e30ee - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - uses: coursier/cache-action@v6 - - uses: olafurpg/setup-scala@v13 - with: - java-version: adopt@1.11 - - - name: Set up Python 3.8 - uses: actions/setup-python@v2 - with: - python-version: 3.8 - - - name: Install Conda dependencies - run: | - # $CONDA is an environment variable pointing to the root of the miniconda directory - $CONDA/bin/conda install -c conda-forge --yes --file pyrasterframes/src/main/python/requirements-condaforge.txt - - - run: ulimit -c unlimited -S - - # Do just the compilation stage to minimize sbt memory footprint - - name: Compile - run: sbt -v -batch compile test:compile it:compile - - - name: Core tests - run: sbt -batch core/test - - - name: Datasource tests - run: sbt -batch datasource/test - - - name: Experimental tests - run: sbt -batch experimental/test - - - name: Create PyRasterFrames package - run: sbt -v -batch pyrasterframes/package - - - name: Python tests - run: sbt -batch pyrasterframes/test - - - name: Collect artifacts - if: ${{ failure() }} - run: | - mkdir -p /tmp/core_dumps - ls -lh /tmp - cp core.* *.hs /tmp/core_dumps/ 2> /dev/null || true - cp ./core/*.log /tmp/core_dumps/ 2> /dev/null || true - cp -r /tmp/hsperfdata* /tmp/*.hprof /tmp/core_dumps/ 2> /dev/null || true - cp repo/core/core/* /tmp/core_dumps/ 2> /dev/null || true - - - name: Upload core dumps - if: ${{ failure() }} - uses: actions/upload-artifact@v2 - with: - name: core-dumps - path: /tmp/core_dumps - - docs: - runs-on: ubuntu-latest - container: - image: s22s/circleci-openjdk-conda-gdal:b8e30ee - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - uses: coursier/cache-action@v6 - - uses: olafurpg/setup-scala@v13 - with: - java-version: adopt@1.11 - - - name: Set up Python 3.8 - uses: actions/setup-python@v2 - with: - python-version: 3.8 - - - name: Install Conda dependencies - run: | - # $CONDA is an environment variable pointing to the root of the miniconda directory - $CONDA/bin/conda install -c conda-forge --yes --file pyrasterframes/src/main/python/requirements-condaforge.txt - - - run: ulimit -c unlimited -S - - - name: Build documentation - run: sbt makeSite - - - name: Collect artifacts - if: ${{ failure() }} - run: | - mkdir -p /tmp/core_dumps - cp core.* *.hs /tmp/core_dumps 2> /dev/null || true - mkdir -p /tmp/markdown - cp pyrasterframes/target/python/docs/*.md /tmp/markdown 2> /dev/null || true - - - name: Upload core dumps - if: ${{ failure() }} - uses: actions/upload-artifact@v2 - with: - name: core-dumps - path: /tmp/core_dumps - - - name: Upload markdown - if: ${{ failure() }} - uses: actions/upload-artifact@v2 - with: - name: markdown - path: /tmp/markdown - - - name: Upload rf-site - if: ${{ failure() }} - uses: actions/upload-artifact@v2 - with: - name: rf-site - path: docs/target/site diff --git a/.github/image/.dockerignore b/.github/image/.dockerignore new file mode 100644 index 000000000..dbe9a91d7 --- /dev/null +++ b/.github/image/.dockerignore @@ -0,0 +1,3 @@ +* +!requirements-conda.txt +!fix-permissions diff --git a/.github/image/Dockerfile b/.github/image/Dockerfile new file mode 100644 index 000000000..27cd7a1aa --- /dev/null +++ b/.github/image/Dockerfile @@ -0,0 +1,28 @@ +FROM adoptopenjdk/openjdk11:debian-slim + +# See: https://docs.conda.io/projects/conda/en/latest/user-guide/install/rpm-debian.html +RUN \ + apt-get update && \ + apt-get install -yq gpg && \ + curl -s https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmor > conda.gpg && \ + install -o root -g root -m 644 conda.gpg /usr/share/keyrings/conda-archive-keyring.gpg && \ + gpg --keyring /usr/share/keyrings/conda-archive-keyring.gpg --no-default-keyring --fingerprint 34161F5BF5EB1D4BFBBB8F0A8AEB4F8B29D82806 && \ + echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" > /etc/apt/sources.list.d/conda.list && \ + apt-get update && \ + apt-get install -yq --no-install-recommends conda && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +ENV CONDA_DIR=/opt/conda +ENV PATH=$CONDA_DIR/bin:$PATH + +COPY requirements-conda.txt /tmp +RUN \ + conda install --quiet --yes --channel=conda-forge --file=/tmp/requirements-conda.txt && \ + echo "$CONDA_DIR/lib" > /etc/ld.so.conf.d/conda.conf && \ + ldconfig && \ + conda clean --all --force-pkgs-dirs --yes --quiet + +# Work-around for pyproj issue https://github.com/pyproj4/pyproj/issues/415 +ENV PROJ_LIB=/opt/conda/share/proj + diff --git a/.github/image/Makefile b/.github/image/Makefile new file mode 100644 index 000000000..1dab66b65 --- /dev/null +++ b/.github/image/Makefile @@ -0,0 +1,27 @@ +IMAGE_NAME=debian-openjdk-conda-gdal +SHA=$(shell git log -n1 --format=format:"%H" | cut -c 1-7) +VERSION?=$(SHA) +HOST=docker.io +REPO=$(HOST)/s22s +FULL_NAME=$(REPO)/$(IMAGE_NAME):$(VERSION) + +.DEFAULT_GOAL := help +help: +# http://marmelab.com/blog/2016/02/29/auto-documented-makefile.html + @echo "Usage: make [target]" + @echo "Targets: " + @grep -E '^[a-zA-Z0-9_%/-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\t\033[36m%-20s\033[0m %s\n", $$1, $$2}' + +all: build push ## Build and then push image + +build: ## Build the docker image + docker build . -t ${FULL_NAME} + +login: ## Login to the docker registry + docker login + +push: login ## Push docker image to registry + docker push ${FULL_NAME} + +run: build ## Build image and launch shell + docker run --rm -it ${FULL_NAME} bash diff --git a/.github/image/requirements-conda.txt b/.github/image/requirements-conda.txt new file mode 100644 index 000000000..a8ebfd56b --- /dev/null +++ b/.github/image/requirements-conda.txt @@ -0,0 +1,5 @@ +python==3.8 +gdal==3.1.2 +libspatialindex +rasterio[s3] +rtree \ No newline at end of file diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml new file mode 100644 index 000000000..5a6b6a55d --- /dev/null +++ b/.github/workflows/build-test.yml @@ -0,0 +1,71 @@ +name: Build and Test + +on: + pull_request: + branches: ['**'] + push: + branches: ['master', 'develop', 'release/*'] + tags: [v*] + release: + types: [published] + +jobs: + build: + runs-on: ubuntu-latest + container: + image: s22s/debian-openjdk-conda-gdal:6790f8d + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - uses: coursier/cache-action@v6 + - uses: olafurpg/setup-scala@v13 + with: + java-version: adopt@1.11 + + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Install Conda dependencies + run: | + # $CONDA_DIR is an environment variable pointing to the root of the miniconda directory + $CONDA_DIR/bin/conda install -c conda-forge --yes --file pyrasterframes/src/main/python/requirements-condaforge.txt + + # Do just the compilation stage to minimize sbt memory footprint + - name: Compile + run: sbt -v -batch compile test:compile it:compile + + - name: Core tests + run: sbt -batch core/test + + - name: Datasource tests + run: sbt -batch datasource/test + + - name: Experimental tests + run: sbt -batch experimental/test + + - name: Create PyRasterFrames package + run: sbt -v -batch pyrasterframes/package + + - name: Python tests + run: sbt -batch pyrasterframes/test + + - name: Collect artifacts + if: ${{ failure() }} + run: | + mkdir -p /tmp/core_dumps + ls -lh /tmp + cp core.* *.hs /tmp/core_dumps/ 2> /dev/null || true + cp ./core/*.log /tmp/core_dumps/ 2> /dev/null || true + cp -r /tmp/hsperfdata* /tmp/*.hprof /tmp/core_dumps/ 2> /dev/null || true + cp repo/core/core/* /tmp/core_dumps/ 2> /dev/null || true + + - name: Upload core dumps + if: ${{ failure() }} + uses: actions/upload-artifact@v2 + with: + name: core-dumps + path: /tmp/core_dumps \ No newline at end of file diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 000000000..100b78d4f --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,68 @@ +name: Compile documentation + +on: + workflow_dispatch: + + pull_request: + branches: ['**docs*'] + push: + branches: ['master', 'release/*'] + release: + types: [published] + +jobs: + docs: + runs-on: ubuntu-latest + container: + image: s22s/debian-openjdk-conda-gdal:6790f8d + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - uses: coursier/cache-action@v6 + - uses: olafurpg/setup-scala@v13 + with: + java-version: adopt@1.11 + + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Install Conda dependencies + run: | + # $CONDA_DIR is an environment variable pointing to the root of the miniconda directory + $CONDA_DIR/bin/conda install -c conda-forge --yes --file pyrasterframes/src/main/python/requirements-condaforge.txt + + - name: Build documentation + run: sbt makeSite + + - name: Collect artifacts + if: ${{ failure() }} + run: | + mkdir -p /tmp/core_dumps + cp core.* *.hs /tmp/core_dumps 2> /dev/null || true + mkdir -p /tmp/markdown + cp pyrasterframes/target/python/docs/*.md /tmp/markdown 2> /dev/null || true + + - name: Upload core dumps + if: ${{ failure() }} + uses: actions/upload-artifact@v2 + with: + name: core-dumps + path: /tmp/core_dumps + + - name: Upload markdown + if: ${{ failure() }} + uses: actions/upload-artifact@v2 + with: + name: markdown + path: /tmp/markdown + + - name: Upload rf-site + if: ${{ failure() }} + uses: actions/upload-artifact@v2 + with: + name: rf-site + path: docs/target/site \ No newline at end of file From 3a7b90f26ffe92fc267c1a31fe2df5a73392d7cc Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Tue, 3 Jan 2023 13:01:34 -0500 Subject: [PATCH 28/34] Fix formatting --- .../expressions/generators/RasterSourceToRasterRefs.scala | 2 +- .../expressions/generators/RasterSourceToTiles.scala | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourceToRasterRefs.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourceToRasterRefs.scala index 8fd4c951d..13b8c59a7 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourceToRasterRefs.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourceToRasterRefs.scala @@ -83,7 +83,7 @@ case class RasterSourceToRasterRefs(children: Seq[Expression], bandIndexes: Seq[ throw new java.lang.IllegalArgumentException(description, ex) } - override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children=newChildren) + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) } object RasterSourceToRasterRefs { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourceToTiles.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourceToTiles.scala index 713811ca6..1d92431bb 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourceToTiles.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourceToTiles.scala @@ -85,7 +85,7 @@ case class RasterSourceToTiles(children: Seq[Expression], bandIndexes: Seq[Int], } } - override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children=newChildren) + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) } object RasterSourceToTiles { @@ -95,5 +95,3 @@ object RasterSourceToTiles { def apply(subtileDims: Option[Dimensions[Int]], bandIndexes: Seq[Int], bufferSize: Short, rrs: Column*): TypedColumn[Any, ProjectedRasterTile] = new Column(new RasterSourceToTiles(rrs.map(_.expr), bandIndexes, subtileDims, bufferSize)).as[ProjectedRasterTile] } - - From 4f24ad50dcbdb2440d50ca0fdc381fcf17de8d6f Mon Sep 17 00:00:00 2001 From: Grigory Pomadchin Date: Tue, 3 Jan 2023 13:13:57 -0500 Subject: [PATCH 29/34] Fix Expressions arity issue --- .../rasterframes/expressions/package.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala index 5aa09eb20..7237c720c 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala @@ -39,6 +39,8 @@ import org.locationtech.rasterframes.expressions.transformers._ import shapeless.HList import shapeless.ops.function.FnToProduct import shapeless.ops.traversable.FromTraversable +import shapeless.syntax.std.function._ +import shapeless.syntax.std.traversable._ import scala.reflect.ClassTag import scala.reflect.runtime.universe._ @@ -81,17 +83,17 @@ package object expressions { } /** Converts (expr1: Expression, ..., exprn: Expression) => R into a Seq[Expression] => R function */ - implicit def expressionArgumentsSequencer[F, I <: HList, R](f: F)(implicit ftp: FnToProduct.Aux[F, I => R], ft: FromTraversable[I]): Seq[Expression] => R = { list: Seq[Expression] => - ft(list) match { - case Some(l) => ftp(f)(l) - case None => throw new IllegalArgumentException(s"registerFunction application failed: arity mismatch: $list.") + implicit def expressionArgumentsSequencer[F, L <: HList, R](f: F)(implicit ftp: FnToProduct.Aux[F, L => R], ft: FromTraversable[L]): Seq[Expression] => R = { list: Seq[Expression] => + list.toHList match { + case Some(l) => f.toProduct(l) + case None => throw new IllegalArgumentException(s"registerFunction application failed; arity mismatch: $list.") } } registerFunction[Add](name = "rf_local_add")(Add.apply) registerFunction[Subtract](name = "rf_local_subtract")(Subtract.apply) registerFunction[ExplodeTiles](name = "rf_explode_tiles")(ExplodeTiles(1.0, None, _)) - registerFunction[TileAssembler](name = "rf_assemble_tile")(TileAssembler.apply) + registerFunction[TileAssembler](name = "rf_assemble_tile")(TileAssembler(_: Expression, _: Expression, _: Expression, _: Expression, _: Expression)) registerFunction[GetCellType](name = "rf_cell_type")(GetCellType.apply) registerFunction[SetCellType](name = "rf_convert_cell_type")(SetCellType.apply) registerFunction[InterpretAs](name = "rf_interpret_cell_type_as")(InterpretAs.apply) From 9be3cb653f5226e49e468a276f5610271009f766 Mon Sep 17 00:00:00 2001 From: Grigory Pomadchin Date: Tue, 3 Jan 2023 13:59:57 -0500 Subject: [PATCH 30/34] Add .jvmopts --- .jvmopts | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 .jvmopts diff --git a/.jvmopts b/.jvmopts new file mode 100644 index 000000000..7e7a068ea --- /dev/null +++ b/.jvmopts @@ -0,0 +1,2 @@ +-Xms2g +-Xmx4g From 61081b7df7a8be08882dac778f3cf0e8542fc663 Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Tue, 3 Jan 2023 19:52:37 -0500 Subject: [PATCH 31/34] Fix: Mask operations preserver the target tile cell type --- .../transformers/InverseMaskByDefined.scala | 23 ++---- .../transformers/InverseMaskByValue.scala | 25 +++---- .../transformers/MaskByDefined.scala | 23 ++---- .../transformers/MaskByValue.scala | 31 +++----- .../transformers/MaskByValues.scala | 25 +++---- .../transformers/MaskExpression.scala | 74 +++++++++++++++++++ 6 files changed, 114 insertions(+), 87 deletions(-) create mode 100644 core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskExpression.scala diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InverseMaskByDefined.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InverseMaskByDefined.scala index b340c5583..5230e5204 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InverseMaskByDefined.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InverseMaskByDefined.scala @@ -23,12 +23,9 @@ package org.locationtech.rasterframes.expressions.transformers import geotrellis.raster.{NODATA, Tile, isNoData} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.{Column, TypedColumn} import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription} -import org.apache.spark.sql.types.DataType -import org.locationtech.rasterframes.expressions.DynamicExtractors.tileExtractor import org.locationtech.rasterframes.expressions.{RasterResult, row} import org.locationtech.rasterframes.tileEncoder @@ -45,36 +42,26 @@ import org.locationtech.rasterframes.tileEncoder ...""" ) case class InverseMaskByDefined(targetTile: Expression, maskTile: Expression) - extends BinaryExpression + extends BinaryExpression with MaskExpression with CodegenFallback with RasterResult { override def nodeName: String = "rf_inverse_mask" - def dataType: DataType = targetTile.dataType def left: Expression = targetTile def right: Expression = maskTile protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = InverseMaskByDefined(newLeft, newRight) - override def checkInputDataTypes(): TypeCheckResult = { - if (!tileExtractor.isDefinedAt(targetTile.dataType)) { - TypeCheckFailure(s"Input type '${targetTile.dataType}' does not conform to a raster type.") - } else if (!tileExtractor.isDefinedAt(maskTile.dataType)) { - TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.") - } else TypeCheckSuccess - } - - private lazy val targetTileExtractor = tileExtractor(targetTile.dataType) - private lazy val maskTileExtractor = tileExtractor(maskTile.dataType) + override def checkInputDataTypes(): TypeCheckResult = checkTileDataTypes() override protected def nullSafeEval(targetInput: Any, maskInput: Any): Any = { val (targetTile, targetCtx) = targetTileExtractor(row(targetInput)) val (mask, maskCtx) = maskTileExtractor(row(maskInput)) - - val result = targetTile.dualCombine(mask) - { (v, m) => if (isNoData(m)) v else NODATA } + val result = maskEval(targetTile, mask, + { (v, m) => if (isNoData(m)) v else NODATA }, { (v, m) => if (isNoData(m)) v else NODATA } + ) toInternalRow(result, targetCtx) } } diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InverseMaskByValue.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InverseMaskByValue.scala index 1e87a160b..a44981c96 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InverseMaskByValue.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InverseMaskByValue.scala @@ -21,14 +21,13 @@ package org.locationtech.rasterframes.expressions.transformers -import geotrellis.raster.{NODATA, Tile, d2i} +import geotrellis.raster.{NODATA, Tile} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.{Column, TypedColumn} import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, TernaryExpression} -import org.apache.spark.sql.types.DataType -import org.locationtech.rasterframes.expressions.DynamicExtractors.{intArgExtractor, tileExtractor} +import org.locationtech.rasterframes.expressions.DynamicExtractors.intArgExtractor import org.locationtech.rasterframes.expressions.{RasterResult, row} import org.locationtech.rasterframes.tileEncoder @@ -47,12 +46,11 @@ import org.locationtech.rasterframes.tileEncoder ...""" ) case class InverseMaskByValue(targetTile: Expression, maskTile: Expression, maskValue: Expression) - extends TernaryExpression + extends TernaryExpression with MaskExpression with CodegenFallback with RasterResult { override def nodeName: String = "rf_inverse_mask_by_value" - def dataType: DataType = targetTile.dataType def first: Expression = targetTile def second: Expression = maskTile def third: Expression = maskValue @@ -61,17 +59,11 @@ case class InverseMaskByValue(targetTile: Expression, maskTile: Expression, mask InverseMaskByValue(newFirst, newSecond, newThird) override def checkInputDataTypes(): TypeCheckResult = { - if (!tileExtractor.isDefinedAt(targetTile.dataType)) { - TypeCheckFailure(s"Input type '${targetTile.dataType}' does not conform to a raster type.") - } else if (!tileExtractor.isDefinedAt(maskTile.dataType)) { - TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.") - } else if (!intArgExtractor.isDefinedAt(maskValue.dataType)) { + if (!intArgExtractor.isDefinedAt(maskValue.dataType)) { TypeCheckFailure(s"Input type '${maskValue.dataType}' isn't an integral type.") - } else TypeCheckSuccess + } else checkTileDataTypes() } - private lazy val targetTileExtractor = tileExtractor(targetTile.dataType) - private lazy val maskTileExtractor = tileExtractor(maskTile.dataType) private lazy val maskValueExtractor = intArgExtractor(maskValue.dataType) override protected def nullSafeEval(targetInput: Any, maskInput: Any, maskValueInput: Any): Any = { @@ -79,9 +71,10 @@ case class InverseMaskByValue(targetTile: Expression, maskTile: Expression, mask val (mask, maskCtx) = maskTileExtractor(row(maskInput)) val maskValue = maskValueExtractor(maskValueInput).value - val result = targetTile.dualCombine(mask) + val result = maskEval(targetTile, mask, + { (v, m) => if (m != maskValue) NODATA else v }, { (v, m) => if (m != maskValue) NODATA else v } - { (v, m) => if (d2i(m) != maskValue) NODATA else v } + ) toInternalRow(result, targetCtx) } } diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByDefined.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByDefined.scala index 7420be708..a41813ed1 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByDefined.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByDefined.scala @@ -22,12 +22,9 @@ package org.locationtech.rasterframes.expressions.transformers import geotrellis.raster.{NODATA, Tile, isNoData} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.{Column, TypedColumn} import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription} -import org.apache.spark.sql.types.DataType -import org.locationtech.rasterframes.expressions.DynamicExtractors.{tileExtractor} import org.locationtech.rasterframes.expressions.{RasterResult, row} import org.locationtech.rasterframes.tileEncoder @@ -44,36 +41,26 @@ import org.locationtech.rasterframes.tileEncoder ...""" ) case class MaskByDefined(targetTile: Expression, maskTile: Expression) - extends BinaryExpression + extends BinaryExpression with MaskExpression with CodegenFallback with RasterResult { override def nodeName: String = "rf_mask" - def dataType: DataType = targetTile.dataType def left: Expression = targetTile def right: Expression = maskTile protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = MaskByDefined(newLeft, newRight) - override def checkInputDataTypes(): TypeCheckResult = { - if (!tileExtractor.isDefinedAt(targetTile.dataType)) { - TypeCheckFailure(s"Input type '${targetTile.dataType}' does not conform to a raster type.") - } else if (!tileExtractor.isDefinedAt(maskTile.dataType)) { - TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.") - } else TypeCheckSuccess - } - - private lazy val targetTileExtractor = tileExtractor(targetTile.dataType) - private lazy val maskTileExtractor = tileExtractor(maskTile.dataType) + override def checkInputDataTypes(): TypeCheckResult = checkTileDataTypes() override protected def nullSafeEval(targetInput: Any, maskInput: Any): Any = { val (targetTile, targetCtx) = targetTileExtractor(row(targetInput)) val (mask, maskCtx) = maskTileExtractor(row(maskInput)) - - val result = targetTile.dualCombine(mask) - { (v, m) => if (isNoData(m)) NODATA else v } + val result = maskEval(targetTile, mask, + { (v, m) => if (isNoData(m)) NODATA else v }, { (v, m) => if (isNoData(m)) NODATA else v } + ) toInternalRow(result, targetCtx) } } diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByValue.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByValue.scala index eda992bdc..b981ddea2 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByValue.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByValue.scala @@ -21,14 +21,13 @@ package org.locationtech.rasterframes.expressions.transformers -import geotrellis.raster.{NODATA, Tile, d2i} +import geotrellis.raster.{NODATA, Tile} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.{Column, TypedColumn} import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, TernaryExpression} -import org.apache.spark.sql.types.{DataType} -import org.locationtech.rasterframes.expressions.DynamicExtractors.{intArgExtractor, tileExtractor} +import org.locationtech.rasterframes.expressions.DynamicExtractors.intArgExtractor import org.locationtech.rasterframes.expressions.{RasterResult, row} import org.locationtech.rasterframes.tileEncoder @@ -46,14 +45,13 @@ import org.locationtech.rasterframes.tileEncoder > SELECT _FUNC_(target, mask, maskValue); ...""" ) -case class MaskByValue(dataTile: Expression, maskTile: Expression, maskValue: Expression) - extends TernaryExpression +case class MaskByValue(targetTile: Expression, maskTile: Expression, maskValue: Expression) + extends TernaryExpression with MaskExpression with CodegenFallback with RasterResult { override def nodeName: String = "rf_mask_by_value" - def dataType: DataType = dataTile.dataType - def first: Expression = dataTile + def first: Expression = targetTile def second: Expression = maskTile def third: Expression = maskValue @@ -61,27 +59,22 @@ case class MaskByValue(dataTile: Expression, maskTile: Expression, maskValue: Ex MaskByValue(newFirst, newSecond, newThird) override def checkInputDataTypes(): TypeCheckResult = { - if (!tileExtractor.isDefinedAt(dataTile.dataType)) { - TypeCheckFailure(s"Input type '${dataTile.dataType}' does not conform to a raster type.") - } else if (!tileExtractor.isDefinedAt(maskTile.dataType)) { - TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.") - } else if (!intArgExtractor.isDefinedAt(maskValue.dataType)) { + if (!intArgExtractor.isDefinedAt(maskValue.dataType)) { TypeCheckFailure(s"Input type '${maskValue.dataType}' isn't an integral type.") - } else TypeCheckSuccess + } else checkTileDataTypes() } - private lazy val dataTileExtractor = tileExtractor(dataTile.dataType) - private lazy val maskTileExtractor = tileExtractor(maskTile.dataType) private lazy val maskValueExtractor = intArgExtractor(maskValue.dataType) override protected def nullSafeEval(targetInput: Any, maskInput: Any, maskValueInput: Any): Any = { - val (targetTile, targetCtx) = dataTileExtractor(row(targetInput)) + val (targetTile, targetCtx) = targetTileExtractor(row(targetInput)) val (mask, maskCtx) = maskTileExtractor(row(maskInput)) val maskValue = maskValueExtractor(maskValueInput).value - val result = targetTile.dualCombine(mask) + val result = maskEval(targetTile, mask, + { (v, m) => if (m == maskValue) NODATA else v }, { (v, m) => if (m == maskValue) NODATA else v } - { (v, m) => if (d2i(m) == maskValue) NODATA else v } + ) toInternalRow(result, targetCtx) } } diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByValues.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByValues.scala index 39d9d9dd3..6d78a6c61 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByValues.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByValues.scala @@ -21,15 +21,14 @@ package org.locationtech.rasterframes.expressions.transformers -import geotrellis.raster.{NODATA, Tile, d2i} +import geotrellis.raster.{NODATA, Tile} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, TernaryExpression} import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, TypedColumn} -import org.locationtech.rasterframes.expressions.DynamicExtractors.{intArrayExtractor, tileExtractor} +import org.locationtech.rasterframes.expressions.DynamicExtractors.intArrayExtractor import org.locationtech.rasterframes.expressions.{RasterResult, row} import org.locationtech.rasterframes.tileEncoder @@ -48,12 +47,11 @@ import org.locationtech.rasterframes.tileEncoder ...""" ) case class MaskByValues(targetTile: Expression, maskTile: Expression, maskValues: Expression) - extends TernaryExpression + extends TernaryExpression with MaskExpression with CodegenFallback with RasterResult { override def nodeName: String = "rf_mask_by_values" - def dataType: DataType = targetTile.dataType def first: Expression = targetTile def second: Expression = maskTile def third: Expression = maskValues @@ -62,16 +60,10 @@ case class MaskByValues(targetTile: Expression, maskTile: Expression, maskValues MaskByValues(newFirst, newSecond, newThird) override def checkInputDataTypes(): TypeCheckResult = - if (!tileExtractor.isDefinedAt(targetTile.dataType)) { - TypeCheckFailure(s"Input type '${targetTile.dataType}' does not conform to a raster type.") - } else if (!tileExtractor.isDefinedAt(maskTile.dataType)) { - TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.") - } else if (!intArrayExtractor.isDefinedAt(maskValues.dataType)) { + if (!intArrayExtractor.isDefinedAt(maskValues.dataType)) { TypeCheckFailure(s"Input type '${maskValues.dataType}' does not translate to an array.") - } else TypeCheckSuccess + } else checkTileDataTypes() - private lazy val targetTileExtractor = tileExtractor(targetTile.dataType) - private lazy val maskTileExtractor = tileExtractor(maskTile.dataType) private lazy val maskValuesExtractor = intArrayExtractor(maskValues.dataType) override protected def nullSafeEval(targetInput: Any, maskInput: Any, maskValuesInput: Any): Any = { @@ -79,9 +71,10 @@ case class MaskByValues(targetTile: Expression, maskTile: Expression, maskValues val (mask, maskCtx) = maskTileExtractor(row(maskInput)) val maskValues: Array[Int] = maskValuesExtractor(maskValuesInput.asInstanceOf[ArrayData]) - val result = targetTile.dualCombine(mask) + val result = maskEval(targetTile, mask, + { (v, m) => if (maskValues.contains(m)) NODATA else v }, { (v, m) => if (maskValues.contains(m)) NODATA else v } - { (v, m) => if (maskValues.contains(d2i(m))) NODATA else v } + ) toInternalRow(result, targetCtx) } diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskExpression.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskExpression.scala new file mode 100644 index 000000000..a8dbe8e24 --- /dev/null +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskExpression.scala @@ -0,0 +1,74 @@ +/* + * This software is licensed under the Apache 2 license, quoted below. + * + * Copyright 2019 Astraea, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * [http://www.apache.org/licenses/LICENSE-2.0] + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.locationtech.rasterframes.expressions.transformers + +import geotrellis.raster.Tile +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.types.DataType +import org.locationtech.rasterframes.expressions.DynamicExtractors.tileExtractor + +import spire.syntax.cfor._ + +trait MaskExpression { self: Expression => + + def targetTile: Expression + def maskTile: Expression + + def dataType: DataType = targetTile.dataType + + protected lazy val targetTileExtractor = tileExtractor(targetTile.dataType) + protected lazy val maskTileExtractor = tileExtractor(maskTile.dataType) + + def checkTileDataTypes(): TypeCheckResult = { + if (!tileExtractor.isDefinedAt(targetTile.dataType)) { + TypeCheckFailure(s"Input type '${targetTile.dataType}' does not conform to a raster type.") + } else if (!tileExtractor.isDefinedAt(maskTile.dataType)) { + TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.") + } else TypeCheckSuccess + } + + def maskEval(targetTile: Tile, maskTile: Tile, maskInt: (Int, Int) => Int, maskDouble: (Double, Int) => Double): Tile = { + val result = targetTile.mutable + + if (targetTile.cellType.isFloatingPoint) { + cfor(0)(_ < targetTile.rows, _ + 1) { row => + cfor(0)(_ < targetTile.cols, _ + 1) { col => + val v = targetTile.getDouble(col, row) + val m = maskTile.get(col, row) + result.setDouble(col, row, maskDouble(v, m)) + } + } + } else { + cfor(0)(_ < targetTile.rows, _ + 1) { row => + cfor(0)(_ < targetTile.cols, _ + 1) { col => + val v = targetTile.get(col, row) + val m = maskTile.get(col, row) + result.set(col, row, maskInt(v, m)) + } + } + } + + result + } +} From b14adaab9d5fa4d031db7d10bf908ea271c14f63 Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Fri, 13 Jan 2023 16:33:32 -0500 Subject: [PATCH 32/34] Pin GitHub Actions to ubuntu-20.04 --- .github/workflows/build-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 5a6b6a55d..e1105fb1c 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -11,7 +11,7 @@ on: jobs: build: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 container: image: s22s/debian-openjdk-conda-gdal:6790f8d From df552b82a52252ee3dc2712dde409431acecb2c5 Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Fri, 13 Jan 2023 18:17:47 -0500 Subject: [PATCH 33/34] Implement withNewChildrenInternal directly avoid reflection which is done at runtime by structural types --- .../expressions/BinaryRasterFunction.scala | 5 ++--- .../expressions/OnCellGridExpression.scala | 6 ++---- .../expressions/OnTileContextExpression.scala | 6 ++---- .../expressions/SpatialRelation.scala | 15 +++++++++------ .../expressions/UnaryRasterAggregate.scala | 4 +--- .../expressions/UnaryRasterFunction.scala | 6 ++---- .../rasterframes/expressions/UnaryRasterOp.scala | 5 +---- .../expressions/accessors/ExtractTile.scala | 2 ++ .../expressions/accessors/GetCRS.scala | 2 +- .../expressions/accessors/GetCellType.scala | 2 +- .../expressions/accessors/GetDimensions.scala | 2 ++ .../expressions/accessors/GetEnvelope.scala | 2 +- .../expressions/accessors/GetExtent.scala | 2 ++ .../expressions/accessors/GetGeometry.scala | 1 + .../expressions/accessors/GetTileContext.scala | 2 ++ .../expressions/accessors/RealizeTile.scala | 2 +- .../aggregates/CellCountAggregate.scala | 8 ++++++-- .../aggregates/CellMeanAggregate.scala | 2 ++ .../aggregates/LocalMeanAggregate.scala | 2 ++ .../expressions/focalops/FocalMax.scala | 2 ++ .../expressions/focalops/FocalMean.scala | 1 + .../expressions/focalops/FocalMedian.scala | 1 + .../expressions/focalops/FocalMin.scala | 2 ++ .../expressions/focalops/FocalMode.scala | 1 + .../expressions/focalops/FocalMoransI.scala | 1 + .../focalops/FocalNeighborhoodOp.scala | 7 ++----- .../expressions/focalops/FocalStdDev.scala | 1 + .../rasterframes/expressions/localops/Abs.scala | 1 + .../rasterframes/expressions/localops/Add.scala | 5 +++-- .../expressions/localops/BiasedAdd.scala | 5 +++-- .../expressions/localops/Defined.scala | 5 +++-- .../expressions/localops/Divide.scala | 2 ++ .../rasterframes/expressions/localops/Equal.scala | 1 + .../rasterframes/expressions/localops/Exp.scala | 4 ++++ .../expressions/localops/Greater.scala | 2 ++ .../expressions/localops/GreaterEqual.scala | 2 ++ .../expressions/localops/Identity.scala | 1 + .../rasterframes/expressions/localops/Less.scala | 2 ++ .../expressions/localops/LessEqual.scala | 2 ++ .../rasterframes/expressions/localops/Log.scala | 4 ++++ .../rasterframes/expressions/localops/Max.scala | 2 ++ .../rasterframes/expressions/localops/Min.scala | 2 ++ .../expressions/localops/Multiply.scala | 2 ++ .../rasterframes/expressions/localops/Round.scala | 1 + .../rasterframes/expressions/localops/Sqrt.scala | 1 + .../expressions/localops/Subtract.scala | 2 ++ .../expressions/localops/Undefined.scala | 1 + .../expressions/localops/Unequal.scala | 2 ++ .../rasterframes/expressions/package.scala | 4 ---- .../expressions/tilestats/DataCells.scala | 2 ++ .../expressions/tilestats/Exists.scala | 2 +- .../expressions/tilestats/ForAll.scala | 1 + .../expressions/tilestats/IsNoDataTile.scala | 1 + .../expressions/tilestats/NoDataCells.scala | 1 + .../rasterframes/expressions/tilestats/Sum.scala | 1 + .../expressions/tilestats/TileHistogram.scala | 1 + .../expressions/tilestats/TileMax.scala | 1 + .../expressions/tilestats/TileMean.scala | 1 + .../expressions/tilestats/TileMin.scala | 1 + .../expressions/tilestats/TileStats.scala | 1 + .../transformers/CreateProjectedRaster.scala | 2 +- .../expressions/transformers/DebugRender.scala | 8 ++++++-- .../transformers/ExtentToGeometry.scala | 2 +- .../expressions/transformers/ExtractBits.scala | 2 +- .../transformers/GeometryToExtent.scala | 2 +- .../expressions/transformers/InterpretAs.scala | 2 +- .../expressions/transformers/RGBComposite.scala | 2 +- .../transformers/RasterRefToTile.scala | 2 +- .../expressions/transformers/RenderPNG.scala | 6 ++++-- .../transformers/ReprojectGeometry.scala | 2 +- .../expressions/transformers/Rescale.scala | 2 +- .../expressions/transformers/SetCellType.scala | 2 +- .../expressions/transformers/SetNoDataValue.scala | 2 +- .../expressions/transformers/Standardize.scala | 2 +- .../transformers/TileToArrayDouble.scala | 1 + .../expressions/transformers/TileToArrayInt.scala | 2 ++ .../transformers/URIToRasterSource.scala | 2 +- .../expressions/transformers/XZ2Indexer.scala | 2 +- .../expressions/transformers/Z2Indexer.scala | 2 +- 79 files changed, 136 insertions(+), 69 deletions(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/BinaryRasterFunction.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/BinaryRasterFunction.scala index edf61ea2b..425e6c4e7 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/BinaryRasterFunction.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/BinaryRasterFunction.scala @@ -25,14 +25,13 @@ import com.typesafe.scalalogging.Logger import geotrellis.raster.Tile import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression} +import org.apache.spark.sql.catalyst.expressions.BinaryExpression import org.apache.spark.sql.types.DataType import org.locationtech.rasterframes.expressions.DynamicExtractors._ import org.slf4j.LoggerFactory /** Operation combining two tiles or a tile and a scalar into a new tile. */ -trait BinaryRasterFunction extends BinaryExpression with RasterResult { self: HasBinaryExpressionCopy => - override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) +trait BinaryRasterFunction extends BinaryExpression with RasterResult { @transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/OnCellGridExpression.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/OnCellGridExpression.scala index c10df97c1..7d20049d4 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/OnCellGridExpression.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/OnCellGridExpression.scala @@ -26,7 +26,7 @@ import geotrellis.raster.CellGrid import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.UnaryExpression /** * Implements boilerplate for subtype expressions processing TileUDT, RasterSourceUDT, and RasterRefs @@ -34,9 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} * * @since 11/4/18 */ -trait OnCellGridExpression extends UnaryExpression { self: HasUnaryExpressionCopy => - override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) - +trait OnCellGridExpression extends UnaryExpression { private lazy val fromRow: InternalRow => CellGrid[Int] = { if (child.resolved) gridExtractor(child.dataType) else throw new IllegalStateException(s"Child expression unbound: ${child}") diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/OnTileContextExpression.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/OnTileContextExpression.scala index 1c02b1a95..3913ef1cb 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/OnTileContextExpression.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/OnTileContextExpression.scala @@ -25,7 +25,7 @@ import org.locationtech.rasterframes.expressions.DynamicExtractors._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.UnaryExpression import org.locationtech.rasterframes.model.TileContext /** @@ -34,9 +34,7 @@ import org.locationtech.rasterframes.model.TileContext * * @since 11/3/18 */ -trait OnTileContextExpression extends UnaryExpression { self: HasUnaryExpressionCopy => - override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) - +trait OnTileContextExpression extends UnaryExpression { override def checkInputDataTypes(): TypeCheckResult = { if (!projectedRasterLikeExtractor.isDefinedAt(child.dataType)) { TypeCheckFailure(s"Input type '${child.dataType}' does not conform to `ProjectedRasterLike`.") diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/SpatialRelation.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/SpatialRelation.scala index a2589fd5b..3b84797fe 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/SpatialRelation.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/SpatialRelation.scala @@ -39,10 +39,7 @@ import org.locationtech.geomesa.spark.jts.udf.SpatialRelationFunctions._ * * @since 12/28/17 */ -abstract class SpatialRelation extends BinaryExpression with CodegenFallback { this: HasBinaryExpressionCopy => - - override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = - copy(left = newLeft, right = newRight) +abstract class SpatialRelation extends BinaryExpression with CodegenFallback { def extractGeometry(expr: Expression, input: Any): Geometry = { input match { @@ -78,36 +75,42 @@ object SpatialRelation { override def nodeName: String = "intersects" val relation = ST_Intersects - override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = - copy(left = newLeft, right = newRight) + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } case class Contains(left: Expression, right: Expression) extends SpatialRelation { override def nodeName = "contains" val relation = ST_Contains + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } case class Covers(left: Expression, right: Expression) extends SpatialRelation { override def nodeName = "covers" val relation = ST_Covers + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } case class Crosses(left: Expression, right: Expression) extends SpatialRelation { override def nodeName = "crosses" val relation = ST_Crosses + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } case class Disjoint(left: Expression, right: Expression) extends SpatialRelation { override def nodeName = "disjoint" val relation = ST_Disjoint + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } case class Overlaps(left: Expression, right: Expression) extends SpatialRelation { override def nodeName = "overlaps" val relation = ST_Overlaps + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } case class Touches(left: Expression, right: Expression) extends SpatialRelation { override def nodeName = "touches" val relation = ST_Touches + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } case class Within(left: Expression, right: Expression) extends SpatialRelation { override def nodeName = "within" val relation = ST_Within + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } private val predicateMap = Map( diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterAggregate.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterAggregate.scala index 253b1cb0f..585de1530 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterAggregate.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterAggregate.scala @@ -33,7 +33,7 @@ import org.locationtech.rasterframes.encoders.syntax._ import scala.reflect.runtime.universe._ /** Mixin providing boilerplate for DeclarativeAggrates over tile-conforming columns. */ -trait UnaryRasterAggregate extends DeclarativeAggregate { self: HasUnaryExpressionCopy => +trait UnaryRasterAggregate extends DeclarativeAggregate { def child: Expression def nullable: Boolean = child.nullable @@ -42,8 +42,6 @@ trait UnaryRasterAggregate extends DeclarativeAggregate { self: HasUnaryExpressi protected def tileOpAsExpression[R: TypeTag](name: String, op: Tile => R): Expression => ScalaUDF = udfiexpr[R, Any](name, (dataType: DataType) => (a: Any) => if(a == null) null.asInstanceOf[R] else op(UnaryRasterAggregate.extractTileFromAny(dataType, a))) - - override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(newChildren(0)) } object UnaryRasterAggregate { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterFunction.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterFunction.scala index 70a8180c8..6eb4e7a69 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterFunction.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterFunction.scala @@ -25,13 +25,11 @@ import org.locationtech.rasterframes.expressions.DynamicExtractors._ import geotrellis.raster.Tile import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.UnaryExpression import org.locationtech.rasterframes.model.TileContext /** Boilerplate for expressions operating on a single Tile-like . */ -trait UnaryRasterFunction extends UnaryExpression { self: HasUnaryExpressionCopy => - override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) - +trait UnaryRasterFunction extends UnaryExpression { override def checkInputDataTypes(): TypeCheckResult = { if (!tileExtractor.isDefinedAt(child.dataType)) { TypeCheckFailure(s"Input type '${child.dataType}' does not conform to a raster type.") diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterOp.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterOp.scala index da9232600..dcb4871c8 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterOp.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterOp.scala @@ -23,13 +23,12 @@ package org.locationtech.rasterframes.expressions import com.typesafe.scalalogging.Logger import geotrellis.raster.Tile -import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.DataType import org.locationtech.rasterframes.model.TileContext import org.slf4j.LoggerFactory /** Operation on a tile returning a tile. */ -trait UnaryRasterOp extends UnaryRasterFunction with RasterResult { this: HasUnaryExpressionCopy => +trait UnaryRasterOp extends UnaryRasterFunction with RasterResult { @transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName)) def dataType: DataType = child.dataType @@ -38,7 +37,5 @@ trait UnaryRasterOp extends UnaryRasterFunction with RasterResult { this: HasUna toInternalRow(op(tile), ctx) protected def op(child: Tile): Tile - - override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/ExtractTile.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/ExtractTile.scala index ea615843a..c11daac57 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/ExtractTile.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/ExtractTile.scala @@ -43,6 +43,8 @@ case class ExtractTile(child: Expression) extends UnaryRasterFunction with Codeg case prt: ProjectedRasterTile => tileUDT.serialize(prt.tile) case tile: Tile => tileSer(tile) } + + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object ExtractTile { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetCRS.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetCRS.scala index 1f5484b73..d5633741b 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetCRS.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetCRS.scala @@ -97,7 +97,7 @@ case class GetCRS(child: Expression) extends UnaryExpression with CodegenFallbac } } - override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object GetCRS { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetCellType.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetCellType.scala index 89180d757..114533cee 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetCellType.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetCellType.scala @@ -56,7 +56,7 @@ case class GetCellType(child: Expression) extends OnCellGridExpression with Code /** Implemented by subtypes to process incoming ProjectedRasterLike entity. */ def eval(cg: CellGrid[Int]): Any = resultConverter(cg.cellType) - override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object GetCellType { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetDimensions.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetDimensions.scala index 7539d6caa..4ec583cc4 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetDimensions.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetDimensions.scala @@ -46,6 +46,8 @@ case class GetDimensions(child: Expression) extends OnCellGridExpression with Co def dataType = dimensionsEncoder[Int].schema def eval(grid: CellGrid[Int]): Any = Dimensions[Int](grid.cols, grid.rows).toInternalRow + + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object GetDimensions { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetEnvelope.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetEnvelope.scala index 67b32ce49..8ff2443e9 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetEnvelope.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetEnvelope.scala @@ -58,7 +58,7 @@ case class GetEnvelope(child: Expression) extends UnaryExpression with CodegenFa def dataType: DataType = envelopeEncoder.schema - override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object GetEnvelope { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetExtent.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetExtent.scala index 5dfb6781a..1920cd47d 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetExtent.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetExtent.scala @@ -49,6 +49,8 @@ case class GetExtent(child: Expression) extends OnTileContextExpression with Cod override def nodeName: String = "rf_extent" def eval(ctx: TileContext): InternalRow = ctx.extent.toInternalRow + + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object GetExtent { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetGeometry.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetGeometry.scala index de8470180..722624bbb 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetGeometry.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetGeometry.scala @@ -50,6 +50,7 @@ case class GetGeometry(child: Expression) extends OnTileContextExpression with C def eval(ctx: TileContext): InternalRow = JTSTypes.GeometryTypeInstance.serialize(ctx.extent.toPolygon()) + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object GetGeometry { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetTileContext.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetTileContext.scala index eb1fb9675..a41dc697d 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetTileContext.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetTileContext.scala @@ -38,6 +38,8 @@ case class GetTileContext(child: Expression) extends UnaryRasterFunction with Co protected def eval(tile: Tile, ctx: Option[TileContext]): Any = ctx.map(SerializersCache.serializer[TileContext].apply).orNull + + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object GetTileContext { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/RealizeTile.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/RealizeTile.scala index 9e37c62d6..f9381f6e0 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/RealizeTile.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/RealizeTile.scala @@ -57,7 +57,7 @@ case class RealizeTile(child: Expression) extends UnaryExpression with CodegenFa tileSer(tile.toArrayTile()) } - override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object RealizeTile { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/CellCountAggregate.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/CellCountAggregate.scala index 1571a29ac..b36ae27e6 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/CellCountAggregate.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/CellCountAggregate.scala @@ -22,7 +22,7 @@ package org.locationtech.rasterframes.expressions.aggregates import org.locationtech.rasterframes.encoders.SparkBasicEncoders._ -import org.locationtech.rasterframes.expressions.{HasUnaryExpressionCopy, UnaryRasterAggregate} +import org.locationtech.rasterframes.expressions.UnaryRasterAggregate import org.locationtech.rasterframes.expressions.tilestats.{DataCells, NoDataCells} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -35,7 +35,7 @@ import org.apache.spark.sql.{Column, TypedColumn} * @since 10/5/17 * @param isData true if count should be of non-NoData cells, false if count should be of NoData cells. */ -abstract class CellCountAggregate(isData: Boolean) extends UnaryRasterAggregate { self: HasUnaryExpressionCopy => +abstract class CellCountAggregate(isData: Boolean) extends UnaryRasterAggregate { private lazy val count = AttributeReference("count", LongType, false, Metadata.empty)() override lazy val aggBufferAttributes = Seq(count) @@ -68,6 +68,8 @@ object CellCountAggregate { ) case class DataCells(child: Expression) extends CellCountAggregate(true) { override def nodeName: String = "rf_agg_data_cells" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(newChildren.head) } object DataCells { @@ -86,6 +88,8 @@ object CellCountAggregate { ) case class NoDataCells(child: Expression) extends CellCountAggregate(false) { override def nodeName: String = "rf_agg_no_data_cells" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(newChildren.head) } object NoDataCells { def apply(tile: Column): TypedColumn[Any, Long] = diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/CellMeanAggregate.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/CellMeanAggregate.scala index 38b2e453f..d39b80e6e 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/CellMeanAggregate.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/CellMeanAggregate.scala @@ -69,6 +69,8 @@ case class CellMeanAggregate(child: Expression) extends UnaryRasterAggregate { val evaluateExpression = sum / new Cast(count, DoubleType) def dataType: DataType = DoubleType + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(newChildren.head) } object CellMeanAggregate { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/LocalMeanAggregate.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/LocalMeanAggregate.scala index c749b1b8f..ccda2b033 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/LocalMeanAggregate.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/LocalMeanAggregate.scala @@ -69,6 +69,8 @@ case class LocalMeanAggregate(child: Expression) extends UnaryRasterAggregate { BiasedAdd(sum.left, sum.right) ) lazy val evaluateExpression: Expression = DivideTiles(sum, count) + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(newChildren.head) } object LocalMeanAggregate { def apply(tile: Column): TypedColumn[Any, Tile] = diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMax.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMax.scala index 5ca4f386f..c2f829d18 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMax.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMax.scala @@ -44,6 +44,8 @@ case class FocalMax(first: Expression, second: Expression, third: Expression) ex case bt: BufferTile => bt.focalMax(neighborhood, target = target) case _ => t.focalMax(neighborhood, target = target) } + + def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object FocalMax { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMean.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMean.scala index f612d118a..2b64d2dda 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMean.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMean.scala @@ -45,6 +45,7 @@ case class FocalMean(first: Expression, second: Expression, third: Expression) e case _ => t.focalMean(neighborhood, target = target) } + def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object FocalMean { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMedian.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMedian.scala index 7830bae41..3c213d0df 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMedian.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMedian.scala @@ -44,6 +44,7 @@ case class FocalMedian(first: Expression, second: Expression, third: Expression) case bt: BufferTile => bt.focalMedian(neighborhood, target = target) case _ => t.focalMedian(neighborhood, target = target) } + def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object FocalMedian { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMin.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMin.scala index 0baead593..01fe11e8a 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMin.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMin.scala @@ -44,6 +44,8 @@ case class FocalMin(first: Expression, second: Expression, third: Expression) ex case bt: BufferTile => bt.focalMin(neighborhood, target = target) case _ => t.focalMin(neighborhood, target = target) } + + def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object FocalMin { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMode.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMode.scala index 4e4d08c67..daf493bb7 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMode.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMode.scala @@ -44,6 +44,7 @@ case class FocalMode(first: Expression, second: Expression, third: Expression) e case bt: BufferTile => bt.focalMode(neighborhood, target = target) case _ => t.focalMode(neighborhood, target = target) } + def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object FocalMode { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMoransI.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMoransI.scala index 7ab8f1d97..d26bb6996 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMoransI.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalMoransI.scala @@ -44,6 +44,7 @@ case class FocalMoransI(first: Expression, second: Expression, third: Expression case bt: BufferTile => bt.tileMoransI(neighborhood, target = target) case _ => t.tileMoransI(neighborhood, target = target) } + def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object FocalMoransI { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalNeighborhoodOp.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalNeighborhoodOp.scala index 2303c7b7c..4fb409cc3 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalNeighborhoodOp.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalNeighborhoodOp.scala @@ -29,13 +29,10 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, TernaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types.DataType import org.locationtech.rasterframes.expressions.DynamicExtractors.{neighborhoodExtractor, targetCellExtractor, tileExtractor} -import org.locationtech.rasterframes.expressions.{HasTernaryExpressionCopy, RasterResult, row} +import org.locationtech.rasterframes.expressions.{RasterResult, row} import org.slf4j.LoggerFactory -trait FocalNeighborhoodOp extends TernaryExpression with RasterResult with CodegenFallback {self: HasTernaryExpressionCopy => - override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = - copy(newFirst, newSecond, newThird) - +trait FocalNeighborhoodOp extends TernaryExpression with RasterResult with CodegenFallback { @transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName)) // Tile diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalStdDev.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalStdDev.scala index 3887d079c..81f133483 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalStdDev.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/focalops/FocalStdDev.scala @@ -44,6 +44,7 @@ case class FocalStdDev(first: Expression, second: Expression, third: Expression) case bt: BufferTile => bt.focalStandardDeviation(neighborhood, target = target) case _ => t.focalStandardDeviation(neighborhood, target = target) } + def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object FocalStdDev { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Abs.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Abs.scala index ed6cdd950..007886caa 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Abs.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Abs.scala @@ -42,6 +42,7 @@ case class Abs(child: Expression) extends UnaryRasterOp with NullToValue with Co def na: Any = null protected def op(t: Tile): Tile = t.localAbs() + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object Abs { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Add.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Add.scala index 7f231797b..016156167 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Add.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Add.scala @@ -43,8 +43,7 @@ import org.locationtech.rasterframes.expressions.DynamicExtractors > SELECT _FUNC_(tile1, tile2); ...""" ) -case class Add(left: Expression, right: Expression) extends BinaryRasterFunction - with CodegenFallback { +case class Add(left: Expression, right: Expression) extends BinaryRasterFunction with CodegenFallback { override val nodeName: String = "rf_local_add" protected def op(left: Tile, right: Tile): Tile = left.localAdd(right) protected def op(left: Tile, right: Double): Tile = left.localAdd(right) @@ -62,6 +61,8 @@ case class Add(left: Expression, right: Expression) extends BinaryRasterFunction else nullSafeEval(l, r) } } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object Add { def apply(left: Column, right: Column): Column = new Column(Add(left.expr, right.expr)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/BiasedAdd.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/BiasedAdd.scala index 300103154..e35ee8382 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/BiasedAdd.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/BiasedAdd.scala @@ -45,8 +45,7 @@ import org.locationtech.rasterframes.util.DataBiasedOp > SELECT _FUNC_(tile1, tile2); ...""" ) -case class BiasedAdd(left: Expression, right: Expression) extends BinaryRasterFunction - with CodegenFallback { +case class BiasedAdd(left: Expression, right: Expression) extends BinaryRasterFunction with CodegenFallback { override val nodeName: String = "rf_local_biased_add" protected def op(left: Tile, right: Tile): Tile = DataBiasedOp.BiasedAdd(left, right) protected def op(left: Tile, right: Double): Tile = DataBiasedOp.BiasedAdd(left, right) @@ -64,6 +63,8 @@ case class BiasedAdd(left: Expression, right: Expression) extends BinaryRasterFu else nullSafeEval(l, r) } } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object BiasedAdd { def apply(left: Column, right: Column): Column = new Column(BiasedAdd(left.expr, right.expr)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Defined.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Defined.scala index 035a5ad84..280fd41f2 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Defined.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Defined.scala @@ -37,11 +37,12 @@ import org.locationtech.rasterframes.expressions.{NullToValue, UnaryRasterOp} > SELECT _FUNC_(tile); ...""" ) -case class Defined(child: Expression) extends UnaryRasterOp - with NullToValue with CodegenFallback { +case class Defined(child: Expression) extends UnaryRasterOp with NullToValue with CodegenFallback { override def nodeName: String = "rf_local_data" def na: Any = null protected def op(child: Tile): Tile = child.localDefined() + + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object Defined{ def apply(tile: Column): Column = new Column(Defined(tile.expr)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Divide.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Divide.scala index ce0d0be1c..0f81cc788 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Divide.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Divide.scala @@ -46,6 +46,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryRasterFunct protected def op(left: Tile, right: Tile): Tile = left.localDivide(right) protected def op(left: Tile, right: Double): Tile = left.localDivide(right) protected def op(left: Tile, right: Int): Tile = left.localDivide(right) + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object Divide { def apply(left: Column, right: Column): Column = new Column(Divide(left.expr, right.expr)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Equal.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Equal.scala index 29f622c78..36692b2d9 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Equal.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Equal.scala @@ -45,6 +45,7 @@ case class Equal(left: Expression, right: Expression) extends BinaryRasterFuncti protected def op(left: Tile, right: Double): Tile = left.localEqual(right) protected def op(left: Tile, right: Int): Tile = left.localEqual(right) + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object Equal { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Exp.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Exp.scala index 21f57d1f6..89499b234 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Exp.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Exp.scala @@ -44,6 +44,7 @@ case class Exp(child: Expression) extends UnaryRasterOp with CodegenFallback { protected def op(tile: Tile): Tile = fpTile(tile).localPowValue(math.E) override def dataType: DataType = child.dataType + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object Exp { def apply(tile: Column): Column = new Column(Exp(tile.expr)) @@ -65,6 +66,7 @@ case class Exp10(child: Expression) extends UnaryRasterOp with CodegenFallback { override protected def op(tile: Tile): Tile = fpTile(tile).localPowValue(10.0) override def dataType: DataType = child.dataType + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object Exp10 { def apply(tile: Column): Column = new Column(Exp10(tile.expr)) @@ -86,6 +88,7 @@ case class Exp2(child: Expression) extends UnaryRasterOp with CodegenFallback { protected def op(tile: Tile): Tile = fpTile(tile).localPowValue(2.0) override def dataType: DataType = child.dataType + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object Exp2 { def apply(tile: Column): Column = new Column(Exp2(tile.expr)) @@ -107,6 +110,7 @@ case class ExpM1(child: Expression) extends UnaryRasterOp with CodegenFallback { protected def op(tile: Tile): Tile = fpTile(tile).localPowValue(math.E).localSubtract(1.0) override def dataType: DataType = child.dataType + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object ExpM1 { def apply(tile: Column): Column = new Column(ExpM1(tile.expr)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Greater.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Greater.scala index e820f94f5..688326cd6 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Greater.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Greater.scala @@ -43,6 +43,8 @@ case class Greater(left: Expression, right: Expression) extends BinaryRasterFunc protected def op(left: Tile, right: Tile): Tile = left.localGreater(right) protected def op(left: Tile, right: Double): Tile = left.localGreater(right) protected def op(left: Tile, right: Int): Tile = left.localGreater(right) + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object Greater { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/GreaterEqual.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/GreaterEqual.scala index dd33e3415..cce792479 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/GreaterEqual.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/GreaterEqual.scala @@ -44,6 +44,8 @@ case class GreaterEqual(left: Expression, right: Expression) extends BinaryRaste protected def op(left: Tile, right: Tile): Tile = left.localGreaterOrEqual(right) protected def op(left: Tile, right: Double): Tile = left.localGreaterOrEqual(right) protected def op(left: Tile, right: Int): Tile = left.localGreaterOrEqual(right) + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object GreaterEqual { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Identity.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Identity.scala index 418ddf780..9c441e636 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Identity.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Identity.scala @@ -41,6 +41,7 @@ case class Identity(child: Expression) extends UnaryRasterOp with NullToValue wi override def nodeName: String = "rf_identity" def na: Any = null protected def op(t: Tile): Tile = t + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object Identity { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Less.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Less.scala index 8f5ac719f..d570a7901 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Less.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Less.scala @@ -43,6 +43,8 @@ case class Less(left: Expression, right: Expression) extends BinaryRasterFunctio protected def op(left: Tile, right: Tile): Tile = left.localLess(right) protected def op(left: Tile, right: Double): Tile = left.localLess(right) protected def op(left: Tile, right: Int): Tile = left.localLess(right) + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object Less { def apply(left: Column, right: Column): Column = new Column(Less(left.expr, right.expr)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/LessEqual.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/LessEqual.scala index ae51ab2f1..7ca5f51a0 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/LessEqual.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/LessEqual.scala @@ -44,6 +44,8 @@ case class LessEqual(left: Expression, right: Expression) extends BinaryRasterFu protected def op(left: Tile, right: Tile): Tile = left.localLessOrEqual(right) protected def op(left: Tile, right: Double): Tile = left.localLessOrEqual(right) protected def op(left: Tile, right: Int): Tile = left.localLessOrEqual(right) + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object LessEqual { def apply(left: Column, right: Column): Column = new Column(LessEqual(left.expr, right.expr)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Log.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Log.scala index 2ebd84412..53b443a1f 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Log.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Log.scala @@ -44,6 +44,7 @@ case class Log(child: Expression) extends UnaryRasterOp with CodegenFallback { protected def op(tile: Tile): Tile = fpTile(tile).localLog() override def dataType: DataType = child.dataType + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object Log { def apply(tile: Column): Column = new Column(Log(tile.expr)) @@ -65,6 +66,7 @@ case class Log10(child: Expression) extends UnaryRasterOp with CodegenFallback { protected def op(tile: Tile): Tile = fpTile(tile).localLog10() override def dataType: DataType = child.dataType + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object Log10 { def apply(tile: Column): Column = new Column(Log10(tile.expr)) @@ -86,6 +88,7 @@ case class Log2(child: Expression) extends UnaryRasterOp with CodegenFallback { protected def op(tile: Tile): Tile = fpTile(tile).localLog() / math.log(2.0) override def dataType: DataType = child.dataType + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object Log2 { def apply(tile: Column): Column = new Column(Log2(tile.expr)) @@ -107,6 +110,7 @@ case class Log1p(child: Expression) extends UnaryRasterOp with CodegenFallback { protected def op(tile: Tile): Tile = fpTile(tile).localAdd(1.0).localLog() override def dataType: DataType = child.dataType + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object Log1p { def apply(tile: Column): Column = new Column(Log1p(tile.expr)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Max.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Max.scala index 01019543f..d075b65d4 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Max.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Max.scala @@ -47,6 +47,8 @@ case class Max(left: Expression, right:Expression) extends BinaryRasterFunction protected def op(left: Tile, right: Tile): Tile = left.localMax(right) protected def op(left: Tile, right: Double): Tile = left.localMax(right) protected def op(left: Tile, right: Int): Tile = left.localMax(right) + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object Max { def apply(left: Column, right: Column): Column = new Column(Max(left.expr, right.expr)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Min.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Min.scala index 171812929..61bf7b180 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Min.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Min.scala @@ -47,6 +47,8 @@ case class Min(left: Expression, right:Expression) extends BinaryRasterFunction protected def op(left: Tile, right: Tile): Tile = left.localMin(right) protected def op(left: Tile, right: Double): Tile = left.localMin(right) protected def op(left: Tile, right: Int): Tile = left.localMin(right) + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object Min { def apply(left: Column, right: Column): Column = new Column(Min(left.expr, right.expr)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Multiply.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Multiply.scala index 7bf3367d4..bc822c16c 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Multiply.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Multiply.scala @@ -46,6 +46,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryRasterFun protected def op(left: Tile, right: Tile): Tile = left.localMultiply(right) protected def op(left: Tile, right: Double): Tile = left.localMultiply(right) protected def op(left: Tile, right: Int): Tile = left.localMultiply(right) + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object Multiply { def apply(left: Column, right: Column): Column = new Column(Multiply(left.expr, right.expr)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Round.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Round.scala index d4238c27f..acadc93f6 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Round.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Round.scala @@ -41,6 +41,7 @@ case class Round(child: Expression) extends UnaryRasterOp with NullToValue with override def nodeName: String = "rf_round" def na: Any = null protected def op(child: Tile): Tile = child.localRound() + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object Round{ def apply(tile: Column): Column = new Column(Round(tile.expr)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Sqrt.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Sqrt.scala index ad3ed376d..d98f0bb8b 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Sqrt.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Sqrt.scala @@ -44,6 +44,7 @@ case class Sqrt(child: Expression) extends UnaryRasterOp with CodegenFallback { override val nodeName: String = "rf_sqrt" protected def op(tile: Tile): Tile = fpTile(tile).localPow(0.5) override def dataType: DataType = child.dataType + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object Sqrt { def apply(tile: Column): Column = new Column(Sqrt(tile.expr)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Subtract.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Subtract.scala index 708e7e207..bfcd403fc 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Subtract.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Subtract.scala @@ -46,6 +46,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryRasterFun protected def op(left: Tile, right: Tile): Tile = left.localSubtract(right) protected def op(left: Tile, right: Double): Tile = left.localSubtract(right) protected def op(left: Tile, right: Int): Tile = left.localSubtract(right) + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object Subtract { def apply(left: Column, right: Column): Column = new Column(Subtract(left.expr, right.expr)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Undefined.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Undefined.scala index bd533f4b7..863fadb94 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Undefined.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Undefined.scala @@ -41,6 +41,7 @@ case class Undefined(child: Expression) extends UnaryRasterOp with NullToValue w override def nodeName: String = "rf_local_no_data" def na: Any = null protected def op(child: Tile): Tile = child.localUndefined() + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object Undefined { def apply(tile: Column): Column = new Column(Undefined(tile.expr)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Unequal.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Unequal.scala index 9bab9b86b..72c526ce9 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Unequal.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Unequal.scala @@ -44,6 +44,8 @@ case class Unequal(left: Expression, right: Expression) extends BinaryRasterFunc protected def op(left: Tile, right: Tile): Tile = left.localUnequal(right) protected def op(left: Tile, right: Double): Tile = left.localUnequal(right) protected def op(left: Tile, right: Int): Tile = left.localUnequal(right) + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object Unequal { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala index 7237c720c..8fea88ee2 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala @@ -52,10 +52,6 @@ import scala.language.implicitConversions * @since 10/10/17 */ package object expressions { - type HasTernaryExpressionCopy = { def copy(first: Expression, second: Expression, third: Expression): Expression } - type HasBinaryExpressionCopy = { def copy(left: Expression, right: Expression): Expression } - type HasUnaryExpressionCopy = { def copy(child: Expression): Expression } - private[expressions] def row(input: Any) = input.asInstanceOf[InternalRow] /** Convert the tile to a floating point type as needed for scalar operations. */ @inline diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/DataCells.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/DataCells.scala index 52dc8c1ed..1694ffb75 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/DataCells.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/DataCells.scala @@ -45,6 +45,8 @@ case class DataCells(child: Expression) extends UnaryRasterFunction with Codegen def dataType: DataType = LongType protected def eval(tile: Tile, ctx: Option[TileContext]): Any = DataCells.op(tile) def na: Any = 0L + + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object DataCells { def apply(tile: Column): TypedColumn[Any, Long] = new Column(DataCells(tile.expr)).as[Long] diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/Exists.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/Exists.scala index ebb2156d7..4941d6500 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/Exists.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/Exists.scala @@ -28,7 +28,7 @@ case class Exists(child: Expression) extends UnaryRasterFunction with CodegenFal override def nodeName: String = "exists" def dataType: DataType = BooleanType protected def eval(tile: Tile, ctx: Option[TileContext]): Any = Exists.op(tile) - + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object Exists { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/ForAll.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/ForAll.scala index f553de047..d60de56a7 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/ForAll.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/ForAll.scala @@ -28,6 +28,7 @@ case class ForAll(child: Expression) extends UnaryRasterFunction with CodegenFal override def nodeName: String = "for_all" def dataType: DataType = BooleanType protected def eval(tile: Tile, ctx: Option[TileContext]): Any = ForAll.op(tile) + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object ForAll { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/IsNoDataTile.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/IsNoDataTile.scala index e03b96194..4e5f25c51 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/IsNoDataTile.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/IsNoDataTile.scala @@ -46,6 +46,7 @@ case class IsNoDataTile(child: Expression) extends UnaryRasterFunction def na: Any = true def dataType: DataType = BooleanType protected def eval(tile: Tile, ctx: Option[TileContext]): Any = tile.isNoDataTile + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object IsNoDataTile { def apply(tile: Column): TypedColumn[Any, Boolean] = new Column(IsNoDataTile(tile.expr)).as[Boolean] diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/NoDataCells.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/NoDataCells.scala index 556abd715..8077544e3 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/NoDataCells.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/NoDataCells.scala @@ -45,6 +45,7 @@ case class NoDataCells(child: Expression) extends UnaryRasterFunction with Codeg def dataType: DataType = LongType protected def eval(tile: Tile, ctx: Option[TileContext]): Any = NoDataCells.op(tile) def na: Any = 0L + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object NoDataCells { def apply(tile: Column): TypedColumn[Any, Long] = diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/Sum.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/Sum.scala index 9e3ff1f8c..4576c0117 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/Sum.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/Sum.scala @@ -44,6 +44,7 @@ case class Sum(child: Expression) extends UnaryRasterFunction with CodegenFallba override def nodeName: String = "rf_tile_sum" def dataType: DataType = DoubleType protected def eval(tile: Tile, ctx: Option[TileContext]): Any = Sum.op(tile) + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object Sum { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileHistogram.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileHistogram.scala index a4a5fffa3..60cc6a047 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileHistogram.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileHistogram.scala @@ -46,6 +46,7 @@ case class TileHistogram(child: Expression) extends UnaryRasterFunction with Cod protected def eval(tile: Tile, ctx: Option[TileContext]): Any = TileHistogram.converter(TileHistogram.op(tile)) def dataType: DataType = CellHistogram.schema + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object TileHistogram { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileMax.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileMax.scala index cbbe1a52c..ce6ee2e99 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileMax.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileMax.scala @@ -45,6 +45,7 @@ case class TileMax(child: Expression) extends UnaryRasterFunction with NullToVal protected def eval(tile: Tile, ctx: Option[TileContext]): Any = TileMax.op(tile) def dataType: DataType = DoubleType def na: Any = Double.MinValue + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object TileMax { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileMean.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileMean.scala index 2f0bdedb5..52227171d 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileMean.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileMean.scala @@ -45,6 +45,7 @@ case class TileMean(child: Expression) extends UnaryRasterFunction with NullToVa protected def eval(tile: Tile, ctx: Option[TileContext]): Any = TileMean.op(tile) def dataType: DataType = DoubleType def na: Any = Double.NaN + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object TileMean { def apply(tile: Column): TypedColumn[Any, Double] = new Column(TileMean(tile.expr)).as[Double] diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileMin.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileMin.scala index c3d26fb4a..f68e6f0a6 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileMin.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileMin.scala @@ -45,6 +45,7 @@ case class TileMin(child: Expression) extends UnaryRasterFunction with NullToVal protected def eval(tile: Tile, ctx: Option[TileContext]): Any = TileMin.op(tile) def dataType: DataType = DoubleType def na: Any = Double.MaxValue + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object TileMin { def apply(tile: Column): TypedColumn[Any, Double] = new Column(TileMin(tile.expr)).as[Double] diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileStats.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileStats.scala index ebf6bf67c..2eb8b4d3f 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileStats.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/TileStats.scala @@ -46,6 +46,7 @@ case class TileStats(child: Expression) extends UnaryRasterFunction with Codegen protected def eval(tile: Tile, ctx: Option[TileContext]): Any = TileStats.converter(TileStats.op(tile).orNull) def dataType: DataType = CellStatistics.schema + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object TileStats { def apply(tile: Column): TypedColumn[Any, CellStatistics] = new Column(TileStats(tile.expr)).as[CellStatistics] diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/CreateProjectedRaster.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/CreateProjectedRaster.scala index 99c7124e5..3a98ceab9 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/CreateProjectedRaster.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/CreateProjectedRaster.scala @@ -72,7 +72,7 @@ case class CreateProjectedRaster(tile: Expression, extent: Expression, crs: Expr toInternalRow(prt) } - override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) + def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object CreateProjectedRaster { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/DebugRender.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/DebugRender.scala index c310dc80c..53c211393 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/DebugRender.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/DebugRender.scala @@ -29,11 +29,11 @@ import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.sql.{Column, TypedColumn} import org.apache.spark.unsafe.types.UTF8String import org.locationtech.rasterframes.encoders.SparkBasicEncoders._ -import org.locationtech.rasterframes.expressions.{HasUnaryExpressionCopy, UnaryRasterFunction} +import org.locationtech.rasterframes.expressions.UnaryRasterFunction import org.locationtech.rasterframes.model.TileContext import spire.syntax.cfor.cfor -abstract class DebugRender(asciiArt: Boolean) extends UnaryRasterFunction with CodegenFallback with Serializable { self: HasUnaryExpressionCopy => +abstract class DebugRender(asciiArt: Boolean) extends UnaryRasterFunction with CodegenFallback with Serializable { import org.locationtech.rasterframes.expressions.transformers.DebugRender.TileAsMatrix def dataType: DataType = StringType @@ -55,6 +55,8 @@ object DebugRender { ) case class RenderAscii(child: Expression) extends DebugRender(true) { override def nodeName: String = "rf_render_ascii" + + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object RenderAscii { def apply(tile: Column): TypedColumn[Any, String] = new Column(RenderAscii(tile.expr)).as[String] @@ -68,6 +70,8 @@ object DebugRender { ) case class RenderMatrix(child: Expression) extends DebugRender(false) { override def nodeName: String = "rf_render_matrix" + + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object RenderMatrix { def apply(tile: Column): TypedColumn[Any, String] = new Column(RenderMatrix(tile.expr)).as[String] diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtentToGeometry.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtentToGeometry.scala index 8b922de4d..09586279b 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtentToGeometry.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtentToGeometry.scala @@ -60,7 +60,7 @@ case class ExtentToGeometry(child: Expression) extends UnaryExpression with Code JTSTypes.GeometryTypeInstance.serialize(geom) } - override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object ExtentToGeometry extends SpatialEncoders { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala index 4412c2a9f..b077df1ae 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala @@ -69,7 +69,7 @@ case class ExtractBits(first: Expression, second: Expression, third: Expression) protected def op(tile: Tile, startBit: Int, numBits: Int): Tile = ExtractBits(tile, startBit, numBits) - override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) + def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object ExtractBits{ diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/GeometryToExtent.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/GeometryToExtent.scala index 43e96311c..97ffdda13 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/GeometryToExtent.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/GeometryToExtent.scala @@ -56,7 +56,7 @@ case class GeometryToExtent(child: Expression) extends UnaryExpression with Code Extent(geom.getEnvelopeInternal).toInternalRow } - override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object GeometryToExtent { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InterpretAs.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InterpretAs.scala index 91fb9ab81..b5eeb29c6 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InterpretAs.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InterpretAs.scala @@ -82,7 +82,7 @@ case class InterpretAs(tile: Expression, cellType: Expression) extends BinaryExp toInternalRow(result, ctx) } - override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) + def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object InterpretAs{ diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RGBComposite.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RGBComposite.scala index f33cc8ca0..cd6173cdd 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RGBComposite.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RGBComposite.scala @@ -88,7 +88,7 @@ case class RGBComposite(red: Expression, green: Expression, blue: Expression) ex toInternalRow(composite, ctx) } - override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) + def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object RGBComposite { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RasterRefToTile.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RasterRefToTile.scala index 261a3a6c5..e364f68ef 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RasterRefToTile.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RasterRefToTile.scala @@ -54,7 +54,7 @@ case class RasterRefToTile(child: Expression) extends UnaryExpression ProjectedRasterTile(ref.tile, ref.extent, ref.crs).toInternalRow } - override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object RasterRefToTile { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RenderPNG.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RenderPNG.scala index be539a4dd..8e1324b71 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RenderPNG.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RenderPNG.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescript import org.apache.spark.sql.types.{BinaryType, DataType} import org.apache.spark.sql.{Column, TypedColumn} import org.locationtech.rasterframes.encoders.SparkBasicEncoders._ -import org.locationtech.rasterframes.expressions.{HasUnaryExpressionCopy, UnaryRasterFunction} +import org.locationtech.rasterframes.expressions.UnaryRasterFunction import org.locationtech.rasterframes.model.TileContext /** @@ -36,7 +36,7 @@ import org.locationtech.rasterframes.model.TileContext * @param child tile column * @param ramp color ramp to use for non-composite tiles. */ -abstract class RenderPNG(child: Expression, ramp: Option[ColorRamp]) extends UnaryRasterFunction with CodegenFallback with Serializable { self: HasUnaryExpressionCopy => +abstract class RenderPNG(child: Expression, ramp: Option[ColorRamp]) extends UnaryRasterFunction with CodegenFallback with Serializable { def dataType: DataType = BinaryType protected def eval(tile: Tile, ctx: Option[TileContext]): Any = { val png = ramp.map(tile.renderPng).getOrElse(tile.renderPng()) @@ -54,6 +54,7 @@ object RenderPNG { ) case class RenderCompositePNG(child: Expression) extends RenderPNG(child, None) { override def nodeName: String = "rf_render_png" + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object RenderCompositePNG { @@ -70,6 +71,7 @@ object RenderPNG { case class RenderColorRampPNG(child: Expression, colors: ColorRamp) extends RenderPNG(child, Some(colors)) { override def nodeName: String = "rf_render_png" def copy(child: Expression): Expression = RenderColorRampPNG(child, colors: ColorRamp) + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object RenderColorRampPNG { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ReprojectGeometry.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ReprojectGeometry.scala index 036d9192d..94b3768ed 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ReprojectGeometry.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ReprojectGeometry.scala @@ -90,7 +90,7 @@ case class ReprojectGeometry(geometry: Expression, srcCRS: Expression, dstCRS: E } } - override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) + def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object ReprojectGeometry { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Rescale.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Rescale.scala index 7dabef32d..c241431be 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Rescale.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Rescale.scala @@ -79,7 +79,7 @@ case class Rescale(first: Expression, second: Expression, third: Expression) ext .normalize(min, max, 0.0, 1.0) } - override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/SetCellType.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/SetCellType.scala index ee311a593..f23671858 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/SetCellType.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/SetCellType.scala @@ -86,7 +86,7 @@ case class SetCellType(tile: Expression, cellType: Expression) extends BinaryExp toInternalRow(result, ctx) } - override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) + def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object SetCellType { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/SetNoDataValue.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/SetNoDataValue.scala index 52fdfc6cb..8d27c7b41 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/SetNoDataValue.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/SetNoDataValue.scala @@ -71,7 +71,7 @@ case class SetNoDataValue(left: Expression, right: Expression) extends BinaryExp toInternalRow(result, leftCtx) } - override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) + def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } object SetNoDataValue { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Standardize.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Standardize.scala index 3d69682f4..e2440726f 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Standardize.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Standardize.scala @@ -77,7 +77,7 @@ case class Standardize(first: Expression, second: Expression, third: Expression) .localSubtract(mean) .localDivide(stdDev) - override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird) } object Standardize { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/TileToArrayDouble.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/TileToArrayDouble.scala index a856b917b..3731fdcb8 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/TileToArrayDouble.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/TileToArrayDouble.scala @@ -42,6 +42,7 @@ case class TileToArrayDouble(child: Expression) extends UnaryRasterFunction with def dataType: DataType = DataTypes.createArrayType(DoubleType, false) protected def eval(tile: Tile, ctx: Option[TileContext]): Any = ArrayData.toArrayData(tile.toArrayDouble()) + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object TileToArrayDouble { def apply(tile: Column): TypedColumn[Any, Array[Double]] = diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/TileToArrayInt.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/TileToArrayInt.scala index e6bbbd4a7..ebee7f25e 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/TileToArrayInt.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/TileToArrayInt.scala @@ -42,6 +42,8 @@ case class TileToArrayInt(child: Expression) extends UnaryRasterFunction with Co def dataType: DataType = DataTypes.createArrayType(IntegerType, false) protected def eval(tile: Tile, ctx: Option[TileContext]): Any = ArrayData.toArrayData(tile.toArray()) + + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object TileToArrayInt { def apply(tile: Column): TypedColumn[Any, Array[Int]] = diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/URIToRasterSource.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/URIToRasterSource.scala index fcab58900..786908282 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/URIToRasterSource.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/URIToRasterSource.scala @@ -54,7 +54,7 @@ case class URIToRasterSource(override val child: Expression) extends UnaryExpres rasterSourceUDT.serialize(ref) } - override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) + def withNewChildInternal(newChild: Expression): Expression = copy(newChild) } object URIToRasterSource { diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/XZ2Indexer.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/XZ2Indexer.scala index 28a9a099c..649c9a55b 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/XZ2Indexer.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/XZ2Indexer.scala @@ -88,7 +88,7 @@ case class XZ2Indexer(left: Expression, right: Expression, indexResolution: Shor index } - override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = + def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Z2Indexer.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Z2Indexer.scala index 2b8844e44..b2c8a2f6f 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Z2Indexer.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Z2Indexer.scala @@ -83,7 +83,7 @@ case class Z2Indexer(left: Expression, right: Expression, indexResolution: Short indexer.index(pt.getX, pt.getY, lenient = true) } - override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = + def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(newLeft, newRight) } From c5cf70c0abe4c32e3e0a21a533dc3b19b0886048 Mon Sep 17 00:00:00 2001 From: Eugene Cheipesh Date: Fri, 13 Jan 2023 18:45:41 -0500 Subject: [PATCH 34/34] Remove python build from CI It needs more work at another time --- .github/workflows/build-test.yml | 37 ++++++++++++++------------------ 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index e1105fb1c..97afa087b 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -4,35 +4,26 @@ on: pull_request: branches: ['**'] push: - branches: ['master', 'develop', 'release/*'] + branches: ['master', 'develop', 'release/*', 'spark-3.2'] tags: [v*] release: types: [published] jobs: build: - runs-on: ubuntu-20.04 - container: - image: s22s/debian-openjdk-conda-gdal:6790f8d + runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 with: fetch-depth: 0 - uses: coursier/cache-action@v6 - - uses: olafurpg/setup-scala@v13 + - name: Setup JDK + uses: actions/setup-java@v3 with: - java-version: adopt@1.11 - - - name: Set up Python 3.8 - uses: actions/setup-python@v2 - with: - python-version: 3.8 - - - name: Install Conda dependencies - run: | - # $CONDA_DIR is an environment variable pointing to the root of the miniconda directory - $CONDA_DIR/bin/conda install -c conda-forge --yes --file pyrasterframes/src/main/python/requirements-condaforge.txt + distribution: temurin + java-version: 8 + cache: sbt # Do just the compilation stage to minimize sbt memory footprint - name: Compile @@ -47,11 +38,15 @@ jobs: - name: Experimental tests run: sbt -batch experimental/test - - name: Create PyRasterFrames package - run: sbt -v -batch pyrasterframes/package - - - name: Python tests - run: sbt -batch pyrasterframes/test + ## TODO: Update python build to be PEP 517 compatible + # - name: Install Conda dependencies + # run: | + # # $CONDA_DIR is an environment variable pointing to the root of the miniconda directory + # $CONDA_DIR/bin/conda install -c conda-forge --yes --file pyrasterframes/src/main/python/requirements-condaforge.txt + # - name: Create PyRasterFrames package + # run: sbt -v -batch pyrasterframes/package + # - name: Python tests + # run: sbt -batch pyrasterframes/test - name: Collect artifacts if: ${{ failure() }}