Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow band indexing in RasterDataset #687

Merged
merged 9 commits into from
Sep 14, 2022
Merged

Conversation

ashnair1
Copy link
Collaborator

RasterDataset by default loads all bands. This PR allows us to be more selective.

@github-actions github-actions bot added the datasets Geospatial or benchmark datasets label Jul 26, 2022
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already a self.bands attribute for this. Right now, it's only used when self.separate_files, but we should be able to reuse this for the purpose you're using it for here. The only problem is that self.bands is the string names, not the integer indices. Maybe we can translate from one to the other using self.all_bands?

@ashnair1
Copy link
Collaborator Author

ashnair1 commented Jul 27, 2022

I did notice those attributes and don't think they can be reused.

bands, all_bands and rgb_bands are a list of band filenames names used when bands are stored separately. Suppose I have a directory of 8 band geotiffs and I want to create a RasterDataset only using bands 2, 3 and 4. How could I do that? There are no band names to go off of.

Conceptually I see it as two different things. The current bands/all_bands attribute helps in assembling a multi band tensor from disparate band files. The band_indexes attribute I'm proposing is to provide users a provision to select the bands they require from the assembled multi band tensor.

@adamjstewart
Copy link
Collaborator

I think it's possible to reuse these. There's no reason that bands/all_bands/rgb_bands should only apply to separate_files = True, we also use them in datasets like NAIP where separate_files = False. If we enforce that the bands in all_bands are in the same order as the actual band indices in a single file, then we can 1) check to make sure that bands are real bands listed in all_bands, and 2) use the ordering of all_bands to compute band_indices from bands. The only thing that would need to change is to add a bands parameter to the affected datasets. I actually think it would be worth adding this directly to RasterDataset since most of our datasets are multi-band and the ones that aren't can just define a single band.

@ashnair1
Copy link
Collaborator Author

ashnair1 commented Jul 28, 2022

I think I understand. Do you mean something like this in the __init__ method?

if self.all_bands and self.bands:
    band_indexes =  [self.all_bands.index(i) + 1 for i in self.bands]
    assert len(band_indexes) == len(self.bands)
else:
    band_indexes = None

self.band_indexes = band_indexes

@adamjstewart
Copy link
Collaborator

Yes, something like that would be half of it. The other half is users need a way of specifying this when they instantiate a dataset (bands parameter). Some datasets have this, others don't. I think it's worth adding this to RasterDataset, then we can actually get rid of the __init__ in some subclasses.

@ashnair1
Copy link
Collaborator Author

ashnair1 commented Jul 28, 2022

So we have a bands parameter that determines our bands of interest and all_bands which are all the bands in the dataset.

If we go back to my previous example of creating a RasterDataset from a directory of 8 band tiffs using only bands 2, 3, and 4 it should look like this right? Assuming we added bands directly to RasterDataset

class Raster8BandFolder(RasterDataset):

     all_bands = [f"B{i}" for i in range(1,9)]

def __init__(
        self,
        root: str = "data",
        crs: Optional[CRS] = None,
        res: Optional[float] = None,
        bands: Sequence[str] = [],
        transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
        cache: bool = True,
    ) -> None:

    super().__init__(root, crs, res, bands, transforms, cache)


raster8 = Raster8BandFolder(root="/path/to/folder", bands=["B2", "B3", "B4"])

@ashnair1
Copy link
Collaborator Author

Modification to RasterDataset

class RasterDataset(GeoDataset):
    """Abstract base class for :class:`GeoDataset` stored as raster files."""

    #: Glob expression used to search for files.
    #:
    #: This expression should be specific enough that it will not pick up files from
    #: other datasets. It should not include a file extension, as the dataset may be in
    #: a different file format than what it was originally downloaded as.
    filename_glob = "*"

    #: Regular expression used to extract date from filename.
    #:
    #: The expression should use named groups. The expression may contain any number of
    #: groups. The following groups are specifically searched for by the base class:
    #:
    #: * ``date``: used to calculate ``mint`` and ``maxt`` for ``index`` insertion
    #:
    #: When :attr:`separate_files`` is True, the following additional groups are
    #: searched for to find other files:
    #:
    #: * ``band``: replaced with requested band name
    #: * ``resolution``: replaced with a glob character
    filename_regex = ".*"

    #: Date format string used to parse date from filename.
    #:
    #: Not used if :attr:`filename_regex` does not contain a ``date`` group.
    date_format = "%Y%m%d"

    #: True if dataset contains imagery, False if dataset contains mask
    is_image = True

    #: True if data is stored in a separate file for each band, else False.
    separate_files = False

    #: Names of all available bands in the dataset
    all_bands: List[str] = []

    #: Names of RGB bands in the dataset, used for plotting
    rgb_bands: List[str] = []

    #: Color map for the dataset, used for plotting
    cmap: Dict[int, Tuple[int, int, int, int]] = {}

    def __init__(
        self,
        root: str,
        crs: Optional[CRS] = None,
        res: Optional[float] = None,
+       bands: List[str] = [],
        transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
        cache: bool = True,
    ) -> None:
        """Initialize a new Dataset instance.

        Args:
            root: root directory where dataset can be found
            crs: :term:`coordinate reference system (CRS)` to warp to
                (defaults to the CRS of the first file found)
            res: resolution of the dataset in units of CRS
                (defaults to the resolution of the first file found)
+           bands: list of band names to be used
            transforms: a function/transform that takes an input sample
                and returns a transformed version
            cache: if True, cache file handle to speed up repeated sampling

        Raises:
            FileNotFoundError: if no files are found in ``root``
        """
        super().__init__(transforms)

        self.root = root
        self.cache = cache

        # Populate the dataset index
        i = 0
        pathname = os.path.join(root, "**", self.filename_glob)
        filename_regex = re.compile(self.filename_regex, re.VERBOSE)
        for filepath in glob.iglob(pathname, recursive=True):
            match = re.match(filename_regex, os.path.basename(filepath))
            if match is not None:
                try:
                    with rasterio.open(filepath) as src:
                        # See if file has a color map
                        if len(self.cmap) == 0:
                            try:
                                self.cmap = src.colormap(1)
                            except ValueError:
                                pass

                        if crs is None:
                            crs = src.crs
                        if res is None:
                            res = src.res[0]

                        with WarpedVRT(src, crs=crs) as vrt:
                            minx, miny, maxx, maxy = vrt.bounds
                except rasterio.errors.RasterioIOError:
                    # Skip files that rasterio is unable to read
                    continue
                else:
                    mint: float = 0
                    maxt: float = sys.maxsize
                    if "date" in match.groupdict():
                        date = match.group("date")
                        mint, maxt = disambiguate_timestamp(date, self.date_format)

                    coords = (minx, maxx, miny, maxy, mint, maxt)
                    self.index.insert(i, coords, filepath)
                    i += 1

        if i == 0:
            raise FileNotFoundError(
                f"No {self.__class__.__name__} data was found in '{root}'"
            )

+        if not self.all_bands:
+            band_indexes = None
+        else:
+            if self.bands:
+                band_indexes = [self.all_bands.index(i) + 1 for i in self.bands]
+               assert len(band_indexes) == len(self.bands)
+            else:
+                band_indexes = None
+
+            if self.rgb_bands:
+                rgb_band_indexes = [self.all_bands.index(i) + 1 for i in self.rgb_bands]
+                assert len(rgb_band_indexes) == len(self.rgb_bands)
+
+       self.band_indexes = band_indexes
        self._crs = cast(CRS, crs)
        self.res = cast(float, res)

@adamjstewart
Copy link
Collaborator

Yes, that looks correct to me. You wouldn't even need to override __init__ in your Raster8BandFolder example dataset since you aren't changing anything.

@ashnair1
Copy link
Collaborator Author

Great. I'll push a PR tomorrow (once mypy is satisfied) and we can continue from there.

@ashnair1 ashnair1 requested a review from adamjstewart August 2, 2022 08:30
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for taking so long to review this!

@ashnair1
Copy link
Collaborator Author

Looks like we can remove Sentinel2's __init__

@adamjstewart
Copy link
Collaborator

Yep, let's remove Sentinel-2's __init__

@ashnair1 ashnair1 requested a review from adamjstewart August 13, 2022 06:27
@adamjstewart adamjstewart added this to the 0.4.0 milestone Aug 13, 2022
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also add a dummy _verify() method to the base class and get rid of the __init__ for most methods, but let's save that for another PR. There's also a bug where classes with no __init__ don't get any docs, but there's a sphinx setting to fix that I need to play around with.

Can you add unit tests for this? We'll want to test selecting a subset of bands for at least one separate = True and one separate = False dataset. Test just needs to make sure that the total number of bands returned actually changes.

@github-actions github-actions bot added the testing Continuous integration testing label Aug 15, 2022
@ashnair1 ashnair1 requested a review from adamjstewart August 15, 2022 18:19
@adamjstewart
Copy link
Collaborator

Remind me to review this this weekend.

@ashnair1
Copy link
Collaborator Author

Pinging @adamjstewart

@ashnair1 ashnair1 force-pushed the bands branch 3 times, most recently from 77ea1f1 to 528fcc8 Compare September 6, 2022 16:50
adamjstewart
adamjstewart previously approved these changes Sep 14, 2022
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! Couple minor formatting requests, but otherwise this looks ready to me.

@adamjstewart adamjstewart merged commit 43916e0 into microsoft:main Sep 14, 2022
@ashnair1 ashnair1 deleted the bands branch September 14, 2022 15:45
yichiac pushed a commit to yichiac/torchgeo that referenced this pull request Apr 29, 2023
* Allow band indexing

* Add bands attribute to RasterDataset

* Review comments#1

* Remove sentinel2 __init__ & fix landsat test

* Add tests

* Add test for coverage

* Review comments#2

* Review comments#3

* Trigger build
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datasets Geospatial or benchmark datasets testing Continuous integration testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants