Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 55 additions & 20 deletions synchros2/synchros2/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import rclpy.callback_groups
import rclpy.executors
import rclpy.node
import rclpy.timer

from synchros2.futures import FutureLike
from synchros2.utilities import bind_to_thread, fqn
Expand Down Expand Up @@ -592,6 +593,7 @@ def __init__(
max_thread_idle_time: typing.Optional[float] = None,
max_threads_per_callback_group: typing.Optional[int] = None,
*,
num_threads_for_timers: typing.Optional[int] = None,
context: typing.Optional[rclpy.context.Context] = None,
logger: typing.Optional[logging.Logger] = None,
) -> None:
Expand All @@ -607,24 +609,41 @@ def __init__(
max_threads_per_callback_group: optional maximum number of concurrent callbacks the
default thread pool should service for a given callback group. Useful to avoid
reentrant callback groups from starving the default thread pool.
num_threads_for_timers: optional number of threads to dedicate to timer callbacks.
Defaults to 10% of all available threads, which may be 0 if there are less than
10 threads, in which case timer callbacks will be serviced by the default thread pool.
context: An optional instance of the ros context.
logger: An optional logger instance.
"""
super().__init__(context=context)
if logger is None:
logger = rclpy.logging.get_logger(fqn(self.__class__))
if max_threads is None:
max_threads = 32 * (os.cpu_count() or 1)
if num_threads_for_timers is None:
num_threads_for_timers = max_threads // 10
if num_threads_for_timers == 0:
logger.warning("Not enough threads available, timers will be serviced by the default thread pool")
max_threads -= num_threads_for_timers
self._logger = logger
self._is_shutdown = False
self._spin_lock = threading.Lock()
self._shutdown_lock = threading.RLock()
self._thread_pools = [
AutoScalingThreadPool(
max_workers=max_threads,
max_idle_time=max_thread_idle_time,
self._default_thread_pool = AutoScalingThreadPool(
max_workers=max_threads,
max_idle_time=max_thread_idle_time,
submission_quota=max_threads_per_callback_group,
logger=self._logger,
)
self._timers_thread_pool: typing.Optional[AutoScalingThreadPool] = None
if num_threads_for_timers != 0:
self._timers_thread_pool = AutoScalingThreadPool(
min_workers=num_threads_for_timers,
max_workers=num_threads_for_timers,
submission_quota=max_threads_per_callback_group,
logger=self._logger,
),
]
)
self._static_thread_pools: typing.List[AutoScalingThreadPool] = []
self._callback_group_affinity: weakref.WeakKeyDictionary[
rclpy.callback_groups.CallbackGroup,
AutoScalingThreadPool,
Expand All @@ -637,12 +656,21 @@ def __init__(
@property
def default_thread_pool(self) -> AutoScalingThreadPool:
"""Default autoscaling thread pool."""
return self._thread_pools[0]
return self._default_thread_pool

@property
def timers_thread_pool(self) -> typing.Optional[AutoScalingThreadPool]:
"""Autoscaling thread pool for timer callbacks."""
return self._timers_thread_pool

@property
def thread_pools(self) -> typing.List[AutoScalingThreadPool]:
"""Autoscaling thread pools in use."""
return list(self._thread_pools)
thread_pools = [self._default_thread_pool]
if self._timers_thread_pool is not None:
thread_pools.append(self._timers_thread_pool)
thread_pools.extend(self._static_thread_pools)
return thread_pools

def add_static_thread_pool(self, num_threads: typing.Optional[int] = None) -> AutoScalingThreadPool:
"""Add a thread pool that keeps a steady number of workers."""
Expand All @@ -653,8 +681,8 @@ def add_static_thread_pool(self, num_threads: typing.Optional[int] = None) -> Au
max_workers=num_threads,
logger=self._logger,
)
self._thread_pools.append(thread_pool)
self._logger.debug(f"Added static thread pool #{len(self._thread_pools) - 1}")
self._static_thread_pools.append(thread_pool)
self._logger.debug(f"Added static thread pool #{len(self._static_thread_pools) - 1}")
return thread_pool

def bind(self, callback_group: rclpy.callback_groups.CallbackGroup, thread_pool: AutoScalingThreadPool) -> None:
Expand All @@ -663,9 +691,13 @@ def bind(self, callback_group: rclpy.callback_groups.CallbackGroup, thread_pool:
Thread pool must be known to the executor. That is, instantiated through add_*_thread_pool() methods.
"""
with self._shutdown_lock:
if thread_pool not in self._thread_pools:
if thread_pool not in self._static_thread_pools:
if thread_pool is self._default_thread_pool:
raise ValueError("cannot rebind to default thread pool")
if thread_pool is self._timers_thread_pool:
raise ValueError("cannot bind to timers thread pool")
raise ValueError("thread pool unknown to executor")
thread_pool_index = self._thread_pools.index(thread_pool)
thread_pool_index = self._static_thread_pools.index(thread_pool)
callback_group_name = f"{fqn(type(callback_group))}@{id(callback_group)}"
self._logger.debug(f"Binding {callback_group_name} to thread pool #{thread_pool_index}...")
self._callback_group_affinity[callback_group] = thread_pool
Expand Down Expand Up @@ -698,14 +730,16 @@ def _do_spin_once(self, *args: typing.Any, **kwargs: typing.Any) -> None:
# dispatch and be missed. Fortunately, this will only delay dispatch until the
# next spin cycle.
if task not in self._work_in_progress or (self._work_in_progress[task].done() and not task.done()):
if task.callback_group is not None:
if task.callback_group not in self._callback_group_affinity:
self._callback_group_affinity[task.callback_group] = self._thread_pools[0]
if task.callback_group is not None and task.callback_group in self._callback_group_affinity:
thread_pool = self._callback_group_affinity[task.callback_group]
thread_pool_index = self._static_thread_pools.index(thread_pool)
self._logger.debug(f"Task '{task}' submitted to static thread pool #{thread_pool_index}")
elif self._timers_thread_pool is not None and isinstance(task.entity, rclpy.timer.Timer):
thread_pool = self._timers_thread_pool
self._logger.debug(f"Task '{task}' submitted to timers thread pool")
else:
thread_pool = self._thread_pools[0]
thread_pool_index = self._thread_pools.index(thread_pool)
self._logger.debug(f"Task '{task}' submitted to thread pool #{thread_pool_index}")
thread_pool = self._default_thread_pool
self._logger.debug(f"Task '{task}' submitted to default thread pool")
self._work_in_progress[task] = thread_pool.submit(task)
for task in list(self._work_in_progress):
if not task.done():
Expand Down Expand Up @@ -781,10 +815,11 @@ def shutdown(self, timeout_sec: typing.Optional[float] = None) -> bool:
# must be waited on. Work tracking in rclpy.executors.Executor
# base implementation is subject to races, so block thread pool
# submissions and wait for all futures to finish. Then shutdown.
done = all(thread_pool.wait(timeout_sec) for thread_pool in self._thread_pools)

done = all(thread_pool.wait(timeout_sec) for thread_pool in self.thread_pools)
if done:
assert super().shutdown(timeout_sec=0)
for thread_pool in self._thread_pools:
for thread_pool in self.thread_pools:
thread_pool.shutdown()
self._is_shutdown = True
if done:
Expand Down
35 changes: 35 additions & 0 deletions synchros2/synchros2/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,22 @@
import functools
from typing import Any, Callable, Iterable, Optional, Type

try:
from typing import override # type: ignore[attr-defined]
except ImportError:
from typing_extensions import override # type: ignore[import]


from rclpy.callback_groups import CallbackGroup
from rclpy.clock import Clock
from rclpy.exceptions import InvalidHandle
from rclpy.node import Node as BaseNode
from rclpy.timer import Rate
from rclpy.waitable import Waitable

from synchros2.callback_groups import NonReentrantCallbackGroup
from synchros2.logging import MemoizingRcutilsLogger
from synchros2.time import SteadyRate


def suppressed(exception: Type[BaseException], func: Callable) -> Callable:
Expand Down Expand Up @@ -55,6 +64,32 @@ def default_callback_group(self) -> CallbackGroup:
# NOTE(hidmic): this overrides the hardcoded default group in rclpy.node.Node implementation
return self._default_callback_group_override

@override
def create_rate(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be marked override?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are technically not overriding but shadowing: return types are different. Hmm, but maybe we could make it look like Rate was returned. Let me see.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PTAL @ 230c702

self,
frequency: float,
clock: Optional[Clock] = None,
) -> Rate:
"""Create a Rate object.

:param frequency: The frequency the Rate runs at (Hz).
:param clock: The clock the Rate gets time from.
"""
if clock is None:
clock = self.get_clock()
return SteadyRate(frequency, clock, context=self._context)

@override
def destroy_rate(self, rate: Rate) -> bool:
"""Destroy a Rate object created by the node.

:return: ``True`` if successful, ``False`` otherwise.
"""
if isinstance(rate, SteadyRate):
rate.destroy()
return True
return super().destroy_rate(rate)

@property
def waitables(self) -> Iterable[Waitable]:
"""Get patched node waitables.
Expand Down
75 changes: 74 additions & 1 deletion synchros2/synchros2/time.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
# Copyright (c) 2024 Robotics and AI Institute LLC dba RAI Institute. All rights reserved.

import threading
from datetime import datetime, timedelta
from typing import Union
from typing import Optional, Union

try:
from typing import override # type: ignore[attr-defined]
except ImportError:
from typing_extensions import override # type: ignore[import]

from rclpy.context import Context
from rclpy.duration import Duration
from rclpy.exceptions import ROSInterruptException
from rclpy.time import Time
from rclpy.timer import Rate
from rclpy.utilities import get_default_context


def as_proper_time(time: Union[int, float, datetime, Time]) -> Time:
Expand Down Expand Up @@ -57,3 +67,66 @@ def as_proper_duration(duration: Union[int, float, timedelta, Duration]) -> Dura
if not isinstance(duration, Duration):
raise ValueError(f"unsupported duration type: {duration}")
return duration


class SteadyRate(Rate):
"""An rclpy.Rate equivalent that uses clock functionality directly, without timer overhead."""

def __init__(self, frequency: float, clock: Time, *, context: Optional[Context] = None) -> None:
# NOTE: SteadyRate subclasses Rate for type consistency but does not use any of its functionality.
# Thus, we skip the constructor call entirely.
self._clock = clock
if context is None:
context = get_default_context()
self._context = context
self._period = as_proper_duration(1.0 / frequency)
self._deadline = self._clock.now() + self._period

self._lock = threading.Lock()
self._num_sleepers = 0

self._is_shutdown = False
self._is_destroyed = False
self._context.on_shutdown(self._on_shutdown)

@override
def _on_shutdown(self) -> None:
self._is_shutdown = True
self.destroy()

@override
def destroy(self) -> None:
"""Destroy the rate."""
self._is_destroyed = True

@override
def _presleep(self) -> None:
if self._is_shutdown:
raise ROSInterruptException()
if self._is_destroyed:
raise RuntimeError("MonotonicRate cannot sleep because it has been destroyed")
with self._lock:
self._num_sleepers += 1

@override
def _postsleep(self) -> None:
with self._lock:
self._num_sleepers -= 1
if self._num_sleepers == 0:
now = self._clock.now()
next_deadline = self._deadline + self._period
if now < self._deadline or now > next_deadline:
next_deadline = now + self._period
self._deadline = next_deadline
if self._is_shutdown:
self.destroy()
raise ROSInterruptException()

@override
def sleep(self) -> None:
"""Block until the current period is over."""
self._presleep()
try:
self._clock.sleep_until(self._deadline, context=self._context)
finally:
self._postsleep()
39 changes: 39 additions & 0 deletions synchros2/test/test_executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,45 @@ def deferred() -> bool:
assert future.result()


def test_autoscaling_executor_with_timers_thread_pool(ros_context: Context, ros_node: Node) -> None:
"""Asserts that the autoscaling multithreaded executor routes timer callbacks to the
dedicated timers thread pool and leaves non-timer work to the default thread pool.
"""
with background(
AutoScalingMultiThreadedExecutor(
context=ros_context,
num_threads_for_timers=1,
logger=logging.root,
),
) as executor:
assert executor.timers_thread_pool is not None
executor.add_node(ros_node)

timer_threads: List[threading.Thread] = []
task_threads: List[threading.Thread] = []

def timer_callback() -> None:
timer_threads.append(threading.current_thread())

ros_node.create_timer(0.05, timer_callback, ReentrantCallbackGroup())

def task_callback() -> None:
task_threads.append(threading.current_thread())
time.sleep(0.05)

for _ in range(5):
executor.create_task(task_callback)

time.sleep(1.0)

assert len(timer_threads) > 0
assert len(task_threads) > 0
# All timer callbacks must run on the same single timers-pool thread
assert all(t is timer_threads[0] for t in timer_threads[1:])
# Task callbacks must never have run on the timers-pool thread
assert not any(t is timer_threads[0] for t in task_threads)


@pytest.mark.filterwarnings("ignore")
def test_background_executor_shows_errors(ros_context: Context, ros_node: Node) -> None:
"""Asserts that an background executor does not swallow callback exceptions."""
Expand Down
Loading
Loading