Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python raster reader argument refactor #329

Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyrasterframes/src/main/python/docs/languages.pymd
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ red_nir_monthly_2017.printSchema()

```python, step_3_python
red_nir_tiles_monthly_2017 = spark.read.raster(
catalog=red_nir_monthly_2017,
red_nir_monthly_2017,
catalog_col_names=['red', 'nir'],
tile_dimensions=(256, 256)
)
Expand Down
2 changes: 1 addition & 1 deletion pyrasterframes/src/main/python/docs/local-algebra.pymd
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ catalog_df = spark.createDataFrame([
Row(red=uri_pattern.format(4), nir=uri_pattern.format(8))
])
df = spark.read.raster(
catalog=catalog_df,
catalog_df,
catalog_col_names=['red', 'nir']
)
df.printSchema()
Expand Down
2 changes: 1 addition & 1 deletion pyrasterframes/src/main/python/docs/nodata-handling.pymd
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ from pyspark.sql import Row
blue_uri = 'https://s22s-test-geotiffs.s3.amazonaws.com/luray_snp/B02.tif'
scl_uri = 'https://s22s-test-geotiffs.s3.amazonaws.com/luray_snp/SCL.tif'
cat = spark.createDataFrame([Row(blue=blue_uri, scl=scl_uri),])
unmasked = spark.read.raster(catalog=cat, catalog_col_names=['blue', 'scl'])
unmasked = spark.read.raster(cat, catalog_col_names=['blue', 'scl'])
unmasked.printSchema()
```

Expand Down
2 changes: 1 addition & 1 deletion pyrasterframes/src/main/python/docs/numpy-pandas.pymd
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ cat = spark.read.format('aws-pds-modis-catalog').load() \
(col('acquisition_date') < lit('2018-02-22'))
)

spark_df = spark.read.raster(catalog=cat, catalog_col_names=['B01']) \
spark_df = spark.read.raster(cat, catalog_col_names=['B01']) \
.select(
'acquisition_date',
'granule_id',
Expand Down
11 changes: 3 additions & 8 deletions pyrasterframes/src/main/python/docs/raster-read.pymd
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ modis_catalog = spark.read \
.withColumn('red' , F.concat('base_url', F.lit("_B01.TIF"))) \
.withColumn('nir' , F.concat('base_url', F.lit("_B02.TIF")))

modis_catalog.printSchema()

print("Available scenes: ", modis_catalog.count())
```

Expand All @@ -124,10 +122,7 @@ equator.select('date', 'gid')
Now that we have prepared our catalog, we simply pass the DataFrame or CSV string to the `raster` DataSource to load the imagery. The `catalog_col_names` parameter gives the columns that contain the URI's to be read.

```python, read_catalog
rf = spark.read.raster(
catalog=equator,
catalog_col_names=['red', 'nir']
)
rf = spark.read.raster(equator, catalog_col_names=['red', 'nir'])
rf.printSchema()
```

Expand Down Expand Up @@ -179,7 +174,7 @@ mb.printSchema()

If a band is passed into `band_indexes` that exceeds the number of bands in the raster, a projected raster column will still be generated in the schema but the column will be full of `null` values.

You can also pass a `catalog` and `band_indexes` together into the `raster` reader. This will create a projected raster column for the combination of all items passed into `catalog_col_names` and `band_indexes`. Again if a band in `band_indexes` exceeds the number of bands in a raster, it will have a `null` value for the corresponding column.
You can also pass a _catalog_ and `band_indexes` together into the `raster` reader. This will create a projected raster column for the combination of all items in `catalog_col_names` and `band_indexes`. Again if a band in `band_indexes` exceeds the number of bands in a raster, it will have a `null` value for the corresponding column.

Here is a trivial example with a _catalog_ over multiband rasters. We specify two columns containing URIs and two bands, resulting in four projected raster columns.

Expand All @@ -191,7 +186,7 @@ mb_cat = pd.DataFrame([
},
])
mb2 = spark.read.raster(
catalog=spark.createDataFrame(mb_cat),
spark.createDataFrame(mb_cat),
catalog_col_names=['foo', 'bar'],
band_indexes=[0, 1],
tile_dimensions=(64,64)
Expand Down
6 changes: 2 additions & 4 deletions pyrasterframes/src/main/python/docs/supervised-learning.pymd
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@ catalog_df = pd.DataFrame([
{b: uri_base.format(b) for b in cols}
])

df = spark.read.raster(catalog=catalog_df,
catalog_col_names=cols,
tile_dimensions=(128, 128)
).repartition(100)
df = spark.read.raster(catalog_df, catalog_col_names=cols, tile_dimensions=(128, 128)) \
.repartition(100)

df = df.select(
rf_crs(df.B01).alias('crs'),
Expand Down
2 changes: 1 addition & 1 deletion pyrasterframes/src/main/python/docs/time-series.pymd
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ We then [reproject](https://gis.stackexchange.com/questions/247770/understanding
```python read_catalog
raster_cols = ['B01', 'B02',] # red and near-infrared respectively
park_rf = spark.read.raster(
catalog=park_cat.select(['acquisition_date', 'granule_id', 'geo_simp'] + raster_cols),
park_cat.select(['acquisition_date', 'granule_id', 'geo_simp'] + raster_cols),
catalog_col_names=raster_cols) \
.withColumn('park_native', st_reproject('geo_simp', lit('EPSG:4326'), rf_crs('B01'))) \
.filter(st_intersects('park_native', rf_geometry('B01')))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ filenamePattern = "L8-B{}-Elkton-VA.tiff"
catalog_df = pd.DataFrame([
{'b' + str(b): os.path.join(resource_dir_uri(), filenamePattern.format(b)) for b in range(1, 8)}
])
df = spark.read.raster(catalog=catalog_df, catalog_col_names=catalog_df.columns)
df = spark.read.raster(catalog_df, catalog_col_names=catalog_df.columns)
df = df.select(
rf_crs(df.b1).alias('crs'),
rf_extent(df.b1).alias('extent'),
Expand Down
59 changes: 57 additions & 2 deletions pyrasterframes/src/main/python/pyrasterframes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,33 @@ def _aliased_writer(df_writer, format_key, path, **options):

def _raster_reader(
df_reader,
path=None,
catalog=None,
source=None,
catalog_col_names=None,
band_indexes=None,
tile_dimensions=(256, 256),
lazy_tiles=True,
**options):
"""
Returns a Spark DataFrame from raster data files specified by URI pointers
vpipkt marked this conversation as resolved.
Show resolved Hide resolved
The returned DataFrame will have a column of (CRS, Extent, Tile) for each URI read
Multiple bands from the same raster file are spread across rows of the DataFrame. See band_indexes param.
vpipkt marked this conversation as resolved.
Show resolved Hide resolved
If bands from a scene are stored in separate files, provide a DataFrame to the `source` parameter. Each row in the returned DataFrame will contain one (CRS, Extent, Tile) for each item in `catalog_col_names`

For more details and example usage, consult https://rasterframes.io/raster-read.html

:param source: a string, list of strings, list of lists of strings, a pandas DataFrame or a Spark DataFrame giving URIs to the raster data to read
vpipkt marked this conversation as resolved.
Show resolved Hide resolved
:param catalog_col_names: required if source is a DataFrame or CSV string. It is a list of strings giving the names of columns containing URIs to read
:param band_indexes: list of integers indicating which bands, zero-based, to read from the raster files specified; default is to read only the first band
:param tile_dimensions: tuple or list of two indicating the default tile dimension as (columns, rows)
:param lazy_tiles: If true (default) only generate minimal references to tile contents; if false, fetch tile cell values
:param options: Additional keyword arguments to pass to the spark DataSource
"""

from pandas import DataFrame as PdDataFrame

if 'catalog' in options:
source = options['catalog'] # maintain back compatibility with 0.8.0

def to_csv(comp):
if isinstance(comp, str):
return comp
Expand All @@ -140,15 +157,49 @@ def temp_name():
"lazyTiles": lazy_tiles
})

# Parse the `source` argument
path = None # to pass into `path` param
if isinstance(source, list):
if all([isinstance(i, str) for i in source]):
path = None
catalog = None
options.update(dict(paths='\n'.join([str(i) for i in source]))) # pass in "uri1\nuri2\nuri3\n..."
if all([isinstance(i, list) for i in source]):
# list of lists; we will rely on pandas to
# - coerce all data to str (possibly using objects' __str__ or __repr__\
# - ensure data is not "ragged": all sublists are same len
path = None
catalog_col_names = ['proj_raster_{}'.format(i) for i in range(len(source[0]))]
catalog = PdDataFrame(source,
columns=catalog_col_names,
dtype=str,
)
elif isinstance(source, str):
if '\n' in source or '\r' in source:
# then the `source` string is a catalog as a CSV (header is required)
path = None
catalog = source
else:
# interpret source as a single URI string
path = source
catalog = None
else:
# user has passed in some other type, we will try to interpret as a catalog
catalog = source

if catalog is not None:
if catalog_col_names is None:
raise Exception("'catalog_col_names' required when DataFrame 'catalog' specified")

if isinstance(catalog, str):
options.update({
"catalogCSV": catalog,
"catalogColumns": to_csv(catalog_col_names)
})
elif isinstance(catalog, DataFrame):
# check catalog_col_names
assert all([c in catalog.columns for c in catalog_col_names]), \
"All items in catalog_col_names must be the name of a column in the catalog DataFrame."
# Create a random view name
tmp_name = temp_name()
catalog.createOrReplaceTempView(tmp_name)
Expand All @@ -157,6 +208,10 @@ def temp_name():
"catalogColumns": to_csv(catalog_col_names)
})
elif isinstance(catalog, PdDataFrame):
# check catalog_col_names
assert all([c in catalog.columns for c in catalog_col_names]), \
"All items in catalog_col_names must be the name of a column in the catalog DataFrame."

# Handle to active spark session
session = SparkContext._active_spark_context._rf_context._spark_session
# Create a random view name
Expand Down
2 changes: 0 additions & 2 deletions pyrasterframes/src/main/python/tests/PyRasterFramesTests.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,6 @@ def test_raster_join(self):
self.rf.raster_join(rf_prime, join_exprs=self.rf.extent)




def suite():
function_tests = unittest.TestSuite()
return function_tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def test_render_composite(self):
cat = self.spark.createDataFrame([
Row(red=self.l8band_uri(4), green=self.l8band_uri(3), blue=self.l8band_uri(2))
])
rf = self.spark.read.raster(catalog = cat, catalog_col_names=['red', 'green', 'blue'])
rf = self.spark.read.raster(catalog=cat, catalog_col_names=cat.columns)
vpipkt marked this conversation as resolved.
Show resolved Hide resolved

# Test composite construction
rgb = rf.select(rf_tile(rf_rgb_composite('red', 'green', 'blue')).alias('rgb')).first()['rgb']
Expand Down
Loading