Skip to content

Commit

Permalink
Improved typing and removed unused ignores
Browse files Browse the repository at this point in the history
as per mypy directives
  • Loading branch information
jsnel committed Aug 17, 2024
1 parent ba5c981 commit fbf1aa2
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ repos:
hooks:
- id: mypy
exclude: "docs|benchmark/|.*/tests?/.*"
additional_dependencies: [types-all, types-attrs]
additional_dependencies: [types-attrs, types-tabulate]

- repo: https://github.com/econchick/interrogate
rev: 1.5.0
Expand Down
15 changes: 8 additions & 7 deletions glotaran/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def __init__(self, error: str):


def _load_item_from_dict(
item_type: type[Item], value: Item | dict[str, Any], extra: dict[str, Any] | None = None
item_type: type[Item],
value: Item | dict[str, Any],
extra: dict[str, Any] | None = None,
) -> Item:
"""Load an item from a dictionary.
Expand Down Expand Up @@ -136,7 +138,7 @@ def _load_global_items_from_dict(


def _load_dataset_groups(
dataset_groups: dict[str, DatasetGroupModel | Any]
dataset_groups: dict[str, DatasetGroupModel | Any],
) -> dict[str, DatasetGroupModel]:
"""Add the default dataset group if not present.
Expand Down Expand Up @@ -340,7 +342,9 @@ def get_dataset_groups(self) -> dict[str, DatasetGroup]:
groups[group].dataset_models[dataset_model.label] = dataset_model
return groups

def iterate_items(self) -> Generator[tuple[str, dict[str, Item] | list[Item]], None, None]:
def iterate_items(
self,
) -> Generator[tuple[str, dict[str, Item] | list[Item]], None, None]:
"""Iterate items.
Yields
Expand Down Expand Up @@ -387,10 +391,7 @@ def generate_parameters(self) -> Parameters:
.. # noqa: D414
"""
return Parameters(
{
label: Parameter(label=label, value=0) # type:ignore[call-arg]
for label in self.get_parameter_labels()
}
{label: Parameter(label=label, value=0) for label in self.get_parameter_labels()}
)

def get_issues(self, *, parameters: Parameters | None = None) -> list[ItemIssue]:
Expand Down
9 changes: 6 additions & 3 deletions glotaran/optimization/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(self, scheme: Scheme, dataset_group: DatasetGroup):
)
self.add_model_weight(scheme.model, label, model_dimension, global_dimension)

self._data[label] = self.get_from_dataset( # type:ignore[assignment]
self._data[label] = self.get_from_dataset(
dataset, "data", model_dimension, global_dimension
)
if self._weight[label] is not None:
Expand Down Expand Up @@ -585,7 +585,7 @@ def align_dataset_indices(self, aligned_global_axes: dict[str, ArrayLike]) -> li

@staticmethod
def align_groups(
aligned_global_axes: dict[str, ArrayLike]
aligned_global_axes: dict[str, ArrayLike],
) -> tuple[ArrayLike, dict[str, list[str]]]:
"""Align the groups in a dataset group.
Expand Down Expand Up @@ -622,7 +622,10 @@ def align_groups(
for i, group_label in enumerate(aligned_group_labels):
if group_label not in group_definitions:
group_definitions[group_label] = list(
filter(lambda label: label != "", aligned_groups.isel({"global": i}).data)
filter(
lambda label: label != "",
aligned_groups.isel({"global": i}).data,
)
)
return aligned_group_labels, group_definitions

Expand Down
30 changes: 17 additions & 13 deletions glotaran/optimization/estimation_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,10 @@ def get_result(
).clp_labels
clps[label] = xr.DataArray(
np.array(self._clps[label]).reshape((len(global_clp_labels), len(clp_labels))),
coords={"global_clp_label": global_clp_labels, "clp_label": clp_labels},
coords={
"global_clp_label": global_clp_labels,
"clp_label": clp_labels,
},
dims=["global_clp_label", "clp_label"],
)

Expand Down Expand Up @@ -397,8 +400,8 @@ def calculate_estimation(self, dataset_model: DatasetModel):
The dataset model.
"""
label = dataset_model.label
self._clps[label].clear() # type:ignore[union-attr]
self._residuals[label].clear() # type:ignore[union-attr]
self._clps[label].clear()
self._residuals[label].clear()

global_axis = self._data_provider.get_global_axis(label)
data = self._data_provider.get_data(label)
Expand All @@ -411,14 +414,19 @@ def calculate_estimation(self, dataset_model: DatasetModel):
)
clp_labels.append(self._matrix_provider.get_matrix_container(label).clp_labels)
clp = self.retrieve_clps(
clp_labels[index], matrix_container.clp_labels, reduced_clps, global_index_value
clp_labels[index],
matrix_container.clp_labels,
reduced_clps,
global_index_value,
)

self._clps[label].append(clp) # type:ignore[union-attr]
self._residuals[label].append(residual) # type:ignore[union-attr]
self._clps[label].append(clp)
self._residuals[label].append(residual)

self._clp_penalty += self.calculate_clp_penalties(
clp_labels, self._clps[label], global_axis # type:ignore[arg-type]
clp_labels,
self._clps[label],
global_axis,
)


Expand All @@ -445,12 +453,8 @@ def __init__(
super().__init__(dataset_group)
self._data_provider = data_provider
self._matrix_provider = matrix_provider
self._clps: list[ArrayLike] = [
None # type:ignore[list-item]
] * self._data_provider.aligned_global_axis.size
self._residuals: list[ArrayLike] = [
None # type:ignore[list-item]
] * self._data_provider.aligned_global_axis.size
self._clps: list[ArrayLike] = [None] * self._data_provider.aligned_global_axis.size
self._residuals: list[ArrayLike] = [None] * self._data_provider.aligned_global_axis.size

def estimate(self):
"""Calculate the estimation."""
Expand Down
5 changes: 3 additions & 2 deletions glotaran/optimization/matrix_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def calculate_dataset_matrix(
clp_labels, matrix = MatrixProvider.combine_megacomplex_matrices(
matrix, this_matrix, clp_labels, this_clp_labels
)
return MatrixContainer(clp_labels, matrix) # type:ignore[arg-type]
return MatrixContainer(clp_labels, matrix)

@staticmethod
def combine_megacomplex_matrices(
Expand Down Expand Up @@ -718,7 +718,8 @@ def calculate_aligned_matrices(self):
]

group_matrix = self.align_matrices(
matrix_containers, matrix_scales # type:ignore[arg-type]
matrix_containers,
matrix_scales, # type:ignore[arg-type]
)

self._aligned_full_clp_labels[i] = full_clp_labels[group_label]
Expand Down

0 comments on commit fbf1aa2

Please sign in to comment.