Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/init plan util #27

Merged
merged 21 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
51ee308
Bugfix init plan not being on device
alexisthual May 27, 2023
017d492
Allow not setting convergence criteria and add test
alexisthual May 27, 2023
c7e3d31
Rename early_stopping_threshold and add tests
alexisthual May 27, 2023
95265db
Bugfix test
alexisthual May 27, 2023
ed9133b
Make stopping criteria more explicit in solvers and allow to not stop…
alexisthual May 27, 2023
d152cc9
Make stopping criteria more explicit in bcd iters and allow to not st…
alexisthual May 27, 2023
f941c81
Convert sparse init plan to CSR matrix
alexisthual May 27, 2023
80703d4
Add POT to dependencies
alexisthual May 27, 2023
478376f
Add util function for initializing dense ot plan
alexisthual May 27, 2023
03459bc
Add tests
alexisthual May 27, 2023
237853f
Fix test name and torch eye
alexisthual May 27, 2023
c98ca0a
Fix marginal constraints test
alexisthual May 27, 2023
97639d5
Add assertion failed error message
alexisthual May 27, 2023
e407081
Use init plan util in dense mapping
alexisthual May 27, 2023
e640617
Bugfix missing var name in util function
alexisthual May 27, 2023
92ba6c9
Merge branch 'bugfix/init_plan_device' of github.com:alexisthual/fugw…
alexisthual May 27, 2023
19c6b31
Merge branch 'feature/allow_not_setting_convergence_criteria' of gith…
alexisthual May 27, 2023
6c4d2ef
Fix incorrect stoping criterium
alexisthual May 27, 2023
ddfd780
Merge branch 'feature/allow_not_setting_convergence_criteria' of gith…
alexisthual May 27, 2023
d8a6e58
Merge branch 'feature/allow_not_setting_convergence_criteria' of gith…
alexisthual May 27, 2023
b1803ca
Merge branch 'main' of github.com:alexisthual/fugw into feature/init_…
alexisthual Jun 5, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ classifiers = [
dynamic = ["version"]
requires-python = ">=3.7"
dependencies = [
"dijkstra3d>=1.12.1",
"joblib>=1.2.0",
"numpy>=1.20",
"torch>=1.13",
"rich>=13.3.1",
"POT>=0.9.0",
"scikit-learn",
"scipy",
"torch>=1.13",
"tvb-gdist>=2.1.1",
"dijkstra3d>=1.12.1",
]

[project.optional-dependencies]
Expand Down
17 changes: 14 additions & 3 deletions src/fugw/mappings/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from fugw.solvers.dense import FUGWSolver
from fugw.mappings.utils import BaseMapping, console
from fugw.utils import _make_tensor
from fugw.utils import _make_tensor, init_plan_dense


class FUGW(BaseMapping):
Expand Down Expand Up @@ -85,6 +85,7 @@ def fit(
will be set to 1 / m.
init_plan: ndarray(n, m) or None
Transport plan to use at initialisation.
If None, an entropic initialization will be used.
init_duals: tuple of [ndarray(n), ndarray(m)] or None
Dual potentials to use at initialisation.
solver: "sinkhorn" or "mm" or "ibpp"
Expand Down Expand Up @@ -140,11 +141,21 @@ def fit(
else:
wt = _make_tensor(target_weights, device=device)

# If initial plan is provided, move it to device
# If initial plan is provided, move it to device.
# Otherwise, initialize it with entropic initialization
pi_init = (
_make_tensor(init_plan, device=device)
if init_plan is not None
else None
else _make_tensor(
init_plan_dense(
source_features.shape[1],
target_features.shape[1],
weights_source=ws,
weights_target=wt,
method="entropic",
),
device=device,
)
)

# Compute distance matrix between features
Expand Down
63 changes: 63 additions & 0 deletions src/fugw/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import torch

from ot import emd_1d
from rich.console import Console
from rich.progress import (
BarColumn,
Expand Down Expand Up @@ -189,6 +190,68 @@ def _add_dict(d, new_d):
return d


def init_plan_dense(
n_source,
n_target,
weights_source=None,
weights_target=None,
method="entropic",
):
"""Initialize transport plan with dense tensor.

Generate a matrix satisfying the constraints of a transport plan.
In particular, marginal constraints on lines and columns are satisfied.

Parameters
----------
n_source: int
Number of source points
n_target: int
Number of target points
weights_source: torch.Tensor of size(n_source), optional, defaults to None
Source weights used in entropic init
weights_target: torch.Tensor of size(n_target), optional, defaults to None
Target weights used in entropic init
method: str, optional, defaults to "entropic"
Method to use for initialization.
Can be "entropic", "permutation" or "identity".
If "entropic", weights_source and weights_target must be provided ;
the initial plan is then given by the product of the two arrays.
If "permutation", the initial plan is the solution to a 1D
optimal transport problem between two random arrays, which can be
understood as a soft permutation between source and target points.
If "identity", the number of source and target points must be equal ;
the initial plan is then the identity matrix.

Returns
-------
init_plan: torch.Tensor of size(n_source, n_target)
"""

if method == "identity":
assert n_source == n_target, (
"Number of source and target points must be equal "
"when using identity initialization."
)
plan = torch.eye(n_source, dtype=torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you should replace dtype=torch.float32 with dtype=torch.float64 to avoid casting issues.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default type we use in _make_tensor() is torch.float32 (or torch.int32), which is why I chose this type here 🙂

fugw/src/fugw/utils.py

Lines 38 to 55 in ffb6fc0

def _make_tensor(x, device=None, dtype=None):
"""Turn x into a torch.Tensor with suited type and device."""
if isinstance(x, np.ndarray):
tensor = torch.tensor(x)
elif isinstance(x, torch.Tensor):
tensor = x
else:
raise Exception(f"Expected np.ndarray or torch.Tensor, got {type(x)}")
# By default, cast x to float32 or int32
# depending on its original type
if dtype is None:
if tensor.is_floating_point():
dtype = torch.float32
else:
dtype = torch.int32
return tensor.to(device=device, dtype=dtype)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay let's keep float32 for coherence then ! 👍

plan = plan / plan.sum()
elif method == "entropic":
if weights_source is None:
weights_source = torch.ones(n_source, dtype=torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

if weights_target is None:
weights_target = torch.ones(n_target, dtype=torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

plan = weights_source[:, None] * weights_target[None, :]
plan = plan / plan.sum()
elif method == "permutation":
xa = torch.rand(n_source)
xb = torch.rand(n_target)
plan = emd_1d(xa, xb).to(dtype=torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

else:
raise Exception(f"Unknown initialisation method {method}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo initialisation -> initialization

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my attempt to defend UK English over US English ;)


return plan


def save_mapping(mapping, fname):
"""Save mapping in pickle file, separating hyperparams and weights.

Expand Down
28 changes: 28 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from fugw.utils import (
_init_mock_distribution,
_make_tensor,
init_plan_dense,
load_mapping,
save_mapping,
)
Expand Down Expand Up @@ -156,3 +157,30 @@ def test_saving_and_loading(device, return_numpy, solver):

weights = pickle.load(f)
assert weights.shape == (n_voxels_source, n_voxels_target)


@pytest.mark.parametrize(
"method", ["identity", "entropic", "permutation", "unknown"]
)
def test_init_plan(method):
n_source = 101
n_target = 99

if method == "unknown":
with pytest.raises(Exception, match="Unknown initialisation method.*"):
init_plan_dense(n_source, n_target, method=method)
else:
if method == "identity":
with pytest.raises(
AssertionError, match="Number of source and target.*"
):
init_plan_dense(n_source, n_target, method=method)

n_source = 100
n_target = 100

plan = init_plan_dense(n_source, n_target, method=method)
assert plan.shape == (n_source, n_target)
# Check that plan satisfies marginal constraints
assert torch.allclose(plan.sum(dim=0), torch.ones(n_target) / n_target)
assert torch.allclose(plan.sum(dim=1), torch.ones(n_source) / n_source)