Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/init plan util #27

Merged
merged 21 commits into from
Jun 5, 2023
Merged

Feature/init plan util #27

merged 21 commits into from
Jun 5, 2023

Conversation

alexisthual
Copy link
Owner

Add util function to initialise ot plan.

Copy link
Collaborator

@pbarbarant pbarbarant left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent overall 👏, I just found a typo and I'm concerned about the use of torch.float32 that might lead to casting errors later on.

xb = torch.rand(n_target)
plan = emd_1d(xa, xb).to(dtype=torch.float32)
else:
raise Exception(f"Unknown initialisation method {method}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo initialisation -> initialization

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my attempt to defend UK English over US English ;)

"Number of source and target points must be equal "
"when using identity initialization."
)
plan = torch.eye(n_source, dtype=torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you should replace dtype=torch.float32 with dtype=torch.float64 to avoid casting issues.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default type we use in _make_tensor() is torch.float32 (or torch.int32), which is why I chose this type here 🙂

fugw/src/fugw/utils.py

Lines 38 to 55 in ffb6fc0

def _make_tensor(x, device=None, dtype=None):
"""Turn x into a torch.Tensor with suited type and device."""
if isinstance(x, np.ndarray):
tensor = torch.tensor(x)
elif isinstance(x, torch.Tensor):
tensor = x
else:
raise Exception(f"Expected np.ndarray or torch.Tensor, got {type(x)}")
# By default, cast x to float32 or int32
# depending on its original type
if dtype is None:
if tensor.is_floating_point():
dtype = torch.float32
else:
dtype = torch.int32
return tensor.to(device=device, dtype=dtype)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay let's keep float32 for coherence then ! 👍

plan = plan / plan.sum()
elif method == "entropic":
if weights_source is None:
weights_source = torch.ones(n_source, dtype=torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

if weights_source is None:
weights_source = torch.ones(n_source, dtype=torch.float32)
if weights_target is None:
weights_target = torch.ones(n_target, dtype=torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

elif method == "permutation":
xa = torch.rand(n_source)
xb = torch.rand(n_target)
plan = emd_1d(xa, xb).to(dtype=torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

Base automatically changed from feature/allow_not_setting_convergence_criteria to main June 5, 2023 10:59
@alexisthual alexisthual merged commit 8775aac into main Jun 5, 2023
@alexisthual alexisthual deleted the feature/init_plan_util branch June 5, 2023 11:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants