diff --git a/geoutils/raster/georeferencing.py b/geoutils/raster/georeferencing.py index 604a718a..91859065 100644 --- a/geoutils/raster/georeferencing.py +++ b/geoutils/raster/georeferencing.py @@ -22,6 +22,7 @@ from __future__ import annotations +import math import warnings from typing import Iterable, Literal @@ -284,3 +285,49 @@ def _cast_nodata(out_dtype: DTypeLike, nodata: int | float | None) -> int | floa nodata = nodata return nodata + + +def _generate_tiling_grid( + row_min: int, + col_min: int, + row_max: int, + col_max: int, + row_split: int, + col_split: int, + overlap: int = 0, +) -> NDArrayNum: + """ + Generate a grid of positions by splitting [row_min, row_max] x + [col_min, col_max] into tiles of size row_split x col_split with optional overlap. + + :param row_min: Minimum row index of the bounding box to split. + :param col_min: Minimum column index of the bounding box to split. + :param row_max: Maximum row index of the bounding box to split. + :param col_max: Maximum column index of the bounding box to split. + :param row_split: Height of each tile. + :param col_split: Width of each tile. + :param overlap: size of overlapping between tiles (both vertically and horizontally). + :return: A numpy array grid with splits in two dimensions (0: row, 1: column), + where each cell contains [row_min, row_max, col_min, col_max]. + """ + # Calculate the number of splits considering overlap + nb_col_split = math.ceil((col_max - col_min) / (col_split - overlap)) + nb_row_split = math.ceil((row_max - row_min) / (row_split - overlap)) + + # Initialize the output grid + tiling_grid = np.zeros(shape=(nb_row_split, nb_col_split, 4), dtype=int) + + for row in range(nb_row_split): + for col in range(nb_col_split): + # Calculate the start of the tile + row_start = row_min + row * (row_split - overlap) + col_start = col_min + col * (col_split - overlap) + + # Calculate the end of the tile ensuring it doesn't exceed the bounds + row_end = min(row_max, row_start + row_split) + col_end = min(col_max, col_start + col_split) + + # Populate the grid with the tile boundaries + tiling_grid[row, col] = [row_start, row_end, col_start, col_end] + + return tiling_grid diff --git a/geoutils/raster/raster.py b/geoutils/raster/raster.py index 78570f0a..65851268 100644 --- a/geoutils/raster/raster.py +++ b/geoutils/raster/raster.py @@ -41,6 +41,7 @@ import rioxarray import xarray as xr from affine import Affine +from matplotlib.patches import Rectangle from mpl_toolkits.axes_grid1 import make_axes_locatable from packaging.version import Version from rasterio.crs import CRS @@ -77,6 +78,7 @@ _cast_pixel_interpretation, _coords, _default_nodata, + _generate_tiling_grid, _ij2xy, _outside_image, _res, @@ -2655,6 +2657,44 @@ def translate( raster_copy.transform = translated_transform return raster_copy + def compute_tiling( + self, + tile_size: int, + raster_ref: RasterType, + overlap: int = 0, + ) -> NDArrayNum: + """ + Compute the raster tiling grid to coregister raster by block. + + :param tile_size: Size of each tile (square tiles) + :param raster_ref: The other raster to coregister, use to validate the shape + :param overlap: Size of overlap between tiles (optional) + :return: tiling_grid (array of tile boundaries), new_shape (shape of the tiled grid) + """ + if self.shape != raster_ref.shape: + raise Exception("Reference and secondary rasters do not have the same shape") + row_max, col_max = self.shape + + # Generate tiling + tiling_grid = _generate_tiling_grid(0, 0, row_max, col_max, tile_size, tile_size, overlap=overlap) + return tiling_grid + + def plot_tiling(self, tiling_grid: NDArrayNum) -> None: + """ + Plot raster with its tiling. + + :param tiling_grid: tiling given by Raster.compute_tiling. + """ + ax, caxes = self.plot(return_axes=True) + for tile in tiling_grid.reshape(-1, 4): + row_min, row_max, col_min, col_max = tile + x_min, y_min = self.transform * (col_min, row_min) # Bottom-left corner + x_max, y_max = self.transform * (col_max, row_max) # Top-right corne + rect = Rectangle( + (x_min, y_min), x_max - x_min, y_max - y_min, edgecolor="red", facecolor="none", linewidth=1.5 + ) + ax.add_patch(rect) + def save( self, filename: str | pathlib.Path | IO[bytes], @@ -2922,6 +2962,38 @@ def intersection(self, raster: str | Raster, match_ref: bool = True) -> tuple[fl # mypy raises a type issue, not sure how to address the fact that output of merge_bounds can be () return intersection # type: ignore + @overload + def plot( + self, + bands: int | tuple[int, ...] | None = None, + cmap: matplotlib.colors.Colormap | str | None = None, + vmin: float | int | None = None, + vmax: float | int | None = None, + alpha: float | int | None = None, + cbar_title: str | None = None, + add_cbar: bool = True, + ax: matplotlib.axes.Axes | Literal["new"] | None = None, + *, + return_axes: Literal[False] = False, + **kwargs: Any, + ) -> None: ... + + @overload + def plot( + self, + bands: int | tuple[int, ...] | None = None, + cmap: matplotlib.colors.Colormap | str | None = None, + vmin: float | int | None = None, + vmax: float | int | None = None, + alpha: float | int | None = None, + cbar_title: str | None = None, + add_cbar: bool = True, + ax: matplotlib.axes.Axes | Literal["new"] | None = None, + *, + return_axes: Literal[True], + **kwargs: Any, + ) -> tuple[matplotlib.axes.Axes, matplotlib.colors.Colormap]: ... + def plot( self, bands: int | tuple[int, ...] | None = None, @@ -3059,8 +3131,7 @@ def plot( # If returning axes if return_axes: return ax0, cax - else: - return None + return None def reduce_points( self,