-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/init plan util #27
Changes from all commits
51ee308
017d492
c7e3d31
95265db
ed9133b
d152cc9
f941c81
80703d4
478376f
03459bc
237853f
c98ca0a
97639d5
e407081
e640617
92ba6c9
19c6b31
6c4d2ef
ddfd780
d8a6e58
b1803ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
import numpy as np | ||
import torch | ||
|
||
from ot import emd_1d | ||
from rich.console import Console | ||
from rich.progress import ( | ||
BarColumn, | ||
|
@@ -189,6 +190,68 @@ def _add_dict(d, new_d): | |
return d | ||
|
||
|
||
def init_plan_dense( | ||
n_source, | ||
n_target, | ||
weights_source=None, | ||
weights_target=None, | ||
method="entropic", | ||
): | ||
"""Initialize transport plan with dense tensor. | ||
|
||
Generate a matrix satisfying the constraints of a transport plan. | ||
In particular, marginal constraints on lines and columns are satisfied. | ||
|
||
Parameters | ||
---------- | ||
n_source: int | ||
Number of source points | ||
n_target: int | ||
Number of target points | ||
weights_source: torch.Tensor of size(n_source), optional, defaults to None | ||
Source weights used in entropic init | ||
weights_target: torch.Tensor of size(n_target), optional, defaults to None | ||
Target weights used in entropic init | ||
method: str, optional, defaults to "entropic" | ||
Method to use for initialization. | ||
Can be "entropic", "permutation" or "identity". | ||
If "entropic", weights_source and weights_target must be provided ; | ||
the initial plan is then given by the product of the two arrays. | ||
If "permutation", the initial plan is the solution to a 1D | ||
optimal transport problem between two random arrays, which can be | ||
understood as a soft permutation between source and target points. | ||
If "identity", the number of source and target points must be equal ; | ||
the initial plan is then the identity matrix. | ||
|
||
Returns | ||
------- | ||
init_plan: torch.Tensor of size(n_source, n_target) | ||
""" | ||
|
||
if method == "identity": | ||
assert n_source == n_target, ( | ||
"Number of source and target points must be equal " | ||
"when using identity initialization." | ||
) | ||
plan = torch.eye(n_source, dtype=torch.float32) | ||
plan = plan / plan.sum() | ||
elif method == "entropic": | ||
if weights_source is None: | ||
weights_source = torch.ones(n_source, dtype=torch.float32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same |
||
if weights_target is None: | ||
weights_target = torch.ones(n_target, dtype=torch.float32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same |
||
plan = weights_source[:, None] * weights_target[None, :] | ||
plan = plan / plan.sum() | ||
elif method == "permutation": | ||
xa = torch.rand(n_source) | ||
xb = torch.rand(n_target) | ||
plan = emd_1d(xa, xb).to(dtype=torch.float32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same |
||
else: | ||
raise Exception(f"Unknown initialisation method {method}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is my attempt to defend UK English over US English ;) |
||
|
||
return plan | ||
|
||
|
||
def save_mapping(mapping, fname): | ||
"""Save mapping in pickle file, separating hyperparams and weights. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you should replace
dtype=torch.float32
withdtype=torch.float64
to avoid casting issues.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default type we use in
_make_tensor()
istorch.float32
(ortorch.int32
), which is why I chose this type here 🙂fugw/src/fugw/utils.py
Lines 38 to 55 in ffb6fc0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay let's keep
float32
for coherence then ! 👍