-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/init plan util #27
Conversation
… into feature/allow_not_setting_convergence_criteria
…ub.com:alexisthual/fugw into feature/init_plan_util
…ub.com:alexisthual/fugw into feature/allow_not_setting_convergence_criteria
…ub.com:alexisthual/fugw into feature/init_plan_util
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent overall 👏, I just found a typo and I'm concerned about the use of torch.float32
that might lead to casting errors later on.
xb = torch.rand(n_target) | ||
plan = emd_1d(xa, xb).to(dtype=torch.float32) | ||
else: | ||
raise Exception(f"Unknown initialisation method {method}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo initialisation
-> initialization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is my attempt to defend UK English over US English ;)
"Number of source and target points must be equal " | ||
"when using identity initialization." | ||
) | ||
plan = torch.eye(n_source, dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you should replace dtype=torch.float32
with dtype=torch.float64
to avoid casting issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default type we use in _make_tensor()
is torch.float32
(or torch.int32
), which is why I chose this type here 🙂
Lines 38 to 55 in ffb6fc0
def _make_tensor(x, device=None, dtype=None): | |
"""Turn x into a torch.Tensor with suited type and device.""" | |
if isinstance(x, np.ndarray): | |
tensor = torch.tensor(x) | |
elif isinstance(x, torch.Tensor): | |
tensor = x | |
else: | |
raise Exception(f"Expected np.ndarray or torch.Tensor, got {type(x)}") | |
# By default, cast x to float32 or int32 | |
# depending on its original type | |
if dtype is None: | |
if tensor.is_floating_point(): | |
dtype = torch.float32 | |
else: | |
dtype = torch.int32 | |
return tensor.to(device=device, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay let's keep float32
for coherence then ! 👍
plan = plan / plan.sum() | ||
elif method == "entropic": | ||
if weights_source is None: | ||
weights_source = torch.ones(n_source, dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
if weights_source is None: | ||
weights_source = torch.ones(n_source, dtype=torch.float32) | ||
if weights_target is None: | ||
weights_target = torch.ones(n_target, dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
elif method == "permutation": | ||
xa = torch.rand(n_source) | ||
xb = torch.rand(n_target) | ||
plan = emd_1d(xa, xb).to(dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
Add util function to initialise ot plan.