Skip to content

Commit

Permalink
get rid of some users of hasattr()
Browse files Browse the repository at this point in the history
instead, let's use type checking protocols. Note that this trick
cannot be used for the `hasattr("protocol")` instances of `DiagComm`
because this would lead to cyclic imports.

thanks to [at]kayoub5 for suggesting this!

Signed-off-by: Andreas Lauser <andreas.lauser@mbition.io>
Signed-off-by: Florian Jost <florian.jost@mbition.io>
  • Loading branch information
andlaus committed Aug 2, 2024
1 parent a6653d0 commit 7a00f2b
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 11 deletions.
12 changes: 6 additions & 6 deletions odxtools/diaglayers/ecuvariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

from typing_extensions import override

from ..diagvariable import DiagVariable
from ..diagvariable import DiagVariable, HasDiagVariables
from ..dyndefinedspec import DynDefinedSpec
from ..ecuvariantpattern import EcuVariantPattern
from ..exceptions import odxassert
from ..nameditemlist import NamedItemList
from ..odxlink import OdxDocFragment, OdxLinkDatabase, OdxLinkRef
from ..parentref import ParentRef
from ..variablegroup import VariableGroup
from ..variablegroup import HasVariableGroups, VariableGroup
from .diaglayer import DiagLayer
from .ecuvariantraw import EcuVariantRaw
from .hierarchyelement import HierarchyElement
Expand Down Expand Up @@ -97,10 +97,10 @@ def _compute_available_diag_variables(self,
odxlinks: OdxLinkDatabase) -> Iterable[DiagVariable]:

def get_local_objects_fn(dl: DiagLayer) -> Iterable[DiagVariable]:
if not hasattr(dl.diag_layer_raw, "diag_variables"):
if not isinstance(dl.diag_layer_raw, HasDiagVariables):
return []

return dl.diag_layer_raw.diag_variables # type: ignore[no-any-return]
return dl.diag_layer_raw.diag_variables

def not_inherited_fn(parent_ref: ParentRef) -> List[str]:
return parent_ref.not_inherited_variables
Expand All @@ -111,10 +111,10 @@ def _compute_available_variable_groups(self,
odxlinks: OdxLinkDatabase) -> Iterable[VariableGroup]:

def get_local_objects_fn(dl: DiagLayer) -> Iterable[VariableGroup]:
if not hasattr(dl.diag_layer_raw, "variable_groups"):
if not isinstance(dl.diag_layer_raw, HasVariableGroups):
return []

return dl.diag_layer_raw.variable_groups # type: ignore[no-any-return]
return dl.diag_layer_raw.variable_groups

def not_inherited_fn(parent_ref: ParentRef) -> List[str]:
return []
Expand Down
11 changes: 10 additions & 1 deletion odxtools/diagvariable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: MIT
import typing
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, runtime_checkable
from xml.etree import ElementTree

from .admindata import AdminData
Expand All @@ -17,6 +18,14 @@
from .variablegroup import VariableGroup


@runtime_checkable
class HasDiagVariables(typing.Protocol):

@property
def diag_variables(self) -> "NamedItemList[DiagVariable]":
...


@dataclass
class DiagVariable(IdentifiableElement):
"""Representation of a diagnostic variable
Expand Down
7 changes: 4 additions & 3 deletions odxtools/nameditemlist.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# SPDX-License-Identifier: MIT
import abc
import typing
from copy import deepcopy
from keyword import iskeyword
from typing import (Any, Collection, Dict, Iterable, List, Optional, Protocol, SupportsIndex, Tuple,
TypeVar, Union, cast, overload, runtime_checkable)
from typing import (Any, Collection, Dict, Iterable, List, Optional, SupportsIndex, Tuple, TypeVar,
Union, cast, overload, runtime_checkable)

from .exceptions import odxraise


@runtime_checkable
class OdxNamed(Protocol):
class OdxNamed(typing.Protocol):

@property
def short_name(self) -> str:
Expand Down
12 changes: 11 additions & 1 deletion odxtools/variablegroup.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
# SPDX-License-Identifier: MIT
import typing
from dataclasses import dataclass
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, runtime_checkable
from xml.etree import ElementTree

from .element import IdentifiableElement, NamedElement
from .nameditemlist import NamedItemList
from .odxlink import OdxDocFragment
from .utils import dataclass_fields_asdict

if TYPE_CHECKING:
pass


@runtime_checkable
class HasVariableGroups(typing.Protocol):

@property
def variable_groups(self) -> NamedItemList["VariableGroup"]:
...


@dataclass
class VariableGroup(IdentifiableElement):

Expand Down

0 comments on commit 7a00f2b

Please sign in to comment.