Skip to content

Commit

Permalink
pyfabm: add NamedObjectList.index, improve annotations and mask access (
Browse files Browse the repository at this point in the history
#95)

* pyfabm: make index method of NamedObjectList take str

* suppress exception chaining

* improve type annotation

* improve mask access
  • Loading branch information
jornbr authored Oct 21, 2024
1 parent 1509670 commit effe676
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions src/pyfabm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,19 +886,27 @@ def __init__(self, *data: Iterable[T]):
def __len__(self) -> int:
return len(self._data)

def __getitem__(self, key: Union[str, int]) -> T:
def __getitem__(self, key: Union[int, str]) -> T:
if isinstance(key, str):
return self.find(key)
return self._data[key]

def __contains__(self, key: Union[str, int]) -> bool:
def __contains__(self, key: Union[T, str]) -> bool:
if isinstance(key, str):
try:
self.find(key)
return True
except KeyError:
return False
return super().__contains__(key)
return key in self._data

def index(self, key: Union[T, str], *args) -> int:
if isinstance(key, str):
try:
key = self.find(key)
except KeyError:
raise ValueError from None
return self._data.index(key, *args)

def __repr__(self) -> str:
return repr(self._data)
Expand Down Expand Up @@ -1079,10 +1087,15 @@ def link_mask(self, *masks: np.ndarray):
self._mask = masks
self.fabm.set_mask(self.pmodel, *self._mask)

def _get_mask(self) -> Union[np.ndarray, Sequence[np.ndarray]]:
return self._mask[0] if len(self._mask) == 1 else self._mask

def _set_mask(self, values: Union[npt.ArrayLike, Sequence[npt.ArrayLike]]):
@property
def mask(self) -> Union[np.ndarray, Sequence[np.ndarray], None]:
mask = self._mask
if mask is not None and len(mask) == 1:
mask = mask[0]
return mask

@mask.setter
def mask(self, values: Union[npt.ArrayLike, Sequence[npt.ArrayLike]]):
if self.fabm.mask_type == 1:
values = (values,)
if len(values) != self.fabm.mask_type:
Expand All @@ -1096,8 +1109,6 @@ def _set_mask(self, values: Union[npt.ArrayLike, Sequence[npt.ArrayLike]]):
if value is not mask:
mask[...] = value

mask = property(_get_mask, _set_mask)

def link_bottom_index(self, indices: np.ndarray):
if not self.fabm.variable_bottom_index:
raise FABMException(
Expand Down Expand Up @@ -1401,7 +1412,7 @@ def _update_configuration(self, settings: Optional[Tuple] = None):
+ self.horizontal_dependencies
+ self.scalar_dependencies
)
self.variables = (
self.variables: NamedObjectList[VariableFromPointer] = (
self.state_variables + self.diagnostic_variables + self.dependencies
)

Expand All @@ -1414,7 +1425,7 @@ def _update_configuration(self, settings: Optional[Tuple] = None):

self.itime = -1.0

def getRates(self, t: float = None, surface: bool = True, bottom: bool = True):
def getRates(self, t: Optional[float] = None, surface: bool = True, bottom: bool = True):
"""Returns the local rate of change in state variables,
given the current state and environment.
"""
Expand Down Expand Up @@ -1451,7 +1462,7 @@ def getRates(self, t: float = None, surface: bool = True, bottom: bool = True):

def get_sources(
self,
t: float = None,
t: Optional[float] = None,
out: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
if t is None:
Expand Down

0 comments on commit effe676

Please sign in to comment.