Skip to content

Commit

Permalink
Extend sieve's input to allow opened datasets and multiband objects (#…
Browse files Browse the repository at this point in the history
…2838)

* Extend sieve's input to allow opened datasets and multiband objects

Resolves #2782

* Remove commented code, clean up imports
  • Loading branch information
sgillies authored May 22, 2023
1 parent 78fc94a commit 5dd56f3
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 47 deletions.
2 changes: 2 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Changes
1.3.7 (2023-05-22)
------------------

- The sieve function now accepts as input opened datasets or multiband
Band objects (#2838).
- Allow color values greater than 256 in colormaps (#2769).
- Fix the GDAL datatype mapping of Rasterio's uint64 and int64 data types. They
were reversed in previous versions.
Expand Down
63 changes: 40 additions & 23 deletions rasterio/_features.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -148,20 +148,20 @@ def _shapes(image, mask, connectivity, transform):


def _sieve(image, size, out, mask, connectivity):
"""
Replaces small polygons in `image` with the value of their largest
neighbor. Polygons are found for each set of neighboring pixels of the
same value.
"""Remove small polygon regions from a raster.
Parameters
----------
image : array or dataset object opened in 'r' mode or Band or tuple(dataset, bidx)
Must be of type rasterio.int16, rasterio.int32, rasterio.uint8,
rasterio.uint16, or rasterio.float32.
image : ndarray or Band
The source is a 2 or 3-D ndarray, or a single or a multiple
Rasterio Band object. Must be of type rasterio.int16,
rasterio.int32, rasterio.uint8, rasterio.uint16, or
rasterio.float32
size : int
minimum polygon size (number of pixels) to retain.
out : numpy ndarray
Array of same shape and data type as `image` in which to store results.
Array of same shape and data type as `image` in which to store
results.
mask : numpy ndarray or rasterio Band object
Values of False or 0 will be excluded from feature generation.
Must evaluate to bool (rasterio.bool_ or rasterio.uint8)
Expand All @@ -182,8 +182,7 @@ def _sieve(image, size, out, mask, connectivity):
valid_dtypes = ('int16', 'int32', 'uint8', 'uint16')

if _getnpdtype(image.dtype).name not in valid_dtypes:
valid_types_str = ', '.join(('rasterio.{0}'.format(t) for t
in valid_dtypes))
valid_types_str = ', '.join(('rasterio.{0}'.format(t) for t in valid_dtypes))
raise ValueError(
"image dtype must be one of: {0}".format(valid_types_str))

Expand All @@ -206,26 +205,40 @@ def _sieve(image, size, out, mask, connectivity):
try:

if dtypes.is_ndarray(image):
if len(image.shape) == 2:
image = image.reshape(1, *image.shape)
src_count = image.shape[0]
src_bidx = list(range(1, src_count + 1))
in_mem_ds = MemoryDataset(image)
in_band = in_mem_ds.band(1)
src_dataset = in_mem_ds

elif isinstance(image, tuple):
rdr = image.ds
in_band = (<DatasetReaderBase?>rdr).band(image.bidx)
src_dataset, src_bidx, dtype, shape = image
if isinstance(src_bidx, int):
src_bidx = [src_bidx]

else:
raise ValueError("Invalid source image")

if dtypes.is_ndarray(out):
log.debug("out array: %r", out)
if len(out.shape) == 2:
out = out.reshape(1, *out.shape)
dst_count = out.shape[0]
dst_bidx = list(range(1, dst_count + 1))
out_mem_ds = MemoryDataset(out)
out_band = out_mem_ds.band(1)
dst_dataset = out_mem_ds

elif isinstance(out, tuple):
udr = out.ds
out_band = (<DatasetReaderBase?>udr).band(out.bidx)
dst_dataset, dst_bidx, _, _ = out
if isinstance(dst_bidx, int):
dst_bidx = [dst_bidx]

else:
raise ValueError("Invalid out image")

if mask is not None:
if mask.shape != image.shape:
if mask.shape != image.shape[-2:]:
raise ValueError("Mask must have same shape as image")

if _getnpdtype(mask.dtype) not in ('bool', 'uint8'):
Expand All @@ -241,12 +254,11 @@ def _sieve(image, size, out, mask, connectivity):
mask_reader = mask.ds
mask_band = (<DatasetReaderBase?>mask_reader).band(mask.bidx)

GDALSieveFilter(in_band, mask_band, out_band, size, connectivity,
NULL, NULL, NULL)

else:
# Read from out_band into out
io_auto(out, out_band, False)
for i, j in zip(src_bidx, dst_bidx):
in_band = (<DatasetReaderBase?>src_dataset).band(i)
out_band = (<DatasetReaderBase?>dst_dataset).band(j)
GDALSieveFilter(in_band, mask_band, out_band, size, connectivity, NULL, NULL, NULL)
io_auto(out[i - 1], out_band, False)

finally:
if in_mem_ds is not None:
Expand All @@ -256,6 +268,11 @@ def _sieve(image, size, out, mask, connectivity):
if mask_mem_ds is not None:
mask_mem_ds.close()

if out.shape[0] == 1:
out = out[0]

return out


def _rasterize(shapes, image, transform, all_touched, merge_alg):
"""
Expand Down
50 changes: 26 additions & 24 deletions rasterio/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,16 @@
import numpy as np

import rasterio
from rasterio.dtypes import (
validate_dtype,
can_cast_dtype,
get_minimum_dtype,
_getnpdtype,
)
from rasterio import warp
from rasterio._base import DatasetBase
from rasterio._features import _shapes, _sieve, _rasterize, _bounds
from rasterio.dtypes import validate_dtype, can_cast_dtype, get_minimum_dtype, _getnpdtype
from rasterio.enums import MergeAlg
from rasterio.env import ensure_env, GDALVersion
from rasterio.errors import ShapeSkipWarning
from rasterio._features import _shapes, _sieve, _rasterize, _bounds
from rasterio import warp
from rasterio.rio.helpers import coords
from rasterio.transform import Affine
from rasterio.transform import IDENTITY, guard_transform, rowcol
from rasterio.transform import IDENTITY, guard_transform
from rasterio.windows import Window

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -133,19 +129,23 @@ def shapes(source, mask=None, connectivity=4, transform=IDENTITY):

@ensure_env
def sieve(source, size, out=None, mask=None, connectivity=4):
"""Replace small polygons in `source` with value of their largest neighbor.
"""Remove small polygon regions from a raster.
Polygons are found for each set of neighboring pixels of the same value.
Polygons are found for each set of neighboring pixels of the same
value.
Parameters
----------
source : array or dataset object opened in 'r' mode or Band or tuple(dataset, bidx)
Must be of type rasterio.int16, rasterio.int32, rasterio.uint8,
rasterio.uint16, or rasterio.float32
source : ndarray, dataset, or Band
The source is a 2 or 3-D ndarray, a dataset opened in "r" mode,
or a single or a multiple Rasterio Band object. Must be of type
rasterio.int16, rasterio.int32, rasterio.uint8, rasterio.uint16,
or rasterio.float32
size : int
minimum polygon size (number of pixels) to retain.
out : numpy ndarray, optional
Array of same shape and data type as `source` in which to store results.
Array of same shape and data type as `source` in which to store
results.
mask : numpy ndarray or rasterio Band object, optional
Values of False or 0 will be excluded from feature generation
Must evaluate to bool (rasterio.bool_ or rasterio.uint8)
Expand All @@ -159,21 +159,23 @@ def sieve(source, size, out=None, mask=None, connectivity=4):
Notes
-----
GDAL only supports values that can be cast to 32-bit integers for this
operation.
GDAL only supports values that can be cast to 32-bit integers for
this operation.
The amount of memory used by this algorithm is proportional to the number
and complexity of polygons found in the image. This algorithm is most
appropriate for simple thematic data. Data with high pixel-to-pixel
variability, such as imagery, may produce one polygon per pixel and consume
large amounts of memory.
The amount of memory used by this algorithm is proportional to the
number and complexity of polygons found in the image. This
algorithm is most appropriate for simple thematic data. Data with
high pixel-to-pixel variability, such as imagery, may produce one
polygon per pixel and consume large amounts of memory.
"""
if isinstance(source, DatasetBase):
source = rasterio.band(source, source.indexes)

if out is None:
out = np.zeros(source.shape, source.dtype)
_sieve(source, size, out, mask, connectivity)
return out

return _sieve(source, size, out, mask, connectivity)


@ensure_env
Expand Down
28 changes: 28 additions & 0 deletions tests/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,3 +1272,31 @@ def test_zz_no_dataset_leaks(capfd):
env._dump_open_datasets()
captured = capfd.readouterr()
assert not captured.err


def test_sieve_bands(pixelated_image, pixelated_image_file):
"""Verify fix for gh-2782."""
truth = sieve(pixelated_image, 9)

with rasterio.open(pixelated_image_file) as src:
assert np.array_equal(truth, sieve(rasterio.band(src, [1]), 9))

# Mask band should also work but will be a no-op
assert np.array_equal(
pixelated_image,
sieve(rasterio.band(src, [1]), 9, mask=rasterio.band(src, 1))
)


def test_sieve_dataset(pixelated_image, pixelated_image_file):
"""Verify fix for gh-2782."""
truth = sieve(pixelated_image, 9)

with rasterio.open(pixelated_image_file) as src:
assert np.array_equal(truth, sieve(src, 9))

# Mask band should also work but will be a no-op
assert np.array_equal(
pixelated_image,
sieve(src, 9, mask=rasterio.band(src, 1))
)

0 comments on commit 5dd56f3

Please sign in to comment.