Skip to content

Commit

Permalink
Revert "refactor the gwproblem and fgwproblem inheritance"
Browse files Browse the repository at this point in the history
This reverts commit 45509e0.
  • Loading branch information
selmanozleyen committed Dec 10, 2024
1 parent 3766a04 commit f06d4c6
Showing 1 changed file with 92 additions and 64 deletions.
156 changes: 92 additions & 64 deletions src/moscot/problems/generic/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value]


class FGWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc]
"""Class for solving the :term:`FGW <fused Gromov-Wasserstein>` problem.
class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc]
"""Class for solving the :term:`GW <Gromov-Wasserstein>` or :term:`FGW <fused Gromov-Wasserstein>` problems.
Parameters
----------
Expand All @@ -281,23 +281,20 @@ def __init__(self, adata: AnnData, **kwargs: Any):
def prepare(
self,
key: str,
joint_attr: Optional[Union[str, Mapping[str, Any]]] = None,
x_attr: Optional[Union[str, Mapping[str, Any]]] = None,
y_attr: Optional[Union[str, Mapping[str, Any]]] = None,
policy: Literal["sequential", "explicit", "star"] = "sequential",
cost: OttCostFnMap_t = "sq_euclidean",
cost_kwargs: CostKwargs_t = types.MappingProxyType({}),
a: Optional[Union[bool, str]] = None,
b: Optional[Union[bool, str]] = None,
xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
subset: Optional[Sequence[Tuple[K, K]]] = None,
reference: Optional[Any] = None,
x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
xy_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
x_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
y_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
subset: Optional[Sequence[Tuple[K, K]]] = None,
reference: Optional[Any] = None,
) -> "FGWProblem[K, B]":
) -> "GWProblem[K, B]":
"""Prepare the individual :term:`quadratic subproblems <quadratic problem>`.
.. seealso::
Expand All @@ -307,16 +304,6 @@ def prepare(
----------
key
Key in :attr:`~anndata.AnnData.obs` for the :class:`~moscot.utils.subset_policy.SubsetPolicy`.
joint_attr
How to get the data for the :term:`linear term` in the :term:`fused <fused Gromov-Wasserstein>` case:
- :obj:`None` - run `PCA <https://en.wikipedia.org/wiki/Principal_component_analysis>`_
on :attr:`~anndata.AnnData.X` is computed.
- :class:`str` - a key in :attr:`~anndata.AnnData.obsm` where the data is stored.
- :class:`dict` - it should contain ``'attr'`` and ``'key'``, the attribute and the key
in :class:`~anndata.AnnData`, and optionally ``'tag'``, one of :class:`~moscot.utils.tagged_array.Tag`.
By default, :attr:`tag = 'point_cloud' <moscot.utils.tagged_array.Tag.POINT_CLOUD>` is used.
x_attr
How to get the data for the source :term:`quadratic term`:
Expand Down Expand Up @@ -368,18 +355,6 @@ def prepare(
:meth:`estimate the marginals <moscot.base.problems.OTProblem.estimate_marginals>`,
otherwise use uniform marginals.
- :obj:`None` - uniform marginals.
xy_callback
Callback function used to prepare the data in the :term:`linear term`.
x_callback
Callback function used to prepare the data in the source :term:`quadratic term`.
y_callback
Callback function used to prepare the data in the target :term:`quadratic term`.
xy_callback_kwargs
Keyword arguments for the ``xy_callback``.
x_callback_kwargs
Keyword arguments for the ``x_callback``.
y_callback_kwargs
Keyword arguments for the ``y_callback``.
Returns
-------
Expand All @@ -394,16 +369,15 @@ def prepare(
self.batch_key = key
x = set_quad_defaults(x_attr) if x_callback is None else {}
y = set_quad_defaults(y_attr) if y_callback is None else {}
xy, xy_callback, xy_callback_kwargs = handle_joint_attr(joint_attr, xy_callback, xy_callback_kwargs)

xy, x, y = handle_cost(
xy=xy,
xy={},
x=x,
y=y,
cost=cost,
cost_kwargs=cost_kwargs,
x_callback=x_callback,
y_callback=y_callback,
xy_callback=xy_callback,
cost_kwargs=cost_kwargs,
)
return super().prepare( # type: ignore[return-value]
key=key,
Expand All @@ -413,19 +387,16 @@ def prepare(
policy=policy,
a=a,
b=b,
reference=reference,
subset=subset,
x_callback=x_callback,
y_callback=y_callback,
xy_callback=xy_callback,
x_callback_kwargs=x_callback_kwargs,
y_callback_kwargs=y_callback_kwargs,
xy_callback_kwargs=xy_callback_kwargs,
subset=subset,
reference=reference,
)

def solve(
self,
alpha: float = 0.5,
epsilon: float = 1e-3,
tau_a: float = 1.0,
tau_b: float = 1.0,
Expand All @@ -442,7 +413,7 @@ def solve(
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
**kwargs: Any,
) -> "FGWProblem[K,B]":
) -> "GWProblem[K,B]":
r"""Solve the individual :term:`quadratic subproblems <quadratic problem>`.
.. seealso:
Expand All @@ -453,10 +424,6 @@ def solve(
Parameters
----------
alpha
Parameter in :math:`(0, 1)` that interpolates between the :term:`quadratic term` and
the :term:`linear term`. :math:`\alpha = 1` corresponds to the pure :term:`Gromov-Wasserstein` problem while
:math:`\alpha \to 0` corresponds to the pure :term:`linear problem`.
epsilon
:term:`Entropic regularization`.
tau_a
Expand Down Expand Up @@ -504,10 +471,8 @@ def solve(
- :attr:`solutions` - the :term:`OT` solutions for each subproblem.
- :attr:`stage` - set to ``'solved'``.
"""
if alpha == 1.0:
raise ValueError("The `FGWProblem` is equivalent to the `GWProblem` when `alpha=1.0`.")
return super().solve(
alpha=alpha,
return super().solve( # type: ignore[return-value]
alpha=1.0,
epsilon=epsilon,
tau_a=tau_a,
tau_b=tau_b,
Expand All @@ -524,7 +489,7 @@ def solve(
linear_solver_kwargs=linear_solver_kwargs,
device=device,
**kwargs,
) # type: ignore[return-value]
)

@property
def _base_problem_type(self) -> Type[B]:
Expand All @@ -535,8 +500,8 @@ def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value]


class GWProblem(FGWProblem[K, B]):
"""Class for solving the :term:`GW <Gromov-Wasserstein>` or :term:`FGW <fused Gromov-Wasserstein>` problems.
class FGWProblem(GWProblem[K, B]):
"""Class for solving the :term:`FGW <fused Gromov-Wasserstein>` problem.
Parameters
----------
Expand All @@ -549,20 +514,23 @@ class GWProblem(FGWProblem[K, B]):
def prepare(
self,
key: str,
joint_attr: Optional[Union[str, Mapping[str, Any]]] = None,
x_attr: Optional[Union[str, Mapping[str, Any]]] = None,
y_attr: Optional[Union[str, Mapping[str, Any]]] = None,
policy: Literal["sequential", "explicit", "star"] = "sequential",
cost: OttCostFnMap_t = "sq_euclidean",
cost_kwargs: CostKwargs_t = types.MappingProxyType({}),
a: Optional[Union[bool, str]] = None,
b: Optional[Union[bool, str]] = None,
subset: Optional[Sequence[Tuple[K, K]]] = None,
reference: Optional[Any] = None,
xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
xy_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
x_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
y_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
) -> "GWProblem[K, B]":
subset: Optional[Sequence[Tuple[K, K]]] = None,
reference: Optional[Any] = None,
) -> "FGWProblem[K, B]":
"""Prepare the individual :term:`quadratic subproblems <quadratic problem>`.
.. seealso::
Expand All @@ -572,6 +540,16 @@ def prepare(
----------
key
Key in :attr:`~anndata.AnnData.obs` for the :class:`~moscot.utils.subset_policy.SubsetPolicy`.
joint_attr
How to get the data for the :term:`linear term` in the :term:`fused <fused Gromov-Wasserstein>` case:
- :obj:`None` - run `PCA <https://en.wikipedia.org/wiki/Principal_component_analysis>`_
on :attr:`~anndata.AnnData.X` is computed.
- :class:`str` - a key in :attr:`~anndata.AnnData.obsm` where the data is stored.
- :class:`dict` - it should contain ``'attr'`` and ``'key'``, the attribute and the key
in :class:`~anndata.AnnData`, and optionally ``'tag'``, one of :class:`~moscot.utils.tagged_array.Tag`.
By default, :attr:`tag = 'point_cloud' <moscot.utils.tagged_array.Tag.POINT_CLOUD>` is used.
x_attr
How to get the data for the source :term:`quadratic term`:
Expand Down Expand Up @@ -623,6 +601,24 @@ def prepare(
:meth:`estimate the marginals <moscot.base.problems.OTProblem.estimate_marginals>`,
otherwise use uniform marginals.
- :obj:`None` - uniform marginals.
xy
Data for the :term:`linear term`.
x
Data for the source :term:`quadratic term`.
y
Data for the target :term:`quadratic term`.
xy_callback
Callback function used to prepare the data in the :term:`linear term`.
x_callback
Callback function used to prepare the data in the source :term:`quadratic term`.
y_callback
Callback function used to prepare the data in the target :term:`quadratic term`.
xy_callback_kwargs
Keyword arguments for the ``xy_callback``.
x_callback_kwargs
Keyword arguments for the ``x_callback``.
y_callback_kwargs
Keyword arguments for the ``y_callback``.
Returns
-------
Expand All @@ -634,25 +630,42 @@ def prepare(
- :attr:`stage` - set to ``'prepared'``.
- :attr:`problem_kind` - set to ``'quadratic'``.
"""
return super().prepare( # type: ignore[return-value]
self.batch_key = key
x = set_quad_defaults(x_attr) if x_callback is None else {}
y = set_quad_defaults(y_attr) if y_callback is None else {}
xy, xy_callback, xy_callback_kwargs = handle_joint_attr(joint_attr, xy_callback, xy_callback_kwargs)
xy, x, y = handle_cost(
xy=xy,
x=x,
y=y,
cost=cost,
x_callback=x_callback,
y_callback=y_callback,
xy_callback=xy_callback,
cost_kwargs=cost_kwargs,
)
return CompoundProblem.prepare(
self, # type: ignore[return-value, arg-type]
key=key,
xy=xy,
x=x,
y=y,
policy=policy,
a=a,
b=b,
x_attr=x_attr,
y_attr=y_attr,
cost=cost,
cost_kwargs=cost_kwargs,
reference=reference,
subset=subset, # type: ignore[arg-type]
x_callback=x_callback,
y_callback=y_callback,
xy_callback=xy_callback,
x_callback_kwargs=x_callback_kwargs,
y_callback_kwargs=y_callback_kwargs,
subset=subset,
reference=reference,
xy_callback_kwargs=xy_callback_kwargs,
)

def solve(
self,
alpha: float = 0.5,
epsilon: float = 1e-3,
tau_a: float = 1.0,
tau_b: float = 1.0,
Expand All @@ -669,7 +682,7 @@ def solve(
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
**kwargs: Any,
) -> "GWProblem[K,B]":
) -> "FGWProblem[K,B]":
r"""Solve the individual :term:`quadratic subproblems <quadratic problem>`.
.. seealso:
Expand All @@ -680,6 +693,10 @@ def solve(
Parameters
----------
alpha
Parameter in :math:`(0, 1)` that interpolates between the :term:`quadratic term` and
the :term:`linear term`. :math:`\alpha = 1` corresponds to the pure :term:`Gromov-Wasserstein` problem while
:math:`\alpha \to 0` corresponds to the pure :term:`linear problem`.
epsilon
:term:`Entropic regularization`.
tau_a
Expand Down Expand Up @@ -727,8 +744,11 @@ def solve(
- :attr:`solutions` - the :term:`OT` solutions for each subproblem.
- :attr:`stage` - set to ``'solved'``.
"""
return super().solve( # type: ignore[return-value]
alpha=1.0,
if alpha == 1.0:
raise ValueError("The `FGWProblem` is equivalent to the `GWProblem` when `alpha=1.0`.")
return CompoundProblem.solve(
self, # type: ignore[return-value, arg-type]
alpha=alpha,
epsilon=epsilon,
tau_a=tau_a,
tau_b=tau_b,
Expand All @@ -747,6 +767,14 @@ def solve(
**kwargs,
)

@property
def _base_problem_type(self) -> Type[B]:
return OTProblem # type: ignore[return-value]

@property
def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value]


class GENOTLinProblem(CondOTProblem):
"""Class for solving Conditional Parameterized Monge Map problems / Conditional Neural OT problems."""
Expand Down

0 comments on commit f06d4c6

Please sign in to comment.