diff --git a/rclpy/rclpy/callback_groups.py b/rclpy/rclpy/callback_groups.py index f0d964a56..bee08d611 100644 --- a/rclpy/rclpy/callback_groups.py +++ b/rclpy/rclpy/callback_groups.py @@ -24,8 +24,7 @@ from rclpy.service import Service from rclpy.waitable import Waitable from rclpy.guard_condition import GuardCondition - Entity = Union[Subscription[Any], Timer, Client[Any, Any], Service[Any, Any], - GuardCondition, Waitable[Any]] + Entity = Union[Subscription, Timer, Client, Service, Waitable[Any], GuardCondition] class CallbackGroup: diff --git a/rclpy/rclpy/executors.py b/rclpy/rclpy/executors.py index 8bf5a1a3b..17169063d 100644 --- a/rclpy/rclpy/executors.py +++ b/rclpy/rclpy/executors.py @@ -24,10 +24,8 @@ from types import TracebackType from typing import Any from typing import Callable -from typing import cast from typing import ContextManager from typing import Coroutine -from typing import Dict from typing import Generator from typing import List from typing import Optional @@ -60,48 +58,33 @@ # For documentation purposes # TODO(jacobperron): Make all entities implement the 'Waitable' interface for better type checking - -T = TypeVar('T') +WaitableEntityType = TypeVar('WaitableEntityType') # Avoid import cycle if TYPE_CHECKING: - from typing import TypeAlias - from rclpy.node import Node # noqa: F401 - from .callback_groups import Entity - EntityT = TypeVar('EntityT', bound=Entity) - - -FunctionOrCoroutineFunction: 'TypeAlias' = Union[Callable[..., T], - Callable[..., Coroutine[None, None, T]]] - - -YieldedCallback: 'TypeAlias' = Generator[Tuple[Task[None], - 'Optional[Entity]', - 'Optional[Node]'], None, None] class _WorkTracker: """Track the amount of work that is in progress.""" - def __init__(self) -> None: + def __init__(self): # Number of tasks that are being executed self._num_work_executing = 0 self._work_condition = Condition() - def __enter__(self) -> None: + def __enter__(self): """Increment the amount of executing work by 1.""" with self._work_condition: self._num_work_executing += 1 - def __exit__(self, exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], exctb: Optional[TracebackType]) -> None: + def __exit__(self, t, v, tb): """Decrement the amount of work executing by 1.""" with self._work_condition: self._num_work_executing -= 1 self._work_condition.notify_all() - def wait(self, timeout_sec: Optional[float] = None) -> bool: + def wait(self, timeout_sec: Optional[float] = None): """ Wait until all work completes. @@ -119,14 +102,12 @@ def wait(self, timeout_sec: Optional[float] = None) -> bool: return True -async def await_or_execute(callback: FunctionOrCoroutineFunction[T], *args: Any) -> T: +async def await_or_execute(callback: Union[Callable, Coroutine], *args) -> Any: """Await a callback if it is a coroutine, else execute it.""" if inspect.iscoroutinefunction(callback): # Await a coroutine - callback = cast(Callable[..., Coroutine[None, None, T]], callback) return await callback(*args) else: - callback = cast(Callable[..., T], callback) # Call a normal function return callback(*args) @@ -158,15 +139,15 @@ class ConditionReachedException(Exception): class TimeoutObject: """Use timeout object to save timeout.""" - def __init__(self, timeout: float) -> None: + def __init__(self, timeout: float): self._timeout = timeout @property - def timeout(self) -> float: + def timeout(self): return self._timeout @timeout.setter - def timeout(self, timeout: float) -> None: + def timeout(self, timeout): self._timeout = timeout @@ -200,10 +181,10 @@ def __init__(self, *, context: Optional[Context] = None) -> None: self._nodes: Set[Node] = set() self._nodes_lock = RLock() # Tasks to be executed (oldest first) 3-tuple Task, Entity, Node - self._tasks: List[Tuple[Task[Any], 'Optional[Entity]', Optional[Node]]] = [] + self._tasks: List[Tuple[Task, Optional[WaitableEntityType], Optional[Node]]] = [] self._tasks_lock = Lock() # This is triggered when wait_for_ready_callbacks should rebuild the wait list - self._guard: Optional[GuardCondition] = GuardCondition( + self._guard = GuardCondition( callback=None, callback_group=None, context=self._context) # True if shutdown has been called self._is_shutdown = False @@ -211,13 +192,12 @@ def __init__(self, *, context: Optional[Context] = None) -> None: # Protect against shutdown() being called in parallel in two threads self._shutdown_lock = Lock() # State for wait_for_ready_callbacks to reuse generator - self._cb_iter: Optional[YieldedCallback] = None - self._last_args: Optional[tuple[object, ...]] = None - self._last_kwargs: Optional[Dict[str, object]] = None + self._cb_iter = None + self._last_args = None + self._last_kwargs = None # Executor cannot use ROS clock because that requires a node self._clock = Clock(clock_type=ClockType.STEADY_TIME) - self._sigint_gc: Optional[SignalHandlerGuardCondition] = \ - SignalHandlerGuardCondition(context) + self._sigint_gc = SignalHandlerGuardCondition(context) self._context.on_shutdown(self.wake) @property @@ -225,8 +205,7 @@ def context(self) -> Context: """Get the context associated with the executor.""" return self._context - def create_task(self, callback: FunctionOrCoroutineFunction[T], *args: Any, **kwargs: Any - ) -> Task[T]: + def create_task(self, callback: Union[Callable, Coroutine], *args, **kwargs) -> Task: """ Add a callback or coroutine to be executed during :meth:`spin` and return a Future. @@ -240,8 +219,7 @@ def create_task(self, callback: FunctionOrCoroutineFunction[T], *args: Any, **kw task = Task(callback, args, kwargs, executor=self) with self._tasks_lock: self._tasks.append((task, None, None)) - if self._guard: - self._guard.trigger() + self._guard.trigger() # Task inherits from Future return task @@ -258,8 +236,7 @@ def shutdown(self, timeout_sec: Optional[float] = None) -> bool: if not self._is_shutdown: self._is_shutdown = True # Tell executor it's been shut down - if self._guard: - self._guard.trigger() + self._guard.trigger() if not self._is_shutdown: if not self._work_tracker.wait(timeout_sec): return False @@ -280,7 +257,7 @@ def shutdown(self, timeout_sec: Optional[float] = None) -> bool: self._last_kwargs = None return True - def __del__(self) -> None: + def __del__(self): if self._sigint_gc is not None: self._sigint_gc.destroy() @@ -296,8 +273,7 @@ def add_node(self, node: 'Node') -> bool: self._nodes.add(node) node.executor = self # Rebuild the wait set so it includes this new node - if self._guard: - self._guard.trigger() + self._guard.trigger() return True return False @@ -314,8 +290,7 @@ def remove_node(self, node: 'Node') -> None: pass else: # Rebuild the wait set so it doesn't include this node - if self._guard: - self._guard.trigger() + self._guard.trigger() def wake(self) -> None: """ @@ -338,7 +313,7 @@ def spin(self) -> None: def spin_until_future_complete( self, - future: Future[Any], + future: Future, timeout_sec: Optional[float] = None ) -> None: """Execute callbacks until a given future is done or a timeout occurs.""" @@ -377,7 +352,7 @@ def spin_once(self, timeout_sec: Optional[float] = None) -> None: def spin_once_until_future_complete( self, - future: Future[Any], + future: Future, timeout_sec: Optional[Union[float, TimeoutObject]] = None ) -> None: """ @@ -392,7 +367,7 @@ def spin_once_until_future_complete( """ raise NotImplementedError() - def _take_timer(self, tmr: Timer) -> Optional[Callable[[], Coroutine[None, None, None]]]: + def _take_timer(self, tmr): try: with tmr.handle: info = tmr.handle.call_timer_with_info() @@ -401,9 +376,7 @@ def _take_timer(self, tmr: Timer) -> Optional[Callable[[], Coroutine[None, None, actual_call_time=info['actual_call_time'], clock_type=tmr.clock.clock_type) - def check_argument_type(callback_func: Union[Callable[[], None], - Callable[[TimerInfo], None]], - target_type: Type[TimerInfo]) -> Optional[str]: + def check_argument_type(callback_func, target_type): sig = inspect.signature(callback_func) for param in sig.parameters.values(): if param.annotation == target_type: @@ -414,19 +387,15 @@ def check_argument_type(callback_func: Union[Callable[[], None], # User might change the Timer.callback function signature at runtime, # so it needs to check the signature every time. - if tmr.callback: - arg_name = check_argument_type(tmr.callback, target_type=TimerInfo) + arg_name = check_argument_type(tmr.callback, target_type=TimerInfo) + prefilled_arg = {arg_name: timer_info} if arg_name is not None: - prefilled_arg = {arg_name: timer_info} - - async def _execute() -> None: - if tmr.callback: - await await_or_execute(partial(tmr.callback, **prefilled_arg)) + async def _execute(): + await await_or_execute(partial(tmr.callback, **prefilled_arg)) return _execute else: - async def _execute() -> None: - if tmr.callback: - await await_or_execute(tmr.callback) + async def _execute(): + await await_or_execute(tmr.callback) return _execute except InvalidHandle: # Timer is a Destroyable, which means that on __enter__ it can throw an @@ -437,8 +406,7 @@ async def _execute() -> None: return None - def _take_subscription(self, sub: Subscription[Any] - ) -> Optional[Callable[[], Coroutine[None, None, None]]]: + def _take_subscription(self, sub): try: with sub.handle: msg_info = sub.handle.take_message(sub.msg_type, sub.raw) @@ -450,7 +418,7 @@ def _take_subscription(self, sub: Subscription[Any] else: msg_tuple = msg_info - async def _execute() -> None: + async def _execute(): await await_or_execute(sub.callback, *msg_tuple) return _execute @@ -463,13 +431,12 @@ async def _execute() -> None: return None - def _take_client(self, client: Client[Any, Any] - ) -> Optional[Callable[[], Coroutine[None, None, None]]]: + def _take_client(self, client): try: with client.handle: header_and_response = client.handle.take_response(client.srv_type.Response) - async def _execute() -> None: + async def _execute(): header, response = header_and_response if header is None: return @@ -493,13 +460,12 @@ async def _execute() -> None: return None - def _take_service(self, srv: Service[Any, Any] - ) -> Optional[Callable[[], Coroutine[None, None, None]]]: + def _take_service(self, srv): try: with srv.handle: request_and_header = srv.handle.service_take_request(srv.srv_type.Request) - async def _execute() -> None: + async def _execute(): (request, header) = request_and_header if header is None: return @@ -516,19 +482,17 @@ async def _execute() -> None: return None - def _take_guard_condition(self, gc: GuardCondition - ) -> Callable[[], Coroutine[None, None, None]]: + def _take_guard_condition(self, gc): gc._executor_triggered = False - async def _execute() -> None: - if gc.callback: - await await_or_execute(gc.callback) + async def _execute(): + await await_or_execute(gc.callback) return _execute - def _take_waitable(self, waitable: Waitable[Any]) -> Callable[[], Coroutine[None, None, None]]: + def _take_waitable(self, waitable): data = waitable.take_data() - async def _execute() -> None: + async def _execute(): for future in waitable._futures: future._set_executor(self) await waitable.execute(data) @@ -536,11 +500,10 @@ async def _execute() -> None: def _make_handler( self, - entity: 'EntityT', + entity: WaitableEntityType, node: 'Node', - take_from_wait_list: Callable[['EntityT'], - Optional[Callable[[], Coroutine[None, None, None]]]], - ) -> Task[None]: + take_from_wait_list: Callable, + ) -> Task: """ Make a handler that performs work on an entity. @@ -551,10 +514,8 @@ def _make_handler( # Mark this so it doesn't get added back to the wait list entity._executor_event = True - async def handler(entity: 'EntityT', gc: GuardCondition, is_shutdown: bool, - work_tracker: _WorkTracker) -> None: - if is_shutdown or entity.callback_group is not None and \ - not entity.callback_group.beginning_execution(entity): + async def handler(entity, gc, is_shutdown, work_tracker): + if is_shutdown or not entity.callback_group.beginning_execution(entity): # Didn't get the callback, or the executor has been ordered to stop entity._executor_event = False gc.trigger() @@ -572,8 +533,7 @@ async def handler(entity: 'EntityT', gc: GuardCondition, is_shutdown: bool, if call_coroutine is not None: await call_coroutine() finally: - if entity.callback_group: - entity.callback_group.ending_execution(entity) + entity.callback_group.ending_execution(entity) # Signal that work has been done so the next callback in a mutually exclusive # callback group can get executed @@ -590,22 +550,21 @@ async def handler(entity: 'EntityT', gc: GuardCondition, is_shutdown: bool, self._tasks.append((task, entity, node)) return task - def can_execute(self, entity: 'Entity') -> bool: + def can_execute(self, entity: WaitableEntityType) -> bool: """ Determine if a callback for an entity can be executed. :param entity: Subscription, Timer, Guard condition, etc :returns: ``True`` if the entity callback can be executed, ``False`` otherwise. """ - return not entity._executor_event and entity.callback_group is not None \ - and entity.callback_group.can_execute(entity) + return not entity._executor_event and entity.callback_group.can_execute(entity) def _wait_for_ready_callbacks( self, timeout_sec: Optional[Union[float, TimeoutObject]] = None, nodes: Optional[List['Node']] = None, condition: Callable[[], bool] = lambda: False, - ) -> YieldedCallback: + ) -> Generator[Tuple[Task, WaitableEntityType, 'Node'], None, None]: """ Yield callbacks that are ready to be executed. @@ -628,7 +587,7 @@ def _wait_for_ready_callbacks( while not yielded_work and not self._is_shutdown and not condition(): # Refresh "all" nodes in case executor was woken by a node being added or removed nodes_to_use = nodes - if nodes_to_use is None: + if nodes is None: nodes_to_use = self.get_nodes() # Yield tasks in-progress before waiting for new work @@ -646,11 +605,11 @@ def _wait_for_ready_callbacks( self._tasks = list(filter(lambda t_e_n: not t_e_n[0].done(), self._tasks)) # Gather entities that can be waited on - subscriptions: List[Subscription[Any, ]] = [] + subscriptions: List[Subscription] = [] guards: List[GuardCondition] = [] timers: List[Timer] = [] - clients: List[Client[Any, Any]] = [] - services: List[Service[Any, Any]] = [] + clients: List[Client] = [] + services: List[Service] = [] waitables: List[Waitable[Any]] = [] for node in nodes_to_use: subscriptions.extend(filter(self.can_execute, node.subscriptions)) @@ -667,10 +626,8 @@ def _wait_for_ready_callbacks( if timeout_timer is not None: timers.append(timeout_timer) - if self._guard: - guards.append(self._guard) - if self._sigint_gc: - guards.append(self._sigint_gc) + guards.append(self._guard) + guards.append(self._sigint_gc) entity_count = NumberOfEntities( len(subscriptions), len(guards), len(timers), len(clients), len(services)) @@ -725,9 +682,6 @@ def _wait_for_ready_callbacks( except InvalidHandle: pass - if self._context.handle is None: - raise RuntimeError('Cannot enter context if context is None') - context_stack.enter_context(self._context.handle) wait_set = _rclpy.WaitSet( @@ -788,7 +742,7 @@ def _wait_for_ready_callbacks( if tmr.handle.pointer in timers_ready: # Check timer is ready to workaround rcl issue with cancelled timers if tmr.handle.is_timer_ready(): - if tmr.callback_group and tmr.callback_group.can_execute(tmr): + if tmr.callback_group.can_execute(tmr): handler = self._make_handler(tmr, node, self._take_timer) yielded_work = True yield handler, tmr, node @@ -802,7 +756,7 @@ def _wait_for_ready_callbacks( for gc in node.guards: if gc._executor_triggered: - if gc.callback_group and gc.callback_group.can_execute(gc): + if gc.callback_group.can_execute(gc): handler = self._make_handler(gc, node, self._take_guard_condition) yielded_work = True yield handler, gc, node @@ -832,9 +786,7 @@ def _wait_for_ready_callbacks( if condition(): raise ConditionReachedException() - def wait_for_ready_callbacks(self, *args: Any, **kwargs: Any) -> Tuple[Task[None], - 'Optional[Entity]', - 'Optional[Node]']: + def wait_for_ready_callbacks(self, *args, **kwargs) -> Tuple[Task, WaitableEntityType, 'Node']: """ Return callbacks that are ready to be executed. @@ -892,9 +844,8 @@ def _spin_once_impl( pass else: handler() - exception = handler.exception() - if exception is not None: - raise exception + if handler.exception() is not None: + raise handler.exception() handler.result() # raise any exceptions @@ -903,7 +854,7 @@ def spin_once(self, timeout_sec: Optional[float] = None) -> None: def spin_once_until_future_complete( self, - future: Future[Any], + future: Future, timeout_sec: Optional[Union[float, TimeoutObject]] = None ) -> None: future.add_done_callback(lambda x: self.wake()) @@ -941,7 +892,7 @@ def __init__( warnings.warn( 'MultiThreadedExecutor is used with a single thread.\n' 'Use the SingleThreadedExecutor instead.') - self._futures: List[Future[Any]] = [] + self._futures = [] self._executor = ThreadPoolExecutor(num_threads) def _spin_once_impl( @@ -975,7 +926,7 @@ def spin_once(self, timeout_sec: Optional[float] = None) -> None: def spin_once_until_future_complete( self, - future: Future[Any], + future: Future, timeout_sec: Optional[Union[float, TimeoutObject]] = None ) -> None: future.add_done_callback(lambda x: self.wake()) @@ -983,7 +934,7 @@ def spin_once_until_future_complete( def shutdown( self, - timeout_sec: Optional[float] = None, + timeout_sec: float = None, *, wait_for_threads: bool = True ) -> bool: diff --git a/rclpy/rclpy/task.py b/rclpy/rclpy/task.py index 10fae2742..81a56ab5b 100644 --- a/rclpy/rclpy/task.py +++ b/rclpy/rclpy/task.py @@ -15,21 +15,16 @@ import inspect import sys import threading -from typing import (Callable, cast, Coroutine, Dict, Generator, Generic, Iterable, List, +from typing import (Callable, cast, Coroutine, Dict, Generator, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union) import warnings import weakref if TYPE_CHECKING: - from typing import TypeAlias - from rclpy.executors import Executor T = TypeVar('T') -FunctionOrCoroutineFunction: 'TypeAlias' = Union[Callable[[], T], - Callable[..., Coroutine[None, None, T]]] - def _fake_weakref() -> None: """Return None when called to simulate a weak reference that has been garbage collected.""" @@ -212,11 +207,13 @@ class Task(Future[T]): """ def __init__(self, - handler: FunctionOrCoroutineFunction[T], - args: Optional[Iterable[object]] = None, + handler: Union[Callable[[], T], Coroutine[None, None, T], None], + args: Optional[List[object]] = None, kwargs: Optional[Dict[str, object]] = None, executor: Optional['Executor'] = None) -> None: super().__init__(executor=executor) + # _handler is either a normal function or a coroutine + self._handler = handler # Arguments passed into the function if args is None: args = [] @@ -224,19 +221,10 @@ def __init__(self, if kwargs is None: kwargs = {} self._kwargs: Optional[Dict[str, object]] = kwargs - - # _handler is either a normal function or a coroutine if inspect.iscoroutinefunction(handler): - self._handler: Union[ - Coroutine[None, None, T], - Callable[[], T], - None - ] = handler(*args, **kwargs) + self._handler = handler(*args, **kwargs) self._args = None self._kwargs = None - else: - handler = cast(Callable[[], T], handler) - self._handler = handler # True while the task is being executed self._executing = False # Lock acquired to prevent task from executing in parallel with itself @@ -260,7 +248,7 @@ def __call__(self) -> None: if inspect.iscoroutine(self._handler): # Execute a coroutine - handler = self._handler + handler = cast(Coroutine[None, None, T], self._handler) try: handler.send(None) except StopIteration as e: diff --git a/rclpy/rclpy/timer.py b/rclpy/rclpy/timer.py index 706234aff..c5b577053 100644 --- a/rclpy/rclpy/timer.py +++ b/rclpy/rclpy/timer.py @@ -67,7 +67,7 @@ class Timer: def __init__( self, callback: Union[Callable[[], None], Callable[[TimerInfo], None], None], - callback_group: Optional[CallbackGroup], + callback_group: CallbackGroup, timer_period_ns: int, clock: Clock, *, diff --git a/rclpy/src/rclpy/action_client.hpp b/rclpy/src/rclpy/action_client.hpp index 49d616e4c..5dcf04b90 100644 --- a/rclpy/src/rclpy/action_client.hpp +++ b/rclpy/src/rclpy/action_client.hpp @@ -207,7 +207,7 @@ class ActionClient : public Destroyable, public std::enable_shared_from_this