Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rf_local_is_in function #400

Merged
merged 9 commits into from
Nov 5, 2019
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,9 @@ trait RasterFunctions {
/** Cellwise inequality comparison between a tile and a scalar. */
def rf_local_unequal[T: Numeric](tileCol: Column, value: T): Column = Unequal(tileCol, value)

/** Test if each cell value is in provided array */
def rf_local_is_in(tileCol: Column, arrayCol: Column) = IsIn(tileCol, arrayCol)

/** Return a tile with ones where the input is NoData, otherwise zero */
def rf_local_no_data(tileCol: Column): Column = Undefined(tileCol)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* This software is licensed under the Apache 2 license, quoted below.
*
* Copyright 2019 Astraea, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* [http://www.apache.org/licenses/LICENSE-2.0]
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*
* SPDX-License-Identifier: Apache-2.0
*
*/

package org.locationtech.rasterframes.expressions.localops

import geotrellis.raster.Tile
import geotrellis.raster.mapalgebra.local.IfCell
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.types.{ArrayType, DataType}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.rf.TileUDT
import org.locationtech.rasterframes.encoders.CatalystSerializer._
import org.locationtech.rasterframes.expressions.DynamicExtractors._
import org.locationtech.rasterframes.expressions._

@ExpressionDescription(
usage = "_FUNC_(tile, rhs) - In each cell of `tile`, return true if the value is in rhs.",
arguments = """
Arguments:
* tile - tile column to apply abs
* rhs - array to test against
""",
examples = """
Examples:
> SELECT _FUNC_(tile, array(lit(33), lit(66), lit(99)));
..."""
)
case class IsIn(left: Expression, right: Expression) extends BinaryExpression with CodegenFallback {
override val nodeName: String = "rf_local_is_in"

override def dataType: DataType = left.dataType

@transient private lazy val elementType: DataType = right.dataType.asInstanceOf[ArrayType].elementType

override def checkInputDataTypes(): TypeCheckResult =
if(!tileExtractor.isDefinedAt(left.dataType)) {
TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a raster type.")
} else right.dataType match {
case _: ArrayType ⇒ TypeCheckSuccess
case _ ⇒ TypeCheckFailure(s"Input type '${right.dataType}' does not conform to ArrayType.")
}

override protected def nullSafeEval(input1: Any, input2: Any): Any = {
implicit val tileSer = TileUDT.tileSerializer
val (childTile, childCtx) = tileExtractor(left.dataType)(row(input1))

val arr = input2.asInstanceOf[ArrayData].toArray[AnyRef](elementType)

childCtx match {
case Some(ctx) => ctx.toProjectRasterTile(op(childTile, arr)).toInternalRow
case None => op(childTile, arr).toInternalRow
}

}

protected def op(left: Tile, right: IndexedSeq[AnyRef]): Tile = {
def fn(i: Int): Boolean = right.contains(i)
IfCell(left, fn(_), 1, 0)
}

}

object IsIn {
def apply(left: Column, right: Column): Column =
new Column(IsIn(left.expr, right.expr))
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ package object expressions {
registry.registerExpression[GreaterEqual]("rf_local_greater_equal")
registry.registerExpression[Equal]("rf_local_equal")
registry.registerExpression[Unequal]("rf_local_unequal")
registry.registerExpression[IsIn]("rf_local_is_in")
registry.registerExpression[Undefined]("rf_local_no_data")
registry.registerExpression[Defined]("rf_local_data")
registry.registerExpression[Sum]("rf_tile_sum")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -972,4 +972,28 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
val dResult = df.select($"ld").as[Tile].first()
dResult should be (randNDPRT.localDefined())
}

it("should check values isin"){
checkDocs("rf_local_is_in")

// tile is 3 by 3 with values, 1 to 9
val df = Seq(byteArrayTile).toDF("t")
.withColumn("one", lit(1))
.withColumn("five", lit(5))
.withColumn("ten", lit(10))
.withColumn("in_expect_2", rf_local_is_in($"t", array($"one", $"five")))
.withColumn("in_expect_1", rf_local_is_in($"t", array($"ten", $"five")))
.withColumn("in_expect_0", rf_local_is_in($"t", array($"ten")))

val e2Result = df.select(rf_tile_sum($"in_expect_2")).as[Double].first()
e2Result should be (2.0)

val e1Result = df.select(rf_tile_sum($"in_expect_1")).as[Double].first()
e1Result should be (1.0)

val e0Result = df.select($"in_expect_0").as[Tile].first()
e0Result.toArray() should contain only (0)

// lazy val invalid = df.select(rf_local_is_in($"t", lit("foobar"))).as[Tile].first()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ Parameters `tile_columns` and `tile_rows` are literals, not column expressions.

Tile rf_array_to_tile(Array arrayCol, Int numCols, Int numRows)

Python only. Create a `tile` from a Spark SQL [Array](http://spark.apache.org/docs/2.3.2/api/python/pyspark.sql.html#pyspark.sql.types.ArrayType), filling values in row-major order.
Python only. Create a `tile` from a Spark SQL [Array][Array], filling values in row-major order.

### rf_assemble_tile

Expand Down Expand Up @@ -383,6 +383,13 @@ Returns a `tile` column containing the element-wise equality of `tile1` and `rhs

Returns a `tile` column containing the element-wise inequality of `tile1` and `rhs`.

### rf_local_is_in

Tile rf_local_is_in(Tile tile, Array array)
Tile rf_local_is_in(Tile tile, list l)

Returns a `tile` column with cell values of 1 where the `tile` cell value is in the provided array or list. The `array` is a Spark SQL [Array][Array]. A python `list` of numeric values can also be passed.

### rf_round

Tile rf_round(Tile tile)
Expand Down Expand Up @@ -630,13 +637,13 @@ Python only. As with @ref:[`rf_explode_tiles`](reference.md#rf-explode-tiles), b

Array rf_tile_to_array_int(Tile tile)

Convert Tile column to Spark SQL [Array](http://spark.apache.org/docs/2.3.2/api/python/pyspark.sql.html#pyspark.sql.types.ArrayType), in row-major order. Float cell types will be coerced to integral type by flooring.
Convert Tile column to Spark SQL [Array][Array], in row-major order. Float cell types will be coerced to integral type by flooring.

### rf_tile_to_array_double

Array rf_tile_to_arry_double(Tile tile)

Convert tile column to Spark [Array](http://spark.apache.org/docs/2.3.2/api/python/pyspark.sql.html#pyspark.sql.types.ArrayType), in row-major order. Integral cell types will be coerced to floats.
Convert tile column to Spark [Array][Array], in row-major order. Integral cell types will be coerced to floats.

### rf_render_ascii

Expand Down Expand Up @@ -666,3 +673,4 @@ Runs [`rf_rgb_composite`](reference.md#rf-rgb-composite) on the given tile colum

[RasterFunctions]: org.locationtech.rasterframes.RasterFunctions
[scaladoc]: latest/api/index.html
[Array]: http://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.types.ArrayType
1 change: 1 addition & 0 deletions docs/src/main/paradox/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

* _Breaking_ (potentially): removed `GeoTiffCollectionRelation` due to usage limitation and overlap with `RasterSourceDataSource` functionality.
* Upgraded to Spark 2.4.4
* Add `rf_local_is_in` raster function

### 0.8.3

Expand Down
24 changes: 7 additions & 17 deletions pyrasterframes/src/main/python/docs/nodata-handling.pymd
Original file line number Diff line number Diff line change
Expand Up @@ -105,32 +105,23 @@ Drawing on @ref:[local map algebra](local-algebra.md) techniques, we will create
```python, def_mask
from pyspark.sql.functions import lit

mask_part = unmasked.withColumn('nodata', rf_local_equal('scl', lit(0))) \
.withColumn('defect', rf_local_equal('scl', lit(1))) \
.withColumn('cloud8', rf_local_equal('scl', lit(8))) \
.withColumn('cloud9', rf_local_equal('scl', lit(9))) \
.withColumn('cirrus', rf_local_equal('scl', lit(10)))

one_mask = mask_part.withColumn('mask', rf_local_add('nodata', 'defect')) \
.withColumn('mask', rf_local_add('mask', 'cloud8')) \
.withColumn('mask', rf_local_add('mask', 'cloud9')) \
.withColumn('mask', rf_local_add('mask', 'cirrus'))

cell_types = one_mask.select(rf_cell_type('mask')).distinct()
mask = unmasked.withColumn('mask', rf_local_is_in('scl', [0, 1, 8, 9, 10]))

cell_types = mask.select(rf_cell_type('mask')).distinct()
cell_types
```

Because there is not a NoData already defined, we will choose one. In this particular example, the minimum value is greater than zero, so we can use 0 as the NoData value.

```python, pick_nd
blue_min = one_mask.agg(rf_agg_stats('blue').min.alias('blue_min'))
blue_min = mask.agg(rf_agg_stats('blue').min.alias('blue_min'))
blue_min
```

We can now construct the cell type string for our blue band's cell type, designating 0 as NoData.

```python, get_ct_string
blue_ct = one_mask.select(rf_cell_type('blue')).distinct().first()[0][0]
blue_ct = mask.select(rf_cell_type('blue')).distinct().first()[0][0]
masked_blue_ct = CellType(blue_ct).with_no_data_value(0)
masked_blue_ct.cell_type_name
```
Expand All @@ -139,9 +130,8 @@ Now we will use the @ref:[`rf_mask_by_value`](reference.md#rf-mask-by-value) to

```python, mask_blu
with_nd = rf_convert_cell_type('blue', masked_blue_ct)
masked = one_mask.withColumn('blue_masked',
rf_mask_by_value(with_nd, 'mask', lit(1))) \
.drop('nodata', 'defect', 'cloud8', 'cloud9', 'cirrus', 'blue')
masked = mask.withColumn('blue_masked',
rf_mask_by_value(with_nd, 'mask', lit(1)))
```

We can verify that the number of NoData cells in the resulting `blue_masked` column matches the total of the boolean `mask` _tile_ to ensure our logic is correct.
Expand Down
52 changes: 24 additions & 28 deletions pyrasterframes/src/main/python/docs/supervised-learning.pymd
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ catalog_df = pd.DataFrame([
{b: uri_base.format(b) for b in cols}
])

df = spark.read.raster(catalog_df, catalog_col_names=cols, tile_dimensions=(128, 128)) \
tile_size = 256
df = spark.read.raster(catalog_df, catalog_col_names=cols, tile_dimensions=(tile_size, tile_size)) \
.repartition(100)

df = df.select(
Expand Down Expand Up @@ -91,23 +92,12 @@ To filter only for good quality pixels, we follow roughly the same procedure as
```python, make_mask
from pyspark.sql.functions import lit

mask_part = df_labeled \
.withColumn('nodata', rf_local_equal('scl', lit(0))) \
.withColumn('defect', rf_local_equal('scl', lit(1))) \
.withColumn('cloud8', rf_local_equal('scl', lit(8))) \
.withColumn('cloud9', rf_local_equal('scl', lit(9))) \
.withColumn('cirrus', rf_local_equal('scl', lit(10)))

df_mask_inv = mask_part \
.withColumn('mask', rf_local_add('nodata', 'defect')) \
.withColumn('mask', rf_local_add('mask', 'cloud8')) \
.withColumn('mask', rf_local_add('mask', 'cloud9')) \
.withColumn('mask', rf_local_add('mask', 'cirrus')) \
.drop('nodata', 'defect', 'cloud8', 'cloud9', 'cirrus')

df_labeled = df_labeled \
.withColumn('mask', rf_local_is_in('scl', [0, 1, 8, 9, 10]))

# at this point the mask contains 0 for good cells and 1 for defect, etc
# convert cell type and set value 1 to NoData
df_mask = df_mask_inv.withColumn('mask',
df_mask = df_labeled.withColumn('mask',
rf_with_no_data(rf_convert_cell_type('mask', 'uint8'), 1.0)
)

Expand Down Expand Up @@ -204,29 +194,35 @@ scored = model.transform(df_mask.drop('label'))
retiled = scored \
.groupBy('extent', 'crs') \
.agg(
rf_assemble_tile('column_index', 'row_index', 'prediction', 128, 128).alias('prediction'),
rf_assemble_tile('column_index', 'row_index', 'B04', 128, 128).alias('red'),
rf_assemble_tile('column_index', 'row_index', 'B03', 128, 128).alias('grn'),
rf_assemble_tile('column_index', 'row_index', 'B02', 128, 128).alias('blu')
rf_assemble_tile('column_index', 'row_index', 'prediction', tile_size, tile_size).alias('prediction'),
rf_assemble_tile('column_index', 'row_index', 'B04', tile_size, tile_size).alias('red'),
rf_assemble_tile('column_index', 'row_index', 'B03', tile_size, tile_size).alias('grn'),
rf_assemble_tile('column_index', 'row_index', 'B02', tile_size, tile_size).alias('blu')
)
retiled.printSchema()
```

Take a look at a sample of the resulting output and the corresponding area's red-green-blue composite image.
Recall the label coding: 1 is forest (purple), 2 is cropland (green) and 3 is developed areas(yellow).

```python, display_rgb
sample = retiled \
.select('prediction', rf_rgb_composite('red', 'grn', 'blu').alias('rgb')) \
.select('prediction', 'red', 'grn', 'blu') \
.sort(-rf_tile_sum(rf_local_equal('prediction', lit(3.0)))) \
.first()

sample_rgb = sample['rgb']
mins = np.nanmin(sample_rgb.cells, axis=(0,1))
plt.imshow((sample_rgb.cells - mins) / (np.nanmax(sample_rgb.cells, axis=(0,1)) - mins))
```
sample_rgb = np.concatenate([sample['red'].cells[:, :, None],
sample['grn'].cells[ :, :, None],
sample['blu'].cells[ :, :, None]], axis=2)
# plot scaled RGB
scaling_quantiles = np.nanpercentile(sample_rgb, [3.00, 97.00], axis=(0,1))
scaled = np.clip(sample_rgb, scaling_quantiles[0, :], scaling_quantiles[1, :])
scaled -= scaling_quantiles[0, :]
scaled /= (scaling_quantiles[1, : ] - scaling_quantiles[0, :])

Recall the label coding: 1 is forest (purple), 2 is cropland (green) and 3 is developed areas(yellow).
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.imshow(scaled)

```python, display_prediction
display(sample['prediction'])
# display prediction
ax2.imshow(sample['prediction'].cells)
```
10 changes: 10 additions & 0 deletions pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,24 @@ def rf_local_unequal_int(tile_col, scalar):
"""Return a Tile with values equal 1 if the cell is not equal to a scalar, otherwise 0"""
return _apply_scalar_to_tile('rf_local_unequal_int', tile_col, scalar)


def rf_local_no_data(tile_col):
"""Return a tile with ones where the input is NoData, otherwise zero."""
return _apply_column_function('rf_local_no_data', tile_col)


def rf_local_data(tile_col):
"""Return a tile with zeros where the input is NoData, otherwise one."""
return _apply_column_function('rf_local_data', tile_col)

def rf_local_is_in(tile_col, array):
"""Return a tile with cell values of 1 where the `tile_col` cell is in the provided array."""
from pyspark.sql.functions import array as sql_array, lit
if isinstance(array, list):
array = sql_array([lit(v) for v in array])

return _apply_column_function('rf_local_is_in', tile_col, array)

def _apply_column_function(name, *args):
jfcn = RFContext.active().lookup(name)
jcols = [_to_java_column(arg) for arg in args]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_tile_udt_serialization(self):
cells[1][1] = nd
a_tile = Tile(cells, ct.with_no_data_value(nd))
round_trip = udt.fromInternal(udt.toInternal(a_tile))
self.assertEquals(a_tile, round_trip, "round-trip serialization for " + str(ct))
self.assertEqual(a_tile, round_trip, "round-trip serialization for " + str(ct))

schema = StructType([StructField("tile", TileUDT(), False)])
df = self.spark.createDataFrame([{"tile": a_tile}], schema)
Expand Down
Loading