Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: formalizing point-visiting strategies #177

Merged
merged 11 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/useq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
GridRowsColumns,
GridWidthHeight,
MultiPointPlan,
OrderMode,
RandomPoints,
RelativeMultiPointPlan,
Shape,
Expand All @@ -20,6 +19,7 @@
from useq._mda_sequence import MDASequence
from useq._plate import WellPlate, WellPlatePlan
from useq._plate_registry import register_well_plates, registered_well_plate_keys
from useq._point_visiting import OrderMode, TraversalOrder
from useq._position import AbsolutePosition, Position, RelativePosition
from useq._time import (
AnyTimePlan,
Expand Down Expand Up @@ -68,6 +68,7 @@
"TDurationLoops",
"TIntervalDuration",
"TIntervalLoops",
"TraversalOrder",
"WellPlate",
"WellPlatePlan",
"ZAboveBelow",
Expand Down
206 changes: 85 additions & 121 deletions src/useq/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,38 @@
import math
import warnings
from enum import Enum
from functools import partial
from typing import Any, Callable, Iterator, Optional, Sequence, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
Optional,
Sequence,
Tuple,
Union,
)

import numpy as np
from pydantic import Field, field_validator
from annotated_types import Ge, Gt # noqa: TCH002
from pydantic import Field, field_validator, model_validator

from useq._point_visiting import OrderMode, TraversalOrder
from useq._position import (
AbsolutePosition,
PositionT,
RelativePosition,
_MultiPointPlan,
)

MIN_RANDOM_POINTS = 5000
if TYPE_CHECKING:
from typing_extensions import Annotated, Self, TypeAlias

PointGenerator: TypeAlias = Callable[
[np.random.RandomState, int, float, float], Iterable[tuple[float, float]]
]

MIN_RANDOM_POINTS = 10000


class RelativeTo(Enum):
Expand All @@ -35,93 +53,7 @@ class RelativeTo(Enum):
top_left: str = "top_left"


class OrderMode(Enum):
"""Order in which grid positions will be iterated.

Attributes
----------
row_wise : Literal['row_wise']
Iterate row by row.
column_wise : Literal['column_wise']
Iterate column by column.
row_wise_snake : Literal['row_wise_snake']
Iterate row by row, but alternate the direction of the columns.
column_wise_snake : Literal['column_wise_snake']
Iterate column by column, but alternate the direction of the rows.
spiral : Literal['spiral']
Iterate in a spiral pattern, starting from the center.
"""

row_wise = "row_wise"
column_wise = "column_wise"
row_wise_snake = "row_wise_snake"
column_wise_snake = "column_wise_snake"
spiral = "spiral"


def _spiral_indices(
rows: int, columns: int, center_origin: bool = False
) -> Iterator[Tuple[int, int]]:
"""Return a spiral iterator over a 2D grid.

Parameters
----------
rows : int
Number of rows.
columns : int
Number of columns.
center_origin : bool
If center_origin is True, all indices are centered around (0, 0), and some will
be negative. Otherwise, the indices are centered around (rows//2, columns//2)

Yields
------
(x, y) : tuple[int, int]
Indices of the next element in the spiral.
"""
# direction: first down and then clockwise (assuming positive Y is down)

x = y = 0
if center_origin: # see docstring
xshift = yshift = 0
else:
xshift = (columns - 1) // 2
yshift = (rows - 1) // 2
dx = 0
dy = -1
for _ in range(max(columns, rows) ** 2):
if (-columns / 2 < x <= columns / 2) and (-rows / 2 < y <= rows / 2):
yield y + yshift, x + xshift
if x == y or (x < 0 and x == -y) or (x > 0 and x == 1 - y):
dx, dy = -dy, dx
x, y = x + dx, y + dy


# function that iterates indices (row, col) in a grid where (0, 0) is the top left
def _rect_indices(
rows: int, columns: int, snake: bool = False, row_wise: bool = True
) -> Iterator[Tuple[int, int]]:
"""Return a row or column-wise iterator over a 2D grid."""
c, r = np.meshgrid(np.arange(columns), np.arange(rows))
if snake:
if row_wise:
c[1::2, :] = c[1::2, :][:, ::-1]
else:
r[:, 1::2] = r[:, 1::2][::-1, :]
return zip(r.ravel(), c.ravel()) if row_wise else zip(r.T.ravel(), c.T.ravel())


# used in iter_indices below, to determine the order in which indices are yielded
IndexGenerator = Callable[[int, int], Iterator[Tuple[int, int]]]
_INDEX_GENERATORS: dict[OrderMode, IndexGenerator] = {
OrderMode.row_wise: partial(_rect_indices, snake=False, row_wise=True),
OrderMode.column_wise: partial(_rect_indices, snake=False, row_wise=False),
OrderMode.row_wise_snake: partial(_rect_indices, snake=True, row_wise=True),
OrderMode.column_wise_snake: partial(_rect_indices, snake=True, row_wise=False),
OrderMode.spiral: _spiral_indices,
}


class _GridPlan(_MultiPointPlan[PositionT]):
"""Base class for all grid plans.

Expand Down Expand Up @@ -199,12 +131,12 @@ def iter_grid_positions(
fov_width: float | None = None,
fov_height: float | None = None,
*,
mode: OrderMode | None = None,
order: OrderMode | None = None,
) -> Iterator[PositionT]:
"""Iterate over all grid positions, given a field of view size."""
_fov_width = fov_width or self.fov_width or 1.0
_fov_height = fov_height or self.fov_height or 1.0
mode = self.mode if mode is None else OrderMode(mode)
order = self.mode if order is None else OrderMode(order)

dx, dy = self._step_size(_fov_width, _fov_height)
rows = self._nrows(dy)
Expand All @@ -213,7 +145,7 @@ def iter_grid_positions(
y0 = self._offset_y(dy)

pos_cls = RelativePosition if self.is_relative else AbsolutePosition
for idx, (r, c) in enumerate(_INDEX_GENERATORS[mode](rows, cols)):
for idx, (r, c) in enumerate(order.generate_indices(rows, cols)):
yield pos_cls(
x=x0 + c * dx,
y=y0 - r * dy,
Expand Down Expand Up @@ -431,9 +363,9 @@ class RandomPoints(_MultiPointPlan[RelativePosition]):
num_points : int
Number of points to generate.
max_width : float
Maximum width of the bounding box.
Maximum width of the bounding box in microns.
max_height : float
Maximum height of the bounding box.
Maximum height of the bounding box in microns.
shape : Shape
Shape of the bounding box. Current options are "ellipse" and "rectangle".
random_seed : Optional[int]
Expand All @@ -442,39 +374,71 @@ class RandomPoints(_MultiPointPlan[RelativePosition]):
allow_overlap : bool
By defaut, True. If False and `fov_width` and `fov_height` are specified, points
will not overlap and will be at least `fov_width` and `fov_height apart.
order : TraversalOrder
Order in which the points will be visited. If None, order is simply the order
in which the points are generated (random). Use 'nearest_neighbor' or
'two_opt' to order the points in a more structured way.
start_at : int
Index of the point to start at. This is only used if `order` is
'nearest_neighbor' or 'two_opt'.
"""

num_points: int
max_width: float = np.inf
max_height: float = np.inf
num_points: Annotated[int, Gt(1)]
max_width: Annotated[float, Gt(0)] = 1
max_height: Annotated[float, Gt(0)] = 1
shape: Shape = Shape.ELLIPSE
random_seed: Optional[int] = None
allow_overlap: bool = True
order: TraversalOrder = TraversalOrder.TWO_OPT
start_at: Annotated[int, Ge(0)] = 0

@model_validator(mode="after")
def _validate_startat(self) -> Self:
if self.start_at > (self.num_points - 1):
warnings.warn(
"start_at is greater than the number of points. "
"Setting start_at to last point.",
stacklevel=2,
)
self.start_at = self.num_points - 1
return self

def __iter__(self) -> Iterator[RelativePosition]: # type: ignore [override]
seed = np.random.RandomState(self.random_seed)
func = _POINTS_GENERATORS[self.shape]
n_points = max(self.num_points, MIN_RANDOM_POINTS)
points: list[Tuple[float, float]] = []
for idx, (x, y) in enumerate(
func(seed, n_points, self.max_width, self.max_height)
):
if (
self.allow_overlap
or self.fov_width is None
or self.fov_height is None
or _is_a_valid_point(points, x, y, self.fov_width, self.fov_height)
):
yield RelativePosition(x=x, y=y, name=f"{str(idx).zfill(4)}")
points.append((x, y))
if len(points) >= self.num_points:
break

points: Iterable[Tuple[float, float]]
# in the easy case, just generate the requested number of points
if self.allow_overlap or self.fov_width is None or self.fov_height is None:
points = func(seed, self.num_points, self.max_width, self.max_height)

else:
warnings.warn(
f"Unable to generate {self.num_points} non-overlapping points. "
f"Only {len(points)} points were found.",
stacklevel=2,
)
# if we need to avoid overlap, generate points, check if they are valid, and
# repeat until we have enough
points = []
per_iter = 100
tries = 0
while tries < MIN_RANDOM_POINTS and len(points) < self.num_points:
candidates = func(seed, per_iter, self.max_width, self.max_height)
tries += per_iter
for p in candidates:
if _is_a_valid_point(points, *p, self.fov_width, self.fov_height):
points.append(p)
if len(points) >= self.num_points:
break

if len(points) < self.num_points:
warnings.warn(
f"Unable to generate {self.num_points} non-overlapping points. "
f"Only {len(points)} points were found.",
stacklevel=2,
)

if self.order is not None:
points = self.order(points, start_at=self.start_at)

for idx, (x, y) in enumerate(points):
yield RelativePosition(x=x, y=y, name=f"{str(idx).zfill(4)}")

def num_positions(self) -> int:
return self.num_points
Expand Down Expand Up @@ -504,8 +468,9 @@ def _random_points_in_ellipse(

The point is within +/- radius_x and +/- radius_y at a random angle.
"""
xy = np.sqrt(seed.uniform(0, 1, size=(n_points, 2)))
angle = seed.uniform(0, 2 * np.pi, size=n_points)
points = seed.uniform(0, 1, size=(n_points, 3))
xy = points[:, :2]
angle = points[:, 2] * 2 * np.pi
xy[:, 0] *= (max_width / 2) * np.cos(angle)
xy[:, 1] *= (max_height / 2) * np.sin(angle)
return xy
Expand All @@ -524,7 +489,6 @@ def _random_points_in_rectangle(
return xy


PointGenerator = Callable[[np.random.RandomState, int, float, float], np.ndarray]
_POINTS_GENERATORS: dict[Shape, PointGenerator] = {
Shape.ELLIPSE: _random_points_in_ellipse,
Shape.RECTANGLE: _random_points_in_rectangle,
Expand Down
31 changes: 16 additions & 15 deletions src/useq/_mda_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
from useq._grid import MultiPointPlan # noqa: TCH001
from useq._hardware_autofocus import AnyAutofocusPlan, AxesBasedAF
from useq._iter_sequence import iter_sequence
from useq._plate import WellPlatePlan # noqa: TCH001
from useq._plate import WellPlatePlan
from useq._position import Position, PositionBase
from useq._time import AnyTimePlan # noqa: TCH001
from useq._utils import AXES, Axis, TimeEstimate, estimate_sequence_duration
from useq._z import AnyZPlan # noqa: TCH001

if TYPE_CHECKING:
from typing_extensions import Self

from useq._mda_event import MDAEvent


Expand Down Expand Up @@ -282,27 +284,26 @@ def _validate_axis_order(cls, v: Any) -> tuple[str, ...]:
return order

@model_validator(mode="after")
@classmethod
def _validate_mda(cls, values: Any) -> Any:
if values.axis_order:
cls._check_order(
values.axis_order,
z_plan=values.z_plan,
stage_positions=values.stage_positions,
channels=values.channels,
grid_plan=values.grid_plan,
autofocus_plan=values.autofocus_plan,
def _validate_mda(self) -> Self:
if self.axis_order:
self._check_order(
self.axis_order,
z_plan=self.z_plan,
stage_positions=self.stage_positions,
channels=self.channels,
grid_plan=self.grid_plan,
autofocus_plan=self.autofocus_plan,
)
if values.stage_positions:
for p in values.stage_positions:
if self.stage_positions and not isinstance(self.stage_positions, WellPlatePlan):
for p in self.stage_positions:
if hasattr(p, "sequence") and getattr(
p.sequence, "keep_shutter_open_across", None
): # pragma: no cover
raise ValueError(
"keep_shutter_open_across cannot currently be set on a "
"Position sequence"
)
return values
return self

def __eq__(self, other: Any) -> bool:
"""Return `True` if two `MDASequences` are equal (uid is excluded)."""
Expand All @@ -315,7 +316,7 @@ def __eq__(self, other: Any) -> bool:

@staticmethod
def _check_order(
order: str,
order: tuple[str, ...],
z_plan: Optional[AnyZPlan] = None,
stage_positions: Sequence[Position] = (),
channels: Sequence[Channel] = (),
Expand Down
Loading
Loading