Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement one-liner coregistration #267

Merged
merged 14 commits into from
Oct 18, 2022
221 changes: 204 additions & 17 deletions xdem/coreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_has_cv2 = False
import fiona
import geoutils as gu
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rasterio as rio
Expand All @@ -25,6 +26,7 @@
import scipy.optimize
import skimage.transform
from geoutils import spatial_tools
from geoutils._typing import AnyNumber
from geoutils.georaster import RasterType
from rasterio import Affine
from tqdm import tqdm, trange
Expand Down Expand Up @@ -1483,6 +1485,7 @@ def apply_matrix(
centroid: tuple[float, float, float] | None = None,
resampling: int | str = "bilinear",
dilate_mask: bool = False,
fill_max_search: int = 0,
) -> NDArrayf:
"""
Apply a 3D transformation matrix to a 2.5D DEM.
Expand All @@ -1503,7 +1506,10 @@ def apply_matrix(
:param invert: Invert the transformation matrix.
:param centroid: The X/Y/Z transformation centroid. Irrelevant for pure translations. Defaults to the midpoint (Z=0)
:param resampling: The resampling method to use. Can be `nearest`, `bilinear`, `cubic` or an integer from 0-5.
:param dilate_mask: Dilate the nan mask to exclude edge pixels that could be wrong.
:param dilate_mask: DEPRECATED - This option does not do anything anymore. Will be removed in the future.
adehecq marked this conversation as resolved.
Show resolved Hide resolved
:param fill_max_search: Set to > 0 value to fill the DEM before applying the transformation, to avoid spreading\
gaps. The DEM will be filled with rasterio.fill.fillnodata with max_search_distance set to fill_max_search.\
This is experimental, use at your own risk !

:returns: The transformed DEM with NaNs as nodata values (replaces a potential mask of the input `dem`).
"""
Expand Down Expand Up @@ -1536,8 +1542,11 @@ def apply_matrix(

nan_mask = spatial_tools.get_mask(dem)
assert np.count_nonzero(~nan_mask) > 0, "Given DEM had all nans."
# Create a filled version of the DEM. (skimage doesn't like nans)
filled_dem = np.where(~nan_mask, demc, np.nan)
# Optionally, fill DEM around gaps to reduce spread of gaps
if fill_max_search > 0:
filled_dem = rio.fill.fillnodata(demc, mask=(~nan_mask).astype("uint8"), max_search_distance=fill_max_search)
else:
filled_dem = demc # np.where(~nan_mask, demc, np.nan) # I don't know why this was needed - to delete

# Get the centre coordinates of the DEM pixels.
x_coords, y_coords = _get_x_and_y_coords(demc.shape, transform)
Expand Down Expand Up @@ -1579,9 +1588,11 @@ def apply_matrix(
# Shift the elevation values of the soon-to-be-warped DEM.
filled_dem -= deramp(x_coords, y_coords)

# Create gap-free arrays of x and y coordinates to be converted into index coordinates.
x_inds = rio.fill.fillnodata(transformed_points[:, :, 0].copy(), mask=(~nan_mask).astype("uint8"))
y_inds = rio.fill.fillnodata(transformed_points[:, :, 1].copy(), mask=(~nan_mask).astype("uint8"))
# Create arrays of x and y coordinates to be converted into index coordinates.
x_inds = transformed_points[:, :, 0].copy()
x_inds[x_inds == 0] = np.nan
y_inds = transformed_points[:, :, 1].copy()
y_inds[y_inds == 0] = np.nan

# Divide the coordinates by the resolution to create index coordinates.
x_inds /= resolution
Expand All @@ -1601,19 +1612,20 @@ def apply_matrix(
transformed_dem = skimage.transform.warp(
filled_dem, inds, order=resampling_order, mode="constant", cval=np.nan, preserve_range=True
)
# Warp the NaN mask, setting true to all values outside the new frame.
tr_nan_mask = (
skimage.transform.warp(
nan_mask.astype("uint8"), inds, order=resampling_order, mode="constant", cval=1, preserve_range=True
)
> 0
)
# TODO: remove these lines when dilate_mask is deprecated
# # Warp the NaN mask, setting true to all values outside the new frame.
# tr_nan_mask = (
# skimage.transform.warp(
# nan_mask.astype("uint8"), inds, order=resampling_order, mode="constant", cval=1, preserve_range=True
# )
# > 0
# )

if dilate_mask:
tr_nan_mask = scipy.ndimage.binary_dilation(tr_nan_mask, iterations=resampling_order)
# if dilate_mask:
# tr_nan_mask = scipy.ndimage.binary_dilation(tr_nan_mask, iterations=resampling_order)

# Apply the transformed nan_mask
transformed_dem[tr_nan_mask] = np.nan
# # Apply the transformed nan_mask
# transformed_dem[tr_nan_mask] = np.nan

assert np.count_nonzero(~np.isnan(transformed_dem)) > 0, "Transformed DEM has all nans."

Expand Down Expand Up @@ -2141,3 +2153,178 @@ def warp_dem(
assert not np.all(np.isnan(warped)), "All-NaN output."

return warped.reshape(dem.shape)


hmodes_dict = {
"nuth_kaab": NuthKaab(),
"nuth_kaab_block": BlockwiseCoreg(coreg=NuthKaab(), subdivision=16),
"icp": ICP(),
}

vmodes_dict = {
"median": BiasCorr(bias_func=np.median),
"mean": BiasCorr(bias_func=np.mean),
"deramp": Deramp(),
}


def dem_coregistration(
src_dem_path: str,
ref_dem_path: str,
out_dem_path: str | None = None,
shpfile: str | None = None,
coreg_method: Coreg | None = None,
hmode: str = "nuth_kaab",
vmode: str = "median",
deramp_degree: int = 1,
grid: str = "ref",
filtering: bool = True,
slope_lim: list[AnyNumber] | tuple[AnyNumber, AnyNumber] = (0.1, 40),
plot: bool = False,
out_fig: str = None,
verbose: bool = False,
) -> tuple[xdem.DEM, pd.DataFrame]:
"""
A one-line function to coregister a selected DEM to a reference DEM.
Reads both DEMs, reprojects them on the same grid, mask content of shpfile, filter steep slopes and outliers, \
run the coregistration, returns the coregistered DEM and some statistics.
Optionally, save the coregistered DEM to file and make a figure.

:param src_dem_path: path to the input DEM to be coregistered
:param ref_dem: path to the reference DEM
:param out_dem_path: Path where to save the coregistered DEM. If set to None (default), will not save to file.
:param shpfile: path to a vector file containing areas to be masked for coregistration
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably use filename_vector for clarity with the parameter naming of GeoUtils functions, or maybe something similar than this function https://github.com/GlacioHack/xdem/blob/main/xdem/spatialstats.py#L445 for flexibility on inclusion or exclusion of the shapefile.

:param coreg_method: The xdem coregistration method, or pipeline. If set to None, DEMs will be resampled to \
ref grid and optionally filtered, but not coregistered. Will be used in priority over hmode and vmode.
:param hmode: The method to be used for horizontally aligning the DEMs, e.g. Nuth & Kaab or ICP. Can be any \
of {list(vmodes_dict.keys())}.
:param vmode: The method to be used for vertically aligning the DEMs, e.g. mean/median bias correction or \
deramping. Can be any of {list(hmodes_dict.keys())}.
:param deramp_degree: The degree of the polynomial for deramping.
:param grid: the grid to be used during coregistration, set either to "ref" or "src".
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name grid is not so transparent as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would you use?

:param filtering: if set to True, filtering will be applied prior to coregistration
:param plot: Set to True to plot a figure of elevation diff before/after coregistration
:param out_fig: Path to the output figure. If None will display to screen.
:param verbose: set to True to print details on screen during coregistration.

:returns: a tuple containing 1) coregistered DEM as an xdem.DEM instance and 2) DataFrame of coregistration \
statistics (count of obs, median and NMAD over stable terrain) before and after coreg.
"""
# Check input arguments
if (coreg_method is not None) and ((hmode is not None) or (vmode is not None)):
warnings.warn("Both `coreg_method` and `hmode/vmode` are set. Using coreg_method.")

if hmode not in list(hmodes_dict.keys()):
raise ValueError(f"vhmode must be in {list(hmodes_dict.keys())}")

if vmode not in list(vmodes_dict.keys()):
raise ValueError(f"vmode must be in {list(vmodes_dict.keys())}")

# Load both DEMs
if verbose:
print("Loading and reprojecting input data")
if grid == "ref":
ref_dem, src_dem = gu.spatial_tools.load_multiple_rasters([ref_dem_path, src_dem_path], ref_grid=0)
elif grid == "src":
ref_dem, src_dem = gu.spatial_tools.load_multiple_rasters([ref_dem_path, src_dem_path], ref_grid=1)
else:
raise ValueError(f"`grid` must be either 'ref' or 'src' - currently set to {grid}")

# Convert to DEM instance with Float32 dtype
ref_dem = xdem.DEM(ref_dem.astype(np.float32))
src_dem = xdem.DEM(src_dem.astype(np.float32))

# Create raster mask
if shpfile is not None:
outlines = gu.Vector(shpfile)
stable_mask = ~outlines.create_mask(src_dem)
else:
stable_mask = np.ones(src_dem.data.shape, dtype="bool")

# Calculate dDEM
ddem = src_dem - ref_dem

# Filter gross outliers in stable terrain
if filtering:
# Remove gross blunders where dh differ by 5 NMAD from the median
inlier_mask = stable_mask & (np.abs(ddem.data - np.median(ddem)) < 5 * xdem.spatialstats.nmad(ddem)).filled(
False
)

# Exclude steep slopes for coreg
slope = xdem.terrain.slope(ref_dem)
inlier_mask[slope.data < slope_lim[0]] = False
inlier_mask[slope.data > slope_lim[1]] = False

else:
inlier_mask = stable_mask

# Calculate dDEM statistics on pixels used for coreg
inlier_data = ddem.data[inlier_mask].compressed()
nstable_orig, mean_orig = len(inlier_data), np.mean(inlier_data)
med_orig, nmad_orig = np.median(inlier_data), xdem.spatialstats.nmad(inlier_data)

# Coregister to reference - Note: this will spread NaN
# Better strategy: calculate shift, update transform, resample
if isinstance(coreg_method, xdem.coreg.Coreg):
coreg_method.fit(ref_dem, src_dem, inlier_mask, verbose=verbose)
dem_coreg = coreg_method.apply(src_dem, dilate_mask=False)
elif coreg_method is None:
# Horizontal coregistration
hcoreg_method = hmodes_dict[hmode]
hcoreg_method.fit(ref_dem, src_dem, inlier_mask, verbose=verbose)
dem_hcoreg = hcoreg_method.apply(src_dem, dilate_mask=False)

# Vertical coregistration
vcoreg_method = vmodes_dict[vmode]
if vmode == "deramp":
vcoreg_method.degree = deramp_degree
vcoreg_method.fit(ref_dem, dem_hcoreg, inlier_mask, verbose=verbose)
dem_coreg = vcoreg_method.apply(dem_hcoreg, dilate_mask=False)

ddem_coreg = dem_coreg - ref_dem

# Calculate new stats
inlier_data = ddem_coreg.data[inlier_mask].compressed()
nstable_coreg, mean_coreg = len(inlier_data), np.mean(inlier_data)
med_coreg, nmad_coreg = np.median(inlier_data), xdem.spatialstats.nmad(inlier_data)

# Plot results
if plot:
# Max colorbar value - 98th percentile rounded to nearest 5
vmax = np.percentile(np.abs(ddem.data.compressed()), 98) // 5 * 5

plt.figure(figsize=(11, 5))

ax1 = plt.subplot(121)
plt.imshow(ddem.data.squeeze(), cmap="coolwarm_r", vmin=-vmax, vmax=vmax)
cb = plt.colorbar()
cb.set_label("Elevation change (m)")
ax1.set_title(f"Before coreg\n\nmean = {mean_orig:.1f} m - med = {med_orig:.1f} m - NMAD = {nmad_orig:.1f} m")

ax2 = plt.subplot(122, sharex=ax1, sharey=ax1)
plt.imshow(ddem_coreg.data.squeeze(), cmap="coolwarm_r", vmin=-vmax, vmax=vmax)
cb = plt.colorbar()
cb.set_label("Elevation change (m)")
ax2.set_title(
f"After coreg\n\n\nmean = {mean_coreg:.1f} m - med = {med_coreg:.1f} m - NMAD = {nmad_coreg:.1f} m"
)

plt.tight_layout()
if out_fig is None:
plt.show()
else:
plt.savefig(out_fig, dpi=200)
plt.close()

# Save coregistered DEM
if out_dem_path is not None:
dem_coreg.save(out_dem_path, tiled=True)

# Save stats to DataFrame
out_stats = pd.DataFrame(
((nstable_orig, med_orig, nmad_orig, nstable_coreg, med_coreg, nmad_coreg),),
columns=("nstable_orig", "med_orig", "nmad_orig", "nstable_coreg", "med_coreg", "nmad_coreg"),
)

return dem_coreg, out_stats