-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor/handle solve args #748
Conversation
for more information, see https://pre-commit.ci
This reverts commit 0b06cc4.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, overall really nice. Let's see what Mike and Marco say on the ott-jax side!
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Hi @MUCDK, except the notebooks I think it is ready to merge. But here are noteworthy side-effects that was caused by the ottjax update to 0.5.0
|
This reverts commit ef50b55.
for more information, see https://pre-commit.ci
This reverts commit 45509e0.
tests/_utils.py
Outdated
def create_lr_initializer( | ||
initializer, | ||
rank, | ||
**kwargs, | ||
) -> lr_init_lib.LRInitializer: # noqa: D102 | ||
if isinstance(initializer, lr_init_lib.LRInitializer): | ||
return initializer | ||
if initializer == "random": | ||
return lr_init_lib.RandomInitializer(rank=rank, **kwargs) | ||
if initializer == "rank2": | ||
return lr_init_lib.Rank2Initializer(rank=rank, **kwargs) | ||
if initializer == "k-means": | ||
return lr_init_lib.KMeansInitializer(rank=rank, **kwargs) | ||
if initializer == "generalized-k-means": | ||
return lr_init_lib.GeneralizedKMeansInitializer(rank=rank, **kwargs) | ||
raise NotImplementedError(f"Initializer `{initializer}` is not yet implemented.") | ||
|
||
|
||
def create_fr_initializer( | ||
initializer, | ||
**kwargs, | ||
) -> init_lib.SinkhornInitializer: # noqa: D102 | ||
if isinstance(initializer, init_lib.SinkhornInitializer): | ||
return initializer | ||
if initializer == "default": | ||
return init_lib.DefaultInitializer(**kwargs) | ||
if initializer == "gaussian": | ||
return init_lib.GaussianInitializer(**kwargs) | ||
if initializer == "sorting": | ||
return init_lib.SortingInitializer(**kwargs) | ||
if initializer == "subsample": | ||
return init_lib.SubsampleInitializer(**kwargs) | ||
raise NotImplementedError(f"Initializer `{initializer}` is not yet implemented.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we keep strings for initializers , or do Literal[...], init_lib.SinkhornInitializer
and port the code to src?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great, thanks @selmanozleyen !
Wdyt about allowing strings for initialisers additionally? Might be too hard for people not familiar with these things otherwise.
The issue there is differentiating the arguments required when creating the initializer objects vs the arguments required when calling initializers. For example you don't know if |
I see. Yes, then maybe let's create a class , but then call it from inside the |
for more information, see https://pre-commit.ci
hi @MUCDK , I made the changes. If the user wants to create more advanced initializers (ie giving something to the initializer constructor as kwargs) they can do by creating it themselves. I think its reasonable to assume that someone who'd want to do that they can look ott initializers up. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That looks great, thanks so much @selmanozleyen !
Please merge, and then do a new release v0.4.0 of moscot! |
hi @MUCDK ,
So good news is we currently do a good job on partitioning the
kwargs
for solve. In solve we give anykwarg
we don't know to eitherSinkhornSolver
orGWSolver
constructors.SinkhornSolver
usesSinkhorn
orLRSinkhorn
fromottjax
, these classes don't havekwargs
in their constructors so when usingSinkhornSolver
as a backend we are good.GWSolver
usesGromovWasserstein
orLRGromovWasserstein
fromottjax
. The parent class of these classWassersteinSolver
don't throw an error on unrecognized args. The tests will pass after the ottjax PR merges.Here is the PR in
ott-jax
: ott-jax/ott#579Other things done:
CompoundProblem
or any other more abstract class. It's handled inGWSolver
as it should.Additionally closes:
solve
methods #720