From fbf1aa2670b44d70f1fc4ef9e0b7deafacc26d45 Mon Sep 17 00:00:00 2001 From: Joris Snellenburg Date: Sat, 17 Aug 2024 02:54:00 +0200 Subject: [PATCH] Improved typing and removed unused ignores as per mypy directives --- .pre-commit-config.yaml | 2 +- glotaran/model/model.py | 15 +++++----- glotaran/optimization/data_provider.py | 9 ++++-- glotaran/optimization/estimation_provider.py | 30 +++++++++++--------- glotaran/optimization/matrix_provider.py | 5 ++-- 5 files changed, 35 insertions(+), 26 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8a5634acf..d42bf5be0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -120,7 +120,7 @@ repos: hooks: - id: mypy exclude: "docs|benchmark/|.*/tests?/.*" - additional_dependencies: [types-all, types-attrs] + additional_dependencies: [types-attrs, types-tabulate] - repo: https://github.com/econchick/interrogate rev: 1.5.0 diff --git a/glotaran/model/model.py b/glotaran/model/model.py index bb58ad6bd..452b7118d 100644 --- a/glotaran/model/model.py +++ b/glotaran/model/model.py @@ -60,7 +60,9 @@ def __init__(self, error: str): def _load_item_from_dict( - item_type: type[Item], value: Item | dict[str, Any], extra: dict[str, Any] | None = None + item_type: type[Item], + value: Item | dict[str, Any], + extra: dict[str, Any] | None = None, ) -> Item: """Load an item from a dictionary. @@ -136,7 +138,7 @@ def _load_global_items_from_dict( def _load_dataset_groups( - dataset_groups: dict[str, DatasetGroupModel | Any] + dataset_groups: dict[str, DatasetGroupModel | Any], ) -> dict[str, DatasetGroupModel]: """Add the default dataset group if not present. @@ -340,7 +342,9 @@ def get_dataset_groups(self) -> dict[str, DatasetGroup]: groups[group].dataset_models[dataset_model.label] = dataset_model return groups - def iterate_items(self) -> Generator[tuple[str, dict[str, Item] | list[Item]], None, None]: + def iterate_items( + self, + ) -> Generator[tuple[str, dict[str, Item] | list[Item]], None, None]: """Iterate items. Yields @@ -387,10 +391,7 @@ def generate_parameters(self) -> Parameters: .. # noqa: D414 """ return Parameters( - { - label: Parameter(label=label, value=0) # type:ignore[call-arg] - for label in self.get_parameter_labels() - } + {label: Parameter(label=label, value=0) for label in self.get_parameter_labels()} ) def get_issues(self, *, parameters: Parameters | None = None) -> list[ItemIssue]: diff --git a/glotaran/optimization/data_provider.py b/glotaran/optimization/data_provider.py index 7864ec648..5b0c3ab25 100644 --- a/glotaran/optimization/data_provider.py +++ b/glotaran/optimization/data_provider.py @@ -67,7 +67,7 @@ def __init__(self, scheme: Scheme, dataset_group: DatasetGroup): ) self.add_model_weight(scheme.model, label, model_dimension, global_dimension) - self._data[label] = self.get_from_dataset( # type:ignore[assignment] + self._data[label] = self.get_from_dataset( dataset, "data", model_dimension, global_dimension ) if self._weight[label] is not None: @@ -585,7 +585,7 @@ def align_dataset_indices(self, aligned_global_axes: dict[str, ArrayLike]) -> li @staticmethod def align_groups( - aligned_global_axes: dict[str, ArrayLike] + aligned_global_axes: dict[str, ArrayLike], ) -> tuple[ArrayLike, dict[str, list[str]]]: """Align the groups in a dataset group. @@ -622,7 +622,10 @@ def align_groups( for i, group_label in enumerate(aligned_group_labels): if group_label not in group_definitions: group_definitions[group_label] = list( - filter(lambda label: label != "", aligned_groups.isel({"global": i}).data) + filter( + lambda label: label != "", + aligned_groups.isel({"global": i}).data, + ) ) return aligned_group_labels, group_definitions diff --git a/glotaran/optimization/estimation_provider.py b/glotaran/optimization/estimation_provider.py index 505ad9655..863a2e8f0 100644 --- a/glotaran/optimization/estimation_provider.py +++ b/glotaran/optimization/estimation_provider.py @@ -353,7 +353,10 @@ def get_result( ).clp_labels clps[label] = xr.DataArray( np.array(self._clps[label]).reshape((len(global_clp_labels), len(clp_labels))), - coords={"global_clp_label": global_clp_labels, "clp_label": clp_labels}, + coords={ + "global_clp_label": global_clp_labels, + "clp_label": clp_labels, + }, dims=["global_clp_label", "clp_label"], ) @@ -397,8 +400,8 @@ def calculate_estimation(self, dataset_model: DatasetModel): The dataset model. """ label = dataset_model.label - self._clps[label].clear() # type:ignore[union-attr] - self._residuals[label].clear() # type:ignore[union-attr] + self._clps[label].clear() + self._residuals[label].clear() global_axis = self._data_provider.get_global_axis(label) data = self._data_provider.get_data(label) @@ -411,14 +414,19 @@ def calculate_estimation(self, dataset_model: DatasetModel): ) clp_labels.append(self._matrix_provider.get_matrix_container(label).clp_labels) clp = self.retrieve_clps( - clp_labels[index], matrix_container.clp_labels, reduced_clps, global_index_value + clp_labels[index], + matrix_container.clp_labels, + reduced_clps, + global_index_value, ) - self._clps[label].append(clp) # type:ignore[union-attr] - self._residuals[label].append(residual) # type:ignore[union-attr] + self._clps[label].append(clp) + self._residuals[label].append(residual) self._clp_penalty += self.calculate_clp_penalties( - clp_labels, self._clps[label], global_axis # type:ignore[arg-type] + clp_labels, + self._clps[label], + global_axis, ) @@ -445,12 +453,8 @@ def __init__( super().__init__(dataset_group) self._data_provider = data_provider self._matrix_provider = matrix_provider - self._clps: list[ArrayLike] = [ - None # type:ignore[list-item] - ] * self._data_provider.aligned_global_axis.size - self._residuals: list[ArrayLike] = [ - None # type:ignore[list-item] - ] * self._data_provider.aligned_global_axis.size + self._clps: list[ArrayLike] = [None] * self._data_provider.aligned_global_axis.size + self._residuals: list[ArrayLike] = [None] * self._data_provider.aligned_global_axis.size def estimate(self): """Calculate the estimation.""" diff --git a/glotaran/optimization/matrix_provider.py b/glotaran/optimization/matrix_provider.py index d7f353923..40042efe8 100644 --- a/glotaran/optimization/matrix_provider.py +++ b/glotaran/optimization/matrix_provider.py @@ -195,7 +195,7 @@ def calculate_dataset_matrix( clp_labels, matrix = MatrixProvider.combine_megacomplex_matrices( matrix, this_matrix, clp_labels, this_clp_labels ) - return MatrixContainer(clp_labels, matrix) # type:ignore[arg-type] + return MatrixContainer(clp_labels, matrix) @staticmethod def combine_megacomplex_matrices( @@ -718,7 +718,8 @@ def calculate_aligned_matrices(self): ] group_matrix = self.align_matrices( - matrix_containers, matrix_scales # type:ignore[arg-type] + matrix_containers, + matrix_scales, # type:ignore[arg-type] ) self._aligned_full_clp_labels[i] = full_clp_labels[group_label]