Skip to content

Commit

Permalink
Use standard collections for type hints
Browse files Browse the repository at this point in the history
Python ≥3.9 feature.
  • Loading branch information
khaeru committed Sep 24, 2024
1 parent 19709cf commit 030434a
Show file tree
Hide file tree
Showing 16 changed files with 81 additions and 106 deletions.
10 changes: 5 additions & 5 deletions ixmp/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import copy
from dataclasses import asdict, dataclass, field, fields, make_dataclass
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Type
from typing import Any, Optional

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -121,7 +121,7 @@ def delete_field(self, name):
data.pop(name)
return new_cls, new_cls(**data)

def keys(self) -> Tuple[str, ...]:
def keys(self) -> tuple[str, ...]:
return tuple(map(lambda f: f.name.replace("_", " "), fields(self)))

def set(self, name: str, value: Any, strict: bool = True):
Expand Down Expand Up @@ -214,7 +214,7 @@ class Config:
#: ``ixmp.config.values["platform"]["platform name"]…``.
values: BaseValues

_ValuesClass: Type
_ValuesClass: type[BaseValues]

def __init__(self, read: bool = True):
self._ValuesClass = BaseValues
Expand Down Expand Up @@ -261,7 +261,7 @@ def get(self, name: str) -> Any:
"""Return the value of a configuration key `name`."""
return self.values[name]

def keys(self) -> Tuple[str, ...]:
def keys(self) -> tuple[str, ...]:
"""Return the names of all registered configuration keys."""
return self.values.keys()

Expand Down Expand Up @@ -383,7 +383,7 @@ def add_platform(self, name: str, *args, **kwargs):

self.values["platform"][name] = info

def get_platform_info(self, name: str) -> Tuple[str, Dict[str, Any]]:
def get_platform_info(self, name: str) -> tuple[str, dict[str, Any]]:
"""Return information on configured Platform `name`.
Parameters
Expand Down
9 changes: 6 additions & 3 deletions ixmp/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Backend API."""

from enum import IntFlag
from typing import Dict, List, Type, Union
from typing import TYPE_CHECKING, Union

if TYPE_CHECKING:
import ixmp.backend.base

#: Lists of field names for tuples returned by Backend API methods.
#:
Expand Down Expand Up @@ -46,12 +49,12 @@
#: Partial list of dimensions for the IAMC data structure, or “IAMC format”. This omits
#: "year" and "subannual" which appear in some variants of the structure, but not in
#: others.
IAMC_IDX: List[Union[str, int]] = ["model", "scenario", "region", "variable", "unit"]
IAMC_IDX: list[Union[str, int]] = ["model", "scenario", "region", "variable", "unit"]


#: Mapping from names to available backends. To register additional backends, add
#: entries to this dictionary.
BACKENDS: Dict[str, Type] = {}
BACKENDS: dict[str, type["ixmp.backend.base.Backend"]] = {}


class ItemType(IntFlag):
Expand Down
59 changes: 28 additions & 31 deletions ixmp/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@
from pathlib import Path
from typing import (
Any,
Dict,
Hashable,
Iterable,
List,
Literal,
MutableMapping,
Optional,
Sequence,
Tuple,
Union,
)

Expand All @@ -33,7 +30,7 @@ class Backend(ABC):
# Typing:
# - All methods MUST be fully typed.
# - Use more permissive types, e.g. Sequence[str], for inputs.
# - Use precise types, e.g. List[str], for return values.
# - Use precise types, e.g. list[str], for return values.
# - Backend subclasses do not need to repeat the type annotations; these are implied
# by this parent class.
#
Expand All @@ -57,7 +54,7 @@ def __call__(self, obj, method, *args, **kwargs):
# Platform methods

@classmethod
def handle_config(cls, args: Sequence, kwargs: MutableMapping) -> Dict[str, Any]:
def handle_config(cls, args: Sequence, kwargs: MutableMapping) -> dict[str, Any]:
"""OPTIONAL: Handle platform/backend config arguments.
Returns a :class:`dict` to be stored in the configuration file. This
Expand Down Expand Up @@ -124,7 +121,7 @@ def set_doc(self, domain: str, docs) -> None:
"""

@abstractmethod
def get_doc(self, domain: str, name: Optional[str] = None) -> Union[str, Dict]:
def get_doc(self, domain: str, name: Optional[str] = None) -> Union[str, dict]:
"""Read documentation from database
Parameters
Expand Down Expand Up @@ -155,7 +152,7 @@ def close_db(self) -> None:
Close any database connection(s), if open.
"""

def get_auth(self, user: str, models: Sequence[str], kind: str) -> Dict[str, bool]:
def get_auth(self, user: str, models: Sequence[str], kind: str) -> dict[str, bool]:
"""OPTIONAL: Return user authorization for `models`.
If the Backend implements access control, this method **must** indicate whether
Expand Down Expand Up @@ -215,7 +212,7 @@ def set_node(
"""

@abstractmethod
def get_nodes(self) -> Iterable[Tuple[str, Optional[str], str, str]]:
def get_nodes(self) -> Iterable[tuple[str, Optional[str], str, str]]:
"""Iterate over all nodes stored on the Platform.
Yields
Expand All @@ -238,7 +235,7 @@ def get_nodes(self) -> Iterable[Tuple[str, Optional[str], str, str]]:
"""

@abstractmethod
def get_timeslices(self) -> Iterable[Tuple[str, str, float]]:
def get_timeslices(self) -> Iterable[tuple[str, str, float]]:
"""Iterate over subannual timeslices defined on the Platform instance.
Yields
Expand Down Expand Up @@ -321,7 +318,7 @@ def get_scenario_names(self) -> Iterable[str]:
def get_scenarios(
self, default: bool, model: Optional[str], scenario: Optional[str]
) -> Iterable[
Tuple[str, str, str, bool, bool, str, str, str, str, str, str, str, int]
tuple[str, str, str, bool, bool, str, str, str, str, str, str, str, int]
]:
"""Iterate over TimeSeries stored on the Platform.
Expand Down Expand Up @@ -377,7 +374,7 @@ def set_unit(self, name: str, comment: str) -> None:
"""

@abstractmethod
def get_units(self) -> List[str]:
def get_units(self) -> list[str]:
"""Return all registered symbols for units of measurement.
Returns
Expand Down Expand Up @@ -592,7 +589,7 @@ def preload(self, ts: TimeSeries) -> None:
"""OPTIONAL: Load `ts` data into memory."""

@staticmethod
def _handle_rw_filters(filters: dict) -> Tuple[Optional[TimeSeries], Dict]:
def _handle_rw_filters(filters: dict) -> tuple[Optional[TimeSeries], dict]:
"""Helper for :meth:`read_file` and :meth:`write_file`.
The `filters` argument is unpacked if the 'scenarios' key is a single
Expand All @@ -617,7 +614,7 @@ def get_data(
variable: Sequence[str],
unit: Sequence[str],
year: Sequence[str],
) -> Iterable[Tuple[str, str, str, int, float]]:
) -> Iterable[tuple[str, str, str, int, float]]:
"""Retrieve time series data.
Parameters
Expand Down Expand Up @@ -650,7 +647,7 @@ def get_data(
@abstractmethod
def get_geo(
self, ts: TimeSeries
) -> Iterable[Tuple[str, str, int, str, str, str, bool]]:
) -> Iterable[tuple[str, str, int, str, str, str, bool]]:
"""Retrieve time-series 'geodata'.
Yields
Expand All @@ -677,7 +674,7 @@ def set_data(
ts: TimeSeries,
region: str,
variable: str,
data: Dict[int, float],
data: dict[int, float],
unit: str,
subannual: str,
meta: bool,
Expand Down Expand Up @@ -831,7 +828,7 @@ def has_solution(self, s: Scenario) -> bool:
"""

@abstractmethod
def list_items(self, s: Scenario, type: str) -> List[str]:
def list_items(self, s: Scenario, type: str) -> list[str]:
"""Return a list of names of items of `type`.
Parameters
Expand Down Expand Up @@ -882,7 +879,7 @@ def delete_item(self, s: Scenario, type: str, name: str) -> None:
"""

@abstractmethod
def item_index(self, s: Scenario, name: str, sets_or_names: str) -> List[str]:
def item_index(self, s: Scenario, name: str, sets_or_names: str) -> list[str]:
"""Return the index sets or names of item `name`.
Parameters
Expand All @@ -900,8 +897,8 @@ def item_get_elements(
s: Scenario,
type: Literal["equ", "par", "set", "var"],
name: str,
filters: Optional[Dict[str, List[Any]]] = None,
) -> Union[Dict[str, Any], pd.Series, pd.DataFrame]:
filters: Optional[dict[str, list[Any]]] = None,
) -> Union[dict[str, Any], pd.Series, pd.DataFrame]:
"""Return elements of item `name`.
Parameters
Expand Down Expand Up @@ -945,7 +942,7 @@ def item_set_elements(
s: Scenario,
type: str,
name: str,
elements: Iterable[Tuple[Any, Optional[float], Optional[str], Optional[str]]],
elements: Iterable[tuple[Any, Optional[float], Optional[str], Optional[str]]],
) -> None:
"""Add keys or values to item `name`.
Expand Down Expand Up @@ -1011,7 +1008,7 @@ def get_meta(
scenario: Optional[str],
version: Optional[int],
strict: bool,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Retrieve all metadata attached to a specific target.
Depending on which of `model`, `scenario`, `version` are :obj:`None`, metadata
Expand Down Expand Up @@ -1121,7 +1118,7 @@ def clear_solution(self, s: Scenario, from_year=None):
# Methods for message_ix.Scenario

@abstractmethod
def cat_list(self, ms: Scenario, name: str) -> List[str]:
def cat_list(self, ms: Scenario, name: str) -> list[str]:
"""Return list of categories in mapping `name`.
Parameters
Expand All @@ -1136,7 +1133,7 @@ def cat_list(self, ms: Scenario, name: str) -> List[str]:
"""

@abstractmethod
def cat_get_elements(self, ms: Scenario, name: str, cat: str) -> List[str]:
def cat_get_elements(self, ms: Scenario, name: str, cat: str) -> list[str]:
"""Get elements of a category mapping.
Parameters
Expand Down Expand Up @@ -1188,11 +1185,11 @@ class CachingBackend(Backend):

#: Cache of values. Keys are given by :meth:`_cache_key`; values depend on the
#: subclass' usage of the cache.
_cache: Dict[Tuple, object] = {}
_cache: dict[tuple, object] = {}

#: Count of number of times a value was retrieved from cache successfully
#: using :meth:`cache_get`.
_cache_hit: Dict[Tuple, int] = {}
_cache_hit: dict[tuple, int] = {}

# Backend API methods

Expand All @@ -1217,8 +1214,8 @@ def _cache_key(
ts: TimeSeries,
ix_type: Optional[str],
name: Optional[str],
filters: Optional[Dict[str, Hashable]] = None,
) -> Tuple[Hashable, ...]:
filters: Optional[dict[str, Hashable]] = None,
) -> tuple[Hashable, ...]:
"""Return a hashable cache key.
ixmp `filters` (a :class:`dict` of :class:`list`) are converted to a unique id
Expand All @@ -1237,7 +1234,7 @@ def _cache_key(
return (ts_id, ix_type, name, hash(json.dumps(sorted(filters.items()))))

def cache_get(
self, ts: TimeSeries, ix_type: str, name: str, filters: Dict
self, ts: TimeSeries, ix_type: str, name: str, filters: dict
) -> Optional[Any]:
"""Retrieve value from cache.
Expand All @@ -1258,7 +1255,7 @@ def cache_get(
raise KeyError(ts, ix_type, name, filters)

def cache(
self, ts: TimeSeries, ix_type: str, name: str, filters: Dict, value: Any
self, ts: TimeSeries, ix_type: str, name: str, filters: dict, value: Any
) -> bool:
"""Store `value` in cache.
Expand All @@ -1284,7 +1281,7 @@ def cache_invalidate(
ts: TimeSeries,
ix_type: Optional[str] = None,
name: Optional[str] = None,
filters: Optional[Dict] = None,
filters: Optional[dict] = None,
) -> None:
"""Invalidate cached values.
Expand All @@ -1300,7 +1297,7 @@ def cache_invalidate(

if filters is None:
i = slice(1) if (ix_type is name is None) else slice(3)
to_remove: Iterable[Tuple] = filter(
to_remove: Iterable[tuple] = filter(
lambda k: k[i] == key[i], self._cache.keys()
)
else:
Expand Down
2 changes: 1 addition & 1 deletion ixmp/backend/jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _handle_jexception():


@lru_cache
def _fixed_index_sets(scheme: str) -> Mapping[str, List[str]]:
def _fixed_index_sets(scheme: str) -> Mapping[str, list[str]]:
"""Return index sets for items that are fixed in the Java code.
See :meth:`JDBCBackend.init_item`. The return value is cached so the method is only
Expand Down
3 changes: 1 addition & 2 deletions ixmp/cli.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from pathlib import Path
from typing import Type

import click

import ixmp

ScenarioClass: Type[ixmp.Scenario] = ixmp.Scenario
ScenarioClass: type[ixmp.Scenario] = ixmp.Scenario


class VersionType(click.ParamType):
Expand Down
6 changes: 3 additions & 3 deletions ixmp/core/platform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from os import PathLike
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Optional, Sequence, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -264,7 +264,7 @@ def add_unit(self, unit: str, comment: str = "None") -> None:

self._backend.set_unit(unit, comment)

def units(self) -> List[str]:
def units(self) -> list[str]:
"""Return all units defined on the Platform.
Returns
Expand Down Expand Up @@ -386,7 +386,7 @@ def add_timeslice(self, name: str, category: str, duration: float) -> None:

def check_access(
self, user: str, models: Union[str, Sequence[str]], access: str = "view"
) -> Union[bool, Dict[str, bool]]:
) -> Union[bool, dict[str, bool]]:
"""Check access to specific models.
Parameters
Expand Down
Loading

0 comments on commit 030434a

Please sign in to comment.