diff --git a/synchros2/synchros2/executors.py b/synchros2/synchros2/executors.py index 2eb02a3..9bd5fba 100644 --- a/synchros2/synchros2/executors.py +++ b/synchros2/synchros2/executors.py @@ -17,6 +17,7 @@ import rclpy.callback_groups import rclpy.executors import rclpy.node +import rclpy.timer from synchros2.futures import FutureLike from synchros2.utilities import bind_to_thread, fqn @@ -592,6 +593,7 @@ def __init__( max_thread_idle_time: typing.Optional[float] = None, max_threads_per_callback_group: typing.Optional[int] = None, *, + num_threads_for_timers: typing.Optional[int] = None, context: typing.Optional[rclpy.context.Context] = None, logger: typing.Optional[logging.Logger] = None, ) -> None: @@ -607,24 +609,41 @@ def __init__( max_threads_per_callback_group: optional maximum number of concurrent callbacks the default thread pool should service for a given callback group. Useful to avoid reentrant callback groups from starving the default thread pool. + num_threads_for_timers: optional number of threads to dedicate to timer callbacks. + Defaults to 10% of all available threads, which may be 0 if there are less than + 10 threads, in which case timer callbacks will be serviced by the default thread pool. context: An optional instance of the ros context. logger: An optional logger instance. """ super().__init__(context=context) if logger is None: logger = rclpy.logging.get_logger(fqn(self.__class__)) + if max_threads is None: + max_threads = 32 * (os.cpu_count() or 1) + if num_threads_for_timers is None: + num_threads_for_timers = max_threads // 10 + if num_threads_for_timers == 0: + logger.warning("Not enough threads available, timers will be serviced by the default thread pool") + max_threads -= num_threads_for_timers self._logger = logger self._is_shutdown = False self._spin_lock = threading.Lock() self._shutdown_lock = threading.RLock() - self._thread_pools = [ - AutoScalingThreadPool( - max_workers=max_threads, - max_idle_time=max_thread_idle_time, + self._default_thread_pool = AutoScalingThreadPool( + max_workers=max_threads, + max_idle_time=max_thread_idle_time, + submission_quota=max_threads_per_callback_group, + logger=self._logger, + ) + self._timers_thread_pool: typing.Optional[AutoScalingThreadPool] = None + if num_threads_for_timers != 0: + self._timers_thread_pool = AutoScalingThreadPool( + min_workers=num_threads_for_timers, + max_workers=num_threads_for_timers, submission_quota=max_threads_per_callback_group, logger=self._logger, - ), - ] + ) + self._static_thread_pools: typing.List[AutoScalingThreadPool] = [] self._callback_group_affinity: weakref.WeakKeyDictionary[ rclpy.callback_groups.CallbackGroup, AutoScalingThreadPool, @@ -637,12 +656,21 @@ def __init__( @property def default_thread_pool(self) -> AutoScalingThreadPool: """Default autoscaling thread pool.""" - return self._thread_pools[0] + return self._default_thread_pool + + @property + def timers_thread_pool(self) -> typing.Optional[AutoScalingThreadPool]: + """Autoscaling thread pool for timer callbacks.""" + return self._timers_thread_pool @property def thread_pools(self) -> typing.List[AutoScalingThreadPool]: """Autoscaling thread pools in use.""" - return list(self._thread_pools) + thread_pools = [self._default_thread_pool] + if self._timers_thread_pool is not None: + thread_pools.append(self._timers_thread_pool) + thread_pools.extend(self._static_thread_pools) + return thread_pools def add_static_thread_pool(self, num_threads: typing.Optional[int] = None) -> AutoScalingThreadPool: """Add a thread pool that keeps a steady number of workers.""" @@ -653,8 +681,8 @@ def add_static_thread_pool(self, num_threads: typing.Optional[int] = None) -> Au max_workers=num_threads, logger=self._logger, ) - self._thread_pools.append(thread_pool) - self._logger.debug(f"Added static thread pool #{len(self._thread_pools) - 1}") + self._static_thread_pools.append(thread_pool) + self._logger.debug(f"Added static thread pool #{len(self._static_thread_pools) - 1}") return thread_pool def bind(self, callback_group: rclpy.callback_groups.CallbackGroup, thread_pool: AutoScalingThreadPool) -> None: @@ -663,9 +691,13 @@ def bind(self, callback_group: rclpy.callback_groups.CallbackGroup, thread_pool: Thread pool must be known to the executor. That is, instantiated through add_*_thread_pool() methods. """ with self._shutdown_lock: - if thread_pool not in self._thread_pools: + if thread_pool not in self._static_thread_pools: + if thread_pool is self._default_thread_pool: + raise ValueError("cannot rebind to default thread pool") + if thread_pool is self._timers_thread_pool: + raise ValueError("cannot bind to timers thread pool") raise ValueError("thread pool unknown to executor") - thread_pool_index = self._thread_pools.index(thread_pool) + thread_pool_index = self._static_thread_pools.index(thread_pool) callback_group_name = f"{fqn(type(callback_group))}@{id(callback_group)}" self._logger.debug(f"Binding {callback_group_name} to thread pool #{thread_pool_index}...") self._callback_group_affinity[callback_group] = thread_pool @@ -698,14 +730,16 @@ def _do_spin_once(self, *args: typing.Any, **kwargs: typing.Any) -> None: # dispatch and be missed. Fortunately, this will only delay dispatch until the # next spin cycle. if task not in self._work_in_progress or (self._work_in_progress[task].done() and not task.done()): - if task.callback_group is not None: - if task.callback_group not in self._callback_group_affinity: - self._callback_group_affinity[task.callback_group] = self._thread_pools[0] + if task.callback_group is not None and task.callback_group in self._callback_group_affinity: thread_pool = self._callback_group_affinity[task.callback_group] + thread_pool_index = self._static_thread_pools.index(thread_pool) + self._logger.debug(f"Task '{task}' submitted to static thread pool #{thread_pool_index}") + elif self._timers_thread_pool is not None and isinstance(task.entity, rclpy.timer.Timer): + thread_pool = self._timers_thread_pool + self._logger.debug(f"Task '{task}' submitted to timers thread pool") else: - thread_pool = self._thread_pools[0] - thread_pool_index = self._thread_pools.index(thread_pool) - self._logger.debug(f"Task '{task}' submitted to thread pool #{thread_pool_index}") + thread_pool = self._default_thread_pool + self._logger.debug(f"Task '{task}' submitted to default thread pool") self._work_in_progress[task] = thread_pool.submit(task) for task in list(self._work_in_progress): if not task.done(): @@ -781,10 +815,11 @@ def shutdown(self, timeout_sec: typing.Optional[float] = None) -> bool: # must be waited on. Work tracking in rclpy.executors.Executor # base implementation is subject to races, so block thread pool # submissions and wait for all futures to finish. Then shutdown. - done = all(thread_pool.wait(timeout_sec) for thread_pool in self._thread_pools) + + done = all(thread_pool.wait(timeout_sec) for thread_pool in self.thread_pools) if done: assert super().shutdown(timeout_sec=0) - for thread_pool in self._thread_pools: + for thread_pool in self.thread_pools: thread_pool.shutdown() self._is_shutdown = True if done: diff --git a/synchros2/synchros2/node.py b/synchros2/synchros2/node.py index dcb7e7a..b796e78 100644 --- a/synchros2/synchros2/node.py +++ b/synchros2/synchros2/node.py @@ -3,13 +3,22 @@ import functools from typing import Any, Callable, Iterable, Optional, Type +try: + from typing import override # type: ignore[attr-defined] +except ImportError: + from typing_extensions import override # type: ignore[import] + + from rclpy.callback_groups import CallbackGroup +from rclpy.clock import Clock from rclpy.exceptions import InvalidHandle from rclpy.node import Node as BaseNode +from rclpy.timer import Rate from rclpy.waitable import Waitable from synchros2.callback_groups import NonReentrantCallbackGroup from synchros2.logging import MemoizingRcutilsLogger +from synchros2.time import SteadyRate def suppressed(exception: Type[BaseException], func: Callable) -> Callable: @@ -55,6 +64,32 @@ def default_callback_group(self) -> CallbackGroup: # NOTE(hidmic): this overrides the hardcoded default group in rclpy.node.Node implementation return self._default_callback_group_override + @override + def create_rate( + self, + frequency: float, + clock: Optional[Clock] = None, + ) -> Rate: + """Create a Rate object. + + :param frequency: The frequency the Rate runs at (Hz). + :param clock: The clock the Rate gets time from. + """ + if clock is None: + clock = self.get_clock() + return SteadyRate(frequency, clock, context=self._context) + + @override + def destroy_rate(self, rate: Rate) -> bool: + """Destroy a Rate object created by the node. + + :return: ``True`` if successful, ``False`` otherwise. + """ + if isinstance(rate, SteadyRate): + rate.destroy() + return True + return super().destroy_rate(rate) + @property def waitables(self) -> Iterable[Waitable]: """Get patched node waitables. diff --git a/synchros2/synchros2/time.py b/synchros2/synchros2/time.py index 9d18a7f..9773335 100644 --- a/synchros2/synchros2/time.py +++ b/synchros2/synchros2/time.py @@ -1,10 +1,20 @@ # Copyright (c) 2024 Robotics and AI Institute LLC dba RAI Institute. All rights reserved. +import threading from datetime import datetime, timedelta -from typing import Union +from typing import Optional, Union +try: + from typing import override # type: ignore[attr-defined] +except ImportError: + from typing_extensions import override # type: ignore[import] + +from rclpy.context import Context from rclpy.duration import Duration +from rclpy.exceptions import ROSInterruptException from rclpy.time import Time +from rclpy.timer import Rate +from rclpy.utilities import get_default_context def as_proper_time(time: Union[int, float, datetime, Time]) -> Time: @@ -57,3 +67,66 @@ def as_proper_duration(duration: Union[int, float, timedelta, Duration]) -> Dura if not isinstance(duration, Duration): raise ValueError(f"unsupported duration type: {duration}") return duration + + +class SteadyRate(Rate): + """An rclpy.Rate equivalent that uses clock functionality directly, without timer overhead.""" + + def __init__(self, frequency: float, clock: Time, *, context: Optional[Context] = None) -> None: + # NOTE: SteadyRate subclasses Rate for type consistency but does not use any of its functionality. + # Thus, we skip the constructor call entirely. + self._clock = clock + if context is None: + context = get_default_context() + self._context = context + self._period = as_proper_duration(1.0 / frequency) + self._deadline = self._clock.now() + self._period + + self._lock = threading.Lock() + self._num_sleepers = 0 + + self._is_shutdown = False + self._is_destroyed = False + self._context.on_shutdown(self._on_shutdown) + + @override + def _on_shutdown(self) -> None: + self._is_shutdown = True + self.destroy() + + @override + def destroy(self) -> None: + """Destroy the rate.""" + self._is_destroyed = True + + @override + def _presleep(self) -> None: + if self._is_shutdown: + raise ROSInterruptException() + if self._is_destroyed: + raise RuntimeError("MonotonicRate cannot sleep because it has been destroyed") + with self._lock: + self._num_sleepers += 1 + + @override + def _postsleep(self) -> None: + with self._lock: + self._num_sleepers -= 1 + if self._num_sleepers == 0: + now = self._clock.now() + next_deadline = self._deadline + self._period + if now < self._deadline or now > next_deadline: + next_deadline = now + self._period + self._deadline = next_deadline + if self._is_shutdown: + self.destroy() + raise ROSInterruptException() + + @override + def sleep(self) -> None: + """Block until the current period is over.""" + self._presleep() + try: + self._clock.sleep_until(self._deadline, context=self._context) + finally: + self._postsleep() diff --git a/synchros2/test/test_executors.py b/synchros2/test/test_executors.py index 3a6a330..9d19a3a 100644 --- a/synchros2/test/test_executors.py +++ b/synchros2/test/test_executors.py @@ -241,6 +241,45 @@ def deferred() -> bool: assert future.result() +def test_autoscaling_executor_with_timers_thread_pool(ros_context: Context, ros_node: Node) -> None: + """Asserts that the autoscaling multithreaded executor routes timer callbacks to the + dedicated timers thread pool and leaves non-timer work to the default thread pool. + """ + with background( + AutoScalingMultiThreadedExecutor( + context=ros_context, + num_threads_for_timers=1, + logger=logging.root, + ), + ) as executor: + assert executor.timers_thread_pool is not None + executor.add_node(ros_node) + + timer_threads: List[threading.Thread] = [] + task_threads: List[threading.Thread] = [] + + def timer_callback() -> None: + timer_threads.append(threading.current_thread()) + + ros_node.create_timer(0.05, timer_callback, ReentrantCallbackGroup()) + + def task_callback() -> None: + task_threads.append(threading.current_thread()) + time.sleep(0.05) + + for _ in range(5): + executor.create_task(task_callback) + + time.sleep(1.0) + + assert len(timer_threads) > 0 + assert len(task_threads) > 0 + # All timer callbacks must run on the same single timers-pool thread + assert all(t is timer_threads[0] for t in timer_threads[1:]) + # Task callbacks must never have run on the timers-pool thread + assert not any(t is timer_threads[0] for t in task_threads) + + @pytest.mark.filterwarnings("ignore") def test_background_executor_shows_errors(ros_context: Context, ros_node: Node) -> None: """Asserts that an background executor does not swallow callback exceptions.""" diff --git a/synchros2/test/test_time.py b/synchros2/test/test_time.py new file mode 100644 index 0000000..fe29288 --- /dev/null +++ b/synchros2/test/test_time.py @@ -0,0 +1,138 @@ +# Copyright (c) 2026 Robotics and AI Institute LLC dba RAI Institute. All rights reserved. +import threading +import time +from datetime import datetime, timedelta, timezone +from typing import Generator + +import pytest +import rclpy +from rclpy.clock import Clock, ClockType +from rclpy.context import Context +from rclpy.duration import Duration +from rclpy.exceptions import ROSInterruptException +from rclpy.time import Time + +from synchros2.time import SteadyRate, as_proper_duration, as_proper_time + + +@pytest.fixture +def ros_context(domain_id: int) -> Generator[Context, None, None]: + context = Context() + rclpy.init(context=context, domain_id=domain_id) + try: + yield context + finally: + context.try_shutdown() + + +@pytest.fixture +def steady_clock() -> Clock: + return Clock(clock_type=ClockType.STEADY_TIME) + + +def test_as_proper_time_from_int() -> None: + t = as_proper_time(5) + assert isinstance(t, Time) + assert t.nanoseconds == 5_000_000_000 + + +def test_as_proper_time_from_float() -> None: + t = as_proper_time(1.5) + assert isinstance(t, Time) + assert t.nanoseconds == 1_500_000_000 + + +def test_as_proper_time_from_datetime() -> None: + dt = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + t = as_proper_time(dt) + assert isinstance(t, Time) + assert t.nanoseconds == int(dt.timestamp() * 1e9) + + +def test_as_proper_time_from_time() -> None: + original = Time(seconds=42) + t = as_proper_time(original) + assert t is original + + +def test_as_proper_time_raises_on_invalid_type() -> None: + with pytest.raises(ValueError): + as_proper_time("not a time") # type: ignore[arg-type] + + +def test_as_proper_duration_from_int() -> None: + d = as_proper_duration(3) + assert isinstance(d, Duration) + assert d.nanoseconds == 3_000_000_000 + + +def test_as_proper_duration_from_float() -> None: + d = as_proper_duration(0.5) + assert isinstance(d, Duration) + assert d.nanoseconds == 500_000_000 + + +def test_as_proper_duration_from_timedelta() -> None: + td = timedelta(seconds=2, milliseconds=500) + d = as_proper_duration(td) + assert isinstance(d, Duration) + assert d.nanoseconds == 2_500_000_000 + + +def test_as_proper_duration_from_duration() -> None: + original = Duration(seconds=7) + d = as_proper_duration(original) + assert d is original + + +def test_as_proper_duration_raises_on_invalid_type() -> None: + with pytest.raises(ValueError): + as_proper_duration("not a duration") # type: ignore[arg-type] + + +def test_steady_rate_fires_at_expected_frequency(ros_context: Context, steady_clock: Clock) -> None: + """SteadyRate sleep() fires at approximately the requested frequency.""" + frequency = 10.0 # Hz + rate = SteadyRate(frequency, steady_clock, context=ros_context) + + iterations = 10 + start = time.monotonic() + for _ in range(iterations): + rate.sleep() + elapsed = time.monotonic() - start + + expected = iterations / frequency + # Allow ±5 % tolerance for CI timing jitter + assert abs(elapsed - expected) / expected < 0.05, f"elapsed={elapsed:.3f}s, expected≈{expected:.3f}s" + + +def test_steady_rate_raises_on_context_shutdown(ros_context: Context, steady_clock: Clock) -> None: + """SteadyRate.sleep() raises ROSInterruptException when the context is shut down.""" + rate = SteadyRate(0.1, steady_clock, context=ros_context) # very slow + + exceptions: list = [] + + def sleeper() -> None: + try: + rate.sleep() + except ROSInterruptException as e: + exceptions.append(e) + + worker = threading.Thread(target=sleeper) + worker.start() + time.sleep(0.05) # let the thread enter sleep_until + ros_context.try_shutdown() + worker.join(timeout=2.0) + + assert not worker.is_alive(), "sleeper thread did not unblock after context shutdown" + assert len(exceptions) == 1 + assert isinstance(exceptions[0], ROSInterruptException) + + +def test_steady_rate_raises_on_destroy(ros_context: Context, steady_clock: Clock) -> None: + """SteadyRate.sleep() raises RuntimeError when called after destroy().""" + rate = SteadyRate(10.0, steady_clock, context=ros_context) + rate.destroy() + + with pytest.raises(RuntimeError): + rate.sleep()