Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dOTC #1787

Merged
merged 106 commits into from
Aug 7, 2024
Merged

dOTC #1787

Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
08c8ae8
OTC
LamAdr Jun 17, 2024
51d8d67
comments
LamAdr Jun 17, 2024
56f60c5
more comments
LamAdr Jun 17, 2024
366dd82
keep_attrs and remove _adjust_src
LamAdr Jun 17, 2024
771a1b3
dOTC
LamAdr Jun 18, 2024
b9725ea
Merge remote-tracking branch 'origin/main' into dOTC
LamAdr Jun 18, 2024
984c663
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2024
98a8085
pre-compile
LamAdr Jun 18, 2024
709d1ba
Merge remote-tracking branch 'origin/dOTC' into dOTC
LamAdr Jun 18, 2024
ef0dc58
keep the right attrs
LamAdr Jun 18, 2024
456976c
restructure
LamAdr Jun 18, 2024
e7c7123
map_groups
LamAdr Jun 19, 2024
0d64085
Inputs are Datasets
LamAdr Jun 19, 2024
f6122ec
MultivariateAdjust from PR#1580
LamAdr Jun 25, 2024
7a19690
test OTC vs SBCK
LamAdr Jun 25, 2024
092151b
series_Dataset -> series_dataset
LamAdr Jun 25, 2024
8a10e3d
dropna, rename core_dims and reindex_like
LamAdr Jun 25, 2024
2f34427
Merge branch 'main' into dOTC
Zeitsperre Jun 25, 2024
6307d45
test dOTC vs SBCK
LamAdr Jun 26, 2024
36e3125
dOTC rename time, reindex
LamAdr Jun 26, 2024
1f24e89
reindex correction
LamAdr Jun 26, 2024
6b31c4f
Merge branch 'main' of https://github.com/Ouranosinc/xclim into dOTC
LamAdr Jun 26, 2024
9cca4e1
Merge branch 'dOTC' of https://github.com/LamAdr/xclim into dOTC
LamAdr Jun 26, 2024
51cb784
sqeuclidean
LamAdr Jul 2, 2024
2e3ca2a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 2, 2024
326f1a8
numItermax doc
LamAdr Jul 2, 2024
8246d8d
revert Multivariate classes
LamAdr Jul 2, 2024
d615fea
Merge branch 'main' into dOTC
Zeitsperre Jul 2, 2024
0347d59
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 2, 2024
b41812d
Update sdba_utils.py
Zeitsperre Jul 2, 2024
229dbe8
inherits from Adjust
LamAdr Jul 2, 2024
d7fb1ac
update tests
LamAdr Jul 2, 2024
60861cb
use map_groups dim
LamAdr Jul 2, 2024
8e3a8a9
Adjust class pass as kwarg and signature changes
LamAdr Jul 3, 2024
88d57dd
docstrings
LamAdr Jul 3, 2024
b073548
pre-merge
LamAdr Jul 3, 2024
0790f3a
numItermax -> num_iter_max
LamAdr Jul 3, 2024
07e049e
cholesky ref
LamAdr Jul 3, 2024
ae1e6ec
sim _is_hist
LamAdr Jul 3, 2024
615eb06
test_shape
LamAdr Jul 3, 2024
8b67168
kind
LamAdr Jul 3, 2024
37de293
spray_bins
LamAdr Jul 3, 2024
c79ad13
spry_bins formulation
LamAdr Jul 4, 2024
5740984
speedup
LamAdr Jul 4, 2024
3d72ea3
halt cholesky
LamAdr Jul 4, 2024
ea3a6bc
Sturges is evil
LamAdr Jul 4, 2024
685b507
Scott
LamAdr Jul 4, 2024
602ed3a
window
LamAdr Jul 4, 2024
05e7b87
remove
LamAdr Jul 8, 2024
69fe3b8
POT optional
LamAdr Jul 8, 2024
b892331
pot opt 2
LamAdr Jul 8, 2024
531ec22
pot opt 3
LamAdr Jul 8, 2024
045a87a
sort codespell
LamAdr Jul 9, 2024
e58ba0d
test NAs
LamAdr Jul 9, 2024
da31098
authors
LamAdr Jul 9, 2024
3ef5cf1
cholesky iterations
LamAdr Jul 10, 2024
7239002
docstrings
LamAdr Jul 10, 2024
386d2e3
kind bug
LamAdr Jul 11, 2024
7644100
dask enforces map_blocks shape
LamAdr Jul 11, 2024
ad8e7a1
pts_dim
LamAdr Jul 11, 2024
fdf432b
pot opt 4
LamAdr Jul 11, 2024
6c02dda
opt dep tests and workflows
LamAdr Jul 11, 2024
041c4c5
doc
LamAdr Jul 15, 2024
791d7a5
Merge branch 'main' into dOTC
Zeitsperre Jul 15, 2024
3e30b1f
distance unit is now bin_width
LamAdr Jul 15, 2024
0e550a9
POT optional dependence doc
LamAdr Jul 15, 2024
5a87120
adapt_freq_thresh
LamAdr Jul 15, 2024
23c8fc9
Merge branch 'dOTC' of https://github.com/LamAdr/xclim into dOTC
LamAdr Jul 15, 2024
4e1796e
do not adapt_freq sim
LamAdr Jul 16, 2024
0a8b1c2
comment
LamAdr Jul 16, 2024
d9f94c2
dropna how=all
LamAdr Jul 16, 2024
7320b1b
notebook
LamAdr Jul 17, 2024
e0aeed5
docstrings
LamAdr Jul 17, 2024
d0a7f8b
changelog
LamAdr Jul 17, 2024
25168c4
Merge branch 'main' into dOTC
Zeitsperre Jul 23, 2024
7267993
spray_bins -> jitter_inside_bins
LamAdr Jul 23, 2024
8a64a87
test dask
LamAdr Jul 23, 2024
d67b8b1
Merge branch 'dOTC' of https://github.com/LamAdr/xclim into dOTC
LamAdr Jul 23, 2024
eb31e88
@ looks cool
LamAdr Jul 23, 2024
915587e
cov_factor
LamAdr Jul 23, 2024
5a4997f
Merge branch 'main' into dOTC
Zeitsperre Jul 24, 2024
0781a0a
Merge branch 'dOTC' of https://github.com/LamAdr/xclim into dOTC
LamAdr Jul 24, 2024
3384cef
convert thresh units
LamAdr Jul 24, 2024
ad6240b
missing adapt_freq_thresh
LamAdr Jul 24, 2024
8ce3cb0
validate cov_factor
LamAdr Jul 24, 2024
c61de5c
'identity' -> None + notebook explanation
LamAdr Jul 24, 2024
dc860cb
cholesky rescale
LamAdr Jul 25, 2024
2d46bfa
transform options
LamAdr Aug 5, 2024
d2f4940
transform doc
LamAdr Aug 5, 2024
908289b
adapt_freq_thresh modification
LamAdr Aug 5, 2024
02b2e66
adapt_freq_thresh bug 2
LamAdr Aug 5, 2024
da5c51d
notebook transform
LamAdr Aug 5, 2024
7e0a30b
bin_* are dicts
LamAdr Aug 5, 2024
b15efc0
differences from SBCK
LamAdr Aug 5, 2024
5f7b747
bin_* can be float
LamAdr Aug 6, 2024
afa0432
kind and adapt_freq_thresh can be str
LamAdr Aug 6, 2024
7b6c4ad
notebook correlations : center title
LamAdr Aug 6, 2024
dd375a7
Merge branch 'main' into dOTC
Zeitsperre Aug 7, 2024
c6ec66f
list acceptable transforms
LamAdr Aug 7, 2024
b3e3a31
Merge branch 'dOTC' of https://github.com/LamAdr/xclim into dOTC
LamAdr Aug 7, 2024
df05bdd
docstring ORs
LamAdr Aug 7, 2024
9575db2
orcid
LamAdr Aug 7, 2024
167c406
Merge branch 'main' into dOTC
Zeitsperre Aug 7, 2024
b66e05a
importorskip SBCK
LamAdr Aug 7, 2024
1ddaa96
Merge branch 'dOTC' of https://github.com/LamAdr/xclim into dOTC
LamAdr Aug 7, 2024
73aa04c
Merge branch 'main' into dOTC
Zeitsperre Aug 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ values = [

[tool.codespell]
skip = 'xclim/data/*.json,docs/_build,docs/notebooks/xclim_training/*.ipynb,docs/references.bib,__pycache__,*.gz,*.nc,*.png,*.svg,*.whl'
ignore-words-list = "absolue,bloc,bui,callendar,degreee,environnement,hanel,inferrable,lond,nam,nd,ressources,socio-economic,sie,vas"
ignore-words-list = "absolue,bloc,bui,callendar,degreee,environnement,hanel,inferrable,lond,nam,nd,ressources,socio-economic,sie,vas,ot"

[tool.coverage.run]
relative_files = true
Expand Down
274 changes: 274 additions & 0 deletions xclim/sdba/_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,3 +615,277 @@ def extremes_adjust(
adjusted: xr.DataArray = (transition * scen) + ((1 - transition) * ds.scen)
out = adjusted.rename("scen").squeeze("group", drop=True).to_dataset()
return out


def _otc_adjust(
X: np.ndarray,
Y: np.ndarray,
src: np.ndarray | None = None,
bin_width: np.ndarray | None = None,
bin_origin: np.ndarray | None = None,
numItermax: int | None = 100_000_000,
):
"""Optimal Transport Correction of the bias of X with respect to Y.

Parameters
----------
X : np.ndarray
Historical data to be corrected.
Y : np.ndarray
Bias correction reference, target of optimal transport.
src : np.ndarray | None
If not None, X is still used to compute the optimal transport, but it is used to correct src.
LamAdr marked this conversation as resolved.
Show resolved Hide resolved
bin_width : np.ndarray | None
Bin widths for all dimensions.
bin_origin : np.ndarray | None
Bin origins for all dimensions.
numItermax : int | None
Number of iterations of the earth mover distance algorithm.
LamAdr marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
np.ndarray
Adjusted data
"""
# Remove NaNs
Y = Y[~np.isnan(Y).any(axis=1), :]
X = X[~np.isnan(X).any(axis=1), :]

# Initialize parameters
bin_width = u.bin_width_estimator([Y, X]) if bin_width is None else bin_width
bin_origin = np.zeros(bin_width.size) if bin_origin is None else bin_origin
numItermax = 100_000_000 if numItermax is None else numItermax

# Get the bin positions and frequencies of X and Y
gridX, muX, idx_binX = u.histogram(X, bin_width, bin_origin)
gridY, muY, _ = u.histogram(Y, bin_width, bin_origin)

# Compute the optimal transportation plan
plan = u.optimal_transport(gridX, gridY, muX, muY, numItermax)

# Get the source positions expressed in terms of bin indices
idx_gridX = ((gridX - bin_origin) / bin_width).astype(int)

if src is not None:
# Source data is different from the data used to compute the transportation plan
src = src[~np.isnan(src).any(axis=1), :]
# Find the bin indices of source points
idx_binX = (src - bin_origin) / bin_width
out = np.empty(src.shape)
else:
out = np.empty(X.shape)

idx_binX = idx_binX.astype(int)
rng = np.random.default_rng()

# The plan row corresponding to a source bin indicates its probabilities to be transported to every target bin
for i, b in enumerate(idx_binX):
# Get the plan row of this source bin
pi = np.where((b == idx_gridX).all(1))[0][0]
# Pick one index of this plan row
choice = rng.choice(range(muY.size), p=plan[pi, :])
out[i] = gridY[choice]

return out


@map_groups(scen=[Grouper.DIM])
def otc_adjust(
ds: xr.Dataset,
dim: str,
bin_width: np.ndarray | None = None,
bin_origin: np.ndarray | None = None,
numItermax: int | None = 100_000_000,
):
"""Optimal Transport Correction of the bias of `hist` with respect to `ref`.

Notes
-----
Dataset variables:
ref : training target
hist : training data

Parameters
LamAdr marked this conversation as resolved.
Show resolved Hide resolved
----------
ds : xr.Dataset
The dataset containing the data.
dim : str
The dimension along which to compute the quantiles.
bin_width : np.ndarray | None
Bin widths for all dimensions.
bin_origin : np.ndarray | None
Bin origins for all dimensions.
numItermax : int | None
Number of iterations of the earth mover distance algorithm.

Returns
-------
xr.Dataset
Adjusted data
"""
scen = xr.apply_ufunc(
_otc_adjust,
ds.hist,
ds.ref,
kwargs=dict(
bin_width=bin_width,
bin_origin=bin_origin,
numItermax=numItermax,
),
join="outer",
input_core_dims=[["time", "multivar"], ["time", "multivar"]],
output_core_dims=[["time", "multivar"]],
keep_attrs=True,
vectorize=True,
).rename("scen")

return scen.T.to_dataset()


def _dotc_adjust(
X1: np.ndarray,
Y0: np.ndarray,
X0: np.ndarray,
bin_width: np.ndarray | None = None,
bin_origin: np.ndarray | None = None,
numItermax: int | None = 100_000_000,
cov_factor: str | None = "std",
):
"""Dynamical Optimal Transport Correction of the bias of X with respect to Y.

Parameters
LamAdr marked this conversation as resolved.
Show resolved Hide resolved
----------
X1 : np.ndarray
Simulation data to adjust.
Y0 : np.ndarray
Bias correction reference.
X0 : np.ndarray
Historical simulation data.
bin_width : np.ndarray | None
Bin widths for all dimensions.
bin_origin : np.ndarray | None
Bin origins for all dimensions.
numItermax : int | None
Number of iterations of the earth mover distance algorithm.
cov_factor : str | None = "std"
Rescaling factor.

Returns
-------
np.ndarray
Adjusted data
"""
# Remove NaNs
Y0 = Y0[~np.isnan(Y0).any(axis=1), :]
X0 = X0[~np.isnan(X0).any(axis=1), :]
X1 = X1[~np.isnan(X1).any(axis=1), :]
LamAdr marked this conversation as resolved.
Show resolved Hide resolved

# Initialize parameters
bin_width = u.bin_width_estimator([Y0, X0, X1]) if bin_width is None else bin_width
if cov_factor == "cholesky":
fact0 = u.eps_cholesky(np.cov(Y0, rowvar=False))
fact1 = u.eps_cholesky(np.cov(X0, rowvar=False))
cov_factor = np.dot(fact0, np.linalg.inv(fact1))
elif cov_factor == "std":
fact0 = np.std(Y0, axis=0)
fact1 = np.std(X0, axis=0)
cov_factor = np.diag(fact0 / fact1)
else:
cov_factor = np.identity(Y0.shape[1])

# Map ref to hist
yX0 = _otc_adjust(
Y0, X0, bin_width=bin_width, bin_origin=bin_origin, numItermax=numItermax
)

# Map hist to sim
yX1 = _otc_adjust(
X0,
X1,
src=yX0,
bin_width=bin_width,
bin_origin=bin_origin,
numItermax=numItermax,
)

# Temporal evolution
motion = yX1 - yX0
# Apply a variance dependent rescaling factor
motion = np.apply_along_axis(lambda x: np.dot(cov_factor, x), 1, motion)

# Apply the evolution to ref
Y1 = Y0 + motion

# Map sim to the evolution of ref
Z1 = _otc_adjust(
X1, Y1, bin_width=bin_width, bin_origin=bin_origin, numItermax=numItermax
)

return Z1


@map_groups(scen=[Grouper.DIM])
def dotc_adjust(
ds: xr.Dataset,
dim: str,
bin_width: np.ndarray | None = None,
bin_origin: np.ndarray | None = None,
numItermax: int | None = 100_000_000,
LamAdr marked this conversation as resolved.
Show resolved Hide resolved
cov_factor: str | None = "std",
):
"""Dynamical Optimal Transport Correction of the bias of X with respect to Y.

Notes
-----
Dataset variables:
ref : training target
hist : training data
sim : simulated data

Parameters
LamAdr marked this conversation as resolved.
Show resolved Hide resolved
----------
ds : xr.Dataset
The dataset containing the data.
LamAdr marked this conversation as resolved.
Show resolved Hide resolved
dim : str
The dimension along which to compute the quantiles.
bin_width : np.ndarray | None
Bin widths for all dimensions.
bin_origin : np.ndarray | None
Bin origins for all dimensions.
numItermax : int | None
Number of iterations of the earth mover distance algorithm.
cov_factor : str | None = "std"
Rescaling factor.

Returns
-------
xr.DataArray
Adjusted data
"""
ref = ds.ref.dropna(dim="time").rename(time="time_cal")
hist = ds.hist.dropna(dim="time").rename(time="time_cal")
sim = ds.sim.dropna(dim="time").rename(time="time_tgt")

scen = xr.apply_ufunc(
_dotc_adjust,
sim,
ref,
hist,
kwargs=dict(
bin_width=bin_width,
bin_origin=bin_origin,
numItermax=numItermax,
cov_factor=cov_factor,
),
join="outer",
input_core_dims=[
["time_tgt", "multivar"],
["time_cal", "multivar"],
["time_cal", "multivar"],
],
output_core_dims=[["time_tgt", "multivar"]],
keep_attrs=True,
vectorize=True,
).rename("scen", time_tgt="time")

return scen.to_dataset()
Loading