diff --git a/rclpy/rclpy/node.py b/rclpy/rclpy/node.py index 555085637..01353b964 100644 --- a/rclpy/rclpy/node.py +++ b/rclpy/rclpy/node.py @@ -16,13 +16,12 @@ import time from types import TracebackType -from typing import Any from typing import Callable from typing import Dict from typing import Iterator from typing import List from typing import Optional -from typing import Sequence +from typing import overload from typing import Tuple from typing import Type from typing import TypeVar @@ -67,7 +66,8 @@ from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy from rclpy.logging import get_logger from rclpy.logging_service import LoggingService -from rclpy.parameter import Parameter, PARAMETER_SEPARATOR_STRING +from rclpy.parameter import (AllowableParameterValue, AllowableParameterValueT, Parameter, + PARAMETER_SEPARATOR_STRING) from rclpy.parameter_service import ParameterService from rclpy.publisher import Publisher from rclpy.qos import qos_profile_parameter_events @@ -164,7 +164,7 @@ def __init__( """ self.__handle = None self._context = get_default_context() if context is None else context - self._parameters: dict = {} + self._parameters: Dict[str, Parameter] = {} self._publishers: List[Publisher] = [] self._subscriptions: List[Subscription] = [] self._clients: List[Client] = [] @@ -179,8 +179,8 @@ def __init__( self._post_set_parameters_callbacks: List[Callable[[List[Parameter]], None]] = [] self._rate_group = ReentrantCallbackGroup() self._allow_undeclared_parameters = allow_undeclared_parameters - self._parameter_overrides = {} - self._descriptors = {} + self._parameter_overrides: Dict[str, Parameter] = {} + self._descriptors: Dict[str, ParameterDescriptor] = {} namespace = namespace or '' if not self._context.ok(): @@ -358,10 +358,23 @@ def get_logger(self): """Get the nodes logger.""" return self._logger + @overload + def declare_parameter(self, name: str, value: Union[AllowableParameterValueT, + Parameter.Type, ParameterValue], + descriptor: Optional[ParameterDescriptor] = None, + ignore_override: bool = False + ) -> Parameter[AllowableParameterValueT]: ... + + @overload + def declare_parameter(self, name: str, + value: Union[None, Parameter.Type, ParameterValue] = None, + descriptor: Optional[ParameterDescriptor] = None, + ignore_override: bool = False) -> Parameter[None]: ... + def declare_parameter( self, name: str, - value: Any = None, + value: Union[AllowableParameterValue, Parameter.Type, ParameterValue] = None, descriptor: Optional[ParameterDescriptor] = None, ignore_override: bool = False ) -> Parameter: @@ -386,7 +399,9 @@ def declare_parameter( """ if value is None and descriptor is None: # Temporal patch so we get deprecation warning if only a name is provided. - args = (name, ) + args: Union[Tuple[str], Tuple[str, Union[AllowableParameterValue, + Parameter.Type, ParameterValue], + ParameterDescriptor]] = (name, ) else: descriptor = ParameterDescriptor() if descriptor is None else descriptor args = (name, value, descriptor) @@ -398,7 +413,8 @@ def declare_parameters( parameters: List[Union[ Tuple[str], Tuple[str, Parameter.Type], - Tuple[str, Any, ParameterDescriptor], + Tuple[str, Union[AllowableParameterValue, Parameter.Type, ParameterValue], + ParameterDescriptor], ]], ignore_override: bool = False ) -> List[Parameter]: @@ -448,8 +464,8 @@ def declare_parameters( :raises: InvalidParameterValueException if the registered callback rejects any parameter. :raises: TypeError if any tuple in **parameters** does not match the annotated type. """ - parameter_list = [] - descriptors = {} + parameter_list: List[Parameter] = [] + descriptors: Dict[str, ParameterDescriptor] = {} for index, parameter_tuple in enumerate(parameters): if len(parameter_tuple) < 1 or len(parameter_tuple) > 3: raise TypeError( @@ -473,9 +489,8 @@ def declare_parameters( # Note(jubeira): declare_parameters verifies the name, but set_parameters doesn't. validate_parameter_name(name) - second_arg = parameter_tuple[1] if 1 < len(parameter_tuple) else None - descriptor = parameter_tuple[2] if 2 < len(parameter_tuple) else ParameterDescriptor() - + second_arg = parameter_tuple[1] if len(parameter_tuple) > 1 else None + descriptor = parameter_tuple[2] if len(parameter_tuple) > 2 else ParameterDescriptor() if not isinstance(descriptor, ParameterDescriptor): raise TypeError( f'Third element {descriptor} at index {index} in parameters list ' @@ -519,6 +534,11 @@ def declare_parameters( if not ignore_override and name in self._parameter_overrides: value = self._parameter_overrides[name].value + if isinstance(value, ParameterValue): + raise ValueError('Cannot declare a Parameter from a ParameterValue without it ' + 'being included in self._parameter_overrides, and ', + 'ignore_override=False') + parameter_list.append(Parameter(name, value=value)) descriptors.update({name: descriptor}) @@ -719,10 +739,7 @@ def get_parameter_or( return self._parameters[name] - def get_parameters_by_prefix(self, prefix: str) -> Dict[str, Optional[Union[ - bool, int, float, str, bytes, - Sequence[bool], Sequence[int], Sequence[float], Sequence[str] - ]]]: + def get_parameters_by_prefix(self, prefix: str) -> Dict[str, Parameter]: """ Get parameters that have a given prefix in their names as a dictionary. @@ -1039,7 +1056,7 @@ def _check_undeclared_parameters(self, parameter_list: List[Parameter]): if not self._allow_undeclared_parameters and any(undeclared_parameters): raise ParameterNotDeclaredException(list(undeclared_parameters)) - def _call_pre_set_parameters_callback(self, parameter_list: [List[Parameter]]): + def _call_pre_set_parameters_callback(self, parameter_list: List[Parameter]): if self._pre_set_parameters_callbacks: modified_parameter_list = [] for callback in self._pre_set_parameters_callbacks: @@ -1049,7 +1066,7 @@ def _call_pre_set_parameters_callback(self, parameter_list: [List[Parameter]]): else: return None - def _call_post_set_parameters_callback(self, parameter_list: [List[Parameter]]): + def _call_post_set_parameters_callback(self, parameter_list: List[Parameter]): if self._post_set_parameters_callbacks: for callback in self._post_set_parameters_callbacks: callback(parameter_list) diff --git a/rclpy/rclpy/parameter.py b/rclpy/rclpy/parameter.py index cd2cc820f..464ee7065 100644 --- a/rclpy/rclpy/parameter.py +++ b/rclpy/rclpy/parameter.py @@ -13,10 +13,14 @@ # limitations under the License. import array -from enum import Enum +from enum import IntEnum +import sys +from typing import Any from typing import Dict +from typing import Generic from typing import List from typing import Optional +from typing import overload from typing import Tuple from typing import TYPE_CHECKING from typing import Union @@ -29,17 +33,38 @@ PARAMETER_SEPARATOR_STRING = '.' if TYPE_CHECKING: - AllowableParameterValue = Union[None, bool, int, float, str, - List[bytes], Tuple[bytes, ...], - List[bool], Tuple[bool, ...], - List[int], Tuple[int, ...], array.array[int], - List[float], Tuple[float, ...], array.array[float], - List[str], Tuple[str, ...], array.array[str]] - - -class Parameter: - - class Type(Enum): + from typing_extensions import TypeVar + # Mypy does not handle string literals of array.array[int/str/float] very well + # So if user has newer version of python can use proper array types. + if sys.version_info > (3, 9): + AllowableParameterValue = Union[None, bool, int, float, str, + list[bytes], Tuple[bytes, ...], + list[bool], Tuple[bool, ...], + list[int], Tuple[int, ...], array.array[int], + list[float], Tuple[float, ...], array.array[float], + list[str], Tuple[str, ...], array.array[str]] + else: + AllowableParameterValue = Union[None, bool, int, float, str, + List[bytes], Tuple[bytes, ...], + List[bool], Tuple[bool, ...], + List[int], Tuple[int, ...], 'array.array[int]', + List[float], Tuple[float, ...], 'array.array[float]', + List[str], Tuple[str, ...], 'array.array[str]'] + + AllowableParameterValueT = TypeVar('AllowableParameterValueT', + bound=AllowableParameterValue, + default=AllowableParameterValue) +else: + from typing import TypeVar + # Done to prevent runtime errors of undefined values. + # after python3.13 is minimum support this could be removed. + AllowableParameterValue = Any + AllowableParameterValueT = TypeVar('AllowableParameterValueT') + + +class Parameter(Generic[AllowableParameterValueT]): + + class Type(IntEnum): NOT_SET = ParameterType.PARAMETER_NOT_SET BOOL = ParameterType.PARAMETER_BOOL INTEGER = ParameterType.PARAMETER_INTEGER @@ -52,7 +77,9 @@ class Type(Enum): STRING_ARRAY = ParameterType.PARAMETER_STRING_ARRAY @classmethod - def from_parameter_value(cls, parameter_value): + def from_parameter_value(cls, + parameter_value: AllowableParameterValueT + ) -> 'Parameter.Type': """ Get a Parameter.Type from a given variable. @@ -88,7 +115,7 @@ def from_parameter_value(cls, parameter_value): raise TypeError( f"The given value is not one of the allowed types '{parameter_value}'.") - def check(self, parameter_value): + def check(self, parameter_value: AllowableParameterValueT) -> bool: if Parameter.Type.NOT_SET == self: return parameter_value is None if Parameter.Type.BOOL == self: @@ -117,7 +144,7 @@ def check(self, parameter_value): return False @classmethod - def from_parameter_msg(cls, param_msg): + def from_parameter_msg(cls, param_msg: ParameterMsg) -> 'Parameter[AllowableParameterValueT]': value = None type_ = Parameter.Type(value=param_msg.value.type) if Parameter.Type.BOOL == type_: @@ -140,7 +167,14 @@ def from_parameter_msg(cls, param_msg): value = param_msg.value.string_array_value return cls(param_msg.name, type_, value) - def __init__(self, name, type_=None, value=None): + @overload + def __init__(self, name: str, type_: Optional['Parameter.Type'] = None) -> None: ... + + @overload + def __init__(self, name: str, type_: Optional['Parameter.Type'], + value: AllowableParameterValueT) -> None: ... + + def __init__(self, name: str, type_: Optional['Parameter.Type'] = None, value=None) -> None: if type_ is None: # This will raise a TypeError if it is not possible to get a type from the value. type_ = Parameter.Type.from_parameter_value(value) @@ -156,18 +190,18 @@ def __init__(self, name, type_=None, value=None): self._value = value @property - def name(self): + def name(self) -> str: return self._name @property - def type_(self): + def type_(self) -> 'Parameter.Type': return self._type_ @property - def value(self): + def value(self) -> AllowableParameterValueT: return self._value - def get_parameter_value(self): + def get_parameter_value(self) -> ParameterValue: parameter_value = ParameterValue(type=self.type_.value) if Parameter.Type.BOOL == self.type_: parameter_value.bool_value = self.value @@ -189,7 +223,7 @@ def get_parameter_value(self): parameter_value.string_array_value = self.value return parameter_value - def to_parameter_msg(self): + def to_parameter_msg(self) -> ParameterMsg: return ParameterMsg(name=self.name, value=self.get_parameter_value()) @@ -237,7 +271,7 @@ def get_parameter_value(string_value: str) -> ParameterValue: return value -def parameter_value_to_python(parameter_value: ParameterValue): +def parameter_value_to_python(parameter_value: ParameterValue) -> AllowableParameterValue: """ Get the value for the Python builtin type from a rcl_interfaces/msg/ParameterValue object. @@ -295,7 +329,7 @@ def parameter_dict_from_yaml_file( """ with open(parameter_file, 'r') as f: param_file = yaml.safe_load(f) - param_keys = [] + param_keys: List[str] = [] param_dict = {} if use_wildcard and '/**' in param_file: @@ -325,7 +359,8 @@ def parameter_dict_from_yaml_file( return _unpack_parameter_dict(namespace, param_dict) -def _unpack_parameter_dict(namespace, parameter_dict): +def _unpack_parameter_dict(namespace: str, + parameter_dict: Dict[str, ParameterMsg]) -> Dict[str, ParameterMsg]: """ Flatten a parameter dictionary recursively.