Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/move taggedarray #457

Merged
merged 7 commits into from
Feb 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/moscot/_docs/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,10 @@
"""
_scale_cost = """\
scale_cost
Method to scale cost matrices. If `None` no scaling is applied.
How to rescale the cost matrix. Implemented scalings are
'median', 'mean', 'max_cost', 'max_norm' and 'max_bound'.
Alternatively, a float factor can be given to rescale the cost such
that ``cost_matrix /= scale_cost``.
"""
_cost_lin = """\
cost
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/backends/ott/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def select_values(last_k: Optional[int] = None) -> Tuple[str, jnp.ndarray, jnp.n
kind, values, xs = select_values(last_k)

ax.plot(xs, values, **kwargs)
ax.set_xlabel("iteration")
ax.set_xlabel("iteration (logged)")
ax.set_ylabel(kind)
if title is None:
title = "converged" if self.converged else "not converged" # type: ignore[attr-defined]
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/generic/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def push(
**kwargs: Any,
) -> Optional[ApplyOutput_t[K]]:
"""
Push distribution of cells through time.
Push distribution of cells from source to target.

Parameters
----------
Expand Down Expand Up @@ -160,7 +160,7 @@ def pull(
**kwargs: Any,
) -> Optional[ApplyOutput_t[K]]:
"""
Pull distribution of cells through time.
Pull distribution of cells from target to source.

Parameters
----------
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/problems/time/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def _get_data(
else:
raise ValueError(f"No data found for `{target}` time point.")

return ( # type: ignore[return-value]
return ( # type:ignore[return-value]
source_data,
growth_rates_source,
intermediate_data,
Expand Down
1 change: 1 addition & 0 deletions src/moscot/solvers/_tagged_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


def get_cost_function(cost: str, *, backend: Literal["ott"] = "ott", **kwargs: Any) -> Callable[..., Any]:
"""Get backend-dependent cost function."""
if backend == "ott":
from moscot.backends.ott._solver import OTTCost

Expand Down