diff --git a/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSource.scala b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSource.scala index 6cea717ec..5aec1a065 100644 --- a/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSource.scala +++ b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSource.scala @@ -44,12 +44,12 @@ object RasterSourceDataSource { final val SHORT_NAME = "raster" final val PATH_PARAM = "path" final val PATHS_PARAM = "paths" - final val BAND_INDEXES_PARAM = "bandIndexes" - final val TILE_DIMS_PARAM = "tileDimensions" - final val CATALOG_TABLE_PARAM = "catalogTable" - final val CATALOG_TABLE_COLS_PARAM = "catalogColumns" - final val CATALOG_CSV_PARAM = "catalogCSV" - final val LAZY_TILES_PARAM = "lazyTiles" + final val BAND_INDEXES_PARAM = "band_indexes" + final val TILE_DIMS_PARAM = "tile_dimensions" + final val CATALOG_TABLE_PARAM = "catalog_table" + final val CATALOG_TABLE_COLS_PARAM = "catalog_col_names" + final val CATALOG_CSV_PARAM = "catalog_csv" + final val LAZY_TILES_PARAM = "lazy_tiles" final val DEFAULT_COLUMN_NAME = PROJECTED_RASTER_COLUMN.columnName diff --git a/pyrasterframes/src/main/python/docs/languages.pymd b/pyrasterframes/src/main/python/docs/languages.pymd index 1d4895e2e..ead79337f 100644 --- a/pyrasterframes/src/main/python/docs/languages.pymd +++ b/pyrasterframes/src/main/python/docs/languages.pymd @@ -42,7 +42,7 @@ red_nir_monthly_2017.printSchema() ```python, step_3_python red_nir_tiles_monthly_2017 = spark.read.raster( - catalog=red_nir_monthly_2017, + red_nir_monthly_2017, catalog_col_names=['red', 'nir'], tile_dimensions=(256, 256) ) @@ -97,9 +97,9 @@ sql(""" CREATE OR REPLACE TEMPORARY VIEW red_nir_tiles_monthly_2017 USING raster OPTIONS ( - catalogTable='red_nir_monthly_2017', - catalogColumns='red,nir', - tileDimensions='256,256' + catalog_table='red_nir_monthly_2017', + catalog_col_names='red,nir', + tile_dimensions='256,256' ) """) ``` diff --git a/pyrasterframes/src/main/python/docs/local-algebra.pymd b/pyrasterframes/src/main/python/docs/local-algebra.pymd index 696186313..fc83ae2d2 100644 --- a/pyrasterframes/src/main/python/docs/local-algebra.pymd +++ b/pyrasterframes/src/main/python/docs/local-algebra.pymd @@ -40,7 +40,7 @@ catalog_df = spark.createDataFrame([ Row(red=uri_pattern.format(4), nir=uri_pattern.format(8)) ]) df = spark.read.raster( - catalog=catalog_df, + catalog_df, catalog_col_names=['red', 'nir'] ) df.printSchema() diff --git a/pyrasterframes/src/main/python/docs/nodata-handling.pymd b/pyrasterframes/src/main/python/docs/nodata-handling.pymd index a4534c8b1..90eeacbb5 100644 --- a/pyrasterframes/src/main/python/docs/nodata-handling.pymd +++ b/pyrasterframes/src/main/python/docs/nodata-handling.pymd @@ -90,7 +90,7 @@ from pyspark.sql import Row blue_uri = 'https://s22s-test-geotiffs.s3.amazonaws.com/luray_snp/B02.tif' scl_uri = 'https://s22s-test-geotiffs.s3.amazonaws.com/luray_snp/SCL.tif' cat = spark.createDataFrame([Row(blue=blue_uri, scl=scl_uri),]) -unmasked = spark.read.raster(catalog=cat, catalog_col_names=['blue', 'scl']) +unmasked = spark.read.raster(cat, catalog_col_names=['blue', 'scl']) unmasked.printSchema() ``` diff --git a/pyrasterframes/src/main/python/docs/numpy-pandas.pymd b/pyrasterframes/src/main/python/docs/numpy-pandas.pymd index 86f5bad3f..5622af7b3 100644 --- a/pyrasterframes/src/main/python/docs/numpy-pandas.pymd +++ b/pyrasterframes/src/main/python/docs/numpy-pandas.pymd @@ -51,7 +51,7 @@ cat = spark.read.format('aws-pds-modis-catalog').load() \ (col('acquisition_date') < lit('2018-02-22')) ) -spark_df = spark.read.raster(catalog=cat, catalog_col_names=['B01']) \ +spark_df = spark.read.raster(cat, catalog_col_names=['B01']) \ .select( 'acquisition_date', 'granule_id', diff --git a/pyrasterframes/src/main/python/docs/raster-read.pymd b/pyrasterframes/src/main/python/docs/raster-read.pymd index f9a1170b2..53f3a96e6 100644 --- a/pyrasterframes/src/main/python/docs/raster-read.pymd +++ b/pyrasterframes/src/main/python/docs/raster-read.pymd @@ -101,8 +101,6 @@ modis_catalog = spark.read \ .withColumn('red' , F.concat('base_url', F.lit("_B01.TIF"))) \ .withColumn('nir' , F.concat('base_url', F.lit("_B02.TIF"))) -modis_catalog.printSchema() - print("Available scenes: ", modis_catalog.count()) ``` @@ -124,10 +122,7 @@ equator.select('date', 'gid') Now that we have prepared our catalog, we simply pass the DataFrame or CSV string to the `raster` DataSource to load the imagery. The `catalog_col_names` parameter gives the columns that contain the URI's to be read. ```python, read_catalog -rf = spark.read.raster( - catalog=equator, - catalog_col_names=['red', 'nir'] -) +rf = spark.read.raster(equator, catalog_col_names=['red', 'nir']) rf.printSchema() ``` @@ -179,7 +174,7 @@ mb.printSchema() If a band is passed into `band_indexes` that exceeds the number of bands in the raster, a projected raster column will still be generated in the schema but the column will be full of `null` values. -You can also pass a `catalog` and `band_indexes` together into the `raster` reader. This will create a projected raster column for the combination of all items passed into `catalog_col_names` and `band_indexes`. Again if a band in `band_indexes` exceeds the number of bands in a raster, it will have a `null` value for the corresponding column. +You can also pass a _catalog_ and `band_indexes` together into the `raster` reader. This will create a projected raster column for the combination of all items in `catalog_col_names` and `band_indexes`. Again if a band in `band_indexes` exceeds the number of bands in a raster, it will have a `null` value for the corresponding column. Here is a trivial example with a _catalog_ over multiband rasters. We specify two columns containing URIs and two bands, resulting in four projected raster columns. @@ -191,7 +186,7 @@ mb_cat = pd.DataFrame([ }, ]) mb2 = spark.read.raster( - catalog=spark.createDataFrame(mb_cat), + spark.createDataFrame(mb_cat), catalog_col_names=['foo', 'bar'], band_indexes=[0, 1], tile_dimensions=(64,64) diff --git a/pyrasterframes/src/main/python/docs/supervised-learning.pymd b/pyrasterframes/src/main/python/docs/supervised-learning.pymd index 0a3f8c0ef..9f2cd968f 100644 --- a/pyrasterframes/src/main/python/docs/supervised-learning.pymd +++ b/pyrasterframes/src/main/python/docs/supervised-learning.pymd @@ -33,10 +33,8 @@ catalog_df = pd.DataFrame([ {b: uri_base.format(b) for b in cols} ]) -df = spark.read.raster(catalog=catalog_df, - catalog_col_names=cols, - tile_dimensions=(128, 128) - ).repartition(100) +df = spark.read.raster(catalog_df, catalog_col_names=cols, tile_dimensions=(128, 128)) \ + .repartition(100) df = df.select( rf_crs(df.B01).alias('crs'), diff --git a/pyrasterframes/src/main/python/docs/time-series.pymd b/pyrasterframes/src/main/python/docs/time-series.pymd index eadcb3ffd..0e0cbed00 100644 --- a/pyrasterframes/src/main/python/docs/time-series.pymd +++ b/pyrasterframes/src/main/python/docs/time-series.pymd @@ -97,7 +97,7 @@ We then [reproject](https://gis.stackexchange.com/questions/247770/understanding ```python read_catalog raster_cols = ['B01', 'B02',] # red and near-infrared respectively park_rf = spark.read.raster( - catalog=park_cat.select(['acquisition_date', 'granule_id', 'geo_simp'] + raster_cols), + park_cat.select(['acquisition_date', 'granule_id', 'geo_simp'] + raster_cols), catalog_col_names=raster_cols) \ .withColumn('park_native', st_reproject('geo_simp', lit('EPSG:4326'), rf_crs('B01'))) \ .filter(st_intersects('park_native', rf_geometry('B01'))) diff --git a/pyrasterframes/src/main/python/docs/unsupervised-learning.pymd b/pyrasterframes/src/main/python/docs/unsupervised-learning.pymd index 800f7e749..f2158d807 100644 --- a/pyrasterframes/src/main/python/docs/unsupervised-learning.pymd +++ b/pyrasterframes/src/main/python/docs/unsupervised-learning.pymd @@ -37,7 +37,7 @@ filenamePattern = "L8-B{}-Elkton-VA.tiff" catalog_df = pd.DataFrame([ {'b' + str(b): os.path.join(resource_dir_uri(), filenamePattern.format(b)) for b in range(1, 8)} ]) -df = spark.read.raster(catalog=catalog_df, catalog_col_names=catalog_df.columns) +df = spark.read.raster(catalog_df, catalog_col_names=catalog_df.columns) df = df.select( rf_crs(df.b1).alias('crs'), rf_extent(df.b1).alias('extent'), diff --git a/pyrasterframes/src/main/python/pyrasterframes/__init__.py b/pyrasterframes/src/main/python/pyrasterframes/__init__.py index 1fa5e91cf..905323d60 100644 --- a/pyrasterframes/src/main/python/pyrasterframes/__init__.py +++ b/pyrasterframes/src/main/python/pyrasterframes/__init__.py @@ -110,16 +110,34 @@ def _aliased_writer(df_writer, format_key, path, **options): def _raster_reader( df_reader, - path=None, - catalog=None, + source=None, catalog_col_names=None, band_indexes=None, tile_dimensions=(256, 256), lazy_tiles=True, **options): + """ + Returns a Spark DataFrame from raster data files specified by URIs. + Each row in the returned DataFrame will contain a column with struct of (CRS, Extent, Tile) for each item in + `catalog_col_names`. + Multiple bands from the same raster file are spread across rows of the DataFrame. See `band_indexes` param. + If bands from a scene are stored in separate files, provide a DataFrame to the `source` parameter. + + For more details and example usage, consult https://rasterframes.io/raster-read.html + + :param source: a string, list of strings, list of lists of strings, a Pandas DataFrame or a Spark DataFrame giving URIs to the raster data to read. + :param catalog_col_names: required if `source` is a DataFrame or CSV string. It is a list of strings giving the names of columns containing URIs to read. + :param band_indexes: list of integers indicating which bands, zero-based, to read from the raster files specified; default is to read only the first band. + :param tile_dimensions: tuple or list of two indicating the default tile dimension as (columns, rows). + :param lazy_tiles: If true (default) only generate minimal references to tile contents; if false, fetch tile cell values. + :param options: Additional keyword arguments to pass to the Spark DataSource. + """ from pandas import DataFrame as PdDataFrame + if 'catalog' in options: + source = options['catalog'] # maintain back compatibility with 0.8.0 + def to_csv(comp): if isinstance(comp, str): return comp @@ -135,28 +153,66 @@ def temp_name(): band_indexes = [0] options.update({ - "bandIndexes": to_csv(band_indexes), - "tileDimensions": to_csv(tile_dimensions), - "lazyTiles": lazy_tiles + "band_indexes": to_csv(band_indexes), + "tile_dimensions": to_csv(tile_dimensions), + "lazy_tiles": lazy_tiles }) + # Parse the `source` argument + path = None # to pass into `path` param + if isinstance(source, list): + if all([isinstance(i, str) for i in source]): + path = None + catalog = None + options.update(dict(paths='\n'.join([str(i) for i in source]))) # pass in "uri1\nuri2\nuri3\n..." + if all([isinstance(i, list) for i in source]): + # list of lists; we will rely on pandas to: + # - coerce all data to str (possibly using objects' __str__ or __repr__) + # - ensure data is not "ragged": all sublists are same len + path = None + catalog_col_names = ['proj_raster_{}'.format(i) for i in range(len(source[0]))] # assign these names + catalog = PdDataFrame(source, + columns=catalog_col_names, + dtype=str, + ) + elif isinstance(source, str): + if '\n' in source or '\r' in source: + # then the `source` string is a catalog as a CSV (header is required) + path = None + catalog = source + else: + # interpret source as a single URI string + path = source + catalog = None + else: + # user has passed in some other type, we will try to interpret as a catalog + catalog = source + if catalog is not None: if catalog_col_names is None: raise Exception("'catalog_col_names' required when DataFrame 'catalog' specified") + if isinstance(catalog, str): options.update({ - "catalogCSV": catalog, - "catalogColumns": to_csv(catalog_col_names) + "catalog_csv": catalog, + "catalog_col_names": to_csv(catalog_col_names) }) elif isinstance(catalog, DataFrame): + # check catalog_col_names + assert all([c in catalog.columns for c in catalog_col_names]), \ + "All items in catalog_col_names must be the name of a column in the catalog DataFrame." # Create a random view name tmp_name = temp_name() catalog.createOrReplaceTempView(tmp_name) options.update({ - "catalogTable": tmp_name, - "catalogColumns": to_csv(catalog_col_names) + "catalog_table": tmp_name, + "catalog_col_names": to_csv(catalog_col_names) }) elif isinstance(catalog, PdDataFrame): + # check catalog_col_names + assert all([c in catalog.columns for c in catalog_col_names]), \ + "All items in catalog_col_names must be the name of a column in the catalog DataFrame." + # Handle to active spark session session = SparkContext._active_spark_context._rf_context._spark_session # Create a random view name @@ -164,8 +220,8 @@ def temp_name(): spark_catalog = session.createDataFrame(catalog) spark_catalog.createOrReplaceTempView(tmp_name) options.update({ - "catalogTable": tmp_name, - "catalogColumns": to_csv(catalog_col_names) + "catalog_table": tmp_name, + "catalog_col_names": to_csv(catalog_col_names) }) return df_reader \ diff --git a/pyrasterframes/src/main/python/tests/PyRasterFramesTests.py b/pyrasterframes/src/main/python/tests/PyRasterFramesTests.py index feee746eb..6092410bb 100644 --- a/pyrasterframes/src/main/python/tests/PyRasterFramesTests.py +++ b/pyrasterframes/src/main/python/tests/PyRasterFramesTests.py @@ -410,8 +410,6 @@ def test_raster_join(self): self.rf.raster_join(rf_prime, join_exprs=self.rf.extent) - - def suite(): function_tests = unittest.TestSuite() return function_tests diff --git a/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py b/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py index 2a57cf356..ac89c2448 100644 --- a/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py +++ b/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py @@ -286,7 +286,7 @@ def test_render_composite(self): cat = self.spark.createDataFrame([ Row(red=self.l8band_uri(4), green=self.l8band_uri(3), blue=self.l8band_uri(2)) ]) - rf = self.spark.read.raster(catalog = cat, catalog_col_names=['red', 'green', 'blue']) + rf = self.spark.read.raster(cat, catalog_col_names=cat.columns) # Test composite construction rgb = rf.select(rf_tile(rf_rgb_composite('red', 'green', 'blue')).alias('rgb')).first()['rgb'] diff --git a/pyrasterframes/src/main/python/tests/RasterSourceTest.py b/pyrasterframes/src/main/python/tests/RasterSourceTest.py new file mode 100644 index 000000000..5f7967a49 --- /dev/null +++ b/pyrasterframes/src/main/python/tests/RasterSourceTest.py @@ -0,0 +1,198 @@ +# +# This software is licensed under the Apache 2 license, quoted below. +# +# Copyright 2019 Astraea, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# [http://www.apache.org/licenses/LICENSE-2.0] +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# +# SPDX-License-Identifier: Apache-2.0 +# + +from pyrasterframes.rasterfunctions import * +from pyrasterframes.rf_types import * +from pyspark.sql.functions import * +import pandas as pd +from shapely.geometry import Point +import os.path +from unittest import skip +from . import TestEnvironment + + +class RasterSourceTest(TestEnvironment): + + @staticmethod + def path(scene, band): + scene_dict = { + 1: 'https://landsat-pds.s3.amazonaws.com/c1/L8/015/041/LC08_L1TP_015041_20190305_20190309_01_T1/LC08_L1TP_015041_20190305_20190309_01_T1_B{}.TIF', + 2: 'https://landsat-pds.s3.amazonaws.com/c1/L8/015/042/LC08_L1TP_015042_20190305_20190309_01_T1/LC08_L1TP_015042_20190305_20190309_01_T1_B{}.TIF', + 3: 'https://landsat-pds.s3.amazonaws.com/c1/L8/016/041/LC08_L1TP_016041_20190224_20190309_01_T1/LC08_L1TP_016041_20190224_20190309_01_T1_B{}.TIF', + } + + assert band in range(1, 12) + assert scene in scene_dict.keys() + p = scene_dict[scene] + return p.format(band) + + def path_pandas_df(self): + return pd.DataFrame([ + {'b1': self.path(1, 1), 'b2': self.path(1, 2), 'b3': self.path(1, 3), 'geo': Point(1, 1)}, + {'b1': self.path(2, 1), 'b2': self.path(2, 2), 'b3': self.path(2, 3), 'geo': Point(2, 2)}, + {'b1': self.path(3, 1), 'b2': self.path(3, 2), 'b3': self.path(3, 3), 'geo': Point(3, 3)}, + ]) + + + def test_handle_lazy_eval(self): + df = self.spark.read.raster(self.path(1, 1)) + ltdf = df.select('proj_raster') + self.assertGreater(ltdf.count(), 0) + self.assertIsNotNone(ltdf.first().proj_raster) + + tdf = df.select(rf_tile('proj_raster').alias('pr')) + self.assertGreater(tdf.count(), 0) + self.assertIsNotNone(tdf.first().pr) + + def test_strict_eval(self): + df_lazy = self.spark.read.raster(self.img_uri, lazy_tiles=True) + # when doing Show on a lazy tile we will see something like RasterRefTile(RasterRef(JVMGeoTiffRasterSource(... + # use this trick to get the `show` string + show_str_lazy = df_lazy.select('proj_raster')._jdf.showString(1, -1, False) + self.assertTrue('RasterRef' in show_str_lazy) + + # again for strict + df_strict = self.spark.read.raster(self.img_uri, lazy_tiles=False) + show_str_strict = df_strict.select('proj_raster')._jdf.showString(1, -1, False) + self.assertTrue('RasterRef' not in show_str_strict) + + def test_prt_functions(self): + df = self.spark.read.raster(self.img_uri) \ + .withColumn('crs', rf_crs('proj_raster')) \ + .withColumn('ext', rf_extent('proj_raster')) \ + .withColumn('geom', rf_geometry('proj_raster')) + df.select('crs', 'ext', 'geom').first() + + def test_list_of_str(self): + # much the same as RasterSourceDataSourceSpec here; but using https PDS. Takes about 30s to run + + def l8path(b): + assert b in range(1, 12) + base = "https://s3-us-west-2.amazonaws.com/landsat-pds/c1/L8/199/026/LC08_L1TP_199026_20180919_20180928_01_T1/LC08_L1TP_199026_20180919_20180928_01_T1_B{}.TIF" + return base.format(b) + + path_param = [l8path(b) for b in [1, 2, 3]] + tile_size = 512 + + df = self.spark.read.raster( + path_param, + tile_dimensions=(tile_size, tile_size), + lazy_tiles=True, + ).cache() + + print(df.take(3)) + + # schema is tile_path and tile + # df.printSchema() + self.assertTrue(len(df.columns) == 2 and 'proj_raster_path' in df.columns and 'proj_raster' in df.columns) + + # the most common tile dimensions should be as passed to `options`, showing that options are correctly applied + tile_size_df = df.select(rf_dimensions(df.proj_raster).rows.alias('r'), rf_dimensions(df.proj_raster).cols.alias('c')) \ + .groupby(['r', 'c']).count().toPandas() + most_common_size = tile_size_df.loc[tile_size_df['count'].idxmax()] + self.assertTrue(most_common_size.r == tile_size and most_common_size.c == tile_size) + + # all rows are from a single source URI + path_count = df.groupby(df.proj_raster_path).count() + print(path_count.collect()) + self.assertTrue(path_count.count() == 3) + + def test_list_of_list_of_str(self): + lol = [ + [self.path(1, 1), self.path(1, 2)], + [self.path(2, 1), self.path(2, 2)], + [self.path(3, 1), self.path(3, 2)] + ] + df = self.spark.read.raster(lol) + self.assertTrue(len(df.columns) == 4) # 2 cols of uris plus 2 cols of proj_rasters + self.assertEqual(sorted(df.columns), sorted(['proj_raster_0_path', 'proj_raster_1_path', + 'proj_raster_0', 'proj_raster_1'])) + uri_df = df.select('proj_raster_0_path', 'proj_raster_1_path').distinct().collect() + uri_list = [list(r.asDict().values()) for r in uri_df] + self.assertTrue(lol[0] in uri_list) + self.assertTrue(lol[1] in uri_list) + self.assertTrue(lol[2] in uri_list) + + def test_schemeless_string(self): + import os.path + path = os.path.join(self.resource_dir, "L8-B8-Robinson-IL.tiff") + self.assertTrue(not path.startswith('file://')) + self.assertTrue(os.path.exists(path)) + df = self.spark.read.raster(path) + self.assertTrue(df.count() > 0) + + def test_spark_df_source(self): + catalog_columns = ['b1', 'b2', 'b3'] + catalog = self.spark.createDataFrame(self.path_pandas_df()) + + df = self.spark.read.raster( + catalog, + tile_dimensions=(512, 512), + catalog_col_names=catalog_columns, + lazy_tiles=True # We'll get an OOM error if we try to read 9 scenes all at once! + ) + + self.assertTrue(len(df.columns) == 7) # three bands times {path, tile} plus geo + self.assertTrue(df.select('b1_path').distinct().count() == 3) # as per scene_dict + b1_paths_maybe = df.select('b1_path').distinct().collect() + b1_paths = [self.path(s, 1) for s in [1, 2, 3]] + self.assertTrue(all([row.b1_path in b1_paths for row in b1_paths_maybe])) + + def test_pandas_source(self): + + df = self.spark.read.raster( + self.path_pandas_df(), + catalog_col_names=['b1', 'b2', 'b3'] + ) + self.assertEqual(len(df.columns), 7) # three path cols, three tile cols, and geo + self.assertTrue('geo' in df.columns) + self.assertTrue(df.select('b1_path').distinct().count() == 3) + + def test_geopandas_source(self): + from geopandas import GeoDataFrame + # Same test as test_pandas_source with geopandas + geo_df = GeoDataFrame(self.path_pandas_df(), crs={'init': 'EPSG:4326'}, geometry='geo') + df = self.spark.read.raster(geo_df, ['b1', 'b2', 'b3']) + + self.assertEqual(len(df.columns), 7) # three path cols, three tile cols, and geo + self.assertTrue('geo' in df.columns) + self.assertTrue(df.select('b1_path').distinct().count() == 3) + + def test_csv_string(self): + + s = """metadata,b1,b2 + a,{},{} + b,{},{} + c,{},{} + """.format( + self.path(1, 1), self.path(1, 2), + self.path(2, 1), self.path(2, 2), + self.path(3, 1), self.path(3, 2), + ) + + df = self.spark.read.raster(s, ['b1', 'b2']) + self.assertEqual(len(df.columns), 3 + 2) # number of columns in original DF plus cardinality of catalog_col_names + self.assertTrue(len(df.take(1))) # non-empty check + + def test_catalog_named_arg(self): + # through version 0.8.1 reading a catalog was via named argument only. + df = self.spark.read.raster(catalog=self.path_pandas_df(), catalog_col_names=['b1', 'b2', 'b3']) + self.assertEqual(len(df.columns), 7) # three path cols, three tile cols, and geo + self.assertTrue(df.select('b1_path').distinct().count() == 3) diff --git a/pyrasterframes/src/main/python/tests/RasterSourceTests.py b/pyrasterframes/src/main/python/tests/RasterSourceTests.py deleted file mode 100644 index 08ebe078c..000000000 --- a/pyrasterframes/src/main/python/tests/RasterSourceTests.py +++ /dev/null @@ -1,174 +0,0 @@ -# -# This software is licensed under the Apache 2 license, quoted below. -# -# Copyright 2019 Astraea, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not -# use this file except in compliance with the License. You may obtain a copy of -# the License at -# -# [http://www.apache.org/licenses/LICENSE-2.0] -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations under -# the License. -# -# SPDX-License-Identifier: Apache-2.0 -# - - -from pyrasterframes.rasterfunctions import * -from . import TestEnvironment - -class RasterSource(TestEnvironment): - - def test_handle_lazy_eval(self): - df = self.spark.read.raster(self.img_uri) - ltdf = df.select('proj_raster') - self.assertGreater(ltdf.count(), 0) - self.assertIsNotNone(ltdf.first()) - - tdf = df.select(rf_tile('proj_raster')) - self.assertGreater(tdf.count(), 0) - self.assertIsNotNone(tdf.first()) - - def test_strict_eval(self): - df_lazy = self.spark.read.raster(self.img_uri, lazy_tiles=True) - # when doing Show on a lazy tile we will see something like RasterRefTile(RasterRef(JVMGeoTiffRasterSource(... - # use this trick to get the `show` string - show_str_lazy = df_lazy.select('proj_raster')._jdf.showString(1, -1, False) - self.assertTrue('RasterRef' in show_str_lazy) - - # again for strict - df_strict = self.spark.read.raster(self.img_uri, lazy_tiles=False) - show_str_strict = df_strict.select('proj_raster')._jdf.showString(1, -1, False) - self.assertTrue('RasterRef' not in show_str_strict) - - - def test_prt_functions(self): - df = self.spark.read.raster(self.img_uri) \ - .withColumn('crs', rf_crs('proj_raster')) \ - .withColumn('ext', rf_extent('proj_raster')) \ - .withColumn('geom', rf_geometry('proj_raster')) - df.select('crs', 'ext', 'geom').first() - - def test_raster_source_reader(self): - # much the same as RasterSourceDataSourceSpec here; but using https PDS. Takes about 30s to run - - def l8path(b): - assert b in range(1, 12) - base = "https://s3-us-west-2.amazonaws.com/landsat-pds/c1/L8/199/026/LC08_L1TP_199026_20180919_20180928_01_T1/LC08_L1TP_199026_20180919_20180928_01_T1_B{}.TIF" - return base.format(b) - - path_param = '\n'.join([l8path(b) for b in [1, 2, 3]]) # "http://foo.com/file1.tif,http://foo.com/file2.tif" - tile_size = 512 - - df = self.spark.read.raster( - tile_dimensions=(tile_size, tile_size), - paths=path_param, - lazy_tiles=True, - ).cache() - - # schema is tile_path and tile - # df.printSchema() - self.assertTrue(len(df.columns) == 2 and 'proj_raster_path' in df.columns and 'proj_raster' in df.columns) - - # the most common tile dimensions should be as passed to `options`, showing that options are correctly applied - tile_size_df = df.select(rf_dimensions(df.proj_raster).rows.alias('r'), rf_dimensions(df.proj_raster).cols.alias('c')) \ - .groupby(['r', 'c']).count().toPandas() - most_common_size = tile_size_df.loc[tile_size_df['count'].idxmax()] - self.assertTrue(most_common_size.r == tile_size and most_common_size.c == tile_size) - - # all rows are from a single source URI - path_count = df.groupby(df.proj_raster_path).count() - print(path_count.toPandas()) - self.assertTrue(path_count.count() == 3) - - def test_raster_source_reader_schemeless(self): - import os.path - path = os.path.join(self.resource_dir, "L8-B8-Robinson-IL.tiff") - self.assertTrue(not path.startswith('file://')) - df = self.spark.read.raster(path) - self.assertTrue(df.count() > 0) - - def test_raster_source_catalog_reader(self): - import pandas as pd - - scene_dict = { - 1: 'http://landsat-pds.s3.amazonaws.com/c1/L8/015/041/LC08_L1TP_015041_20190305_20190309_01_T1/LC08_L1TP_015041_20190305_20190309_01_T1_B{}.TIF', - 2: 'http://landsat-pds.s3.amazonaws.com/c1/L8/015/042/LC08_L1TP_015042_20190305_20190309_01_T1/LC08_L1TP_015042_20190305_20190309_01_T1_B{}.TIF', - 3: 'http://landsat-pds.s3.amazonaws.com/c1/L8/016/041/LC08_L1TP_016041_20190224_20190309_01_T1/LC08_L1TP_016041_20190224_20190309_01_T1_B{}.TIF', - } - - def path(scene, band): - assert band in range(1, 12) - p = scene_dict[scene] - return p.format(band) - - # Create a pandas dataframe (makes it easy to create spark df) - path_pandas = pd.DataFrame([ - {'b1': path(1, 1), 'b2': path(1, 2), 'b3': path(1, 3)}, - {'b1': path(2, 1), 'b2': path(2, 2), 'b3': path(2, 3)}, - {'b1': path(3, 1), 'b2': path(3, 2), 'b3': path(3, 3)}, - ]) - # comma separated list of column names containing URI's to read. - catalog_columns = ','.join(path_pandas.columns.tolist()) # 'b1,b2,b3' - path_table = self.spark.createDataFrame(path_pandas) - - path_df = self.spark.read.raster( - tile_dimensions=(512, 512), - catalog=path_table, - catalog_col_names=catalog_columns, - lazy_tiles=True # We'll get an OOM error if we try to read 9 scenes all at once! - ) - - self.assertTrue(len(path_df.columns) == 6) # three bands times {path, tile} - self.assertTrue(path_df.select('b1_path').distinct().count() == 3) # as per scene_dict - b1_paths_maybe = path_df.select('b1_path').distinct().collect() - b1_paths = [s.format('1') for s in scene_dict.values()] - self.assertTrue(all([row.b1_path in b1_paths for row in b1_paths_maybe])) - - def test_raster_source_catalog_reader_with_pandas(self): - import pandas as pd - import geopandas - from shapely.geometry import Point - - scene_dict = { - 1: 'http://landsat-pds.s3.amazonaws.com/c1/L8/015/041/LC08_L1TP_015041_20190305_20190309_01_T1/LC08_L1TP_015041_20190305_20190309_01_T1_B{}.TIF', - 2: 'http://landsat-pds.s3.amazonaws.com/c1/L8/015/042/LC08_L1TP_015042_20190305_20190309_01_T1/LC08_L1TP_015042_20190305_20190309_01_T1_B{}.TIF', - 3: 'http://landsat-pds.s3.amazonaws.com/c1/L8/016/041/LC08_L1TP_016041_20190224_20190309_01_T1/LC08_L1TP_016041_20190224_20190309_01_T1_B{}.TIF', - } - - def path(scene, band): - assert band in range(1, 12) - p = scene_dict[scene] - return p.format(band) - - # Create a pandas dataframe (makes it easy to create spark df) - path_pandas = pd.DataFrame([ - {'b1': path(1, 1), 'b2': path(1, 2), 'b3': path(1, 3), 'geo': Point(1, 1)}, - {'b1': path(2, 1), 'b2': path(2, 2), 'b3': path(2, 3), 'geo': Point(2, 2)}, - {'b1': path(3, 1), 'b2': path(3, 2), 'b3': path(3, 3), 'geo': Point(3, 3)}, - ]) - - # here a subtle difference with the test_raster_source_catalog_reader test, feed the DataFrame not a CSV and not an already created spark DF. - df = self.spark.read.raster( - catalog=path_pandas, - catalog_col_names=['b1', 'b2', 'b3'] - ) - self.assertEqual(len(df.columns), 7) # three path cols, three tile cols, and geo - self.assertTrue('geo' in df.columns) - self.assertTrue(df.select('b1_path').distinct().count() == 3) - - - # Same test with geopandas - geo_df = geopandas.GeoDataFrame(path_pandas, crs={'init': 'EPSG:4326'}, geometry='geo') - df2 = self.spark.read.raster( - catalog=geo_df, - catalog_col_names=['b1', 'b2', 'b3'] - ) - self.assertEqual(len(df2.columns), 7) # three path cols, three tile cols, and geo - self.assertTrue('geo' in df2.columns) - self.assertTrue(df2.select('b1_path').distinct().count() == 3) diff --git a/pyrasterframes/src/main/python/tests/__init__.py b/pyrasterframes/src/main/python/tests/__init__.py index 152859fb0..2fe44a1dd 100644 --- a/pyrasterframes/src/main/python/tests/__init__.py +++ b/pyrasterframes/src/main/python/tests/__init__.py @@ -71,7 +71,9 @@ def setUpClass(cls): cls.spark = spark_test_session() - cls.img_uri = 'file://' + os.path.join(cls.resource_dir, 'L8-B8-Robinson-IL.tiff') + cls.img_path = os.path.join(cls.resource_dir, 'L8-B8-Robinson-IL.tiff') + + cls.img_uri = 'file://' + cls.img_path @classmethod def l8band_uri(cls, band_index): @@ -88,4 +90,3 @@ def create_layer(self): self.rf = rf.withColumn('tile2', rf_convert_cell_type('tile', 'float32')) \ .drop('tile') \ .withColumnRenamed('tile2', 'tile').as_layer() - # cls.rf.show()