Skip to content

Commit

Permalink
Add types to parameter.py. (#1246)
Browse files Browse the repository at this point in the history
* Add types to parameter.py

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* Revert "Add types to TypeHash and moved away from __slots__ usage (#1232)" (#1243)

This reverts commit b06baef.

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* Add back Type hash __slots__ and add test cases. (#1245)

* Add types to TypeHash and add test cases

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* Add types to context.py (#1240)

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* Add types to qos_overriding_options.py (#1248)

Signed-off-by: Michael Carlstrom <rmc170@case.edu>
Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* Small fixes for modern flake8. (#1264)

It doesn't like to compare types with ==, so switch to
isinstance as appropriate.

Signed-off-by: Chris Lalancette <clalancette@gmail.com>
Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* Add types to time_source.py (#1259)

* Add types to time_source.py

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* fix small bug

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* Switch to overloads to avoid mypy3737

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* Update parameter declaration from Node

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* add back rclpy.

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* Fix flake8 imports

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* Add proper array.array[]

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* Update types of declare_parameter, declare_parameters api

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* Add non conflicting types back to constructor

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* Move sys import

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* Update error message

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* Add default value for generic Parameter

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* Add explanation comment

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* Add TypeVar import inside else case

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* push to rerun ci

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* push to rerun ci

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

* Switch back to union

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>

---------

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>
Signed-off-by: Michael Carlstrom <rmc170@case.edu>
Signed-off-by: Chris Lalancette <clalancette@gmail.com>
Co-authored-by: Chris Lalancette <clalancette@gmail.com>
Co-authored-by: Shane Loretz <sloretz@intrinsic.ai>
  • Loading branch information
3 people authored Aug 2, 2024
1 parent adfcb2b commit db98d90
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 44 deletions.
57 changes: 37 additions & 20 deletions rclpy/rclpy/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
import time

from types import TracebackType
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
from typing import overload
from typing import Tuple
from typing import Type
from typing import TypeVar
Expand Down Expand Up @@ -67,7 +66,8 @@
from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy
from rclpy.logging import get_logger
from rclpy.logging_service import LoggingService
from rclpy.parameter import Parameter, PARAMETER_SEPARATOR_STRING
from rclpy.parameter import (AllowableParameterValue, AllowableParameterValueT, Parameter,
PARAMETER_SEPARATOR_STRING)
from rclpy.parameter_service import ParameterService
from rclpy.publisher import Publisher
from rclpy.qos import qos_profile_parameter_events
Expand Down Expand Up @@ -164,7 +164,7 @@ def __init__(
"""
self.__handle = None
self._context = get_default_context() if context is None else context
self._parameters: dict = {}
self._parameters: Dict[str, Parameter] = {}
self._publishers: List[Publisher] = []
self._subscriptions: List[Subscription] = []
self._clients: List[Client] = []
Expand All @@ -179,8 +179,8 @@ def __init__(
self._post_set_parameters_callbacks: List[Callable[[List[Parameter]], None]] = []
self._rate_group = ReentrantCallbackGroup()
self._allow_undeclared_parameters = allow_undeclared_parameters
self._parameter_overrides = {}
self._descriptors = {}
self._parameter_overrides: Dict[str, Parameter] = {}
self._descriptors: Dict[str, ParameterDescriptor] = {}

namespace = namespace or ''
if not self._context.ok():
Expand Down Expand Up @@ -358,10 +358,23 @@ def get_logger(self):
"""Get the nodes logger."""
return self._logger

@overload
def declare_parameter(self, name: str, value: Union[AllowableParameterValueT,
Parameter.Type, ParameterValue],
descriptor: Optional[ParameterDescriptor] = None,
ignore_override: bool = False
) -> Parameter[AllowableParameterValueT]: ...

@overload
def declare_parameter(self, name: str,
value: Union[None, Parameter.Type, ParameterValue] = None,
descriptor: Optional[ParameterDescriptor] = None,
ignore_override: bool = False) -> Parameter[None]: ...

def declare_parameter(
self,
name: str,
value: Any = None,
value: Union[AllowableParameterValue, Parameter.Type, ParameterValue] = None,
descriptor: Optional[ParameterDescriptor] = None,
ignore_override: bool = False
) -> Parameter:
Expand All @@ -386,7 +399,9 @@ def declare_parameter(
"""
if value is None and descriptor is None:
# Temporal patch so we get deprecation warning if only a name is provided.
args = (name, )
args: Union[Tuple[str], Tuple[str, Union[AllowableParameterValue,
Parameter.Type, ParameterValue],
ParameterDescriptor]] = (name, )
else:
descriptor = ParameterDescriptor() if descriptor is None else descriptor
args = (name, value, descriptor)
Expand All @@ -398,7 +413,8 @@ def declare_parameters(
parameters: List[Union[
Tuple[str],
Tuple[str, Parameter.Type],
Tuple[str, Any, ParameterDescriptor],
Tuple[str, Union[AllowableParameterValue, Parameter.Type, ParameterValue],
ParameterDescriptor],
]],
ignore_override: bool = False
) -> List[Parameter]:
Expand Down Expand Up @@ -448,8 +464,8 @@ def declare_parameters(
:raises: InvalidParameterValueException if the registered callback rejects any parameter.
:raises: TypeError if any tuple in **parameters** does not match the annotated type.
"""
parameter_list = []
descriptors = {}
parameter_list: List[Parameter] = []
descriptors: Dict[str, ParameterDescriptor] = {}
for index, parameter_tuple in enumerate(parameters):
if len(parameter_tuple) < 1 or len(parameter_tuple) > 3:
raise TypeError(
Expand All @@ -473,9 +489,8 @@ def declare_parameters(
# Note(jubeira): declare_parameters verifies the name, but set_parameters doesn't.
validate_parameter_name(name)

second_arg = parameter_tuple[1] if 1 < len(parameter_tuple) else None
descriptor = parameter_tuple[2] if 2 < len(parameter_tuple) else ParameterDescriptor()

second_arg = parameter_tuple[1] if len(parameter_tuple) > 1 else None
descriptor = parameter_tuple[2] if len(parameter_tuple) > 2 else ParameterDescriptor()
if not isinstance(descriptor, ParameterDescriptor):
raise TypeError(
f'Third element {descriptor} at index {index} in parameters list '
Expand Down Expand Up @@ -519,6 +534,11 @@ def declare_parameters(
if not ignore_override and name in self._parameter_overrides:
value = self._parameter_overrides[name].value

if isinstance(value, ParameterValue):
raise ValueError('Cannot declare a Parameter from a ParameterValue without it '
'being included in self._parameter_overrides, and ',
'ignore_override=False')

parameter_list.append(Parameter(name, value=value))
descriptors.update({name: descriptor})

Expand Down Expand Up @@ -719,10 +739,7 @@ def get_parameter_or(

return self._parameters[name]

def get_parameters_by_prefix(self, prefix: str) -> Dict[str, Optional[Union[
bool, int, float, str, bytes,
Sequence[bool], Sequence[int], Sequence[float], Sequence[str]
]]]:
def get_parameters_by_prefix(self, prefix: str) -> Dict[str, Parameter]:
"""
Get parameters that have a given prefix in their names as a dictionary.
Expand Down Expand Up @@ -1039,7 +1056,7 @@ def _check_undeclared_parameters(self, parameter_list: List[Parameter]):
if not self._allow_undeclared_parameters and any(undeclared_parameters):
raise ParameterNotDeclaredException(list(undeclared_parameters))

def _call_pre_set_parameters_callback(self, parameter_list: [List[Parameter]]):
def _call_pre_set_parameters_callback(self, parameter_list: List[Parameter]):
if self._pre_set_parameters_callbacks:
modified_parameter_list = []
for callback in self._pre_set_parameters_callbacks:
Expand All @@ -1049,7 +1066,7 @@ def _call_pre_set_parameters_callback(self, parameter_list: [List[Parameter]]):
else:
return None

def _call_post_set_parameters_callback(self, parameter_list: [List[Parameter]]):
def _call_post_set_parameters_callback(self, parameter_list: List[Parameter]):
if self._post_set_parameters_callbacks:
for callback in self._post_set_parameters_callbacks:
callback(parameter_list)
Expand Down
83 changes: 59 additions & 24 deletions rclpy/rclpy/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
# limitations under the License.

import array
from enum import Enum
from enum import IntEnum
import sys
from typing import Any
from typing import Dict
from typing import Generic
from typing import List
from typing import Optional
from typing import overload
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
Expand All @@ -29,17 +33,38 @@
PARAMETER_SEPARATOR_STRING = '.'

if TYPE_CHECKING:
AllowableParameterValue = Union[None, bool, int, float, str,
List[bytes], Tuple[bytes, ...],
List[bool], Tuple[bool, ...],
List[int], Tuple[int, ...], array.array[int],
List[float], Tuple[float, ...], array.array[float],
List[str], Tuple[str, ...], array.array[str]]


class Parameter:

class Type(Enum):
from typing_extensions import TypeVar
# Mypy does not handle string literals of array.array[int/str/float] very well
# So if user has newer version of python can use proper array types.
if sys.version_info > (3, 9):
AllowableParameterValue = Union[None, bool, int, float, str,
list[bytes], Tuple[bytes, ...],
list[bool], Tuple[bool, ...],
list[int], Tuple[int, ...], array.array[int],
list[float], Tuple[float, ...], array.array[float],
list[str], Tuple[str, ...], array.array[str]]
else:
AllowableParameterValue = Union[None, bool, int, float, str,
List[bytes], Tuple[bytes, ...],
List[bool], Tuple[bool, ...],
List[int], Tuple[int, ...], 'array.array[int]',
List[float], Tuple[float, ...], 'array.array[float]',
List[str], Tuple[str, ...], 'array.array[str]']

AllowableParameterValueT = TypeVar('AllowableParameterValueT',
bound=AllowableParameterValue,
default=AllowableParameterValue)
else:
from typing import TypeVar
# Done to prevent runtime errors of undefined values.
# after python3.13 is minimum support this could be removed.
AllowableParameterValue = Any
AllowableParameterValueT = TypeVar('AllowableParameterValueT')


class Parameter(Generic[AllowableParameterValueT]):

class Type(IntEnum):
NOT_SET = ParameterType.PARAMETER_NOT_SET
BOOL = ParameterType.PARAMETER_BOOL
INTEGER = ParameterType.PARAMETER_INTEGER
Expand All @@ -52,7 +77,9 @@ class Type(Enum):
STRING_ARRAY = ParameterType.PARAMETER_STRING_ARRAY

@classmethod
def from_parameter_value(cls, parameter_value):
def from_parameter_value(cls,
parameter_value: AllowableParameterValueT
) -> 'Parameter.Type':
"""
Get a Parameter.Type from a given variable.
Expand Down Expand Up @@ -88,7 +115,7 @@ def from_parameter_value(cls, parameter_value):
raise TypeError(
f"The given value is not one of the allowed types '{parameter_value}'.")

def check(self, parameter_value):
def check(self, parameter_value: AllowableParameterValueT) -> bool:
if Parameter.Type.NOT_SET == self:
return parameter_value is None
if Parameter.Type.BOOL == self:
Expand Down Expand Up @@ -117,7 +144,7 @@ def check(self, parameter_value):
return False

@classmethod
def from_parameter_msg(cls, param_msg):
def from_parameter_msg(cls, param_msg: ParameterMsg) -> 'Parameter[AllowableParameterValueT]':
value = None
type_ = Parameter.Type(value=param_msg.value.type)
if Parameter.Type.BOOL == type_:
Expand All @@ -140,7 +167,14 @@ def from_parameter_msg(cls, param_msg):
value = param_msg.value.string_array_value
return cls(param_msg.name, type_, value)

def __init__(self, name, type_=None, value=None):
@overload
def __init__(self, name: str, type_: Optional['Parameter.Type'] = None) -> None: ...

@overload
def __init__(self, name: str, type_: Optional['Parameter.Type'],
value: AllowableParameterValueT) -> None: ...

def __init__(self, name: str, type_: Optional['Parameter.Type'] = None, value=None) -> None:
if type_ is None:
# This will raise a TypeError if it is not possible to get a type from the value.
type_ = Parameter.Type.from_parameter_value(value)
Expand All @@ -156,18 +190,18 @@ def __init__(self, name, type_=None, value=None):
self._value = value

@property
def name(self):
def name(self) -> str:
return self._name

@property
def type_(self):
def type_(self) -> 'Parameter.Type':
return self._type_

@property
def value(self):
def value(self) -> AllowableParameterValueT:
return self._value

def get_parameter_value(self):
def get_parameter_value(self) -> ParameterValue:
parameter_value = ParameterValue(type=self.type_.value)
if Parameter.Type.BOOL == self.type_:
parameter_value.bool_value = self.value
Expand All @@ -189,7 +223,7 @@ def get_parameter_value(self):
parameter_value.string_array_value = self.value
return parameter_value

def to_parameter_msg(self):
def to_parameter_msg(self) -> ParameterMsg:
return ParameterMsg(name=self.name, value=self.get_parameter_value())


Expand Down Expand Up @@ -237,7 +271,7 @@ def get_parameter_value(string_value: str) -> ParameterValue:
return value


def parameter_value_to_python(parameter_value: ParameterValue):
def parameter_value_to_python(parameter_value: ParameterValue) -> AllowableParameterValue:
"""
Get the value for the Python builtin type from a rcl_interfaces/msg/ParameterValue object.
Expand Down Expand Up @@ -295,7 +329,7 @@ def parameter_dict_from_yaml_file(
"""
with open(parameter_file, 'r') as f:
param_file = yaml.safe_load(f)
param_keys = []
param_keys: List[str] = []
param_dict = {}

if use_wildcard and '/**' in param_file:
Expand Down Expand Up @@ -325,7 +359,8 @@ def parameter_dict_from_yaml_file(
return _unpack_parameter_dict(namespace, param_dict)


def _unpack_parameter_dict(namespace, parameter_dict):
def _unpack_parameter_dict(namespace: str,
parameter_dict: Dict[str, ParameterMsg]) -> Dict[str, ParameterMsg]:
"""
Flatten a parameter dictionary recursively.
Expand Down

0 comments on commit db98d90

Please sign in to comment.