Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/handle solve args #748

Merged
merged 70 commits into from
Dec 12, 2024
Merged

Refactor/handle solve args #748

merged 70 commits into from
Dec 12, 2024

Conversation

selmanozleyen
Copy link
Collaborator

@selmanozleyen selmanozleyen commented Sep 22, 2024

hi @MUCDK ,

So good news is we currently do a good job on partitioning the kwargs for solve. In solve we give any kwarg we don't know to either SinkhornSolver or GWSolver constructors. SinkhornSolver uses Sinkhorn or LRSinkhorn from ottjax, these classes don't have kwargs in their constructors so when using SinkhornSolver as a backend we are good. GWSolver uses GromovWasserstein or LRGromovWasserstein from ottjax. The parent class of these class WassersteinSolver don't throw an error on unrecognized args. The tests will pass after the ottjax PR merges.

Here is the PR in ott-jax: ott-jax/ott#579

Other things done:

  • Added tests if we throw appropriate errors on completely unrecognized arguments
  • I refactored where we handle checks for alpha, now it's completely independent from CompoundProblem or any other more abstract class. It's handled in GWSolver as it should.
  • I added extra tests on the errors we raise for alpha or the data given.

Additionally closes:

@selmanozleyen selmanozleyen marked this pull request as draft September 22, 2024 21:35
Copy link
Collaborator

@MUCDK MUCDK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, overall really nice. Let's see what Mike and Marco say on the ott-jax side!

@selmanozleyen
Copy link
Collaborator Author

selmanozleyen commented Dec 10, 2024

Hi @MUCDK, except the notebooks I think it is ready to merge. But here are noteworthy side-effects that was caused by the ottjax update to 0.5.0

  • inner_iterations parameter removed since it was fixed to 1 in new version of ottjax
  • we no longer take str for initializers since ottjax dropped the support for this. We take initializer objects. I modified the tests accordingly but notebooks also need to comply to this

@selmanozleyen selmanozleyen requested a review from MUCDK December 10, 2024 19:59
tests/_utils.py Outdated
Comment on lines 108 to 140
def create_lr_initializer(
initializer,
rank,
**kwargs,
) -> lr_init_lib.LRInitializer: # noqa: D102
if isinstance(initializer, lr_init_lib.LRInitializer):
return initializer
if initializer == "random":
return lr_init_lib.RandomInitializer(rank=rank, **kwargs)
if initializer == "rank2":
return lr_init_lib.Rank2Initializer(rank=rank, **kwargs)
if initializer == "k-means":
return lr_init_lib.KMeansInitializer(rank=rank, **kwargs)
if initializer == "generalized-k-means":
return lr_init_lib.GeneralizedKMeansInitializer(rank=rank, **kwargs)
raise NotImplementedError(f"Initializer `{initializer}` is not yet implemented.")


def create_fr_initializer(
initializer,
**kwargs,
) -> init_lib.SinkhornInitializer: # noqa: D102
if isinstance(initializer, init_lib.SinkhornInitializer):
return initializer
if initializer == "default":
return init_lib.DefaultInitializer(**kwargs)
if initializer == "gaussian":
return init_lib.GaussianInitializer(**kwargs)
if initializer == "sorting":
return init_lib.SortingInitializer(**kwargs)
if initializer == "subsample":
return init_lib.SubsampleInitializer(**kwargs)
raise NotImplementedError(f"Initializer `{initializer}` is not yet implemented.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep strings for initializers , or do Literal[...], init_lib.SinkhornInitializer and port the code to src?

Copy link
Collaborator

@MUCDK MUCDK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great, thanks @selmanozleyen !

Wdyt about allowing strings for initialisers additionally? Might be too hard for people not familiar with these things otherwise.

@selmanozleyen
Copy link
Collaborator Author

This looks great, thanks @selmanozleyen !

Wdyt about allowing strings for initialisers additionally? Might be too hard for people not familiar with these things otherwise.

The issue there is differentiating the arguments required when creating the initializer objects vs the arguments required when calling initializers. For example you don't know if rank belongs to RandomInitializer or solve(rank=rank) in initializer_kwargs unless we have a hard coded list. We might not be able to use inspect module because some classes might be using kwargs in their constructors (because of this we won't know what the intializer takes). I will look if we can use inspect module but I think providing a class for creating initializers from strings is a better approach to clarify which argument goes where.

@MUCDK
Copy link
Collaborator

MUCDK commented Dec 11, 2024

This looks great, thanks @selmanozleyen !
Wdyt about allowing strings for initialisers additionally? Might be too hard for people not familiar with these things otherwise.

The issue there is differentiating the arguments required when creating the initializer objects vs the arguments required when calling initializers. For example you don't know if rank belongs to RandomInitializer or solve(rank=rank) in initializer_kwargs unless we have a hard coded list. We might not be able to use inspect module because some classes might be using kwargs in their constructors (because of this we won't know what the intializer takes). I will look if we can use inspect module but I think providing a class for creating initializers from strings is a better approach to clarify which argument goes where.

I see. Yes, then maybe let's create a class , but then call it from inside the solve method, because moscot users are not expected to know OOP. We are really supposed to do it in a simple manner.

@selmanozleyen selmanozleyen requested a review from MUCDK December 12, 2024 10:10
@selmanozleyen
Copy link
Collaborator Author

hi @MUCDK , I made the changes. If the user wants to create more advanced initializers (ie giving something to the initializer constructor as kwargs) they can do by creating it themselves. I think its reasonable to assume that someone who'd want to do that they can look ott initializers up.

Copy link
Collaborator

@MUCDK MUCDK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks great, thanks so much @selmanozleyen !

@MUCDK
Copy link
Collaborator

MUCDK commented Dec 12, 2024

Please merge, and then do a new release v0.4.0 of moscot!

@selmanozleyen selmanozleyen merged commit 29764d4 into main Dec 12, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants