Skip to content

Commit

Permalink
Python type stub improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
cgevans committed Jul 2, 2024
1 parent f0a1392 commit 342c998
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 6 deletions.
2 changes: 2 additions & 0 deletions py-rgrow/rgrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def _system_plot_canvas(
ax.text(j, i, n, ha="center", va="center", color="white")

if annotate_mismatches:
if isinstance(state, np.ndarray):
raise ValueError("Cannot currently annotate mismatches on a numpy array.")
mml = sys.calc_mismatch_locations(state)
for i, j in zip(*mml.nonzero()):
d = mml[i, j]
Expand Down
179 changes: 173 additions & 6 deletions py-rgrow/rgrow/rgrow.pyi
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
# flake8: noqa: PYI021
from typing import Any, List, Sequence, Self, TypeAlias
from typing import Any, List, Sequence, Self, TypeAlias, overload
from numpy import dtype, ndarray
import numpy as np
import polars as pl
from numpy.typing import NDArray

class ATAM:

@property
def tile_names(self) -> list[str]: ...

@property
def tile_colors(self) -> NDArray[np.uint]: ...

def calc_dimers(self) -> List[DimerInfo]:
"""
Calculate information about the dimers the system is able to form.
Expand Down Expand Up @@ -55,6 +62,41 @@ class ATAM:
Calculate the location and direction of mismatches, not jus the number.
"""



@overload
def evolve(
self,
state: State,
for_events: int | None = None,
total_events: int | None = None,
for_time: float | None = None,
total_time: float | None = None,
size_min: int | None = None,
size_max: int | None = None,
for_wall_time: float | None = None,
require_strong_bound: bool = True,
show_window: bool = False,
parallel: bool = True,
) -> EvolveOutcome: ...

@overload
def evolve(
self,
state: Sequence[State],
for_events: int | None = None,
total_events: int | None = None,
for_time: float | None = None,
total_time: float | None = None,
size_min: int | None = None,
size_max: int | None = None,
for_wall_time: float | None = None,
require_strong_bound: bool = True,
show_window: bool = False,
parallel: bool = True,
) -> List[EvolveOutcome]: ...

@overload
def evolve(
self,
state: State | Sequence[State],
Expand Down Expand Up @@ -110,6 +152,8 @@ class ATAM:

def get_param(self, param_name): ...
def print_debug(self): ...

@staticmethod
def read_json(filename: str) -> None:
"""
Read a system from a JSON file.
Expand Down Expand Up @@ -211,6 +255,13 @@ class ATAM:
filename : str
The name of the file to write to.
"""

def color_canvas(self, state: State | FFSStateRef | NDArray[np.uint]) -> NDArray[np.uint8]:
...

def name_canvas(self, state: State | FFSStateRef | NDArray[np.uint]) -> NDArray[np.str_]:
...


class EvolveBounds:
def __init__(self, for_time: float | None = None): ...
Expand Down Expand Up @@ -274,12 +325,25 @@ class FFSStateRef:
"""A copy of the state's canvas. This is safe, but can't be modified and is slower than `canvas_view`."""

def clone_state(self): ...
def n_tiles(self): ...
def time(self): ...
def total_events(self): ...
def tracking_copy(self): ...


@property
def canvas_view(self) -> NDArray[np.uint]:
...

def n_tiles(self) -> int: ...
def time(self) -> float: ...
def total_events(self) -> int: ...
def tracking_copy(self) -> Any: ...

class KTAM:
@property
def tile_names(self) -> list[str]: ...

@property
def tile_colors(self) -> NDArray[np.uint]: ...


def calc_dimers(self) -> List[DimerInfo]:
"""
Calculate information about the dimers the system is able to form.
Expand Down Expand Up @@ -329,6 +393,40 @@ class KTAM:
Calculate the location and direction of mismatches, not jus the number.
"""


@overload
def evolve(
self,
state: State,
for_events: int | None = None,
total_events: int | None = None,
for_time: float | None = None,
total_time: float | None = None,
size_min: int | None = None,
size_max: int | None = None,
for_wall_time: float | None = None,
require_strong_bound: bool = True,
show_window: bool = False,
parallel: bool = True,
) -> EvolveOutcome: ...

@overload
def evolve(
self,
state: Sequence[State],
for_events: int | None = None,
total_events: int | None = None,
for_time: float | None = None,
total_time: float | None = None,
size_min: int | None = None,
size_max: int | None = None,
for_wall_time: float | None = None,
require_strong_bound: bool = True,
show_window: bool = False,
parallel: bool = True,
) -> List[EvolveOutcome]: ...

@overload
def evolve(
self,
state: State | Sequence[State],
Expand Down Expand Up @@ -385,6 +483,8 @@ class KTAM:
def from_tileset(tileset): ...
def get_param(self, param_name): ...
def print_debug(self): ...

@staticmethod
def read_json(filename: str) -> None:
"""
Read a system from a JSON file.
Expand Down Expand Up @@ -486,8 +586,28 @@ class KTAM:
filename : str
The name of the file to write to.
"""
def color_canvas(self, state: State | FFSStateRef | NDArray[np.uint]) -> NDArray[np.uint8]:
...

def name_canvas(self, state: State | FFSStateRef | NDArray[np.uint]) -> NDArray[np.str_]:
...


class OldKTAM:

@property
def tile_names(self) -> list[str]: ...

@property
def tile_colors(self) -> NDArray[np.uint]: ...

def color_canvas(self, state: State | FFSStateRef | NDArray[np.uint]) -> NDArray[np.uint8]:
...

def name_canvas(self, state: State | FFSStateRef | NDArray[np.uint]) -> NDArray[np.str_]:
...


def calc_dimers(self) -> List[DimerInfo]:
"""
Calculate information about the dimers the system is able to form.
Expand Down Expand Up @@ -537,6 +657,39 @@ class OldKTAM:
Calculate the location and direction of mismatches, not jus the number.
"""

@overload
def evolve(
self,
state: State,
for_events: int | None = None,
total_events: int | None = None,
for_time: float | None = None,
total_time: float | None = None,
size_min: int | None = None,
size_max: int | None = None,
for_wall_time: float | None = None,
require_strong_bound: bool = True,
show_window: bool = False,
parallel: bool = True,
) -> EvolveOutcome: ...

@overload
def evolve(
self,
state: Sequence[State],
for_events: int | None = None,
total_events: int | None = None,
for_time: float | None = None,
total_time: float | None = None,
size_min: int | None = None,
size_max: int | None = None,
for_wall_time: float | None = None,
require_strong_bound: bool = True,
show_window: bool = False,
parallel: bool = True,
) -> List[EvolveOutcome]: ...

@overload
def evolve(
self,
state: State | Sequence[State],
Expand Down Expand Up @@ -592,6 +745,8 @@ class OldKTAM:

def get_param(self, param_name): ...
def print_debug(self): ...

@staticmethod
def read_json(filename: str) -> None:
"""
Read a system from a JSON file.
Expand Down Expand Up @@ -693,12 +848,13 @@ class OldKTAM:
filename : str
The name of the file to write to.
"""


System: TypeAlias = ATAM | KTAM | OldKTAM

class State:
@property
def canvas_view(self) -> ndarray[int, dtype[int]]:
def canvas_view(self) -> NDArray[np.uint]:
"""A view of the state's canvas. This is fast but unsafe."""

def canvas_copy(self) -> ndarray:
Expand All @@ -710,11 +866,22 @@ class State:
def read_json(filename: str) -> State: ...
def tracking_copy(self) -> ndarray: ...
def write_json(self, filename: str) -> None: ...

@property
def ntiles(self) -> int: ...

@property
def total_events(self) -> int: ...

@property
def time(self) -> float: ...


class TileSet:
def __init__(self, **kwargs: Any): ...
def create_state(self, system: System | None = None) -> State: ...
def create_state_from_canvas(self, canvas: NDArray[np.uint]) -> State: ...
def create_state_empty(self, system: System | None = None) -> State: ...
def create_system(self) -> System: ...
def create_system_and_state(self) -> tuple[System, State]: ...
@classmethod
Expand Down

0 comments on commit 342c998

Please sign in to comment.