Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type annotations, part 5 #217

Merged
merged 16 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/somersaultecu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,7 @@ class SomersaultSID(IntEnum):
is_visible_raw=None,
byte_size=None,
dtc_values=[],
parameters=[
parameters=NamedItemList([
ValueParameter(
short_name="flip_speed",
long_name="Flip Speed",
Expand All @@ -1272,7 +1272,7 @@ class SomersaultSID(IntEnum):
bit_position=None,
sdgs=[],
),
],
]),
)
}

Expand Down
2 changes: 1 addition & 1 deletion odxtools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
__author__ = "Katrin Bauer"


def _main():
def _main() -> None:
# Command line tool
from .cli import main as _main

Expand Down
10 changes: 5 additions & 5 deletions odxtools/database.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: MIT
from itertools import chain
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Tuple
from xml.etree import ElementTree
from zipfile import ZipFile

Expand All @@ -13,7 +13,7 @@
from .odxlink import OdxLinkDatabase


def version(v: str):
def version(v: str) -> Tuple[int, ...]:
return tuple(map(int, (v.split("."))))


Expand Down Expand Up @@ -135,13 +135,13 @@ def diag_layers(self) -> NamedItemList[DiagLayer]:
return self._diag_layers

@property
def diag_layer_containers(self):
def diag_layer_containers(self) -> NamedItemList[DiagLayerContainer]:
return self._diag_layer_containers

@diag_layer_containers.setter
def diag_layer_containers(self, value):
def diag_layer_containers(self, value: NamedItemList[DiagLayerContainer]) -> None:
self._diag_layer_containers = value

@property
def comparam_subsets(self):
def comparam_subsets(self) -> NamedItemList[ComparamSubset]:
return self._comparam_subsets
8 changes: 4 additions & 4 deletions odxtools/dataobjectproperty.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .encodestate import EncodeState
from .exceptions import DecodeError, EncodeError, odxassert, odxrequire
from .odxlink import OdxDocFragment, OdxLinkDatabase, OdxLinkId, OdxLinkRef
from .odxtypes import odxstr_to_bool
from .odxtypes import AtomicOdxType, ParameterValue, odxstr_to_bool
from .physicaltype import PhysicalType
from .unit import Unit
from .utils import dataclass_fields_asdict
Expand Down Expand Up @@ -99,7 +99,7 @@ def _build_odxlinks(self) -> Dict[OdxLinkId, Any]:
result.update(self.diag_coded_type._build_odxlinks())
return result

def _resolve_odxlinks(self, odxlinks: OdxLinkDatabase):
def _resolve_odxlinks(self, odxlinks: OdxLinkDatabase) -> None:
"""Resolves the reference to the unit"""
super()._resolve_odxlinks(odxlinks)

Expand Down Expand Up @@ -165,5 +165,5 @@ def convert_bytes_to_physical(self,
f"DOP {self.short_name} could not convert the coded value "
f" {repr(internal)} to physical type {self.physical_type.base_data_type}.")

def is_valid_physical_value(self, physical_value):
return self.compu_method.is_valid_physical_value(physical_value)
def is_valid_physical_value(self, physical_value: ParameterValue) -> bool:
return self.compu_method.is_valid_physical_value(cast(AtomicOdxType, physical_value))
15 changes: 0 additions & 15 deletions odxtools/diagdatadictionaryspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,6 @@ class DiagDataDictionarySpec:
unit_spec: Optional[UnitSpec]
sdgs: List[SpecialDataGroup]

def __post_init__(self):
self._all_data_object_properties = NamedItemList(
chain(
self.data_object_props,
self.structures,
self.end_of_pdu_fields,
self.dynamic_length_fields,
self.dtc_dops,
self.tables,
),)

@staticmethod
def from_et(et_element: ElementTree.Element,
doc_frags: List[OdxDocFragment]) -> "DiagDataDictionarySpec":
Expand Down Expand Up @@ -220,7 +209,3 @@ def _resolve_snrefs(self, diag_layer: "DiagLayer") -> None:

if self.unit_spec is not None:
self.unit_spec._resolve_snrefs(diag_layer)

@property
def all_data_object_properties(self):
return self._all_data_object_properties
28 changes: 14 additions & 14 deletions odxtools/diaglayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def diag_data_dictionary_spec(self) -> DiagDataDictionarySpec:
#####
# <value inheritance mechanism helpers>
#####
def _get_parent_refs_sorted_by_priority(self, reverse=False) -> Iterable[ParentRef]:
def _get_parent_refs_sorted_by_priority(self, reverse: bool = False) -> Iterable[ParentRef]:
return sorted(
self.diag_layer_raw.parent_refs,
key=lambda pr: pr.layer.variant_type.inheritance_priority,
Expand Down Expand Up @@ -444,21 +444,21 @@ def _get_local_unit_groups(self) -> Iterable[UnitGroup]:
def _compute_available_diag_comms(self, odxlinks: OdxLinkDatabase
) -> Iterable[Union[DiagService, SingleEcuJob]]:

def get_local_objects_fn(dl):
def get_local_objects_fn(dl: DiagLayer) -> Iterable[Union[DiagService, SingleEcuJob]]:
return dl._get_local_diag_comms(odxlinks)

def not_inherited_fn(parent_ref):
def not_inherited_fn(parent_ref: ParentRef) -> List[str]:
return parent_ref.not_inherited_diag_comms

return self._compute_available_objects(get_local_objects_fn, not_inherited_fn)

def _compute_available_global_neg_responses(self, odxlinks: OdxLinkDatabase) \
-> Iterable[Response]:

def get_local_objects_fn(dl):
def get_local_objects_fn(dl: DiagLayer) -> Iterable[Response]:
return dl.diag_layer_raw.global_negative_responses

def not_inherited_fn(parent_ref):
def not_inherited_fn(parent_ref: ParentRef) -> List[str]:
return parent_ref.not_inherited_global_neg_responses

return self._compute_available_objects(get_local_objects_fn, not_inherited_fn)
Expand All @@ -469,7 +469,7 @@ def _compute_available_ddd_spec_items(
exclude: Callable[["ParentRef"], List[str]],
) -> NamedItemList[TNamed]:

def get_local_objects_fn(dl: "DiagLayer"):
def get_local_objects_fn(dl: DiagLayer) -> Iterable[TNamed]:
if dl.diag_layer_raw.diag_data_dictionary_spec is None:
return []
return include(dl.diag_layer_raw.diag_data_dictionary_spec)
Expand All @@ -479,40 +479,40 @@ def get_local_objects_fn(dl: "DiagLayer"):

def _compute_available_functional_classes(self) -> Iterable[FunctionalClass]:

def get_local_objects_fn(dl):
def get_local_objects_fn(dl: DiagLayer) -> Iterable[FunctionalClass]:
return dl.diag_layer_raw.functional_classes

def not_inherited_fn(parent_ref):
def not_inherited_fn(parent_ref: ParentRef) -> List[str]:
return []

return self._compute_available_objects(get_local_objects_fn, not_inherited_fn)

def _compute_available_additional_audiences(self) -> Iterable[AdditionalAudience]:

def get_local_objects_fn(dl):
def get_local_objects_fn(dl: DiagLayer) -> Iterable[AdditionalAudience]:
return dl.diag_layer_raw.additional_audiences

def not_inherited_fn(parent_ref):
def not_inherited_fn(parent_ref: ParentRef) -> List[str]:
return []

return self._compute_available_objects(get_local_objects_fn, not_inherited_fn)

def _compute_available_state_charts(self) -> Iterable[StateChart]:

def get_local_objects_fn(dl):
def get_local_objects_fn(dl: DiagLayer) -> Iterable[StateChart]:
return dl.diag_layer_raw.state_charts

def not_inherited_fn(parent_ref):
def not_inherited_fn(parent_ref: ParentRef) -> List[str]:
return []

return self._compute_available_objects(get_local_objects_fn, not_inherited_fn)

def _compute_available_unit_groups(self) -> Iterable[UnitGroup]:

def get_local_objects_fn(dl):
def get_local_objects_fn(dl: DiagLayer) -> Iterable[UnitGroup]:
return dl._get_local_unit_groups()

def not_inherited_fn(parent_ref):
def not_inherited_fn(parent_ref: ParentRef) -> List[str]:
return []

return self._compute_available_objects(get_local_objects_fn, not_inherited_fn)
Expand Down
8 changes: 4 additions & 4 deletions odxtools/diaglayercontainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: MIT
from dataclasses import dataclass
from itertools import chain
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from xml.etree import ElementTree

from .admindata import AdminData
Expand All @@ -12,7 +12,7 @@
from .element import IdentifiableElement
from .exceptions import odxrequire
from .nameditemlist import NamedItemList
from .odxlink import OdxDocFragment, OdxLinkDatabase
from .odxlink import OdxDocFragment, OdxLinkDatabase, OdxLinkId
from .specialdatagroup import SpecialDataGroup
from .utils import dataclass_fields_asdict

Expand Down Expand Up @@ -82,7 +82,7 @@ def from_et(et_element: ElementTree.Element,
sdgs=sdgs,
**kwargs)

def _build_odxlinks(self):
def _build_odxlinks(self) -> Dict[OdxLinkId, Any]:
result = {self.odx_id: self}

if self.admin_data is not None:
Expand Down Expand Up @@ -137,7 +137,7 @@ def _finalize_init(self, odxlinks: OdxLinkDatabase) -> None:
ecu_variant._finalize_init(odxlinks)

@property
def diag_layers(self):
def diag_layers(self) -> NamedItemList[DiagLayer]:
return self._diag_layers

def __getitem__(self, key: Union[int, str]) -> DiagLayer:
Expand Down
37 changes: 25 additions & 12 deletions odxtools/diagservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
from .createsdgs import create_sdgs_from_et
from .element import IdentifiableElement
from .exceptions import DecodeError, odxassert, odxrequire
from .functionalclass import FunctionalClass
from .message import Message
from .nameditemlist import NamedItemList
from .odxlink import OdxDocFragment, OdxLinkDatabase, OdxLinkId, OdxLinkRef
from .odxtypes import ParameterValue
from .parameters.parameter import Parameter
from .request import Request
from .response import Response
from .specialdatagroup import SpecialDataGroup
from .state import State
from .statetransition import StateTransition
from .utils import dataclass_fields_asdict

if TYPE_CHECKING:
Expand Down Expand Up @@ -119,15 +123,15 @@ def negative_responses(self) -> NamedItemList[Response]:
return self._negative_responses

@property
def functional_classes(self):
def functional_classes(self) -> NamedItemList[FunctionalClass]:
return self._functional_classes

@property
def pre_condition_states(self):
def pre_condition_states(self) -> NamedItemList[State]:
return self._pre_condition_states

@property
def state_transitions(self):
def state_transitions(self) -> NamedItemList[StateTransition]:
return self._state_transitions

def _build_odxlinks(self) -> Dict[OdxLinkId, Any]:
Expand All @@ -148,11 +152,11 @@ def _resolve_odxlinks(self, odxlinks: OdxLinkDatabase) -> None:
[odxlinks.resolve(x, Response) for x in self.neg_response_refs])

self._functional_classes = NamedItemList(
[odxlinks.resolve(fc_id) for fc_id in self.functional_class_refs])
[odxlinks.resolve(fc_ref, FunctionalClass) for fc_ref in self.functional_class_refs])
self._pre_condition_states = NamedItemList(
[odxlinks.resolve(st_id) for st_id in self.pre_condition_state_refs])
[odxlinks.resolve(st_ref, State) for st_ref in self.pre_condition_state_refs])
self._state_transitions = NamedItemList(
[odxlinks.resolve(stt_id) for stt_id in self.state_transition_refs])
[odxlinks.resolve(stt_ref, StateTransition) for stt_ref in self.state_transition_refs])

if self.admin_data:
self.admin_data._resolve_odxlinks(odxlinks)
Expand Down Expand Up @@ -203,7 +207,7 @@ def decode_message(self, raw_message: bytes) -> Message:
coding_object=coding_object,
param_dict=param_dict)

def encode_request(self, **params):
def encode_request(self, **params: ParameterValue) -> bytes:
"""
Composes an UDS request as list of bytes for this service.
Parameters:
Expand All @@ -214,6 +218,9 @@ def encode_request(self, **params):
# make sure that all parameters which are required for
# encoding are specified (parameters which have a default are
# optional)
if self.request is None:
return b''

missing_params = {x.short_name
for x in self.request.required_parameters}.difference(params.keys())
odxassert(not missing_params, f"The parameters {missing_params} are required but missing!")
Expand All @@ -224,21 +231,27 @@ def encode_request(self, **params):
set(params.keys()).issubset(rq_all_param_names),
f"Unknown parameters specified for encoding: {params.keys()}, "
f"known parameters are: {rq_all_param_names}")
return self.request.encode(**params)
return self.request.encode(coded_request=None, **params)

def encode_positive_response(self, coded_request, response_index=0, **params):
def encode_positive_response(self,
coded_request: bytes,
response_index: int = 0,
**params: ParameterValue) -> bytes:
# TODO: Should the user decide the positive response or what are the differences?
return self.positive_responses[response_index].encode(coded_request, **params)

def encode_negative_response(self, coded_request, response_index=0, **params):
def encode_negative_response(self,
coded_request: bytes,
response_index: int = 0,
**params: ParameterValue) -> bytes:
return self.negative_responses[response_index].encode(coded_request, **params)

def __call__(self, **params) -> bytes:
def __call__(self, **params: ParameterValue) -> bytes:
"""Encode a request."""
return self.encode_request(**params)

def __hash__(self) -> int:
return hash(self.odx_id)

def __eq__(self, o: object) -> bool:
def __eq__(self, o: Any) -> bool:
return isinstance(o, DiagService) and self.odx_id == o.odx_id
6 changes: 4 additions & 2 deletions odxtools/dynamiclengthfield.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: MIT
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from xml.etree import ElementTree

from .decodestate import DecodeState
Expand Down Expand Up @@ -55,5 +55,7 @@ def convert_physical_to_bytes(
) -> bytes:
raise NotImplementedError()

def convert_bytes_to_physical(self, decode_state: DecodeState, bit_position: int = 0):
def convert_bytes_to_physical(self,
decode_state: DecodeState,
bit_position: int = 0) -> Tuple[ParameterValue, int]:
raise NotImplementedError()
6 changes: 4 additions & 2 deletions odxtools/endofpdufield.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: MIT
from copy import copy
from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Tuple
from xml.etree import ElementTree

from .decodestate import DecodeState
Expand Down Expand Up @@ -59,7 +59,9 @@ def convert_physical_to_bytes(
coded_message += self.structure.convert_physical_to_bytes(value, encode_state)
return coded_message

def convert_bytes_to_physical(self, decode_state: DecodeState, bit_position: int = 0):
def convert_bytes_to_physical(self,
decode_state: DecodeState,
bit_position: int = 0) -> Tuple[ParameterValue, int]:
decode_state = copy(decode_state)
cursor_position = decode_state.cursor_position
byte_code = decode_state.coded_message
Expand Down
Loading