Skip to content

Commit

Permalink
Add normalized burn ratio and tests (microsoft#284)
Browse files Browse the repository at this point in the history
  • Loading branch information
RitwikGupta authored Dec 17, 2021
1 parent 53b2986 commit b402252
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 1 deletion.
13 changes: 13 additions & 0 deletions tests/transforms/test_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ def test_ndwi(sample: Dict[str, Tensor]) -> None:
assert index.shape[-2:] == sample["image"].shape[-2:]


def test_nbr(sample: Dict[str, Tensor]) -> None:
index = indices.nbr(nir=sample["image"], swir=sample["image"])
assert index.ndim == 3
assert index.shape[-2:] == sample["image"].shape[-2:]


def test_append_ndbi(batch: Dict[str, Tensor]) -> None:
b, c, h, w = batch["image"].shape
tr = indices.AppendNDBI(index_swir=0, index_nir=0)
Expand All @@ -89,3 +95,10 @@ def test_append_ndwi(batch: Dict[str, Tensor]) -> None:
tr = indices.AppendNDWI(index_green=0, index_nir=0)
output = tr(batch)
assert output["image"].shape == (b, c + 1, h, w)


def test_append_nbr(batch: Dict[str, Tensor]) -> None:
b, c, h, w = batch["image"].shape
tr = indices.AppendNBR(index_nir=0, index_swir=0)
output = tr(batch)
assert output["image"].shape == (b, c + 1, h, w)
3 changes: 2 additions & 1 deletion torchgeo/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

"""TorchGeo transforms."""

from .indices import AppendNDBI, AppendNDSI, AppendNDVI, AppendNDWI
from .indices import AppendNBR, AppendNDBI, AppendNDSI, AppendNDVI, AppendNDWI
from .transforms import AugmentationSequential

__all__ = (
"AppendNDBI",
"AppendNBR",
"AppendNDSI",
"AppendNDVI",
"AppendNDWI",
Expand Down
56 changes: 56 additions & 0 deletions torchgeo/transforms/indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,26 @@ def ndbi(swir: Tensor, nir: Tensor) -> Tensor:
Returns:
tensor containing computed NDBI values
"""
return (swir - nir) / ((swir + nir) + _EPSILON)


def nbr(nir: Tensor, swir: Tensor) -> Tensor:
"""Compute Normalized Burn Ratio (NBR).
Args:
nir: tensor containing nir band
swir: tensor containing swir band
Returns:
tensor containing computed NBR values
.. versionadded:: 0.2.0
"""
return (nir - swir) / ((nir + swir) + _EPSILON)


def ndsi(green: Tensor, swir: Tensor) -> Tensor:
"""Compute Normalized Different Snow Index (NDSI).
Expand Down Expand Up @@ -116,6 +132,46 @@ def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
return sample


class AppendNBR(Module): # type: ignore[misc,name-defined]
"""Normalized Burn Ratio (NBR).
.. versionadded:: 0.2.0
"""

def __init__(self, index_nir: int, index_swir: int) -> None:
"""Initialize a new transform instance.
Args:
index_nir: index of the Near Infrared (NIR) band in the image
index_swir: index of the Short-wave Infrared (SWIR) band in the image
"""
super().__init__()
self.dim = -3
self.index_nir = index_nir
self.index_swir = index_swir

def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""Create a band for NBR and append to image channels.
Args:
sample: a single data sample
Returns:
a sample where the image has an additional channel representing NBR
"""
if "image" in sample:
index = nbr(
nir=sample["image"][:, self.index_nir],
swir=sample["image"][:, self.index_swir],
)
index = index.unsqueeze(self.dim)
sample["image"] = torch.cat( # type: ignore[attr-defined]
[sample["image"], index], dim=self.dim
)

return sample


class AppendNDSI(Module): # type: ignore[misc,name-defined]
"""Normalized Difference Snow Index (NDSI).
Expand Down

0 comments on commit b402252

Please sign in to comment.