Skip to content

Commit

Permalink
remove force_same_dtype arg from MultiRasterSource
Browse files Browse the repository at this point in the history
Require sub raster sources to have the same dtype.
  • Loading branch information
AdeelH committed Aug 6, 2024
1 parent d152210 commit 0f45c2c
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,25 @@


class MultiRasterSource(RasterSource):
"""Merge multiple ``RasterSources`` by concatenating along channel dim."""
"""
Merge multiple ``RasterSources`` by concatenating along the channel dim.
"""

def __init__(self,
raster_sources: Sequence[RasterSource],
primary_source_idx: NonNegInt = 0,
force_same_dtype: bool = False,
channel_order: Sequence[NonNegInt] | None = None,
raster_transformers: Sequence = [],
raster_transformers: Sequence['RasterTransformer'] = [],
bbox: Box | None = None):
"""Constructor.
Args:
raster_sources: Sequence of RasterSources.
primary_source_idx (0 <= int < len(raster_sources)): Index of the
raster source whose CRS, dtype, and other attributes will
override those of the other raster sources.
force_same_dtype: If true, force all sub-chips to have the
same dtype as the primary_source_idx-th sub-chip. No careful
conversion is done, just a quick cast. Use with caution.
primary_source_idx: Index of the raster source whose CRS, dtype,
and other attributes will override those of the other raster
sources.
channel_order: Channel ordering that will be used by
:meth:`MultiRasterSource.get_chip()`. Defaults to ``None``.
:meth:`.MultiRasterSource.get_chip`. Defaults to ``None``.
raster_transformers: List of transformers. Defaults to ``[]``.
bbox: User-specified crop of the extent. If specified, the primary
raster source's bbox is set to this. If ``None``, the full
Expand Down Expand Up @@ -65,14 +63,12 @@ def __init__(self,
bbox=bbox,
raster_transformers=raster_transformers)

self.force_same_dtype = force_same_dtype
self.raster_sources = raster_sources
self.primary_source_idx = primary_source_idx
self.non_primary_sources = [
rs for i, rs in enumerate(raster_sources)
if i != primary_source_idx
]

self.validate_raster_sources()

@classmethod
Expand All @@ -82,7 +78,6 @@ def from_stac(
assets: list[str] | None,
primary_source_idx: NonNegInt = 0,
raster_transformers: list['RasterTransformer'] = [],
force_same_dtype: bool = False,
channel_order: Sequence[int] | None = None,
bbox: Box | tuple[int, int, int, int] | None = None,
bbox_map_coords: Box | tuple[int, int, int, int] | None = None,
Expand All @@ -99,14 +94,11 @@ def from_stac(
item: STAC Item.
assets: List of names of assets to use. If ``None``, all assets
present in the item will be used. Defaults to ``None``.
primary_source_idx (0 <= int < len(raster_sources)): Index of the
raster source whose CRS, dtype, and other attributes will
override those of the other raster sources.
primary_source_idx: Index of the raster source whose CRS, dtype,
and other attributes will override those of the other raster
sources.
raster_transformers: RasterTransformers to use to transform chips
after they are read.
force_same_dtype: If true, force all sub-chips to have the
same dtype as the primary_source_idx-th sub-chip. No careful
conversion is done, just a quick cast. Use with caution.
after they are read. Defaults to ``[]``.
channel_order: List of indices of channels to extract from raw
imagery. Can be a subset of the available channels. If None,
all channels available in the image will be read.
Expand Down Expand Up @@ -148,25 +140,19 @@ def from_stac(
primary_source_idx=primary_source_idx,
raster_transformers=raster_transformers,
channel_order=channel_order,
force_same_dtype=force_same_dtype,
bbox=bbox)
return raster_source

def validate_raster_sources(self) -> None:
"""Validate sub-``RasterSources``.
Checks if:
- dtypes are same or ``force_same_dtype`` is True.
Checks if all raster sources have the same dtype.
"""
dtypes = [rs.dtype for rs in self.raster_sources]
if not self.force_same_dtype and not all_equal(dtypes):
if not all_equal(dtypes):
raise ValueError(
'dtypes of all sub raster sources must be the same. '
f'Got: {dtypes} '
'(Use force_same_dtype to cast all to the dtype of the '
'primary source)')
f'Got: {dtypes}.')

@property
def primary_source(self) -> RasterSource:
Expand Down Expand Up @@ -234,10 +220,6 @@ def get_chip(rs: RasterSource,
]
sub_chips.insert(self.primary_source_idx, primary_sub_chip)

if self.force_same_dtype:
dtype = sub_chips[self.primary_source_idx].dtype
sub_chips = [chip.astype(dtype) for chip in sub_chips]

return sub_chips

def _get_chip(self, window: Box,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ def multi_rs_config_upgrader(cfg_dict: dict, version: int) -> dict:
if version == 1:
# field renamed in version 2
cfg_dict['primary_source_idx'] = cfg_dict.get('crs_source', 0)
try:
del cfg_dict['crs_source']
except KeyError:
pass
cfg_dict.pop('crs_source', None)
elif version == 13:
# field removed in version 14
cfg_dict.pop('force_same_dtype', None)
return cfg_dict


Expand All @@ -36,10 +36,6 @@ class MultiRasterSourceConfig(RasterSourceConfig):
description=
'Index of the raster source whose CRS, dtype, and other attributes '
'will override those of the other raster sources. Defaults to 0.')
force_same_dtype: bool = Field(
False,
description='Force all subchips to be of the same dtype as the '
'primary_source_idx-th subchip.')
temporal: bool = Field(
False,
description='Stack images from sub raster sources into a time-series '
Expand Down Expand Up @@ -82,14 +78,12 @@ def build(self, tmp_dir: str | None = None,
multi_raster_source = TemporalMultiRasterSource(
raster_sources=built_raster_sources,
primary_source_idx=self.primary_source_idx,
force_same_dtype=self.force_same_dtype,
raster_transformers=raster_transformers,
bbox=bbox)
else:
multi_raster_source = MultiRasterSource(
raster_sources=built_raster_sources,
primary_source_idx=self.primary_source_idx,
force_same_dtype=self.force_same_dtype,
channel_order=self.channel_order,
raster_transformers=raster_transformers,
bbox=bbox)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Sequence
from typing import TYPE_CHECKING, Any, Sequence

from pydantic import NonNegativeInt as NonNegInt
import numpy as np
Expand All @@ -8,32 +8,30 @@
MultiRasterSource)
from rastervision.core.data.utils import all_equal, parse_array_slices_Nd

if TYPE_CHECKING:
from rastervision.core.data import RasterTransformer


class TemporalMultiRasterSource(MultiRasterSource):
"""Merge multiple ``RasterSources`` by stacking them along a new dim."""

def __init__(self,
raster_sources: Sequence[RasterSource],
primary_source_idx: NonNegInt = 0,
force_same_dtype: bool = False,
raster_transformers: Sequence = [],
raster_transformers: Sequence['RasterTransformer'] = [],
bbox: Box | None = None):
"""Constructor.
Args:
raster_sources (Sequence[RasterSource]): Sequence of RasterSources.
primary_source_idx (0 <= int < len(raster_sources)): Index of the
raster source whose CRS, dtype, and other attributes will
override those of the other raster sources.
force_same_dtype (bool): If true, force all sub-chips to have the
same dtype as the primary_source_idx-th sub-chip. No careful
conversion is done, just a quick cast. Use with caution.
raster_transformers (Sequence): Sequence of transformers.
Defaults to [].
bbox (Box | None): User-specified crop of the extent.
If given, the primary raster source's bbox is set to this.
If None, the full extent available in the source file of the
primary raster source is used.
raster_sources: Sequence of RasterSources.
primary_source_idx: Index of the raster source whose CRS, dtype,
and other attributes will override those of the other raster
sources.
raster_transformers: Sequence of transformers. Defaults to ``[]``.
bbox: User-specified crop of the extent. If given, the primary
raster source's bbox is set to this. If ``None``, the full
extent available in the source file of the primary raster
source is used.
"""
if not all_equal([rs.num_channels for rs in raster_sources]):
raise ValueError(
Expand Down Expand Up @@ -62,7 +60,6 @@ def __init__(self,
bbox=bbox,
raster_transformers=raster_transformers)

self.force_same_dtype = force_same_dtype
self.raster_sources = raster_sources
self.primary_source_idx = primary_source_idx

Expand Down
25 changes: 14 additions & 11 deletions tests/core/data/raster_source/test_multi_raster_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
RasterioSourceConfig, MultiRasterSource, MultiRasterSourceConfig,
ReclassTransformerConfig, CastTransformerConfig, XarraySource,
IdentityCRSTransformer, TemporalMultiRasterSource)
from rastervision.core.data.raster_source.multi_raster_source_config import (
multi_rs_config_upgrader)

from tests import data_file_path

Expand Down Expand Up @@ -83,6 +85,13 @@ def test_build_temporal(self):
self.assertIsInstance(rs, TemporalMultiRasterSource)
self.assertEqual(rs.shape, (3, 256, 256, 3))

def test_upgrader_v14(self):
cfg = make_cfg()
cfg_dict_old = cfg.dict()
cfg_dict_old['force_same_dtype'] = True
cfg_dict_new = multi_rs_config_upgrader(cfg_dict_old, 13)
self.assertNotIn('force_same_dtype', cfg_dict_new)


class TestMultiRasterSource(unittest.TestCase):
def assertNoError(self, fn: Callable, msg: str = ''):
Expand Down Expand Up @@ -113,25 +122,19 @@ def test_extent(self):

def test_primary_source_idx(self):
primary_source_idx = 2
non_primary_source_idx = 1

cfg = make_cfg_diverse(
diff_dtypes=True,
force_same_dtype=True,
primary_source_idx=primary_source_idx)
cfg = make_cfg_diverse(primary_source_idx=primary_source_idx)
rs = cfg.build(tmp_dir=self.tmp_dir)
primary_rs = rs.raster_sources[primary_source_idx]
non_primary_rs = rs.raster_sources[non_primary_source_idx]

self.assertEqual(rs.extent, primary_rs.extent)
self.assertNotEqual(rs.extent, non_primary_rs.extent)

self.assertEqual(rs.dtype, primary_rs.dtype)
self.assertNotEqual(rs.dtype, non_primary_rs.dtype)

self.assertEqual(rs.crs_transformer.transform,
primary_rs.crs_transformer.transform)
self.assertNotEqual(rs.crs_transformer, non_primary_rs.crs_transformer)

def test_dtype_validation(self):
cfg = make_cfg_diverse(diff_dtypes=True)
self.assertRaises(ValueError, lambda: cfg.build(tmp_dir=self.tmp_dir))

def test_bbox(self):
# /wo user specified extent
Expand Down

0 comments on commit 0f45c2c

Please sign in to comment.